Compare commits

...

7 Commits

Author SHA1 Message Date
Viktor Liu
91ae600670 Unify peer and route ACL filtering with multi-source peer rules 2026-06-01 10:49:47 +02:00
Pascal Fischer
9189625487 [management] enrich context in permissions manager (#6286) 2026-05-29 16:36:38 +02:00
Bethuel Mmbaga
e9dbf9db6f [management] Extend combined server initialization (#6156) 2026-05-29 17:35:35 +03:00
Theodor Midtlien
5a9e9e7bc9 [Infrastructure] Pin actions with SHA and improve workflows (#6249)
* Pin actions with SHA, replace unmaintained, add dependabot for actions

* Update FreeBSD to version 15 for tests

* Use shared actions

* Update sign-pipelines version
2026-05-29 15:24:30 +02:00
Viktor Liu
43e041cf9f [client] Apply netroute unspecified-destination workaround on android (#6192) 2026-05-29 15:15:22 +02:00
Viktor Liu
77e5693200 [client] Recognize NetBird DNS forwarder port in capture text format (#6177) 2026-05-29 15:14:32 +02:00
Zoltan Papp
174dc24867 [management] Add SSO session extend flow (management) (#6197)
* add SSO session extend flow (management)

Adds the management-server half of the SSO session-extension feature:

- New ExtendAuthSession gRPC RPC that refreshes a peer's session expiry
  using a fresh JWT, validated through the same pipeline as Login but
  without tearing down the tunnel or redoing the NetworkMap sync.
- Per-peer SessionExpiresAt timestamp on every LoginResponse and
  SyncResponse so connected clients learn the deadline on the existing
  long-lived stream, and admin-side changes (toggling expiration,
  changing the expiration window) reach every peer within seconds.
- SessionExpiresAt(...) helper on Peer that derives the absolute UTC
  deadline from LastLogin + the account-level PeerLoginExpiration
  setting, returning zero when the peer is not SSO-tracked or expiration
  is disabled.

The matching client-side consumer of these fields lands separately.

* encode SessionExpiresAt as 3-state on the wire

Previously the `sessionExpiresAt` field on LoginResponse, SyncResponse
and ExtendAuthSessionResponse was 2-state: a valid timestamp meant
"new deadline", and nil meant "clear". That conflated two distinct
meanings — "no info in this snapshot" vs "expiry is explicitly off /
peer is not SSO-tracked" — so a Sync push that legitimately couldn't
compute the deadline (settings lookup failed) would silently clear the
client's anchor and lose the warning window.

Three states now, encoded on the same field number (no .proto schema
churn — only comments and the server-side encoder change):

  - nil pointer (field absent) → "no info"; client preserves anchor
  - &Timestamp{} (seconds=0, nanos=0) → explicit "disabled / not SSO"
    sentinel; client clears
  - valid timestamp → new absolute UTC deadline

A new encodeSessionExpiresAt helper centralises the zero/non-zero
encoding and is shared by the Sync, Login and ExtendAuthSession
builders. The Sync builder still emits nil when settings are missing.
Login and ExtendAuthSession always carry an authoritative value.

The matching client-side decoder lands on feature/session-extend.

* add UserExtendedPeerSession activity event

ExtendAuthSession previously reused UserLoggedInPeer for its audit
record, which conflated two distinct user actions: a full interactive
SSO login (tunnel re-established, network map resync) versus an
in-place deadline refresh (tunnel untouched). Auditors reading the log
couldn't tell which one happened, and downstream dashboards/alerts on
"login" volume were polluted by routine extends.

Adds a dedicated UserExtendedPeerSession Activity (code 125,
"user.peer.session.extend") and switches ExtendPeerSession over to it.
The peer-extend audit trail is now distinguishable from interactive
logins.

* make ExtendAuthSession JWT-retry backoff cancellable

Skip the retry log and 200ms wait on the final attempt, and replace the
uncancellable time.Sleep with a select on time.After/ctx.Done so an
upstream cancellation aborts the wait instead of running it to
completion.
2026-05-28 19:14:14 +02:00
130 changed files with 9077 additions and 7324 deletions

45
.github/dependabot.yml vendored Normal file
View File

@@ -0,0 +1,45 @@
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "daily"
open-pull-requests-limit: 15
groups:
actions:
patterns:
- "*"
ignore:
# git-town/action v1.3.x crashes on cyclic PR graphs (self-loop main->main
# fork PRs) via its topological-sort visualization. Pinned to v1.2.1 in
# git-town.yml; block v1.3.x until upstream tolerates cyclic edges.
- dependency-name: "git-town/action"
update-types:
- "version-update:semver-minor"
- "version-update:semver-major"
- package-ecosystem: "gomod"
directories:
- "/"
schedule:
interval: "daily"
open-pull-requests-limit: 15
groups:
aws-sdk:
patterns:
- "github.com/aws/aws-sdk-go-v2/*"
pion:
patterns:
- "github.com/pion/*"
gorm:
patterns:
- "gorm.io/*"
otel:
patterns:
- "go.opentelemetry.io/*"
testcontainers:
patterns:
- "github.com/testcontainers/testcontainers-go/*"
wireguard:
patterns:
- "golang.zx2c4.com/wireguard*"

View File

@@ -2,16 +2,16 @@ name: Check License Dependencies
on:
push:
branches: [ main ]
branches: [main]
paths:
- 'go.mod'
- 'go.sum'
- '.github/workflows/check-license-dependencies.yml'
- "go.mod"
- "go.sum"
- ".github/workflows/check-license-dependencies.yml"
pull_request:
paths:
- 'go.mod'
- 'go.sum'
- '.github/workflows/check-license-dependencies.yml'
- "go.mod"
- "go.sum"
- ".github/workflows/check-license-dependencies.yml"
jobs:
check-internal-dependencies:
@@ -19,7 +19,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Check for problematic license dependencies
run: |
@@ -56,55 +59,57 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version-file: 'go.mod'
cache: true
- name: Set up Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: true
- name: Install go-licenses
run: go install github.com/google/go-licenses@v1.6.0
- name: Install go-licenses
run: go install github.com/google/go-licenses@v1.6.0
- name: Check for GPL/AGPL licensed dependencies
run: |
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
echo ""
# Check all Go packages for copyleft licenses, excluding internal netbird packages
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
if [ -n "$COPYLEFT_DEPS" ]; then
echo "Found copyleft licensed dependencies:"
echo "$COPYLEFT_DEPS"
- name: Check for GPL/AGPL licensed dependencies
run: |
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
echo ""
# Filter out dependencies that are only pulled in by internal AGPL packages
INCOMPATIBLE=""
while IFS=',' read -r package url license; do
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
# Find ALL packages that import this GPL package using go list
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
# Check all Go packages for copyleft licenses, excluding internal netbird packages
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
# Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
else
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
fi
fi
done <<< "$COPYLEFT_DEPS"
if [ -n "$INCOMPATIBLE" ]; then
if [ -n "$COPYLEFT_DEPS" ]; then
echo "Found copyleft licensed dependencies:"
echo "$COPYLEFT_DEPS"
echo ""
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
echo -e "$INCOMPATIBLE"
exit 1
fi
fi
echo "✅ All external license dependencies are compatible with BSD-3-Clause"
# Filter out dependencies that are only pulled in by internal AGPL packages
INCOMPATIBLE=""
while IFS=',' read -r package url license; do
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
# Find ALL packages that import this GPL package using go list
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
# Check if any importer is NOT in management/signal/relay
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
if [ -n "$BSD_IMPORTER" ]; then
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
else
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
fi
fi
done <<< "$COPYLEFT_DEPS"
if [ -n "$INCOMPATIBLE" ]; then
echo ""
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
echo -e "$INCOMPATIBLE"
exit 1
fi
fi
echo "✅ All external license dependencies are compatible with BSD-3-Clause"

View File

@@ -83,7 +83,7 @@ jobs:
- name: Verify docs PR exists (and is open or merged)
if: steps.validate.outputs.mode == 'added'
uses: actions/github-script@v7
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
id: verify
with:
pr_number: ${{ steps.extract.outputs.pr_number }}

View File

@@ -8,11 +8,10 @@ jobs:
post:
runs-on: ubuntu-latest
steps:
- uses: roots/discourse-topic-github-release-action@main
- uses: roots/discourse-topic-github-release-action@557d74ea05b6cc0c47f555c1d5d28a89d904005b # v1.1.0
with:
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
discourse-base-url: https://forum.netbird.io
discourse-author-username: NetBird
discourse-category: 17
discourse-tags:
releases
discourse-tags: releases

View File

@@ -3,7 +3,7 @@ name: Git Town
on:
pull_request:
branches:
- '**'
- "**"
jobs:
git-town:
@@ -15,7 +15,9 @@ jobs:
pull-requests: write
steps:
- uses: actions/checkout@v4
- uses: git-town/action@v1.2.1
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- uses: git-town/action@3d8b878379abb1ee393fb49865a28b4a6c2cd3b0 # v1.2.1
with:
skip-single-stacks: true

View File

@@ -16,16 +16,18 @@ jobs:
runs-on: macos-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@v4
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: ~/go/pkg/mod
key: macos-gotest-${{ hashFiles('**/go.sum') }}
@@ -44,4 +46,3 @@ jobs:
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)

View File

@@ -15,20 +15,31 @@ jobs:
name: "Client / Unit"
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Read Go version from go.mod
id: goversion
run: echo "version=$(awk '/^go / {print $2}' go.mod)" >> "$GITHUB_OUTPUT"
- name: Test in FreeBSD
id: test
uses: vmactions/freebsd-vm@v1
env:
GO_VERSION: ${{ steps.goversion.outputs.version }}
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
with:
usesh: true
copyback: false
release: "14.2"
release: "15.0"
envs: "GO_VERSION"
prepare: |
pkg install -y curl pkgconf xorg
GO_TARBALL="go1.25.3.freebsd-amd64.tar.gz"
GO_TARBALL="go${GO_VERSION}.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -vLO "$GO_URL"
tar -C /usr/local -vxzf "$GO_TARBALL"
tar -C /usr/local -vxzf "$GO_TARBALL"
# -x - to print all executed commands
# -e - to faile on first error

View File

@@ -18,9 +18,11 @@ jobs:
management: ${{ steps.filter.outputs.management }}
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- uses: dorny/paths-filter@v3
- uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1
id: filter
with:
filters: |
@@ -28,7 +30,7 @@ jobs:
- 'management/**'
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -36,10 +38,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@v4
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
id: cache
with:
path: |
@@ -113,14 +115,16 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
arch: ["386", "amd64"]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -128,10 +132,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
${{ env.cache }}
@@ -154,18 +158,20 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags "devcert integration" -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
test_client_on_docker:
name: "Client (Docker) / Unit"
needs: [ build-cache ]
needs: [build-cache]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -177,7 +183,7 @@ jobs:
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
- name: Cache Go modules
uses: actions/cache/restore@v4
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
id: cache-restore
with:
path: |
@@ -214,7 +220,7 @@ jobs:
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
go test -buildvcs=false -tags "devcert integration" -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
'
test_relay:
@@ -231,10 +237,12 @@ jobs:
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -246,10 +254,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
${{ env.cache }}
@@ -277,14 +285,16 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
arch: ["386", "amd64"]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -298,7 +308,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
${{ env.cache }}
@@ -324,14 +334,16 @@ jobs:
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
arch: ["386", "amd64"]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -343,10 +355,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
${{ env.cache }}
@@ -370,19 +382,21 @@ jobs:
test_management:
name: "Management / Unit"
needs: [ build-cache ]
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ]
arch: ["amd64"]
store: ["sqlite", "postgres", "mysql"]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -390,10 +404,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
${{ env.cache }}
@@ -410,7 +424,7 @@ jobs:
- name: Login to Docker hub
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
uses: docker/login-action@v3
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
@@ -427,7 +441,7 @@ jobs:
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: |
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
@@ -437,13 +451,13 @@ jobs:
benchmark:
name: "Management / Benchmark"
needs: [ build-cache ]
needs: [build-cache]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ]
arch: ["amd64"]
store: ["sqlite", "postgres"]
runs-on: ubuntu-22.04
steps:
- name: Create Docker network
@@ -474,10 +488,12 @@ jobs:
prom/prometheus
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -485,10 +501,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
${{ env.cache }}
@@ -505,7 +521,7 @@ jobs:
- name: Login to Docker hub
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
uses: docker/login-action@v3
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
@@ -529,13 +545,13 @@ jobs:
api_benchmark:
name: "Management / Benchmark (API)"
needs: [ build-cache ]
needs: [build-cache]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ]
arch: ["amd64"]
store: ["sqlite", "postgres"]
runs-on: ubuntu-22.04
steps:
- name: Create Docker network
@@ -566,10 +582,12 @@ jobs:
prom/prometheus
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -577,10 +595,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
${{ env.cache }}
@@ -597,7 +615,7 @@ jobs:
- name: Login to Docker hub
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
uses: docker/login-action@v3
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
@@ -623,20 +641,22 @@ jobs:
api_integration_test:
name: "Management / Integration"
needs: [ build-cache ]
needs: [build-cache]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres']
arch: ["amd64"]
store: ["sqlite", "postgres"]
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -644,10 +664,10 @@ jobs:
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
${{ env.cache }}

View File

@@ -18,10 +18,12 @@ jobs:
runs-on: windows-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
id: go
with:
go-version-file: "go.mod"
@@ -33,7 +35,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@v4
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
${{ env.cache }}
@@ -44,16 +46,15 @@ jobs:
${{ runner.os }}-go-
- name: Download wintun
uses: carlosperate/download-file-action@v2
id: download-wintun
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
with:
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
file-name: wintun.zip
location: ${{ env.downloadPath }}
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
destination: ${{ env.downloadPath }}\wintun.zip
sha256: 07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51
- name: Decompressing wintun files
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
run: tar -xvf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
- run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'

View File

@@ -15,9 +15,11 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: codespell
uses: codespell-project/actions-codespell@v2
uses: codespell-project/actions-codespell@8f01853be192eb0f849a5c7d721450e7a467c579 # v2.2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
skip: go.mod,go.sum,**/proxy/web/**
@@ -38,13 +40,15 @@ jobs:
timeout-minutes: 15
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest'
run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
@@ -52,7 +56,7 @@ jobs:
if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: golangci-lint
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
with:
version: latest
skip-cache: true

View File

@@ -22,7 +22,9 @@ jobs:
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: run install script
env:

View File

@@ -16,23 +16,25 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
- name: Setup Android SDK
uses: android-actions/setup-android@v3
uses: android-actions/setup-android@40fd30fb8d7440372e1316f5d1809ec01dcd3699 # v4.0.1
with:
cmdline-tools-version: 8512546
- name: Setup Java
uses: actions/setup-java@v4
uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654
with:
java-version: "11"
distribution: "adopt"
- name: NDK Cache
id: ndk-cache
uses: actions/cache@v4
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620
@@ -52,9 +54,11 @@ jobs:
runs-on: macos-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
- name: install gomobile

View File

@@ -9,7 +9,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Validate PR title prefix
uses: actions/github-script@v7
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
with:
script: |
const title = context.payload.pull_request.title;

View File

@@ -10,7 +10,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Check for proto tool version changes
uses: actions/github-script@v7
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
with:
script: |
const files = await github.paginate(github.rest.pulls.listFiles, {

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.1.4"
SIGN_PIPE_VER: "v0.1.5"
GORELEASER_VER: "v2.14.3"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"
@@ -24,7 +24,9 @@ jobs:
runs-on: ubuntu-22.04
steps:
- name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Generate FreeBSD port diff
run: bash release_files/freebsd-port-diff.sh
@@ -51,19 +53,26 @@ jobs:
echo "Generated files for version: $VERSION"
cat netbird-*.diff
- name: Read Go version from go.mod
id: goversion
run: echo "version=$(awk '/^go / {print $2}' go.mod)" >> "$GITHUB_OUTPUT"
- name: Test FreeBSD port
if: steps.check_diff.outputs.diff_exists == 'true'
uses: vmactions/freebsd-vm@v1
env:
GO_VERSION: ${{ steps.goversion.outputs.version }}
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
with:
usesh: true
copyback: false
release: "15.0"
envs: "GO_VERSION"
prepare: |
# Install required packages
pkg install -y git curl portlint go
pkg install -y git curl portlint
# Install Go for building
GO_TARBALL="go1.25.5.freebsd-amd64.tar.gz"
GO_TARBALL="go${GO_VERSION}.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -LO "$GO_URL"
tar -C /usr/local -xzf "$GO_TARBALL"
@@ -93,19 +102,19 @@ jobs:
# Show patched Makefile
version=$(cat security/netbird/Makefile | grep -E '^DISTVERSION=' | awk '{print $NF}')
cd /usr/ports/security/netbird
export BATCH=yes
make package
pkg add ./work/pkg/netbird-*.pkg
netbird version | grep "$version"
echo "FreeBSD port test completed successfully!"
- name: Upload FreeBSD port files
if: steps.check_diff.outputs.diff_exists == 'true'
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
with:
name: freebsd-port-files
path: |
@@ -124,26 +133,25 @@ jobs:
env:
flags: ""
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
persist-credentials: false
- name: Parse semver string
id: semver_parser
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@v4
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
~/go/pkg/mod
@@ -156,18 +164,18 @@ jobs:
- name: check git status
run: git --no-pager diff --exit-code
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a #v4.0.0
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd #v4.0.0
- name: Login to Docker hub
if: github.event_name != 'pull_request'
uses: docker/login-action@v1
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Log in to the GitHub container registry
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
uses: docker/login-action@v3
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
with:
registry: ghcr.io
username: ${{ github.actor }}
@@ -191,7 +199,7 @@ jobs:
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
- name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@v4
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
with:
version: ${{ env.GORELEASER_VER }}
args: release --clean ${{ env.flags }}
@@ -282,28 +290,28 @@ jobs:
} >> "$GITHUB_OUTPUT"
- name: upload non tags for debug purposes
id: upload_release
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
with:
name: release
path: dist/
retention-days: 7
- name: upload linux packages
id: upload_linux_packages
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
with:
name: linux-packages
path: dist/netbird_linux**
retention-days: 7
- name: upload windows packages
id: upload_windows_packages
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
with:
name: windows-packages
path: dist/netbird_windows**
retention-days: 7
- name: upload macos packages
id: upload_macos_packages
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
with:
name: macos-packages
path: dist/netbird_darwin**
@@ -314,27 +322,26 @@ jobs:
outputs:
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
persist-credentials: false
- name: Parse semver string
id: semver_parser
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@v4
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
~/go/pkg/mod
@@ -375,7 +382,7 @@ jobs:
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
with:
version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
@@ -404,7 +411,7 @@ jobs:
run: rm -f /tmp/gpg-rpm-signing-key.asc
- name: upload non tags for debug purposes
id: upload_release_ui
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
with:
name: release-ui
path: dist/
@@ -418,16 +425,17 @@ jobs:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
persist-credentials: false
- name: Set up Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@v4
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: |
~/go/pkg/mod
@@ -441,7 +449,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@v4
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
with:
version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
@@ -449,7 +457,7 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: upload non tags for debug purposes
id: upload_release_ui_darwin
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
with:
name: release-ui-darwin
path: dist/
@@ -474,27 +482,26 @@ jobs:
PackageWorkdir: netbird_windows_${{ matrix.arch }}
downloadPath: '${{ github.workspace }}\temp'
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Parse semver string
id: semver_parser
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
- name: Checkout
uses: actions/checkout@v4
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
- name: Add 7-Zip to PATH
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- name: Download release artifacts
uses: actions/download-artifact@v4
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
with:
name: release
path: release
- name: Download UI release artifacts
uses: actions/download-artifact@v4
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
with:
name: release-ui
path: release-ui
@@ -514,29 +521,27 @@ jobs:
Get-ChildItem $workdir
- name: Download wintun
uses: carlosperate/download-file-action@v2
id: download-wintun
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
with:
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
file-name: wintun.zip
location: ${{ env.downloadPath }}
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
destination: ${{ env.downloadPath }}\wintun.zip
sha256: 07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51
- name: Decompress wintun files
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
run: tar -xvf "${{ env.downloadPath }}\wintun.zip" -C ${{ env.downloadPath }}
- name: Move wintun.dll into dist
run: mv ${{ env.downloadPath }}\wintun\bin\${{ matrix.wintun_arch }}\wintun.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
- name: Download Mesa3D (amd64 only)
uses: carlosperate/download-file-action@v2
id: download-mesa3d
if: matrix.arch == 'amd64'
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
with:
file-url: https://downloads.fdossena.com/Projects/Mesa3D/Builds/MesaForWindows-x64-20.1.8.7z
file-name: mesa3d.7z
location: ${{ env.downloadPath }}
sha256: '71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9'
url: https://pkgs.netbird.io/mesa3d/MesaForWindows-x64-20.1.8.7z
destination: ${{ env.downloadPath }}\mesa3d.7z
sha256: 71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9
- name: Extract Mesa3D driver (amd64 only)
if: matrix.arch == 'amd64'
@@ -547,35 +552,38 @@ jobs:
run: mv ${{ env.downloadPath }}\opengl32.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
- name: Download EnVar plugin for NSIS
uses: carlosperate/download-file-action@v2
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
with:
file-url: https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip
file-name: envar_plugin.zip
location: ${{ github.workspace }}
url: https://pkgs.netbird.io/nsis/EnVar_plugin.zip
destination: ${{ github.workspace }}\envar_plugin.zip
sha256: e9aa92de351345ed82795251d838f1ae9041ba35af9d381a5780c7843b01f56a
- name: Extract EnVar plugin
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/envar_plugin.zip"
- name: Download ShellExecAsUser plugin for NSIS (amd64 only)
uses: carlosperate/download-file-action@v2
if: matrix.arch == 'amd64'
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
with:
file-url: https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z
file-name: ShellExecAsUser_amd64-Unicode.7z
location: ${{ github.workspace }}
url: https://pkgs.netbird.io/nsis/ShellExecAsUser_amd64-Unicode.7z
destination: ${{ github.workspace }}\ShellExecAsUser_amd64-Unicode.7z
sha256: 0a55ea25c7330a92cec028eda8afcaf1b1a7092e0dfb77c21c8f654564b4ff9d
- name: Extract ShellExecAsUser plugin (amd64 only)
if: matrix.arch == 'amd64'
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/ShellExecAsUser_amd64-Unicode.7z"
- name: Build NSIS installer
uses: joncloud/makensis-action@v3.3
with:
additional-plugin-paths: ${{ github.workspace }}/NSIS_Plugins/Plugins
script-file: client/installer.nsis
arguments: "/V4 /DARCH=${{ matrix.arch }}"
shell: pwsh
env:
APPVER: ${{ steps.semver_parser.outputs.major }}.${{ steps.semver_parser.outputs.minor }}.${{ steps.semver_parser.outputs.patch }}.${{ github.run_id }}
run: |
$nsisPluginDir = "C:\Program Files (x86)\NSIS\Plugins\x86-unicode"
$srcPlugins = "${{ github.workspace }}\NSIS_Plugins\Plugins"
Get-ChildItem -Path $srcPlugins -Recurse -Filter *.dll |
Copy-Item -Destination $nsisPluginDir -Force
& "C:\Program Files (x86)\NSIS\makensis.exe" /V4 "/DARCH=${{ matrix.arch }}" client\installer.nsis
if ($LASTEXITCODE -ne 0) { throw "makensis failed with exit code $LASTEXITCODE" }
- name: Rename NSIS installer
run: mv netbird-installer.exe netbird_installer_test_windows_${{ matrix.arch }}.exe
@@ -592,7 +600,7 @@ jobs:
- name: Upload installer artifacts
if: always()
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
with:
name: windows-installer-test-${{ matrix.arch }}
path: |
@@ -611,7 +619,7 @@ jobs:
pull-requests: write
steps:
- name: Create or update PR comment
uses: actions/github-script@v7
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
env:
RELEASE_RESULT: ${{ needs.release.result }}
RELEASE_UI_RESULT: ${{ needs.release_ui.result }}
@@ -703,7 +711,7 @@ jobs:
if: startsWith(github.ref, 'refs/tags/')
steps:
- name: Trigger binaries sign pipelines
uses: benc-uk/workflow-dispatch@v1
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
with:
workflow: Sign bin and installer
repo: netbirdio/sign-pipelines

View File

@@ -14,9 +14,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Trigger main branch sync
uses: benc-uk/workflow-dispatch@v1
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
with:
workflow: sync-main.yml
repo: ${{ secrets.UPSTREAM_REPO }}
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "sha": "${{ github.sha }}" }'
inputs: '{ "sha": "${{ github.sha }}" }'

View File

@@ -3,7 +3,7 @@ name: sync tag
on:
push:
tags:
- 'v*'
- "v*"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
@@ -16,7 +16,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Trigger release tag sync
uses: benc-uk/workflow-dispatch@v1
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
with:
workflow: sync-tag.yml
ref: main
@@ -29,7 +29,7 @@ jobs:
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
steps:
- name: Trigger android-client submodule bump
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
with:
workflow: bump-netbird.yml
ref: main
@@ -42,10 +42,10 @@ jobs:
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
steps:
- name: Trigger ios-client submodule bump
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
with:
workflow: bump-netbird.yml
ref: main
repo: netbirdio/ios-client
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref_name }}" }'
inputs: '{ "tag": "${{ github.ref_name }}" }'

View File

@@ -6,10 +6,10 @@ on:
- main
pull_request:
paths:
- 'infrastructure_files/**'
- '.github/workflows/test-infrastructure-files.yml'
- 'management/cmd/**'
- 'signal/cmd/**'
- "infrastructure_files/**"
- ".github/workflows/test-infrastructure-files.yml"
- "management/cmd/**"
- "signal/cmd/**"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
@@ -20,7 +20,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
store: [ 'sqlite', 'postgres', 'mysql' ]
store: ["sqlite", "postgres", "mysql"]
services:
postgres:
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
@@ -68,15 +68,17 @@ jobs:
run: sudo apt-get install -y curl
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
- name: Cache Go modules
uses: actions/cache@v4
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -139,8 +141,8 @@ jobs:
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_SIGNAL_PORT: 12345
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
NETBIRD_STORE_ENGINE_POSTGRES_DSN: "${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$"
NETBIRD_STORE_ENGINE_MYSQL_DSN: "${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$"
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
@@ -254,7 +256,9 @@ jobs:
run: sudo apt-get install -y jq
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: run script with Zitadel PostgreSQL
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh

View File

@@ -3,9 +3,9 @@ name: update docs
on:
push:
tags:
- 'v*'
- "v*"
paths:
- 'shared/management/http/api/openapi.yml'
- "shared/management/http/api/openapi.yml"
jobs:
trigger_docs_api_update:
@@ -13,10 +13,10 @@ jobs:
if: startsWith(github.ref, 'refs/tags/')
steps:
- name: Trigger API pages generation
uses: benc-uk/workflow-dispatch@v1
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
with:
workflow: generate api pages
repo: netbirdio/docs
ref: "refs/heads/main"
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}" }'
inputs: '{ "tag": "${{ github.ref }}" }'

View File

@@ -19,15 +19,17 @@ jobs:
GOARCH: wasm
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: Install golangci-lint
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
with:
version: latest
install-mode: binary
@@ -42,9 +44,11 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@v5
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
go-version-file: "go.mod"
- name: Build Wasm client
@@ -65,4 +69,3 @@ jobs:
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
exit 1
fi

View File

@@ -1,554 +0,0 @@
package iptables
import (
"errors"
"fmt"
"net"
"slices"
"github.com/coreos/go-iptables/iptables"
"github.com/google/uuid"
ipset "github.com/lrh3321/ipset-go"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
)
const (
tableName = "filter"
// rules chains contains the effective ACL rules
chainNameInputRules = "NETBIRD-ACL-INPUT"
// mangleFwdKey is the entries map key for mangle FORWARD guard rules that prevent
// external DNAT from bypassing ACL rules.
mangleFwdKey = "MANGLE-FORWARD"
)
type aclEntries map[string][][]string
type entry struct {
spec []string
position int
}
type aclManager struct {
iptablesClient *iptables.IPTables
wgIface iFaceMapper
entries aclEntries
optionalEntries map[string][]entry
ipsetStore *ipsetStore
v6 bool
stateManager *statemanager.Manager
}
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
return &aclManager{
iptablesClient: iptablesClient,
wgIface: wgIface,
entries: make(map[string][][]string),
optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(),
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
}, nil
}
func (m *aclManager) init(stateManager *statemanager.Manager) error {
m.stateManager = stateManager
m.seedInitialEntries()
m.seedInitialOptionalEntries()
if err := m.cleanChains(); err != nil {
return fmt.Errorf("clean chains: %w", err)
}
if err := m.createDefaultChains(); err != nil {
return fmt.Errorf("create default chains: %w", err)
}
m.updateState()
return nil
}
func (m *aclManager) AddPeerFiltering(
id []byte,
ip net.IP,
protocol firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
chain := chainNameInputRules
ipsetName = transformIPsetName(ipsetName, sPort, dPort, action)
if m.v6 && ipsetName != "" {
ipsetName += "-v6"
}
proto := protoForFamily(protocol, m.v6)
specs := filterRuleSpecs(ip, proto, sPort, dPort, action, ipsetName)
mangleSpecs := slices.Clone(specs)
mangleSpecs = append(mangleSpecs,
"-i", m.wgIface.Name(),
"-m", "addrtype", "--dst-type", "LOCAL",
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
)
specs = append(specs, "-j", actionToStr(action))
if ipsetName != "" {
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
if err := m.addToIPSet(ipsetName, ip); err != nil {
return nil, fmt.Errorf("add IP to ipset: %w", err)
}
// if ruleset already exists it means we already have the firewall rule
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
ipList.addIP(ip.String())
return []firewall.Rule{&Rule{
ruleID: uuid.New().String(),
ipsetName: ipsetName,
ip: ip.String(),
chain: chain,
specs: specs,
v6: m.v6,
}}, nil
}
if err := m.flushIPSet(ipsetName); err != nil {
if errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("flush ipset %s before use: %v", ipsetName, err)
} else {
log.Errorf("flush ipset %s before use: %v", ipsetName, err)
}
}
if err := m.createIPSet(ipsetName); err != nil {
return nil, fmt.Errorf("create ipset: %w", err)
}
if err := m.addToIPSet(ipsetName, ip); err != nil {
return nil, fmt.Errorf("add IP to ipset: %w", err)
}
ipList := newIpList(ip.String())
m.ipsetStore.addIpList(ipsetName, ipList)
}
ok, err := m.iptablesClient.Exists(tableFilter, chain, specs...)
if err != nil {
return nil, fmt.Errorf("failed to check rule: %w", err)
}
if ok {
return nil, fmt.Errorf("rule already exists")
}
// Insert DROP rules at the beginning, append ACCEPT rules at the end
if action == firewall.ActionDrop {
// Insert at the beginning of the chain (position 1)
err = m.iptablesClient.Insert(tableFilter, chain, 1, specs...)
} else {
err = m.iptablesClient.Append(tableFilter, chain, specs...)
}
if err != nil {
return nil, err
}
if err := m.iptablesClient.Append(tableMangle, chainRTPRE, mangleSpecs...); err != nil {
log.Errorf("failed to add mangle rule: %v", err)
mangleSpecs = nil
}
rule := &Rule{
ruleID: uuid.New().String(),
specs: specs,
mangleSpecs: mangleSpecs,
ipsetName: ipsetName,
ip: ip.String(),
chain: chain,
v6: m.v6,
}
m.updateState()
return []firewall.Rule{rule}, nil
}
// DeletePeerRule from the firewall by rule definition
func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
r, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
}
shouldDestroyIpset := false
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
// delete IP from ruleset IPs list and ipset
if _, ok := ipsetList.ips[r.ip]; ok {
ip := net.ParseIP(r.ip)
if ip == nil {
return fmt.Errorf("parse IP %s", r.ip)
}
if err := m.delFromIPSet(r.ipsetName, ip); err != nil {
return fmt.Errorf("delete ip from ipset: %w", err)
}
delete(ipsetList.ips, r.ip)
}
// if after delete, set still contains other IPs,
// no need to delete firewall rule and we should exit here
if len(ipsetList.ips) != 0 {
return nil
}
// we delete last IP from the set, that means we need to delete
// set itself and associated firewall rule too
m.ipsetStore.deleteIpset(r.ipsetName)
shouldDestroyIpset = true
}
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
}
if r.mangleSpecs != nil {
if err := m.iptablesClient.Delete(tableMangle, chainRTPRE, r.mangleSpecs...); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
if shouldDestroyIpset {
if err := m.destroyIPSet(r.ipsetName); err != nil {
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("destroy empty ipset: %v", err)
} else {
log.Errorf("destroy empty ipset: %v", err)
}
}
}
m.updateState()
return nil
}
func (m *aclManager) Reset() error {
if err := m.cleanChains(); err != nil {
return fmt.Errorf("clean chains: %w", err)
}
m.updateState()
return nil
}
// todo write less destructive cleanup mechanism
func (m *aclManager) cleanChains() error {
ok, err := m.iptablesClient.ChainExists(tableName, chainNameInputRules)
if err != nil {
log.Debugf("failed to list chains: %s", err)
return err
}
if ok {
for _, rule := range m.entries["INPUT"] {
err := m.iptablesClient.DeleteIfExists(tableName, "INPUT", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
for _, rule := range m.entries["FORWARD"] {
err := m.iptablesClient.DeleteIfExists(tableName, "FORWARD", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameInputRules)
if err != nil {
log.Debugf("failed to clear and delete %s chain: %s", chainNameInputRules, err)
return err
}
}
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
if err != nil {
return fmt.Errorf("list chains: %w", err)
}
if ok {
for _, rule := range m.entries["PREROUTING"] {
err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
}
for _, rule := range m.entries[mangleFwdKey] {
if err := m.iptablesClient.DeleteIfExists(tableMangle, chainFORWARD, rule...); err != nil {
log.Errorf("failed to delete mangle FORWARD guard rule: %v, %s", rule, err)
}
}
for _, ipsetName := range m.ipsetStore.ipsetNames() {
if err := m.flushIPSet(ipsetName); err != nil {
if errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("flush ipset %q during reset: %v", ipsetName, err)
} else {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
}
}
if err := m.destroyIPSet(ipsetName); err != nil {
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
log.Debugf("destroy ipset %q during reset: %v", ipsetName, err)
} else {
log.Errorf("destroy ipset %q during reset: %v", ipsetName, err)
}
}
m.ipsetStore.deleteIpset(ipsetName)
}
return nil
}
func (m *aclManager) createDefaultChains() error {
// chain netbird-acl-input-rules
if err := m.iptablesClient.NewChain(tableName, chainNameInputRules); err != nil {
log.Debugf("failed to create '%s' chain: %s", chainNameInputRules, err)
return err
}
for chainName, rules := range m.entries {
// mangle FORWARD guard rules are handled separately below
if chainName == mangleFwdKey {
continue
}
for _, rule := range rules {
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err)
return err
}
}
}
for chainName, entries := range m.optionalEntries {
for _, entry := range entries {
if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil {
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
continue
}
m.entries[chainName] = append(m.entries[chainName], entry.spec)
}
}
clear(m.optionalEntries)
// Insert mangle FORWARD guard rules to prevent external DNAT bypass.
for _, rule := range m.entries[mangleFwdKey] {
if err := m.iptablesClient.AppendUnique(tableMangle, chainFORWARD, rule...); err != nil {
log.Errorf("failed to add mangle FORWARD guard rule: %v", err)
}
}
return nil
}
// seedInitialEntries adds default rules to the entries map, rules are inserted on pos 1, hence the order is reversed.
// We want to make sure our traffic is not dropped by existing rules.
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
func (m *aclManager) seedInitialEntries() {
established := getConntrackEstablished()
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
// Inbound is handled by our ACLs, the rest is dropped.
// For outbound we respect the FORWARD policy. However, we need to allow established/related traffic for inbound rules.
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
// Mangle FORWARD guard: when external DNAT redirects traffic from the wg interface, it
// traverses FORWARD instead of INPUT, bypassing ACL rules. ACCEPT rules in filter FORWARD
// can be inserted above ours. Mangle runs before filter, so these guard rules enforce the
// ACL mark check where it cannot be overridden.
m.appendToEntries(mangleFwdKey, []string{
"-i", m.wgIface.Name(),
"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED",
"-j", "ACCEPT",
})
m.appendToEntries(mangleFwdKey, []string{
"-i", m.wgIface.Name(),
"-m", "conntrack", "--ctstate", "DNAT",
"-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
"-j", "DROP",
})
}
func (m *aclManager) seedInitialOptionalEntries() {
m.optionalEntries["FORWARD"] = []entry{
{
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
position: 2,
},
}
}
func (m *aclManager) appendToEntries(chainName string, spec []string) {
m.entries[chainName] = append(m.entries[chainName], spec)
}
func (m *aclManager) updateState() {
if m.stateManager == nil {
return
}
var currentState *ShutdownState
if existing := m.stateManager.GetState(currentState); existing != nil {
if existingState, ok := existing.(*ShutdownState); ok {
currentState = existingState
}
}
if currentState == nil {
currentState = &ShutdownState{}
}
currentState.Lock()
defer currentState.Unlock()
if m.v6 {
currentState.ACLEntries6 = m.entries
currentState.ACLIPsetStore6 = m.ipsetStore
} else {
currentState.ACLEntries = m.entries
currentState.ACLIPsetStore = m.ipsetStore
}
if err := m.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err)
}
}
// filterRuleSpecs returns the specs of a filtering rule
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
if v6 && protocol == firewall.ProtocolICMP {
return "ipv6-icmp"
}
return string(protocol)
}
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
// don't use IP matching if IP is 0.0.0.0
matchByIP := !ip.IsUnspecified()
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--match-set", ipsetName, "src")
} else {
specs = append(specs, "-s", ip.String())
}
}
if protocol != "all" {
specs = append(specs, "-p", protocol)
}
specs = append(specs, applyPort("--sport", sPort)...)
specs = append(specs, applyPort("--dport", dPort)...)
return specs
}
func actionToStr(action firewall.Action) string {
if action == firewall.ActionAccept {
return "ACCEPT"
}
return "DROP"
}
func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action firewall.Action) string {
if ipsetName == "" {
return ""
}
actionSuffix := ""
if action == firewall.ActionDrop {
actionSuffix = "-drop"
}
switch {
case sPort != nil && dPort != nil:
return ipsetName + "-sport-dport" + actionSuffix
case sPort != nil:
return ipsetName + "-sport" + actionSuffix
case dPort != nil:
return ipsetName + "-dport" + actionSuffix
default:
return ipsetName + actionSuffix
}
}
func (m *aclManager) createIPSet(name string) error {
opts := ipset.CreateOptions{
Replace: true,
}
if m.v6 {
opts.Family = ipset.FamilyIPV6
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err)
}
log.Debugf("created ipset %s with type hash:net", name)
return nil
}
func (m *aclManager) addToIPSet(name string, ip net.IP) error {
cidr := uint8(32)
if ip.To4() == nil {
cidr = 128
}
entry := &ipset.Entry{
IP: ip,
CIDR: cidr,
Replace: true,
}
if err := ipset.Add(name, entry); err != nil {
return fmt.Errorf("add IP to ipset %s: %w", name, err)
}
return nil
}
func (m *aclManager) delFromIPSet(name string, ip net.IP) error {
cidr := uint8(32)
if ip.To4() == nil {
cidr = 128
}
entry := &ipset.Entry{
IP: ip,
CIDR: cidr,
}
if err := ipset.Del(name, entry); err != nil {
return fmt.Errorf("delete IP from ipset %s: %w", name, err)
}
return nil
}
func (m *aclManager) flushIPSet(name string) error {
return ipset.Flush(name)
}
func (m *aclManager) destroyIPSet(name string) error {
return ipset.Destroy(name)
}

View File

@@ -0,0 +1,346 @@
//go:build !android
package iptables
import (
"fmt"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/client/net"
)
func (r *family) cleanUpDefaultForwardRules() error {
if err := r.cleanJumpRules(); err != nil {
return fmt.Errorf("clean jump rules: %w", err)
}
log.Debug("flushing routing related tables")
// Remove jump rules from built-in chains before deleting custom chains,
// otherwise the chain deletion fails with "device or resource busy".
if ok, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput); err != nil {
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
} else if ok {
jumpRule := []string{"-j", chainNATOutput}
if err := r.iptablesClient.Delete(tableNat, chainOUTPUT, jumpRule...); err != nil {
log.Debugf("clean OUTPUT jump rule: %v", err)
}
}
for _, chainInfo := range []struct {
chain string
table string
}{
{chainRTFWDIN, tableFilter},
{chainRTFWDOUT, tableFilter},
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
{chainRTRDR, tableNat},
{chainNATOutput, tableNat},
{chainRTMSSCLAMP, tableMangle},
} {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
if err != nil {
return fmt.Errorf("check chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
} else if ok {
if err = r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
return fmt.Errorf("clear and delete chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
}
}
}
return nil
}
func (r *family) createContainers() error {
for _, chainInfo := range []struct {
chain string
table string
}{
{chainRTFWDIN, tableFilter},
{chainRTFWDOUT, tableFilter},
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
{chainRTRDR, tableNat},
{chainRTMSSCLAMP, tableMangle},
} {
// Fallback: clear chains that survived an unclean shutdown.
if ok, _ := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain); ok {
if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
log.Warnf("clear stale chain %s in %s: %v", chainInfo.chain, chainInfo.table, err)
}
}
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
}
}
if err := r.insertEstablishedRule(chainRTFWDIN); err != nil {
return fmt.Errorf("insert established rule: %w", err)
}
if err := r.insertEstablishedRule(chainRTFWDOUT); err != nil {
return fmt.Errorf("insert established rule: %w", err)
}
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add static nat rules: %w", err)
}
if err := r.addJumpRules(); err != nil {
return fmt.Errorf("add jump rules: %w", err)
}
if err := r.addMSSClampingRules(); err != nil {
log.Errorf("failed to add MSS clamping rules: %s", err)
}
return nil
}
func (r *family) setupDataPlaneMark() error {
var merr *multierror.Error
preRule := []string{
"-i", r.wgIface.Name(),
"-m", "conntrack", "--ctstate", "NEW",
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkIn),
}
if err := r.iptablesClient.AppendUnique(tableMangle, chainPREROUTING, preRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add mangle prerouting rule: %w", err))
} else {
r.rules[markManglePre] = preRule
}
postRule := []string{
"-o", r.wgIface.Name(),
"-m", "conntrack", "--ctstate", "NEW",
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkOut),
}
if err := r.iptablesClient.AppendUnique(tableMangle, chainPOSTROUTING, postRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add mangle postrouting rule: %w", err))
} else {
r.rules[markManglePost] = postRule
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) cleanupDataPlaneMark() error {
var merr *multierror.Error
if preRule, exists := r.rules[markManglePre]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPREROUTING, preRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove mangle prerouting rule: %w", err))
} else {
delete(r.rules, markManglePre)
}
}
if postRule, exists := r.rules[markManglePost]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPOSTROUTING, postRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove mangle postrouting rule: %w", err))
} else {
delete(r.rules, markManglePost)
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) addJumpRules() error {
// Jump to nat chain
natRule := []string{"-j", chainRTNAT}
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
return fmt.Errorf("add nat postrouting jump rule: %w", err)
}
r.rules[jumpNatPost] = natRule
// Jump to mangle prerouting chain
preRule := []string{"-j", chainRTPRE}
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
return fmt.Errorf("add mangle prerouting jump rule: %w", err)
}
r.rules[jumpManglePre] = preRule
// Jump to nat prerouting chain
rdrRule := []string{"-j", chainRTRDR}
if err := r.iptablesClient.Insert(tableNat, chainPREROUTING, 1, rdrRule...); err != nil {
return fmt.Errorf("add nat prerouting jump rule: %w", err)
}
r.rules[jumpNatPre] = rdrRule
return nil
}
func (r *family) cleanJumpRules() error {
for _, ruleKey := range []firewall.RuleID{jumpNatPost, jumpManglePre, jumpNatPre, jumpMSSClamp} {
if rule, exists := r.rules[ruleKey]; exists {
var table, chain string
switch ruleKey {
case jumpNatPost:
table = tableNat
chain = chainPOSTROUTING
case jumpManglePre:
table = tableMangle
chain = chainPREROUTING
case jumpNatPre:
table = tableNat
chain = chainPREROUTING
case jumpMSSClamp:
table = tableMangle
chain = chainFORWARD
default:
return fmt.Errorf("unknown jump rule: %s", ruleKey)
}
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
return fmt.Errorf("delete rule from chain %s in table %s: %w", chain, table, err)
}
delete(r.rules, ruleKey)
}
}
return nil
}
func (r *family) cleanAclChains() error {
ok, err := r.iptablesClient.ChainExists(tableName, chainNameInputRules)
if err != nil {
log.Debugf("failed to list chains: %s", err)
return err
}
if ok {
for _, rule := range r.entries[chainINPUT] {
err := r.iptablesClient.DeleteIfExists(tableName, chainINPUT, rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
for _, rule := range r.entries[chainFORWARD] {
err := r.iptablesClient.DeleteIfExists(tableName, chainFORWARD, rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
err = r.iptablesClient.ClearAndDeleteChain(tableName, chainNameInputRules)
if err != nil {
log.Debugf("failed to clear and delete %s chain: %s", chainNameInputRules, err)
return err
}
}
ok, err = r.iptablesClient.ChainExists(tableMangle, chainPREROUTING)
if err != nil {
return fmt.Errorf("list chains: %w", err)
}
if ok {
for _, rule := range r.entries[chainPREROUTING] {
err := r.iptablesClient.DeleteIfExists(tableMangle, chainPREROUTING, rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
}
for _, rule := range r.entries[mangleFwdKey] {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainFORWARD, rule...); err != nil {
log.Errorf("failed to delete mangle FORWARD guard rule: %v, %s", rule, err)
}
}
return nil
}
func (r *family) createDefaultChains() error {
if err := r.iptablesClient.NewChain(tableName, chainNameInputRules); err != nil {
log.Debugf("failed to create '%s' chain: %s", chainNameInputRules, err)
return err
}
for chain, rules := range r.entries {
// mangle FORWARD guard rules are handled separately below
if chain == mangleFwdKey {
continue
}
for _, rule := range rules {
if err := r.iptablesClient.InsertUnique(tableName, string(chain), 1, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err)
return err
}
}
}
for chain, entries := range r.optionalEntries {
for _, entry := range entries {
if err := r.iptablesClient.InsertUnique(tableName, string(chain), entry.position, entry.spec...); err != nil {
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
continue
}
r.entries[chain] = append(r.entries[chain], entry.spec)
}
}
clear(r.optionalEntries)
// Insert mangle FORWARD guard rules to prevent external DNAT bypass.
for _, rule := range r.entries[mangleFwdKey] {
if err := r.iptablesClient.AppendUnique(tableMangle, chainFORWARD, rule...); err != nil {
log.Errorf("failed to add mangle FORWARD guard rule: %v", err)
}
}
return nil
}
// seedInitialEntries adds default rules to the entries map. Rules are
// inserted at position 1, so the order here is reversed.
//
// Existing FORWARD policy decides outbound traffic towards our
// interface. If FORWARD policy is "drop", we add an
// established/related rule to allow return traffic for inbound rules.
func (r *family) seedInitialEntries() {
established := getConntrackEstablished()
r.appendToEntries(chainINPUT, []string{"-i", r.wgIface.Name(), "-j", "DROP"})
r.appendToEntries(chainINPUT, []string{"-i", r.wgIface.Name(), "-j", chainNameInputRules})
r.appendToEntries(chainINPUT, append([]string{"-i", r.wgIface.Name()}, established...))
r.appendToEntries(chainFORWARD, []string{"-i", r.wgIface.Name(), "-j", "DROP"})
r.appendToEntries(chainFORWARD, []string{"-o", r.wgIface.Name(), "-j", chainRTFWDOUT})
r.appendToEntries(chainFORWARD, []string{"-i", r.wgIface.Name(), "-j", chainRTFWDIN})
// Mangle FORWARD guard: when external DNAT redirects traffic from
// the wg interface, it traverses FORWARD instead of INPUT,
// bypassing ACL rules. ACCEPT rules in filter FORWARD can be
// inserted above ours. Mangle runs before filter, so these guard
// rules enforce the ACL mark check where it cannot be overridden.
r.appendToEntries(mangleFwdKey, []string{
"-i", r.wgIface.Name(),
"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED",
"-j", "ACCEPT",
})
r.appendToEntries(mangleFwdKey, []string{
"-i", r.wgIface.Name(),
"-m", "conntrack", "--ctstate", "DNAT",
"-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
"-j", "DROP",
})
}
func (r *family) seedInitialOptionalEntries() {
r.optionalEntries[chainFORWARD] = []entry{
{
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
position: 2,
},
}
}
func (r *family) appendToEntries(chain chainKey, spec ruleSpec) {
r.entries[chain] = append(r.entries[chain], spec)
}

View File

@@ -0,0 +1,269 @@
//go:build !android
package iptables
import (
"fmt"
"net/netip"
"strconv"
"strings"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
func (r *family) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return nil, err
}
ruleKey := rule.ID()
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
return rule, nil
}
toDestination := rule.TranslatedAddress.String()
switch {
case len(rule.TranslatedPort.Values) == 0:
// no translated port, use original port
case len(rule.TranslatedPort.Values) == 1:
toDestination += fmt.Sprintf(":%d", rule.TranslatedPort.Values[0])
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
// need the "/originalport" suffix to avoid dnat port randomization
toDestination += fmt.Sprintf(":%d-%d/%d", rule.TranslatedPort.Values[0], rule.TranslatedPort.Values[1], rule.DestinationPort.Values[0])
default:
return nil, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
}
proto := strings.ToLower(string(rule.Protocol))
rules := make(map[firewall.RuleID]ruleInfo, 3)
// DNAT rule
dnatRule := []string{
"!", "-i", r.wgIface.Name(),
"-p", proto,
"-j", "DNAT",
"--to-destination", toDestination,
}
dnatRule = append(dnatRule, applyPort("--dport", &rule.DestinationPort)...)
rules[ruleKey+dnatSuffix] = ruleInfo{
table: tableNat,
chain: chainRTRDR,
rule: dnatRule,
}
// SNAT rule
snatRule := []string{
"-o", r.wgIface.Name(),
"-p", proto,
"-d", rule.TranslatedAddress.String(),
"-j", "MASQUERADE",
}
snatRule = append(snatRule, applyPort("--dport", &rule.TranslatedPort)...)
rules[ruleKey+snatSuffix] = ruleInfo{
table: tableNat,
chain: chainRTNAT,
rule: snatRule,
}
// Forward filtering rule, if fwd policy is DROP
forwardRule := []string{
"-o", r.wgIface.Name(),
"-p", proto,
"-d", rule.TranslatedAddress.String(),
"-j", "ACCEPT",
}
forwardRule = append(forwardRule, applyPort("--dport", &rule.TranslatedPort)...)
rules[ruleKey+fwdSuffix] = ruleInfo{
table: tableFilter,
chain: chainRTFWDOUT,
rule: forwardRule,
}
for key, ruleInfo := range rules {
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
if rollbackErr := r.rollbackRules(rules); rollbackErr != nil {
log.Errorf("rollback failed: %v", rollbackErr)
}
return nil, fmt.Errorf("add rule %s: %w", key, err)
}
r.rules[key] = ruleInfo.rule
}
r.updateState()
return rule, nil
}
func (r *family) rollbackRules(rules map[firewall.RuleID]ruleInfo) error {
var merr *multierror.Error
for key, ruleInfo := range rules {
if err := r.iptablesClient.DeleteIfExists(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("rollback rule %s: %w", key, err))
// On rollback error, add to rules map for next cleanup
r.rules[key] = ruleInfo.rule
}
}
if merr != nil {
r.updateState()
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) DeleteDNATRule(rule firewall.Rule) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
ruleKey := rule.ID()
var merr *multierror.Error
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete DNAT rule: %w", err))
}
delete(r.rules, ruleKey+dnatSuffix)
}
if snatRule, exists := r.rules[ruleKey+snatSuffix]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTNAT, snatRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete SNAT rule: %w", err))
}
delete(r.rules, ruleKey+snatSuffix)
}
if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists {
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDOUT, fwdRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
}
delete(r.rules, ruleKey+fwdSuffix)
}
r.updateState()
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
if _, exists := r.rules[ruleID]; exists {
return nil
}
dnatRule := []string{
"-i", r.wgIface.Name(),
"-p", strings.ToLower(protoForFamily(protocol, r.v6)),
"--dport", strconv.Itoa(int(originalPort)),
"-d", localAddr.String(),
"-m", "addrtype", "--dst-type", "LOCAL",
"-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(translatedPort)),
}
info := ruleInfo{
table: tableNat,
chain: chainRTRDR,
rule: dnatRule,
}
if err := r.iptablesClient.Append(info.table, info.chain, info.rule...); err != nil {
return fmt.Errorf("add inbound DNAT rule: %w", err)
}
r.rules[ruleID] = info.rule
r.updateState()
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (r *family) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
return fmt.Errorf("delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
r.updateState()
return nil
}
// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use.
func (r *family) ensureNATOutputChain() error {
if _, exists := r.rules[jumpNatOutput]; exists {
return nil
}
chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput)
if err != nil {
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
}
if !chainExists {
if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil {
return fmt.Errorf("create chain %s: %w", chainNATOutput, err)
}
}
jumpRule := []string{"-j", chainNATOutput}
if err := r.iptablesClient.Insert(tableNat, chainOUTPUT, 1, jumpRule...); err != nil {
if !chainExists {
if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil {
log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr)
}
}
return fmt.Errorf("add OUTPUT jump rule: %w", err)
}
r.rules[jumpNatOutput] = jumpRule
r.updateState()
return nil
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (r *family) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
if _, exists := r.rules[ruleID]; exists {
return nil
}
if err := r.ensureNATOutputChain(); err != nil {
return err
}
dnatRule := []string{
"-p", strings.ToLower(protoForFamily(protocol, localAddr.Is6())),
"--dport", strconv.Itoa(int(originalPort)),
"-d", localAddr.String(),
"-j", "DNAT",
"--to-destination", ":" + strconv.Itoa(int(translatedPort)),
}
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
return fmt.Errorf("add output DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
r.updateState()
return nil
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (r *family) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
if dnatRule, exists := r.rules[ruleID]; exists {
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
return fmt.Errorf("delete output DNAT rule: %w", err)
}
delete(r.rules, ruleID)
}
r.updateState()
return nil
}

View File

@@ -0,0 +1,248 @@
//go:build !android
package iptables
import (
"fmt"
"net/netip"
"github.com/coreos/go-iptables/iptables"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// constants needed to manage and create iptable rules
const (
tableFilter = "filter"
tableName = tableFilter
tableNat = "nat"
tableMangle = "mangle"
// chainNameInputRules is the peer ACL chain that holds installed
// peer-filtering rules.
chainNameInputRules = "NETBIRD-ACL-INPUT"
// mangleFwdKey is the entries map key for mangle FORWARD guard
// rules that prevent external DNAT from bypassing ACL rules.
mangleFwdKey chainKey = "MANGLE-FORWARD"
chainINPUT = "INPUT"
chainPOSTROUTING = "POSTROUTING"
chainPREROUTING = "PREROUTING"
chainFORWARD = "FORWARD"
chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWDIN = "NETBIRD-RT-FWD-IN"
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
chainRTPRE = "NETBIRD-RT-PRE"
chainRTRDR = "NETBIRD-RT-RDR"
chainNATOutput = "NETBIRD-NAT-OUTPUT"
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
jumpManglePre = "jump-mangle-pre"
jumpNatPre = "jump-nat-pre"
jumpNatPost = "jump-nat-post"
jumpNatOutput = "jump-nat-output"
jumpMSSClamp = "jump-mss-clamp"
markManglePre = "mark-mangle-pre"
markManglePost = "mark-mangle-post"
matchSet = "--match-set"
dnatSuffix firewall.RuleID = "_dnat"
snatSuffix firewall.RuleID = "_snat"
fwdSuffix firewall.RuleID = "_fwd"
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
ipv4TCPHeaderSize = 40
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
ipv6TCPHeaderSize = 60
)
type ruleInfo struct {
chain string
table string
rule []string
}
type routeRules map[firewall.RuleID][]string
// ruleSpec is a single iptables rule expressed as its argument list
// (e.g. {"-i", "wg0", "-j", "DROP"}).
type ruleSpec []string
// chainKey identifies the chain a seeded entry belongs to. It holds
// built-in chain names ("INPUT", "FORWARD", "PREROUTING") plus the
// synthetic mangleFwdKey bucket for the mangle FORWARD guard rules.
type chainKey string
// aclEntries maps a chain to the rules seeded into it to jump into or
// guard the netbird ACL chains.
type aclEntries map[chainKey][]ruleSpec
type entry struct {
spec ruleSpec
position int
}
// ipsetCounter is the shared hash:net refcounter used by peer and
// route ACLs alike. The ipset library does not support comments, so
// the key is just the set name (string).
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
// family holds the per-address-family iptables state. One instance
// handles route ACLs, peer ACLs, NAT, DNAT, and MSS clamping for a
// single family; the top-level Manager owns one for v4 and another
// for v6.
type family struct {
iptablesClient *iptables.IPTables
wgIface iFaceMapper
v6 bool
// Peer ACL chain bookkeeping.
entries aclEntries
optionalEntries map[chainKey][]entry
// filters holds peer + route filter rules keyed by content hash.
// AddFilterRule writes here; DeleteFilterRule looks up by id.
filters map[nbid.RuleID]*Rule
ipsetCounter *ipsetCounter
// rules holds NAT, jump, and MSS-clamping rules (auxiliary
// plumbing that isn't a filter rule).
rules routeRules
// Routing / NAT.
legacyManagement bool
mtu uint16
ipFwdState *ipfwdstate.IPForwardingState
stateManager *statemanager.Manager
}
func newFamily(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*family, error) {
r := &family{
iptablesClient: iptablesClient,
wgIface: wgIface,
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
entries: make(aclEntries),
optionalEntries: make(map[chainKey][]entry),
filters: make(map[nbid.RuleID]*Rule),
rules: make(routeRules),
mtu: mtu,
ipFwdState: ipfwdstate.NewIPForwardingState(),
}
r.ipsetCounter = refcounter.New(
func(name string, sources []netip.Prefix) (struct{}, error) {
return struct{}{}, r.createIpSet(name, sources)
},
func(name string, _ struct{}) error {
return r.deleteIpSet(name)
},
)
return r, nil
}
// init wires the family to the state manager and installs both the
// route ACL containers and the peer ACL chain skeleton.
func (r *family) init(stateManager *statemanager.Manager) error {
r.stateManager = stateManager
if err := r.cleanUpDefaultForwardRules(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
if err := r.createContainers(); err != nil {
return fmt.Errorf("create containers: %w", err)
}
if err := r.setupDataPlaneMark(); err != nil {
log.Errorf("failed to set up data plane mark: %v", err)
}
r.seedInitialEntries()
r.seedInitialOptionalEntries()
if err := r.cleanAclChains(); err != nil {
return fmt.Errorf("clean acl chains: %w", err)
}
if err := r.createDefaultChains(); err != nil {
return fmt.Errorf("create default chains: %w", err)
}
r.updateState()
return nil
}
// Reset tears down all firewall state owned by this family. ACL
// chain cleanup runs before route-chain cleanup because the route
// chains are still referenced by FORWARD jumps installed during
// seedInitialEntries; deleting them first would trip EBUSY.
func (r *family) Reset() error {
var merr *multierror.Error
if err := r.cleanAclChains(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.cleanUpDefaultForwardRules(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.ipsetCounter.Flush(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.cleanupDataPlaneMark(); err != nil {
merr = multierror.Append(merr, err)
}
r.rules = make(routeRules)
r.filters = make(map[nbid.RuleID]*Rule)
r.updateState()
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) updateState() {
if r.stateManager == nil {
return
}
var currentState *ShutdownState
if existing := r.stateManager.GetState(currentState); existing != nil {
if existingState, ok := existing.(*ShutdownState); ok {
currentState = existingState
}
}
if currentState == nil {
currentState = &ShutdownState{}
}
currentState.Lock()
defer currentState.Unlock()
if r.v6 {
currentState.RouteRules6 = r.rules
currentState.RouteIPsetCounter6 = r.ipsetCounter
currentState.ACLEntries6 = r.entries
} else {
currentState.RouteRules = r.rules
currentState.RouteIPsetCounter = r.ipsetCounter
currentState.ACLEntries = r.entries
}
if err := r.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err)
}
}

View File

@@ -0,0 +1,331 @@
//go:build !android
package iptables
import (
"fmt"
"net/netip"
"slices"
"strconv"
"strings"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
nbnet "github.com/netbirdio/netbird/client/net"
)
// AddFilterRule installs a packet-filtering rule. With destination
// empty, the rule goes to the peer ACL input chain plus a paired
// mangle PREROUTING rule for the redirect mark. With destination set
// (prefix or named set), it goes to the route ACL forward chain.
// Multi-source rules collapse to one iptables rule via the shared
// hash:net ipset.
func (r *family) AddFilterRule(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
ruleID := nbid.GenerateRuleID(sources, destination, proto, sPort, dPort, action)
if existing, ok := r.filters[ruleID]; ok {
return existing, nil
}
srcMatch, err := r.applySourceMatch(sourceNetwork(sources), sources)
if err != nil {
return nil, fmt.Errorf("apply source match: %w", err)
}
rule, err := r.installFilterRule(ruleID, srcMatch, destination, proto, sPort, dPort, action)
if err != nil {
r.dropSourceMatch(srcMatch)
return nil, err
}
r.filters[ruleID] = rule
r.updateState()
return rule, nil
}
func (r *family) hasRule(id nbid.RuleID) bool {
_, ok := r.filters[id]
return ok
}
// hasDNATRule reports whether this family owns the DNAT rule set for
// the given user id. DNAT rules live in r.rules under the well-known
// "<id>_dnat" key; the lookup here is used by Manager.DeleteDNATRule
// to pick the right family.
func (r *family) hasDNATRule(id firewall.RuleID) bool {
_, ok := r.rules[id+dnatSuffix]
return ok
}
// DeleteFilterRule removes a previously installed filter rule. The
// rule's stored chain/table identify where to delete from; source set
// references are recovered from the spec via findSets and dropped
// from the shared ipset counter.
func (r *family) DeleteFilterRule(rule firewall.Rule) error {
ruleID := rule.ID()
pr, ok := r.filters[ruleID]
if !ok {
log.Debugf("filter rule %s not found", ruleID)
return nil
}
if err := r.iptablesClient.Delete(tableFilter, pr.chain, pr.specs...); err != nil {
return fmt.Errorf("delete rule %s: %w", pr.chain, err)
}
if pr.mangleSpecs != nil {
if err := r.iptablesClient.Delete(tableMangle, chainRTPRE, pr.mangleSpecs...); err != nil {
log.Errorf("delete mangle rule: %v", err)
}
}
r.dropSourceMatch(pr.specs)
delete(r.filters, ruleID)
r.updateState()
return nil
}
// findSets scans an iptables rule spec for "-m set --match-set <name>
// <dir>" fragments and returns the named sets in occurrence order.
// Used at delete time to drop ipsetCounter references.
func findSets(rule []string) []string {
var sets []string
for i, arg := range rule {
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
sets = append(sets, rule[i+3])
}
}
return sets
}
// sourceNetwork classifies a source-prefix list into the firewall.Network
// shape the rest of the spec-builder consumes: empty for match-any, a
// single prefix inline, or an ipset for multiple sources.
func sourceNetwork(sources []netip.Prefix) firewall.Network {
switch {
case len(sources) == 0:
return firewall.Network{}
case len(sources) == 1 && sources[0].Bits() == 0:
return firewall.Network{}
case len(sources) == 1:
return firewall.Network{Prefix: sources[0]}
default:
return firewall.Network{Set: firewall.NewPrefixSet(sources)}
}
}
// applySourceMatch returns the iptables match fragment for the rule's
// source. For a Set it increments the shared ipset's refcount; for a
// Prefix it emits a direct -s match; for the wildcard it returns nil.
func (r *family) applySourceMatch(network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
switch {
case network.IsSet():
if r.ipsetCounter == nil {
return nil, fmt.Errorf("multi-source peer rule requires shared ipset counter")
}
name := r.ipsetName(network.Set.HashedName())
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
return nil, fmt.Errorf("ipset increment %s: %w", name, err)
}
return []string{"-m", "set", matchSet, name, "src"}, nil
case network.IsPrefix():
return []string{"-s", network.Prefix.String()}, nil
default:
return nil, nil
}
}
// dropSourceMatch undoes whatever applySourceMatch reserved. Safe to
// call when the spec is empty or holds only inline matchers. Decrement
// errors are logged but not returned: the filter rule has already been
// deleted at that point and we don't want to leak the deletion.
func (r *family) dropSourceMatch(srcMatch []string) {
if r.ipsetCounter == nil {
return
}
for _, name := range findSets(srcMatch) {
if _, err := r.ipsetCounter.Decrement(name); err != nil {
log.Errorf("rollback ipset decrement %s: %v", name, err)
}
}
}
// decrementSetCounter drops ipset references owned by a raw rule spec
// stored in r.rules (NAT / legacy route entries). It returns an error
// aggregate so the caller surfaces decrement failures.
func (r *family) decrementSetCounter(rule []string) error {
if r.ipsetCounter == nil {
return nil
}
var merr *multierror.Error
for _, name := range findSets(rule) {
if _, err := r.ipsetCounter.Decrement(name); err != nil {
merr = multierror.Append(merr, fmt.Errorf("decrement counter: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
// installFilterRule assembles and writes one iptables filter-chain
// rule. With destination empty the rule lands in the peer ACL input
// chain and a paired mangle PREROUTING rule is added for the redirect
// mark. With destination set the rule lands in the route ACL forward
// chain and there is no mangle pairing.
func (r *family) installFilterRule(
ruleID nbid.RuleID,
srcMatch []string,
destination firewall.Network,
protocol firewall.Protocol,
sPort, dPort *firewall.Port,
action firewall.Action,
) (*Rule, error) {
isRoute := destination.IsPrefix() || destination.IsSet()
proto := protoForFamily(protocol, r.v6)
specs := slices.Clone(srcMatch)
var destExp []string
if isRoute {
var err error
destExp, err = r.applyNetwork("-d", destination, nil)
if err != nil {
return nil, fmt.Errorf("apply network -d: %w", err)
}
specs = append(specs, destExp...)
}
specs = append(specs, filterMatchSpecs(proto, sPort, dPort)...)
var mangleSpecs []string
if !isRoute {
mangleSpecs = slices.Clone(specs)
mangleSpecs = append(mangleSpecs,
"-i", r.wgIface.Name(),
"-m", "addrtype", "--dst-type", "LOCAL",
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
)
}
specs = append(specs, "-j", actionToStr(action))
chain := chainNameInputRules
if isRoute {
chain = chainRTFWDIN
}
// Peer ACL drops are inserted at position 1 so they precede the
// chain's catch-all; route ACL drops are inserted at position 2
// to sit immediately after the established/related accept rule.
var err error
if action == firewall.ActionDrop {
pos := 1
if isRoute {
pos = 2
}
err = r.iptablesClient.Insert(tableFilter, chain, pos, specs...)
} else {
err = r.iptablesClient.Append(tableFilter, chain, specs...)
}
if err != nil {
r.dropSourceMatch(destExp)
return nil, fmt.Errorf("install filter rule on %s: %w", chain, err)
}
if mangleSpecs != nil {
if err := r.iptablesClient.Append(tableMangle, chainRTPRE, mangleSpecs...); err != nil {
log.Errorf("add mangle rule: %v", err)
mangleSpecs = nil
}
}
return &Rule{
id: ruleID,
specs: specs,
mangleSpecs: mangleSpecs,
chain: chain,
v6: r.v6,
}, nil
}
// applyNetwork resolves a firewall.Network into the iptables match
// fragment for the given direction flag (-s or -d). Set networks
// increment the shared ipset refcount; prefixes emit a direct match;
// an empty network returns no spec ("match any").
func (r *family) applyNetwork(flag string, network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
direction := "src"
if flag == "-d" {
direction = "dst"
}
if network.IsSet() {
name := r.ipsetName(network.Set.HashedName())
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err)
}
return []string{"-m", "set", matchSet, name, direction}, nil
}
if network.IsPrefix() {
return []string{flag, network.Prefix.String()}, nil
}
// nolint:nilnil
return nil, nil
}
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
if v6 && protocol == firewall.ProtocolICMP {
return "ipv6-icmp"
}
return string(protocol)
}
// filterMatchSpecs returns the proto/port match fragment for a
// filtering rule. The source match (-s or -m set) is built by the
// caller and prepended.
func filterMatchSpecs(protocol string, sPort, dPort *firewall.Port) (specs []string) {
if protocol != "all" {
specs = append(specs, "-p", protocol)
}
specs = append(specs, applyPort("--sport", sPort)...)
specs = append(specs, applyPort("--dport", dPort)...)
return specs
}
func actionToStr(action firewall.Action) string {
if action == firewall.ActionAccept {
return "ACCEPT"
}
return "DROP"
}
func applyPort(flag string, port *firewall.Port) []string {
if port == nil {
return nil
}
if port.IsRange && len(port.Values) == 2 {
return []string{flag, fmt.Sprintf("%d:%d", port.Values[0], port.Values[1])}
}
if len(port.Values) > 1 {
portList := make([]string, len(port.Values))
for i, p := range port.Values {
portList[i] = strconv.Itoa(int(p))
}
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
}
return []string{flag, strconv.Itoa(int(port.Values[0]))}
}

View File

@@ -0,0 +1,97 @@
//go:build !android
package iptables
import (
"fmt"
"net/netip"
"github.com/hashicorp/go-multierror"
ipset "github.com/lrh3321/ipset-go"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
func (r *family) createIpSet(setName string, sources []netip.Prefix) error {
if err := r.createIPSet(setName); err != nil {
return fmt.Errorf("create set %s: %w", setName, err)
}
for _, prefix := range sources {
if err := r.addPrefixToIPSet(setName, prefix); err != nil {
return fmt.Errorf("add element to set %s: %w", setName, err)
}
}
return nil
}
func (r *family) deleteIpSet(setName string) error {
if err := r.destroyIPSet(setName); err != nil {
return fmt.Errorf("destroy set %s: %w", setName, err)
}
log.Debugf("Deleted unused ipset %s", setName)
return nil
}
func (r *family) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
name := r.ipsetName(set.HashedName())
var merr *multierror.Error
for _, prefix := range prefixes {
if err := r.addPrefixToIPSet(name, prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
}
}
if merr == nil {
log.Debugf("updated set %s with prefixes %v", name, prefixes)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) ipsetName(name string) string {
if r.v6 {
return name + "-v6"
}
return name
}
func (r *family) createIPSet(name string) error {
opts := ipset.CreateOptions{
Replace: true,
}
if r.v6 {
opts.Family = ipset.FamilyIPV6
}
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
return fmt.Errorf("create ipset %s: %w", name, err)
}
log.Debugf("created ipset %s with type hash:net", name)
return nil
}
func (r *family) addPrefixToIPSet(name string, prefix netip.Prefix) error {
addr := prefix.Addr()
ip := addr.AsSlice()
entry := &ipset.Entry{
IP: ip,
CIDR: uint8(prefix.Bits()),
Replace: true,
}
if err := ipset.Add(name, entry); err != nil {
return fmt.Errorf("add prefix to ipset %s: %w", name, err)
}
return nil
}
func (r *family) destroyIPSet(name string) error {
return ipset.Destroy(name)
}

View File

@@ -3,7 +3,6 @@ package iptables
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
@@ -18,25 +17,21 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
type resetter interface {
Reset() error
}
// Manager of iptables firewall
// Manager of iptables firewall. Per-family state (peer ACLs, route
// ACLs, NAT, DNAT, MSS clamping) lives on family; Manager dispatches
// by family and provides the public firewall.Manager surface.
type Manager struct {
mutex sync.Mutex
wgIface iFaceMapper
ipv4Client *iptables.IPTables
aclMgr *aclManager
router *router
family4 *family
rawSupported bool
// IPv6 counterparts, nil when no v6 overlay
ipv6Client *iptables.IPTables
aclMgr6 *aclManager
router6 *router
family6 *family
}
// iFaceMapper defines subset methods of interface required for manager
@@ -57,14 +52,9 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
ipv4Client: iptablesClient,
}
m.router, err = newRouter(iptablesClient, wgIface, mtu)
m.family4, err = newFamily(iptablesClient, wgIface, mtu)
if err != nil {
return nil, fmt.Errorf("create router: %w", err)
}
m.aclMgr, err = newAclManager(iptablesClient, wgIface)
if err != nil {
return nil, fmt.Errorf("create acl manager: %w", err)
return nil, fmt.Errorf("create family: %w", err)
}
if wgIface.Address().HasIPv6() {
@@ -83,19 +73,14 @@ func (m *Manager) createIPv6Components(wgIface iFaceMapper, mtu uint16) error {
}
m.ipv6Client = ip6Client
m.router6, err = newRouter(ip6Client, wgIface, mtu)
m.family6, err = newFamily(ip6Client, wgIface, mtu)
if err != nil {
return fmt.Errorf("create v6 router: %w", err)
return fmt.Errorf("create v6 family: %w", err)
}
// Share the same IP forwarding state with the v4 router, since
// Share the same IP forwarding state with the v4 family, since
// EnableIPForwarding controls both v4 and v6 sysctls.
m.router6.ipFwdState = m.router.ipFwdState
m.aclMgr6, err = newAclManager(ip6Client, wgIface)
if err != nil {
return fmt.Errorf("create v6 acl manager: %w", err)
}
m.family6.ipFwdState = m.family4.ipFwdState
return nil
}
@@ -109,7 +94,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
MTU: m.router.mtu,
MTU: m.family4.mtu,
},
}
stateManager.RegisterState(state)
@@ -141,31 +126,24 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
return nil
}
// initChains initializes router and ACL chains for both address families,
// rolling back on failure.
// initChains initializes the per-family firewall state for both
// address families, rolling back on failure.
func (m *Manager) initChains(stateManager *statemanager.Manager) error {
type initStep struct {
name string
init func(*statemanager.Manager) error
mgr resetter
r *family
}
steps := []initStep{
{"router", m.router.init, m.router},
{"acl manager", m.aclMgr.init, m.aclMgr},
}
steps := []initStep{{"v4", m.family4}}
if m.hasIPv6() {
steps = append(steps,
initStep{"v6 router", m.router6.init, m.router6},
initStep{"v6 acl manager", m.aclMgr6.init, m.aclMgr6},
)
steps = append(steps, initStep{"v6", m.family6})
}
var initialized []initStep
for _, s := range steps {
if err := s.init(stateManager); err != nil {
if err := s.r.init(stateManager); err != nil {
for i := len(initialized) - 1; i >= 0; i-- {
if rerr := initialized[i].mgr.Reset(); rerr != nil {
if rerr := initialized[i].r.Reset(); rerr != nil {
log.Warnf("rollback %s: %v", initialized[i].name, rerr)
}
}
@@ -176,84 +154,78 @@ func (m *Manager) initChains(stateManager *statemanager.Manager) error {
return nil
}
// AddPeerFiltering adds a rule to the firewall
//
// Comment will be ignored because some system this feature is not supported
func (m *Manager) AddPeerFiltering(
id []byte,
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
if ip.To4() != nil {
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
if !m.hasIPv6() {
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
}
return m.aclMgr6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
func (m *Manager) AddRouteFiltering(
// AddFilterRule installs a packet-filtering rule. See firewall.Manager
// docs for destination semantics. Mixed-family source lists are split
// and dispatched to the v4 / v6 backends.
func (m *Manager) AddFilterRule(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort, dPort *firewall.Port,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
) ([]firewall.Rule, error) {
if len(sources) == 0 {
return nil, firewall.ErrNoSources
}
m.mutex.Lock()
defer m.mutex.Unlock()
if isIPv6RouteRule(sources, destination) {
if !m.hasIPv6() {
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
isRoute := destination.IsPrefix() || destination.IsSet()
if isRoute {
fam := m.family4
if isIPv6RouteRule(sources, destination) {
if !m.hasIPv6() {
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
}
fam = m.family6
}
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
rule, err := fam.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
if err != nil {
return nil, err
}
return []firewall.Rule{rule}, nil
}
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
if destination.IsPrefix() {
return destination.Prefix.Addr().Is6()
v4, v6 := splitSourcesByFamily(sources)
var out []firewall.Rule
if len(v4) > 0 {
rule, err := m.family4.AddFilterRule(id, v4, destination, proto, sPort, dPort, action)
if err != nil {
return nil, err
}
out = append(out, rule)
}
return len(sources) > 0 && sources[0].Addr().Is6()
if len(v6) > 0 {
if !m.hasIPv6() {
return nil, fmt.Errorf("add peer filtering for v6 sources %v: %w", v6, firewall.ErrIPv6NotInitialized)
}
rule, err := m.family6.AddFilterRule(id, v6, destination, proto, sPort, dPort, action)
if err != nil {
return nil, err
}
out = append(out, rule)
}
return out, nil
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
// DeleteFilterRule removes a rule previously added via AddFilterRule.
// The rule is looked up by id in each family's filter cache.
func (m *Manager) DeleteFilterRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && isIPv6IptRule(rule) {
return m.aclMgr6.DeletePeerRule(rule)
id := rule.ID()
if m.family4.hasRule(id) {
return m.family4.DeleteFilterRule(rule)
}
return m.aclMgr.DeletePeerRule(rule)
}
func isIPv6IptRule(rule firewall.Rule) bool {
r, ok := rule.(*Rule)
return ok && r.v6
}
// DeleteRouteRule deletes a routing rule.
// Route rules are keyed by content hash. Check v4 first, try v6 if not found.
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && !m.router.hasRule(rule.ID()) {
return m.router6.DeleteRouteRule(rule)
if m.hasIPv6() && m.family6.hasRule(id) {
return m.family6.DeleteFilterRule(rule)
}
return m.router.DeleteRouteRule(rule)
log.Debugf("filter rule %s not found in any family", id)
return nil
}
func (m *Manager) IsServerRouteSupported() bool {
@@ -272,10 +244,10 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if !m.hasIPv6() {
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddNatRule(pair)
return m.family6.AddNatRule(pair)
}
if err := m.router.AddNatRule(pair); err != nil {
if err := m.family4.AddNatRule(pair); err != nil {
return err
}
@@ -284,7 +256,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
// wildcard 0.0.0.0/0 destination where the client resolves DNS.
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.router6.AddNatRule(v6Pair); err != nil {
if err := m.family6.AddNatRule(v6Pair); err != nil {
return fmt.Errorf("add v6 NAT rule: %w", err)
}
}
@@ -300,18 +272,18 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
if !m.hasIPv6() {
return nil
}
return m.router6.RemoveNatRule(pair)
return m.family6.RemoveNatRule(pair)
}
var merr *multierror.Error
if err := m.router.RemoveNatRule(pair); err != nil {
if err := m.family4.RemoveNatRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
}
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
if err := m.family6.RemoveNatRule(v6Pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
}
}
@@ -320,11 +292,11 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
}
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
if err := firewall.SetLegacyManagement(m.family4, isLegacy); err != nil {
return err
}
if m.hasIPv6() {
return firewall.SetLegacyManagement(m.router6, isLegacy)
return firewall.SetLegacyManagement(m.family6, isLegacy)
}
return nil
}
@@ -341,19 +313,13 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
}
if m.hasIPv6() {
if err := m.aclMgr6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 acl manager: %w", err))
}
if err := m.router6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %w", err))
if err := m.family6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 family: %w", err))
}
}
if err := m.aclMgr.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
}
if err := m.router.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
if err := m.family4.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset family: %w", err))
}
// Appending to merr intentionally blocks DeleteState below so ShutdownState
@@ -377,11 +343,11 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
// rules so that packet filtering is handled in userspace instead of by netfilter.
func (m *Manager) AllowNetbird() error {
var merr *multierror.Error
if _, err := m.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
if _, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}, firewall.Network{}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept); err != nil {
merr = multierror.Append(merr, fmt.Errorf("allow netbird v4 interface traffic: %w", err))
}
if m.hasIPv6() {
if _, err := m.AddPeerFiltering(nil, net.IPv6zero, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
if _, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}, firewall.Network{}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept); err != nil {
merr = multierror.Append(merr, fmt.Errorf("allow netbird v6 interface traffic: %w", err))
}
}
@@ -402,14 +368,14 @@ func (m *Manager) SetLogLevel(log.Level) {
}
func (m *Manager) EnableRouting() error {
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
if err := m.family4.ipFwdState.RequestForwarding(); err != nil {
return fmt.Errorf("enable IP forwarding: %w", err)
}
return nil
}
func (m *Manager) DisableRouting() error {
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
if err := m.family4.ipFwdState.ReleaseForwarding(); err != nil {
return fmt.Errorf("disable IP forwarding: %w", err)
}
return nil
@@ -424,9 +390,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
if !m.hasIPv6() {
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddDNATRule(rule)
return m.family6.AddDNATRule(rule)
}
return m.router.AddDNATRule(rule)
return m.family4.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
@@ -434,10 +400,10 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && !m.router.hasRule(rule.ID()+dnatSuffix) {
return m.router6.DeleteDNATRule(rule)
if m.hasIPv6() && !m.family4.hasDNATRule(rule.ID()) {
return m.family6.DeleteDNATRule(rule)
}
return m.router.DeleteDNATRule(rule)
return m.family4.DeleteDNATRule(rule)
}
// UpdateSet updates the set with the given prefixes
@@ -454,12 +420,12 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
}
}
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
if err := m.family4.UpdateSet(set, v4Prefixes); err != nil {
return err
}
if m.hasIPv6() && len(v6Prefixes) > 0 {
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
if err := m.family6.UpdateSet(set, v6Prefixes); err != nil {
return fmt.Errorf("update v6 set: %w", err)
}
}
@@ -476,9 +442,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
if !m.hasIPv6() {
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family4.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
// RemoveInboundDNAT removes an inbound DNAT rule.
@@ -490,9 +456,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
if !m.hasIPv6() {
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family4.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
@@ -504,9 +470,9 @@ func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol
if !m.hasIPv6() {
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family4.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
@@ -518,9 +484,9 @@ func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Proto
if !m.hasIPv6() {
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family4.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
const (
@@ -654,3 +620,22 @@ func (m *Manager) cleanupNoTrackChain() error {
func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}
// splitSourcesByFamily partitions a mixed-family prefix list.
func splitSourcesByFamily(sources []netip.Prefix) (v4, v6 []netip.Prefix) {
for _, p := range sources {
if p.Addr().Is4() || p.Addr().Is4In6() {
v4 = append(v4, p)
} else {
v6 = append(v6, p)
}
}
return v4, v6
}
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
if destination.IsPrefix() {
return destination.Prefix.Addr().Is6()
}
return len(sources) > 0 && sources[0].Addr().Is6()
}

View File

@@ -1,3 +1,5 @@
//go:build integration && !android
package iptables
import (
@@ -72,7 +74,7 @@ func TestIptablesManager(t *testing.T) {
IsRange: true,
Values: []uint16{8043, 8046},
}
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
rule2, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", port, nil, fw.ActionAccept)
require.NoError(t, err, "failed to add rule")
for _, r := range rule2 {
@@ -83,18 +85,16 @@ func TestIptablesManager(t *testing.T) {
t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 {
err := manager.DeletePeerRule(r)
err := manager.DeleteFilterRule(r)
require.NoError(t, err, "failed to delete rule")
}
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
})
t.Run("reset check", func(t *testing.T) {
// add second rule
ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{Values: []uint16{5353}}
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "udp", nil, port, fw.ActionAccept)
require.NoError(t, err, "failed to add rule")
err = manager.Close(nil)
@@ -126,7 +126,7 @@ func TestIptablesManagerDenyRules(t *testing.T) {
ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{Values: []uint16{22}}
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionDrop, "deny-ssh")
rule, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionDrop)
require.NoError(t, err, "failed to add deny rule")
require.NotEmpty(t, rule, "deny rule should not be empty")
@@ -142,11 +142,11 @@ func TestIptablesManagerDenyRules(t *testing.T) {
port := &fw.Port{Values: []uint16{80}}
// Add accept rule first
_, err := manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "accept-http")
_, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
require.NoError(t, err, "failed to add accept rule")
// Add deny rule second for same IP/port - this should take precedence
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionDrop, "deny-http")
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionDrop)
require.NoError(t, err, "failed to add deny rule")
// Inspect the actual iptables rules to verify deny rule comes before accept rule
@@ -159,19 +159,23 @@ func TestIptablesManagerDenyRules(t *testing.T) {
t.Logf(" [%d] %s", i, rule)
}
// Single-source rules emit a direct `-s <ip>/32 ... --dport 80`
// match. Match on that shape instead of the legacy
// per-(action,port) ipset names ("deny-http"/"accept-http")
// that this test predates.
srcMatch := fmt.Sprintf("-s %s/32", ip)
var denyRuleIndex, acceptRuleIndex = -1, -1
for i, rule := range rules {
if strings.Contains(rule, "DROP") {
t.Logf("Found DROP rule at index %d: %s", i, rule)
if strings.Contains(rule, "deny-http") && strings.Contains(rule, "80") {
denyRuleIndex = i
}
if !strings.Contains(rule, srcMatch) || !strings.Contains(rule, "--dport 80") {
continue
}
if strings.Contains(rule, "ACCEPT") {
if strings.Contains(rule, "-j DROP") {
t.Logf("Found DROP rule at index %d: %s", i, rule)
denyRuleIndex = i
}
if strings.Contains(rule, "-j ACCEPT") {
t.Logf("Found ACCEPT rule at index %d: %s", i, rule)
if strings.Contains(rule, "accept-http") && strings.Contains(rule, "80") {
acceptRuleIndex = i
}
acceptRuleIndex = i
}
}
@@ -196,7 +200,6 @@ func TestIptablesManagerIPSet(t *testing.T) {
},
}
// just check on the local interface
manager, err := Create(mock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
@@ -211,26 +214,41 @@ func TestIptablesManagerIPSet(t *testing.T) {
}()
var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) {
t.Run("single source uses direct -s match (no ipset)", func(t *testing.T) {
ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{
Values: []uint16{443},
}
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
rule2, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", port, nil, fw.ActionAccept)
require.NoError(t, err, "failed to add rule")
require.Len(t, rule2, 1)
require.Contains(t, rule2[0].(*Rule).specs, "-s",
"single-source rule should use direct -s match, not an ipset")
require.Empty(t, findSets(rule2[0].(*Rule).specs),
"single-source rule should not allocate a shared ipset")
})
t.Run("delete single-source rule", func(t *testing.T) {
for _, r := range rule2 {
require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
require.Equal(t, r.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
err := manager.DeleteFilterRule(r)
require.NoError(t, err, "failed to delete rule")
}
})
t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 {
err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
t.Run("multi-source uses shared ipset", func(t *testing.T) {
sources := []netip.Prefix{
netip.PrefixFrom(netip.MustParseAddr("10.20.0.3"), 32),
netip.PrefixFrom(netip.MustParseAddr("10.20.0.4"), 32),
netip.PrefixFrom(netip.MustParseAddr("10.20.0.5"), 32),
}
port := &fw.Port{Values: []uint16{8080}}
multi, err := manager.AddFilterRule(nil, sources, fw.Network{}, "tcp", nil, port, fw.ActionAccept)
require.NoError(t, err, "failed to add multi-source rule")
require.Len(t, multi, 1, "multi-source rule must produce exactly one iptables rule")
sets := findSets(multi[0].(*Rule).specs)
require.Len(t, sets, 1, "multi-source rule must reference exactly one ipset")
require.NoError(t, manager.DeleteFilterRule(multi[0]))
})
t.Run("reset check", func(t *testing.T) {
@@ -281,7 +299,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
require.NoError(t, err, "failed to add rule")
}

View File

@@ -0,0 +1,269 @@
//go:build !android
package iptables
import (
"fmt"
"strings"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/client/net"
)
func (r *family) AddNatRule(pair firewall.RouterPair) error {
if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil {
return fmt.Errorf("add legacy routing rule: %w", err)
}
}
if !pair.Masquerade {
return nil
}
if err := r.addNatRule(pair); err != nil {
return fmt.Errorf("add nat rule: %w", err)
}
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("add inverse nat rule: %w", err)
}
r.updateState()
return nil
}
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
func (r *family) RemoveNatRule(pair firewall.RouterPair) error {
if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse nat rule: %w", err)
}
}
if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err)
}
r.updateState()
return nil
}
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *family) addLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if err := r.removeLegacyRouteRule(pair); err != nil {
return err
}
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
if err := r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...); err != nil {
return fmt.Errorf("add legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
r.rules[ruleKey] = rule
return nil
}
func (r *family) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement ipset counter: %w", err)
}
}
return nil
}
// GetLegacyManagement returns the current legacy management mode
func (r *family) GetLegacyManagement() bool {
return r.legacyManagement
}
// SetLegacyManagement sets the route manager to use legacy management mode
func (r *family) SetLegacyManagement(isLegacy bool) {
r.legacyManagement = isLegacy
}
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
func (r *family) RemoveAllLegacyRouteRules() error {
var merr *multierror.Error
for k, rule := range r.rules {
if !strings.HasPrefix(string(k), firewall.ForwardingFormatPrefix) {
continue
}
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %w", err))
} else {
delete(r.rules, k)
}
}
r.updateState()
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) addPostroutingRules() error {
// First rule for outbound masquerade
rule1 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
"!", "-o", "lo",
"-j", routingFinalNatJump,
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule1...); err != nil {
return fmt.Errorf("add outbound masquerade rule: %w", err)
}
r.rules["static-nat-outbound"] = rule1
// Second rule for return traffic masquerade
rule2 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
"-o", r.wgIface.Name(),
"-j", routingFinalNatJump,
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule2...); err != nil {
return fmt.Errorf("add return masquerade rule: %w", err)
}
r.rules["static-nat-return"] = rule2
return nil
}
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
func (r *family) addMSSClampingRules() error {
overhead := uint16(ipv4TCPHeaderSize)
if r.v6 {
overhead = ipv6TCPHeaderSize
}
mss := r.mtu - overhead
// Add jump rule from FORWARD chain in mangle table to our custom chain
jumpRule := []string{
"-j", chainRTMSSCLAMP,
}
if err := r.iptablesClient.Insert(tableMangle, chainFORWARD, 1, jumpRule...); err != nil {
return fmt.Errorf("add jump to MSS clamp chain: %w", err)
}
r.rules[jumpMSSClamp] = jumpRule
ruleOut := []string{
"-o", r.wgIface.Name(),
"-p", "tcp",
"--tcp-flags", "SYN,RST", "SYN",
"-j", "TCPMSS",
"--set-mss", fmt.Sprintf("%d", mss),
}
if err := r.iptablesClient.Append(tableMangle, chainRTMSSCLAMP, ruleOut...); err != nil {
return fmt.Errorf("add outbound MSS clamp rule: %w", err)
}
r.rules["mss-clamp-out"] = ruleOut
return nil
}
func (r *family) insertEstablishedRule(chain string) error {
establishedRule := getConntrackEstablished()
err := r.iptablesClient.Insert(tableFilter, chain, 1, establishedRule...)
if err != nil {
return fmt.Errorf("insert established rule: %w", err)
}
ruleKey := firewall.RuleID("established-" + chain)
r.rules[ruleKey] = establishedRule
return nil
}
func (r *family) addNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("remove existing marking rule for %s: %w", pair.Destination, err)
}
delete(r.rules, ruleKey)
}
markValue := nbnet.PreroutingFwmarkMasquerade
if pair.Inverse {
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
}
rule := []string{"-i", r.wgIface.Name()}
if pair.Inverse {
rule = []string{"!", "-i", r.wgIface.Name()}
}
rule = append(rule,
"-m", "conntrack",
"--ctstate", "NEW",
)
sourceExp, err := r.applyNetwork("-s", pair.Source, nil)
if err != nil {
return fmt.Errorf("apply network -s: %w", err)
}
destExp, err := r.applyNetwork("-d", pair.Destination, nil)
if err != nil {
return fmt.Errorf("apply network -d: %w", err)
}
rule = append(rule, sourceExp...)
rule = append(rule, destExp...)
rule = append(rule,
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
)
// Ensure nat rules come first, so the mark can be overwritten.
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
if err := r.iptablesClient.Insert(tableMangle, chainRTPRE, 1, rule...); err != nil {
// TODO: rollback ipset counter
return fmt.Errorf("add marking rule for %s: %w", pair.Destination, err)
}
r.rules[ruleKey] = rule
r.updateState()
return nil
}
func (r *family) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("remove marking rule for %s: %w", pair.Destination, err)
}
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement ipset counter: %w", err)
}
} else {
log.Debugf("marking rule %s not found", ruleKey)
}
r.updateState()
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
//go:build !android
//go:build integration && !android
package iptables
@@ -31,7 +31,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
manager, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "should return a valid iptables manager")
require.NoError(t, manager.init(nil))
@@ -84,7 +84,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
manager, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
@@ -157,7 +157,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) {
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
manager, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
defer func() {
@@ -219,13 +219,13 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "Failed to create iptables client")
r, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router manager")
r, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create family manager")
require.NoError(t, r.init(nil))
defer func() {
err := r.Reset()
require.NoError(t, err, "Failed to reset router")
require.NoError(t, err, "Failed to reset family")
}()
tests := []struct {
@@ -334,62 +334,30 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
ruleKey, err := r.AddFilterRule(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddFilterRule failed")
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.ID()]
assert.True(t, ok, "Rule not found in internal map")
stored, ok := r.filters[ruleKey.ID()]
require.True(t, ok, "rule not stored in filters")
t.Logf("Internal rule: %v", stored.specs)
// Log the internal rule
t.Logf("Internal rule: %v", rule)
// Check if the rule exists in iptables
exists, err := iptablesClient.Exists(tableFilter, chainRTFWDIN, rule...)
exists, err := iptablesClient.Exists(tableFilter, chainRTFWDIN, stored.specs...)
assert.NoError(t, err, "Failed to check rule existence")
assert.True(t, exists, "Rule not found in iptables")
var source firewall.Network
if len(tt.sources) > 1 {
source.Set = firewall.NewPrefixSet(tt.sources)
} else if len(tt.sources) > 0 {
source.Prefix = tt.sources[0]
}
// Verify rule content
params := routeFilteringRuleParams{
Source: source,
Destination: firewall.Network{Prefix: tt.destination},
Proto: tt.proto,
SPort: tt.sPort,
DPort: tt.dPort,
Action: tt.action,
}
expectedRule, err := r.genRouteRuleSpec(params, nil)
require.NoError(t, err, "Failed to generate expected rule spec")
if tt.expectSet {
setName := firewall.NewPrefixSet(tt.sources).HashedName()
expectedRule, err = r.genRouteRuleSpec(params, nil)
require.NoError(t, err, "Failed to generate expected rule spec with set")
// Check if the set was created
_, exists := r.ipsetCounter.Get(setName)
assert.True(t, exists, "IPSet not created")
assert.NotEmpty(t, findSets(stored.specs), "Rule should reference an ipset")
}
assert.Equal(t, expectedRule, rule, "Rule content mismatch")
// Clean up
err = r.DeleteRouteRule(ruleKey)
require.NoError(t, err, "Failed to delete rule")
require.NoError(t, r.DeleteFilterRule(ruleKey), "Failed to delete rule")
})
}
}
func TestFindSetNameInRule(t *testing.T) {
r := &router{}
testCases := []struct {
name string
rule []string
@@ -430,7 +398,7 @@ func TestFindSetNameInRule(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := r.findSets(tc.rule)
result := findSets(tc.rule)
if len(result) != len(tc.expected) {
t.Errorf("Expected %d sets, got %d. Sets found: %v", len(tc.expected), len(result), result)

View File

@@ -1,18 +1,20 @@
package iptables
// Rule to handle management of rules
type Rule struct {
ruleID string
ipsetName string
import "github.com/netbirdio/netbird/client/firewall/manager"
// Rule to handle management of rules. Source set membership (when the
// rule was built against a shared hash:net ipset) is encoded in specs;
// DeleteFilterRule recovers it via findSets so the refcounter can drop
// the right reference.
type Rule struct {
id manager.RuleID
specs []string
mangleSpecs []string
ip string
chain string
v6 bool
}
// GetRuleID returns the rule id
func (r *Rule) ID() string {
return r.ruleID
// ID returns the rule id
func (r *Rule) ID() manager.RuleID {
return r.id
}

View File

@@ -1,103 +0,0 @@
package iptables
import "encoding/json"
type ipList struct {
ips map[string]struct{}
}
func newIpList(ip string) *ipList {
ips := make(map[string]struct{})
ips[ip] = struct{}{}
return &ipList{
ips: ips,
}
}
func (s *ipList) addIP(ip string) {
s.ips[ip] = struct{}{}
}
// MarshalJSON implements json.Marshaler
func (s *ipList) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
IPs map[string]struct{} `json:"ips"`
}{
IPs: s.ips,
})
}
// UnmarshalJSON implements json.Unmarshaler
func (s *ipList) UnmarshalJSON(data []byte) error {
temp := struct {
IPs map[string]struct{} `json:"ips"`
}{}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
s.ips = temp.IPs
if temp.IPs == nil {
temp.IPs = make(map[string]struct{})
}
return nil
}
type ipsetStore struct {
ipsets map[string]*ipList
}
func newIpsetStore() *ipsetStore {
return &ipsetStore{
ipsets: make(map[string]*ipList),
}
}
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
r, ok := s.ipsets[ipsetName]
return r, ok
}
func (s *ipsetStore) addIpList(ipsetName string, list *ipList) {
s.ipsets[ipsetName] = list
}
func (s *ipsetStore) deleteIpset(ipsetName string) {
delete(s.ipsets, ipsetName)
}
func (s *ipsetStore) ipsetNames() []string {
names := make([]string, 0, len(s.ipsets))
for name := range s.ipsets {
names = append(names, name)
}
return names
}
// MarshalJSON implements json.Marshaler
func (s *ipsetStore) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
IPSets map[string]*ipList `json:"ipsets"`
}{
IPSets: s.ipsets,
})
}
// UnmarshalJSON implements json.Unmarshaler
func (s *ipsetStore) UnmarshalJSON(data []byte) error {
temp := struct {
IPSets map[string]*ipList `json:"ipsets"`
}{}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
s.ipsets = temp.IPSets
if temp.IPSets == nil {
temp.IPSets = make(map[string]*ipList)
}
return nil
}

View File

@@ -32,14 +32,12 @@ type ShutdownState struct {
RouteRules routeRules `json:"route_rules,omitempty"`
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
ACLEntries aclEntries `json:"acl_entries,omitempty"`
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
ACLEntries aclEntries `json:"acl_entries,omitempty"`
// IPv6 counterparts
RouteRules6 routeRules `json:"route_rules_v6,omitempty"`
RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"`
ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"`
ACLIPsetStore6 *ipsetStore `json:"acl_ipset_store_v6,omitempty"`
}
func (s *ShutdownState) Name() string {
@@ -57,17 +55,14 @@ func (s *ShutdownState) Cleanup() error {
}
if s.RouteRules != nil {
ipt.router.rules = s.RouteRules
ipt.family4.rules = s.RouteRules
}
if s.RouteIPsetCounter != nil {
ipt.router.ipsetCounter.LoadData(s.RouteIPsetCounter)
ipt.family4.ipsetCounter.LoadData(s.RouteIPsetCounter)
}
if s.ACLEntries != nil {
ipt.aclMgr.entries = s.ACLEntries
}
if s.ACLIPsetStore != nil {
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
ipt.family4.entries = s.ACLEntries
}
// Clean up v6 state even if the current run has no IPv6.
@@ -79,16 +74,13 @@ func (s *ShutdownState) Cleanup() error {
}
if ipt.hasIPv6() {
if s.RouteRules6 != nil {
ipt.router6.rules = s.RouteRules6
ipt.family6.rules = s.RouteRules6
}
if s.RouteIPsetCounter6 != nil {
ipt.router6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
ipt.family6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
}
if s.ACLEntries6 != nil {
ipt.aclMgr6.entries = s.ACLEntries6
}
if s.ACLIPsetStore6 != nil {
ipt.aclMgr6.ipsetStore = s.ACLIPsetStore6
ipt.family6.entries = s.ACLEntries6
}
}

View File

@@ -0,0 +1,23 @@
//go:build integration && !android
package iptables
import (
"net"
"net/netip"
)
func pfx(ip net.IP) []netip.Prefix {
if ip == nil {
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
if ip.IsUnspecified() {
if ip.To4() != nil {
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
return []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
}
a, _ := netip.AddrFromSlice(ip)
a = a.Unmap()
return []netip.Prefix{netip.PrefixFrom(a, a.BitLen())}
}

View File

@@ -3,7 +3,6 @@ package manager
import (
"errors"
"fmt"
"net"
"net/netip"
"sort"
@@ -16,6 +15,12 @@ import (
// method but the IPv6 firewall components were not initialized.
var ErrIPv6NotInitialized = errors.New("IPv6 firewall not initialized")
// ErrNoSources is returned when AddFilterRule is called with an empty
// source list. "Match any source" must be expressed explicitly with a
// /0 prefix; an empty list is a caller error and is rejected rather
// than silently widening the rule to every source.
var ErrNoSources = errors.New("rule has no sources")
const (
ForwardingFormatPrefix = "netbird-fwd-"
ForwardingFormat = "netbird-fwd-%s-%t"
@@ -23,13 +28,22 @@ const (
NatFormat = "netbird-nat-%s-%t"
)
// RuleID identifies a firewall rule. It is a typed string so the
// compiler catches accidental mixing with arbitrary string keys.
// RuleID itself satisfies the Rule interface so callers can drop a
// bare key into APIs like DeleteFilterRule without wrapping it.
type RuleID string
// ID implements the Rule interface for a bare RuleID.
func (r RuleID) ID() RuleID { return r }
// Rule abstraction should be implemented by each firewall manager
//
// Each firewall type for different OS can use different type
// of the properties to hold data of the created rule
type Rule interface {
// ID returns the rule id
ID() string
ID() RuleID
}
// RuleDirection is the traffic direction which a rule is applied
@@ -101,43 +115,42 @@ type Manager interface {
// AllowNetbird allows netbird interface traffic
AllowNetbird() error
// AddPeerFiltering adds a rule to the firewall
// AddFilterRule adds a packet-filtering rule to the firewall.
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
// If destination is the zero Network, the rule applies to traffic
// inbound to this node, i.e. peer ACL semantics, installed in
// the kernel's input chain. If destination is set (prefix or
// set), the rule applies to forwarded traffic with that
// destination, route ACL semantics, installed in the forward
// chain.
//
// Note: Callers should call Flush() after adding rules to ensure
// they are applied to the kernel and rule handles are refreshed.
AddPeerFiltering(
// sources may mix IPv4 and IPv6 prefixes; backends split by
// family and return one rule per family. "Match any source" must
// be expressed with an explicit /0 prefix; an empty sources list
// is rejected with ErrNoSources so a zeroed list can never widen a
// rule to every source.
//
// Note: callers should call Flush() after adding rules.
AddFilterRule(
id []byte,
ip net.IP,
sources []netip.Prefix,
destination Network,
proto Protocol,
sPort *Port,
dPort *Port,
action Action,
ipsetName string,
) ([]Rule, error)
// DeletePeerRule from the firewall by rule definition
DeletePeerRule(rule Rule) error
// DeleteFilterRule removes a filtering rule previously added via
// AddFilterRule. The rule's own type identifies whether it lives
// in the peer (input) or route (forward) path.
DeleteFilterRule(rule Rule) error
// IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool
IsStateful() bool
AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination Network,
proto Protocol,
sPort, dPort *Port,
action Action,
) (Rule, error)
// DeleteRouteRule deletes a routing rule
DeleteRouteRule(rule Rule) error
// AddNatRule inserts a routing NAT rule
AddNatRule(pair RouterPair) error
@@ -185,8 +198,8 @@ type Manager interface {
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
}
func GenKey(format string, pair RouterPair) string {
return fmt.Sprintf(format, pair.ID, pair.Inverse)
func GenKey(format string, pair RouterPair) RuleID {
return RuleID(fmt.Sprintf(format, pair.ID, pair.Inverse))
}
// LegacyManager defines the interface for legacy management operations
@@ -242,6 +255,20 @@ func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
return merged
}
// UnmapPrefix normalizes a v4-mapped v6 prefix (::ffff:a.b.c.d) to its
// plain v4 form, shifting the prefix length out of the 96-bit mapped
// range. Other prefixes are returned unchanged. Keeping prefixes
// unmapped ensures v4 rules match consistently and the match builders
// read the correct address length.
func UnmapPrefix(p netip.Prefix) netip.Prefix {
addr := p.Addr()
if !addr.Is4In6() {
return p
}
bits := max(p.Bits()-96, 0)
return netip.PrefixFrom(addr.Unmap(), bits)
}
// SortPrefixes sorts the given slice of netip.Prefix in place.
// It sorts first by IP address, then by prefix length (most specific to least specific).
func SortPrefixes(prefixes []netip.Prefix) {

View File

@@ -13,13 +13,13 @@ type ForwardRule struct {
TranslatedPort Port
}
func (r ForwardRule) ID() string {
func (r ForwardRule) ID() RuleID {
id := fmt.Sprintf("%s;%s;%s;%s",
r.Protocol,
r.DestinationPort.String(),
r.TranslatedAddress.String(),
r.TranslatedPort.String())
return id
return RuleID(id)
}
func (r ForwardRule) String() string {

View File

@@ -40,7 +40,8 @@ func (h Set) Comment() string {
// NewPrefixSet generates a unique name for an ipset based on the given prefixes.
func NewPrefixSet(prefixes []netip.Prefix) Set {
// sort for consistent naming
// sort a copy for consistent naming without mutating the caller's slice
prefixes = slices.Clone(prefixes)
SortPrefixes(prefixes)
hash := sha256.New()

View File

@@ -1,713 +0,0 @@
package nftables
import (
"bytes"
"fmt"
"net"
"slices"
"strconv"
"strings"
"time"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/client/net"
)
const (
// rules chains contains the effective ACL rules
chainNameInputRules = "netbird-acl-input-rules"
// filter chains contains the rules that jump to the rules chains
chainNameInputFilter = "netbird-acl-input-filter"
chainNameForwardFilter = "netbird-acl-forward-filter"
chainNameManglePrerouting = "netbird-mangle-prerouting"
chainNameManglePostrouting = "netbird-mangle-postrouting"
)
const flushError = "flush: %w"
type AclManager struct {
rConn *nftables.Conn
sConn *nftables.Conn
wgIface iFaceMapper
routingFwChainName string
af addrFamily
workTable *nftables.Table
chainInputRules *nftables.Chain
chainPrerouting *nftables.Chain
ipsetStore *ipsetStore
rules map[string]*Rule
}
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) {
// sConn is used for creating sets and adding/removing elements from them
// it's differ then rConn (which does create new conn for each flush operation)
// and is permanent. Using same connection for both type of operations
// overloads netlink with high amount of rules ( > 10000)
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
return nil, fmt.Errorf("create nf conn: %w", err)
}
return &AclManager{
rConn: &nftables.Conn{},
sConn: sConn,
wgIface: wgIface,
workTable: table,
routingFwChainName: routingFwChainName,
af: familyForAddr(table.Family == nftables.TableFamilyIPv4),
ipsetStore: newIpsetStore(),
rules: make(map[string]*Rule),
}, nil
}
func (m *AclManager) init(workTable *nftables.Table) error {
m.workTable = workTable
return m.createDefaultChains()
}
// AddPeerFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
func (m *AclManager) AddPeerFiltering(
id []byte,
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
var ipset *nftables.Set
if ipsetName != "" {
var err error
ipset, err = m.addIpToSet(ipsetName, ip)
if err != nil {
return nil, err
}
}
newRules := make([]firewall.Rule, 0, 2)
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset)
if err != nil {
return nil, err
}
newRules = append(newRules, ioRule)
return newRules, nil
}
// DeletePeerRule from the firewall by rule definition
func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
r, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
}
if r.nftSet == nil {
if err := m.rConn.DelRule(r.nftRule); err != nil {
log.Errorf("failed to delete rule: %v", err)
}
if r.mangleRule != nil {
if err := m.rConn.DelRule(r.mangleRule); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
delete(m.rules, r.ID())
return m.rConn.Flush()
}
ips, ok := m.ipsetStore.ips(r.nftSet.Name)
if !ok {
if err := m.rConn.DelRule(r.nftRule); err != nil {
log.Errorf("failed to delete rule: %v", err)
}
if r.mangleRule != nil {
if err := m.rConn.DelRule(r.mangleRule); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
delete(m.rules, r.ID())
return m.rConn.Flush()
}
if _, ok := ips[r.ip.String()]; ok {
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: ipToBytes(r.ip, m.af)}})
if err != nil {
log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err)
}
if err := m.sConn.Flush(); err != nil {
log.Debugf("flush error of set delete element, %s", r.nftSet.Name)
return err
}
m.ipsetStore.DeleteIpFromSet(r.nftSet.Name, r.ip)
}
// if after delete, set still contains other IPs,
// no need to delete firewall rule and we should exit here
if len(ips) > 0 {
return nil
}
if err := m.rConn.DelRule(r.nftRule); err != nil {
log.Errorf("failed to delete rule: %v", err)
}
if r.mangleRule != nil {
if err := m.rConn.DelRule(r.mangleRule); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
if err := m.rConn.Flush(); err != nil {
return err
}
delete(m.rules, r.ID())
m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name)
if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) {
return nil
}
// we delete last IP from the set, that means we need to delete
// set itself and associated firewall rule too
m.rConn.FlushSet(r.nftSet)
m.rConn.DelSet(r.nftSet)
m.ipsetStore.deleteIpset(r.nftSet.Name)
return nil
}
// createDefaultAllowRules creates default allow rules for the input and output chains
func (m *AclManager) createDefaultAllowRules() error {
expIn := []expr.Any{
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainInputRules,
Position: 0,
Exprs: expIn,
})
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
// Flush rule/chain/set operations from the buffer
//
// Method also get all rules after flush and refreshes handle values in the rulesets
func (m *AclManager) Flush() error {
if err := m.flushWithBackoff(); err != nil {
return err
}
if err := m.refreshRuleHandles(m.chainInputRules, false); err != nil {
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
}
if err := m.refreshRuleHandles(m.chainPrerouting, true); err != nil {
log.Errorf("failed to refresh rule handles prerouting chain: %v", err)
}
return nil
}
func (m *AclManager) addIOFiltering(
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipset *nftables.Set,
) (*Rule, error) {
ruleId := generatePeerRuleId(ip, proto, sPort, dPort, action, ipset)
if r, ok := m.rules[ruleId]; ok {
return &Rule{
nftRule: r.nftRule,
mangleRule: r.mangleRule,
nftSet: r.nftSet,
ruleID: r.ruleID,
ip: ip,
}, nil
}
var expressions []expr.Any
if proto != firewall.ProtocolALL {
expressions = append(expressions, &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: m.af.protoOffset,
Len: uint32(1),
})
protoData, err := m.af.protoNum(proto)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %v", err)
}
expressions = append(expressions, &expr.Cmp{
Register: 1,
Op: expr.CmpOpEq,
Data: []byte{protoData},
})
}
rawIP := ipToBytes(ip, m.af)
// check if rawIP contains zeroed IPv4 0.0.0.0 value
// in that case not add IP match expression into the rule definition
if slices.ContainsFunc(rawIP, func(v byte) bool { return v != 0 }) {
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: m.af.srcAddrOffset,
Len: m.af.addrLen,
},
)
// add individual IP for match if no ipset defined
if ipset == nil {
expressions = append(expressions,
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: rawIP,
},
)
} else {
expressions = append(expressions,
&expr.Lookup{
SourceRegister: 1,
SetName: ipset.Name,
SetID: ipset.ID,
},
)
}
}
expressions = append(expressions, applyPort(sPort, true)...)
expressions = append(expressions, applyPort(dPort, false)...)
mainExpressions := slices.Clone(expressions)
switch action {
case firewall.ActionAccept:
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictAccept})
case firewall.ActionDrop:
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
}
userData := []byte(ruleId)
chain := m.chainInputRules
rule := &nftables.Rule{
Table: m.workTable,
Chain: chain,
Exprs: mainExpressions,
UserData: userData,
}
// Insert DROP rules at the beginning, append ACCEPT rules at the end
var nftRule *nftables.Rule
if action == firewall.ActionDrop {
nftRule = m.rConn.InsertRule(rule)
} else {
nftRule = m.rConn.AddRule(rule)
}
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf("flush input rule %s: %v", ruleId, err)
}
ruleStruct := &Rule{
nftRule: nftRule,
// best effort mangle rule
mangleRule: m.createPreroutingRule(expressions, userData),
nftSet: ipset,
ruleID: ruleId,
ip: ip,
}
m.rules[ruleId] = ruleStruct
if ipset != nil {
m.ipsetStore.AddReferenceToIpset(ipset.Name)
}
return ruleStruct, nil
}
func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
if m.chainPrerouting == nil {
log.Warn("prerouting chain is not created")
return nil
}
preroutingExprs := slices.Clone(expressions)
// interface
preroutingExprs = append([]expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
}, preroutingExprs...)
// local destination and mark
preroutingExprs = append(preroutingExprs,
&expr.Fib{
Register: 1,
ResultADDRTYPE: true,
FlagDADDR: true,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
)
nfRule := m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainPrerouting,
Exprs: preroutingExprs,
UserData: userData,
})
if err := m.rConn.Flush(); err != nil {
log.Errorf("failed to flush mangle rule %s: %v", string(userData), err)
return nil
}
return nfRule
}
func (m *AclManager) createDefaultChains() (err error) {
// chainNameInputRules
chain := m.createChain(chainNameInputRules)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
return fmt.Errorf(flushError, err)
}
m.chainInputRules = chain
// netbird-acl-input-filter
// type filter hook input priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
m.addDropExpressions(chain, expr.MetaKeyIIFNAME)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
return err
}
// netbird-acl-forward-filter
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
m.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
return fmt.Errorf(flushError, err)
}
if err := m.allowRedirectedTraffic(chainFwFilter); err != nil {
log.Errorf("failed to allow redirected traffic: %s", err)
}
return nil
}
// Makes redirected traffic originally destined for the host itself (now subject to the forward filter)
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
// netbird peer IP.
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
// Chain is created by route manager
// TODO: move creation to a common place
m.chainPrerouting = &nftables.Chain{
Name: chainNameManglePrerouting,
Table: m.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
}
m.addFwmarkToForward(chainFwFilter)
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
Chain: chainFwFilter,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
})
}
func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
expressions := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: m.routingFwChainName,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: chainFwFilter,
Exprs: expressions,
})
}
func (m *AclManager) createChain(name string) *nftables.Chain {
chain := &nftables.Chain{
Name: name,
Table: m.workTable,
}
chain = m.rConn.AddChain(chain)
insertReturnTrafficRule(m.rConn, m.workTable, chain)
return chain
}
func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain {
polAccept := nftables.ChainPolicyAccept
chain := &nftables.Chain{
Name: name,
Table: m.workTable,
Hooknum: hookNum,
Priority: nftables.ChainPriorityFilter,
Type: nftables.ChainTypeFilter,
Policy: &polAccept,
}
return m.rConn.AddChain(chain)
}
func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{Kind: expr.VerdictDrop},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: chain,
Exprs: expressions,
})
return nil
}
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: to,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: expressions,
})
}
func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) {
ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName)
rawIP := ipToBytes(ip, m.af)
if err != nil {
if ipset, err = m.createSet(m.workTable, ipsetName); err != nil {
return nil, fmt.Errorf("get set name: %v", err)
}
m.ipsetStore.newIpset(ipset.Name)
}
if m.ipsetStore.IsIpInSet(ipset.Name, ip) {
return ipset, nil
}
if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil {
return nil, fmt.Errorf("add set element for the first time: %v", err)
}
m.ipsetStore.AddIpToSet(ipset.Name, ip)
if err := m.sConn.Flush(); err != nil {
return nil, fmt.Errorf("flush add elements: %v", err)
}
return ipset, nil
}
// createSet in given table by name
func (m *AclManager) createSet(table *nftables.Table, name string) (*nftables.Set, error) {
ipset := &nftables.Set{
Name: name,
Table: table,
Dynamic: true,
KeyType: m.af.setKeyType,
}
if err := m.rConn.AddSet(ipset, nil); err != nil {
return nil, fmt.Errorf("create set: %v", err)
}
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf("flush created set: %v", err)
}
return ipset, nil
}
func (m *AclManager) flushWithBackoff() (err error) {
backoff := 4
backoffTime := 1000 * time.Millisecond
for i := 0; ; i++ {
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to flush nftables: %v", err)
if !strings.Contains(err.Error(), "busy") {
return
}
log.Error("failed to flush nftables, retrying...")
if i == backoff-1 {
return err
}
time.Sleep(backoffTime)
backoffTime *= 2
continue
}
break
}
return
}
func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) error {
if m.workTable == nil || chain == nil {
return nil
}
list, err := m.rConn.GetRules(m.workTable, chain)
if err != nil {
return err
}
for _, rule := range list {
if len(rule.UserData) == 0 {
continue
}
split := bytes.Split(rule.UserData, []byte(" "))
r, ok := m.rules[string(split[0])]
if ok {
if mangle {
*r.mangleRule = *rule
} else {
*r.nftRule = *rule
}
}
}
return nil
}
func generatePeerRuleId(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
rulesetID := ":" + string(proto) + ":"
if sPort != nil {
rulesetID += sPort.String()
}
rulesetID += ":"
if dPort != nil {
rulesetID += dPort.String()
}
rulesetID += ":"
rulesetID += strconv.Itoa(int(action))
if ipset == nil {
return "ip:" + ip.String() + rulesetID
}
return "set:" + ipset.Name + rulesetID
}
func ifname(n string) []byte {
b := make([]byte, 16)
copy(b, n+"\x00")
return b
}
// ipToBytes converts net.IP to the correct byte length for the address family.
func ipToBytes(ip net.IP, af addrFamily) []byte {
if af.addrLen == 4 {
return ip.To4()
}
return ip.To16()
}

View File

@@ -3,6 +3,7 @@ package nftables
import (
"fmt"
"net"
"net/netip"
"github.com/google/nftables"
"golang.org/x/sys/unix"
@@ -63,6 +64,14 @@ func familyForAddr(is4 bool) addrFamily {
return afIPv6
}
// zeroPrefix returns the family's unspecified prefix (/0).
func (af addrFamily) zeroPrefix() netip.Prefix {
if af.addrLen == net.IPv4len {
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
}
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
}
// protoNum converts a firewall protocol to the IP protocol number,
// using the correct ICMP variant for the address family.
func (af addrFamily) protoNum(protocol firewall.Protocol) (uint8, error) {

View File

@@ -0,0 +1,885 @@
//go:build !android
package nftables
import (
"bytes"
"errors"
"fmt"
"slices"
"strings"
"time"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/client/net"
)
func (r *family) createContainers() error {
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingFw,
Table: r.workTable,
})
prio := *nftables.ChainPriorityNATSource - 1
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting,
Priority: &prio,
Type: nftables.ChainTypeNAT,
})
r.chains[chainNameRoutingRdr] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingRdr,
Table: r.workTable,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
})
r.chains[chainNameManglePostrouting] = r.conn.AddChain(&nftables.Chain{
Name: chainNameManglePostrouting,
Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
})
r.chains[chainNameManglePrerouting] = r.conn.AddChain(&nftables.Chain{
Name: chainNameManglePrerouting,
Table: r.workTable,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
})
r.chains[chainNameMangleForward] = r.conn.AddChain(&nftables.Chain{
Name: chainNameMangleForward,
Table: r.workTable,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
})
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
r.addPostroutingRules()
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("initialize tables: %v", err)
}
if err := r.addMSSClampingRules(); err != nil {
log.Errorf("failed to add MSS clamping rules: %s", err)
}
if err := r.acceptForwardRules(); err != nil {
log.Errorf("failed to add accept rules for the forward chain: %s", err)
}
if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
if err := r.refreshRulesMap(); err != nil {
log.Errorf("failed to refresh rules: %s", err)
}
return nil
}
// setupDataPlaneMark configures the fwmark for the data plane
func (r *family) setupDataPlaneMark() error {
if r.chains[chainNameManglePrerouting] == nil || r.chains[chainNameManglePostrouting] == nil {
return errors.New("no mangle chains found")
}
ctNew := getCtNewExprs()
preExprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}
preExprs = append(preExprs, ctNew...)
preExprs = append(preExprs,
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkIn),
},
&expr.Ct{
Key: expr.CtKeyMARK,
Register: 1,
SourceRegister: true,
},
)
preNftRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameManglePrerouting],
Exprs: preExprs,
}
r.conn.AddRule(preNftRule)
postExprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}
postExprs = append(postExprs, ctNew...)
postExprs = append(postExprs,
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkOut),
},
&expr.Ct{
Key: expr.CtKeyMARK,
Register: 1,
SourceRegister: true,
},
)
postNftRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameManglePostrouting],
Exprs: postExprs,
}
r.conn.AddRule(postNftRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush: %w", err)
}
return nil
}
func (r *family) acceptForwardRules() error {
var merr *multierror.Error
if err := r.acceptFilterTableRules(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.acceptExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) acceptFilterTableRules() error {
if r.filterTable == nil {
return nil
}
fw := "iptables"
defer func() {
log.Debugf("Used %s to add accept forward and input rules", fw)
}()
// Try iptables first and fallback to nftables if iptables is not available.
// Use the correct protocol (iptables vs ip6tables) for the address family.
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
if err != nil {
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
fw = "nftables"
return r.acceptFilterRulesNftables(r.filterTable)
}
if err := r.acceptFilterRulesIptables(ipt); err != nil {
log.Warnf("iptables failed (table may be incompatible), falling back to nftables: %v", err)
fw = "nftables"
return r.acceptFilterRulesNftables(r.filterTable)
}
return nil
}
func (r *family) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err))
} else {
log.Debugf("added iptables forward rule: %v", rule)
}
}
inputRule := r.getAcceptInputRule()
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err))
} else {
log.Debugf("added iptables input rule: %v", inputRule)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) getAcceptForwardRules() [][]string {
intf := r.wgIface.Name()
return [][]string{
{"-i", intf, "-j", "ACCEPT"},
{"-o", intf, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"},
}
}
func (r *family) getAcceptInputRule() []string {
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
}
// acceptFilterRulesNftables adds accept rules to the ip filter table using nftables.
// This is used when iptables is not available.
func (r *family) acceptFilterRulesNftables(table *nftables.Table) error {
intf := ifname(r.wgIface.Name())
forwardChain := &nftables.Chain{
Name: chainNameForward,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
}
r.insertForwardAcceptRules(forwardChain, intf)
inputChain := &nftables.Chain{
Name: chainNameInput,
Table: table,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
}
r.insertInputAcceptRule(inputChain, intf)
return r.conn.Flush()
}
// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables).
// It dynamically finds chains at call time to handle chains that may have been created after startup.
func (r *family) acceptExternalChainsRules() error {
chains := r.findExternalChains()
if len(chains) == 0 {
return nil
}
intf := ifname(r.wgIface.Name())
for _, chain := range chains {
r.applyExternalChainAccept(chain, intf)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush external chain rules: %w", err)
}
return nil
}
func (r *family) applyExternalChainAccept(chain *nftables.Chain, intf []byte) {
if chain.Hooknum == nil {
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
return
}
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
switch *chain.Hooknum {
case *nftables.ChainHookForward:
r.insertForwardAcceptRules(chain, intf)
case *nftables.ChainHookInput:
r.insertInputAcceptRule(chain, intf)
}
}
func (r *family) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
existing, err := r.existingNetbirdRulesInChain(chain)
if err != nil {
log.Warnf("skip forward accept rules in %s/%s: %v", chain.Table.Name, chain.Name, err)
return
}
r.insertForwardIifRule(chain, intf, existing)
r.insertForwardOifEstablishedRule(chain, intf, existing)
}
func (r *family) insertForwardIifRule(chain *nftables.Chain, intf []byte, existing map[string]bool) {
if existing[userDataAcceptForwardRuleIif] {
return
}
r.conn.InsertRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
},
UserData: []byte(userDataAcceptForwardRuleIif),
})
}
func (r *family) insertForwardOifEstablishedRule(chain *nftables.Chain, intf []byte, existing map[string]bool) {
if existing[userDataAcceptForwardRuleOif] {
return
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
}
r.conn.InsertRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: append(exprs, getEstablishedExprs(2)...),
UserData: []byte(userDataAcceptForwardRuleOif),
})
}
func (r *family) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
existing, err := r.existingNetbirdRulesInChain(chain)
if err != nil {
log.Warnf("skip input accept rule in %s/%s: %v", chain.Table.Name, chain.Name, err)
return
}
if existing[userDataAcceptInputRule] {
return
}
r.conn.InsertRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
},
UserData: []byte(userDataAcceptInputRule),
})
}
// existingNetbirdRulesInChain returns the set of netbird-owned UserData tags present in a chain; callers must bail on error since InsertRule is additive.
func (r *family) existingNetbirdRulesInChain(chain *nftables.Chain) (map[string]bool, error) {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return nil, fmt.Errorf("list rules: %w", err)
}
present := map[string]bool{}
for _, rule := range rules {
if !isNetbirdAcceptRuleTag(rule.UserData) {
continue
}
present[string(rule.UserData)] = true
}
return present, nil
}
func isNetbirdAcceptRuleTag(userData []byte) bool {
switch string(userData) {
case userDataAcceptForwardRuleIif,
userDataAcceptForwardRuleOif,
userDataAcceptInputRule:
return true
}
return false
}
func (r *family) removeAcceptFilterRules() error {
var merr *multierror.Error
if err := r.removeFilterTableRules(); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.removeExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) removeFilterTableRules() error {
if r.filterTable == nil {
return nil
}
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
if err != nil {
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
return r.removeAcceptRulesFromTable(r.filterTable)
}
if err := r.removeAcceptFilterRulesIptables(ipt); err != nil {
log.Debugf("iptables removal failed (table may be incompatible), falling back to nftables: %v", err)
return r.removeAcceptRulesFromTable(r.filterTable)
}
return nil
}
func (r *family) removeAcceptRulesFromTable(table *nftables.Table) error {
chains, err := r.conn.ListChainsOfTableFamily(table.Family)
if err != nil {
return fmt.Errorf("list chains: %v", err)
}
for _, chain := range chains {
if chain.Table.Name != table.Name {
continue
}
if chain.Name != chainNameForward && chain.Name != chainNameInput {
continue
}
if err := r.removeAcceptRulesFromChain(table, chain); err != nil {
return err
}
}
return r.conn.Flush()
}
func (r *family) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error {
rules, err := r.conn.GetRules(table, chain)
if err != nil {
return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err)
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err)
}
}
}
return nil
}
// removeExternalChainsRules removes our accept rules from all external chains.
// This is deterministic - it scans for chains at removal time rather than relying on saved state,
// ensuring cleanup works even after a crash or if chains changed.
func (r *family) removeExternalChainsRules() error {
chains := r.findExternalChains()
if len(chains) == 0 {
return nil
}
for _, chain := range chains {
if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil {
log.Warnf("remove rules from external chain %s/%s: %v", chain.Table.Name, chain.Name, err)
}
}
return r.conn.Flush()
}
// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks.
// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal).
func (r *family) findExternalChains() []*nftables.Chain {
var chains []*nftables.Chain
families := []nftables.TableFamily{r.af.tableFamily, nftables.TableFamilyINet}
for _, family := range families {
allChains, err := r.conn.ListChainsOfTableFamily(family)
if err != nil {
log.Debugf("list chains for family %d: %v", family, err)
continue
}
for _, chain := range allChains {
if r.isExternalChain(chain) {
chains = append(chains, chain)
}
}
}
return chains
}
func (r *family) isExternalChain(chain *nftables.Chain) bool {
if r.workTable != nil && chain.Table.Name == r.workTable.Name {
return false
}
// Skip firewalld-owned chains. Firewalld creates its chains with the
// NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM.
// We delegate acceptance to firewalld by trusting the interface instead.
if chain.Table.Name == firewalldTableName {
return false
}
// Skip iptables/ip6tables-managed tables (adding nft-native rules breaks iptables-save compat)
if (chain.Table.Family == nftables.TableFamilyIPv4 || chain.Table.Family == nftables.TableFamilyIPv6) && isIptablesTable(chain.Table.Name) {
return false
}
if chain.Type != nftables.ChainTypeFilter {
return false
}
if chain.Hooknum == nil {
return false
}
return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput
}
func isIptablesTable(name string) bool {
switch name {
case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity:
return true
}
return false
}
func (r *family) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err))
}
}
inputRule := r.getAcceptInputRule()
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) createDefaultAllowRules() error {
expIn := []expr.Any{
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = r.conn.InsertRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chainInputRules,
Position: 0,
Exprs: expIn,
})
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
// Flush rule/chain/set operations from the buffer
//
// Method also get all rules after flush and refreshes handle values in the rulesets
func (r *family) Flush() error {
if err := r.flushWithBackoff(); err != nil {
return err
}
if err := r.refreshRuleHandles(r.chainInputRules, false); err != nil {
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
}
if err := r.refreshRuleHandles(r.chainPrerouting, true); err != nil {
log.Errorf("failed to refresh rule handles prerouting chain: %v", err)
}
return nil
}
func (r *family) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
if r.chainPrerouting == nil {
log.Warn("prerouting chain is not created")
return nil
}
preroutingExprs := slices.Clone(expressions)
// interface
preroutingExprs = append([]expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}, preroutingExprs...)
// local destination and mark
preroutingExprs = append(preroutingExprs,
&expr.Fib{
Register: 1,
ResultADDRTYPE: true,
FlagDADDR: true,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
)
nfRule := r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chainPrerouting,
Exprs: preroutingExprs,
UserData: userData,
})
if err := r.conn.Flush(); err != nil {
log.Errorf("failed to flush mangle rule %s: %v", string(userData), err)
return nil
}
return nfRule
}
func (r *family) createDefaultChains() (err error) {
// chainNameInputRules
chain := r.createChain(chainNameInputRules)
err = r.conn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
return fmt.Errorf(flushError, err)
}
r.chainInputRules = chain
// netbird-acl-input-filter
// type filter hook input priority filter; policy accept;
chain = r.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
r.addJumpRule(chain, r.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
r.addDropExpressions(chain, expr.MetaKeyIIFNAME)
err = r.conn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
return err
}
// netbird-acl-forward-filter
chainFwFilter := r.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
r.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
r.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME)
err = r.conn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
return fmt.Errorf(flushError, err)
}
if err := r.allowRedirectedTraffic(chainFwFilter); err != nil {
log.Errorf("failed to allow redirected traffic: %s", err)
}
return nil
}
// Makes redirected traffic originally destined for the host itself (now subject to the forward filter)
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
// netbird peer IP.
func (r *family) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
r.chainPrerouting = r.chains[chainNameManglePrerouting]
r.addFwmarkToForward(chainFwFilter)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
func (r *family) addFwmarkToForward(chainFwFilter *nftables.Chain) {
r.conn.InsertRule(&nftables.Rule{
Table: r.workTable,
Chain: chainFwFilter,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
})
}
func (r *family) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
expressions := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: r.routingFwChainName,
},
}
_ = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: chainFwFilter,
Exprs: expressions,
})
}
func (r *family) createChain(name string) *nftables.Chain {
chain := &nftables.Chain{
Name: name,
Table: r.workTable,
}
chain = r.conn.AddChain(chain)
insertReturnTrafficRule(r.conn, r.workTable, chain)
return chain
}
func (r *family) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain {
polAccept := nftables.ChainPolicyAccept
chain := &nftables.Chain{
Name: name,
Table: r.workTable,
Hooknum: hookNum,
Priority: nftables.ChainPriorityFilter,
Type: nftables.ChainTypeFilter,
Policy: &polAccept,
}
return r.conn.AddChain(chain)
}
func (r *family) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Verdict{Kind: expr.VerdictDrop},
}
_ = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: chain,
Exprs: expressions,
})
return nil
}
func (r *family) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: to,
},
}
_ = r.conn.AddRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: expressions,
})
}
func (r *family) flushWithBackoff() (err error) {
backoff := 4
backoffTime := 1000 * time.Millisecond
for i := 0; ; i++ {
err = r.conn.Flush()
if err != nil {
log.Debugf("failed to flush nftables: %v", err)
if !strings.Contains(err.Error(), "busy") {
return
}
log.Error("failed to flush nftables, retrying...")
if i == backoff-1 {
return err
}
time.Sleep(backoffTime)
backoffTime *= 2
continue
}
break
}
return
}
func (r *family) refreshRuleHandles(chain *nftables.Chain, mangle bool) error {
if r.workTable == nil || chain == nil {
return nil
}
list, err := r.conn.GetRules(r.workTable, chain)
if err != nil {
return err
}
for _, rule := range list {
if len(rule.UserData) == 0 {
continue
}
pr, ok := r.filters[firewall.RuleID(rule.UserData)]
if !ok {
continue
}
if mangle {
if pr.mangleRule != nil {
*pr.mangleRule = *rule
}
} else {
*pr.nftRule = *rule
}
}
return nil
}

View File

@@ -0,0 +1,533 @@
//go:build !android
package nftables
import (
"fmt"
"net/netip"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/google/nftables/xt"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
func (r *family) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return nil, err
}
ruleKey := rule.ID()
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
return rule, nil
}
protoNum, err := r.af.protoNum(rule.Protocol)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err)
}
if err := r.addDnatRedirect(rule, protoNum, ruleKey); err != nil {
return nil, err
}
r.addDnatMasq(rule, protoNum, ruleKey)
// Unlike iptables, there's no point in adding "out" rules in the forward chain here as our policy is ACCEPT.
// To overcome DROP policies in other chains, we'd have to add rules to the chains there.
// We also cannot just add "oif <iface> accept" there and filter in our own table as we don't know what is supposed to be allowed.
// TODO: find chains with drop policies and add rules there
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush rules: %w", err)
}
return &rule, nil
}
func (r *family) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, ruleKey firewall.RuleID) error {
dnatExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
}
dnatExprs = append(dnatExprs, applyPort(&rule.DestinationPort, false)...)
// shifted translated port is not supported in nftables, so we hand this over to xtables
if rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2 {
if rule.TranslatedPort.Values[0] != rule.DestinationPort.Values[0] ||
rule.TranslatedPort.Values[1] != rule.DestinationPort.Values[1] {
return r.addXTablesRedirect(dnatExprs, ruleKey, rule)
}
}
additionalExprs, regProtoMin, regProtoMax, err := r.handleTranslatedPort(rule)
if err != nil {
return err
}
dnatExprs = append(dnatExprs, additionalExprs...)
dnatExprs = append(dnatExprs,
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(r.af.tableFamily),
RegAddrMin: 1,
RegProtoMin: regProtoMin,
RegProtoMax: regProtoMax,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingRdr],
Exprs: dnatExprs,
UserData: []byte(ruleKey + dnatSuffix),
}
r.conn.AddRule(dnatRule)
r.rules[ruleKey+dnatSuffix] = dnatRule
return nil
}
func (r *family) handleTranslatedPort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
switch {
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
return r.handlePortRange(rule)
case len(rule.TranslatedPort.Values) == 0:
return r.handleAddressOnly(rule)
case len(rule.TranslatedPort.Values) == 1:
return r.handleSinglePort(rule)
default:
return nil, 0, 0, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
}
}
func (r *family) handlePortRange(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
},
&expr.Immediate{
Register: 3,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[1]),
},
}
return exprs, 2, 3, nil
}
func (r *family) handleAddressOnly(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
}
return exprs, 0, 0, nil
}
func (r *family) handleSinglePort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
},
}
return exprs, 2, 0, nil
}
func (r *family) addXTablesRedirect(dnatExprs []expr.Any, ruleKey firewall.RuleID, rule firewall.ForwardRule) error {
dnatExprs = append(dnatExprs,
&expr.Counter{},
&expr.Target{
Name: "DNAT",
Rev: 2,
Info: &xt.NatRange2{
NatRange: xt.NatRange{
Flags: uint(xt.NatRangeMapIPs | xt.NatRangeProtoSpecified | xt.NatRangeProtoOffset),
MinIP: rule.TranslatedAddress.AsSlice(),
MaxIP: rule.TranslatedAddress.AsSlice(),
MinPort: rule.TranslatedPort.Values[0],
MaxPort: rule.TranslatedPort.Values[1],
},
BasePort: rule.DestinationPort.Values[0],
},
},
)
natTable := &nftables.Table{
Name: tableNat,
Family: r.af.tableFamily,
}
dnatRule := &nftables.Rule{
Table: natTable,
Chain: &nftables.Chain{
Name: chainNameNatPrerouting,
Table: natTable,
Type: nftables.ChainTypeNAT,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
},
Exprs: dnatExprs,
UserData: []byte(ruleKey + dnatSuffix),
}
r.conn.AddRule(dnatRule)
r.rules[ruleKey+dnatSuffix] = dnatRule
return nil
}
func (r *family) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey firewall.RuleID) {
masqExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: r.af.dstAddrOffset,
Len: r.af.addrLen,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
}
masqExprs = append(masqExprs, applyPort(&rule.TranslatedPort, false)...)
masqExprs = append(masqExprs, &expr.Masq{})
masqRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: masqExprs,
UserData: []byte(ruleKey + snatSuffix),
}
r.conn.AddRule(masqRule)
r.rules[ruleKey+snatSuffix] = masqRule
}
func (r *family) DeleteDNATRule(rule firewall.Rule) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
ruleKey := rule.ID()
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
var needsFlush bool
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
if dnatRule.Handle == 0 {
log.Warnf("dnat rule %s has no handle, removing stale entry", ruleKey+dnatSuffix)
delete(r.rules, ruleKey+dnatSuffix)
} else if err := r.conn.DelRule(dnatRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
} else {
needsFlush = true
}
}
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
if masqRule.Handle == 0 {
log.Warnf("snat rule %s has no handle, removing stale entry", ruleKey+snatSuffix)
delete(r.rules, ruleKey+snatSuffix)
} else if err := r.conn.DelRule(masqRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
} else {
needsFlush = true
}
}
if needsFlush {
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
}
if merr == nil {
delete(r.rules, ruleKey+dnatSuffix)
delete(r.rules, ruleKey+snatSuffix)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
if _, exists := r.rules[ruleID]; exists {
return nil
}
protoNum, err := r.af.protoNum(protocol)
if err != nil {
return fmt.Errorf("convert protocol to number: %w", err)
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 3,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 3,
Data: binaryutil.BigEndian.PutUint16(originalPort),
},
}
bits := 32
if localAddr.Is6() {
bits = 128
}
exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...)
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: localAddr.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(translatedPort),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(r.af.tableFamily),
RegAddrMin: 1,
RegProtoMin: 2,
RegProtoMax: 0,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingRdr],
Exprs: exprs,
UserData: []byte(ruleID),
}
r.conn.AddRule(dnatRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("add inbound DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
return nil
}
// RemoveInboundDNAT removes an inbound DNAT rule.
func (r *family) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
rule, exists := r.rules[ruleID]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("inbound DNAT rule %s has no handle, removing stale entry", ruleID)
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
}
delete(r.rules, ruleID)
return nil
}
// ensureNATOutputChain lazily creates the OUTPUT NAT chain on first use.
func (r *family) ensureNATOutputChain() error {
if _, exists := r.chains[chainNameNATOutput]; exists {
return nil
}
r.chains[chainNameNATOutput] = r.conn.AddChain(&nftables.Chain{
Name: chainNameNATOutput,
Table: r.workTable,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
})
if err := r.conn.Flush(); err != nil {
delete(r.chains, chainNameNATOutput)
return fmt.Errorf("create NAT output chain: %w", err)
}
return nil
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
func (r *family) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
if _, exists := r.rules[ruleID]; exists {
return nil
}
if err := r.ensureNATOutputChain(); err != nil {
return err
}
protoNum, err := r.af.protoNum(protocol)
if err != nil {
return fmt.Errorf("convert protocol to number: %w", err)
}
exprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: binaryutil.BigEndian.PutUint16(originalPort),
},
}
bits := 32
if localAddr.Is6() {
bits = 128
}
exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...)
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: localAddr.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(translatedPort),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(r.af.tableFamily),
RegAddrMin: 1,
RegProtoMin: 2,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameNATOutput],
Exprs: exprs,
UserData: []byte(ruleID),
}
r.conn.AddRule(dnatRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("add output DNAT rule: %w", err)
}
r.rules[ruleID] = dnatRule
return nil
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
func (r *family) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
rule, exists := r.rules[ruleID]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("output DNAT rule %s has no handle, removing stale entry", ruleID)
delete(r.rules, ruleID)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete output DNAT rule %s: %w", ruleID, err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete output DNAT rule: %w", err)
}
delete(r.rules, ruleID)
return nil
}

View File

@@ -0,0 +1,249 @@
//go:build !android
package nftables
import (
"fmt"
"net/netip"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
)
const (
tableNat = "nat"
tableMangle = "mangle"
tableRaw = "raw"
tableSecurity = "security"
chainNameNatPrerouting = "PREROUTING"
chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameRoutingRdr = "netbird-rt-redirect"
chainNameNATOutput = "netbird-nat-output"
chainNameForward = "FORWARD"
chainNameMangleForward = "netbird-mangle-forward"
// Peer ACL chain names.
chainNameInputRules = "netbird-acl-input-rules"
chainNameInputFilter = "netbird-acl-input-filter"
chainNameForwardFilter = "netbird-acl-forward-filter"
chainNameManglePrerouting = "netbird-mangle-prerouting"
chainNameManglePostrouting = "netbird-mangle-postrouting"
flushError = "flush: %w"
firewalldTableName = "firewalld"
userDataAcceptForwardRuleIif = "frwacceptiif"
userDataAcceptForwardRuleOif = "frwacceptoif"
userDataAcceptInputRule = "inputaccept"
dnatSuffix firewall.RuleID = "_dnat"
snatSuffix firewall.RuleID = "_snat"
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
ipv4TCPHeaderSize = 40
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
ipv6TCPHeaderSize = 60
// maxPrefixesSet 1638 prefixes start to fail, taking some margin
maxPrefixesSet = 1500
refreshRulesMapError = "refresh rules map: %w"
)
var (
errFilterTableNotFound = fmt.Errorf("'filter' table not found")
)
type setInput struct {
set firewall.Set
prefixes []netip.Prefix
}
// family holds the per-address-family nftables state. One instance
// handles route ACLs, peer ACLs, NAT, DNAT, and MSS clamping for a
// single family; the top-level Manager owns one for v4 and another
// for v6. The name predates the peer-ACL absorption; it's effectively
// the per-family backend now.
type family struct {
conn *nftables.Conn
workTable *nftables.Table
filterTable *nftables.Table
chains map[string]*nftables.Chain
// filters holds peer + route filter rules keyed by content hash.
// AddFilterRule writes here; DeleteFilterRule looks up by id.
filters map[firewall.RuleID]*Rule
// rules holds NAT, DNAT, and external accept rules (auxiliary
// plumbing that isn't a filter rule).
rules map[firewall.RuleID]*nftables.Rule
// Peer ACL chain handles.
chainInputRules *nftables.Chain
chainPrerouting *nftables.Chain
routingFwChainName string
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
af addrFamily
wgIface iFaceMapper
ipFwdState *ipfwdstate.IPForwardingState
legacyManagement bool
mtu uint16
}
func newFamily(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*family, error) {
r := &family{
conn: &nftables.Conn{},
workTable: workTable,
chains: make(map[string]*nftables.Chain),
filters: make(map[firewall.RuleID]*Rule),
rules: make(map[firewall.RuleID]*nftables.Rule),
routingFwChainName: chainNameRoutingFw,
af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4),
wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(),
mtu: mtu,
}
r.ipsetCounter = refcounter.New(
r.createIpSet,
r.deleteIpSet,
)
var err error
r.filterTable, err = r.loadFilterTable()
if err != nil {
log.Debugf("ip filter table not found: %v", err)
}
return r, nil
}
func (r *family) init(workTable *nftables.Table) error {
r.workTable = workTable
if err := r.removeAcceptFilterRules(); err != nil {
log.Errorf("failed to clean up rules from filter table: %s", err)
}
if err := r.createContainers(); err != nil {
return fmt.Errorf("create containers: %w", err)
}
if err := r.setupDataPlaneMark(); err != nil {
log.Errorf("failed to set up data plane mark: %v", err)
}
if err := r.createDefaultChains(); err != nil {
return fmt.Errorf("create default acl chains: %w", err)
}
return nil
}
// Reset cleans existing nftables filter table rules from the system
func (r *family) Reset() error {
// clear without deleting the ipsets, the nf table will be deleted by the caller
r.ipsetCounter.Clear()
var merr *multierror.Error
if err := r.removeAcceptFilterRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
}
if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.removeNatPreroutingRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(r.af.tableFamily)
if err != nil {
return nil, fmt.Errorf("list tables: %w", err)
}
for _, table := range tables {
if table.Name == "filter" {
return table, nil
}
}
return nil, errFilterTableNotFound
}
func hookName(hook *nftables.ChainHook) string {
if hook == nil {
return "unknown"
}
switch *hook {
case *nftables.ChainHookForward:
return chainNameForward
case *nftables.ChainHookInput:
return chainNameInput
default:
return fmt.Sprintf("hook(%d)", *hook)
}
}
func familyName(family nftables.TableFamily) string {
switch family {
case nftables.TableFamilyIPv4:
return "ip"
case nftables.TableFamilyIPv6:
return "ip6"
case nftables.TableFamilyINet:
return "inet"
default:
return fmt.Sprintf("family(%d)", family)
}
}
func (r *family) iptablesProto() iptables.Protocol {
if r.af.tableFamily == nftables.TableFamilyIPv6 {
return iptables.ProtocolIPv6
}
return iptables.ProtocolIPv4
}
func (r *family) refreshRulesMap() error {
var merr *multierror.Error
newRules := make(map[firewall.RuleID]*nftables.Rule)
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("list rules for chain %s: %w", chain.Name, err))
// preserve existing entries for this chain since we can't verify their state
for k, v := range r.rules {
if v.Chain != nil && v.Chain.Name == chain.Name {
newRules[k] = v
}
}
continue
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
newRules[firewall.RuleID(rule.UserData)] = rule
}
}
}
r.rules = newRules
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -0,0 +1,444 @@
//go:build !android
package nftables
import (
"fmt"
"net"
"net/netip"
"slices"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
)
// AddFilterRule installs one nftables packet-filter rule. With
// destination empty the rule goes to the peer ACL input chain plus a
// paired prerouting mangle rule for the redirect mark. With
// destination set (prefix or named set) it goes to the route ACL
// forward chain. Multi-source rules collapse to one nftables rule
// backed by the shared refcounted hash:net set.
func (r *family) AddFilterRule(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
isRoute := destination.IsPrefix() || destination.IsSet()
// Peer ACL with no sources is the v4 wildcard. Route paths never
// hit this branch; their callers always carry a destination so the
// source list can legitimately be empty.
if !isRoute && len(sources) == 0 {
sources = []netip.Prefix{r.af.zeroPrefix()}
}
ruleID := nbid.GenerateRuleID(sources, destination, proto, sPort, dPort, action)
if existing, ok := r.filters[ruleID]; ok {
return existing, nil
}
srcExprs, err := r.applyNetwork(sourceNetwork(sources), sources, true)
if err != nil {
return nil, fmt.Errorf("apply source: %w", err)
}
exprs, err := r.buildFilterExprs(srcExprs, destination, proto, sPort, dPort, isRoute)
if err != nil {
r.dropSourceMatch(srcExprs)
return nil, err
}
mainExprs := slices.Clone(exprs)
verdict := expr.VerdictAccept
if action == firewall.ActionDrop {
verdict = expr.VerdictDrop
}
mainExprs = append(mainExprs, &expr.Verdict{Kind: verdict})
chain := r.chainInputRules
if isRoute {
chain = r.chains[chainNameRoutingFw]
}
userData := []byte(ruleID)
nftRule := &nftables.Rule{
Table: r.workTable,
Chain: chain,
Exprs: mainExprs,
UserData: userData,
}
if action == firewall.ActionDrop {
nftRule = r.conn.InsertRule(nftRule)
} else {
nftRule = r.conn.AddRule(nftRule)
}
if err := r.conn.Flush(); err != nil {
r.dropSourceMatch(exprs)
return nil, fmt.Errorf(flushError, err)
}
rule := &Rule{
nftRule: nftRule,
sources: sources,
id: ruleID,
}
if !isRoute {
rule.mangleRule = r.createPreroutingRule(exprs, userData)
}
r.filters[ruleID] = rule
if isRoute {
log.Debugf("added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v",
sources, destination, proto, sPort, dPort, action)
}
return rule, nil
}
// buildFilterExprs assembles the non-verdict portion of a filter
// rule. Route rules use Meta L4PROTO + Counter; peer rules read the
// IP-header protocol byte via Payload and skip the counter, matching
// the historical shapes so the per-rule kernel state is identical to
// pre-unification.
func (r *family) buildFilterExprs(
srcExprs []expr.Any,
destination firewall.Network,
proto firewall.Protocol,
sPort, dPort *firewall.Port,
isRoute bool,
) ([]expr.Any, error) {
var exprs []expr.Any
if isRoute {
exprs = append(exprs, srcExprs...)
destExprs, err := r.applyNetwork(destination, nil, false)
if err != nil {
return nil, fmt.Errorf("apply destination: %w", err)
}
exprs = append(exprs, destExprs...)
if proto != firewall.ProtocolALL {
protoNum, err := r.af.protoNum(proto)
if err != nil {
r.dropSourceMatch(destExprs)
return nil, fmt.Errorf("convert protocol to number: %w", err)
}
exprs = append(exprs,
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{protoNum}},
)
exprs = append(exprs, applyPort(sPort, true)...)
exprs = append(exprs, applyPort(dPort, false)...)
}
exprs = append(exprs, &expr.Counter{})
return exprs, nil
}
// Peer ACL shape: protocol header read first, then source, then ports.
if proto != firewall.ProtocolALL {
protoNum, err := r.af.protoNum(proto)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err)
}
exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: r.af.protoOffset,
Len: 1,
},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{protoNum}},
)
}
exprs = append(exprs, srcExprs...)
exprs = append(exprs, applyPort(sPort, true)...)
exprs = append(exprs, applyPort(dPort, false)...)
return exprs, nil
}
func (r *family) hasRule(id firewall.RuleID) bool {
_, ok := r.filters[id]
return ok
}
func (r *family) hasDNATRule(id firewall.RuleID) bool {
_, ok := r.rules[id+dnatSuffix]
return ok
}
// DeleteFilterRule removes a previously installed filter rule. Source
// set references are recovered from the stored rule's expressions via
// findSets and dropped from the shared refcounter.
func (r *family) DeleteFilterRule(rule firewall.Rule) error {
ruleID := rule.ID()
pr, ok := r.filters[ruleID]
if !ok {
log.Debugf("filter rule %s not found", ruleID)
return nil
}
// A freshly added rule carries no handle until it is read back from
// the kernel, and Flush only refreshes the peer chains. Pull live
// handles for this rule's chain before deciding it is stale so route
// rules (which Flush never refreshes) can actually be deleted.
if pr.nftRule.Handle == 0 {
if err := r.refreshRuleHandles(pr.nftRule.Chain, false); err != nil {
log.Warnf("refresh handles for chain %s: %v", pr.nftRule.Chain.Name, err)
}
if pr.mangleRule != nil {
if err := r.refreshRuleHandles(r.chainPrerouting, true); err != nil {
log.Warnf("refresh mangle handles: %v", err)
}
}
}
if pr.nftRule.Handle == 0 {
log.Warnf("filter rule %s has no handle, removing stale entry", ruleID)
r.dropSourceMatch(pr.nftRule.Exprs)
delete(r.filters, ruleID)
return nil
}
if err := r.conn.DelRule(pr.nftRule); err != nil {
log.Errorf("queue rule delete: %v", err)
}
if pr.mangleRule != nil {
if err := r.conn.DelRule(pr.mangleRule); err != nil {
log.Errorf("queue mangle rule delete: %v", err)
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush delete %s: %w", ruleID, err)
}
r.dropSourceMatch(pr.nftRule.Exprs)
delete(r.filters, ruleID)
return nil
}
func (r *family) decrementSetCounter(rule *nftables.Rule) error {
if r.ipsetCounter == nil {
return nil
}
sets := findSets(rule)
var merr *multierror.Error
for _, setName := range sets {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
merr = multierror.Append(merr, fmt.Errorf("decrement set counter: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
// findSets scans an nftables rule's expressions for expr.Lookup and
// returns the named sets in occurrence order. Used at delete time to
// drop ipsetCounter references; peer and route ACLs go through it.
func findSets(rule *nftables.Rule) []string {
var sets []string
for _, e := range rule.Exprs {
if lookup, ok := e.(*expr.Lookup); ok {
sets = append(sets, lookup.SetName)
}
}
return sets
}
func (r *family) applyNetwork(
network firewall.Network,
setPrefixes []netip.Prefix,
isSource bool,
) ([]expr.Any, error) {
if network.IsSet() {
exprs, err := r.getIpSet(network.Set, setPrefixes, isSource)
if err != nil {
return nil, fmt.Errorf("source: %w", err)
}
return exprs, nil
}
if network.IsPrefix() {
return r.applyPrefix(network.Prefix, isSource), nil
}
return nil, nil
}
// applyPrefix generates nftables expressions for a CIDR prefix
func (r *family) applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
return prefixMatchExprs(r.af, prefix, isSource)
}
// prefixMatchExprs is the family-aware match sequence for a CIDR
// prefix. /0 returns nil; a host prefix (full bit length for the
// family) skips the bitwise step since the mask is all-ones. Shared
// between family and aclManager so both treat single prefixes
// identically.
func prefixMatchExprs(af addrFamily, prefix netip.Prefix, isSource bool) []expr.Any {
offset := af.dstAddrOffset
if isSource {
offset = af.srcAddrOffset
}
ones := prefix.Bits()
if ones == 0 {
return nil
}
payload := &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offset,
Len: af.addrLen,
}
cmp := &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: prefix.Masked().Addr().AsSlice(),
}
if ones == af.totalBits {
return []expr.Any{payload, cmp}
}
mask := net.CIDRMask(ones, af.totalBits)
xor := make([]byte, af.addrLen)
return []expr.Any{
payload,
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: af.addrLen,
Mask: mask,
Xor: xor,
},
cmp,
}
}
func applyPort(port *firewall.Port, isSource bool) []expr.Any {
if port == nil {
return nil
}
var exprs []expr.Any
offset := uint32(2) // Default offset for destination port
if isSource {
offset = 0 // Offset for source port
}
exprs = append(exprs, &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: offset,
Len: 2,
})
if port.IsRange && len(port.Values) == 2 {
// Handle port range
exprs = append(exprs,
&expr.Range{
Op: expr.CmpOpEq,
Register: 1,
FromData: binaryutil.BigEndian.PutUint16(port.Values[0]),
ToData: binaryutil.BigEndian.PutUint16(port.Values[1]),
},
)
} else {
// Handle single port or multiple ports
for i, p := range port.Values {
if i > 0 {
// Add a bitwise OR operation between port checks
exprs = append(exprs, &expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: []byte{0x00, 0x00, 0xff, 0xff},
Xor: []byte{0x00, 0x00, 0x00, 0x00},
})
}
exprs = append(exprs, &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(p),
})
}
}
return exprs
}
func getCtNewExprs() []expr.Any {
return []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
}
}
// sourceNetwork classifies a source-prefix list into the firewall.Network
// shape the rest of the spec-builder consumes: empty for match-any, a
// single prefix inline, or an ipset for multiple sources.
func sourceNetwork(sources []netip.Prefix) firewall.Network {
switch {
case len(sources) == 0:
return firewall.Network{}
case len(sources) == 1 && sources[0].Bits() == 0:
return firewall.Network{}
case len(sources) == 1:
return firewall.Network{Prefix: sources[0]}
default:
return firewall.Network{Set: firewall.NewPrefixSet(sources)}
}
}
// dropSourceMatch undoes whatever the source/destination match
// reserved. Safe to call when the spec is empty or holds only inline
// matchers.
func (r *family) dropSourceMatch(exprs []expr.Any) {
if r.ipsetCounter == nil {
return
}
for _, e := range exprs {
lookup, ok := e.(*expr.Lookup)
if !ok {
continue
}
if _, err := r.ipsetCounter.Decrement(lookup.SetName); err != nil {
log.Errorf("rollback ipset decrement %s: %v", lookup.SetName, err)
}
}
}
func ifname(n string) []byte {
b := make([]byte, 16)
copy(b, n+"\x00")
return b
}

View File

@@ -0,0 +1,176 @@
//go:build !android
package nftables
import (
"encoding/binary"
"fmt"
"net/netip"
"github.com/google/nftables"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
)
func (r *family) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bool) ([]expr.Any, error) {
ref, err := r.ipsetCounter.Increment(set.HashedName(), setInput{
set: set,
prefixes: prefixes,
})
if err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err)
}
return r.getIpSetExprs(ref, isSource)
}
func (r *family) createIpSet(setName string, input setInput) (*nftables.Set, error) {
// overlapping prefixes will result in an error, so we need to merge them
prefixes := firewall.MergeIPRanges(input.prefixes)
nfset := &nftables.Set{
Name: setName,
Comment: input.set.Comment(),
Table: r.workTable,
// required for prefixes
Interval: true,
KeyType: r.af.setKeyType,
}
elements := r.convertPrefixesToSet(prefixes)
nElements := len(elements)
maxElements := maxPrefixesSet * 2
initialElements := elements[:min(maxElements, nElements)]
if err := r.conn.AddSet(nfset, initialElements); err != nil {
return nil, fmt.Errorf("error adding set %s: %w", setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Debugf("Created new ipset: %s with %d initial prefixes (total prefixes %d)", setName, len(initialElements)/2, len(prefixes))
var subEnd int
for subStart := maxElements; subStart < nElements; subStart += maxElements {
subEnd = min(subStart+maxElements, nElements)
subElement := elements[subStart:subEnd]
nSubPrefixes := len(subElement) / 2
log.Tracef("Adding new prefixes (%d) in ipset: %s", nSubPrefixes, setName)
if err := r.conn.SetAddElements(nfset, subElement); err != nil {
return nil, fmt.Errorf("error adding prefixes (%d) to set %s: %w", nSubPrefixes, setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Debugf("Added new prefixes (%d) in ipset: %s", nSubPrefixes, setName)
}
log.Infof("Created new ipset: %s with %d prefixes", setName, len(prefixes))
return nfset, nil
}
func (r *family) convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
var elements []nftables.SetElement
for _, prefix := range prefixes {
// nftables needs half-open intervals [firstIP, lastIP) for prefixes
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
firstIP := prefix.Addr()
lastIP := calculateLastIP(prefix).Next()
elements = append(elements,
// the nft tool also adds a zero-address IntervalEnd element, see https://github.com/google/nftables/issues/247
// nftables.SetElement{Key: make([]byte, r.af.addrLen), IntervalEnd: true},
nftables.SetElement{Key: firstIP.AsSlice()},
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
)
}
return elements
}
// calculateLastIP determines the last IP in a given prefix.
func calculateLastIP(prefix netip.Prefix) netip.Addr {
masked := prefix.Masked()
if masked.Addr().Is4() {
hostMask := ^uint32(0) >> masked.Bits()
lastIP := uint32FromNetipAddr(masked.Addr()) | hostMask
return netip.AddrFrom4(uint32ToBytes(lastIP))
}
// IPv6: set host bits to all 1s
b := masked.Addr().As16()
bits := masked.Bits()
for i := bits; i < 128; i++ {
b[i/8] |= 1 << (7 - i%8)
}
return netip.AddrFrom16(b)
}
// Utility function to convert netip.Addr to uint32.
func uint32FromNetipAddr(addr netip.Addr) uint32 {
b := addr.As4()
return binary.BigEndian.Uint32(b[:])
}
// Utility function to convert uint32 to a netip-compatible byte slice.
func uint32ToBytes(ip uint32) [4]byte {
var b [4]byte
binary.BigEndian.PutUint32(b[:], ip)
return b
}
func (r *family) deleteIpSet(setName string, nfset *nftables.Set) error {
r.conn.DelSet(nfset)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
log.Debugf("Deleted unused ipset %s", setName)
return nil
}
func (r *family) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
nfset, err := r.conn.GetSetByName(r.workTable, set.HashedName())
if err != nil {
return fmt.Errorf("get set %s: %w", set.HashedName(), err)
}
elements := r.convertPrefixesToSet(prefixes)
if err := r.conn.SetAddElements(nfset, elements); err != nil {
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
return nil
}
func (r *family) getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
// dst offset by default
offset := r.af.dstAddrOffset
if isSource {
// src offset
offset = r.af.srcAddrOffset
}
return []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offset,
Len: r.af.addrLen,
},
&expr.Lookup{
SourceRegister: 1,
SetName: ref.Out.Name,
SetID: ref.Out.ID,
},
}, nil
}

View File

@@ -1,85 +0,0 @@
package nftables
import (
"net"
)
type ipsetStore struct {
ipsetReference map[string]int
ipsets map[string]map[string]struct{} // ipsetName -> list of ips
}
func newIpsetStore() *ipsetStore {
return &ipsetStore{
ipsetReference: make(map[string]int),
ipsets: make(map[string]map[string]struct{}),
}
}
func (s *ipsetStore) ips(ipsetName string) (map[string]struct{}, bool) {
r, ok := s.ipsets[ipsetName]
return r, ok
}
func (s *ipsetStore) newIpset(ipsetName string) map[string]struct{} {
s.ipsetReference[ipsetName] = 0
ipList := make(map[string]struct{})
s.ipsets[ipsetName] = ipList
return ipList
}
func (s *ipsetStore) deleteIpset(ipsetName string) {
delete(s.ipsetReference, ipsetName)
delete(s.ipsets, ipsetName)
}
func (s *ipsetStore) DeleteIpFromSet(ipsetName string, ip net.IP) {
ipList, ok := s.ipsets[ipsetName]
if !ok {
return
}
delete(ipList, ip.String())
}
func (s *ipsetStore) AddIpToSet(ipsetName string, ip net.IP) {
ipList, ok := s.ipsets[ipsetName]
if !ok {
return
}
ipList[ip.String()] = struct{}{}
}
func (s *ipsetStore) IsIpInSet(ipsetName string, ip net.IP) bool {
ipList, ok := s.ipsets[ipsetName]
if !ok {
return false
}
_, ok = ipList[ip.String()]
return ok
}
func (s *ipsetStore) AddReferenceToIpset(ipsetName string) {
s.ipsetReference[ipsetName]++
}
func (s *ipsetStore) DeleteReferenceFromIpSet(ipsetName string) {
r, ok := s.ipsetReference[ipsetName]
if !ok {
return
}
if r == 0 {
return
}
s.ipsetReference[ipsetName]--
}
func (s *ipsetStore) HasReferenceToSet(ipsetName string) bool {
if _, ok := s.ipsetReference[ipsetName]; !ok {
return false
}
if s.ipsetReference[ipsetName] == 0 {
return false
}
return true
}

View File

@@ -3,7 +3,6 @@ package nftables
import (
"context"
"fmt"
"net"
"net/netip"
"os"
"sync"
@@ -45,18 +44,17 @@ type iFaceMapper interface {
Address() wgaddr.Address
}
// Manager of iptables firewall
// Manager of nftables firewall. Per-family state (peer ACLs, route
// ACLs, NAT, DNAT, MSS clamping) lives on family; Manager dispatches
// by family and provides the public firewall.Manager surface.
type Manager struct {
mutex sync.Mutex
rConn *nftables.Conn
wgIface iFaceMapper
router *router
aclManager *AclManager
// IPv6 counterparts, nil when no v6 overlay
router6 *router
aclManager6 *AclManager
family4 *family
// IPv6 counterpart, nil when no v6 overlay.
family6 *family
notrackOutputChain *nftables.Chain
notrackPreroutingChain *nftables.Chain
@@ -75,14 +73,9 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
workTable := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}
var err error
m.router, err = newRouter(workTable, wgIface, mtu)
m.family4, err = newFamily(workTable, wgIface, mtu)
if err != nil {
return nil, fmt.Errorf("create router: %w", err)
}
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
if err != nil {
return nil, fmt.Errorf("create acl manager: %w", err)
return nil, fmt.Errorf("create family: %w", err)
}
if wgIface.Address().HasIPv6() {
@@ -100,26 +93,21 @@ func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mt
workTable6 := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6}
var err error
m.router6, err = newRouter(workTable6, wgIface, mtu)
m.family6, err = newFamily(workTable6, wgIface, mtu)
if err != nil {
return fmt.Errorf("create v6 router: %w", err)
return fmt.Errorf("create v6 family: %w", err)
}
// Share the same IP forwarding state with the v4 router, since
// EnableIPForwarding controls both v4 and v6 sysctls.
m.router6.ipFwdState = m.router.ipFwdState
m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw)
if err != nil {
return fmt.Errorf("create v6 acl manager: %w", err)
}
m.family6.ipFwdState = m.family4.ipFwdState
return nil
}
// hasIPv6 reports whether the manager has IPv6 components initialized.
func (m *Manager) hasIPv6() bool {
return m.router6 != nil
return m.family6 != nil
}
func (m *Manager) initIPv6() error {
@@ -128,12 +116,8 @@ func (m *Manager) initIPv6() error {
return fmt.Errorf("create v6 work table: %w", err)
}
if err := m.router6.init(workTable6); err != nil {
return fmt.Errorf("v6 router init: %w", err)
}
if err := m.aclManager6.init(workTable6); err != nil {
return fmt.Errorf("v6 acl manager init: %w", err)
if err := m.family6.init(workTable6); err != nil {
return fmt.Errorf("v6 family init: %w", err)
}
return nil
@@ -162,13 +146,13 @@ func (m *Manager) reconcileExternalChains() error {
defer m.mutex.Unlock()
var merr *multierror.Error
if m.router != nil {
if err := m.router.acceptExternalChainsRules(); err != nil {
if m.family4 != nil {
if err := m.family4.acceptExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("v4: %w", err))
}
}
if m.hasIPv6() {
if err := m.router6.acceptExternalChainsRules(); err != nil {
if err := m.family6.acceptExternalChainsRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("v6: %w", err))
}
}
@@ -187,12 +171,8 @@ func (m *Manager) initFirewall() (err error) {
}
}()
if err := m.router.init(workTable); err != nil {
return fmt.Errorf("router init: %w", err)
}
if err := m.aclManager.init(workTable); err != nil {
return fmt.Errorf("acl manager init: %w", err)
if err := m.family4.init(workTable); err != nil {
return fmt.Errorf("family init: %w", err)
}
if m.hasIPv6() {
@@ -220,7 +200,7 @@ func (m *Manager) persistState(stateManager *statemanager.Manager) {
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
MTU: m.router.mtu,
MTU: m.family4.mtu,
},
}); err != nil {
log.Errorf("failed to update state: %v", err)
@@ -235,12 +215,12 @@ func (m *Manager) persistState(stateManager *statemanager.Manager) {
// rollbackInit performs best-effort cleanup of already-initialized state when Init fails partway through.
func (m *Manager) rollbackInit() {
if err := m.router.Reset(); err != nil {
log.Warnf("rollback router: %v", err)
if err := m.family4.Reset(); err != nil {
log.Warnf("rollback family: %v", err)
}
if m.hasIPv6() {
if err := m.router6.Reset(); err != nil {
log.Warnf("rollback v6 router: %v", err)
if err := m.family6.Reset(); err != nil {
log.Warnf("rollback v6 family: %v", err)
}
}
if err := m.cleanupNetbirdTables(); err != nil {
@@ -251,118 +231,108 @@ func (m *Manager) rollbackInit() {
}
}
// AddPeerFiltering rule to the firewall
// AddFilterRule installs a packet-filtering rule.
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
func (m *Manager) AddPeerFiltering(
id []byte,
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
if ip.To4() != nil {
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
if !m.hasIPv6() {
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
}
return m.aclManager6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
}
func (m *Manager) AddRouteFiltering(
// Destination semantics: zero Network → input chain (peer ACL);
// set Network → forward chain (route ACL).
//
// Sources can mix IPv4 and IPv6 prefixes; they're split by family
// and dispatched to the per-family backends.
func (m *Manager) AddFilterRule(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort, dPort *firewall.Port,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
) ([]firewall.Rule, error) {
if len(sources) == 0 {
return nil, firewall.ErrNoSources
}
m.mutex.Lock()
defer m.mutex.Unlock()
if isIPv6RouteRule(sources, destination) {
if !m.hasIPv6() {
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
isRoute := destination.IsPrefix() || destination.IsSet()
if isRoute {
fam := m.family4
if isIPv6RouteRule(sources, destination) {
if !m.hasIPv6() {
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
}
fam = m.family6
}
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
rule, err := fam.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
if err != nil {
return nil, err
}
return []firewall.Rule{rule}, nil
}
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if m.hasIPv6() && isIPv6Rule(rule) {
return m.aclManager6.DeletePeerRule(rule)
v4Sources, v6Sources := splitSourcesByFamily(sources)
if len(v6Sources) > 0 && !m.hasIPv6() {
return nil, fmt.Errorf("add peer filtering for v6 sources %v: %w", v6Sources, firewall.ErrIPv6NotInitialized)
}
return m.aclManager.DeletePeerRule(rule)
}
func isIPv6Rule(rule firewall.Rule) bool {
r, ok := rule.(*Rule)
return ok && r.nftRule != nil && r.nftRule.Table != nil && r.nftRule.Table.Family == nftables.TableFamilyIPv6
}
// isIPv6RouteRule determines whether a route rule belongs to the v6 table.
// For static routes, the destination prefix determines the family. For dynamic
// routes (DomainSet), the sources determine the family since management
// duplicates dynamic rules per family.
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
if destination.IsPrefix() {
return destination.Prefix.Addr().Is6()
var out []firewall.Rule
if len(v4Sources) > 0 {
rule, err := m.family4.AddFilterRule(id, v4Sources, destination, proto, sPort, dPort, action)
if err != nil {
return nil, err
}
out = append(out, rule)
}
return len(sources) > 0 && sources[0].Addr().Is6()
if len(v6Sources) > 0 {
rule, err := m.family6.AddFilterRule(id, v6Sources, destination, proto, sPort, dPort, action)
if err != nil {
return nil, err
}
out = append(out, rule)
}
return out, nil
}
// DeleteRouteRule deletes a routing rule. Route rules live in exactly one
// router; the cached maps are normally authoritative, so the kernel is only
// consulted when neither map knows about the rule.
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
// DeleteFilterRule removes a filtering rule. The rule is looked up by
// id in each family's filter cache.
func (m *Manager) DeleteFilterRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
id := rule.ID()
r, err := m.routerForRuleID(id, (*router).hasRule)
if err != nil {
return err
if m.family4.hasRule(id) {
return m.family4.DeleteFilterRule(rule)
}
return r.DeleteRouteRule(rule)
if m.hasIPv6() && m.family6.hasRule(id) {
return m.family6.DeleteFilterRule(rule)
}
log.Debugf("filter rule %s not found in any family", id)
return nil
}
// routerForRuleID picks the router holding the rule with the given id, using
// familyForRuleID picks the family holding the rule with the given id, using
// the supplied lookup. If the cached maps disagree (or both miss), it refreshes
// from the kernel once and re-checks before falling back to the v4 router.
func (m *Manager) routerForRuleID(id string, has func(*router, string) bool) (*router, error) {
if has(m.router, id) {
return m.router, nil
// from the kernel once and re-checks before falling back to the v4 family.
func (m *Manager) familyForRuleID(id firewall.RuleID, has func(*family, firewall.RuleID) bool) (*family, error) {
if has(m.family4, id) {
return m.family4, nil
}
if m.hasIPv6() && has(m.router6, id) {
return m.router6, nil
if m.hasIPv6() && has(m.family6, id) {
return m.family6, nil
}
if !m.hasIPv6() {
return m.router, nil
return m.family4, nil
}
if err := m.router.refreshRulesMap(); err != nil {
if err := m.family4.refreshRulesMap(); err != nil {
return nil, fmt.Errorf("refresh v4 rules: %w", err)
}
if err := m.router6.refreshRulesMap(); err != nil {
if err := m.family6.refreshRulesMap(); err != nil {
return nil, fmt.Errorf("refresh v6 rules: %w", err)
}
if has(m.router6, id) && !has(m.router, id) {
return m.router6, nil
if has(m.family6, id) && !has(m.family4, id) {
return m.family6, nil
}
return m.router, nil
return m.family4, nil
}
func (m *Manager) IsServerRouteSupported() bool {
@@ -381,10 +351,10 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if !m.hasIPv6() {
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddNatRule(pair)
return m.family6.AddNatRule(pair)
}
if err := m.router.AddNatRule(pair); err != nil {
if err := m.family4.AddNatRule(pair); err != nil {
return err
}
@@ -396,7 +366,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
// so the eventual cleanup still works.
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.router6.AddNatRule(v6Pair); err != nil {
if err := m.family6.AddNatRule(v6Pair); err != nil {
return fmt.Errorf("add v6 NAT rule: %w", err)
}
}
@@ -412,18 +382,18 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
if !m.hasIPv6() {
return nil
}
return m.router6.RemoveNatRule(pair)
return m.family6.RemoveNatRule(pair)
}
var merr *multierror.Error
if err := m.router.RemoveNatRule(pair); err != nil {
if err := m.family4.RemoveNatRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
}
if m.hasIPv6() && pair.Dynamic {
v6Pair := firewall.ToV6NatPair(pair)
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
if err := m.family6.RemoveNatRule(v6Pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
}
}
@@ -445,11 +415,11 @@ func (m *Manager) AllowNetbird() error {
m.mutex.Lock()
defer m.mutex.Unlock()
if err := m.aclManager.createDefaultAllowRules(); err != nil {
if err := m.family4.createDefaultAllowRules(); err != nil {
return fmt.Errorf("create default allow rules: %w", err)
}
if m.hasIPv6() {
if err := m.aclManager6.createDefaultAllowRules(); err != nil {
if err := m.family6.createDefaultAllowRules(); err != nil {
return fmt.Errorf("create v6 default allow rules: %w", err)
}
}
@@ -466,11 +436,11 @@ func (m *Manager) AllowNetbird() error {
// SetLegacyManagement sets the route manager to use legacy management
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
if err := firewall.SetLegacyManagement(m.family4, isLegacy); err != nil {
return err
}
if m.hasIPv6() {
return firewall.SetLegacyManagement(m.router6, isLegacy)
return firewall.SetLegacyManagement(m.family6, isLegacy)
}
return nil
}
@@ -484,13 +454,13 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
var merr *multierror.Error
if err := m.router.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset router: %v", err))
if err := m.family4.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset family: %w", err))
}
if m.hasIPv6() {
if err := m.router6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %v", err))
if err := m.family6.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset v6 family: %w", err))
}
}
@@ -530,14 +500,14 @@ func (m *Manager) SetLogLevel(log.Level) {
}
func (m *Manager) EnableRouting() error {
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
if err := m.family4.ipFwdState.RequestForwarding(); err != nil {
return fmt.Errorf("enable IP forwarding: %w", err)
}
return nil
}
func (m *Manager) DisableRouting() error {
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
if err := m.family4.ipFwdState.ReleaseForwarding(); err != nil {
return fmt.Errorf("disable IP forwarding: %w", err)
}
return nil
@@ -551,12 +521,12 @@ func (m *Manager) Flush() error {
m.mutex.Lock()
defer m.mutex.Unlock()
if err := m.aclManager.Flush(); err != nil {
if err := m.family4.Flush(); err != nil {
return err
}
if m.hasIPv6() {
if err := m.aclManager6.Flush(); err != nil {
if err := m.family6.Flush(); err != nil {
return fmt.Errorf("flush v6 acl: %w", err)
}
}
@@ -577,9 +547,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
if !m.hasIPv6() {
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddDNATRule(rule)
return m.family6.AddDNATRule(rule)
}
return m.router.AddDNATRule(rule)
return m.family4.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
@@ -587,7 +557,7 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
r, err := m.routerForRuleID(rule.ID(), (*router).hasDNATRule)
r, err := m.familyForRuleID(rule.ID(), (*family).hasDNATRule)
if err != nil {
return err
}
@@ -608,12 +578,12 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
}
}
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
if err := m.family4.UpdateSet(set, v4Prefixes); err != nil {
return err
}
if m.hasIPv6() && len(v6Prefixes) > 0 {
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
if err := m.family6.UpdateSet(set, v6Prefixes); err != nil {
return fmt.Errorf("update v6 set: %w", err)
}
}
@@ -630,9 +600,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
if !m.hasIPv6() {
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family4.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
// RemoveInboundDNAT removes an inbound DNAT rule.
@@ -644,9 +614,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
if !m.hasIPv6() {
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family4.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
}
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
@@ -658,9 +628,9 @@ func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol
if !m.hasIPv6() {
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family4.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
@@ -672,9 +642,9 @@ func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Proto
if !m.hasIPv6() {
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
}
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
return m.family4.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
}
const (
@@ -903,3 +873,31 @@ func getEstablishedExprs(register uint32) []expr.Any {
},
}
}
// splitSourcesByFamily partitions a mixed-family prefix list into v4
// and v6 buckets. v4-mapped v6 prefixes are normalized to v4 so the
// match builder reads the correct address length. An empty input maps
// to "match any" on v4 only; callers wanting v6 wildcards must include
// ::/0 explicitly.
func splitSourcesByFamily(sources []netip.Prefix) (v4, v6 []netip.Prefix) {
for _, p := range sources {
addr := p.Addr()
if addr.Is4() || addr.Is4In6() {
v4 = append(v4, firewall.UnmapPrefix(p))
} else {
v6 = append(v6, p)
}
}
return v4, v6
}
// isIPv6RouteRule determines whether a route rule belongs to the v6 table.
// For static routes, the destination prefix determines the family. For dynamic
// routes (DomainSet), the sources determine the family since management
// duplicates dynamic rules per family.
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
if destination.IsPrefix() {
return destination.Prefix.Addr().Is6()
}
return len(sources) > 0 && sources[0].Addr().Is6()
}

View File

@@ -1,3 +1,5 @@
//go:build integration && !android
package nftables
import (
@@ -70,13 +72,13 @@ func TestNftablesManager(t *testing.T) {
testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
rule, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop)
require.NoError(t, err, "failed to add rule")
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
rules, err := testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
require.NoError(t, err, "failed to get rules")
require.Len(t, rules, 2, "expected 2 rules")
@@ -148,14 +150,14 @@ func TestNftablesManager(t *testing.T) {
compareExprsIgnoringCounters(t, rules[1].Exprs, expectedExprs1)
for _, r := range rule {
err = manager.DeletePeerRule(r)
err = manager.DeleteFilterRule(r)
require.NoError(t, err, "failed to delete rule")
}
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
rules, err = testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
require.NoError(t, err, "failed to get rules")
// established rule remains
require.Len(t, rules, 1, "expected 1 rules after deletion")
@@ -180,47 +182,39 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
testClient := &nftables.Conn{}
// Add accept rule first
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "accept-http")
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err, "failed to add accept rule")
// Add deny rule second for the same traffic
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop, "deny-http")
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
require.NoError(t, err, "failed to add deny rule")
err = manager.Flush()
require.NoError(t, err, "failed to flush")
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
rules, err := testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
require.NoError(t, err, "failed to get rules")
t.Logf("Found %d rules in nftables chain", len(rules))
// Find the accept and deny rules and verify deny comes before accept
// Single-source rules emit a direct payload+cmp on the source IP
// (no set lookup). Match by source-IP + port + verdict instead of
// the legacy per-(action,port) set names ("deny-http"/"accept-http")
// that this test predates.
wantSrc := ip.AsSlice()
var acceptRuleIndex, denyRuleIndex = -1, -1
for i, rule := range rules {
hasAcceptHTTPSet := false
hasDenyHTTPSet := false
hasPort80 := false
var hasSrc, hasPort80 bool
var action string
for _, e := range rule.Exprs {
// Check for set lookup
if lookup, ok := e.(*expr.Lookup); ok {
switch lookup.SetName {
case "accept-http":
hasAcceptHTTPSet = true
case "deny-http":
hasDenyHTTPSet = true
if cmp, ok := e.(*expr.Cmp); ok && cmp.Op == expr.CmpOpEq {
if bytes.Equal(cmp.Data, wantSrc) {
hasSrc = true
}
}
// Check for port 80
if cmp, ok := e.(*expr.Cmp); ok {
if cmp.Op == expr.CmpOpEq && len(cmp.Data) == 2 && binary.BigEndian.Uint16(cmp.Data) == 80 {
if len(cmp.Data) == 2 && binary.BigEndian.Uint16(cmp.Data) == 80 {
hasPort80 = true
}
}
// Check for verdict
if verdict, ok := e.(*expr.Verdict); ok {
switch verdict.Kind {
case expr.VerdictAccept:
@@ -231,11 +225,15 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
}
}
if hasAcceptHTTPSet && hasPort80 && action == "ACCEPT" {
t.Logf("Rule [%d]: accept-http set + Port 80 + ACCEPT", i)
if !hasSrc || !hasPort80 {
continue
}
switch action {
case "ACCEPT":
t.Logf("Rule [%d]: src=%s port=80 ACCEPT", i, ip)
acceptRuleIndex = i
} else if hasDenyHTTPSet && hasPort80 && action == "DROP" {
t.Logf("Rule [%d]: deny-http set + Port 80 + DROP", i)
case "DROP":
t.Logf("Rule [%d]: src=%s port=80 DROP", i, ip)
denyRuleIndex = i
}
}
@@ -279,7 +277,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
require.NoError(t, err, "failed to add rule")
if i%100 == 0 {
@@ -361,10 +359,10 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
})
ip := netip.MustParseAddr("100.96.0.1")
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering(
_, err = manager.AddFilterRule(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
fw.Network{Prefix: netip.MustParsePrefix("10.1.0.0/24")},
@@ -437,10 +435,10 @@ func TestNftablesManagerIPv6CompatibilityWithIp6tables(t *testing.T) {
})
ip := netip.MustParseAddr("fd00::2")
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err, "add v6 peer filtering rule")
_, err = manager.AddRouteFiltering(
_, err = manager.AddFilterRule(
nil,
[]netip.Prefix{netip.MustParsePrefix("fd00:1::/64")},
fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
@@ -550,7 +548,7 @@ func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
prefixes = append(prefixes, netip.PrefixFrom(addr, 24))
}
}
_, err = manager.AddRouteFiltering(
_, err = manager.AddFilterRule(
nil,
prefixes,
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
@@ -591,7 +589,7 @@ func TestNftablesManagerCompatibilityWithIptablesForEmptyPrefixes(t *testing.T)
verifyIptablesOutput(t, stdout, stderr)
})
_, err = manager.AddRouteFiltering(
_, err = manager.AddFilterRule(
nil,
[]netip.Prefix{},
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},

View File

@@ -0,0 +1,492 @@
//go:build !android
package nftables
import (
"fmt"
"strings"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/client/net"
)
func (r *family) AddNatRule(pair firewall.RouterPair) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil {
return fmt.Errorf("add legacy routing rule: %w", err)
}
}
if pair.Masquerade {
if err := r.addNatRule(pair); err != nil {
return fmt.Errorf("add nat rule: %w", err)
}
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("add inverse nat rule: %w", err)
}
}
if err := r.conn.Flush(); err != nil {
r.rollbackRules(pair)
return fmt.Errorf("insert rules for %s: %w", pair.Destination, err)
}
return nil
}
// rollbackRules cleans up unflushed rules and their set counters after a flush failure.
func (r *family) rollbackRules(pair firewall.RouterPair) {
keys := []firewall.RuleID{
firewall.GenKey(firewall.ForwardingFormat, pair),
firewall.GenKey(firewall.PreroutingFormat, pair),
firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair)),
}
for _, key := range keys {
rule, ok := r.rules[key]
if !ok {
continue
}
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("rollback set counter for %s: %v", key, err)
}
delete(r.rules, key)
}
}
// addNatRule inserts a nftables rule to the conn client flush queue
func (r *family) addNatRule(pair firewall.RouterPair) error {
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
if err != nil {
return fmt.Errorf("apply source: %w", err)
}
destExp, err := r.applyNetwork(pair.Destination, nil, false)
if err != nil {
return fmt.Errorf("apply destination: %w", err)
}
op := expr.CmpOpEq
if pair.Inverse {
op = expr.CmpOpNeq
}
exprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: op,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
exprs = append(exprs, getCtNewExprs()...)
exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...)
var markValue uint32 = nbnet.PreroutingFwmarkMasquerade
if pair.Inverse {
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
}
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(markValue),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
SourceRegister: true,
Register: 1,
},
)
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if _, exists := r.rules[ruleKey]; exists {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove prerouting rule: %w", err)
}
}
// Ensure nat rules come first, so the mark can be overwritten.
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameManglePrerouting],
Exprs: exprs,
UserData: []byte(ruleKey),
})
return nil
}
func (r *family) addPostroutingRules() {
// First masquerade rule for traffic coming in from WireGuard interface
exprs := []expr.Any{
// Match on the first fwmark
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasquerade),
},
// We need to exclude the loopback interface as this changes the ebpf proxy port
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
&expr.Counter{},
&expr.Masq{},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs,
})
// Second masquerade rule for traffic going out through WireGuard interface
exprs2 := []expr.Any{
// Match on the second fwmark
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasqueradeReturn),
},
// Match WireGuard interface
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Counter{},
&expr.Masq{},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs2,
})
}
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
func (r *family) addMSSClampingRules() error {
overhead := uint16(ipv4TCPHeaderSize)
if r.af.tableFamily == nftables.TableFamilyIPv6 {
overhead = ipv6TCPHeaderSize
}
if r.mtu <= overhead {
log.Debugf("MTU %d too small for MSS clamping (overhead %d), skipping", r.mtu, overhead)
return nil
}
mss := r.mtu - overhead
exprsOut := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{
Key: expr.MetaKeyL4PROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{unix.IPPROTO_TCP},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 13,
Len: 1,
},
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 1,
Mask: []byte{0x02},
Xor: []byte{0x00},
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0x00},
},
&expr.Counter{},
&expr.Exthdr{
DestRegister: 1,
Type: 2,
Offset: 2,
Len: 2,
Op: expr.ExthdrOpTcpopt,
},
&expr.Cmp{
Op: expr.CmpOpGt,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
},
&expr.Exthdr{
SourceRegister: 1,
Type: 2,
Offset: 2,
Len: 2,
Op: expr.ExthdrOpTcpopt,
},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameMangleForward],
Exprs: exprsOut,
})
return r.conn.Flush()
}
func (r *family) addLegacyRouteRule(pair firewall.RouterPair) error {
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
if err != nil {
return fmt.Errorf("apply source: %w", err)
}
destExp, err := r.applyNetwork(pair.Destination, nil, false)
if err != nil {
return fmt.Errorf("apply destination: %w", err)
}
exprs := []expr.Any{
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...)
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if _, exists := r.rules[ruleKey]; exists {
if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err)
}
}
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Exprs: exprs,
UserData: []byte(ruleKey),
})
return nil
}
// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls
func (r *family) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
rule, exists := r.rules[ruleKey]
if !exists {
return nil
}
if rule.Handle == 0 {
log.Warnf("legacy forwarding rule %s has no handle, removing stale entry", ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
return nil
}
// GetLegacyManagement returns the route manager's legacy management mode
func (r *family) GetLegacyManagement() bool {
return r.legacyManagement
}
// SetLegacyManagement sets the route manager to use legacy management mode
func (r *family) SetLegacyManagement(isLegacy bool) {
r.legacyManagement = isLegacy
}
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
func (r *family) RemoveAllLegacyRouteRules() error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
for k, rule := range r.rules {
if !strings.HasPrefix(string(k), firewall.ForwardingFormatPrefix) {
continue
}
if err := r.conn.DelRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
} else {
delete(r.rules, k)
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) removeNatPreroutingRules() error {
table := &nftables.Table{
Name: tableNat,
Family: r.af.tableFamily,
}
chain := &nftables.Chain{
Name: chainNameNatPrerouting,
Table: table,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
}
rules, err := r.conn.GetRules(table, chain)
if err != nil {
return fmt.Errorf("get rules from nat table: %w", err)
}
var merr *multierror.Error
// Delete rules that have our UserData suffix
for _, rule := range rules {
if len(rule.UserData) == 0 || !strings.HasSuffix(string(rule.UserData), string(dnatSuffix)) {
continue
}
if err := r.conn.DelRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete rule %s: %w", rule.UserData, err))
}
}
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove prerouting rule: %w", err))
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove inverse prerouting rule: %w", err))
}
}
if err := r.removeLegacyRouteRule(pair); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy routing rule: %w", err))
}
// Set counters are decremented in the sub-methods above before flush. If flush fails,
// counters will be off until the next successful removal or refresh cycle.
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush remove nat rules %s: %w", pair.Destination, err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *family) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
rule, exists := r.rules[ruleKey]
if !exists {
log.Debugf("prerouting rule %s not found", ruleKey)
return nil
}
if rule.Handle == 0 {
log.Warnf("prerouting rule %s has no handle, removing stale entry", ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
return nil
}
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %w", pair.Source, pair.Destination, err)
}
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
return nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
//go:build !android
//go:build integration && !android
package nftables
@@ -37,7 +37,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
// need fw manager to init both acl mgr and router for all chains to be present
// need fw manager to init both acl mgr and family for all chains to be present
manager, err := Create(ifaceMock, iface.DefaultMTU)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
@@ -47,7 +47,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
nftablesTestingClient := &nftables.Conn{}
rtr := manager.router
rtr := manager.family4
err = rtr.AddNatRule(testCase.InputPair)
require.NoError(t, err, "pair should be inserted")
@@ -90,7 +90,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
}
// Build CIDR matching expressions
testRouter := &router{af: afIPv4}
testRouter := &family{af: afIPv4}
sourceExp := testRouter.applyPrefix(testCase.InputPair.Source.Prefix, true)
destExp := testRouter.applyPrefix(testCase.InputPair.Destination.Prefix, false)
@@ -107,7 +107,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
// Compare expressions up to the mark setting expressions
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "prerouting nat rule elements should match")
found = 1
@@ -135,9 +135,9 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
rtr := manager.router
rtr := manager.family4
// First add the NAT rule using the router's method
// First add the NAT rule using the family's method
err = rtr.AddNatRule(testCase.InputPair)
require.NoError(t, err, "should add NAT rule")
@@ -147,7 +147,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
require.NoError(t, err, "should list rules")
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
found = true
break
}
@@ -163,7 +163,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
require.NoError(t, err, "should list rules after removal")
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
found = true
break
}
@@ -200,11 +200,11 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router")
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create family")
require.NoError(t, r.init(workTable))
defer func(r *router) {
defer func(r *family) {
require.NoError(t, r.Reset(), "Failed to reset rules")
}(r)
@@ -314,16 +314,16 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
ruleKey, err := r.AddFilterRule(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddFilterRule failed")
t.Cleanup(func() {
require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule")
require.NoError(t, r.DeleteFilterRule(ruleKey), "Failed to delete rule")
})
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.ID()]
assert.True(t, ok, "Rule not found in internal map")
stored, ok := r.filters[id.RuleID(ruleKey.ID())]
require.True(t, ok, "Rule not found in filters map")
rule := stored.nftRule
t.Log("Internal rule expressions:")
for i, expr := range rule.Exprs {
@@ -339,7 +339,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
var nftRule *nftables.Rule
for _, rule := range rules {
if string(rule.UserData) == ruleKey.ID() {
if firewall.RuleID(rule.UserData) == ruleKey.ID() {
nftRule = rule
break
}
@@ -367,12 +367,12 @@ func TestNftablesCreateIpSet(t *testing.T) {
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router")
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create family")
require.NoError(t, r.init(workTable))
defer func() {
require.NoError(t, r.Reset(), "Failed to reset router")
require.NoError(t, r.Reset(), "Failed to reset family")
}()
tests := []struct {
@@ -518,11 +518,11 @@ func TestNftablesCreateIpSet_IPv6(t *testing.T) {
require.NoError(t, err, "Failed to create v6 work table")
defer deleteWorkTableIPv6()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router")
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create family")
require.NoError(t, r.init(workTable))
defer func() {
require.NoError(t, r.Reset(), "Failed to reset router")
require.NoError(t, r.Reset(), "Failed to reset family")
}()
tests := []struct {
@@ -861,13 +861,13 @@ func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
require.NoError(t, err)
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
// Add a real rule to the kernel
ruleKey, err := r.AddRouteFiltering(
ruleKey, err := r.AddFilterRule(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
@@ -878,11 +878,11 @@ func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, r.DeleteRouteRule(ruleKey))
require.NoError(t, r.DeleteFilterRule(ruleKey))
})
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
staleKey := "stale-rule-that-does-not-exist"
staleKey := firewall.RuleID("stale-rule-that-does-not-exist")
r.rules[staleKey] = &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
@@ -902,6 +902,55 @@ func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
}
// TestRouter_DeleteRouteRule_RemovesKernelRule verifies a route filter
// rule is actually removed from the kernel on delete. The route chain is
// not refreshed by Flush, so the stored rule carries a zero handle;
// DeleteFilterRule must pull live handles itself before issuing the
// delete or the kernel rule leaks. Regression test for that path.
func TestRouter_DeleteRouteRule_RemovesKernelRule(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err)
defer deleteWorkTable()
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
ruleKey, err := r.AddFilterRule(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
firewall.ProtocolTCP,
nil,
&firewall.Port{Values: []uint16{80}},
firewall.ActionAccept,
)
require.NoError(t, err)
countKernelRules := func() int {
list, err := r.conn.GetRules(r.workTable, r.chains[chainNameRoutingFw])
require.NoError(t, err)
n := 0
for _, rule := range list {
if string(rule.UserData) == string(ruleKey.ID()) {
n++
}
}
return n
}
require.Equal(t, 1, countKernelRules(), "rule should be present in the kernel after add")
require.NoError(t, r.DeleteFilterRule(ruleKey))
assert.Equal(t, 0, countKernelRules(), "rule must be removed from the kernel after delete")
assert.NotContains(t, r.filters, ruleKey.ID(), "filters map entry should be cleared")
}
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
@@ -911,24 +960,27 @@ func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
require.NoError(t, err)
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, r.init(workTable))
defer func() { require.NoError(t, r.Reset()) }()
// Inject a stale entry with Handle=0
staleKey := "stale-route-rule"
r.rules[staleKey] = &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Handle: 0,
UserData: []byte(staleKey),
staleKey := id.RuleID("stale-route-rule")
r.filters[staleKey] = &Rule{
nftRule: &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Handle: 0,
UserData: []byte(staleKey),
},
ruleID: staleKey,
}
// DeleteRouteRule should not return an error for stale handles
err = r.DeleteRouteRule(id.RuleID(staleKey))
// DeleteFilterRule should not return an error for stale handles
err = r.DeleteFilterRule(staleKey)
assert.NoError(t, err, "deleting a stale rule should not error")
assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up")
assert.NotContains(t, r.filters, staleKey, "stale entry should be cleaned up")
}
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
@@ -950,7 +1002,7 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
Masquerade: true,
}
rtr := manager.router
rtr := manager.family4
// First add succeeds
err = rtr.AddNatRule(pair)
@@ -979,7 +1031,7 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
found := 0
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
found++
}
}
@@ -1010,7 +1062,7 @@ func TestCalculateLastIP(t *testing.T) {
}
func TestConvertPrefixesToSet_IPv6(t *testing.T) {
r := &router{af: afIPv6}
r := &family{af: afIPv6}
prefixes := []netip.Prefix{
netip.MustParsePrefix("fd00::/64"),
netip.MustParsePrefix("2001:db8::1/128"),

View File

@@ -1,21 +1,26 @@
package nftables
import (
"net"
"net/netip"
"github.com/google/nftables"
"github.com/netbirdio/netbird/client/firewall/manager"
)
// Rule to handle management of rules
// Rule wraps an installed filter rule (peer or route). Source set
// membership is encoded in the rule's expressions; DeleteFilterRule
// recovers the set name via findSets so the refcounter can drop the
// right reference. mangleRule is set only for peer rules.
type Rule struct {
nftRule *nftables.Rule
mangleRule *nftables.Rule
nftSet *nftables.Set
ruleID string
ip net.IP
// sources is the canonical source list this rule was created for.
sources []netip.Prefix
id manager.RuleID
}
// GetRuleID returns the rule id
func (r *Rule) ID() string {
return r.ruleID
// ID returns the rule id
func (r *Rule) ID() manager.RuleID {
return r.id
}

View File

@@ -0,0 +1,23 @@
//go:build integration && !android
package nftables
import (
"net"
"net/netip"
)
func pfx(ip net.IP) []netip.Prefix {
if ip == nil {
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
if ip.IsUnspecified() {
if ip.To4() != nil {
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
return []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
}
a, _ := netip.AddrFromSlice(ip)
a = a.Unmap()
return []netip.Prefix{netip.PrefixFrom(a, a.BitLen())}
}

View File

@@ -5,7 +5,6 @@ import (
"encoding/binary"
"errors"
"fmt"
"net"
"net/netip"
"os"
"slices"
@@ -72,8 +71,10 @@ const (
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
// RuleSet is a set of rules grouped by a string key
type RuleSet map[string]PeerRule
// PeerRules is the canonical list-based storage for peer ACL rules.
// Match order is significant: drop rules come before accept rules so
// callers should consult the slice in order.
type PeerRules []*PeerRule
type RouteRules []*RouteRule
@@ -86,20 +87,22 @@ func (r RouteRules) Sort() {
if a.action != firewall.ActionDrop && b.action == firewall.ActionDrop {
return 1
}
return strings.Compare(a.id, b.id)
return strings.Compare(string(a.id), string(b.id))
})
}
// Manager userspace firewall manager
type Manager struct {
outgoingRules map[netip.Addr]RuleSet
incomingDenyRules map[netip.Addr]RuleSet
incomingRules map[netip.Addr]RuleSet
routeRules RouteRules
routeRulesMap map[nbid.RuleID]*RouteRule
decoders sync.Pool
wgIface common.IFaceMapper
nativeFirewall firewall.Manager
incomingDenyRules PeerRules
incomingRules PeerRules
incomingDenyIndex peerRuleIndex
incomingAcceptIndex peerRuleIndex
peerRulesMap map[nbid.RuleID]*PeerRule
routeRules RouteRules
routeRulesMap map[nbid.RuleID]*RouteRule
decoders sync.Pool
wgIface common.IFaceMapper
nativeFirewall firewall.Manager
mutex sync.RWMutex
@@ -255,9 +258,6 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
},
},
nativeFirewall: nativeFirewall,
outgoingRules: make(map[netip.Addr]RuleSet),
incomingDenyRules: make(map[netip.Addr]RuleSet),
incomingRules: make(map[netip.Addr]RuleSet),
wgIface: iface,
localipmanager: newLocalIPManager(),
disableServerRoutes: disableServerRoutes,
@@ -266,6 +266,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
flowLogger: flowLogger,
netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding,
peerRulesMap: make(map[nbid.RuleID]*PeerRule),
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
dnatMappings: make(map[netip.Addr]netip.Addr),
portDNATRules: []portDNATRule{},
@@ -488,75 +489,284 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
return nil
}
// AddPeerFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
func (m *Manager) AddPeerFiltering(
// addPeerFiltering installs an input-chain rule that matches packets
// by source only. Called from AddFilterRule when the caller doesn't
// specify a destination. Mixed-family inputs are split: each family
// gets its own rule with a family-correct ipLayer so packet decoding
// matches what the matcher expects.
func (m *Manager) addPeerFiltering(
id []byte,
ip net.IP,
sources []netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
_ string,
) ([]firewall.Rule, error) {
// TODO: fix in upper layers
i, ok := netip.AddrFromSlice(ip)
if !ok {
return nil, fmt.Errorf("invalid IP: %s", ip)
}
i = i.Unmap()
r := PeerRule{
id: uuid.New().String(),
mgmtId: id,
ip: i,
ipLayer: layers.LayerTypeIPv6,
matchByIP: true,
drop: action == firewall.ActionDrop,
}
if i.Is4() {
r.ipLayer = layers.LayerTypeIPv4
}
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
r.matchByIP = false
}
r.sPort = sPort
r.dPort = dPort
r.protoLayer = protoToLayer(proto, r.ipLayer)
m.mutex.Lock()
var targetMap map[netip.Addr]RuleSet
if r.drop {
targetMap = m.incomingDenyRules
} else {
targetMap = m.incomingRules
defer m.mutex.Unlock()
if sourcesMatchAny(sources) {
return []firewall.Rule{m.addOnePeerRule(id, sources, layerTypeAll, true, proto, sPort, dPort, action)}, nil
}
if _, ok := targetMap[r.ip]; !ok {
targetMap[r.ip] = make(RuleSet)
v4, v6 := splitPrefixesByFamily(sources)
out := make([]firewall.Rule, 0, 2)
if len(v4) > 0 {
out = append(out, m.addOnePeerRule(id, v4, layers.LayerTypeIPv4, false, proto, sPort, dPort, action))
}
targetMap[r.ip][r.id] = r
m.mutex.Unlock()
return []firewall.Rule{&r}, nil
if len(v6) > 0 {
out = append(out, m.addOnePeerRule(id, v6, layers.LayerTypeIPv6, false, proto, sPort, dPort, action))
}
return out, nil
}
func (m *Manager) AddRouteFiltering(
// addOnePeerRule builds and registers a single-family peer rule, or
// returns the existing rule when one with the same content key is
// already installed. The caller must hold m.mutex. The content key is
// the shared GenerateRuleID with an empty destination, so peer
// rules dedup the same way route rules and the kernel backends do.
//
// There is no refcount: a content key is installed once and deleted on
// the first DeleteFilterRule for that key. The caller must therefore
// key its own tracking by the returned rule id so add and delete stay
// balanced per content key; the acl manager does this via
// peerRulesPairs. The content key is order-independent, so callers
// passing the same sources in any order dedup to one rule.
func (m *Manager) addOnePeerRule(
id []byte,
sources []netip.Prefix,
ipLayer gopacket.LayerType,
matchAny bool,
proto firewall.Protocol,
sPort, dPort *firewall.Port,
action firewall.Action,
) *PeerRule {
ruleKey := nbid.GenerateRuleID(sources, firewall.Network{}, proto, sPort, dPort, action)
if existing, ok := m.peerRulesMap[ruleKey]; ok {
return existing
}
rule := m.buildPeerRule(ruleKey, id, sources, ipLayer, matchAny, proto, sPort, dPort, action)
m.registerPeerRule(rule)
return rule
}
func (m *Manager) buildPeerRule(
ruleKey nbid.RuleID,
id []byte,
sources []netip.Prefix,
ipLayer gopacket.LayerType,
matchAny bool,
proto firewall.Protocol,
sPort, dPort *firewall.Port,
action firewall.Action,
) *PeerRule {
r := &PeerRule{
id: ruleKey,
mgmtId: id,
sources: sources,
matchAny: matchAny,
action: action,
srcPort: sPort,
dstPort: dPort,
}
if !matchAny {
r.sourceAddrs = make(map[netip.Addr]struct{}, len(sources))
for _, p := range sources {
if p.Bits() == p.Addr().BitLen() {
r.sourceAddrs[p.Addr()] = struct{}{}
}
}
}
r.protoLayer = protoToLayer(proto, ipLayer)
return r
}
// registerPeerRule records a freshly built peer rule in the matching
// slice, index, and dedup map. The caller must hold m.mutex.
func (m *Manager) registerPeerRule(r *PeerRule) {
if r.action == firewall.ActionDrop {
m.incomingDenyRules = append(m.incomingDenyRules, r)
m.incomingDenyIndex.add(r)
} else {
m.incomingRules = append(m.incomingRules, r)
m.incomingAcceptIndex.add(r)
}
m.peerRulesMap[r.id] = r
}
// splitPrefixesByFamily partitions a mixed-family prefix list into v4
// and v6 buckets. v4-mapped v6 addresses are normalized to v4.
func splitPrefixesByFamily(sources []netip.Prefix) (v4, v6 []netip.Prefix) {
for _, p := range sources {
addr := p.Addr()
if addr.Is4() || addr.Is4In6() {
v4 = append(v4, firewall.UnmapPrefix(p))
} else {
v6 = append(v6, p)
}
}
return v4, v6
}
// peerRuleIndex is the source-side dispatcher consulted on the packet
// hot path. It separates rules into three disjoint buckets based on
// the shape of their source list. Every rule lives in exactly one:
//
// - bySource: every source is a host prefix (full-bit-length for
// its family: /32 for v4, /128 for v6). The map key is the
// concrete source address, so a hit guarantees the rule's source
// filter passes; the matcher proceeds straight to proto/port
// checks without re-verifying the source.
// - byCIDR: any source list containing a prefix coarser than a
// single host. The matcher walks this slice and runs prefix
// Contains() per rule. Expected to be empty for typical peer
// ACLs (always host prefixes) and small even when populated.
// - matchAny: at least one /0 source. Always matches every packet;
// no per-rule source check needed.
//
// Maintained incrementally by Add/DeleteFilterRule, never rebuilt.
type peerRuleIndex struct {
bySource map[netip.Addr][]*PeerRule
byCIDR []*PeerRule
matchAny []*PeerRule
}
func (i *peerRuleIndex) add(r *PeerRule) {
switch {
case r.matchAny:
i.matchAny = append(i.matchAny, r)
case hasNonHostSource(r):
i.byCIDR = append(i.byCIDR, r)
default:
if i.bySource == nil {
i.bySource = make(map[netip.Addr][]*PeerRule)
}
for a := range r.sourceAddrs {
i.bySource[a] = append(i.bySource[a], r)
}
}
}
func (i *peerRuleIndex) remove(r *PeerRule) {
switch {
case r.matchAny:
i.matchAny = slices.DeleteFunc(i.matchAny, eqRule(r))
case hasNonHostSource(r):
i.byCIDR = slices.DeleteFunc(i.byCIDR, eqRule(r))
default:
if i.bySource == nil {
return
}
for a := range r.sourceAddrs {
entries := slices.DeleteFunc(i.bySource[a], eqRule(r))
if len(entries) == 0 {
delete(i.bySource, a)
} else {
i.bySource[a] = entries
}
}
}
}
func (i *peerRuleIndex) reset() {
i.bySource = nil
i.byCIDR = i.byCIDR[:0]
i.matchAny = i.matchAny[:0]
}
func eqRule(target *PeerRule) func(*PeerRule) bool {
return func(p *PeerRule) bool { return p == target }
}
// hasNonHostSource reports whether the rule has any source prefix
// that is not a single host address. Called only at add/remove time,
// not on the packet path.
func hasNonHostSource(r *PeerRule) bool {
for _, p := range r.sources {
if p.Bits() != p.Addr().BitLen() {
return true
}
}
return false
}
// sourcesMatchAny reports whether the source list matches every source,
// i.e. contains an explicit /0 prefix. An empty list does not qualify:
// AddFilterRule rejects it with ErrNoSources, so "match any" is always
// the deliberate /0 case.
func sourcesMatchAny(sources []netip.Prefix) bool {
for _, p := range sources {
if p.Bits() == 0 {
return true
}
}
return false
}
// AddFilterRule is the unified entry point for both peer (input chain)
// and route (forward chain) filtering rules. The destination
// distinguishes the two semantics: a zero Network installs an
// input-side rule that matches by source only; a set Network installs
// a forward-side rule that also matches the destination.
func (m *Manager) AddFilterRule(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort, dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
) ([]firewall.Rule, error) {
if len(sources) == 0 {
return nil, firewall.ErrNoSources
}
if destination.IsPrefix() || destination.IsSet() {
m.mutex.Lock()
defer m.mutex.Unlock()
r, err := m.addRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
if err != nil {
return nil, err
}
return []firewall.Rule{r}, nil
}
// Peer path: sources are expected to be single-family; the acl
// manager keys its selector grouping on family, and management
// emits one FirewallRule per family upstream of that. The kernel
// backends still split defensively because their per-family
// tables can't encode the other family's addresses; uspfilter
// has no such constraint.
return m.addPeerFiltering(id, sources, proto, sPort, dPort, action)
}
return m.addRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
// DeleteFilterRule deletes a filtering rule. The rule's underlying type
// is used to route to the correct internal path.
func (m *Manager) DeleteFilterRule(rule firewall.Rule) error {
switch r := rule.(type) {
case *RouteRule:
m.mutex.Lock()
defer m.mutex.Unlock()
return m.deleteRouteRule(rule)
case *PeerRule:
return m.deletePeerRule(rule)
case firewall.RuleID:
// Deletion by bare id. Resolve to the concrete rule rather than
// assuming a route: a peer rule if we own that id, otherwise the
// route path. Without this, deleting a peer rule by id would
// silently miss in the route map.
m.mutex.Lock()
defer m.mutex.Unlock()
if pr, ok := m.peerRulesMap[r]; ok {
return m.deletePeerRuleLocked(pr)
}
return m.deleteRouteRule(rule)
default:
// Native firewall route rules implement firewall.Rule but
// aren't one of our concrete types; the route path knows how
// to forward them.
m.mutex.Lock()
defer m.mutex.Unlock()
return m.deleteRouteRule(rule)
}
}
func (m *Manager) addRouteFiltering(
@@ -568,18 +778,24 @@ func (m *Manager) addRouteFiltering(
action firewall.Action,
) (firewall.Rule, error) {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
rules, err := m.nativeFirewall.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
if err != nil {
return nil, err
}
if len(rules) == 0 {
return nil, fmt.Errorf("native firewall returned no rule")
}
return rules[0], nil
}
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
ruleKey := nbid.GenerateRuleID(sources, destination, proto, sPort, dPort, action)
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
return existingRule, nil
}
rule := RouteRule{
// TODO: consolidate these IDs
id: string(ruleKey),
id: ruleKey,
mgmtId: id,
sources: sources,
dstSet: destination.Set,
@@ -599,25 +815,18 @@ func (m *Manager) addRouteFiltering(
return &rule, nil
}
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.deleteRouteRule(rule)
}
func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.DeleteRouteRule(rule)
return m.nativeFirewall.DeleteFilterRule(rule)
}
ruleKey := nbid.RuleID(rule.ID())
ruleKey := rule.ID()
if _, ok := m.routeRulesMap[ruleKey]; !ok {
return fmt.Errorf("route rule not found: %s", ruleKey)
}
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
return r.id == string(ruleKey)
return r.id == ruleKey
})
if idx < 0 {
return fmt.Errorf("route rule not found in slice: %s", ruleKey)
@@ -628,8 +837,8 @@ func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
return nil
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
// deletePeerRule removes an input-chain rule. Acquires m.mutex.
func (m *Manager) deletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -637,26 +846,31 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
if !ok {
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
}
return m.deletePeerRuleLocked(r)
}
var sourceMap map[netip.Addr]RuleSet
if r.drop {
sourceMap = m.incomingDenyRules
// deletePeerRuleLocked removes a peer rule from the matching slice,
// index, and dedup map. The caller must hold m.mutex.
func (m *Manager) deletePeerRuleLocked(r *PeerRule) error {
var target *PeerRules
if r.action == firewall.ActionDrop {
target = &m.incomingDenyRules
} else {
sourceMap = m.incomingRules
target = &m.incomingRules
}
if ruleset, ok := sourceMap[r.ip]; ok {
if _, exists := ruleset[r.id]; !exists {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(ruleset, r.id)
if len(ruleset) == 0 {
delete(sourceMap, r.ip)
}
} else {
pos := slices.IndexFunc(*target, func(p *PeerRule) bool { return p.id == r.id })
if pos < 0 {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
stored := (*target)[pos]
*target = slices.Delete(*target, pos, pos+1)
if r.action == firewall.ActionDrop {
m.incomingDenyIndex.remove(stored)
} else {
m.incomingAcceptIndex.remove(stored)
}
delete(m.peerRulesMap, r.id)
return nil
}
@@ -674,9 +888,11 @@ func (m *Manager) Flush() error { return nil }
// resetState clears all firewall rules and closes connection trackers.
// Must be called with m.mutex held.
func (m *Manager) resetState() {
clear(m.outgoingRules)
clear(m.incomingDenyRules)
clear(m.incomingRules)
m.incomingDenyRules = m.incomingDenyRules[:0]
m.incomingRules = m.incomingRules[:0]
m.incomingDenyIndex.reset()
m.incomingAcceptIndex.reset()
clear(m.peerRulesMap)
clear(m.routeRulesMap)
m.routeRules = m.routeRules[:0]
m.udpHookOut.Store(nil)
@@ -820,11 +1036,11 @@ func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP netip.Addr) {
case layers.LayerTypeIPv4:
src, _ := netip.AddrFromSlice(d.ip4.SrcIP)
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
return src, dst
return src.Unmap(), dst.Unmap()
case layers.LayerTypeIPv6:
src, _ := netip.AddrFromSlice(d.ip6.SrcIP)
dst, _ := netip.AddrFromSlice(d.ip6.DstIP)
return src, dst
return src.Unmap(), dst.Unmap()
default:
return netip.Addr{}, netip.Addr{}
}
@@ -1404,23 +1620,83 @@ func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte)
return nil, false
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingDenyRules[srcIP], d); ok {
if mgmtId, filter, ok := m.incomingDenyIndex.match(srcIP, d); ok {
return mgmtId, filter
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[srcIP], d); ok {
if mgmtId, filter, ok := m.incomingAcceptIndex.match(srcIP, d); ok {
return mgmtId, filter
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[netip.IPv4Unspecified()], d); ok {
return mgmtId, filter
}
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[netip.IPv6Unspecified()], d); ok {
return mgmtId, filter
}
return nil, true
}
// match walks the three buckets in order and returns the first rule
// that matches src and the decoded packet. Source filtering is
// dispatched up-front by bucket: bySource[src] is by definition
// source-matching; byCIDR rules need a prefix Contains() check;
// matchAny rules apply to every source. Within each bucket the
// matcher runs proto/port filters.
func (i *peerRuleIndex) match(src netip.Addr, d *decoder) ([]byte, bool, bool) {
payloadLayer := d.decoded[1]
if len(i.bySource) > 0 {
if rules := i.bySource[src]; len(rules) > 0 {
for _, rule := range rules {
if id, drop, ok := matchProto(rule, d, payloadLayer); ok {
return id, drop, true
}
}
}
}
for _, rule := range i.byCIDR {
if !prefixesContain(rule.sources, src) {
continue
}
if id, drop, ok := matchProto(rule, d, payloadLayer); ok {
return id, drop, true
}
}
for _, rule := range i.matchAny {
if id, drop, ok := matchProto(rule, d, payloadLayer); ok {
return id, drop, true
}
}
return nil, false, false
}
// matchProto applies the proto/port half of a rule against the
// decoded packet. Source matching is the caller's responsibility.
func matchProto(rule *PeerRule, d *decoder, payloadLayer gopacket.LayerType) ([]byte, bool, bool) {
drop := rule.action == firewall.ActionDrop
if rule.protoLayer == layerTypeAll {
return rule.mgmtId, drop, true
}
if !protoLayerMatches(rule.protoLayer, payloadLayer) {
return nil, false, false
}
switch payloadLayer {
case layers.LayerTypeTCP:
if portsMatch(rule.srcPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dstPort, uint16(d.tcp.DstPort)) {
return rule.mgmtId, drop, true
}
case layers.LayerTypeUDP:
if portsMatch(rule.srcPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dstPort, uint16(d.udp.DstPort)) {
return rule.mgmtId, drop, true
}
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return rule.mgmtId, drop, true
}
return nil, false, false
}
func prefixesContain(sources []netip.Prefix, src netip.Addr) bool {
for _, p := range sources {
if p.Contains(src) {
return true
}
}
return false
}
func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
if rulePort == nil {
return true
@@ -1438,39 +1714,6 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
return false
}
func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
payloadLayer := d.decoded[1]
for _, rule := range rules {
if rule.matchByIP && ip.Compare(rule.ip) != 0 {
continue
}
if rule.protoLayer == layerTypeAll {
return rule.mgmtId, rule.drop, true
}
if !protoLayerMatches(rule.protoLayer, payloadLayer) {
continue
}
switch payloadLayer {
case layers.LayerTypeTCP:
if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) {
return rule.mgmtId, rule.drop, true
}
case layers.LayerTypeUDP:
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
return rule.mgmtId, rule.drop, true
}
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return rule.mgmtId, rule.drop, true
}
}
return nil, false, false
}
// routeACLsPass returns true if the packet is allowed by the route ACLs
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) ([]byte, bool) {
m.mutex.RLock()

View File

@@ -94,7 +94,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
stateful: false,
setupFunc: func(m *Manager) {
// Single rule allowing all traffic
_, err := m.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
_, err := m.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolALL, nil, nil, fw.ActionAccept)
require.NoError(b, err)
},
desc: "Baseline: Single 'allow all' rule without connection tracking",
@@ -114,15 +114,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Add explicit rules matching return traffic pattern
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
ip := generateRandomIPs(1)[0]
_, err := m.AddPeerFiltering(
_, err := m.AddFilterRule(
nil,
ip,
pfx(ip), fw.Network{},
fw.ProtocolTCP,
&fw.Port{Values: []uint16{uint16(1024 + i)}},
&fw.Port{Values: []uint16{80}},
fw.ActionAccept,
"",
)
fw.ActionAccept)
require.NoError(b, err)
}
},
@@ -133,15 +131,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
stateful: true,
setupFunc: func(m *Manager) {
// Add some basic rules but rely on state for established connections
_, err := m.AddPeerFiltering(
_, err := m.AddFilterRule(
nil,
net.ParseIP("0.0.0.0"),
pfx(net.ParseIP("0.0.0.0")), fw.Network{},
fw.ProtocolTCP,
nil,
nil,
fw.ActionDrop,
"",
)
fw.ActionDrop)
require.NoError(b, err)
},
desc: "Connection tracking with established connections",
@@ -546,7 +542,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Setup initial state based on scenario
if sc.rules {
// Single rule to allow all return traffic from port 80
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
require.NoError(b, err)
}
@@ -629,7 +625,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
// Setup initial state based on scenario
if sc.rules {
// Single rule to allow all return traffic from port 80
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
require.NoError(b, err)
}
@@ -739,7 +735,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
// Setup initial state based on scenario
if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
require.NoError(b, err)
}
@@ -818,7 +814,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
})
if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
require.NoError(b, err)
}
@@ -931,7 +927,7 @@ func BenchmarkRouteACLs(b *testing.B) {
for _, r := range rules {
dst := fw.Network{Prefix: r.dest}
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
_, err := manager.AddFilterRule(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
if err != nil {
b.Fatal(err)
}

View File

@@ -496,39 +496,35 @@ func TestPeerACLFiltering(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
if tc.ruleAction == fw.ActionDrop {
// add general accept rule for the same IP to test drop rule precedence
rules, err := manager.AddPeerFiltering(
rules, err := manager.AddFilterRule(
nil,
net.ParseIP(tc.ruleIP),
pfx(net.ParseIP(tc.ruleIP)), fw.Network{},
fw.ProtocolALL,
nil,
nil,
fw.ActionAccept,
"",
)
fw.ActionAccept)
require.NoError(t, err)
require.NotEmpty(t, rules)
t.Cleanup(func() {
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
require.NoError(t, manager.DeleteFilterRule(rule))
}
})
}
rules, err := manager.AddPeerFiltering(
rules, err := manager.AddFilterRule(
nil,
net.ParseIP(tc.ruleIP),
pfx(net.ParseIP(tc.ruleIP)), fw.Network{},
tc.ruleProto,
tc.ruleSrcPort,
tc.ruleDstPort,
tc.ruleAction,
"",
)
tc.ruleAction)
require.NoError(t, err)
require.NotEmpty(t, rules)
t.Cleanup(func() {
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
require.NoError(t, manager.DeleteFilterRule(rule))
}
})
@@ -672,21 +668,21 @@ func TestPeerACLFilteringIPv6(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if tc.ruleAction == fw.ActionDrop {
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
rules, err := manager.AddFilterRule(nil, pfx(net.ParseIP(tc.ruleIP)), fw.Network{}, fw.ProtocolALL, nil, nil, fw.ActionAccept)
require.NoError(t, err)
t.Cleanup(func() {
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
require.NoError(t, manager.DeleteFilterRule(rule))
}
})
}
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction, "")
rules, err := manager.AddFilterRule(nil, pfx(net.ParseIP(tc.ruleIP)), fw.Network{}, tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction)
require.NoError(t, err)
require.NotEmpty(t, rules)
t.Cleanup(func() {
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
require.NoError(t, manager.DeleteFilterRule(rule))
}
})
@@ -1405,7 +1401,7 @@ func TestRouteACLFiltering(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
if tc.rule.action == fw.ActionDrop {
// add general accept rule to test drop rule
rule, err := manager.AddRouteFiltering(
rule, err := manager.AddFilterRule(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
@@ -1415,13 +1411,13 @@ func TestRouteACLFiltering(t *testing.T) {
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
require.NotEmpty(t, rule)
t.Cleanup(func() {
require.NoError(t, manager.DeleteRouteRule(rule))
require.NoError(t, manager.DeleteFilterRule(rule[0]))
})
}
rule, err := manager.AddRouteFiltering(
rule, err := manager.AddFilterRule(
nil,
tc.rule.sources,
tc.rule.dest,
@@ -1431,10 +1427,10 @@ func TestRouteACLFiltering(t *testing.T) {
tc.rule.action,
)
require.NoError(t, err)
require.NotNil(t, rule)
require.NotEmpty(t, rule)
t.Cleanup(func() {
require.NoError(t, manager.DeleteRouteRule(rule))
require.NoError(t, manager.DeleteFilterRule(rule[0]))
})
srcIP := netip.MustParseAddr(tc.srcIP)
@@ -1604,7 +1600,7 @@ func TestRouteACLOrder(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
var rules []fw.Rule
for _, r := range tc.rules {
rule, err := manager.AddRouteFiltering(
rule, err := manager.AddFilterRule(
nil,
r.sources,
r.dest,
@@ -1614,13 +1610,13 @@ func TestRouteACLOrder(t *testing.T) {
r.action,
)
require.NoError(t, err)
require.NotNil(t, rule)
rules = append(rules, rule)
require.NotEmpty(t, rule)
rules = append(rules, rule...)
}
t.Cleanup(func() {
for _, rule := range rules {
require.NoError(t, manager.DeleteRouteRule(rule))
require.NoError(t, manager.DeleteFilterRule(rule))
}
})
@@ -1655,7 +1651,7 @@ func TestRouteACLSet(t *testing.T) {
set := fw.NewDomainSet(domain.List{"example.org"})
// Add rule that uses the set (initially empty)
rule, err := manager.AddRouteFiltering(
rule, err := manager.AddFilterRule(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
@@ -1689,7 +1685,7 @@ func TestRouteACLFilteringIPv6(t *testing.T) {
manager := setupRoutedManager(t, "10.10.0.100/16")
v6Dst := netip.MustParsePrefix("fd00:dead:beef::/48")
_, err := manager.AddRouteFiltering(
_, err := manager.AddFilterRule(
nil,
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
fw.Network{Prefix: v6Dst},
@@ -1700,7 +1696,7 @@ func TestRouteACLFilteringIPv6(t *testing.T) {
)
require.NoError(t, err)
_, err = manager.AddRouteFiltering(
_, err = manager.AddFilterRule(
nil,
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
fw.Network{Prefix: netip.MustParsePrefix("fd00:dead:beef:1::/64")},

View File

@@ -29,7 +29,7 @@ func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Add rule first time
rule1, err := manager.AddRouteFiltering(
rule1, err := manager.AddFilterRule(
[]byte("policy-1"),
sources,
destination,
@@ -42,7 +42,7 @@ func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
require.NotNil(t, rule1)
// Add the same rule again
rule2, err := manager.AddRouteFiltering(
rule2, err := manager.AddFilterRule(
[]byte("policy-1"),
sources,
destination,
@@ -55,7 +55,7 @@ func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
require.NotNil(t, rule2)
// These should be the same (idempotent) like nftables/iptables implementations
assert.Equal(t, rule1.ID(), rule2.ID(),
assert.Equal(t, rule1[0].ID(), rule2[0].ID(),
"Adding the same rule twice should return the same rule ID (idempotent)")
manager.mutex.RLock()
@@ -74,7 +74,7 @@ func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
// Add first rule
rule1, err := manager.AddRouteFiltering(
rule1, err := manager.AddFilterRule(
[]byte("policy-1"),
sources,
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
@@ -86,7 +86,7 @@ func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
require.NoError(t, err)
// Add different rule (different destination)
rule2, err := manager.AddRouteFiltering(
rule2, err := manager.AddFilterRule(
[]byte("policy-2"),
sources,
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
@@ -97,7 +97,7 @@ func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
)
require.NoError(t, err)
assert.NotEqual(t, rule1.ID(), rule2.ID(),
assert.NotEqual(t, rule1[0].ID(), rule2[0].ID(),
"Different rules should have different IDs")
manager.mutex.RLock()
@@ -115,7 +115,7 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
rule1, err := manager.AddRouteFiltering(
rule1, err := manager.AddFilterRule(
[]byte("policy-1"),
sources,
destination,
@@ -132,7 +132,7 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
require.True(t, pass, "Traffic should pass with rule in place")
// Re-add same rule (simulates network map update)
rule2, err := manager.AddRouteFiltering(
rule2, err := manager.AddFilterRule(
[]byte("policy-1"),
sources,
destination,
@@ -143,11 +143,11 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
)
require.NoError(t, err)
// Idempotent IDs mean rule1.ID() == rule2.ID(), so the ACL manager
// Idempotent IDs mean rule1[0].ID() == rule2[0].ID(), so the ACL manager
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
// would remove the only matching rule and cause a traffic gap.
if rule1.ID() != rule2.ID() {
err = manager.DeleteRouteRule(rule1)
if rule1[0].ID() != rule2[0].ID() {
err = manager.DeleteFilterRule(rule1[0])
require.NoError(t, err)
}
@@ -274,7 +274,7 @@ func TestRouteRuleCountStableAcrossUpdates(t *testing.T) {
// Simulate 5 network map updates with the same route rule
for i := 0; i < 5; i++ {
rule, err := manager.AddRouteFiltering(
rule, err := manager.AddFilterRule(
[]byte("policy-1"),
sources,
destination,
@@ -304,7 +304,7 @@ func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
// Add same rule twice
rule1, err := manager.AddRouteFiltering(
rule1, err := manager.AddFilterRule(
[]byte("policy-1"),
sources,
destination,
@@ -315,7 +315,7 @@ func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
)
require.NoError(t, err)
rule2, err := manager.AddRouteFiltering(
rule2, err := manager.AddFilterRule(
[]byte("policy-1"),
sources,
destination,
@@ -326,10 +326,10 @@ func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
)
require.NoError(t, err)
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
require.Equal(t, rule1[0].ID(), rule2[0].ID(), "Should return same rule ID")
// Delete using first reference
err = manager.DeleteRouteRule(rule1)
err = manager.DeleteFilterRule(rule1[0])
require.NoError(t, err)
// Verify traffic no longer passes

View File

@@ -89,7 +89,7 @@ func TestManagerCreate(t *testing.T) {
}
}
func TestManagerAddPeerFiltering(t *testing.T) {
func TestManagerAddFilterRule(t *testing.T) {
isSetFilterCalled := false
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error {
@@ -109,7 +109,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
rule, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
rule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -142,7 +142,7 @@ func TestManagerDeleteRule(t *testing.T) {
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
rule2, err := m.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, proto, nil, port, action)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -155,42 +155,39 @@ func TestManagerDeleteRule(t *testing.T) {
t.Errorf("rule should be a PeerRule")
continue
}
// Check if rule exists in deny or allow maps based on action
var found bool
if peerRule.drop {
_, found = m.incomingDenyRules[ip][r.ID()]
if peerRule.action == fw.ActionDrop {
found = findRuleByID(m.incomingDenyRules, ip, r.ID())
} else {
_, found = m.incomingRules[ip][r.ID()]
found = findRuleByID(m.incomingRules, ip, r.ID())
}
if !found {
t.Errorf("rule2 is not in the expected rules map")
t.Errorf("rule2 is not in the expected rules list")
}
}
for _, r := range rule2 {
err = m.DeletePeerRule(r)
err = m.DeleteFilterRule(r)
if err != nil {
t.Errorf("failed to delete rule: %v", err)
return
}
}
// Check rules are removed from appropriate maps
for _, r := range rule2 {
peerRule, ok := r.(*PeerRule)
if !ok {
t.Errorf("rule should be a PeerRule")
continue
}
// Check if rule is removed from deny or allow maps based on action
var found bool
if peerRule.drop {
_, found = m.incomingDenyRules[ip][r.ID()]
if peerRule.action == fw.ActionDrop {
found = findRuleByID(m.incomingDenyRules, ip, r.ID())
} else {
_, found = m.incomingRules[ip][r.ID()]
found = findRuleByID(m.incomingRules, ip, r.ID())
}
if found {
t.Errorf("rule2 should be removed from the rules map")
t.Errorf("rule2 should be removed from the rules list")
}
}
}
@@ -260,36 +257,34 @@ func TestPeerRuleLifecycleDenyRules(t *testing.T) {
addr := netip.MustParseAddr("192.168.1.1")
// Add multiple deny rules for different ports
rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
rule1, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{22}}, fw.ActionDrop)
require.NoError(t, err)
rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionDrop, "")
rule2, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
require.NoError(t, err)
m.mutex.RLock()
denyCount := len(m.incomingDenyRules[addr])
denyCount := countRulesForAddr(m.incomingDenyRules, addr)
m.mutex.RUnlock()
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
// Delete the first deny rule
err = m.DeletePeerRule(rule1[0])
err = m.DeleteFilterRule(rule1[0])
require.NoError(t, err)
m.mutex.RLock()
denyCount = len(m.incomingDenyRules[addr])
denyCount = countRulesForAddr(m.incomingDenyRules, addr)
m.mutex.RUnlock()
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
// Delete the second deny rule
err = m.DeletePeerRule(rule2[0])
err = m.DeleteFilterRule(rule2[0])
require.NoError(t, err)
m.mutex.RLock()
_, exists := m.incomingDenyRules[addr]
exists := countRulesForAddr(m.incomingDenyRules, addr) > 0
m.mutex.RUnlock()
require.False(t, exists, "Deny rules IP entry should be cleaned up when empty")
require.False(t, exists, "Deny rules should be cleaned up when empty")
}
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
@@ -311,27 +306,25 @@ func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
// Simulate 10 network map updates: add rule, delete old, add new
for i := 0; i < 10; i++ {
// Add a deny rule
rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
rules, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{22}}, fw.ActionDrop)
require.NoError(t, err)
// Add an allow rule
allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
allowRules, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err)
// Delete them (simulating ACL manager cleanup)
for _, r := range rules {
require.NoError(t, m.DeletePeerRule(r))
require.NoError(t, m.DeleteFilterRule(r))
}
for _, r := range allowRules {
require.NoError(t, m.DeletePeerRule(r))
require.NoError(t, m.DeleteFilterRule(r))
}
}
m.mutex.RLock()
denyCount := len(m.incomingDenyRules[addr])
allowCount := len(m.incomingRules[addr])
denyCount := countRulesForAddr(m.incomingDenyRules, addr)
allowCount := countRulesForAddr(m.incomingRules, addr)
m.mutex.RUnlock()
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
@@ -354,41 +347,39 @@ func TestMixedAllowDenyRulesSameIP(t *testing.T) {
ip := net.ParseIP("192.168.1.1")
// Add allow rule for port 80
allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
allowRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err)
// Add deny rule for port 22
denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
denyRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{22}}, fw.ActionDrop)
require.NoError(t, err)
addr := netip.MustParseAddr("192.168.1.1")
m.mutex.RLock()
allowCount := len(m.incomingRules[addr])
denyCount := len(m.incomingDenyRules[addr])
allowCount := countRulesForAddr(m.incomingRules, addr)
denyCount := countRulesForAddr(m.incomingDenyRules, addr)
m.mutex.RUnlock()
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
// Delete allow rule should not affect deny rule
err = m.DeletePeerRule(allowRule[0])
err = m.DeleteFilterRule(allowRule[0])
require.NoError(t, err)
m.mutex.RLock()
denyCountAfter := len(m.incomingDenyRules[addr])
denyCountAfter := countRulesForAddr(m.incomingDenyRules, addr)
m.mutex.RUnlock()
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
// Delete deny rule
err = m.DeletePeerRule(denyRule[0])
err = m.DeleteFilterRule(denyRule[0])
require.NoError(t, err)
m.mutex.RLock()
_, denyExists := m.incomingDenyRules[addr]
_, allowExists := m.incomingRules[addr]
denyExists := countRulesForAddr(m.incomingDenyRules, addr) > 0
allowExists := countRulesForAddr(m.incomingRules, addr) > 0
m.mutex.RUnlock()
require.False(t, denyExists, "Deny rules should be empty")
@@ -411,7 +402,7 @@ func TestManagerReset(t *testing.T) {
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err = m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
_, err = m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -423,7 +414,7 @@ func TestManagerReset(t *testing.T) {
return
}
if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 || len(m.incomingDenyRules) != 0 {
if len(m.incomingRules) != 0 || len(m.incomingDenyRules) != 0 {
t.Errorf("rules are not empty")
}
}
@@ -449,7 +440,7 @@ func TestNotMatchByIP(t *testing.T) {
proto := fw.ProtocolUDP
action := fw.ActionAccept
_, err = m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
_, err = m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, nil, action)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -621,7 +612,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
_, err = manager.AddFilterRule(nil, pfx(ip), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
require.NoError(t, err, "failed to add rule")
}
@@ -858,7 +849,7 @@ func TestUpdateSetMerge(t *testing.T) {
netip.MustParsePrefix("192.168.1.0/24"),
}
rule, err := manager.AddRouteFiltering(
rule, err := manager.AddFilterRule(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
@@ -916,7 +907,7 @@ func TestUpdateSetMerge(t *testing.T) {
manager.mutex.RLock()
foundRule := false
for _, r := range manager.routeRules {
if r.id == rule.ID() {
if r.id == rule[0].ID() {
foundRule = true
require.Len(t, r.destinations, len(initialPrefixes)+len(newPrefixes),
"Rule should have all prefixes merged")
@@ -939,7 +930,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
set := fw.NewDomainSet(domain.List{"example.org"})
rule, err := manager.AddRouteFiltering(
rule, err := manager.AddFilterRule(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
@@ -965,7 +956,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
manager.mutex.RLock()
foundRule := false
for _, r := range manager.routeRules {
if r.id == rule.ID() {
if r.id == rule[0].ID() {
foundRule = true
// Should have deduplicated to 2 prefixes
require.Len(t, r.destinations, 2, "Duplicate prefixes should be removed")
@@ -998,7 +989,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
// Check that all prefixes are included (no deduplication of overlapping prefixes)
manager.mutex.RLock()
for _, r := range manager.routeRules {
if r.id == rule.ID() {
if r.id == rule[0].ID() {
// Should have all 4 prefixes (2 original + 2 new more general ones)
require.Len(t, r.destinations, 4,
"Overlapping prefixes should not be deduplicated")

View File

@@ -0,0 +1,327 @@
//go:build uspbench
package uspfilter
import (
"fmt"
"io"
"math/rand"
"net"
"net/netip"
"runtime"
"testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// BenchmarkPeerACLMatch measures the per-packet cost of the peer ACL
// matcher (peerACLsBlock) across realistic shapes: M distinct policy
// rules, each with K source peers in its set.
//
// With the reverse-source index, miss cost is independent of M and
// hit cost grows only with the number of rules touching a single
// srcIP, not with total rule count.
func BenchmarkPeerACLMatch(b *testing.B) {
shapes := []struct{ M, K int }{
{1, 100}, {10, 100}, {50, 100}, {100, 100}, {100, 1000},
}
families := []struct {
name string
v6 bool
}{{"v4", false}, {"v6", true}}
for _, fam := range families {
for _, s := range shapes {
b.Run(fmt.Sprintf("%s/M=%d/K=%d/hit", fam.name, s.M, s.K), func(b *testing.B) {
runPeerACLBench(b, s.M, s.K, true, fam.v6)
})
b.Run(fmt.Sprintf("%s/M=%d/K=%d/miss", fam.name, s.M, s.K), func(b *testing.B) {
runPeerACLBench(b, s.M, s.K, false, fam.v6)
})
}
}
}
func runPeerACLBench(b *testing.B, m, k int, hit, v6 bool) {
log.SetOutput(io.Discard) // keep manager logs out of the benchmark output
// Miss packets are dropped, so they always traverse the full peer
// ACL matcher (every bucket) without short-circuiting and without
// touching conntrack. Disable conntrack for the miss case so it
// measures the matcher, not established-state lookups. The hit case
// keeps conntrack on: an accepted packet reaches trackInbound, which
// needs the trackers conntrack creates.
if !hit {
b.Setenv("NB_DISABLE_CONNTRACK", "1")
}
bits := 32
genPkt := generatePacket
addrs := uniqueAddrs
if v6 {
bits = 128
genPkt = generatePacket6
addrs = uniqueAddrs6
}
// dstIP must be a local IP so filterInbound takes the local-traffic
// path (handleLocalTraffic → peerACLsBlock) we are measuring; an
// address the manager doesn't own would be treated as routed and
// short-circuit before the peer matcher.
dstIP := addrs(1, 2)[0]
mockAddr := wgaddr.Address{IP: dstIP, Network: netip.PrefixFrom(dstIP, bits)}
if v6 {
// The local-IP manager needs a valid v4 address too; expose the v6
// dst as the interface's IPv6 so IsLocalIP recognizes it.
mockAddr = wgaddr.Address{
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/16"),
IPv6: dstIP,
IPv6Net: netip.PrefixFrom(dstIP, bits),
}
}
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { return mockAddr },
}, false, flowLogger, iface.DefaultMTU)
b.Cleanup(func() { require.NoError(b, manager.Close(nil)) })
// Generate M policies × K source peers, all distinct.
all := addrs(m*k, 1)
for i := 0; i < m; i++ {
sources := make([]netip.Prefix, k)
for j, a := range all[i*k : (i+1)*k] {
sources[j] = netip.PrefixFrom(a, bits)
}
_, err := manager.AddFilterRule(
nil, sources, fw.Network{}, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{uint16(80 + i)}},
fw.ActionAccept)
require.NoError(b, err)
}
// Hit: cycle through real sources, picking the matching policy's port.
// Miss: a source from a disjoint range, port 80 (matches no policy).
var pktFn func(i int) []byte
if hit {
pktFn = func(i int) []byte {
policy := i % m
src := all[policy*k+(i%k)]
return genPkt(b, src.AsSlice(), dstIP.AsSlice(),
uint16(1024+i%60000), uint16(80+policy), layers.IPProtocolTCP)
}
} else {
miss := addrs(4096, 99)
pktFn = func(i int) []byte {
return genPkt(b, miss[i%len(miss)].AsSlice(), dstIP.AsSlice(),
uint16(1024+i%60000), 80, layers.IPProtocolTCP)
}
}
// Pre-build a pool to avoid allocations dominating the measurement.
pool := make([][]byte, 1024)
for i := range pool {
pool[i] = pktFn(i)
}
// Confirm the matcher is actually exercised: a hit packet must be
// allowed and a miss packet dropped. Without this the benchmark
// could silently time the routed early-return instead.
require.Equal(b, !hit, manager.filterInbound(pool[0], 0),
"benchmark must reach the peer ACL matcher")
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterInbound(pool[i%len(pool)], 0)
}
}
// BenchmarkPeerACLIndexMemory reports the resident memory cost of
// the source-keyed index across realistic deployment shapes. Two
// dimensions matter: (M, K), the number of policies × peers-per-policy,
// and overlap, the fraction of peers shared between policies.
//
// The output uses ReportMetric("bytes/rule") so the cost can be
// compared across shapes directly. Total bytes = bytes/rule * M.
func BenchmarkPeerACLIndexMemory(b *testing.B) {
cases := []struct {
name string
M, K int
overlapFrac float64 // 0 = disjoint per-policy sources, 1 = all share the same pool
}{
{"M=10/K=100/disjoint", 10, 100, 0},
{"M=100/K=100/disjoint", 100, 100, 0},
{"M=100/K=1000/disjoint", 100, 1000, 0},
{"M=100/K=1000/overlap=0.5", 100, 1000, 0.5},
{"M=100/K=1000/overlap=1.0", 100, 1000, 1.0},
{"M=1000/K=100/overlap=1.0", 1000, 100, 1.0},
}
for _, c := range cases {
b.Run(c.name, func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
mgr, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
populateIndexedRules(b, mgr, c.M, c.K, c.overlapFrac)
runtime.GC()
var ms runtime.MemStats
runtime.ReadMemStats(&ms)
before := ms.HeapAlloc
// Drop the manager's external roots so we can isolate
// the index cost. We hold the manager itself live; the
// index is what we measure on the second pass.
mgr.incomingAcceptIndex.reset()
mgr.incomingDenyIndex.reset()
mgr.incomingRules = mgr.incomingRules[:0]
mgr.incomingDenyRules = mgr.incomingDenyRules[:0]
runtime.GC()
runtime.ReadMemStats(&ms)
after := ms.HeapAlloc
delta := int64(before) - int64(after)
if delta < 0 {
delta = 0
}
b.ReportMetric(float64(delta)/float64(c.M), "bytes/rule")
b.ReportMetric(float64(delta), "bytes/total")
require.NoError(b, mgr.Close(nil))
}
})
}
}
func populateIndexedRules(b *testing.B, mgr *Manager, m, k int, overlapFrac float64) {
b.Helper()
pool := uniqueAddrs(k+m*k, 1) // big enough universe
sharedLen := int(float64(k) * overlapFrac)
if sharedLen > k {
sharedLen = k
}
shared := pool[:sharedLen]
uniquePool := pool[sharedLen:]
for i := 0; i < m; i++ {
sources := make([]netip.Prefix, 0, k)
for _, a := range shared {
sources = append(sources, netip.PrefixFrom(a, 32))
}
// each policy gets (k-sharedLen) addresses unique to it from the unique pool
unique := uniquePool[i*(k-sharedLen) : (i+1)*(k-sharedLen)]
for _, a := range unique {
sources = append(sources, netip.PrefixFrom(a, 32))
}
_, err := mgr.AddFilterRule(
nil, sources, fw.Network{}, fw.ProtocolTCP, nil,
&fw.Port{Values: []uint16{uint16(80 + i)}},
fw.ActionAccept)
require.NoError(b, err)
}
}
// uniqueAddrs returns n distinct addrs. Seeds 1, 2 are used for
// policy sources / dst; seed 99 puts misses in 10/8.
func uniqueAddrs(n int, seed int64) []netip.Addr {
out := make([]netip.Addr, 0, n)
seen := make(map[netip.Addr]struct{}, n)
r := rand.New(rand.NewSource(seed))
miss := seed == 99
for len(out) < n {
var b [4]byte
if miss {
b[0] = 10
b[1] = byte(r.Intn(256))
} else {
b[0] = 100
b[1] = byte(64 + r.Intn(63))
}
b[2] = byte(r.Intn(256))
b[3] = byte(1 + r.Intn(254))
a := netip.AddrFrom4(b)
if _, ok := seen[a]; ok {
continue
}
seen[a] = struct{}{}
out = append(out, a)
}
return out
}
// uniqueAddrs6 mirrors uniqueAddrs for IPv6: sources come from the ULA
// range fd00::/8, the miss set (seed 99) from 2001:db8::/32 so it is
// disjoint from any source.
func uniqueAddrs6(n int, seed int64) []netip.Addr {
out := make([]netip.Addr, 0, n)
seen := make(map[netip.Addr]struct{}, n)
r := rand.New(rand.NewSource(seed))
miss := seed == 99
for len(out) < n {
var b [16]byte
if miss {
b[0], b[1], b[2], b[3] = 0x20, 0x01, 0x0d, 0xb8
} else {
b[0] = 0xfd
}
for x := 8; x < 16; x++ {
b[x] = byte(r.Intn(256))
}
a := netip.AddrFrom16(b)
if _, ok := seen[a]; ok {
continue
}
seen[a] = struct{}{}
out = append(out, a)
}
return out
}
// generatePacket6 builds an IPv6 TCP/UDP packet, mirroring
// generatePacket for the v4 case.
func generatePacket6(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16, protocol layers.IPProtocol) []byte {
b.Helper()
ipv6 := &layers.IPv6{
Version: 6,
HopLimit: 64,
NextHeader: protocol,
SrcIP: srcIP,
DstIP: dstIP,
}
var transportLayer gopacket.SerializableLayer
switch protocol {
case layers.IPProtocolTCP:
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
SYN: true,
}
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv6))
transportLayer = tcp
case layers.IPProtocolUDP:
udp := &layers.UDP{
SrcPort: layers.UDPPort(srcPort),
DstPort: layers.UDPPort(dstPort),
}
require.NoError(b, udp.SetNetworkLayerForChecksum(ipv6))
transportLayer = udp
}
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv6, transportLayer, gopacket.Payload("test")))
return buf.Bytes()
}

View File

@@ -0,0 +1,175 @@
package uspfilter
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
nbiface "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
func newTestManager(t *testing.T) *Manager {
t.Helper()
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err, "create manager")
return m
}
// TestAddPeerFiltering_DeduplicatesIdenticalRules verifies that adding
// the same peer rule twice does not create two backing rules. The acl
// manager keys its own cache, but the firewall backend must be
// idempotent on its own so a double-apply cannot leak rules, matching
// the route path and the kernel backends.
func TestAddPeerFiltering_DeduplicatesIdenticalRules(t *testing.T) {
m := newTestManager(t)
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
first, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err, "first add")
require.Len(t, first, 1, "first add should yield one rule")
second, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err, "second add")
require.Len(t, second, 1, "second add should yield one rule")
assert.Equal(t, first[0].ID(), second[0].ID(), "duplicate add should return the same rule id")
assert.Len(t, m.incomingDenyRules, 1, "duplicate add must not create a second backing rule")
}
// TestDeletePeerFiltering_NoRefcountSingleDeleteRemoves locks the
// backend's no-refcount contract: a content key installed twice is
// still one rule, and the first DeleteFilterRule removes it. The
// backend does not refcount, so balance is the caller's job (it keys
// its tracking by the returned id and deletes once per key). If this
// ever silently grew a refcount, the acl manager's delete accounting
// would diverge from the kernel.
func TestDeletePeerFiltering_NoRefcountSingleDeleteRemoves(t *testing.T) {
m := newTestManager(t)
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
first, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err, "first add")
require.Len(t, first, 1)
second, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err, "second add")
require.Len(t, second, 1)
require.Equal(t, first[0].ID(), second[0].ID(), "dedup to one rule")
require.Len(t, m.incomingDenyRules, 1, "still one backing rule after duplicate add")
require.NoError(t, m.DeleteFilterRule(first[0]), "delete once")
assert.Empty(t, m.incomingDenyRules, "single delete removes the backing rule (no refcount)")
assert.NotContains(t, m.peerRulesMap, first[0].ID(), "dedup map entry cleared")
}
// TestDeletePeerFiltering_ByRuleID verifies a peer rule can be deleted
// by its bare RuleID, not only by the concrete *PeerRule, so a caller
// that tracks ids cannot accidentally fall through to the route path.
func TestDeletePeerFiltering_ByRuleID(t *testing.T) {
m := newTestManager(t)
ip := net.ParseIP("192.168.1.1")
rules, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err)
require.Len(t, rules, 1)
require.NoError(t, m.DeleteFilterRule(rules[0].ID()), "delete by bare id")
assert.Empty(t, m.incomingRules, "rule removed when deleted by id")
assert.NotContains(t, m.peerRulesMap, rules[0].ID())
}
// TestAddPeerFiltering_DeterministicID verifies the peer rule id is a
// content hash, not a random UUID: identical inputs produce the same id
// across independent managers. A random id breaks caller-side dedup.
func TestAddPeerFiltering_DeterministicID(t *testing.T) {
ip := net.ParseIP("10.0.0.5")
proto := fw.ProtocolUDP
port := &fw.Port{Values: []uint16{53}}
action := fw.ActionAccept
m1 := newTestManager(t)
r1, err := m1.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err)
require.Len(t, r1, 1)
m2 := newTestManager(t)
r2, err := m2.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err)
require.Len(t, r2, 1)
assert.Equal(t, r1[0].ID(), r2[0].ID(), "same inputs must produce the same rule id")
}
// TestAddPeerFiltering_DistinctRulesNotDeduped verifies that rules
// differing only by port are kept separate.
func TestAddPeerFiltering_DistinctRulesNotDeduped(t *testing.T) {
m := newTestManager(t)
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
action := fw.ActionAccept
r80, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, &fw.Port{Values: []uint16{80}}, action)
require.NoError(t, err)
require.Len(t, r80, 1)
r443, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, &fw.Port{Values: []uint16{443}}, action)
require.NoError(t, err)
require.Len(t, r443, 1)
assert.NotEqual(t, r80[0].ID(), r443[0].ID(), "different ports must produce different rule ids")
assert.Len(t, m.incomingRules, 2, "distinct rules must both be stored")
}
// TestAddPeerFiltering_SourceVsDestPortNotDeduped verifies that a rule
// matching on source port and one matching on destination port for the
// same selector do not collide: the port lands in a different slot, so
// the content key must differ.
func TestAddPeerFiltering_SourceVsDestPortNotDeduped(t *testing.T) {
m := newTestManager(t)
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
dPortRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err)
require.Len(t, dPortRule, 1)
sPortRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, port, nil, action)
require.NoError(t, err)
require.Len(t, sPortRule, 1)
assert.NotEqual(t, dPortRule[0].ID(), sPortRule[0].ID(), "source-port and dest-port matches must produce different rule ids")
}
// TestAddFilterRule_EmptySourcesRejected verifies that an empty source
// list is rejected rather than treated as "match any". "Match any" must
// be an explicit /0, so a zeroed list can never silently widen a rule to
// every source.
func TestAddFilterRule_EmptySourcesRejected(t *testing.T) {
m := newTestManager(t)
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
_, err := m.AddFilterRule(nil, nil, fw.Network{}, proto, nil, port, fw.ActionAccept)
require.ErrorIs(t, err, fw.ErrNoSources, "empty sources must be rejected")
assert.Empty(t, m.incomingRules, "no rule should be stored for empty sources")
}

View File

@@ -0,0 +1,106 @@
package uspfilter
import (
"net"
"net/netip"
"testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
nbiface "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func newV6TestManager(t *testing.T, localV6 string) *Manager {
t.Helper()
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.10.0.100"),
Network: netip.MustParsePrefix("100.10.0.0/16"),
IPv6: netip.MustParseAddr(localV6),
IPv6Net: netip.MustParsePrefix("fd00::/64"),
}
},
}
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err, "create manager")
t.Cleanup(func() { require.NoError(t, m.Close(nil)) })
return m
}
func v6UDPPacket(t *testing.T, src, dst string, dstPort uint16) []byte {
t.Helper()
ip6 := &layers.IPv6{
Version: 6,
HopLimit: 64,
NextHeader: layers.IPProtocolUDP,
SrcIP: net.ParseIP(src),
DstIP: net.ParseIP(dst),
}
udp := &layers.UDP{SrcPort: 51334, DstPort: layers.UDPPort(dstPort)}
require.NoError(t, udp.SetNetworkLayerForChecksum(ip6))
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
require.NoError(t, gopacket.SerializeLayers(buf, opts, ip6, udp, gopacket.Payload("test")))
return buf.Bytes()
}
// TestPeerACL_IPv6HostRule verifies the source index resolves /128 v6
// rules: a matching v6 source is accepted, a non-matching one is
// denied by the default. This is the end-to-end proof that the index
// is not v4-only.
func TestPeerACL_IPv6HostRule(t *testing.T) {
m := newV6TestManager(t, "fd00::100")
src := net.ParseIP("fd00::1")
_, err := m.AddFilterRule(nil, pfx(src), fw.Network{}, fw.ProtocolUDP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionAccept)
require.NoError(t, err, "add v6 accept rule")
require.False(t, m.filterInbound(v6UDPPacket(t, "fd00::1", "fd00::100", 53), 0),
"v6 packet from the allowed /128 source must be accepted")
require.True(t, m.filterInbound(v6UDPPacket(t, "fd00::2", "fd00::100", 53), 0),
"v6 packet from an unlisted source must be denied by default")
}
// TestPeerACL_IPv6IndexBuckets verifies that v6 sources land in the
// right index bucket: a /128 in bySource keyed by its address, a
// coarser prefix in byCIDR, and ::/0 in matchAny.
func TestPeerACL_IPv6IndexBuckets(t *testing.T) {
m := newV6TestManager(t, "fd00::100")
port := &fw.Port{Values: []uint16{53}}
host := netip.MustParseAddr("fd00::1")
_, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(host, 128)}, fw.Network{}, fw.ProtocolUDP, nil, port, fw.ActionAccept)
require.NoError(t, err)
assert.Contains(t, m.incomingAcceptIndex.bySource, host, "/128 v6 source must be indexed by address")
_, err = m.AddFilterRule(nil, []netip.Prefix{netip.MustParsePrefix("fd00:dead::/64")}, fw.Network{}, fw.ProtocolUDP, nil, port, fw.ActionAccept)
require.NoError(t, err)
require.Len(t, m.incomingAcceptIndex.byCIDR, 1, "coarser v6 prefix must land in byCIDR")
_, err = m.AddFilterRule(nil, []netip.Prefix{netip.MustParsePrefix("::/0")}, fw.Network{}, fw.ProtocolUDP, nil, port, fw.ActionAccept)
require.NoError(t, err)
require.Len(t, m.incomingAcceptIndex.matchAny, 1, "::/0 source must land in matchAny")
}
// TestPeerACL_IPv4MappedSourceNormalized verifies a v4-mapped v6
// source prefix is normalized to v4 so a plain v4 packet matches it.
func TestPeerACL_IPv4MappedSourceNormalized(t *testing.T) {
m := newTestManager(t)
mapped := netip.MustParseAddr("::ffff:192.168.1.1")
rules, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(mapped, mapped.BitLen())}, fw.Network{}, fw.ProtocolUDP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionAccept)
require.NoError(t, err)
require.Len(t, rules, 1)
v4 := netip.MustParseAddr("192.168.1.1")
assert.Contains(t, m.incomingAcceptIndex.bySource, v4, "v4-mapped v6 source must be indexed as plain v4")
}

View File

@@ -10,24 +10,50 @@ import (
// PeerRule to handle management of rules
type PeerRule struct {
id string
mgmtId []byte
ip netip.Addr
ipLayer gopacket.LayerType
matchByIP bool
id firewall.RuleID
mgmtId []byte
// sources is the canonical list of source prefixes this rule
// matches against. A single 0.0.0.0/0 (or ::/0) entry means
// "match any source".
sources []netip.Prefix
// sourceAddrs is a fast-path membership set for host-prefix
// sources (/32 v4, /128 v6). Populated alongside sources;
// consulted before falling back to prefix scan.
sourceAddrs map[netip.Addr]struct{}
// matchAny is true when sources covers everything (0.0.0.0/0,
// ::/0). In that case neither sourceAddrs nor sources need to be
// consulted.
matchAny bool
protoLayer gopacket.LayerType
sPort *firewall.Port
dPort *firewall.Port
drop bool
srcPort *firewall.Port
dstPort *firewall.Port
action firewall.Action
}
// matchesSource reports whether the given source address is covered
// by this rule's source list.
func (r *PeerRule) matchesSource(src netip.Addr) bool {
if r.matchAny {
return true
}
if _, ok := r.sourceAddrs[src]; ok {
return true
}
for _, p := range r.sources {
if p.Contains(src) {
return true
}
}
return false
}
// ID returns the rule id
func (r *PeerRule) ID() string {
func (r *PeerRule) ID() firewall.RuleID {
return r.id
}
type RouteRule struct {
id string
id firewall.RuleID
mgmtId []byte
sources []netip.Prefix
dstSet firewall.Set
@@ -39,6 +65,6 @@ type RouteRule struct {
}
// ID returns the rule id
func (r *RouteRule) ID() string {
func (r *RouteRule) ID() firewall.RuleID {
return r.id
}

View File

@@ -0,0 +1,50 @@
package uspfilter
import (
"net"
"net/netip"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
// countRulesForAddr reports how many rules in the given slice match
// the supplied source address.
func countRulesForAddr(rules PeerRules, src netip.Addr) int {
n := 0
for _, r := range rules {
if r.matchesSource(src) {
n++
}
}
return n
}
// findRuleByID returns true if the rules slice contains a rule with
// the given id whose source set covers src.
func findRuleByID(rules PeerRules, src netip.Addr, id firewall.RuleID) bool {
for _, r := range rules {
if r.id == id && r.matchesSource(src) {
return true
}
}
return false
}
// pfx converts a single net.IP into the []netip.Prefix form
// AddFilterRule expects. A nil or unspecified address becomes a /0
// ("match any") prefix in the matching family; any other address
// becomes its /32 (or /128) host prefix.
func pfx(ip net.IP) []netip.Prefix {
if ip == nil {
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
if ip.IsUnspecified() {
if ip.To4() != nil {
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
return []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
}
a, _ := netip.AddrFromSlice(ip)
a = a.Unmap()
return []netip.Prefix{netip.PrefixFrom(a, a.BitLen())}
}

View File

@@ -97,7 +97,7 @@ func TestTracePacket(t *testing.T) {
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -121,7 +121,7 @@ func TestTracePacket(t *testing.T) {
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -150,7 +150,7 @@ func TestTracePacket(t *testing.T) {
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -178,7 +178,7 @@ func TestTracePacket(t *testing.T) {
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -205,7 +205,7 @@ func TestTracePacket(t *testing.T) {
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
_, err := m.AddFilterRule(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -231,7 +231,7 @@ func TestTracePacket(t *testing.T) {
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
_, err := m.AddFilterRule(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -332,7 +332,7 @@ func TestTracePacket(t *testing.T) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolICMP
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, nil, action)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -355,7 +355,7 @@ func TestTracePacket(t *testing.T) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolICMP
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, nil, action)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -379,7 +379,7 @@ func TestTracePacket(t *testing.T) {
proto := fw.ProtocolUDP
port := &fw.Port{Values: []uint16{53}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
@@ -423,7 +423,7 @@ func TestTracePacket(t *testing.T) {
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {

View File

@@ -0,0 +1,190 @@
package acl
import (
"net/netip"
"sync"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/firewall"
fwmgr "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// TestNetworkZeroPrefixIsRoute guards the route-vs-peer dispatch
// invariant: the backends classify a rule as a peer rule purely by the
// absence of a destination (neither prefix nor set). A default route
// (0.0.0.0/0 or ::/0) is a valid prefix and must therefore classify as
// a route, not collapse into the peer path.
func TestNetworkZeroPrefixIsRoute(t *testing.T) {
for _, p := range []string{"0.0.0.0/0", "::/0", "10.0.0.0/8"} {
n := fwmgr.Network{Prefix: netip.MustParsePrefix(p)}
assert.True(t, n.IsPrefix(), "%s must report IsPrefix", p)
assert.True(t, n.IsPrefix() || n.IsSet(), "%s must classify as a route", p)
}
// A zero-value Network is the only peer-rule shape.
var empty fwmgr.Network
assert.False(t, empty.IsPrefix(), "zero Network must not be a prefix")
assert.False(t, empty.IsSet(), "zero Network must not be a set")
}
// TestDetermineDestinationAlwaysRoute verifies determineDestination
// never yields an empty Network for a valid route rule: every branch
// (static prefix, default route, dynamic with/without domains, with and
// without a local resolver) produces a destination that classifies as a
// route. If this regresses, a route rule would be dispatched down the
// peer path, which matches on source only.
func TestDetermineDestinationAlwaysRoute(t *testing.T) {
v4 := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}
v6 := []netip.Prefix{netip.MustParsePrefix("2001:db8::/48")}
cases := []struct {
name string
rule *mgmProto.RouteFirewallRule
resolver bool
sources []netip.Prefix
}{
{"static prefix", &mgmProto.RouteFirewallRule{Destination: "192.168.0.0/16"}, false, v4},
{"static default route", &mgmProto.RouteFirewallRule{Destination: "0.0.0.0/0"}, false, v4},
{"dynamic with domains + resolver", &mgmProto.RouteFirewallRule{IsDynamic: true, Domains: []string{"example.com"}}, true, v4},
{"dynamic no domains + resolver (v4)", &mgmProto.RouteFirewallRule{IsDynamic: true}, true, v4},
{"dynamic no domains + resolver (v6)", &mgmProto.RouteFirewallRule{IsDynamic: true}, true, v6},
{"dynamic + no local resolver (v4)", &mgmProto.RouteFirewallRule{IsDynamic: true}, false, v4},
{"dynamic + no local resolver (v6)", &mgmProto.RouteFirewallRule{IsDynamic: true}, false, v6},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
dest, err := determineDestination(tc.rule, tc.resolver, tc.sources)
require.NoError(t, err)
assert.True(t, dest.IsPrefix() || dest.IsSet(),
"destination must classify as a route, got empty Network")
})
}
}
// countingFirewall wraps a real firewall.Manager and counts filter-rule
// add/delete calls so a test can assert how many backing rules the acl
// manager actually creates and tears down.
type countingFirewall struct {
fwmgr.Manager
mu sync.Mutex
addCalls int
dels int
ruleIDs map[fwmgr.RuleID]struct{}
}
// distinctRules returns the number of distinct backing rules the
// backend produced. Because the backend dedups identical content,
// repeated AddFilterRule calls for the same rule resolve to one id.
func (f *countingFirewall) distinctRules() int {
f.mu.Lock()
defer f.mu.Unlock()
return len(f.ruleIDs)
}
func (f *countingFirewall) AddFilterRule(id []byte, sources []netip.Prefix, destination fwmgr.Network, proto fwmgr.Protocol, sPort, dPort *fwmgr.Port, action fwmgr.Action) ([]fwmgr.Rule, error) {
rules, err := f.Manager.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
if err == nil {
f.mu.Lock()
f.addCalls++
if f.ruleIDs == nil {
f.ruleIDs = make(map[fwmgr.RuleID]struct{})
}
for _, r := range rules {
f.ruleIDs[r.ID()] = struct{}{}
}
f.mu.Unlock()
}
return rules, err
}
func (f *countingFirewall) DeleteFilterRule(r fwmgr.Rule) error {
err := f.Manager.DeleteFilterRule(r)
if err == nil {
f.mu.Lock()
f.dels++
delete(f.ruleIDs, r.ID())
f.mu.Unlock()
}
return err
}
func newCountingACL(t *testing.T) (*DefaultManager, *countingFirewall, func()) {
t.Helper()
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
ctrl := gomock.NewController(t)
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{IP: network.Addr(), Network: network}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
realFW, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
fw := &countingFirewall{Manager: realFW}
cleanup := func() {
require.NoError(t, realFW.Close(nil))
ctrl.Finish()
}
return NewDefaultManager(fw), fw, cleanup
}
// TestDuplicateContentPoliciesShareOneRule verifies the dedup contract
// the backends rely on: two policies that authorize an identical flow
// (same selector and sources) collapse to a single backing firewall
// rule, and that rule survives until BOTH policies are gone. This is
// why the backend can dedup on add without refcounting on delete: the
// acl manager's pair key matches the backend's content key, so add and
// delete stay balanced per content key across full-state reapplies.
func TestDuplicateContentPoliciesShareOneRule(t *testing.T) {
acl, fw, cleanup := newCountingACL(t)
defer cleanup()
ruleA := &mgmProto.FirewallRule{
PolicyID: []byte("policy-A"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
}
ruleB := &mgmProto.FirewallRule{
PolicyID: []byte("policy-B"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
}
// Both policies present: identical content collapses to one rule.
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: []*mgmProto.FirewallRule{ruleA, ruleB}, FirewallRulesIsEmpty: false}, false)
assert.Equal(t, 1, fw.distinctRules(), "identical-content policies must produce one backing rule")
assert.Equal(t, 1, len(acl.peerRulesPairs), "one content key, one pair")
// Drop policy A only: the shared rule is still authorized by B, so
// nothing is deleted.
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: []*mgmProto.FirewallRule{ruleB}, FirewallRulesIsEmpty: false}, false)
assert.Equal(t, 1, fw.distinctRules(), "no new backing rule on reapply")
assert.Equal(t, 0, fw.dels, "rule must survive while any policy still authorizes it")
assert.Equal(t, 1, len(acl.peerRulesPairs))
// Drop policy B too: now the content key has no authorizer and the
// single backing rule is removed exactly once.
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: nil, FirewallRulesIsEmpty: true}, false)
assert.Equal(t, 1, fw.dels, "rule removed once when last policy is gone")
assert.Equal(t, 0, len(acl.peerRulesPairs))
}

View File

@@ -0,0 +1,238 @@
package acl
import (
"errors"
"net/netip"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/firewall"
fwmgr "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
)
// TestGroupPeerRulesPolicyIDSeparates verifies that two FirewallRules
// with identical selectors but different PolicyIDs do NOT get merged
// into one group, so each policy's sources merge under its own
// attribution id. (Identical-content groups may still dedup to one
// backing rule at the backend; see TestDuplicateContentPoliciesShareOneRule.)
func TestGroupPeerRulesPolicyIDSeparates(t *testing.T) {
rules := []*mgmProto.FirewallRule{
{
PolicyID: []byte("policy-A"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PolicyID: []byte("policy-B"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
}
groups, merr, _ := groupPeerRules(rules)
require.Nil(t, merr.ErrorOrNil())
require.Len(t, groups, 2, "rules with different PolicyIDs must produce separate groups")
}
// TestGroupPeerRulesFamilySeparates verifies that v4 and v6 rules
// belonging to the same policy don't merge.
func TestGroupPeerRulesFamilySeparates(t *testing.T) {
rules := []*mgmProto.FirewallRule{
{
PolicyID: []byte("policy-A"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PolicyID: []byte("policy-A"),
PeerIP: "2001:db8::1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
}
groups, merr, _ := groupPeerRules(rules)
require.Nil(t, merr.ErrorOrNil())
require.Len(t, groups, 2, "rules of different families must produce separate groups")
var sawV4, sawV6 bool
for _, g := range groups {
require.Len(t, g.sources, 1)
if g.sources[0].Addr().Is4() {
sawV4 = true
}
if g.sources[0].Addr().Is6() {
sawV6 = true
}
}
assert.True(t, sawV4 && sawV6)
}
// TestGroupPeerRulesMergesSameSelector verifies that rules sharing
// every distinguishing field (policy, family, direction, action,
// proto, port) collapse into a single multi-source group.
func TestGroupPeerRulesMergesSameSelector(t *testing.T) {
mk := func(peerIP string) *mgmProto.FirewallRule {
return &mgmProto.FirewallRule{
PolicyID: []byte("policy-A"),
PeerIP: peerIP,
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
}
}
rules := []*mgmProto.FirewallRule{mk("10.0.0.1"), mk("10.0.0.2"), mk("10.0.0.3")}
groups, merr, _ := groupPeerRules(rules)
require.Nil(t, merr.ErrorOrNil())
require.Len(t, groups, 1)
require.Len(t, groups[0].sources, 3)
}
// TestGroupPeerRulesUsesSourcePrefixesWhenPresent verifies that the
// new sourcePrefixes wire field is consumed and produces a
// multi-source group in one shot (no client-side merging needed).
func TestGroupPeerRulesUsesSourcePrefixesWhenPresent(t *testing.T) {
srcs := [][]byte{
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.1")),
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.2")),
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.3")),
}
rules := []*mgmProto.FirewallRule{
{
PolicyID: []byte("policy-A"),
SourcePrefixes: srcs,
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
}
groups, merr, _ := groupPeerRules(rules)
require.Nil(t, merr.ErrorOrNil())
require.Len(t, groups, 1)
require.Len(t, groups[0].sources, 3)
}
// TestGroupPeerRulesActionSeparates verifies the obvious: accept
// and drop rules with the same selector don't merge.
func TestGroupPeerRulesActionSeparates(t *testing.T) {
rules := []*mgmProto.FirewallRule{
{
PolicyID: []byte("policy-A"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PolicyID: []byte("policy-A"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
}
groups, merr, _ := groupPeerRules(rules)
require.Nil(t, merr.ErrorOrNil())
require.Len(t, groups, 2)
}
// failingDeleteFirewall wraps a real firewall.Manager and forces the
// next N DeleteFilterRule calls to fail. Used to verify that the acl
// manager retains rules whose deletion was rejected by the backend,
// so they get retried on the next ApplyFiltering pass instead of
// becoming orphans.
type failingDeleteFirewall struct {
fwmgr.Manager
failCount int
}
func (f *failingDeleteFirewall) DeleteFilterRule(r fwmgr.Rule) error {
if f.failCount > 0 {
f.failCount--
return errors.New("simulated delete failure")
}
return f.Manager.DeleteFilterRule(r)
}
// TestApplyFilteringRetainsRulesOnDeleteFailure verifies that a
// transient DeleteFilterRule error doesn't make the acl manager
// forget about a rule. The rule must remain in peerRulesPairs so the
// next ApplyFiltering pass attempts the delete again.
func TestApplyFilteringRetainsRulesOnDeleteFailure(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{IP: network.Addr(), Network: network}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
realFW, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() { require.NoError(t, realFW.Close(nil)) }()
fw := &failingDeleteFirewall{Manager: realFW}
acl := NewDefaultManager(fw)
// First pass: install a rule.
netmap1 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PolicyID: []byte("policy-A"),
PeerIP: "10.0.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(netmap1, false)
require.Equal(t, 1, len(acl.peerRulesPairs), "rule should be installed")
// Second pass: remove the rule from the map. The backend will
// fail the delete; the acl manager must retain the rule.
fw.failCount = 1
netmap2 := &mgmProto.NetworkMap{FirewallRules: nil, FirewallRulesIsEmpty: true}
acl.ApplyFiltering(netmap2, false)
require.Equal(t, 1, len(acl.peerRulesPairs),
"rule must be retained when DeleteFilterRule fails so it gets retried")
// Third pass: same map, backend no longer fails. The rule
// should now succeed in being removed.
acl.ApplyFiltering(netmap2, false)
require.Equal(t, 0, len(acl.peerRulesPairs), "retry should succeed")
}

View File

@@ -5,18 +5,21 @@ import (
"encoding/hex"
"fmt"
"net/netip"
"slices"
"strconv"
"github.com/netbirdio/netbird/client/firewall/manager"
)
type RuleID string
// RuleID aliases manager.RuleID so existing nbid.RuleID references
// keep working while the canonical type lives in the firewall package.
type RuleID = manager.RuleID
func (r RuleID) ID() string {
return string(r)
}
func GenerateRouteRuleKey(
// GenerateRuleID returns a deterministic content hash identifying a
// filter rule. It covers both peer rules (empty destination) and route
// rules (prefix or set destination), so identical rules dedup to the
// same id across backends regardless of which path created them.
func GenerateRuleID(
sources []netip.Prefix,
destination manager.Network,
proto manager.Protocol,
@@ -24,6 +27,7 @@ func GenerateRouteRuleKey(
dPort *manager.Port,
action manager.Action,
) RuleID {
sources = slices.Clone(sources)
manager.SortPrefixes(sources)
h := sha256.New()

View File

@@ -1,8 +1,6 @@
package acl
import (
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"net/netip"
@@ -31,7 +29,6 @@ type Manager interface {
// DefaultManager uses firewall manager to handle
type DefaultManager struct {
firewall firewall.Manager
ipsetCounter int
peerRulesPairs map[id.RuleID][]firewall.Rule
routeRules map[id.RuleID]struct{}
mutex sync.Mutex
@@ -102,59 +99,271 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
)
}
newRulePairs := make(map[id.RuleID][]firewall.Rule)
ipsetByRuleSelectors := make(map[string]string)
// Group incoming single-source rules from management by their
// (direction, action, proto, port) selector and merge sources.
// One call to the firewall backend per merged rule.
groups, merr, denyFailed := groupPeerRules(rules)
if denyFailed {
log.Errorf("a deny rule failed to decode its sources, skipping this pass to avoid fail-open: %v", nberrors.FormatErrorOrNil(merr))
return
}
// TODO: deny rules should be fatal: if a deny rule fails to apply, we must
// roll back all allow rules to avoid a fail-open where allowed traffic bypasses
// the missing deny. Currently we accumulate errors and continue.
var merr *multierror.Error
for _, r := range rules {
// if this rule is member of rule selection with more than DefaultIPsCountForSet
// it's IP address can be used in the ipset for firewall manager which supports it
selector := d.getRuleGroupingSelector(r)
ipsetName, ok := ipsetByRuleSelectors[selector]
if !ok {
d.ipsetCounter++
ipsetName = fmt.Sprintf("nb%07d", d.ipsetCounter)
ipsetByRuleSelectors[selector] = ipsetName
}
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("apply firewall rule: %w", err))
continue
}
if len(rulePair) > 0 {
d.peerRulesPairs[pairID] = rulePair
newRulePairs[pairID] = rulePair
}
newRulePairs := make(map[id.RuleID][]firewall.Rule)
// Apply denies first. A deny that fails to install is a security
// failure (fail-open), so if any deny errors we roll back the
// denies we already installed in this pass and bail out without
// installing any accept. Pre-existing rules stay untouched until
// the next successful pass clears them.
denies, accepts := splitDenyAccept(groups)
if err := d.installPeerGroups(denies, newRulePairs, true); err != nil {
log.Errorf("deny install failed, skipping accepts to avoid fail-open: %v", err)
return
}
if err := d.installPeerGroups(accepts, newRulePairs, false); err != nil {
merr = multierror.Append(merr, err)
}
if merr != nil {
log.Errorf("failed to apply %d peer ACL rule(s): %v", merr.Len(), nberrors.FormatErrorOrNil(merr))
}
// Tear down rules that disappeared from the networkmap. Any rule
// the backend refuses to delete stays in our tracking so it gets
// retried on the next ApplyFiltering. Otherwise a transient
// delete failure would leak the rule in the kernel until the
// process exits.
for pairID, rules := range d.peerRulesPairs {
if _, ok := newRulePairs[pairID]; !ok {
for _, rule := range rules {
if err := d.firewall.DeletePeerRule(rule); err != nil {
log.Errorf("failed to delete peer firewall rule: %v", err)
continue
}
if _, ok := newRulePairs[pairID]; ok {
continue
}
var remaining []firewall.Rule
for _, rule := range rules {
if err := d.firewall.DeleteFilterRule(rule); err != nil {
log.Errorf("failed to delete peer firewall rule, will retry: %v", err)
remaining = append(remaining, rule)
}
delete(d.peerRulesPairs, pairID)
}
if len(remaining) > 0 {
newRulePairs[pairID] = remaining
}
}
d.peerRulesPairs = newRulePairs
}
// splitDenyAccept partitions groups by action so denies can be
// applied before accepts. Order within each bucket is preserved.
func splitDenyAccept(groups []*peerRuleGroup) (denies, accepts []*peerRuleGroup) {
for _, g := range groups {
if g.action == mgmProto.RuleAction_DROP {
denies = append(denies, g)
} else {
accepts = append(accepts, g)
}
}
return denies, accepts
}
// installPeerGroups applies each group and records the resulting rule
// pairs in newRulePairs. With atomic set (deny rules), a single failure
// rolls back every rule installed in this call and returns, leaving the
// kernel exactly as before: denies are fail-closed and must be applied
// all-or-nothing. With atomic unset (accept rules), failures are
// accumulated and the remaining groups still install, so one malformed
// allow cannot drop every other legitimate allow in the pass.
func (d *DefaultManager) installPeerGroups(groups []*peerRuleGroup, newRulePairs map[id.RuleID][]firewall.Rule, atomic bool) error {
var freshlyInstalled []id.RuleID
var merr *multierror.Error
for _, g := range groups {
pairID, rulePair, err := d.applyPeerGroup(g)
if err != nil {
if atomic {
d.rollbackInstalled(freshlyInstalled)
return fmt.Errorf("apply firewall rule: %w", err)
}
merr = multierror.Append(merr, fmt.Errorf("apply firewall rule: %w", err))
continue
}
if len(rulePair) == 0 {
continue
}
if _, existed := d.peerRulesPairs[pairID]; !existed {
freshlyInstalled = append(freshlyInstalled, pairID)
}
d.peerRulesPairs[pairID] = rulePair
newRulePairs[pairID] = rulePair
}
return nberrors.FormatErrorOrNil(merr)
}
func (d *DefaultManager) rollbackInstalled(pairIDs []id.RuleID) {
for _, pairID := range pairIDs {
for _, rule := range d.peerRulesPairs[pairID] {
if err := d.firewall.DeleteFilterRule(rule); err != nil {
log.Errorf("rollback peer rule %s: %v", pairID, err)
}
}
delete(d.peerRulesPairs, pairID)
}
}
// peerRuleGroup collapses a set of single-source FirewallRules sharing
// the same selector into one multi-source rule to push to the backend.
type peerRuleGroup struct {
direction mgmProto.RuleDirection
action mgmProto.RuleAction
protocol mgmProto.RuleProtocol
port *mgmProto.PortInfo
// legacyPort is used only when PortInfo is empty (old management).
legacyPort string
policyID []byte
sources []netip.Prefix
}
// groupPeerRules merges single-source rules sharing a selector into
// multi-source groups. The bool return reports whether any deny rule
// failed to decode its sources: a deny we cannot realize is a
// fail-open risk, so the caller skips the whole pass and retries rather
// than installing accepts on top of a missing deny.
func groupPeerRules(rules []*mgmProto.FirewallRule) ([]*peerRuleGroup, *multierror.Error, bool) {
var merr *multierror.Error
denyFailed := false
byKey := make(map[string]*peerRuleGroup)
order := make([]string, 0)
for _, r := range rules {
srcs, err := extractRuleSources(r)
if err != nil {
merr = multierror.Append(merr, err)
if r.Action == mgmProto.RuleAction_DROP {
denyFailed = true
}
continue
}
// extractRuleSources returns at least one source on success;
// pick the family from the first to key this group. Sources
// from a single FirewallRule are always same-family (mgmt
// emits one rule per family), so this is unambiguous.
family := familyTag(srcs[0])
key := ruleGroupKey(r, family)
g, ok := byKey[key]
if !ok {
g = &peerRuleGroup{
direction: r.Direction,
action: r.Action,
protocol: r.Protocol,
port: r.PortInfo,
legacyPort: r.Port,
policyID: r.PolicyID,
}
byKey[key] = g
order = append(order, key)
}
g.sources = append(g.sources, srcs...)
}
out := make([]*peerRuleGroup, 0, len(order))
for _, k := range order {
out = append(out, byKey[k])
}
return out, merr, denyFailed
}
func familyTag(p netip.Prefix) string {
if p.Addr().Is6() && !p.Addr().Is4In6() {
return "v6"
}
return "v4"
}
// ruleGroupKey returns a string that uniquely identifies a peer-rule
// selector. Rules sharing a key can be collapsed into one multi-source
// rule pushed to the firewall backend.
//
// All distinguishing fields must be in the key:
// - family (v4 vs v6): mgmt emits one FirewallRule per family per
// peer, and merging would produce mixed-family groups that broke
// ICMP-variant selection in uspfilter.
// - policyID: two policies may authorize different source peers for
// the same proto/port/direction. Keeping them in separate groups
// keeps each policy's sources in its own backend rule instead of
// merging unrelated peers into one rule attributed to a single
// policy. (When two policies authorize the identical sources and
// selector, the backend dedups them to one rule regardless,
// attributed to whichever was applied first.)
// - direction, action, protocol, port: behavioral fields; mismatched
// rules must produce mismatched kernel rules.
func ruleGroupKey(r *mgmProto.FirewallRule, family string) string {
return fmt.Sprintf("%s:%x:%d:%d:%d:%s:%v",
family, r.PolicyID, r.Direction, r.Action, r.Protocol, r.Port, r.PortInfo)
}
func (d *DefaultManager) applyPeerGroup(g *peerRuleGroup) (id.RuleID, []firewall.Rule, error) {
protocol, err := convertToFirewallProtocol(g.protocol)
if err != nil {
return "", nil, fmt.Errorf("skipping firewall rule: %w", err)
}
action, err := convertFirewallAction(g.action)
if err != nil {
return "", nil, fmt.Errorf("skipping firewall rule: %w", err)
}
port, err := resolveGroupPort(g)
if err != nil {
return "", nil, err
}
var fwRules []firewall.Rule
switch g.direction {
case mgmProto.RuleDirection_IN:
fwRules, err = d.firewall.AddFilterRule(g.policyID, g.sources, firewall.Network{}, protocol, nil, port, action)
case mgmProto.RuleDirection_OUT:
if d.firewall.IsStateful() {
return "", nil, nil
}
if shouldSkipInvertedRule(protocol, port) {
return "", nil, nil
}
fwRules, err = d.firewall.AddFilterRule(g.policyID, g.sources, firewall.Network{}, protocol, port, nil, action)
default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
}
if err != nil {
return "", nil, fmt.Errorf("add firewall rule: %w", err)
}
if len(fwRules) == 0 {
return "", nil, nil
}
// Derive the pair id from the backend rule, like the route path:
// the backend dedups identical content, so two policies authorizing
// the same flow resolve to the same id and a single backing rule.
return fwRules[0].ID(), fwRules, nil
}
func resolveGroupPort(g *peerRuleGroup) (*firewall.Port, error) {
if !portInfoEmpty(g.port) {
return convertPortInfo(g.port), nil
}
if g.legacyPort != "" {
value, err := strconv.Atoi(g.legacyPort)
if err != nil {
return nil, fmt.Errorf("invalid port: %w", err)
}
return &firewall.Port{Values: []uint16{uint16(value)}}, nil
}
// nolint:nilnil // a nil port legitimately means "no port restriction"
return nil, nil
}
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dynamicResolver bool) error {
newRouteRules := make(map[id.RuleID]struct{}, len(rules))
var merr *multierror.Error
// Apply new rules - firewall manager will return existing rule ID if already present
for _, rule := range rules {
id, err := d.applyRouteACL(rule, dynamicResolver)
ruleID, err := d.applyRouteACL(rule, dynamicResolver)
if err != nil {
if errors.Is(err, ErrSourceRangesEmpty) {
log.Debugf("skipping empty sources rule with destination %s: %v", rule.Destination, err)
@@ -163,16 +372,18 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dyn
}
continue
}
newRouteRules[id] = struct{}{}
newRouteRules[ruleID] = struct{}{}
}
// Clean up old firewall rules
for id := range d.routeRules {
if _, exists := newRouteRules[id]; !exists {
if err := d.firewall.DeleteRouteRule(id); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete route rule: %w", err))
}
// implicitly deleted from the map
// Tear down old route rules; retain ones the backend refused so a
// transient failure doesn't leave kernel-side orphans.
for ruleID := range d.routeRules {
if _, exists := newRouteRules[ruleID]; exists {
continue
}
if err := d.firewall.DeleteFilterRule(ruleID); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete route rule, will retry: %w", err))
newRouteRules[ruleID] = struct{}{}
}
}
@@ -191,7 +402,7 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamic
if err != nil {
return "", fmt.Errorf("parse source range: %w", err)
}
sources = append(sources, source)
sources = append(sources, firewall.UnmapPrefix(source))
}
destination, err := determineDestination(rule, dynamicResolver, sources)
@@ -211,71 +422,15 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamic
dPorts := convertPortInfo(rule.PortInfo)
addedRule, err := d.firewall.AddRouteFiltering(rule.PolicyID, sources, destination, protocol, nil, dPorts, action)
addedRules, err := d.firewall.AddFilterRule(rule.PolicyID, sources, destination, protocol, nil, dPorts, action)
if err != nil {
return "", fmt.Errorf("add route rule: %w", err)
}
return id.RuleID(addedRule.ID()), nil
}
func (d *DefaultManager) protoRuleToFirewallRule(
r *mgmProto.FirewallRule,
ipsetName string,
) (id.RuleID, []firewall.Rule, error) {
ip, err := extractRuleIP(r)
if err != nil {
return "", nil, err
if len(addedRules) == 0 {
return "", fmt.Errorf("add route rule: backend returned no rules")
}
protocol, err := convertToFirewallProtocol(r.Protocol)
if err != nil {
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
}
action, err := convertFirewallAction(r.Action)
if err != nil {
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
}
var port *firewall.Port
if !portInfoEmpty(r.PortInfo) {
port = convertPortInfo(r.PortInfo)
} else if r.Port != "" {
// old version of management, single port
value, err := strconv.Atoi(r.Port)
if err != nil {
return "", nil, fmt.Errorf("invalid port: %w", err)
}
port = &firewall.Port{
Values: []uint16{uint16(value)},
}
}
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action)
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
return ruleID, rulesPair, nil
}
var rules []firewall.Rule
switch r.Direction {
case mgmProto.RuleDirection_IN:
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
case mgmProto.RuleDirection_OUT:
if d.firewall.IsStateful() {
return "", nil, nil
}
// return traffic for outbound connections if firewall is stateless
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
}
if err != nil {
return "", nil, err
}
return ruleID, rules, nil
return addedRules[0].ID(), nil
}
func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
@@ -294,82 +449,28 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
}
}
func (d *DefaultManager) addInRules(
id []byte,
ip netip.Addr,
protocol firewall.Protocol,
port *firewall.Port,
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, nil, port, action, ipsetName)
if err != nil {
return nil, fmt.Errorf("add firewall rule: %w", err)
}
return rule, nil
}
func (d *DefaultManager) addOutRules(
id []byte,
ip netip.Addr,
protocol firewall.Protocol,
port *firewall.Port,
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
if shouldSkipInvertedRule(protocol, port) {
return nil, nil
}
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, port, nil, action, ipsetName)
if err != nil {
return nil, fmt.Errorf("add firewall rule: %w", err)
}
return rule, nil
}
// getPeerRuleID returns unique ID for the rule based on its parameters.
func (d *DefaultManager) getPeerRuleID(
ip netip.Addr,
proto firewall.Protocol,
direction int,
port *firewall.Port,
action firewall.Action,
) id.RuleID {
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action))
if port != nil {
idStr += port.String()
}
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
}
// getRuleGroupingSelector takes all rule properties except IP address to build selector
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
}
// extractRuleIP extracts the peer IP from a firewall rule.
// If sourcePrefixes is populated (new management), decode the first entry and use its address.
// Otherwise fall back to the deprecated PeerIP string field (old management).
func extractRuleIP(r *mgmProto.FirewallRule) (netip.Addr, error) {
// extractRuleSources returns all source prefixes the rule applies to.
// New management populates sourcePrefixes; older management sets PeerIP.
func extractRuleSources(r *mgmProto.FirewallRule) ([]netip.Prefix, error) {
if len(r.SourcePrefixes) > 0 {
addr, err := netiputil.DecodeAddr(r.SourcePrefixes[0])
if err != nil {
return netip.Addr{}, fmt.Errorf("decode source prefix: %w", err)
out := make([]netip.Prefix, 0, len(r.SourcePrefixes))
for _, raw := range r.SourcePrefixes {
addr, err := netiputil.DecodeAddr(raw)
if err != nil {
return nil, fmt.Errorf("decode source prefix: %w", err)
}
out = append(out, netip.PrefixFrom(addr.Unmap(), addr.Unmap().BitLen()))
}
return addr.Unmap(), nil
return out, nil
}
//nolint:staticcheck // PeerIP used for backward compatibility with old management
addr, err := netip.ParseAddr(r.PeerIP)
if err != nil {
return netip.Addr{}, fmt.Errorf("invalid IP address, skipping firewall rule")
return nil, fmt.Errorf("invalid IP address, skipping firewall rule")
}
return addr.Unmap(), nil
addr = addr.Unmap()
return []netip.Prefix{netip.PrefixFrom(addr, addr.BitLen())}, nil
}
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {

View File

@@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/firewall"
fwmanager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks"
@@ -76,9 +77,9 @@ func TestDefaultManager(t *testing.T) {
})
t.Run("add extra rules", func(t *testing.T) {
existedPairs := map[string]struct{}{}
existedPairs := map[fwmanager.RuleID]struct{}{}
for id := range acl.peerRulesPairs {
existedPairs[id.ID()] = struct{}{}
existedPairs[id] = struct{}{}
}
// remove first rule
@@ -105,7 +106,7 @@ func TestDefaultManager(t *testing.T) {
// check that old rule was removed
previousCount := 0
for id := range acl.peerRulesPairs {
if _, ok := existedPairs[id.ID()]; ok {
if _, ok := existedPairs[id]; ok {
previousCount++
}
}

View File

@@ -360,7 +360,13 @@ func isRedirectURLPortUsed(redirectURL string, excludedRanges []excludedPortRang
return true
}
addr := fmt.Sprintf(":%s", port)
// FreeBSD 15 disables connecting to INADDR_ANY (0.0.0.0) as a localhost
// alias by default, ensure explicit ip for localhost.
host := parsedURL.Hostname()
if host == "" {
host = "127.0.0.1"
}
addr := net.JoinHostPort(host, port)
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
if err != nil {
return false

View File

@@ -3,7 +3,6 @@ package dnsfwd
import (
"context"
"fmt"
"net"
"net/netip"
"os"
"strconv"
@@ -160,12 +159,13 @@ func (m *Manager) allowDNSFirewall() error {
return nil
}
dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
anyV4 := []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
dnsRules, err := m.firewall.AddFilterRule(nil, anyV4, firewall.Network{}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept)
if err != nil {
return fmt.Errorf("add udp firewall rule: %w", err)
}
tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "")
tcpRules, err := m.firewall.AddFilterRule(nil, anyV4, firewall.Network{}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept)
if err != nil {
return fmt.Errorf("add tcp firewall rule: %w", err)
}
@@ -209,12 +209,12 @@ func (m *Manager) unregisterNetstackServices() {
func (m *Manager) dropDNSFirewall() error {
var mErr *multierror.Error
for _, rule := range m.fwRules {
if err := m.firewall.DeletePeerRule(rule); err != nil {
if err := m.firewall.DeleteFilterRule(rule); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
}
}
for _, rule := range m.tcpRules {
if err := m.firewall.DeletePeerRule(rule); err != nil {
if err := m.firewall.DeleteFilterRule(rule); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
}
}

View File

@@ -640,14 +640,14 @@ func (e *Engine) initFirewall() error {
port := firewallManager.Port{Values: []uint16{uint16(rosenpassPort)}}
// IPv4-only: rosenpass peers connect via AllowedIps[0] which is always v4.
if _, err := e.firewall.AddPeerFiltering(
if _, err := e.firewall.AddFilterRule(
nil,
net.IP{0, 0, 0, 0},
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
firewallManager.Network{},
firewallManager.ProtocolUDP,
nil,
&port,
firewallManager.ActionAccept,
"",
); err != nil {
log.Errorf("failed to allow rosenpass interface traffic: %v", err)
return nil
@@ -697,7 +697,7 @@ func (e *Engine) blockLanAccess() {
if network.Addr().Is6() {
source = v6
}
if _, err := e.firewall.AddRouteFiltering(
if _, err := e.firewall.AddFilterRule(
nil,
[]netip.Prefix{source},
firewallManager.Network{Prefix: network},

View File

@@ -24,14 +24,14 @@ type RulePair struct {
type Manager struct {
dnatFirewall DNATFirewall
rules map[string]RulePair // keys is the ID of the ForwardRule
rules map[firewall.RuleID]RulePair // keys is the ID of the ForwardRule
rulesMu sync.Mutex
}
func NewManager(dnatFirewall DNATFirewall) *Manager {
return &Manager{
dnatFirewall: dnatFirewall,
rules: make(map[string]RulePair),
rules: make(map[firewall.RuleID]RulePair),
}
}
@@ -41,7 +41,7 @@ func (h *Manager) Update(forwardRules []firewall.ForwardRule) error {
var mErr *multierror.Error
toDelete := make(map[string]RulePair, len(h.rules))
toDelete := make(map[firewall.RuleID]RulePair, len(h.rules))
for id, r := range h.rules {
toDelete[id] = r
}
@@ -90,7 +90,7 @@ func (h *Manager) Close() error {
}
}
h.rules = make(map[string]RulePair)
h.rules = make(map[firewall.RuleID]RulePair)
return nberrors.FormatErrorOrNil(mErr)
}

View File

@@ -14,11 +14,11 @@ var (
)
type MocFwRule struct {
id string
id firewall.RuleID
}
func (m *MocFwRule) ID() string {
return string(m.id)
func (m *MocFwRule) ID() firewall.RuleID {
return m.id
}
type MockDNATFirewall struct {

View File

@@ -179,8 +179,10 @@ func getDefaultGateway() (gateway net.IP, localIP net.IP, err error) {
}
dst := net.IPv4zero
if runtime.GOOS == "linux" {
// go-netroute v0.4.0 rejects unspecified destinations client-side on Linux.
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
// go-netroute v0.4.0 rejects unspecified destinations client-side on Linux/Android.
// TODO: on android/ios, use platform APIs (ConnectivityManager.getLinkProperties /
// NWPathMonitor) when netlink-based lookup is restricted or unavailable.
dst = net.IPv4(0, 0, 0, 1)
}
_, gateway, localIP, err = router.Route(dst)
@@ -203,7 +205,7 @@ func getDefaultGateway6() (gateway net.IP, localIP net.IP, err error) {
}
dst := net.IPv6zero
if runtime.GOOS == "linux" {
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
// ::2
dst = net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}
}

View File

@@ -67,6 +67,10 @@ func init() {
rootCmd.AddCommand(newTokenCommands())
}
func RootCmd() *cobra.Command {
return rootCmd
}
func Execute() error {
return rootCmd.Execute()
}
@@ -168,7 +172,7 @@ func initializeConfig() error {
// serverInstances holds all server instances created during startup.
type serverInstances struct {
relaySrv *relayServer.Server
mgmtSrv *mgmtServer.BaseServer
mgmtSrv mgmtServer.Server
signalSrv *signalServer.Server
healthcheck *healthcheck.Server
stunServer *stun.Server
@@ -324,19 +328,24 @@ func setupServerHooks(servers *serverInstances, cfg *CombinedConfig) {
return
}
servers.mgmtSrv.AfterInit(func(s *mgmtServer.BaseServer) {
grpcSrv := s.GRPCServer()
if s, ok := servers.mgmtSrv.GetContainer(mgmtServer.ContainerKeyBaseServer); ok {
if baseServer, ok := s.(*mgmtServer.BaseServer); ok {
baseServer.AfterInit(func(s *mgmtServer.BaseServer) {
grpcSrv := s.GRPCServer()
if servers.signalSrv != nil {
proto.RegisterSignalExchangeServer(grpcSrv, servers.signalSrv)
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
}
if servers.signalSrv != nil {
proto.RegisterSignalExchangeServer(grpcSrv, servers.signalSrv)
log.Infof("Signal server registered on port %s", cfg.Server.ListenAddress)
}
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), s.IDPHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
if servers.relaySrv != nil {
log.Infof("Relay WebSocket handler added (path: /relay)")
s.SetHandlerFunc(createCombinedHandler(grpcSrv, s.APIHandler(), s.IDPHandler(), servers.relaySrv, servers.metricsServer.Meter, cfg))
if servers.relaySrv != nil {
log.Infof("Relay WebSocket handler added (path: /relay)")
}
})
}
})
}
}
func startServers(wg *sync.WaitGroup, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, metricsServer *sharedMetrics.Metrics) {
@@ -346,38 +355,32 @@ func startServers(wg *sync.WaitGroup, srv *relayServer.Server, httpHealthcheck *
log.Infof("Relay WebSocket multiplexed on management port (no separate relay listener)")
}
wg.Add(1)
go func() {
defer wg.Done()
wg.Go(func() {
log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("failed to start metrics server: %v", err)
}
}()
})
wg.Add(1)
go func() {
defer wg.Done()
wg.Go(func() {
if err := httpHealthcheck.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("failed to start healthcheck server: %v", err)
}
}()
})
if stunServer != nil {
wg.Add(1)
go func() {
defer wg.Done()
wg.Go(func() {
if err := stunServer.Listen(); err != nil {
if errors.Is(err, stun.ErrServerClosed) {
return
}
log.Errorf("STUN server error: %v", err)
}
}()
})
}
}
func shutdownServers(ctx context.Context, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, mgmtSrv *mgmtServer.BaseServer, metricsServer *sharedMetrics.Metrics) error {
func shutdownServers(ctx context.Context, srv *relayServer.Server, httpHealthcheck *healthcheck.Server, stunServer *stun.Server, mgmtSrv mgmtServer.Server, metricsServer *sharedMetrics.Metrics) error {
var errs error
if err := httpHealthcheck.Shutdown(ctx); err != nil {
@@ -491,7 +494,7 @@ func handleTLSConfig(cfg *CombinedConfig) (*tls.Config, bool, error) {
return nil, false, nil
}
func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*mgmtServer.BaseServer, error) {
func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (mgmtServer.Server, error) {
mgmt := cfg.Management
// Extract port from listen address
@@ -502,7 +505,7 @@ func createManagementServer(cfg *CombinedConfig, mgmtConfig *nbconfig.Config) (*
}
mgmtPort, _ := strconv.Atoi(portStr)
mgmtSrv := mgmtServer.NewServer(
mgmtSrv := newServer(
&mgmtServer.Config{
NbConfig: mgmtConfig,
DNSDomain: "",

13
combined/cmd/server.go Normal file
View File

@@ -0,0 +1,13 @@
package cmd
import (
mgmtServer "github.com/netbirdio/netbird/management/internals/server"
)
var newServer = func(cfg *mgmtServer.Config) mgmtServer.Server {
return mgmtServer.NewServer(cfg)
}
func SetNewServer(fn func(*mgmtServer.Config) mgmtServer.Server) {
newServer = fn
}

View File

@@ -75,7 +75,7 @@ func (m *managerImpl) SetAccountManager(accountManager account.Manager) {
}
func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) {
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
allowed, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
@@ -88,7 +88,7 @@ func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID str
}
func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
allowed, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}

View File

@@ -63,7 +63,7 @@ func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.Ac
// GetAllAccessLogs retrieves access logs for an account with pagination and filtering
func (m *managerImpl) GetAllAccessLogs(ctx context.Context, accountID, userID string, filter *accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, 0, status.NewPermissionValidationError(err)
}

View File

@@ -57,7 +57,7 @@ func NewManager(store store, proxyMgr proxyManager, permissionsManager permissio
}
func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -122,7 +122,7 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
}
func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*domain.Domain, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -163,7 +163,7 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
}
func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
@@ -187,7 +187,7 @@ func (m Manager) DeleteDomain(ctx context.Context, accountID, userID, domainID s
}
func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID string) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
ok, _, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
log.WithFields(log.Fields{
"accountID": accountID,

View File

@@ -37,7 +37,7 @@ func (h *handler) createToken(w http.ResponseWriter, r *http.Request) {
return
}
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Create)
ok, ctx, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Create)
if err != nil {
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
return
@@ -76,13 +76,13 @@ func (h *handler) createToken(w http.ResponseWriter, r *http.Request) {
return
}
if err := h.store.SaveProxyAccessToken(r.Context(), &generated.ProxyAccessToken); err != nil {
if err := h.store.SaveProxyAccessToken(ctx, &generated.ProxyAccessToken); err != nil {
util.WriteErrorResponse("failed to save token", http.StatusInternalServerError, w)
return
}
resp := toProxyTokenCreatedResponse(generated)
util.WriteJSONObject(r.Context(), w, resp)
util.WriteJSONObject(ctx, w, resp)
}
func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) {
@@ -92,7 +92,7 @@ func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) {
return
}
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Read)
ok, ctx, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Read)
if err != nil {
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
return
@@ -102,7 +102,7 @@ func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) {
return
}
tokens, err := h.store.GetProxyAccessTokensByAccountID(r.Context(), store.LockingStrengthNone, userAuth.AccountId)
tokens, err := h.store.GetProxyAccessTokensByAccountID(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil {
util.WriteErrorResponse("failed to list tokens", http.StatusInternalServerError, w)
return
@@ -113,7 +113,7 @@ func (h *handler) listTokens(w http.ResponseWriter, r *http.Request) {
resp = append(resp, toProxyTokenResponse(token))
}
util.WriteJSONObject(r.Context(), w, resp)
util.WriteJSONObject(ctx, w, resp)
}
func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) {
@@ -123,7 +123,7 @@ func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) {
return
}
ok, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Delete)
ok, ctx, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Services, operations.Delete)
if err != nil {
util.WriteErrorResponse("failed to validate permissions", http.StatusInternalServerError, w)
return
@@ -139,7 +139,7 @@ func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) {
return
}
token, err := h.store.GetProxyAccessTokenByID(r.Context(), store.LockingStrengthNone, tokenID)
token, err := h.store.GetProxyAccessTokenByID(ctx, store.LockingStrengthNone, tokenID)
if err != nil {
if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound {
util.WriteErrorResponse("token not found", http.StatusNotFound, w)
@@ -154,12 +154,12 @@ func (h *handler) revokeToken(w http.ResponseWriter, r *http.Request) {
return
}
if err := h.store.RevokeProxyAccessToken(r.Context(), tokenID); err != nil {
if err := h.store.RevokeProxyAccessToken(ctx, tokenID); err != nil {
util.WriteErrorResponse("failed to revoke token", http.StatusInternalServerError, w)
return
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
util.WriteJSONObject(ctx, w, util.EmptyObject{})
}
func toProxyTokenResponse(token *types.ProxyAccessToken) api.ProxyToken {

View File

@@ -47,7 +47,7 @@ func TestCreateToken_AccountScoped(t *testing.T) {
)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Create).Return(true, nil)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Create).Return(true, context.Background(), nil)
h := &handler{
store: mockStore,
@@ -90,7 +90,7 @@ func TestCreateToken_WithExpiration(t *testing.T) {
)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, context.Background(), nil)
h := &handler{
store: mockStore,
@@ -115,7 +115,7 @@ func TestCreateToken_EmptyName(t *testing.T) {
defer ctrl.Finish()
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, nil)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(true, context.Background(), nil)
h := &handler{
permissionsManager: permsMgr,
@@ -135,7 +135,7 @@ func TestCreateToken_PermissionDenied(t *testing.T) {
defer ctrl.Finish()
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(false, nil)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Create).Return(false, context.Background(), nil)
h := &handler{
permissionsManager: permsMgr,
@@ -164,7 +164,7 @@ func TestListTokens(t *testing.T) {
}, nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Read).Return(true, nil)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Read).Return(true, context.Background(), nil)
h := &handler{
store: mockStore,
@@ -202,7 +202,7 @@ func TestRevokeToken_Success(t *testing.T) {
mockStore.EXPECT().RevokeProxyAccessToken(gomock.Any(), "tok-1").Return(nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Delete).Return(true, nil)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), accountID, "user-1", modules.Services, operations.Delete).Return(true, context.Background(), nil)
h := &handler{
store: mockStore,
@@ -231,7 +231,7 @@ func TestRevokeToken_WrongAccount(t *testing.T) {
}, nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, context.Background(), nil)
h := &handler{
store: mockStore,
@@ -258,7 +258,7 @@ func TestRevokeToken_ManagementWideToken(t *testing.T) {
}, nil)
permsMgr := permissions.NewMockManager(ctrl)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, nil)
permsMgr.EXPECT().ValidateUserPermissions(gomock.Any(), "acc-123", "user-1", modules.Services, operations.Delete).Return(true, context.Background(), nil)
h := &handler{
store: mockStore,

View File

@@ -120,7 +120,7 @@ func (m *Manager) StartExposeReaper(ctx context.Context) {
// capability flags reported by its active proxies so the dashboard can
// render feature support without a second round-trip.
func (m *Manager) GetClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -146,7 +146,7 @@ func (m *Manager) GetClusters(ctx context.Context, accountID, userID string) ([]
// DeleteAccountCluster removes all proxy registrations for the given cluster address
// owned by the account.
func (m *Manager) DeleteAccountCluster(ctx context.Context, accountID, userID, clusterAddress string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
@@ -158,7 +158,7 @@ func (m *Manager) DeleteAccountCluster(ctx context.Context, accountID, userID, c
}
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -222,7 +222,7 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
}
func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -243,7 +243,7 @@ func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID s
}
func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -528,7 +528,7 @@ func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.St
}
func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -836,7 +836,7 @@ func validateResourceTargetType(target *service.Target, resource *resourcetypes.
}
func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
@@ -876,7 +876,7 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI
}
func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}

View File

@@ -1172,7 +1172,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
mockPerms.EXPECT().
ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete).
Return(true, nil)
Return(true, ctx, nil)
mockAcct.EXPECT().
StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, gomock.Any())
mockAcct.EXPECT().

View File

@@ -32,7 +32,7 @@ func NewManager(store store.Store, accountManager account.Manager, permissionsMa
}
func (m *managerImpl) GetAllZones(ctx context.Context, accountID, userID string) ([]*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -44,7 +44,7 @@ func (m *managerImpl) GetAllZones(ctx context.Context, accountID, userID string)
}
func (m *managerImpl) GetZone(ctx context.Context, accountID, userID, zoneID string) (*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -56,7 +56,7 @@ func (m *managerImpl) GetZone(ctx context.Context, accountID, userID, zoneID str
}
func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string, zone *zones.Zone) (*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -103,7 +103,7 @@ func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string,
}
func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string, updatedZone *zones.Zone) (*zones.Zone, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -151,7 +151,7 @@ func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string,
}
func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}

View File

@@ -79,7 +79,7 @@ func TestManagerImpl_GetAllZones(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
Return(true, ctx, nil)
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.NoError(t, err)
@@ -95,7 +95,7 @@ func TestManagerImpl_GetAllZones(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
Return(false, ctx, nil)
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.Error(t, err)
@@ -112,7 +112,7 @@ func TestManagerImpl_GetAllZones(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, status.Errorf(status.Internal, "permission check failed"))
Return(false, ctx, status.Errorf(status.Internal, "permission check failed"))
result, err := manager.GetAllZones(ctx, testAccountID, testUserID)
require.Error(t, err)
@@ -134,7 +134,7 @@ func TestManagerImpl_GetZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
Return(true, ctx, nil)
result, err := manager.GetZone(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err)
@@ -150,7 +150,7 @@ func TestManagerImpl_GetZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
Return(false, ctx, nil)
result, err := manager.GetZone(ctx, testAccountID, testUserID, testZoneID)
require.Error(t, err)
@@ -179,7 +179,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
Return(true, ctx, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
@@ -212,7 +212,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(false, nil)
Return(false, ctx, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
@@ -235,7 +235,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
Return(true, ctx, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
@@ -261,7 +261,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
Return(true, ctx, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
@@ -293,7 +293,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
Return(true, ctx, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
@@ -319,7 +319,7 @@ func TestManagerImpl_CreateZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
Return(true, ctx, nil)
result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone)
require.Error(t, err)
@@ -354,7 +354,7 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
Return(true, ctx, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
@@ -394,7 +394,7 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
Return(true, ctx, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err)
@@ -418,7 +418,7 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(false, nil)
Return(false, ctx, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err)
@@ -441,7 +441,7 @@ func TestManagerImpl_UpdateZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
Return(true, ctx, nil)
result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone)
require.Error(t, err)
@@ -471,7 +471,7 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
Return(true, ctx, nil)
storeEventCallCount := 0
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
@@ -503,7 +503,7 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
Return(true, ctx, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
@@ -529,7 +529,7 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(false, nil)
Return(false, ctx, nil)
err := manager.DeleteZone(ctx, testAccountID, testUserID, testZoneID)
require.Error(t, err)
@@ -545,7 +545,7 @@ func TestManagerImpl_DeleteZone(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
Return(true, ctx, nil)
err := manager.DeleteZone(ctx, testAccountID, testUserID, "non-existent-zone")
require.Error(t, err)

View File

@@ -32,7 +32,7 @@ func NewManager(store store.Store, accountManager account.Manager, permissionsMa
}
func (m *managerImpl) GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -44,7 +44,7 @@ func (m *managerImpl) GetAllRecords(ctx context.Context, accountID, userID, zone
}
func (m *managerImpl) GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -56,7 +56,7 @@ func (m *managerImpl) GetRecord(ctx context.Context, accountID, userID, zoneID,
}
func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *records.Record) (*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -102,7 +102,7 @@ func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneI
}
func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneID string, updatedRecord *records.Record) (*records.Record, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -161,7 +161,7 @@ func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneI
}
func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}

View File

@@ -80,7 +80,7 @@ func TestManagerImpl_GetAllRecords(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
Return(true, ctx, nil)
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.NoError(t, err)
@@ -96,7 +96,7 @@ func TestManagerImpl_GetAllRecords(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
Return(false, ctx, nil)
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.Error(t, err)
@@ -113,7 +113,7 @@ func TestManagerImpl_GetAllRecords(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, status.Errorf(status.Internal, "permission check failed"))
Return(false, ctx, status.Errorf(status.Internal, "permission check failed"))
result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID)
require.Error(t, err)
@@ -135,7 +135,7 @@ func TestManagerImpl_GetRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(true, nil)
Return(true, ctx, nil)
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, record.ID)
require.NoError(t, err)
@@ -153,7 +153,7 @@ func TestManagerImpl_GetRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read).
Return(false, nil)
Return(false, ctx, nil)
result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
require.Error(t, err)
@@ -181,7 +181,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
Return(true, ctx, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
@@ -215,7 +215,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
Return(true, ctx, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
@@ -244,7 +244,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
Return(true, ctx, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
assert.Equal(t, testUserID, initiatorID)
@@ -273,7 +273,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(false, nil)
Return(false, ctx, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
@@ -297,7 +297,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
Return(true, ctx, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
@@ -323,7 +323,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
Return(true, ctx, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
@@ -349,7 +349,7 @@ func TestManagerImpl_CreateRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create).
Return(true, nil)
Return(true, ctx, nil)
result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord)
require.Error(t, err)
@@ -380,7 +380,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
Return(true, ctx, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
@@ -418,7 +418,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
Return(true, ctx, nil)
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
// Event should be stored
@@ -445,7 +445,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(false, nil)
Return(false, ctx, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err)
@@ -470,7 +470,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
Return(true, ctx, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err)
@@ -500,7 +500,7 @@ func TestManagerImpl_UpdateRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update).
Return(true, nil)
Return(true, ctx, nil)
result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord)
require.Error(t, err)
@@ -523,7 +523,7 @@ func TestManagerImpl_DeleteRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
Return(true, ctx, nil)
storeEventCalled := false
mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) {
@@ -549,7 +549,7 @@ func TestManagerImpl_DeleteRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(false, nil)
Return(false, ctx, nil)
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID)
require.Error(t, err)
@@ -565,7 +565,7 @@ func TestManagerImpl_DeleteRecord(t *testing.T) {
mockPermissionsManager.EXPECT().
ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete).
Return(true, nil)
Return(true, ctx, nil)
err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, "non-existent-record")
require.Error(t, err)

View File

@@ -34,6 +34,8 @@ const (
ManagementLegacyPort = 33073
// DefaultSelfHostedDomain is the default domain used for self-hosted fresh installs.
DefaultSelfHostedDomain = "netbird.selfhosted"
ContainerKeyBaseServer = "baseServer"
)
type Server interface {
@@ -91,7 +93,7 @@ type Config struct {
// NewServer initializes and configures a new Server instance
func NewServer(cfg *Config) *BaseServer {
return &BaseServer{
s := &BaseServer{
Config: cfg.NbConfig,
container: make(map[string]any),
dnsDomain: cfg.DNSDomain,
@@ -104,6 +106,9 @@ func NewServer(cfg *Config) *BaseServer {
mgmtMetricsPort: cfg.MgmtMetricsPort,
autoResolveDomains: cfg.AutoResolveDomains,
}
s.container[ContainerKeyBaseServer] = s
return s
}
func (s *BaseServer) AfterInit(fn func(s *BaseServer)) {

View File

@@ -6,9 +6,11 @@ import (
"net/netip"
"net/url"
"strings"
"time"
log "github.com/sirupsen/logrus"
goproto "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
@@ -185,9 +187,38 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim}
}
// settings == nil → field stays nil → "no info in this snapshot", client
// preserves the deadline it already had. settings non-nil → emit either a
// valid deadline or the explicit-zero "disabled" sentinel via
// encodeSessionExpiresAt.
if settings != nil {
response.SessionExpiresAt = encodeSessionExpiresAt(
peer.SessionExpiresAt(settings.PeerLoginExpirationEnabled, settings.PeerLoginExpiration),
)
}
return response
}
// encodeSessionExpiresAt encodes a server-side deadline into the 3-state wire
// representation used on LoginResponse, SyncResponse and
// ExtendAuthSessionResponse. See the proto comments on those messages.
//
// - deadline.IsZero() → returns &Timestamp{} (seconds=0, nanos=0): the
// "expiry disabled or peer is not SSO-tracked" sentinel; the client clears
// its anchor.
// - deadline non-zero → returns timestamppb.New(deadline): the new absolute
// UTC deadline.
//
// Returning nil ("no info, preserve client's anchor") is the caller's job —
// only meaningful on Sync builds where settings were not resolved.
func encodeSessionExpiresAt(deadline time.Time) *timestamppb.Timestamp {
if deadline.IsZero() {
return &timestamppb.Timestamp{}
}
return timestamppb.New(deadline)
}
func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
userIDToIndex := make(map[string]uint32)
var hashedUsers [][]byte

View File

@@ -5,6 +5,7 @@ import (
"net/netip"
"reflect"
"testing"
"time"
"github.com/stretchr/testify/assert"
@@ -200,3 +201,29 @@ func TestBuildJWTConfig_Audiences(t *testing.T) {
})
}
}
// TestEncodeSessionExpiresAt pins the wire encoding the client's
// applySessionDeadline depends on:
//
// - zero deadline → &Timestamp{} (seconds=0, nanos=0): the explicit
// "expiry disabled or peer is not SSO-tracked" sentinel.
// - non-zero → timestamppb.New(deadline): the absolute UTC deadline.
//
// The third state (nil pointer = "no info in this snapshot") is the caller's
// responsibility on the Sync path when settings could not be resolved; the
// helper itself never returns nil.
func TestEncodeSessionExpiresAt(t *testing.T) {
t.Run("zero deadline encodes as explicit-zero sentinel", func(t *testing.T) {
got := encodeSessionExpiresAt(time.Time{})
assert.NotNil(t, got, "must not return nil; nil means 'no info', not 'disabled'")
assert.Equal(t, int64(0), got.GetSeconds())
assert.Equal(t, int32(0), got.GetNanos())
})
t.Run("non-zero deadline round-trips", func(t *testing.T) {
deadline := time.Date(2030, 1, 2, 3, 4, 5, 0, time.UTC)
got := encodeSessionExpiresAt(deadline)
assert.NotNil(t, got)
assert.True(t, got.AsTime().Equal(deadline))
})
}

View File

@@ -821,6 +821,80 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
}, nil
}
// ExtendAuthSession refreshes the peer's SSO session expiry deadline using a
// fresh JWT. The same JWT validation pipeline as Login is used. The tunnel
// stays up; no network map sync is performed. The new deadline is returned
// in ExtendAuthSessionResponse.SessionExpiresAt.
func (s *Server) ExtendAuthSession(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
extendReq := &proto.ExtendAuthSessionRequest{}
peerKey, err := s.parseRequest(ctx, req, extendReq)
if err != nil {
return nil, err
}
//nolint
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
if accountID, accErr := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()); accErr == nil {
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
}
jwt := extendReq.GetJwtToken()
if jwt == "" {
return nil, status.Errorf(codes.InvalidArgument, "jwt token is required")
}
var userID string
const attempts = 3
for i := 0; i < attempts; i++ {
userID, err = s.validateToken(ctx, peerKey.String(), jwt)
if err == nil {
break
}
if i == attempts-1 {
break
}
log.WithContext(ctx).Warnf("failed validating JWT token while extending session for peer %s: %v. Retrying (idP cache).", peerKey.String(), err)
select {
case <-time.After(200 * time.Millisecond):
case <-ctx.Done():
return nil, ctx.Err()
}
}
if err != nil {
return nil, err
}
if userID == "" {
return nil, status.Errorf(codes.Unauthenticated, "jwt token did not yield a user id")
}
deadline, err := s.accountManager.ExtendPeerSession(ctx, peerKey.String(), userID)
if err != nil {
log.WithContext(ctx).Warnf("failed extending session for peer %s: %v", peerKey.String(), err)
return nil, mapError(ctx, err)
}
// Success path normally returns a non-zero deadline. A defensive zero
// would still encode as the explicit "disabled" sentinel rather than nil,
// so the client clears any stale anchor instead of preserving it.
resp := &proto.ExtendAuthSessionResponse{
SessionExpiresAt: encodeSessionExpiresAt(deadline),
}
wgKey, err := s.secretsManager.GetWGKey()
if err != nil {
return nil, status.Errorf(codes.Internal, "failed processing request")
}
encrypted, err := encryption.EncryptMessage(peerKey, wgKey, resp)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed encrypting response")
}
return &proto.EncryptedMessage{
WgPubKey: wgKey.PublicKey().String(),
Body: encrypted,
}, nil
}
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
var relayToken *Token
var err error
@@ -844,6 +918,12 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne
Checks: toProtocolChecks(ctx, postureChecks),
}
// settings is always non-nil here, so we never emit nil — encoder returns
// either a valid deadline or the explicit-zero "disabled" sentinel.
loginResp.SessionExpiresAt = encodeSessionExpiresAt(
peer.SessionExpiresAt(settings.PeerLoginExpirationEnabled, settings.PeerLoginExpiration),
)
return loginResp, nil
}

View File

@@ -282,7 +282,7 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager {
// User that performs the update has to belong to the account.
// Returns an updated Settings
func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
@@ -355,7 +355,17 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled ||
oldSettings.DNSDomain != newSettings.DNSDomain ||
oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion ||
oldSettings.AutoUpdateAlways != newSettings.AutoUpdateAlways {
oldSettings.AutoUpdateAlways != newSettings.AutoUpdateAlways ||
oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled ||
oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
// Session deadline is derived from LastLogin + PeerLoginExpiration
// on every Login/Sync response. Without a fan-out push, connected
// peers keep the deadline they received at login time and only see
// the new value after the next unrelated NetworkMap change. Add
// these two fields to the trigger list so admin-side expiry tweaks
// (e.g. shortening from 24h to 1h) reach every connected peer
// within seconds, which is what the proactive-warning feature
// relies on (see client/internal/auth/sessionwatch).
updateAccountPeers = true
}
@@ -845,7 +855,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return err
}
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Delete)
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Delete)
if err != nil {
return fmt.Errorf("failed to validate user permissions: %w", err)
}
@@ -1412,7 +1422,7 @@ func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID strin
// GetAccountByID returns an account associated with this account ID.
func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -1425,7 +1435,7 @@ func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID s
// GetAccountMeta returns the account metadata associated with this account ID.
func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -1438,7 +1448,7 @@ func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID s
// GetAccountOnboarding retrieves the onboarding information for a specific account.
func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -1463,7 +1473,7 @@ func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accou
}
func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
}
@@ -1530,7 +1540,8 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
return accountID, user.Id, nil
}
if err := am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false); err != nil {
ctx, err = am.permissionsManager.ValidateAccountAccess(ctx, accountID, user, false)
if err != nil {
return "", "", err
}
@@ -1976,7 +1987,7 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction
}
func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -2544,7 +2555,7 @@ func (am *DefaultAccountManager) validateIPForUpdate(account *types.Account, pee
}
func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
if err != nil {
return fmt.Errorf("validate user permissions: %w", err)
}
@@ -2634,7 +2645,7 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti
// UpdatePeerIPv6 updates the IPv6 overlay address of a peer, validating it's
// within the account's v6 network range and not already taken.
func (am *DefaultAccountManager) UpdatePeerIPv6(ctx context.Context, accountID, userID, peerID string, newIPv6 netip.Addr) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Update)
if err != nil {
return fmt.Errorf("validate user permissions: %w", err)
}

View File

@@ -109,6 +109,7 @@ type Manager interface {
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) // used by peer gRPC API for ExtendAuthSession
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API
GetExternalCacheManager() ExternalCacheManager
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)

View File

@@ -1304,6 +1304,21 @@ func (mr *MockManagerMockRecorder) LoginPeer(ctx, login interface{}) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoginPeer", reflect.TypeOf((*MockManager)(nil).LoginPeer), ctx, login)
}
// ExtendPeerSession mocks base method.
func (m *MockManager) ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ExtendPeerSession", ctx, peerPubKey, userID)
ret0, _ := ret[0].(time.Time)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ExtendPeerSession indicates an expected call of ExtendPeerSession.
func (mr *MockManagerMockRecorder) ExtendPeerSession(ctx, peerPubKey, userID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExtendPeerSession", reflect.TypeOf((*MockManager)(nil).ExtendPeerSession), ctx, peerPubKey, userID)
}
// MarkPeerConnected mocks base method.
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
m.ctrl.T.Helper()

View File

@@ -240,6 +240,10 @@ const (
AccountLocalMfaEnabled Activity = 123
// AccountLocalMfaDisabled indicates that a user disabled TOTP MFA for local users
AccountLocalMfaDisabled Activity = 124
// UserExtendedPeerSession indicates that a user refreshed their peer's
// SSO session deadline via ExtendAuthSession without re-establishing the
// tunnel. Distinct from UserLoggedInPeer (full interactive login).
UserExtendedPeerSession Activity = 125
AccountDeleted Activity = 99999
)
@@ -394,6 +398,8 @@ var activityMap = map[Activity]Code{
AccountLocalMfaEnabled: {"Account local MFA enabled", "account.setting.local.mfa.enable"},
AccountLocalMfaDisabled: {"Account local MFA disabled", "account.setting.local.mfa.disable"},
UserExtendedPeerSession: {"User extended peer session", "user.peer.session.extend"},
DomainAdded: {"Domain added", "domain.add"},
DomainDeleted: {"Domain deleted", "domain.delete"},
DomainValidated: {"Domain validated", "domain.validate"},

View File

@@ -1,10 +1,27 @@
package context
import "github.com/netbirdio/netbird/shared/context"
import (
"context"
nbcontext "github.com/netbirdio/netbird/shared/context"
)
const (
RequestIDKey = context.RequestIDKey
AccountIDKey = context.AccountIDKey
UserIDKey = context.UserIDKey
PeerIDKey = context.PeerIDKey
RequestIDKey = nbcontext.RequestIDKey
AccountIDKey = nbcontext.AccountIDKey
RoleKey = nbcontext.RoleKey
UserIDKey = nbcontext.UserIDKey
PeerIDKey = nbcontext.PeerIDKey
)
// RoleFromContext returns the role stored in ctx, or empty string and false if absent.
func RoleFromContext(ctx context.Context) (string, bool) {
role, ok := ctx.Value(RoleKey).(string)
return role, ok
}
// WithRole returns a new context carrying the given role.
func WithRole(ctx context.Context, role string) context.Context {
//nolint
return context.WithValue(ctx, RoleKey, role)
}

View File

@@ -22,7 +22,7 @@ const (
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
@@ -39,7 +39,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
}
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}

View File

@@ -23,7 +23,7 @@ func isEnabled() bool {
// GetEvents returns a list of activity events of an account
func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Events, operations.Read)
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Events, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}

View File

@@ -32,7 +32,7 @@ func (e *GroupLinkError) Error() string {
// CheckGroupPermissions validates if a user has the necessary permissions to view groups
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Read)
allowed, _, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Read)
if err != nil {
return err
}
@@ -70,7 +70,7 @@ func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName,
// CreateGroup object of the peers
func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
if err != nil {
return status.NewPermissionValidationError(err)
}
@@ -125,7 +125,7 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
// UpdateGroup object of the peers
func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
@@ -200,7 +200,7 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that.
func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Create)
if err != nil {
return status.NewPermissionValidationError(err)
}
@@ -268,7 +268,7 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
// This method will not create group peer membership relations. Use AddPeerToGroup or RemovePeerFromGroup methods for that.
func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
@@ -427,7 +427,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
// Errors are collected and returned at the end.
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Delete)
allowed, ctx, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Groups, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}

Some files were not shown because too many files have changed in this diff Show More