mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-01 21:49:56 +00:00
Compare commits
20 Commits
task/align
...
peer-acl-m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9925f64e76 | ||
|
|
9189625487 | ||
|
|
e9dbf9db6f | ||
|
|
5a9e9e7bc9 | ||
|
|
43e041cf9f | ||
|
|
77e5693200 | ||
|
|
174dc24867 | ||
|
|
7ea5e37dd4 | ||
|
|
9d7ef9b255 | ||
|
|
944a258459 | ||
|
|
1f9a829f2c | ||
|
|
14af179556 | ||
|
|
1fbb5e6d5d | ||
|
|
6771e35d57 | ||
|
|
e89b1e0596 | ||
|
|
d542c60e21 | ||
|
|
4983b5cf17 | ||
|
|
b3b0feb3b8 | ||
|
|
7aebdd69dd | ||
|
|
0358be2313 |
45
.github/dependabot.yml
vendored
Normal file
45
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
open-pull-requests-limit: 15
|
||||
groups:
|
||||
actions:
|
||||
patterns:
|
||||
- "*"
|
||||
ignore:
|
||||
# git-town/action v1.3.x crashes on cyclic PR graphs (self-loop main->main
|
||||
# fork PRs) via its topological-sort visualization. Pinned to v1.2.1 in
|
||||
# git-town.yml; block v1.3.x until upstream tolerates cyclic edges.
|
||||
- dependency-name: "git-town/action"
|
||||
update-types:
|
||||
- "version-update:semver-minor"
|
||||
- "version-update:semver-major"
|
||||
|
||||
- package-ecosystem: "gomod"
|
||||
directories:
|
||||
- "/"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
open-pull-requests-limit: 15
|
||||
groups:
|
||||
aws-sdk:
|
||||
patterns:
|
||||
- "github.com/aws/aws-sdk-go-v2/*"
|
||||
pion:
|
||||
patterns:
|
||||
- "github.com/pion/*"
|
||||
gorm:
|
||||
patterns:
|
||||
- "gorm.io/*"
|
||||
otel:
|
||||
patterns:
|
||||
- "go.opentelemetry.io/*"
|
||||
testcontainers:
|
||||
patterns:
|
||||
- "github.com/testcontainers/testcontainers-go/*"
|
||||
wireguard:
|
||||
patterns:
|
||||
- "golang.zx2c4.com/wireguard*"
|
||||
109
.github/workflows/check-license-dependencies.yml
vendored
109
.github/workflows/check-license-dependencies.yml
vendored
@@ -2,16 +2,16 @@ name: Check License Dependencies
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- '.github/workflows/check-license-dependencies.yml'
|
||||
- "go.mod"
|
||||
- "go.sum"
|
||||
- ".github/workflows/check-license-dependencies.yml"
|
||||
pull_request:
|
||||
paths:
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- '.github/workflows/check-license-dependencies.yml'
|
||||
- "go.mod"
|
||||
- "go.sum"
|
||||
- ".github/workflows/check-license-dependencies.yml"
|
||||
|
||||
jobs:
|
||||
check-internal-dependencies:
|
||||
@@ -19,7 +19,10 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check for problematic license dependencies
|
||||
run: |
|
||||
@@ -56,55 +59,57 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: 'go.mod'
|
||||
cache: true
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: true
|
||||
|
||||
- name: Install go-licenses
|
||||
run: go install github.com/google/go-licenses@v1.6.0
|
||||
- name: Install go-licenses
|
||||
run: go install github.com/google/go-licenses@v1.6.0
|
||||
|
||||
- name: Check for GPL/AGPL licensed dependencies
|
||||
run: |
|
||||
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
|
||||
echo ""
|
||||
|
||||
# Check all Go packages for copyleft licenses, excluding internal netbird packages
|
||||
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
|
||||
|
||||
if [ -n "$COPYLEFT_DEPS" ]; then
|
||||
echo "Found copyleft licensed dependencies:"
|
||||
echo "$COPYLEFT_DEPS"
|
||||
- name: Check for GPL/AGPL licensed dependencies
|
||||
run: |
|
||||
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
|
||||
echo ""
|
||||
|
||||
# Filter out dependencies that are only pulled in by internal AGPL packages
|
||||
INCOMPATIBLE=""
|
||||
while IFS=',' read -r package url license; do
|
||||
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
|
||||
# Find ALL packages that import this GPL package using go list
|
||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||
# Check all Go packages for copyleft licenses, excluding internal netbird packages
|
||||
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
|
||||
|
||||
# Check if any importer is NOT in management/signal/relay
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
|
||||
|
||||
if [ -n "$BSD_IMPORTER" ]; then
|
||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
|
||||
else
|
||||
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
|
||||
fi
|
||||
fi
|
||||
done <<< "$COPYLEFT_DEPS"
|
||||
|
||||
if [ -n "$INCOMPATIBLE" ]; then
|
||||
if [ -n "$COPYLEFT_DEPS" ]; then
|
||||
echo "Found copyleft licensed dependencies:"
|
||||
echo "$COPYLEFT_DEPS"
|
||||
echo ""
|
||||
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
|
||||
echo -e "$INCOMPATIBLE"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "✅ All external license dependencies are compatible with BSD-3-Clause"
|
||||
# Filter out dependencies that are only pulled in by internal AGPL packages
|
||||
INCOMPATIBLE=""
|
||||
while IFS=',' read -r package url license; do
|
||||
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
|
||||
# Find ALL packages that import this GPL package using go list
|
||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||
|
||||
# Check if any importer is NOT in management/signal/relay
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
|
||||
|
||||
if [ -n "$BSD_IMPORTER" ]; then
|
||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
|
||||
else
|
||||
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
|
||||
fi
|
||||
fi
|
||||
done <<< "$COPYLEFT_DEPS"
|
||||
|
||||
if [ -n "$INCOMPATIBLE" ]; then
|
||||
echo ""
|
||||
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
|
||||
echo -e "$INCOMPATIBLE"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "✅ All external license dependencies are compatible with BSD-3-Clause"
|
||||
|
||||
2
.github/workflows/docs-ack.yml
vendored
2
.github/workflows/docs-ack.yml
vendored
@@ -83,7 +83,7 @@ jobs:
|
||||
|
||||
- name: Verify docs PR exists (and is open or merged)
|
||||
if: steps.validate.outputs.mode == 'added'
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
id: verify
|
||||
with:
|
||||
pr_number: ${{ steps.extract.outputs.pr_number }}
|
||||
|
||||
5
.github/workflows/forum.yml
vendored
5
.github/workflows/forum.yml
vendored
@@ -8,11 +8,10 @@ jobs:
|
||||
post:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: roots/discourse-topic-github-release-action@main
|
||||
- uses: roots/discourse-topic-github-release-action@557d74ea05b6cc0c47f555c1d5d28a89d904005b # v1.1.0
|
||||
with:
|
||||
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
|
||||
discourse-base-url: https://forum.netbird.io
|
||||
discourse-author-username: NetBird
|
||||
discourse-category: 17
|
||||
discourse-tags:
|
||||
releases
|
||||
discourse-tags: releases
|
||||
|
||||
8
.github/workflows/git-town.yml
vendored
8
.github/workflows/git-town.yml
vendored
@@ -3,7 +3,7 @@ name: Git Town
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- '**'
|
||||
- "**"
|
||||
|
||||
jobs:
|
||||
git-town:
|
||||
@@ -15,7 +15,9 @@ jobs:
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: git-town/action@v1.2.1
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: git-town/action@3d8b878379abb1ee393fb49865a28b4a6c2cd3b0 # v1.2.1
|
||||
with:
|
||||
skip-single-stacks: true
|
||||
|
||||
9
.github/workflows/golang-test-darwin.yml
vendored
9
.github/workflows/golang-test-darwin.yml
vendored
@@ -16,16 +16,18 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: macos-gotest-${{ hashFiles('**/go.sum') }}
|
||||
@@ -44,4 +46,3 @@ jobs:
|
||||
|
||||
- name: Test
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
|
||||
|
||||
21
.github/workflows/golang-test-freebsd.yml
vendored
21
.github/workflows/golang-test-freebsd.yml
vendored
@@ -15,20 +15,31 @@ jobs:
|
||||
name: "Client / Unit"
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Read Go version from go.mod
|
||||
id: goversion
|
||||
run: echo "version=$(awk '/^go / {print $2}' go.mod)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Test in FreeBSD
|
||||
id: test
|
||||
uses: vmactions/freebsd-vm@v1
|
||||
env:
|
||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
release: "14.2"
|
||||
release: "15.0"
|
||||
envs: "GO_VERSION"
|
||||
prepare: |
|
||||
pkg install -y curl pkgconf xorg
|
||||
GO_TARBALL="go1.25.3.freebsd-amd64.tar.gz"
|
||||
GO_TARBALL="go${GO_VERSION}.freebsd-amd64.tar.gz"
|
||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||
curl -vLO "$GO_URL"
|
||||
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||
|
||||
# -x - to print all executed commands
|
||||
# -e - to faile on first error
|
||||
|
||||
142
.github/workflows/golang-test-linux.yml
vendored
142
.github/workflows/golang-test-linux.yml
vendored
@@ -18,9 +18,11 @@ jobs:
|
||||
management: ${{ steps.filter.outputs.management }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- uses: dorny/paths-filter@v3
|
||||
- uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1
|
||||
id: filter
|
||||
with:
|
||||
filters: |
|
||||
@@ -28,7 +30,7 @@ jobs:
|
||||
- 'management/**'
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -36,10 +38,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
id: cache
|
||||
with:
|
||||
path: |
|
||||
@@ -113,14 +115,16 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
arch: ["386", "amd64"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -128,10 +132,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -154,18 +158,20 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags "devcert integration" -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -177,7 +183,7 @@ jobs:
|
||||
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
id: cache-restore
|
||||
with:
|
||||
path: |
|
||||
@@ -214,7 +220,7 @@ jobs:
|
||||
sh -c ' \
|
||||
apk update; apk add --no-cache \
|
||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
|
||||
go test -buildvcs=false -tags "devcert integration" -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
|
||||
'
|
||||
|
||||
test_relay:
|
||||
@@ -231,10 +237,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -246,10 +254,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -277,14 +285,16 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
arch: ["386", "amd64"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -298,7 +308,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -324,14 +334,16 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
arch: ["386", "amd64"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -343,10 +355,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -370,19 +382,21 @@ jobs:
|
||||
|
||||
test_management:
|
||||
name: "Management / Unit"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
arch: ["amd64"]
|
||||
store: ["sqlite", "postgres", "mysql"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -390,10 +404,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -410,7 +424,7 @@ jobs:
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -427,7 +441,7 @@ jobs:
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
@@ -437,13 +451,13 @@ jobs:
|
||||
|
||||
benchmark:
|
||||
name: "Management / Benchmark"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
arch: ["amd64"]
|
||||
store: ["sqlite", "postgres"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Create Docker network
|
||||
@@ -474,10 +488,12 @@ jobs:
|
||||
prom/prometheus
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -485,10 +501,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -505,7 +521,7 @@ jobs:
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -529,13 +545,13 @@ jobs:
|
||||
|
||||
api_benchmark:
|
||||
name: "Management / Benchmark (API)"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
arch: ["amd64"]
|
||||
store: ["sqlite", "postgres"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Create Docker network
|
||||
@@ -566,10 +582,12 @@ jobs:
|
||||
prom/prometheus
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -577,10 +595,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -597,7 +615,7 @@ jobs:
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -623,20 +641,22 @@ jobs:
|
||||
|
||||
api_integration_test:
|
||||
name: "Management / Integration"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres']
|
||||
arch: ["amd64"]
|
||||
store: ["sqlite", "postgres"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -644,10 +664,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
|
||||
19
.github/workflows/golang-test-windows.yml
vendored
19
.github/workflows/golang-test-windows.yml
vendored
@@ -18,10 +18,12 @@ jobs:
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
id: go
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
@@ -33,7 +35,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -44,16 +46,15 @@ jobs:
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Download wintun
|
||||
uses: carlosperate/download-file-action@v2
|
||||
id: download-wintun
|
||||
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
with:
|
||||
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
|
||||
file-name: wintun.zip
|
||||
location: ${{ env.downloadPath }}
|
||||
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
|
||||
url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
|
||||
destination: ${{ env.downloadPath }}\wintun.zip
|
||||
sha256: 07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51
|
||||
|
||||
- name: Decompressing wintun files
|
||||
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
||||
run: tar -xvf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
||||
|
||||
- run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'
|
||||
|
||||
|
||||
14
.github/workflows/golangci-lint.yml
vendored
14
.github/workflows/golangci-lint.yml
vendored
@@ -15,9 +15,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
uses: codespell-project/actions-codespell@8f01853be192eb0f849a5c7d721450e7a467c579 # v2.2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
|
||||
skip: go.mod,go.sum,**/proxy/web/**
|
||||
@@ -38,13 +40,15 @@ jobs:
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Check for duplicate constants
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: |
|
||||
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -52,7 +56,7 @@ jobs:
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
|
||||
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
|
||||
with:
|
||||
version: latest
|
||||
skip-cache: true
|
||||
|
||||
4
.github/workflows/install-script-test.yml
vendored
4
.github/workflows/install-script-test.yml
vendored
@@ -22,7 +22,9 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: run install script
|
||||
env:
|
||||
|
||||
18
.github/workflows/mobile-build-validation.yml
vendored
18
.github/workflows/mobile-build-validation.yml
vendored
@@ -16,23 +16,25 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Setup Android SDK
|
||||
uses: android-actions/setup-android@v3
|
||||
uses: android-actions/setup-android@40fd30fb8d7440372e1316f5d1809ec01dcd3699 # v4.0.1
|
||||
with:
|
||||
cmdline-tools-version: 8512546
|
||||
- name: Setup Java
|
||||
uses: actions/setup-java@v4
|
||||
uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654
|
||||
with:
|
||||
java-version: "11"
|
||||
distribution: "adopt"
|
||||
- name: NDK Cache
|
||||
id: ndk-cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: /usr/local/lib/android/sdk/ndk
|
||||
key: ndk-cache-23.1.7779620
|
||||
@@ -52,9 +54,11 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: install gomobile
|
||||
|
||||
2
.github/workflows/pr-title-check.yml
vendored
2
.github/workflows/pr-title-check.yml
vendored
@@ -9,7 +9,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Validate PR title prefix
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
script: |
|
||||
const title = context.payload.pull_request.title;
|
||||
|
||||
138
.github/workflows/proto-version-check.yml
vendored
138
.github/workflows/proto-version-check.yml
vendored
@@ -3,74 +3,92 @@ name: Proto Version Check
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- "**/*.proto"
|
||||
- "**/*.pb.go"
|
||||
- "**/generate.sh"
|
||||
- "proto-tools.env"
|
||||
- ".github/workflows/proto-version-check.yml"
|
||||
|
||||
jobs:
|
||||
regenerate-and-diff:
|
||||
name: Regenerate proto and verify no drift
|
||||
check-proto-versions:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Load pinned proto toolchain versions
|
||||
run: |
|
||||
# shellcheck source=/dev/null
|
||||
. ./proto-tools.env
|
||||
{
|
||||
echo "PROTOC_VERSION=${PROTOC_VERSION}"
|
||||
echo "PROTOC_GEN_GO_VERSION=${PROTOC_GEN_GO_VERSION}"
|
||||
echo "PROTOC_GEN_GO_GRPC_VERSION=${PROTOC_GEN_GO_GRPC_VERSION}"
|
||||
} >> "$GITHUB_ENV"
|
||||
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
- name: Check for proto tool version changes
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
script: |
|
||||
const files = await github.paginate(github.rest.pulls.listFiles, {
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: context.issue.number,
|
||||
per_page: 100,
|
||||
});
|
||||
|
||||
- name: Setup protoc
|
||||
uses: arduino/setup-protoc@f4d5893b897028ff5739576ea0409746887fa536 # v3.0.0
|
||||
with:
|
||||
version: ${{ env.PROTOC_VERSION }}
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
const modifiedPbFiles = files.filter(
|
||||
f => f.filename.endsWith('.pb.go') && f.status === 'modified'
|
||||
);
|
||||
if (modifiedPbFiles.length === 0) {
|
||||
console.log('No modified .pb.go files to check');
|
||||
return;
|
||||
}
|
||||
|
||||
- name: Install protoc plugins
|
||||
run: |
|
||||
go install "google.golang.org/protobuf/cmd/protoc-gen-go@${PROTOC_GEN_GO_VERSION}"
|
||||
go install "google.golang.org/grpc/cmd/protoc-gen-go-grpc@${PROTOC_GEN_GO_GRPC_VERSION}"
|
||||
echo "$(go env GOPATH)/bin" >> "$GITHUB_PATH"
|
||||
const versionPattern = /^\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
|
||||
const baseSha = context.payload.pull_request.base.sha;
|
||||
const headSha = context.payload.pull_request.head.sha;
|
||||
|
||||
- name: Verify protoc version matches pin
|
||||
run: |
|
||||
actual=$(protoc --version | awk '{print $2}')
|
||||
if [[ "$actual" != "$PROTOC_VERSION" ]]; then
|
||||
echo "::error::protoc $actual does not match pinned $PROTOC_VERSION"
|
||||
exit 1
|
||||
fi
|
||||
async function getVersionHeader(path, ref) {
|
||||
try {
|
||||
const res = await github.rest.repos.getContent({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
path,
|
||||
ref,
|
||||
});
|
||||
if (!res.data.content) {
|
||||
return { ok: false, reason: 'no inline content (file too large)' };
|
||||
}
|
||||
const content = Buffer.from(res.data.content, 'base64').toString('utf8');
|
||||
const lines = content
|
||||
.split('\n')
|
||||
.slice(0, 20)
|
||||
.filter(line => versionPattern.test(line));
|
||||
return { ok: true, lines };
|
||||
} catch (e) {
|
||||
return { ok: false, reason: e.message };
|
||||
}
|
||||
}
|
||||
|
||||
- name: Regenerate all proto bindings
|
||||
run: |
|
||||
set -euo pipefail
|
||||
for script in \
|
||||
client/proto/generate.sh \
|
||||
shared/signal/proto/generate.sh \
|
||||
shared/management/proto/generate.sh \
|
||||
flow/proto/generate.sh \
|
||||
encryption/testprotos/generate.sh; do
|
||||
echo "::group::$script"
|
||||
bash "$script"
|
||||
echo "::endgroup::"
|
||||
done
|
||||
const violations = [];
|
||||
for (const file of modifiedPbFiles) {
|
||||
const [base, head] = await Promise.all([
|
||||
getVersionHeader(file.filename, baseSha),
|
||||
getVersionHeader(file.filename, headSha),
|
||||
]);
|
||||
if (!base.ok || !head.ok) {
|
||||
core.warning(
|
||||
`Skipping ${file.filename}: base=${base.ok ? 'ok' : base.reason}, head=${head.ok ? 'ok' : head.reason}`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
if (base.lines.join('\n') !== head.lines.join('\n')) {
|
||||
violations.push({
|
||||
file: file.filename,
|
||||
base: base.lines,
|
||||
head: head.lines,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
- name: Fail if regeneration changed any tracked or untracked file
|
||||
run: |
|
||||
if [[ -n "$(git status --porcelain --untracked-files=all)" ]]; then
|
||||
echo "::error::Generated proto files drift from .proto sources or pinned tool versions."
|
||||
echo "Run the generate.sh scripts locally with the toolchain in proto-tools.env and commit the result."
|
||||
git status --short
|
||||
exit 1
|
||||
fi
|
||||
if (violations.length > 0) {
|
||||
const details = violations.map(v =>
|
||||
`${v.file}:\n` +
|
||||
` base:\n${v.base.map(l => ' ' + l).join('\n') || ' (none)'}\n` +
|
||||
` head:\n${v.head.map(l => ' ' + l).join('\n') || ' (none)'}`
|
||||
).join('\n\n');
|
||||
|
||||
core.setFailed(
|
||||
`Proto version strings changed in generated files.\n` +
|
||||
`This usually means the wrong protoc or protoc-gen-go version was used.\n` +
|
||||
`Regenerate with the matching tool versions.\n\n` +
|
||||
details
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
console.log('No proto version string changes detected');
|
||||
|
||||
168
.github/workflows/release.yml
vendored
168
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.1.4"
|
||||
SIGN_PIPE_VER: "v0.1.5"
|
||||
GORELEASER_VER: "v2.14.3"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
@@ -24,7 +24,9 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Generate FreeBSD port diff
|
||||
run: bash release_files/freebsd-port-diff.sh
|
||||
@@ -51,19 +53,26 @@ jobs:
|
||||
echo "Generated files for version: $VERSION"
|
||||
cat netbird-*.diff
|
||||
|
||||
- name: Read Go version from go.mod
|
||||
id: goversion
|
||||
run: echo "version=$(awk '/^go / {print $2}' go.mod)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Test FreeBSD port
|
||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||
uses: vmactions/freebsd-vm@v1
|
||||
env:
|
||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
release: "15.0"
|
||||
envs: "GO_VERSION"
|
||||
prepare: |
|
||||
# Install required packages
|
||||
pkg install -y git curl portlint go
|
||||
pkg install -y git curl portlint
|
||||
|
||||
# Install Go for building
|
||||
GO_TARBALL="go1.25.5.freebsd-amd64.tar.gz"
|
||||
GO_TARBALL="go${GO_VERSION}.freebsd-amd64.tar.gz"
|
||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||
curl -LO "$GO_URL"
|
||||
tar -C /usr/local -xzf "$GO_TARBALL"
|
||||
@@ -93,19 +102,19 @@ jobs:
|
||||
|
||||
# Show patched Makefile
|
||||
version=$(cat security/netbird/Makefile | grep -E '^DISTVERSION=' | awk '{print $NF}')
|
||||
|
||||
|
||||
cd /usr/ports/security/netbird
|
||||
export BATCH=yes
|
||||
make package
|
||||
pkg add ./work/pkg/netbird-*.pkg
|
||||
|
||||
|
||||
netbird version | grep "$version"
|
||||
|
||||
echo "FreeBSD port test completed successfully!"
|
||||
|
||||
- name: Upload FreeBSD port files
|
||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: freebsd-port-files
|
||||
path: |
|
||||
@@ -124,26 +133,25 @@ jobs:
|
||||
env:
|
||||
flags: ""
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||
with:
|
||||
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||
version_extractor_regex: '\/v(.*)$'
|
||||
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -156,18 +164,18 @@ jobs:
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v2
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a #v4.0.0
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd #v4.0.0
|
||||
- name: Login to Docker hub
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v1
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Log in to the GitHub container registry
|
||||
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
@@ -191,7 +199,7 @@ jobs:
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --clean ${{ env.flags }}
|
||||
@@ -282,28 +290,28 @@ jobs:
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
- name: upload non tags for debug purposes
|
||||
id: upload_release
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: release
|
||||
path: dist/
|
||||
retention-days: 7
|
||||
- name: upload linux packages
|
||||
id: upload_linux_packages
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: linux-packages
|
||||
path: dist/netbird_linux**
|
||||
retention-days: 7
|
||||
- name: upload windows packages
|
||||
id: upload_windows_packages
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: windows-packages
|
||||
path: dist/netbird_windows**
|
||||
retention-days: 7
|
||||
- name: upload macos packages
|
||||
id: upload_macos_packages
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: macos-packages
|
||||
path: dist/netbird_darwin**
|
||||
@@ -314,27 +322,26 @@ jobs:
|
||||
outputs:
|
||||
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||
with:
|
||||
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||
version_extractor_regex: '\/v(.*)$'
|
||||
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -375,7 +382,7 @@ jobs:
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
|
||||
@@ -404,7 +411,7 @@ jobs:
|
||||
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
||||
- name: upload non tags for debug purposes
|
||||
id: upload_release_ui
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: release-ui
|
||||
path: dist/
|
||||
@@ -418,16 +425,17 @@ jobs:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -441,7 +449,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
|
||||
@@ -449,7 +457,7 @@ jobs:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: upload non tags for debug purposes
|
||||
id: upload_release_ui_darwin
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: release-ui-darwin
|
||||
path: dist/
|
||||
@@ -474,27 +482,26 @@ jobs:
|
||||
PackageWorkdir: netbird_windows_${{ matrix.arch }}
|
||||
downloadPath: '${{ github.workspace }}\temp'
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||
with:
|
||||
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||
version_extractor_regex: '\/v(.*)$'
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
|
||||
- name: Add 7-Zip to PATH
|
||||
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
|
||||
- name: Download release artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
||||
with:
|
||||
name: release
|
||||
path: release
|
||||
|
||||
- name: Download UI release artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
||||
with:
|
||||
name: release-ui
|
||||
path: release-ui
|
||||
@@ -514,29 +521,27 @@ jobs:
|
||||
Get-ChildItem $workdir
|
||||
|
||||
- name: Download wintun
|
||||
uses: carlosperate/download-file-action@v2
|
||||
id: download-wintun
|
||||
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
with:
|
||||
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
|
||||
file-name: wintun.zip
|
||||
location: ${{ env.downloadPath }}
|
||||
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
|
||||
url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
|
||||
destination: ${{ env.downloadPath }}\wintun.zip
|
||||
sha256: 07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51
|
||||
|
||||
- name: Decompress wintun files
|
||||
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
||||
run: tar -xvf "${{ env.downloadPath }}\wintun.zip" -C ${{ env.downloadPath }}
|
||||
|
||||
- name: Move wintun.dll into dist
|
||||
run: mv ${{ env.downloadPath }}\wintun\bin\${{ matrix.wintun_arch }}\wintun.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
|
||||
|
||||
- name: Download Mesa3D (amd64 only)
|
||||
uses: carlosperate/download-file-action@v2
|
||||
id: download-mesa3d
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
with:
|
||||
file-url: https://downloads.fdossena.com/Projects/Mesa3D/Builds/MesaForWindows-x64-20.1.8.7z
|
||||
file-name: mesa3d.7z
|
||||
location: ${{ env.downloadPath }}
|
||||
sha256: '71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9'
|
||||
url: https://pkgs.netbird.io/mesa3d/MesaForWindows-x64-20.1.8.7z
|
||||
destination: ${{ env.downloadPath }}\mesa3d.7z
|
||||
sha256: 71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9
|
||||
|
||||
- name: Extract Mesa3D driver (amd64 only)
|
||||
if: matrix.arch == 'amd64'
|
||||
@@ -547,35 +552,38 @@ jobs:
|
||||
run: mv ${{ env.downloadPath }}\opengl32.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
|
||||
|
||||
- name: Download EnVar plugin for NSIS
|
||||
uses: carlosperate/download-file-action@v2
|
||||
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
with:
|
||||
file-url: https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip
|
||||
file-name: envar_plugin.zip
|
||||
location: ${{ github.workspace }}
|
||||
url: https://pkgs.netbird.io/nsis/EnVar_plugin.zip
|
||||
destination: ${{ github.workspace }}\envar_plugin.zip
|
||||
sha256: e9aa92de351345ed82795251d838f1ae9041ba35af9d381a5780c7843b01f56a
|
||||
|
||||
- name: Extract EnVar plugin
|
||||
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/envar_plugin.zip"
|
||||
|
||||
- name: Download ShellExecAsUser plugin for NSIS (amd64 only)
|
||||
uses: carlosperate/download-file-action@v2
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
with:
|
||||
file-url: https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z
|
||||
file-name: ShellExecAsUser_amd64-Unicode.7z
|
||||
location: ${{ github.workspace }}
|
||||
url: https://pkgs.netbird.io/nsis/ShellExecAsUser_amd64-Unicode.7z
|
||||
destination: ${{ github.workspace }}\ShellExecAsUser_amd64-Unicode.7z
|
||||
sha256: 0a55ea25c7330a92cec028eda8afcaf1b1a7092e0dfb77c21c8f654564b4ff9d
|
||||
|
||||
- name: Extract ShellExecAsUser plugin (amd64 only)
|
||||
if: matrix.arch == 'amd64'
|
||||
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/ShellExecAsUser_amd64-Unicode.7z"
|
||||
|
||||
- name: Build NSIS installer
|
||||
uses: joncloud/makensis-action@v3.3
|
||||
with:
|
||||
additional-plugin-paths: ${{ github.workspace }}/NSIS_Plugins/Plugins
|
||||
script-file: client/installer.nsis
|
||||
arguments: "/V4 /DARCH=${{ matrix.arch }}"
|
||||
shell: pwsh
|
||||
env:
|
||||
APPVER: ${{ steps.semver_parser.outputs.major }}.${{ steps.semver_parser.outputs.minor }}.${{ steps.semver_parser.outputs.patch }}.${{ github.run_id }}
|
||||
run: |
|
||||
$nsisPluginDir = "C:\Program Files (x86)\NSIS\Plugins\x86-unicode"
|
||||
$srcPlugins = "${{ github.workspace }}\NSIS_Plugins\Plugins"
|
||||
Get-ChildItem -Path $srcPlugins -Recurse -Filter *.dll |
|
||||
Copy-Item -Destination $nsisPluginDir -Force
|
||||
& "C:\Program Files (x86)\NSIS\makensis.exe" /V4 "/DARCH=${{ matrix.arch }}" client\installer.nsis
|
||||
if ($LASTEXITCODE -ne 0) { throw "makensis failed with exit code $LASTEXITCODE" }
|
||||
|
||||
- name: Rename NSIS installer
|
||||
run: mv netbird-installer.exe netbird_installer_test_windows_${{ matrix.arch }}.exe
|
||||
@@ -592,7 +600,7 @@ jobs:
|
||||
|
||||
- name: Upload installer artifacts
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: windows-installer-test-${{ matrix.arch }}
|
||||
path: |
|
||||
@@ -611,7 +619,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Create or update PR comment
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
env:
|
||||
RELEASE_RESULT: ${{ needs.release.result }}
|
||||
RELEASE_UI_RESULT: ${{ needs.release_ui.result }}
|
||||
@@ -703,7 +711,7 @@ jobs:
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- name: Trigger binaries sign pipelines
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: Sign bin and installer
|
||||
repo: netbirdio/sign-pipelines
|
||||
|
||||
4
.github/workflows/sync-main.yml
vendored
4
.github/workflows/sync-main.yml
vendored
@@ -14,9 +14,9 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Trigger main branch sync
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: sync-main.yml
|
||||
repo: ${{ secrets.UPSTREAM_REPO }}
|
||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
||||
inputs: '{ "sha": "${{ github.sha }}" }'
|
||||
inputs: '{ "sha": "${{ github.sha }}" }'
|
||||
|
||||
10
.github/workflows/sync-tag.yml
vendored
10
.github/workflows/sync-tag.yml
vendored
@@ -3,7 +3,7 @@ name: sync tag
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
- "v*"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
@@ -16,7 +16,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Trigger release tag sync
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: sync-tag.yml
|
||||
ref: main
|
||||
@@ -29,7 +29,7 @@ jobs:
|
||||
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
|
||||
steps:
|
||||
- name: Trigger android-client submodule bump
|
||||
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: bump-netbird.yml
|
||||
ref: main
|
||||
@@ -42,10 +42,10 @@ jobs:
|
||||
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
|
||||
steps:
|
||||
- name: Trigger ios-client submodule bump
|
||||
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: bump-netbird.yml
|
||||
ref: main
|
||||
repo: netbirdio/ios-client
|
||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
||||
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
||||
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
||||
|
||||
26
.github/workflows/test-infrastructure-files.yml
vendored
26
.github/workflows/test-infrastructure-files.yml
vendored
@@ -6,10 +6,10 @@ on:
|
||||
- main
|
||||
pull_request:
|
||||
paths:
|
||||
- 'infrastructure_files/**'
|
||||
- '.github/workflows/test-infrastructure-files.yml'
|
||||
- 'management/cmd/**'
|
||||
- 'signal/cmd/**'
|
||||
- "infrastructure_files/**"
|
||||
- ".github/workflows/test-infrastructure-files.yml"
|
||||
- "management/cmd/**"
|
||||
- "signal/cmd/**"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
@@ -20,7 +20,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
store: ["sqlite", "postgres", "mysql"]
|
||||
services:
|
||||
postgres:
|
||||
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
|
||||
@@ -68,15 +68,17 @@ jobs:
|
||||
run: sudo apt-get install -y curl
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
@@ -139,8 +141,8 @@ jobs:
|
||||
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
||||
CI_NETBIRD_SIGNAL_PORT: 12345
|
||||
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
|
||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
|
||||
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
|
||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN: "${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$"
|
||||
NETBIRD_STORE_ENGINE_MYSQL_DSN: "${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$"
|
||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
||||
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
|
||||
@@ -254,7 +256,9 @@ jobs:
|
||||
run: sudo apt-get install -y jq
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: run script with Zitadel PostgreSQL
|
||||
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
|
||||
|
||||
8
.github/workflows/update-docs.yml
vendored
8
.github/workflows/update-docs.yml
vendored
@@ -3,9 +3,9 @@ name: update docs
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
- "v*"
|
||||
paths:
|
||||
- 'shared/management/http/api/openapi.yml'
|
||||
- "shared/management/http/api/openapi.yml"
|
||||
|
||||
jobs:
|
||||
trigger_docs_api_update:
|
||||
@@ -13,10 +13,10 @@ jobs:
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- name: Trigger API pages generation
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: generate api pages
|
||||
repo: netbirdio/docs
|
||||
ref: "refs/heads/main"
|
||||
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
||||
inputs: '{ "tag": "${{ github.ref }}" }'
|
||||
inputs: '{ "tag": "${{ github.ref }}" }'
|
||||
|
||||
15
.github/workflows/wasm-build-validation.yml
vendored
15
.github/workflows/wasm-build-validation.yml
vendored
@@ -19,15 +19,17 @@ jobs:
|
||||
GOARCH: wasm
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: Install golangci-lint
|
||||
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
|
||||
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
|
||||
with:
|
||||
version: latest
|
||||
install-mode: binary
|
||||
@@ -42,9 +44,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Build Wasm client
|
||||
@@ -65,4 +69,3 @@ jobs:
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"go.opentelemetry.io/otel"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
|
||||
@@ -109,7 +109,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
|
||||
iv, _ := validator.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
@@ -84,6 +85,12 @@ type Options struct {
|
||||
DisableIPv6 bool
|
||||
// BlockInbound blocks all inbound connections from peers
|
||||
BlockInbound bool
|
||||
// BlockLANAccess blocks the embedded peer from reaching the host's
|
||||
// LAN (RFC 1918, link-local, loopback) when it's used as a routing
|
||||
// peer. Mirrors profilemanager.ConfigInput.BlockLANAccess. Useful
|
||||
// when the embedded client must never act as a stepping stone into
|
||||
// the host's local network (e.g. the proxy's overlay peer).
|
||||
BlockLANAccess bool
|
||||
// WireguardPort is the port for the tunnel interface. Use 0 for a random port.
|
||||
WireguardPort *int
|
||||
// MTU is the MTU for the tunnel interface.
|
||||
@@ -94,6 +101,26 @@ type Options struct {
|
||||
MTU *uint16
|
||||
// DNSLabels defines additional DNS labels configured in the peer.
|
||||
DNSLabels []string
|
||||
// Performance configures the tunnel's buffer pool cap and batch size.
|
||||
Performance Performance
|
||||
}
|
||||
|
||||
// Performance configures the embedded client's tunnel memory/throughput knobs.
|
||||
//
|
||||
// These settings are process-global: any non-nil field also becomes the
|
||||
// default for Clients constructed by later embed.New calls in the same
|
||||
// process. Nil fields are ignored.
|
||||
type Performance struct {
|
||||
// PreallocatedBuffersPerPool caps the per-tunnel buffer pool. Zero
|
||||
// leaves the pool unbounded. Lower values trade throughput for a
|
||||
// tighter memory ceiling. May also be changed on a running Client via
|
||||
// Client.SetPerformance, provided this field was nonzero at construction.
|
||||
PreallocatedBuffersPerPool *uint32
|
||||
// MaxBatchSize overrides the number of packets the tunnel reads or
|
||||
// writes per syscall, which also bounds eager buffer allocation per
|
||||
// worker. Zero uses the platform default. Applied at construction
|
||||
// only; ignored by Client.SetPerformance.
|
||||
MaxBatchSize *uint32
|
||||
}
|
||||
|
||||
// validateCredentials checks that exactly one credential type is provided
|
||||
@@ -175,6 +202,7 @@ func New(opts Options) (*Client, error) {
|
||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||
DisableIPv6: &opts.DisableIPv6,
|
||||
BlockInbound: &opts.BlockInbound,
|
||||
BlockLANAccess: &opts.BlockLANAccess,
|
||||
WireguardPort: opts.WireguardPort,
|
||||
MTU: opts.MTU,
|
||||
DNSLabels: parsedLabels,
|
||||
@@ -192,6 +220,13 @@ func New(opts Options) (*Client, error) {
|
||||
config.PrivateKey = opts.PrivateKey
|
||||
}
|
||||
|
||||
if opts.Performance.PreallocatedBuffersPerPool != nil {
|
||||
wgdevice.SetPreallocatedBuffersPerPool(*opts.Performance.PreallocatedBuffersPerPool)
|
||||
}
|
||||
if opts.Performance.MaxBatchSize != nil {
|
||||
wgdevice.SetMaxBatchSizeOverride(*opts.Performance.MaxBatchSize)
|
||||
}
|
||||
|
||||
return &Client{
|
||||
deviceName: opts.DeviceName,
|
||||
setupKey: opts.SetupKey,
|
||||
@@ -405,6 +440,21 @@ func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// IdentityForIP looks up a remote peer by its tunnel IP using the
|
||||
// embedded client's status recorder. Returns the peer's WireGuard public
|
||||
// key and FQDN. ok=false means the IP isn't in this client's peer
|
||||
// roster — callers should treat that as "unknown peer".
|
||||
func (c *Client) IdentityForIP(ip netip.Addr) (pubKey, fqdn string, ok bool) {
|
||||
if !ip.IsValid() || c.recorder == nil {
|
||||
return "", "", false
|
||||
}
|
||||
state, found := c.recorder.PeerStateByIP(ip.String())
|
||||
if !found {
|
||||
return "", "", false
|
||||
}
|
||||
return state.PubKey, state.FQDN, true
|
||||
}
|
||||
|
||||
// Status returns the current status of the client.
|
||||
func (c *Client) Status() (peer.FullStatus, error) {
|
||||
c.mu.Lock()
|
||||
@@ -473,6 +523,25 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
|
||||
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
|
||||
}
|
||||
|
||||
// SetPerformance retunes a running Client. Only PreallocatedBuffersPerPool
|
||||
// takes effect, and only when it was nonzero at construction;
|
||||
// MaxBatchSize is construction-only and returns an error if set here.
|
||||
//
|
||||
// Returns ErrClientNotStarted / ErrEngineNotStarted if the Client is not
|
||||
// running yet.
|
||||
func (c *Client) SetPerformance(t Performance) error {
|
||||
if t.MaxBatchSize != nil {
|
||||
return errors.New("MaxBatchSize is construction-only and cannot be changed at runtime")
|
||||
}
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return engine.SetPerformance(internal.Performance{
|
||||
PreallocatedBuffersPerPool: t.PreallocatedBuffersPerPool,
|
||||
})
|
||||
}
|
||||
|
||||
// StartCapture begins capturing packets on this client's tunnel device.
|
||||
// Only one capture can be active at a time; starting a new one stops the previous.
|
||||
// Call StopCapture (or CaptureSession.Stop) to end it.
|
||||
|
||||
@@ -21,7 +21,7 @@ func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.
|
||||
}
|
||||
|
||||
// use userspace packet filtering firewall
|
||||
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
|
||||
fm, err := uspfilter.Create(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -29,47 +29,80 @@ const (
|
||||
NFTABLES
|
||||
)
|
||||
|
||||
// SKIP_NFTABLES_ENV is the environment variable to skip nftables check
|
||||
const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
||||
// SkipNftablesEnv is the environment variable to skip nftables check
|
||||
const SkipNftablesEnv = "NB_SKIP_NFTABLES_CHECK"
|
||||
|
||||
// errNoFirewallManager indicates no kernel firewall backend is present,
|
||||
// as opposed to a backend that exists but failed to create or initialize.
|
||||
var errNoFirewallManager = errors.New("no firewall manager found")
|
||||
|
||||
// FWType is the type for the firewall type
|
||||
type FWType int
|
||||
|
||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
|
||||
// We run in userspace mode and force userspace firewall was requested. We don't attempt native firewall.
|
||||
// We run in userspace mode and force userspace firewall was requested.
|
||||
if iface.IsUserspaceBind() && forceUserspaceFirewall() {
|
||||
nativeFw, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
|
||||
if err != nil {
|
||||
log.Warnf("failed to create native firewall: %v. Proceeding without it", err)
|
||||
}
|
||||
|
||||
log.Info("forcing userspace firewall")
|
||||
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||
return createUserspaceFirewall(iface, nativeFw, disableServerRoutes, flowLogger, mtu)
|
||||
}
|
||||
|
||||
// Use native firewall for either kernel or userspace, the interface appears identical to netfilter
|
||||
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
|
||||
|
||||
// Kernel cannot fall back to anything else, need to return error
|
||||
if !iface.IsUserspaceBind() {
|
||||
return fm, err
|
||||
}
|
||||
|
||||
// Fall back to the userspace packet filter if native is unavailable
|
||||
if err != nil {
|
||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||
switch {
|
||||
case err == nil && !iface.IsUserspaceBind():
|
||||
// Nothing to do, fall through
|
||||
case err == nil && iface.IsUserspaceBind():
|
||||
// Native firewall handles packet filtering, but the userspace WireGuard bind
|
||||
// needs a device filter for DNS interception hooks. Install a minimal
|
||||
// hooks-only filter that passes all traffic through to the kernel firewall.
|
||||
if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil {
|
||||
log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err)
|
||||
}
|
||||
case err != nil && !iface.IsUserspaceBind():
|
||||
// Kernel cannot fall back to anything else, need to return error
|
||||
return nil, err
|
||||
case err != nil && iface.IsUserspaceBind():
|
||||
// Fall back to the userspace packet filter if native is unavailable
|
||||
logNativeFirewallUnavailable(err)
|
||||
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||
}
|
||||
|
||||
// Native firewall handles packet filtering, but the userspace WireGuard bind
|
||||
// needs a device filter for DNS interception hooks. Install a minimal
|
||||
// hooks-only filter that passes all traffic through to the kernel firewall.
|
||||
if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil {
|
||||
log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err)
|
||||
}
|
||||
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
// createUserspaceFirewall builds the userspace packet filter, optionally
|
||||
// backed by a native firewall, and allows netbird interface traffic.
|
||||
func createUserspaceFirewall(iface IFaceMapper, nativeFw firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (firewall.Manager, error) {
|
||||
fm, err := uspfilter.Create(iface, nativeFw, disableServerRoutes, flowLogger, mtu)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := fm.AllowNetbird(); err != nil {
|
||||
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||
}
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
// logNativeFirewallUnavailable logs the fallback to userspace at info level
|
||||
// when no kernel firewall backend exists, and at warn level otherwise.
|
||||
func logNativeFirewallUnavailable(err error) {
|
||||
if errors.Is(err, errNoFirewallManager) {
|
||||
log.Infof("no native firewall backend available: %v. Proceeding with userspace", err)
|
||||
} else {
|
||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||
}
|
||||
}
|
||||
|
||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
|
||||
fm, err := createFW(iface, mtu)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create firewall: %s", err)
|
||||
return nil, fmt.Errorf("create firewall: %w", err)
|
||||
}
|
||||
|
||||
if err = fm.Init(stateManager); err != nil {
|
||||
@@ -88,29 +121,10 @@ func createFW(iface IFaceMapper, mtu uint16) (firewall.Manager, error) {
|
||||
log.Info("creating an nftables firewall manager")
|
||||
return nbnftables.Create(iface, mtu)
|
||||
default:
|
||||
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
|
||||
return nil, errors.New("no firewall manager found")
|
||||
return nil, errNoFirewallManager
|
||||
}
|
||||
}
|
||||
|
||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (firewall.Manager, error) {
|
||||
var errUsp error
|
||||
if fm != nil {
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
|
||||
} else {
|
||||
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
|
||||
}
|
||||
|
||||
if errUsp != nil {
|
||||
return nil, fmt.Errorf("create userspace firewall: %s", errUsp)
|
||||
}
|
||||
|
||||
if err := fm.AllowNetbird(); err != nil {
|
||||
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||
}
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
||||
func check() FWType {
|
||||
useIPTABLES := false
|
||||
@@ -132,35 +146,38 @@ func check() FWType {
|
||||
}
|
||||
}
|
||||
|
||||
nf := nftables.Conn{}
|
||||
if chains, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
|
||||
if !useIPTABLES {
|
||||
return NFTABLES
|
||||
}
|
||||
|
||||
// search for chains where table is filter
|
||||
// if we find one, we assume that nftables manager can be used with iptables
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name == "filter" {
|
||||
// Honor the skip env before probing nftables at all.
|
||||
if os.Getenv(SkipNftablesEnv) != "true" {
|
||||
nf := nftables.Conn{}
|
||||
if chains, err := nf.ListChains(); err == nil {
|
||||
if !useIPTABLES {
|
||||
return NFTABLES
|
||||
}
|
||||
}
|
||||
|
||||
// check tables for the following constraints:
|
||||
// 1. there is no chain in nftables for the filter table and there is at least one chain in iptables, we assume that nftables manager can not be used
|
||||
// 2. there is no tables or more than one table, we assume that nftables manager can be used
|
||||
// 3. there is only one table and its name is filter, we assume that nftables manager can not be used, since there was no chain in it
|
||||
// 4. if we find an error we log and continue with iptables check
|
||||
nbTablesList, err := nf.ListTables()
|
||||
switch {
|
||||
case err == nil && len(iptablesChains) > 0:
|
||||
return IPTABLES
|
||||
case err == nil && len(nbTablesList) != 1:
|
||||
return NFTABLES
|
||||
case err == nil && len(nbTablesList) == 1 && nbTablesList[0].Name == "filter":
|
||||
return IPTABLES
|
||||
case err != nil:
|
||||
log.Errorf("failed to list nftables tables on fw manager discovery: %s", err)
|
||||
// search for chains where table is filter
|
||||
// if we find one, we assume that nftables manager can be used with iptables
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name == "filter" {
|
||||
return NFTABLES
|
||||
}
|
||||
}
|
||||
|
||||
// check tables for the following constraints:
|
||||
// 1. there is no chain in nftables for the filter table and there is at least one chain in iptables, we assume that nftables manager can not be used
|
||||
// 2. there is no tables or more than one table, we assume that nftables manager can be used
|
||||
// 3. there is only one table and its name is filter, we assume that nftables manager can not be used, since there was no chain in it
|
||||
// 4. if we find an error we log and continue with iptables check
|
||||
nbTablesList, err := nf.ListTables()
|
||||
switch {
|
||||
case err == nil && len(iptablesChains) > 0:
|
||||
return IPTABLES
|
||||
case err == nil && len(nbTablesList) != 1:
|
||||
return NFTABLES
|
||||
case err == nil && len(nbTablesList) == 1 && nbTablesList[0].Name == "filter":
|
||||
return IPTABLES
|
||||
case err != nil:
|
||||
log.Errorf("failed to list nftables tables on fw manager discovery: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,554 +0,0 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/uuid"
|
||||
ipset "github.com/lrh3321/ipset-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
const (
|
||||
tableName = "filter"
|
||||
|
||||
// rules chains contains the effective ACL rules
|
||||
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
||||
|
||||
// mangleFwdKey is the entries map key for mangle FORWARD guard rules that prevent
|
||||
// external DNAT from bypassing ACL rules.
|
||||
mangleFwdKey = "MANGLE-FORWARD"
|
||||
)
|
||||
|
||||
type aclEntries map[string][][]string
|
||||
|
||||
type entry struct {
|
||||
spec []string
|
||||
position int
|
||||
}
|
||||
|
||||
type aclManager struct {
|
||||
iptablesClient *iptables.IPTables
|
||||
wgIface iFaceMapper
|
||||
entries aclEntries
|
||||
optionalEntries map[string][]entry
|
||||
ipsetStore *ipsetStore
|
||||
v6 bool
|
||||
|
||||
stateManager *statemanager.Manager
|
||||
}
|
||||
|
||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
|
||||
return &aclManager{
|
||||
iptablesClient: iptablesClient,
|
||||
wgIface: wgIface,
|
||||
entries: make(map[string][][]string),
|
||||
optionalEntries: make(map[string][]entry),
|
||||
ipsetStore: newIpsetStore(),
|
||||
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
||||
m.stateManager = stateManager
|
||||
|
||||
m.seedInitialEntries()
|
||||
m.seedInitialOptionalEntries()
|
||||
|
||||
if err := m.cleanChains(); err != nil {
|
||||
return fmt.Errorf("clean chains: %w", err)
|
||||
}
|
||||
|
||||
if err := m.createDefaultChains(); err != nil {
|
||||
return fmt.Errorf("create default chains: %w", err)
|
||||
}
|
||||
|
||||
m.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
chain := chainNameInputRules
|
||||
|
||||
ipsetName = transformIPsetName(ipsetName, sPort, dPort, action)
|
||||
if m.v6 && ipsetName != "" {
|
||||
ipsetName += "-v6"
|
||||
}
|
||||
proto := protoForFamily(protocol, m.v6)
|
||||
specs := filterRuleSpecs(ip, proto, sPort, dPort, action, ipsetName)
|
||||
|
||||
mangleSpecs := slices.Clone(specs)
|
||||
mangleSpecs = append(mangleSpecs,
|
||||
"-i", m.wgIface.Name(),
|
||||
"-m", "addrtype", "--dst-type", "LOCAL",
|
||||
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
||||
)
|
||||
|
||||
specs = append(specs, "-j", actionToStr(action))
|
||||
if ipsetName != "" {
|
||||
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
||||
if err := m.addToIPSet(ipsetName, ip); err != nil {
|
||||
return nil, fmt.Errorf("add IP to ipset: %w", err)
|
||||
}
|
||||
// if ruleset already exists it means we already have the firewall rule
|
||||
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
||||
ipList.addIP(ip.String())
|
||||
return []firewall.Rule{&Rule{
|
||||
ruleID: uuid.New().String(),
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
chain: chain,
|
||||
specs: specs,
|
||||
v6: m.v6,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
if err := m.flushIPSet(ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("flush ipset %s before use: %v", ipsetName, err)
|
||||
} else {
|
||||
log.Errorf("flush ipset %s before use: %v", ipsetName, err)
|
||||
}
|
||||
}
|
||||
if err := m.createIPSet(ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("create ipset: %w", err)
|
||||
}
|
||||
if err := m.addToIPSet(ipsetName, ip); err != nil {
|
||||
return nil, fmt.Errorf("add IP to ipset: %w", err)
|
||||
}
|
||||
|
||||
ipList := newIpList(ip.String())
|
||||
m.ipsetStore.addIpList(ipsetName, ipList)
|
||||
}
|
||||
|
||||
ok, err := m.iptablesClient.Exists(tableFilter, chain, specs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check rule: %w", err)
|
||||
}
|
||||
if ok {
|
||||
return nil, fmt.Errorf("rule already exists")
|
||||
}
|
||||
|
||||
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||
if action == firewall.ActionDrop {
|
||||
// Insert at the beginning of the chain (position 1)
|
||||
err = m.iptablesClient.Insert(tableFilter, chain, 1, specs...)
|
||||
} else {
|
||||
err = m.iptablesClient.Append(tableFilter, chain, specs...)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := m.iptablesClient.Append(tableMangle, chainRTPRE, mangleSpecs...); err != nil {
|
||||
log.Errorf("failed to add mangle rule: %v", err)
|
||||
mangleSpecs = nil
|
||||
}
|
||||
|
||||
rule := &Rule{
|
||||
ruleID: uuid.New().String(),
|
||||
specs: specs,
|
||||
mangleSpecs: mangleSpecs,
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
chain: chain,
|
||||
v6: m.v6,
|
||||
}
|
||||
|
||||
m.updateState()
|
||||
|
||||
return []firewall.Rule{rule}, nil
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
r, ok := rule.(*Rule)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
}
|
||||
|
||||
shouldDestroyIpset := false
|
||||
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
|
||||
// delete IP from ruleset IPs list and ipset
|
||||
if _, ok := ipsetList.ips[r.ip]; ok {
|
||||
ip := net.ParseIP(r.ip)
|
||||
if ip == nil {
|
||||
return fmt.Errorf("parse IP %s", r.ip)
|
||||
}
|
||||
if err := m.delFromIPSet(r.ipsetName, ip); err != nil {
|
||||
return fmt.Errorf("delete ip from ipset: %w", err)
|
||||
}
|
||||
delete(ipsetList.ips, r.ip)
|
||||
}
|
||||
|
||||
// if after delete, set still contains other IPs,
|
||||
// no need to delete firewall rule and we should exit here
|
||||
if len(ipsetList.ips) != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// we delete last IP from the set, that means we need to delete
|
||||
// set itself and associated firewall rule too
|
||||
m.ipsetStore.deleteIpset(r.ipsetName)
|
||||
shouldDestroyIpset = true
|
||||
}
|
||||
|
||||
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
|
||||
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
|
||||
}
|
||||
|
||||
if r.mangleSpecs != nil {
|
||||
if err := m.iptablesClient.Delete(tableMangle, chainRTPRE, r.mangleSpecs...); err != nil {
|
||||
log.Errorf("failed to delete mangle rule: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if shouldDestroyIpset {
|
||||
if err := m.destroyIPSet(r.ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("destroy empty ipset: %v", err)
|
||||
} else {
|
||||
log.Errorf("destroy empty ipset: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) Reset() error {
|
||||
if err := m.cleanChains(); err != nil {
|
||||
return fmt.Errorf("clean chains: %w", err)
|
||||
}
|
||||
|
||||
m.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// todo write less destructive cleanup mechanism
|
||||
func (m *aclManager) cleanChains() error {
|
||||
ok, err := m.iptablesClient.ChainExists(tableName, chainNameInputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to list chains: %s", err)
|
||||
return err
|
||||
}
|
||||
if ok {
|
||||
for _, rule := range m.entries["INPUT"] {
|
||||
err := m.iptablesClient.DeleteIfExists(tableName, "INPUT", rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range m.entries["FORWARD"] {
|
||||
err := m.iptablesClient.DeleteIfExists(tableName, "FORWARD", rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameInputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to clear and delete %s chain: %s", chainNameInputRules, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %w", err)
|
||||
}
|
||||
if ok {
|
||||
for _, rule := range m.entries["PREROUTING"] {
|
||||
err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range m.entries[mangleFwdKey] {
|
||||
if err := m.iptablesClient.DeleteIfExists(tableMangle, chainFORWARD, rule...); err != nil {
|
||||
log.Errorf("failed to delete mangle FORWARD guard rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
||||
if err := m.flushIPSet(ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
} else {
|
||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
}
|
||||
if err := m.destroyIPSet(ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("destroy ipset %q during reset: %v", ipsetName, err)
|
||||
} else {
|
||||
log.Errorf("destroy ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
}
|
||||
m.ipsetStore.deleteIpset(ipsetName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) createDefaultChains() error {
|
||||
// chain netbird-acl-input-rules
|
||||
if err := m.iptablesClient.NewChain(tableName, chainNameInputRules); err != nil {
|
||||
log.Debugf("failed to create '%s' chain: %s", chainNameInputRules, err)
|
||||
return err
|
||||
}
|
||||
|
||||
for chainName, rules := range m.entries {
|
||||
// mangle FORWARD guard rules are handled separately below
|
||||
if chainName == mangleFwdKey {
|
||||
continue
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
|
||||
log.Debugf("failed to create input chain jump rule: %s", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for chainName, entries := range m.optionalEntries {
|
||||
for _, entry := range entries {
|
||||
if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil {
|
||||
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
|
||||
continue
|
||||
}
|
||||
m.entries[chainName] = append(m.entries[chainName], entry.spec)
|
||||
}
|
||||
}
|
||||
clear(m.optionalEntries)
|
||||
|
||||
// Insert mangle FORWARD guard rules to prevent external DNAT bypass.
|
||||
for _, rule := range m.entries[mangleFwdKey] {
|
||||
if err := m.iptablesClient.AppendUnique(tableMangle, chainFORWARD, rule...); err != nil {
|
||||
log.Errorf("failed to add mangle FORWARD guard rule: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// seedInitialEntries adds default rules to the entries map, rules are inserted on pos 1, hence the order is reversed.
|
||||
// We want to make sure our traffic is not dropped by existing rules.
|
||||
|
||||
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||
func (m *aclManager) seedInitialEntries() {
|
||||
established := getConntrackEstablished()
|
||||
|
||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
||||
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
|
||||
|
||||
// Inbound is handled by our ACLs, the rest is dropped.
|
||||
// For outbound we respect the FORWARD policy. However, we need to allow established/related traffic for inbound rules.
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||
|
||||
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
|
||||
|
||||
// Mangle FORWARD guard: when external DNAT redirects traffic from the wg interface, it
|
||||
// traverses FORWARD instead of INPUT, bypassing ACL rules. ACCEPT rules in filter FORWARD
|
||||
// can be inserted above ours. Mangle runs before filter, so these guard rules enforce the
|
||||
// ACL mark check where it cannot be overridden.
|
||||
m.appendToEntries(mangleFwdKey, []string{
|
||||
"-i", m.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED",
|
||||
"-j", "ACCEPT",
|
||||
})
|
||||
m.appendToEntries(mangleFwdKey, []string{
|
||||
"-i", m.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "DNAT",
|
||||
"-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
||||
"-j", "DROP",
|
||||
})
|
||||
}
|
||||
|
||||
func (m *aclManager) seedInitialOptionalEntries() {
|
||||
m.optionalEntries["FORWARD"] = []entry{
|
||||
{
|
||||
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
|
||||
position: 2,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
||||
m.entries[chainName] = append(m.entries[chainName], spec)
|
||||
}
|
||||
|
||||
func (m *aclManager) updateState() {
|
||||
if m.stateManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var currentState *ShutdownState
|
||||
if existing := m.stateManager.GetState(currentState); existing != nil {
|
||||
if existingState, ok := existing.(*ShutdownState); ok {
|
||||
currentState = existingState
|
||||
}
|
||||
}
|
||||
if currentState == nil {
|
||||
currentState = &ShutdownState{}
|
||||
}
|
||||
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
if m.v6 {
|
||||
currentState.ACLEntries6 = m.entries
|
||||
currentState.ACLIPsetStore6 = m.ipsetStore
|
||||
} else {
|
||||
currentState.ACLEntries = m.entries
|
||||
currentState.ACLIPsetStore = m.ipsetStore
|
||||
}
|
||||
|
||||
if err := m.stateManager.UpdateState(currentState); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// filterRuleSpecs returns the specs of a filtering rule
|
||||
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
|
||||
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
|
||||
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
|
||||
if v6 && protocol == firewall.ProtocolICMP {
|
||||
return "ipv6-icmp"
|
||||
}
|
||||
return string(protocol)
|
||||
}
|
||||
|
||||
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
||||
// don't use IP matching if IP is 0.0.0.0
|
||||
matchByIP := !ip.IsUnspecified()
|
||||
|
||||
if matchByIP {
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--match-set", ipsetName, "src")
|
||||
} else {
|
||||
specs = append(specs, "-s", ip.String())
|
||||
}
|
||||
}
|
||||
if protocol != "all" {
|
||||
specs = append(specs, "-p", protocol)
|
||||
}
|
||||
specs = append(specs, applyPort("--sport", sPort)...)
|
||||
specs = append(specs, applyPort("--dport", dPort)...)
|
||||
return specs
|
||||
}
|
||||
|
||||
func actionToStr(action firewall.Action) string {
|
||||
if action == firewall.ActionAccept {
|
||||
return "ACCEPT"
|
||||
}
|
||||
return "DROP"
|
||||
}
|
||||
|
||||
func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action firewall.Action) string {
|
||||
if ipsetName == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
actionSuffix := ""
|
||||
if action == firewall.ActionDrop {
|
||||
actionSuffix = "-drop"
|
||||
}
|
||||
|
||||
switch {
|
||||
case sPort != nil && dPort != nil:
|
||||
return ipsetName + "-sport-dport" + actionSuffix
|
||||
case sPort != nil:
|
||||
return ipsetName + "-sport" + actionSuffix
|
||||
case dPort != nil:
|
||||
return ipsetName + "-dport" + actionSuffix
|
||||
default:
|
||||
return ipsetName + actionSuffix
|
||||
}
|
||||
}
|
||||
|
||||
func (m *aclManager) createIPSet(name string) error {
|
||||
opts := ipset.CreateOptions{
|
||||
Replace: true,
|
||||
}
|
||||
if m.v6 {
|
||||
opts.Family = ipset.FamilyIPV6
|
||||
}
|
||||
|
||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
log.Debugf("created ipset %s with type hash:net", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) addToIPSet(name string, ip net.IP) error {
|
||||
cidr := uint8(32)
|
||||
if ip.To4() == nil {
|
||||
cidr = 128
|
||||
}
|
||||
|
||||
entry := &ipset.Entry{
|
||||
IP: ip,
|
||||
CIDR: cidr,
|
||||
Replace: true,
|
||||
}
|
||||
|
||||
if err := ipset.Add(name, entry); err != nil {
|
||||
return fmt.Errorf("add IP to ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) delFromIPSet(name string, ip net.IP) error {
|
||||
cidr := uint8(32)
|
||||
if ip.To4() == nil {
|
||||
cidr = 128
|
||||
}
|
||||
|
||||
entry := &ipset.Entry{
|
||||
IP: ip,
|
||||
CIDR: cidr,
|
||||
}
|
||||
|
||||
if err := ipset.Del(name, entry); err != nil {
|
||||
return fmt.Errorf("delete IP from ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) flushIPSet(name string) error {
|
||||
return ipset.Flush(name)
|
||||
}
|
||||
|
||||
func (m *aclManager) destroyIPSet(name string) error {
|
||||
return ipset.Destroy(name)
|
||||
}
|
||||
352
client/firewall/iptables/chains_linux.go
Normal file
352
client/firewall/iptables/chains_linux.go
Normal file
@@ -0,0 +1,352 @@
|
||||
//go:build !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
func (r *family) createContainers() error {
|
||||
for _, chainInfo := range []struct {
|
||||
chain string
|
||||
table string
|
||||
}{
|
||||
{chainRTFwdIn, tableFilter},
|
||||
{chainRTFwdOut, tableFilter},
|
||||
{chainRTPre, tableMangle},
|
||||
{chainRTNAT, tableNat},
|
||||
{chainRTRdr, tableNat},
|
||||
{chainRTMSSClamp, tableMangle},
|
||||
} {
|
||||
// Fallback: clear chains that survived an unclean shutdown.
|
||||
if ok, _ := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain); ok {
|
||||
if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||
log.Warnf("clear stale chain %s in %s: %v", chainInfo.chain, chainInfo.table, err)
|
||||
}
|
||||
}
|
||||
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.insertEstablishedRule(chainRTFwdIn); err != nil {
|
||||
return fmt.Errorf("insert established rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.insertEstablishedRule(chainRTFwdOut); err != nil {
|
||||
return fmt.Errorf("insert established rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addPostroutingRules(); err != nil {
|
||||
return fmt.Errorf("add static nat rules: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addJumpRules(); err != nil {
|
||||
return fmt.Errorf("add jump rules: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addMSSClampingRules(); err != nil {
|
||||
log.Errorf("failed to add MSS clamping rules: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) addJumpRules() error {
|
||||
// Jump to nat chain
|
||||
natRule := []string{"-j", chainRTNAT}
|
||||
if err := r.iptablesClient.Insert(tableNat, chainPostrouting, 1, natRule...); err != nil {
|
||||
return fmt.Errorf("add nat postrouting jump rule: %w", err)
|
||||
}
|
||||
r.rules[jumpNATPost] = natRule
|
||||
|
||||
// Jump to mangle prerouting chain
|
||||
preRule := []string{"-j", chainRTPre}
|
||||
if err := r.iptablesClient.Insert(tableMangle, chainPrerouting, 1, preRule...); err != nil {
|
||||
return fmt.Errorf("add mangle prerouting jump rule: %w", err)
|
||||
}
|
||||
r.rules[jumpManglePre] = preRule
|
||||
|
||||
// Jump to nat prerouting chain
|
||||
rdrRule := []string{"-j", chainRTRdr}
|
||||
if err := r.iptablesClient.Insert(tableNat, chainPrerouting, 1, rdrRule...); err != nil {
|
||||
return fmt.Errorf("add nat prerouting jump rule: %w", err)
|
||||
}
|
||||
r.rules[jumpNATPre] = rdrRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) setupDataPlaneMark() error {
|
||||
var merr *multierror.Error
|
||||
preRule := []string{
|
||||
"-i", r.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "NEW",
|
||||
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkIn),
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.AppendUnique(tableMangle, chainPrerouting, preRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add mangle prerouting rule: %w", err))
|
||||
} else {
|
||||
r.rules[markManglePre] = preRule
|
||||
}
|
||||
|
||||
postRule := []string{
|
||||
"-o", r.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "NEW",
|
||||
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkOut),
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.AppendUnique(tableMangle, chainPostrouting, postRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add mangle postrouting rule: %w", err))
|
||||
} else {
|
||||
r.rules[markManglePost] = postRule
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// seedInitialEntries adds default rules to the entries map. Rules are
|
||||
// inserted at position 1, so the order here is reversed.
|
||||
//
|
||||
// Existing FORWARD policy decides outbound traffic towards our
|
||||
// interface. If FORWARD policy is "drop", we add an
|
||||
// established/related rule to allow return traffic for inbound rules.
|
||||
func (r *family) seedInitialEntries() {
|
||||
established := getConntrackEstablished()
|
||||
|
||||
r.appendToEntries(chainInput, []string{"-i", r.wgIface.Name(), "-j", "DROP"})
|
||||
r.appendToEntries(chainInput, []string{"-i", r.wgIface.Name(), "-j", chainACLInput})
|
||||
r.appendToEntries(chainInput, append([]string{"-i", r.wgIface.Name()}, established...))
|
||||
|
||||
r.appendToEntries(chainForward, []string{"-i", r.wgIface.Name(), "-j", "DROP"})
|
||||
r.appendToEntries(chainForward, []string{"-o", r.wgIface.Name(), "-j", chainRTFwdOut})
|
||||
r.appendToEntries(chainForward, []string{"-i", r.wgIface.Name(), "-j", chainRTFwdIn})
|
||||
|
||||
// Mangle FORWARD guard: when external DNAT redirects traffic from
|
||||
// the wg interface, it traverses FORWARD instead of INPUT,
|
||||
// bypassing ACL rules. ACCEPT rules in filter FORWARD can be
|
||||
// inserted above ours. Mangle runs before filter, so these guard
|
||||
// rules enforce the ACL mark check where it cannot be overridden.
|
||||
r.appendToEntries(mangleForwardKey, []string{
|
||||
"-i", r.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED",
|
||||
"-j", "ACCEPT",
|
||||
})
|
||||
r.appendToEntries(mangleForwardKey, []string{
|
||||
"-i", r.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "DNAT",
|
||||
"-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
||||
"-j", "DROP",
|
||||
})
|
||||
}
|
||||
|
||||
func (r *family) seedInitialOptionalEntries() {
|
||||
r.optionalEntries[chainForward] = []entry{
|
||||
{
|
||||
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
|
||||
position: 2,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *family) appendToEntries(chain chainKey, spec ruleSpec) {
|
||||
r.entries[chain] = append(r.entries[chain], spec)
|
||||
}
|
||||
|
||||
func (r *family) createDefaultChains() error {
|
||||
if err := r.iptablesClient.NewChain(tableName, chainACLInput); err != nil {
|
||||
return fmt.Errorf("create %s chain: %w", chainACLInput, err)
|
||||
}
|
||||
|
||||
for chain, rules := range r.entries {
|
||||
// mangle FORWARD guard rules are handled separately below
|
||||
if chain == mangleForwardKey {
|
||||
continue
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if err := r.iptablesClient.InsertUnique(tableName, string(chain), 1, rule...); err != nil {
|
||||
return fmt.Errorf("insert jump rule into %s: %w", chain, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for chain, entries := range r.optionalEntries {
|
||||
for _, entry := range entries {
|
||||
if err := r.iptablesClient.InsertUnique(tableName, string(chain), entry.position, entry.spec...); err != nil {
|
||||
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
|
||||
continue
|
||||
}
|
||||
r.entries[chain] = append(r.entries[chain], entry.spec)
|
||||
}
|
||||
}
|
||||
clear(r.optionalEntries)
|
||||
|
||||
// Insert mangle FORWARD guard rules to prevent external DNAT bypass.
|
||||
for _, rule := range r.entries[mangleForwardKey] {
|
||||
if err := r.iptablesClient.AppendUnique(tableMangle, chainForward, rule...); err != nil {
|
||||
log.Errorf("failed to add mangle FORWARD guard rule: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) cleanUpDefaultForwardRules() error {
|
||||
var merr *multierror.Error
|
||||
|
||||
// cleanJumpRules removes the OUTPUT jump to NETBIRD-NAT-OUTPUT among
|
||||
// the others, so the chain below deletes cleanly instead of failing
|
||||
// with "device or resource busy".
|
||||
if err := r.cleanJumpRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("clean jump rules: %w", err))
|
||||
}
|
||||
|
||||
for _, chainInfo := range []struct {
|
||||
chain string
|
||||
table string
|
||||
}{
|
||||
{chainRTFwdIn, tableFilter},
|
||||
{chainRTFwdOut, tableFilter},
|
||||
{chainRTPre, tableMangle},
|
||||
{chainRTNAT, tableNat},
|
||||
{chainRTRdr, tableNat},
|
||||
{chainNATOutput, tableNat},
|
||||
{chainRTMSSClamp, tableMangle},
|
||||
} {
|
||||
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("check chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err))
|
||||
continue
|
||||
}
|
||||
if ok {
|
||||
if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("clear and delete chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) cleanJumpRules() error {
|
||||
// locations maps each tracked jump rule to the built-in table and
|
||||
// chain it was inserted into.
|
||||
locations := map[firewall.RuleID]struct{ table, chain string }{
|
||||
jumpNATPost: {tableNat, chainPostrouting},
|
||||
jumpManglePre: {tableMangle, chainPrerouting},
|
||||
jumpNATPre: {tableNat, chainPrerouting},
|
||||
jumpMSSClamp: {tableMangle, chainForward},
|
||||
jumpNATOutput: {tableNat, chainOutput},
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
for ruleID, loc := range locations {
|
||||
rule, exists := r.rules[ruleID]
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
if err := r.iptablesClient.DeleteIfExists(loc.table, loc.chain, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete rule from chain %s in table %s: %w", loc.chain, loc.table, err))
|
||||
continue
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) cleanAclChains() error {
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := r.cleanInputAclChain(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.cleanPreroutingEntries(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
for _, rule := range r.entries[mangleForwardKey] {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainForward, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete mangle %s guard rule %v: %w", chainForward, rule, err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) cleanInputAclChain() error {
|
||||
ok, err := r.iptablesClient.ChainExists(tableName, chainACLInput)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check chain %s: %w", chainACLInput, err)
|
||||
}
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, rule := range r.entries[chainInput] {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableName, chainInput, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete %s rule %v: %w", chainInput, rule, err))
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range r.entries[chainForward] {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableName, chainForward, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete %s rule %v: %w", chainForward, rule, err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.ClearAndDeleteChain(tableName, chainACLInput); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("clear and delete %s chain: %w", chainACLInput, err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) cleanPreroutingEntries() error {
|
||||
ok, err := r.iptablesClient.ChainExists(tableMangle, chainPrerouting)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check chain %s in %s: %w", chainPrerouting, tableMangle, err)
|
||||
}
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, rule := range r.entries[chainPrerouting] {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPrerouting, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete %s rule %v: %w", chainPrerouting, rule, err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) cleanupDataPlaneMark() error {
|
||||
var merr *multierror.Error
|
||||
if preRule, exists := r.rules[markManglePre]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPrerouting, preRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove mangle prerouting rule: %w", err))
|
||||
} else {
|
||||
delete(r.rules, markManglePre)
|
||||
}
|
||||
}
|
||||
|
||||
if postRule, exists := r.rules[markManglePost]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPostrouting, postRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove mangle postrouting rule: %w", err))
|
||||
} else {
|
||||
delete(r.rules, markManglePost)
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
269
client/firewall/iptables/dnat_linux.go
Normal file
269
client/firewall/iptables/dnat_linux.go
Normal file
@@ -0,0 +1,269 @@
|
||||
//go:build !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
func (r *family) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ruleID := rule.ID()
|
||||
if _, exists := r.rules[ruleID+dnatSuffix]; exists {
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
toDestination := rule.TranslatedAddress.String()
|
||||
switch {
|
||||
case len(rule.TranslatedPort.Values) == 0:
|
||||
// no translated port, use original port
|
||||
case len(rule.TranslatedPort.Values) == 1:
|
||||
toDestination += fmt.Sprintf(":%d", rule.TranslatedPort.Values[0])
|
||||
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
|
||||
// need the "/originalport" suffix to avoid dnat port randomization
|
||||
toDestination += fmt.Sprintf(":%d-%d/%d", rule.TranslatedPort.Values[0], rule.TranslatedPort.Values[1], rule.DestinationPort.Values[0])
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
|
||||
}
|
||||
|
||||
proto := strings.ToLower(string(rule.Protocol))
|
||||
|
||||
rules := make(map[firewall.RuleID]ruleInfo, 3)
|
||||
|
||||
// DNAT rule
|
||||
dnatRule := []string{
|
||||
"!", "-i", r.wgIface.Name(),
|
||||
"-p", proto,
|
||||
"-j", "DNAT",
|
||||
"--to-destination", toDestination,
|
||||
}
|
||||
dnatRule = append(dnatRule, applyPort("--dport", &rule.DestinationPort)...)
|
||||
rules[ruleID+dnatSuffix] = ruleInfo{
|
||||
table: tableNat,
|
||||
chain: chainRTRdr,
|
||||
rule: dnatRule,
|
||||
}
|
||||
|
||||
// SNAT rule
|
||||
snatRule := []string{
|
||||
"-o", r.wgIface.Name(),
|
||||
"-p", proto,
|
||||
"-d", rule.TranslatedAddress.String(),
|
||||
"-j", "MASQUERADE",
|
||||
}
|
||||
snatRule = append(snatRule, applyPort("--dport", &rule.TranslatedPort)...)
|
||||
rules[ruleID+snatSuffix] = ruleInfo{
|
||||
table: tableNat,
|
||||
chain: chainRTNAT,
|
||||
rule: snatRule,
|
||||
}
|
||||
|
||||
// Forward filtering rule, if fwd policy is DROP
|
||||
forwardRule := []string{
|
||||
"-o", r.wgIface.Name(),
|
||||
"-p", proto,
|
||||
"-d", rule.TranslatedAddress.String(),
|
||||
"-j", "ACCEPT",
|
||||
}
|
||||
forwardRule = append(forwardRule, applyPort("--dport", &rule.TranslatedPort)...)
|
||||
rules[ruleID+fwdSuffix] = ruleInfo{
|
||||
table: tableFilter,
|
||||
chain: chainRTFwdOut,
|
||||
rule: forwardRule,
|
||||
}
|
||||
|
||||
for key, ruleInfo := range rules {
|
||||
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
|
||||
if rollbackErr := r.rollbackRules(rules); rollbackErr != nil {
|
||||
log.Errorf("rollback failed: %v", rollbackErr)
|
||||
}
|
||||
return nil, fmt.Errorf("add rule %s: %w", key, err)
|
||||
}
|
||||
r.rules[key] = ruleInfo.rule
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (r *family) rollbackRules(rules map[firewall.RuleID]ruleInfo) error {
|
||||
var merr *multierror.Error
|
||||
for key, ruleInfo := range rules {
|
||||
if err := r.iptablesClient.DeleteIfExists(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("rollback rule %s: %w", key, err))
|
||||
// On rollback error, add to rules map for next cleanup
|
||||
r.rules[key] = ruleInfo.rule
|
||||
}
|
||||
}
|
||||
if merr != nil {
|
||||
r.updateState()
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) DeleteDNATRule(rule firewall.Rule) error {
|
||||
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
log.Errorf("%v", err)
|
||||
}
|
||||
|
||||
ruleID := rule.ID()
|
||||
|
||||
var merr *multierror.Error
|
||||
if dnatRule, exists := r.rules[ruleID+dnatSuffix]; exists {
|
||||
if err := r.iptablesClient.Delete(tableNat, chainRTRdr, dnatRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete DNAT rule: %w", err))
|
||||
}
|
||||
delete(r.rules, ruleID+dnatSuffix)
|
||||
}
|
||||
|
||||
if snatRule, exists := r.rules[ruleID+snatSuffix]; exists {
|
||||
if err := r.iptablesClient.Delete(tableNat, chainRTNAT, snatRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete SNAT rule: %w", err))
|
||||
}
|
||||
delete(r.rules, ruleID+snatSuffix)
|
||||
}
|
||||
|
||||
if fwdRule, exists := r.rules[ruleID+fwdSuffix]; exists {
|
||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFwdOut, fwdRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
|
||||
}
|
||||
delete(r.rules, ruleID+fwdSuffix)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
dnatRule := []string{
|
||||
"-i", r.wgIface.Name(),
|
||||
"-p", strings.ToLower(protoForFamily(protocol, r.v6)),
|
||||
"--dport", strconv.Itoa(int(originalPort)),
|
||||
"-d", localAddr.String(),
|
||||
"-m", "addrtype", "--dst-type", "LOCAL",
|
||||
"-j", "DNAT",
|
||||
"--to-destination", ":" + strconv.Itoa(int(translatedPort)),
|
||||
}
|
||||
|
||||
info := ruleInfo{
|
||||
table: tableNat,
|
||||
chain: chainRTRdr,
|
||||
rule: dnatRule,
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.Append(info.table, info.chain, info.rule...); err != nil {
|
||||
return fmt.Errorf("add inbound DNAT rule: %w", err)
|
||||
}
|
||||
r.rules[ruleID] = info.rule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (r *family) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||
|
||||
if dnatRule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.iptablesClient.Delete(tableNat, chainRTRdr, dnatRule...); err != nil {
|
||||
return fmt.Errorf("delete inbound DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use.
|
||||
func (r *family) ensureNATOutputChain() error {
|
||||
if _, exists := r.rules[jumpNATOutput]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
|
||||
}
|
||||
if !chainExists {
|
||||
if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil {
|
||||
return fmt.Errorf("create chain %s: %w", chainNATOutput, err)
|
||||
}
|
||||
}
|
||||
|
||||
jumpRule := []string{"-j", chainNATOutput}
|
||||
if err := r.iptablesClient.Insert(tableNat, chainOutput, 1, jumpRule...); err != nil {
|
||||
if !chainExists {
|
||||
if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil {
|
||||
log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("add OUTPUT jump rule: %w", err)
|
||||
}
|
||||
r.rules[jumpNATOutput] = jumpRule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (r *family) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.ensureNATOutputChain(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dnatRule := []string{
|
||||
"-p", strings.ToLower(protoForFamily(protocol, localAddr.Is6())),
|
||||
"--dport", strconv.Itoa(int(originalPort)),
|
||||
"-d", localAddr.String(),
|
||||
"-j", "DNAT",
|
||||
"--to-destination", ":" + strconv.Itoa(int(translatedPort)),
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
|
||||
return fmt.Errorf("add output DNAT rule: %w", err)
|
||||
}
|
||||
r.rules[ruleID] = dnatRule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (r *family) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||
|
||||
if dnatRule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
|
||||
return fmt.Errorf("delete output DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
246
client/firewall/iptables/family_linux.go
Normal file
246
client/firewall/iptables/family_linux.go
Normal file
@@ -0,0 +1,246 @@
|
||||
//go:build !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// constants needed to manage and create iptable rules
|
||||
const (
|
||||
tableFilter = "filter"
|
||||
tableName = tableFilter
|
||||
tableNat = "nat"
|
||||
tableMangle = "mangle"
|
||||
|
||||
// chainACLInput is the peer ACL chain that holds installed
|
||||
// peer-filtering rules.
|
||||
chainACLInput = "NETBIRD-ACL-INPUT"
|
||||
|
||||
// mangleForwardKey is the entries map key for mangle FORWARD guard
|
||||
// rules that prevent external DNAT from bypassing ACL rules.
|
||||
mangleForwardKey chainKey = "MANGLE-FORWARD"
|
||||
|
||||
chainInput = "INPUT"
|
||||
chainPostrouting = "POSTROUTING"
|
||||
chainPrerouting = "PREROUTING"
|
||||
chainForward = "FORWARD"
|
||||
chainRTNAT = "NETBIRD-RT-NAT"
|
||||
chainRTFwdIn = "NETBIRD-RT-FWD-IN"
|
||||
chainRTFwdOut = "NETBIRD-RT-FWD-OUT"
|
||||
chainRTPre = "NETBIRD-RT-PRE"
|
||||
chainRTRdr = "NETBIRD-RT-RDR"
|
||||
chainNATOutput = "NETBIRD-NAT-OUTPUT"
|
||||
chainRTMSSClamp = "NETBIRD-RT-MSSCLAMP"
|
||||
|
||||
jumpManglePre = "jump-mangle-pre"
|
||||
jumpNATPre = "jump-nat-pre"
|
||||
jumpNATPost = "jump-nat-post"
|
||||
jumpNATOutput = "jump-nat-output"
|
||||
jumpMSSClamp = "jump-mss-clamp"
|
||||
markManglePre = "mark-mangle-pre"
|
||||
markManglePost = "mark-mangle-post"
|
||||
matchSet = "--match-set"
|
||||
|
||||
dnatSuffix firewall.RuleID = "_dnat"
|
||||
snatSuffix firewall.RuleID = "_snat"
|
||||
fwdSuffix firewall.RuleID = "_fwd"
|
||||
|
||||
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
|
||||
ipv4TCPHeaderSize = 40
|
||||
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
|
||||
ipv6TCPHeaderSize = 60
|
||||
)
|
||||
|
||||
type ruleInfo struct {
|
||||
chain string
|
||||
table string
|
||||
rule []string
|
||||
}
|
||||
|
||||
type routeRules map[firewall.RuleID][]string
|
||||
|
||||
// ruleSpec is a single iptables rule expressed as its argument list
|
||||
// (e.g. {"-i", "wg0", "-j", "DROP"}).
|
||||
type ruleSpec []string
|
||||
|
||||
// chainKey identifies the chain a seeded entry belongs to. It holds
|
||||
// built-in chain names ("INPUT", "FORWARD", "PREROUTING") plus the
|
||||
// synthetic mangleForwardKey bucket for the mangle FORWARD guard rules.
|
||||
type chainKey string
|
||||
|
||||
// aclEntries maps a chain to the rules seeded into it to jump into or
|
||||
// guard the netbird ACL chains.
|
||||
type aclEntries map[chainKey][]ruleSpec
|
||||
|
||||
type entry struct {
|
||||
spec ruleSpec
|
||||
position int
|
||||
}
|
||||
|
||||
// ipsetCounter is the shared hash:net refcounter used by peer and
|
||||
// route ACLs alike. The ipset library does not support comments, so
|
||||
// the key is just the set name (string).
|
||||
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
|
||||
|
||||
// family holds the per-address-family iptables state. One instance
|
||||
// handles route ACLs, peer ACLs, NAT, DNAT, and MSS clamping for a
|
||||
// single family; the top-level Manager owns one for v4 and another
|
||||
// for v6.
|
||||
type family struct {
|
||||
iptablesClient *iptables.IPTables
|
||||
wgIface iFaceMapper
|
||||
v6 bool
|
||||
|
||||
// Peer ACL chain bookkeeping.
|
||||
entries aclEntries
|
||||
optionalEntries map[chainKey][]entry
|
||||
|
||||
// filters holds peer + route filter rules keyed by content hash.
|
||||
// AddFilterRule writes here; DeleteFilterRule looks up by id.
|
||||
filters map[nbid.RuleID]*Rule
|
||||
ipsetCounter *ipsetCounter
|
||||
|
||||
// rules holds NAT, jump, and MSS-clamping rules (auxiliary
|
||||
// plumbing that isn't a filter rule).
|
||||
rules routeRules
|
||||
|
||||
// Routing / NAT.
|
||||
legacyManagement bool
|
||||
mtu uint16
|
||||
ipFwdState *ipfwdstate.IPForwardingState
|
||||
|
||||
stateManager *statemanager.Manager
|
||||
}
|
||||
|
||||
func newFamily(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*family, error) {
|
||||
r := &family{
|
||||
iptablesClient: iptablesClient,
|
||||
wgIface: wgIface,
|
||||
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
|
||||
entries: make(aclEntries),
|
||||
optionalEntries: make(map[chainKey][]entry),
|
||||
filters: make(map[nbid.RuleID]*Rule),
|
||||
rules: make(routeRules),
|
||||
mtu: mtu,
|
||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||
}
|
||||
|
||||
r.ipsetCounter = refcounter.New(
|
||||
func(name string, sources []netip.Prefix) (struct{}, error) {
|
||||
return struct{}{}, r.createIpSet(name, sources)
|
||||
},
|
||||
func(name string, _ struct{}) error {
|
||||
return r.deleteIpSet(name)
|
||||
},
|
||||
)
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// init wires the family to the state manager and installs both the
|
||||
// route ACL containers and the peer ACL chain skeleton.
|
||||
func (r *family) init(stateManager *statemanager.Manager) error {
|
||||
r.stateManager = stateManager
|
||||
|
||||
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
if err := r.createContainers(); err != nil {
|
||||
return fmt.Errorf("create containers: %w", err)
|
||||
}
|
||||
|
||||
if err := r.setupDataPlaneMark(); err != nil {
|
||||
log.Errorf("failed to set up data plane mark: %v", err)
|
||||
}
|
||||
|
||||
r.seedInitialEntries()
|
||||
r.seedInitialOptionalEntries()
|
||||
|
||||
if err := r.cleanAclChains(); err != nil {
|
||||
return fmt.Errorf("clean acl chains: %w", err)
|
||||
}
|
||||
if err := r.createDefaultChains(); err != nil {
|
||||
return fmt.Errorf("create default chains: %w", err)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset tears down all firewall state owned by this family. ACL
|
||||
// chain cleanup runs before route-chain cleanup because the route
|
||||
// chains are still referenced by FORWARD jumps installed during
|
||||
// seedInitialEntries; deleting them first would trip EBUSY.
|
||||
func (r *family) Reset() error {
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := r.cleanAclChains(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.ipsetCounter.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.cleanupDataPlaneMark(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
clear(r.rules)
|
||||
clear(r.filters)
|
||||
r.updateState()
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) updateState() {
|
||||
if r.stateManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var currentState *ShutdownState
|
||||
if existing := r.stateManager.GetState(currentState); existing != nil {
|
||||
if existingState, ok := existing.(*ShutdownState); ok {
|
||||
currentState = existingState
|
||||
}
|
||||
}
|
||||
if currentState == nil {
|
||||
currentState = &ShutdownState{}
|
||||
}
|
||||
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
if r.v6 {
|
||||
currentState.RouteRules6 = r.rules
|
||||
currentState.RouteIPsetCounter6 = r.ipsetCounter
|
||||
currentState.ACLEntries6 = r.entries
|
||||
} else {
|
||||
currentState.RouteRules = r.rules
|
||||
currentState.RouteIPsetCounter = r.ipsetCounter
|
||||
currentState.ACLEntries = r.entries
|
||||
}
|
||||
|
||||
if err := r.stateManager.UpdateState(currentState); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
}
|
||||
334
client/firewall/iptables/filter_linux.go
Normal file
334
client/firewall/iptables/filter_linux.go
Normal file
@@ -0,0 +1,334 @@
|
||||
//go:build !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
// AddFilterRule installs a packet-filtering rule. With destination
|
||||
// empty, the rule goes to the peer ACL input chain plus a paired
|
||||
// mangle PREROUTING rule for the redirect mark. With destination set
|
||||
// (prefix or named set), it goes to the route ACL forward chain.
|
||||
// Multi-source rules collapse to one iptables rule via the shared
|
||||
// hash:net ipset.
|
||||
func (r *family) AddFilterRule(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
ruleID := nbid.GenerateRuleID(sources, destination, proto, sPort, dPort, action)
|
||||
if existing, ok := r.filters[ruleID]; ok {
|
||||
return existing, nil
|
||||
}
|
||||
|
||||
srcMatch, err := r.applySourceMatch(sourceNetwork(sources), sources)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("apply source match: %w", err)
|
||||
}
|
||||
|
||||
rule, err := r.installFilterRule(ruleID, srcMatch, destination, proto, sPort, dPort, action)
|
||||
if err != nil {
|
||||
r.dropSourceMatch(srcMatch)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.filters[ruleID] = rule
|
||||
r.updateState()
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (r *family) hasRule(id nbid.RuleID) bool {
|
||||
_, ok := r.filters[id]
|
||||
return ok
|
||||
}
|
||||
|
||||
// hasDNATRule reports whether this family owns the DNAT rule set for
|
||||
// the given user id. DNAT rules live in r.rules under the well-known
|
||||
// "<id>_dnat" key; the lookup here is used by Manager.DeleteDNATRule
|
||||
// to pick the right family.
|
||||
func (r *family) hasDNATRule(id firewall.RuleID) bool {
|
||||
_, ok := r.rules[id+dnatSuffix]
|
||||
return ok
|
||||
}
|
||||
|
||||
// DeleteFilterRule removes a previously installed filter rule. The
|
||||
// rule's stored chain/table identify where to delete from; source set
|
||||
// references are recovered from the spec via findSets and dropped
|
||||
// from the shared ipset counter.
|
||||
func (r *family) DeleteFilterRule(rule firewall.Rule) error {
|
||||
ruleID := rule.ID()
|
||||
pr, ok := r.filters[ruleID]
|
||||
if !ok {
|
||||
log.Debugf("filter rule %s not found", ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.Delete(tableFilter, pr.chain, pr.specs...); err != nil {
|
||||
return fmt.Errorf("delete rule %s: %w", pr.chain, err)
|
||||
}
|
||||
if pr.mangleSpecs != nil {
|
||||
if err := r.iptablesClient.Delete(tableMangle, chainRTPre, pr.mangleSpecs...); err != nil {
|
||||
log.Errorf("delete mangle rule: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.dropSourceMatch(pr.specs)
|
||||
delete(r.filters, ruleID)
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// findSets scans an iptables rule spec for "-m set --match-set <name>
|
||||
// <dir>" fragments and returns the named sets in occurrence order.
|
||||
// Used at delete time to drop ipsetCounter references.
|
||||
func findSets(rule []string) []string {
|
||||
var sets []string
|
||||
for i, arg := range rule {
|
||||
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
|
||||
sets = append(sets, rule[i+3])
|
||||
}
|
||||
}
|
||||
return sets
|
||||
}
|
||||
|
||||
// sourceNetwork classifies a source-prefix list into the firewall.Network
|
||||
// shape the rest of the spec-builder consumes: empty for match-any, a
|
||||
// single prefix inline, or an ipset for multiple sources.
|
||||
func sourceNetwork(sources []netip.Prefix) firewall.Network {
|
||||
switch {
|
||||
case len(sources) == 0:
|
||||
return firewall.Network{}
|
||||
case len(sources) == 1 && sources[0].Bits() == 0:
|
||||
return firewall.Network{}
|
||||
case len(sources) == 1:
|
||||
return firewall.Network{Prefix: sources[0]}
|
||||
default:
|
||||
return firewall.Network{Set: firewall.NewPrefixSet(sources)}
|
||||
}
|
||||
}
|
||||
|
||||
// applySourceMatch returns the iptables match fragment for the rule's
|
||||
// source. For a Set it increments the shared ipset's refcount; for a
|
||||
// Prefix it emits a direct -s match; for the wildcard it returns nil.
|
||||
func (r *family) applySourceMatch(network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
|
||||
switch {
|
||||
case network.IsSet():
|
||||
if r.ipsetCounter == nil {
|
||||
return nil, fmt.Errorf("multi-source peer rule requires shared ipset counter")
|
||||
}
|
||||
name := r.ipsetName(network.Set.HashedName())
|
||||
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
|
||||
return nil, fmt.Errorf("ipset increment %s: %w", name, err)
|
||||
}
|
||||
return []string{"-m", "set", matchSet, name, "src"}, nil
|
||||
case network.IsPrefix():
|
||||
return []string{"-s", network.Prefix.String()}, nil
|
||||
default:
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// dropSourceMatch undoes whatever applySourceMatch reserved. Safe to
|
||||
// call when the spec is empty or holds only inline matchers. Decrement
|
||||
// errors are logged but not returned: the filter rule has already been
|
||||
// deleted at that point and we don't want to leak the deletion.
|
||||
func (r *family) dropSourceMatch(srcMatch []string) {
|
||||
if r.ipsetCounter == nil {
|
||||
return
|
||||
}
|
||||
for _, name := range findSets(srcMatch) {
|
||||
if _, err := r.ipsetCounter.Decrement(name); err != nil {
|
||||
log.Errorf("rollback ipset decrement %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// decrementSetCounter drops ipset references owned by a raw rule spec
|
||||
// stored in r.rules (NAT / legacy route entries). It returns an error
|
||||
// aggregate so the caller surfaces decrement failures.
|
||||
func (r *family) decrementSetCounter(rule []string) error {
|
||||
if r.ipsetCounter == nil {
|
||||
return nil
|
||||
}
|
||||
var merr *multierror.Error
|
||||
for _, name := range findSets(rule) {
|
||||
if _, err := r.ipsetCounter.Decrement(name); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("decrement counter: %w", err))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// installFilterRule assembles and writes one iptables filter-chain
|
||||
// rule. With destination empty the rule lands in the peer ACL input
|
||||
// chain and a paired mangle PREROUTING rule is added for the redirect
|
||||
// mark. With destination set the rule lands in the route ACL forward
|
||||
// chain and there is no mangle pairing.
|
||||
func (r *family) installFilterRule(
|
||||
ruleID nbid.RuleID,
|
||||
srcMatch []string,
|
||||
destination firewall.Network,
|
||||
protocol firewall.Protocol,
|
||||
sPort, dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (*Rule, error) {
|
||||
isRoute := !destination.IsZero()
|
||||
|
||||
proto := protoForFamily(protocol, r.v6)
|
||||
|
||||
specs := slices.Clone(srcMatch)
|
||||
var destExp []string
|
||||
if isRoute {
|
||||
var err error
|
||||
destExp, err = r.applyNetwork("-d", destination, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("apply network -d: %w", err)
|
||||
}
|
||||
specs = append(specs, destExp...)
|
||||
}
|
||||
specs = append(specs, filterMatchSpecs(proto, sPort, dPort)...)
|
||||
|
||||
var mangleSpecs []string
|
||||
if !isRoute {
|
||||
mangleSpecs = slices.Clone(specs)
|
||||
mangleSpecs = append(mangleSpecs,
|
||||
"-i", r.wgIface.Name(),
|
||||
"-m", "addrtype", "--dst-type", "LOCAL",
|
||||
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
||||
)
|
||||
}
|
||||
|
||||
specs = append(specs, "-j", actionToStr(action))
|
||||
|
||||
chain := chainACLInput
|
||||
if isRoute {
|
||||
chain = chainRTFwdIn
|
||||
}
|
||||
|
||||
// Peer ACL drops are inserted at position 1 so they precede the
|
||||
// chain's catch-all; route ACL drops are inserted at position 2
|
||||
// to sit immediately after the established/related accept rule.
|
||||
var err error
|
||||
if action == firewall.ActionDrop {
|
||||
pos := 1
|
||||
if isRoute {
|
||||
pos = 2
|
||||
}
|
||||
err = r.iptablesClient.Insert(tableFilter, chain, pos, specs...)
|
||||
} else {
|
||||
err = r.iptablesClient.Append(tableFilter, chain, specs...)
|
||||
}
|
||||
if err != nil {
|
||||
r.dropSourceMatch(destExp)
|
||||
return nil, fmt.Errorf("install filter rule on %s: %w", chain, err)
|
||||
}
|
||||
|
||||
// The mangle redirect-mark rule is best effort: the filter rule itself
|
||||
// is what enforces the ACL, so a mangle failure must not undo it. Drop
|
||||
// the spec so teardown does not try to remove a rule that was not added.
|
||||
if mangleSpecs != nil {
|
||||
if err := r.iptablesClient.Append(tableMangle, chainRTPre, mangleSpecs...); err != nil {
|
||||
log.Errorf("add mangle rule: %v", err)
|
||||
mangleSpecs = nil
|
||||
}
|
||||
}
|
||||
|
||||
return &Rule{
|
||||
id: ruleID,
|
||||
specs: specs,
|
||||
mangleSpecs: mangleSpecs,
|
||||
chain: chain,
|
||||
v6: r.v6,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// applyNetwork resolves a firewall.Network into the iptables match
|
||||
// fragment for the given direction flag (-s or -d). Set networks
|
||||
// increment the shared ipset refcount; prefixes emit a direct match;
|
||||
// an empty network returns no spec ("match any").
|
||||
func (r *family) applyNetwork(flag string, network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
|
||||
direction := "src"
|
||||
if flag == "-d" {
|
||||
direction = "dst"
|
||||
}
|
||||
|
||||
if network.IsSet() {
|
||||
name := r.ipsetName(network.Set.HashedName())
|
||||
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
|
||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||
}
|
||||
|
||||
return []string{"-m", "set", matchSet, name, direction}, nil
|
||||
}
|
||||
if network.IsPrefix() {
|
||||
return []string{flag, network.Prefix.String()}, nil
|
||||
}
|
||||
|
||||
// nolint:nilnil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
|
||||
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
|
||||
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
|
||||
if v6 && protocol == firewall.ProtocolICMP {
|
||||
return "ipv6-icmp"
|
||||
}
|
||||
return string(protocol)
|
||||
}
|
||||
|
||||
// filterMatchSpecs returns the proto/port match fragment for a
|
||||
// filtering rule. The source match (-s or -m set) is built by the
|
||||
// caller and prepended.
|
||||
func filterMatchSpecs(protocol string, sPort, dPort *firewall.Port) (specs []string) {
|
||||
if protocol != "all" {
|
||||
specs = append(specs, "-p", protocol)
|
||||
}
|
||||
specs = append(specs, applyPort("--sport", sPort)...)
|
||||
specs = append(specs, applyPort("--dport", dPort)...)
|
||||
return specs
|
||||
}
|
||||
|
||||
func actionToStr(action firewall.Action) string {
|
||||
if action == firewall.ActionAccept {
|
||||
return "ACCEPT"
|
||||
}
|
||||
return "DROP"
|
||||
}
|
||||
|
||||
func applyPort(flag string, port *firewall.Port) []string {
|
||||
if port == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if port.IsRange && len(port.Values) == 2 {
|
||||
return []string{flag, fmt.Sprintf("%d:%d", port.Values[0], port.Values[1])}
|
||||
}
|
||||
|
||||
if len(port.Values) > 1 {
|
||||
portList := make([]string, len(port.Values))
|
||||
for i, p := range port.Values {
|
||||
portList[i] = strconv.Itoa(int(p))
|
||||
}
|
||||
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
|
||||
}
|
||||
|
||||
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
||||
}
|
||||
104
client/firewall/iptables/ipset_linux.go
Normal file
104
client/firewall/iptables/ipset_linux.go
Normal file
@@ -0,0 +1,104 @@
|
||||
//go:build !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/lrh3321/ipset-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
func (r *family) createIpSet(setName string, sources []netip.Prefix) error {
|
||||
if err := r.createIPSet(setName); err != nil {
|
||||
return fmt.Errorf("create set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
for _, prefix := range sources {
|
||||
if err := r.addPrefixToIPSet(setName, prefix); err != nil {
|
||||
// The refcounter records nothing when this callback errors,
|
||||
// so destroy the set or it leaks in the kernel. A partial
|
||||
// source set would also fail-open for deny rules, so the
|
||||
// rule must fail rather than install with a missing source.
|
||||
if derr := r.destroyIPSet(setName); derr != nil {
|
||||
log.Warnf("rollback ipset %s after add failure: %v", setName, derr)
|
||||
}
|
||||
return fmt.Errorf("add element to set %s: %w", setName, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) deleteIpSet(setName string) error {
|
||||
if err := r.destroyIPSet(setName); err != nil {
|
||||
return fmt.Errorf("destroy set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
log.Debugf("deleted unused ipset %s", setName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
name := r.ipsetName(set.HashedName())
|
||||
var merr *multierror.Error
|
||||
for _, prefix := range prefixes {
|
||||
if err := r.addPrefixToIPSet(name, prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
|
||||
}
|
||||
}
|
||||
if merr == nil {
|
||||
log.Debugf("updated set %s with prefixes %v", name, prefixes)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) ipsetName(name string) string {
|
||||
if r.v6 {
|
||||
return name + "-v6"
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func (r *family) createIPSet(name string) error {
|
||||
opts := ipset.CreateOptions{
|
||||
Replace: true,
|
||||
}
|
||||
if r.v6 {
|
||||
opts.Family = ipset.FamilyIPV6
|
||||
}
|
||||
|
||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
log.Debugf("created ipset %s with type hash:net", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) addPrefixToIPSet(name string, prefix netip.Prefix) error {
|
||||
addr := prefix.Addr()
|
||||
ip := addr.AsSlice()
|
||||
|
||||
entry := &ipset.Entry{
|
||||
IP: ip,
|
||||
CIDR: uint8(prefix.Bits()),
|
||||
Replace: true,
|
||||
}
|
||||
|
||||
if err := ipset.Add(name, entry); err != nil {
|
||||
return fmt.Errorf("add prefix to ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) destroyIPSet(name string) error {
|
||||
return ipset.Destroy(name)
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package iptables
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
@@ -18,25 +17,21 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
type resetter interface {
|
||||
Reset() error
|
||||
}
|
||||
|
||||
// Manager of iptables firewall
|
||||
// Manager of iptables firewall. Per-family state (peer ACLs, route
|
||||
// ACLs, NAT, DNAT, MSS clamping) lives on family; Manager dispatches
|
||||
// by family and provides the public firewall.Manager surface.
|
||||
type Manager struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
wgIface iFaceMapper
|
||||
|
||||
ipv4Client *iptables.IPTables
|
||||
aclMgr *aclManager
|
||||
router *router
|
||||
family4 *family
|
||||
rawSupported bool
|
||||
|
||||
// IPv6 counterparts, nil when no v6 overlay
|
||||
ipv6Client *iptables.IPTables
|
||||
aclMgr6 *aclManager
|
||||
router6 *router
|
||||
family6 *family
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
@@ -57,14 +52,9 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
||||
ipv4Client: iptablesClient,
|
||||
}
|
||||
|
||||
m.router, err = newRouter(iptablesClient, wgIface, mtu)
|
||||
m.family4, err = newFamily(iptablesClient, wgIface, mtu)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create router: %w", err)
|
||||
}
|
||||
|
||||
m.aclMgr, err = newAclManager(iptablesClient, wgIface)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||
return nil, fmt.Errorf("create family: %w", err)
|
||||
}
|
||||
|
||||
if wgIface.Address().HasIPv6() {
|
||||
@@ -81,21 +71,18 @@ func (m *Manager) createIPv6Components(wgIface iFaceMapper, mtu uint16) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("init ip6tables: %w", err)
|
||||
}
|
||||
m.ipv6Client = ip6Client
|
||||
|
||||
m.router6, err = newRouter(ip6Client, wgIface, mtu)
|
||||
family6, err := newFamily(ip6Client, wgIface, mtu)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create v6 router: %w", err)
|
||||
return fmt.Errorf("create v6 family: %w", err)
|
||||
}
|
||||
|
||||
// Share the same IP forwarding state with the v4 router, since
|
||||
// Share the same IP forwarding state with the v4 family, since
|
||||
// EnableIPForwarding controls both v4 and v6 sysctls.
|
||||
m.router6.ipFwdState = m.router.ipFwdState
|
||||
family6.ipFwdState = m.family4.ipFwdState
|
||||
|
||||
m.aclMgr6, err = newAclManager(ip6Client, wgIface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create v6 acl manager: %w", err)
|
||||
}
|
||||
m.ipv6Client = ip6Client
|
||||
m.family6 = family6
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -109,7 +96,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
InterfaceState: &InterfaceState{
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
MTU: m.router.mtu,
|
||||
MTU: m.family4.mtu,
|
||||
},
|
||||
}
|
||||
stateManager.RegisterState(state)
|
||||
@@ -141,31 +128,24 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// initChains initializes router and ACL chains for both address families,
|
||||
// rolling back on failure.
|
||||
// initChains initializes the per-family firewall state for both
|
||||
// address families, rolling back on failure.
|
||||
func (m *Manager) initChains(stateManager *statemanager.Manager) error {
|
||||
type initStep struct {
|
||||
name string
|
||||
init func(*statemanager.Manager) error
|
||||
mgr resetter
|
||||
r *family
|
||||
}
|
||||
|
||||
steps := []initStep{
|
||||
{"router", m.router.init, m.router},
|
||||
{"acl manager", m.aclMgr.init, m.aclMgr},
|
||||
}
|
||||
steps := []initStep{{"v4", m.family4}}
|
||||
if m.hasIPv6() {
|
||||
steps = append(steps,
|
||||
initStep{"v6 router", m.router6.init, m.router6},
|
||||
initStep{"v6 acl manager", m.aclMgr6.init, m.aclMgr6},
|
||||
)
|
||||
steps = append(steps, initStep{"v6", m.family6})
|
||||
}
|
||||
|
||||
var initialized []initStep
|
||||
for _, s := range steps {
|
||||
if err := s.init(stateManager); err != nil {
|
||||
if err := s.r.init(stateManager); err != nil {
|
||||
for i := len(initialized) - 1; i >= 0; i-- {
|
||||
if rerr := initialized[i].mgr.Reset(); rerr != nil {
|
||||
if rerr := initialized[i].r.Reset(); rerr != nil {
|
||||
log.Warnf("rollback %s: %v", initialized[i].name, rerr)
|
||||
}
|
||||
}
|
||||
@@ -176,84 +156,50 @@ func (m *Manager) initChains(stateManager *statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPeerFiltering adds a rule to the firewall
|
||||
//
|
||||
// Comment will be ignored because some system this feature is not supported
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if ip.To4() != nil {
|
||||
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.aclMgr6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
// AddFilterRule installs a packet-filtering rule. See firewall.Manager
|
||||
// docs for destination semantics. Sources are a single address family;
|
||||
// the rule is dispatched to the matching v4 / v6 backend.
|
||||
func (m *Manager) AddFilterRule(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort, dPort *firewall.Port,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
if len(sources) == 0 {
|
||||
return nil, firewall.ErrNoSources
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if isIPv6RouteRule(sources, destination) {
|
||||
fam := m.family4
|
||||
if isIPv6Rule(sources, destination) {
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
|
||||
return nil, fmt.Errorf("add filtering: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
fam = m.family6
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
return fam.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
|
||||
if destination.IsPrefix() {
|
||||
return destination.Prefix.Addr().Is6()
|
||||
}
|
||||
return len(sources) > 0 && sources[0].Addr().Is6()
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
// DeleteFilterRule removes a rule previously added via AddFilterRule.
|
||||
// The rule is looked up by id in each family's filter cache.
|
||||
func (m *Manager) DeleteFilterRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && isIPv6IptRule(rule) {
|
||||
return m.aclMgr6.DeletePeerRule(rule)
|
||||
id := rule.ID()
|
||||
if m.family4.hasRule(id) {
|
||||
return m.family4.DeleteFilterRule(rule)
|
||||
}
|
||||
return m.aclMgr.DeletePeerRule(rule)
|
||||
}
|
||||
|
||||
func isIPv6IptRule(rule firewall.Rule) bool {
|
||||
r, ok := rule.(*Rule)
|
||||
return ok && r.v6
|
||||
}
|
||||
|
||||
// DeleteRouteRule deletes a routing rule.
|
||||
// Route rules are keyed by content hash. Check v4 first, try v6 if not found.
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && !m.router.hasRule(rule.ID()) {
|
||||
return m.router6.DeleteRouteRule(rule)
|
||||
if m.hasIPv6() && m.family6.hasRule(id) {
|
||||
return m.family6.DeleteFilterRule(rule)
|
||||
}
|
||||
return m.router.DeleteRouteRule(rule)
|
||||
log.Debugf("filter rule %s not found in any family", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) IsServerRouteSupported() bool {
|
||||
@@ -272,10 +218,10 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddNatRule(pair)
|
||||
return m.family6.AddNatRule(pair)
|
||||
}
|
||||
|
||||
if err := m.router.AddNatRule(pair); err != nil {
|
||||
if err := m.family4.AddNatRule(pair); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -284,7 +230,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
// wildcard 0.0.0.0/0 destination where the client resolves DNS.
|
||||
if m.hasIPv6() && pair.Dynamic {
|
||||
v6Pair := firewall.ToV6NatPair(pair)
|
||||
if err := m.router6.AddNatRule(v6Pair); err != nil {
|
||||
if err := m.family6.AddNatRule(v6Pair); err != nil {
|
||||
return fmt.Errorf("add v6 NAT rule: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -300,18 +246,18 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if !m.hasIPv6() {
|
||||
return nil
|
||||
}
|
||||
return m.router6.RemoveNatRule(pair)
|
||||
return m.family6.RemoveNatRule(pair)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := m.router.RemoveNatRule(pair); err != nil {
|
||||
if err := m.family4.RemoveNatRule(pair); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
|
||||
}
|
||||
|
||||
if m.hasIPv6() && pair.Dynamic {
|
||||
v6Pair := firewall.ToV6NatPair(pair)
|
||||
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
|
||||
if err := m.family6.RemoveNatRule(v6Pair); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
|
||||
}
|
||||
}
|
||||
@@ -320,11 +266,11 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
}
|
||||
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
|
||||
if err := firewall.SetLegacyManagement(m.family4, isLegacy); err != nil {
|
||||
return err
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
return firewall.SetLegacyManagement(m.router6, isLegacy)
|
||||
return firewall.SetLegacyManagement(m.family6, isLegacy)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -341,19 +287,13 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
}
|
||||
|
||||
if m.hasIPv6() {
|
||||
if err := m.aclMgr6.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 acl manager: %w", err))
|
||||
}
|
||||
if err := m.router6.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %w", err))
|
||||
if err := m.family6.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 family: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.aclMgr.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
|
||||
}
|
||||
if err := m.router.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
|
||||
if err := m.family4.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset family: %w", err))
|
||||
}
|
||||
|
||||
// Appending to merr intentionally blocks DeleteState below so ShutdownState
|
||||
@@ -377,11 +317,11 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
// rules so that packet filtering is handled in userspace instead of by netfilter.
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
var merr *multierror.Error
|
||||
if _, err := m.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
|
||||
if _, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}, firewall.Network{}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("allow netbird v4 interface traffic: %w", err))
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
if _, err := m.AddPeerFiltering(nil, net.IPv6zero, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
|
||||
if _, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}, firewall.Network{}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("allow netbird v6 interface traffic: %w", err))
|
||||
}
|
||||
}
|
||||
@@ -402,14 +342,14 @@ func (m *Manager) SetLogLevel(log.Level) {
|
||||
}
|
||||
|
||||
func (m *Manager) EnableRouting() error {
|
||||
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
||||
if err := m.family4.ipFwdState.RequestForwarding(); err != nil {
|
||||
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DisableRouting() error {
|
||||
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
if err := m.family4.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
@@ -424,9 +364,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddDNATRule(rule)
|
||||
return m.family6.AddDNATRule(rule)
|
||||
}
|
||||
return m.router.AddDNATRule(rule)
|
||||
return m.family4.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
@@ -434,10 +374,10 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && !m.router.hasRule(rule.ID()+dnatSuffix) {
|
||||
return m.router6.DeleteDNATRule(rule)
|
||||
if m.hasIPv6() && !m.family4.hasDNATRule(rule.ID()) {
|
||||
return m.family6.DeleteDNATRule(rule)
|
||||
}
|
||||
return m.router.DeleteDNATRule(rule)
|
||||
return m.family4.DeleteDNATRule(rule)
|
||||
}
|
||||
|
||||
// UpdateSet updates the set with the given prefixes
|
||||
@@ -454,12 +394,12 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
|
||||
if err := m.family4.UpdateSet(set, v4Prefixes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.hasIPv6() && len(v6Prefixes) > 0 {
|
||||
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
|
||||
if err := m.family6.UpdateSet(set, v6Prefixes); err != nil {
|
||||
return fmt.Errorf("update v6 set: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -476,9 +416,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family4.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
@@ -490,9 +430,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family4.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
@@ -504,9 +444,9 @@ func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family4.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
@@ -518,14 +458,14 @@ func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Proto
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family4.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
|
||||
const (
|
||||
chainNameRaw = "NETBIRD-RAW"
|
||||
chainOUTPUT = "OUTPUT"
|
||||
chainOutput = "OUTPUT"
|
||||
tableRaw = "raw"
|
||||
)
|
||||
|
||||
@@ -600,15 +540,15 @@ func (m *Manager) initNoTrackChain() error {
|
||||
|
||||
jumpRule := []string{"-j", chainNameRaw}
|
||||
|
||||
if err := m.ipv4Client.InsertUnique(tableRaw, chainOUTPUT, 1, jumpRule...); err != nil {
|
||||
if err := m.ipv4Client.InsertUnique(tableRaw, chainOutput, 1, jumpRule...); err != nil {
|
||||
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
|
||||
log.Debugf("delete orphan chain: %v", delErr)
|
||||
}
|
||||
return fmt.Errorf("add output jump rule: %w", err)
|
||||
}
|
||||
|
||||
if err := m.ipv4Client.InsertUnique(tableRaw, chainPREROUTING, 1, jumpRule...); err != nil {
|
||||
if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); delErr != nil {
|
||||
if err := m.ipv4Client.InsertUnique(tableRaw, chainPrerouting, 1, jumpRule...); err != nil {
|
||||
if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOutput, jumpRule...); delErr != nil {
|
||||
log.Debugf("delete output jump rule: %v", delErr)
|
||||
}
|
||||
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
|
||||
@@ -635,11 +575,11 @@ func (m *Manager) cleanupNoTrackChain() error {
|
||||
|
||||
jumpRule := []string{"-j", chainNameRaw}
|
||||
|
||||
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); err != nil {
|
||||
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOutput, jumpRule...); err != nil {
|
||||
return fmt.Errorf("remove output jump rule: %w", err)
|
||||
}
|
||||
|
||||
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPREROUTING, jumpRule...); err != nil {
|
||||
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPrerouting, jumpRule...); err != nil {
|
||||
return fmt.Errorf("remove prerouting jump rule: %w", err)
|
||||
}
|
||||
|
||||
@@ -654,3 +594,13 @@ func (m *Manager) cleanupNoTrackChain() error {
|
||||
func getConntrackEstablished() []string {
|
||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
// isIPv6Rule reports whether the rule belongs to the IPv6 family, from
|
||||
// the destination prefix when set, otherwise from the (single-family)
|
||||
// sources.
|
||||
func isIPv6Rule(sources []netip.Prefix, destination firewall.Network) bool {
|
||||
if destination.IsPrefix() {
|
||||
return destination.Prefix.Addr().Is6()
|
||||
}
|
||||
return len(sources) > 0 && sources[0].Addr().Is6()
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build integration && !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
@@ -65,46 +67,39 @@ func TestIptablesManager(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
var rule2 []fw.Rule
|
||||
var rule2 fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
IsRange: true,
|
||||
Values: []uint16{8043, 8046},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
|
||||
rule2, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", port, nil, fw.ActionAccept)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
for _, r := range rule2 {
|
||||
rr := r.(*Rule)
|
||||
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
|
||||
}
|
||||
rr := rule2.(*Rule)
|
||||
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
|
||||
})
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
for _, r := range rule2 {
|
||||
err := manager.DeletePeerRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
}
|
||||
|
||||
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
|
||||
require.NoError(t, manager.DeleteFilterRule(rule2), "failed to delete rule")
|
||||
})
|
||||
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
// add second rule
|
||||
ip := netip.MustParseAddr("10.20.0.3")
|
||||
port := &fw.Port{Values: []uint16{5353}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "udp", nil, port, fw.ActionAccept)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Close(nil)
|
||||
require.NoError(t, err, "failed to reset")
|
||||
|
||||
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
|
||||
ok, err := ipv4Client.ChainExists("filter", chainACLInput)
|
||||
require.NoError(t, err, "failed check chain exists")
|
||||
|
||||
if ok {
|
||||
require.NoErrorf(t, err, "chain '%v' still exists after Close", chainNameInputRules)
|
||||
require.NoErrorf(t, err, "chain '%v' still exists after Close", chainACLInput)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -126,15 +121,13 @@ func TestIptablesManagerDenyRules(t *testing.T) {
|
||||
ip := netip.MustParseAddr("10.20.0.3")
|
||||
port := &fw.Port{Values: []uint16{22}}
|
||||
|
||||
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionDrop, "deny-ssh")
|
||||
rule, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionDrop)
|
||||
require.NoError(t, err, "failed to add deny rule")
|
||||
require.NotEmpty(t, rule, "deny rule should not be empty")
|
||||
require.NotNil(t, rule, "deny rule should not be nil")
|
||||
|
||||
// Verify the rule was added by checking iptables
|
||||
for _, r := range rule {
|
||||
rr := r.(*Rule)
|
||||
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
|
||||
}
|
||||
rr := rule.(*Rule)
|
||||
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
|
||||
})
|
||||
|
||||
t.Run("deny rule precedence test", func(t *testing.T) {
|
||||
@@ -142,36 +135,40 @@ func TestIptablesManagerDenyRules(t *testing.T) {
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
|
||||
// Add accept rule first
|
||||
_, err := manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "accept-http")
|
||||
_, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
|
||||
require.NoError(t, err, "failed to add accept rule")
|
||||
|
||||
// Add deny rule second for same IP/port - this should take precedence
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionDrop, "deny-http")
|
||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionDrop)
|
||||
require.NoError(t, err, "failed to add deny rule")
|
||||
|
||||
// Inspect the actual iptables rules to verify deny rule comes before accept rule
|
||||
rules, err := ipv4Client.List("filter", chainNameInputRules)
|
||||
rules, err := ipv4Client.List("filter", chainACLInput)
|
||||
require.NoError(t, err, "failed to list iptables rules")
|
||||
|
||||
// Debug: print all rules
|
||||
t.Logf("All iptables rules in chain %s:", chainNameInputRules)
|
||||
t.Logf("All iptables rules in chain %s:", chainACLInput)
|
||||
for i, rule := range rules {
|
||||
t.Logf(" [%d] %s", i, rule)
|
||||
}
|
||||
|
||||
// Single-source rules emit a direct `-s <ip>/32 ... --dport 80`
|
||||
// match. Match on that shape instead of the legacy
|
||||
// per-(action,port) ipset names ("deny-http"/"accept-http")
|
||||
// that this test predates.
|
||||
srcMatch := fmt.Sprintf("-s %s/32", ip)
|
||||
var denyRuleIndex, acceptRuleIndex = -1, -1
|
||||
for i, rule := range rules {
|
||||
if strings.Contains(rule, "DROP") {
|
||||
t.Logf("Found DROP rule at index %d: %s", i, rule)
|
||||
if strings.Contains(rule, "deny-http") && strings.Contains(rule, "80") {
|
||||
denyRuleIndex = i
|
||||
}
|
||||
if !strings.Contains(rule, srcMatch) || !strings.Contains(rule, "--dport 80") {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(rule, "ACCEPT") {
|
||||
if strings.Contains(rule, "-j DROP") {
|
||||
t.Logf("Found DROP rule at index %d: %s", i, rule)
|
||||
denyRuleIndex = i
|
||||
}
|
||||
if strings.Contains(rule, "-j ACCEPT") {
|
||||
t.Logf("Found ACCEPT rule at index %d: %s", i, rule)
|
||||
if strings.Contains(rule, "accept-http") && strings.Contains(rule, "80") {
|
||||
acceptRuleIndex = i
|
||||
}
|
||||
acceptRuleIndex = i
|
||||
}
|
||||
}
|
||||
|
||||
@@ -196,7 +193,6 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
@@ -210,27 +206,39 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
time.Sleep(time.Second)
|
||||
}()
|
||||
|
||||
var rule2 []fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
var rule2 fw.Rule
|
||||
t.Run("single source uses direct -s match (no ipset)", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
Values: []uint16{443},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
|
||||
for _, r := range rule2 {
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||
require.Equal(t, r.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
|
||||
}
|
||||
rule2, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", port, nil, fw.ActionAccept)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
require.NotNil(t, rule2)
|
||||
require.Contains(t, rule2.(*Rule).specs, "-s",
|
||||
"single-source rule should use direct -s match, not an ipset")
|
||||
require.Empty(t, findSets(rule2.(*Rule).specs),
|
||||
"single-source rule should not allocate a shared ipset")
|
||||
})
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
for _, r := range rule2 {
|
||||
err := manager.DeletePeerRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
t.Run("delete single-source rule", func(t *testing.T) {
|
||||
require.NoError(t, manager.DeleteFilterRule(rule2), "failed to delete rule")
|
||||
})
|
||||
|
||||
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
|
||||
t.Run("multi-source uses shared ipset", func(t *testing.T) {
|
||||
sources := []netip.Prefix{
|
||||
netip.PrefixFrom(netip.MustParseAddr("10.20.0.3"), 32),
|
||||
netip.PrefixFrom(netip.MustParseAddr("10.20.0.4"), 32),
|
||||
netip.PrefixFrom(netip.MustParseAddr("10.20.0.5"), 32),
|
||||
}
|
||||
port := &fw.Port{Values: []uint16{8080}}
|
||||
multi, err := manager.AddFilterRule(nil, sources, fw.Network{}, "tcp", nil, port, fw.ActionAccept)
|
||||
require.NoError(t, err, "failed to add multi-source rule")
|
||||
require.NotNil(t, multi, "multi-source rule must produce one iptables rule")
|
||||
sets := findSets(multi.(*Rule).specs)
|
||||
require.Len(t, sets, 1, "multi-source rule must reference exactly one ipset")
|
||||
|
||||
require.NoError(t, manager.DeleteFilterRule(multi))
|
||||
})
|
||||
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
@@ -281,7 +289,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build integration && !android
|
||||
|
||||
package iptables
|
||||
|
||||
@@ -31,7 +31,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "failed to init iptables client")
|
||||
|
||||
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
manager, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "should return a valid iptables manager")
|
||||
require.NoError(t, manager.init(nil))
|
||||
|
||||
@@ -52,12 +52,12 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
// 11. MSS clamping rule for outbound traffic
|
||||
require.Len(t, manager.rules, 11, "should have created rules map")
|
||||
|
||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPostrouting, "-j", chainRTNAT)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPostrouting)
|
||||
require.True(t, exists, "postrouting jump rule should exist")
|
||||
|
||||
exists, err = manager.iptablesClient.Exists(tableMangle, chainPREROUTING, "-j", chainRTPRE)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPREROUTING)
|
||||
exists, err = manager.iptablesClient.Exists(tableMangle, chainPrerouting, "-j", chainRTPre)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPrerouting)
|
||||
require.True(t, exists, "prerouting jump rule should exist")
|
||||
|
||||
pair := firewall.RouterPair{
|
||||
@@ -84,7 +84,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "failed to init iptables client")
|
||||
|
||||
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
manager, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
require.NoError(t, manager.init(nil))
|
||||
|
||||
@@ -95,7 +95,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
err = manager.AddNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "marking rule should be inserted")
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
natRuleKey := testCase.InputPair.GenKey(firewall.NatFormat)
|
||||
markingRule := []string{
|
||||
"-i", ifaceMock.Name(),
|
||||
"-m", "conntrack",
|
||||
@@ -106,8 +106,8 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
||||
}
|
||||
|
||||
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||
exists, err := iptablesClient.Exists(tableMangle, chainRTPre, markingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
|
||||
if testCase.InputPair.Masquerade {
|
||||
require.True(t, exists, "marking rule should be created")
|
||||
foundRule, found := manager.rules[natRuleKey]
|
||||
@@ -121,7 +121,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
|
||||
// Check inverse rule
|
||||
inversePair := firewall.GetInversePair(testCase.InputPair)
|
||||
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
|
||||
inverseRuleKey := inversePair.GenKey(firewall.NatFormat)
|
||||
inverseMarkingRule := []string{
|
||||
"!", "-i", ifaceMock.Name(),
|
||||
"-m", "conntrack",
|
||||
@@ -132,8 +132,8 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||
}
|
||||
|
||||
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||
exists, err = iptablesClient.Exists(tableMangle, chainRTPre, inverseMarkingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
|
||||
if testCase.InputPair.Masquerade {
|
||||
require.True(t, exists, "inverse marking rule should be created")
|
||||
foundRule, found := manager.rules[inverseRuleKey]
|
||||
@@ -157,7 +157,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
|
||||
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
manager, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
require.NoError(t, manager.init(nil))
|
||||
defer func() {
|
||||
@@ -170,7 +170,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
err = manager.RemoveNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||
natRuleKey := testCase.InputPair.GenKey(firewall.NatFormat)
|
||||
markingRule := []string{
|
||||
"-i", ifaceMock.Name(),
|
||||
"-m", "conntrack",
|
||||
@@ -181,8 +181,8 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
||||
}
|
||||
|
||||
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||
exists, err := iptablesClient.Exists(tableMangle, chainRTPre, markingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
|
||||
require.False(t, exists, "marking rule should not exist")
|
||||
|
||||
_, found := manager.rules[natRuleKey]
|
||||
@@ -190,7 +190,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
|
||||
// Check inverse rule removal
|
||||
inversePair := firewall.GetInversePair(testCase.InputPair)
|
||||
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
|
||||
inverseRuleKey := inversePair.GenKey(firewall.NatFormat)
|
||||
inverseMarkingRule := []string{
|
||||
"!", "-i", ifaceMock.Name(),
|
||||
"-m", "conntrack",
|
||||
@@ -201,8 +201,8 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||
}
|
||||
|
||||
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||
exists, err = iptablesClient.Exists(tableMangle, chainRTPre, inverseMarkingRule...)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
|
||||
require.False(t, exists, "inverse marking rule should not exist")
|
||||
|
||||
_, found = manager.rules[inverseRuleKey]
|
||||
@@ -219,13 +219,13 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "Failed to create iptables client")
|
||||
|
||||
r, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create router manager")
|
||||
r, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create family manager")
|
||||
require.NoError(t, r.init(nil))
|
||||
|
||||
defer func() {
|
||||
err := r.Reset()
|
||||
require.NoError(t, err, "Failed to reset router")
|
||||
require.NoError(t, err, "Failed to reset family")
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
@@ -334,62 +334,30 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddRouteFiltering failed")
|
||||
ruleKey, err := r.AddFilterRule(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddFilterRule failed")
|
||||
|
||||
// Check if the rule is in the internal map
|
||||
rule, ok := r.rules[ruleKey.ID()]
|
||||
assert.True(t, ok, "Rule not found in internal map")
|
||||
stored, ok := r.filters[ruleKey.ID()]
|
||||
require.True(t, ok, "rule not stored in filters")
|
||||
t.Logf("Internal rule: %v", stored.specs)
|
||||
|
||||
// Log the internal rule
|
||||
t.Logf("Internal rule: %v", rule)
|
||||
|
||||
// Check if the rule exists in iptables
|
||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWDIN, rule...)
|
||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFwdIn, stored.specs...)
|
||||
assert.NoError(t, err, "Failed to check rule existence")
|
||||
assert.True(t, exists, "Rule not found in iptables")
|
||||
|
||||
var source firewall.Network
|
||||
if len(tt.sources) > 1 {
|
||||
source.Set = firewall.NewPrefixSet(tt.sources)
|
||||
} else if len(tt.sources) > 0 {
|
||||
source.Prefix = tt.sources[0]
|
||||
}
|
||||
// Verify rule content
|
||||
params := routeFilteringRuleParams{
|
||||
Source: source,
|
||||
Destination: firewall.Network{Prefix: tt.destination},
|
||||
Proto: tt.proto,
|
||||
SPort: tt.sPort,
|
||||
DPort: tt.dPort,
|
||||
Action: tt.action,
|
||||
}
|
||||
|
||||
expectedRule, err := r.genRouteRuleSpec(params, nil)
|
||||
require.NoError(t, err, "Failed to generate expected rule spec")
|
||||
|
||||
if tt.expectSet {
|
||||
setName := firewall.NewPrefixSet(tt.sources).HashedName()
|
||||
expectedRule, err = r.genRouteRuleSpec(params, nil)
|
||||
require.NoError(t, err, "Failed to generate expected rule spec with set")
|
||||
|
||||
// Check if the set was created
|
||||
_, exists := r.ipsetCounter.Get(setName)
|
||||
assert.True(t, exists, "IPSet not created")
|
||||
assert.NotEmpty(t, findSets(stored.specs), "Rule should reference an ipset")
|
||||
}
|
||||
|
||||
assert.Equal(t, expectedRule, rule, "Rule content mismatch")
|
||||
|
||||
// Clean up
|
||||
err = r.DeleteRouteRule(ruleKey)
|
||||
require.NoError(t, err, "Failed to delete rule")
|
||||
require.NoError(t, r.DeleteFilterRule(ruleKey), "Failed to delete rule")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindSetNameInRule(t *testing.T) {
|
||||
r := &router{}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
rule []string
|
||||
@@ -430,7 +398,7 @@ func TestFindSetNameInRule(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := r.findSets(tc.rule)
|
||||
result := findSets(tc.rule)
|
||||
|
||||
if len(result) != len(tc.expected) {
|
||||
t.Errorf("Expected %d sets, got %d. Sets found: %v", len(tc.expected), len(result), result)
|
||||
|
||||
269
client/firewall/iptables/routing_linux.go
Normal file
269
client/firewall/iptables/routing_linux.go
Normal file
@@ -0,0 +1,269 @@
|
||||
//go:build !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
func (r *family) AddNatRule(pair firewall.RouterPair) error {
|
||||
if r.legacyManagement {
|
||||
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("add legacy routing rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !pair.Masquerade {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.addNatRule(pair); err != nil {
|
||||
return fmt.Errorf("add nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("add inverse nat rule: %w", err)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
||||
func (r *family) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if pair.Masquerade {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("remove inverse nat rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||
func (r *family) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
ruleID := pair.GenKey(firewall.ForwardingFormat)
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", "ACCEPT"}
|
||||
if err := r.iptablesClient.Append(tableFilter, chainRTFwdIn, rule...); err != nil {
|
||||
return fmt.Errorf("add legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
r.rules[ruleID] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
ruleID := pair.GenKey(firewall.ForwardingFormat)
|
||||
|
||||
if rule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFwdIn, rule...); err != nil {
|
||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement ipset counter: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLegacyManagement returns the current legacy management mode
|
||||
func (r *family) GetLegacyManagement() bool {
|
||||
return r.legacyManagement
|
||||
}
|
||||
|
||||
// SetLegacyManagement sets the route manager to use legacy management mode
|
||||
func (r *family) SetLegacyManagement(isLegacy bool) {
|
||||
r.legacyManagement = isLegacy
|
||||
}
|
||||
|
||||
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
|
||||
func (r *family) RemoveAllLegacyRouteRules() error {
|
||||
var merr *multierror.Error
|
||||
for k, rule := range r.rules {
|
||||
if !strings.HasPrefix(string(k), firewall.ForwardingFormatPrefix) {
|
||||
continue
|
||||
}
|
||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFwdIn, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %w", err))
|
||||
} else {
|
||||
delete(r.rules, k)
|
||||
}
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) addPostroutingRules() error {
|
||||
// First rule for outbound masquerade
|
||||
rule1 := []string{
|
||||
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
||||
"!", "-o", "lo",
|
||||
"-j", "MASQUERADE",
|
||||
}
|
||||
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule1...); err != nil {
|
||||
return fmt.Errorf("add outbound masquerade rule: %w", err)
|
||||
}
|
||||
r.rules["static-nat-outbound"] = rule1
|
||||
|
||||
// Second rule for return traffic masquerade
|
||||
rule2 := []string{
|
||||
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||
"-o", r.wgIface.Name(),
|
||||
"-j", "MASQUERADE",
|
||||
}
|
||||
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule2...); err != nil {
|
||||
return fmt.Errorf("add return masquerade rule: %w", err)
|
||||
}
|
||||
r.rules["static-nat-return"] = rule2
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||
func (r *family) addMSSClampingRules() error {
|
||||
overhead := uint16(ipv4TCPHeaderSize)
|
||||
if r.v6 {
|
||||
overhead = ipv6TCPHeaderSize
|
||||
}
|
||||
mss := r.mtu - overhead
|
||||
|
||||
// Add jump rule from FORWARD chain in mangle table to our custom chain
|
||||
jumpRule := []string{
|
||||
"-j", chainRTMSSClamp,
|
||||
}
|
||||
if err := r.iptablesClient.Insert(tableMangle, chainForward, 1, jumpRule...); err != nil {
|
||||
return fmt.Errorf("add jump to MSS clamp chain: %w", err)
|
||||
}
|
||||
r.rules[jumpMSSClamp] = jumpRule
|
||||
|
||||
ruleOut := []string{
|
||||
"-o", r.wgIface.Name(),
|
||||
"-p", "tcp",
|
||||
"--tcp-flags", "SYN,RST", "SYN",
|
||||
"-j", "TCPMSS",
|
||||
"--set-mss", fmt.Sprintf("%d", mss),
|
||||
}
|
||||
if err := r.iptablesClient.Append(tableMangle, chainRTMSSClamp, ruleOut...); err != nil {
|
||||
return fmt.Errorf("add outbound MSS clamp rule: %w", err)
|
||||
}
|
||||
r.rules["mss-clamp-out"] = ruleOut
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) insertEstablishedRule(chain string) error {
|
||||
establishedRule := getConntrackEstablished()
|
||||
|
||||
err := r.iptablesClient.Insert(tableFilter, chain, 1, establishedRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert established rule: %w", err)
|
||||
}
|
||||
|
||||
ruleID := firewall.RuleID("established-" + chain)
|
||||
r.rules[ruleID] = establishedRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) addNatRule(pair firewall.RouterPair) error {
|
||||
ruleID := pair.GenKey(firewall.NatFormat)
|
||||
|
||||
if rule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPre, rule...); err != nil {
|
||||
return fmt.Errorf("remove existing marking rule for %s: %w", pair.Destination, err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
}
|
||||
|
||||
markValue := nbnet.PreroutingFwmarkMasquerade
|
||||
if pair.Inverse {
|
||||
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
|
||||
}
|
||||
|
||||
rule := []string{"-i", r.wgIface.Name()}
|
||||
if pair.Inverse {
|
||||
rule = []string{"!", "-i", r.wgIface.Name()}
|
||||
}
|
||||
|
||||
rule = append(rule,
|
||||
"-m", "conntrack",
|
||||
"--ctstate", "NEW",
|
||||
)
|
||||
sourceExp, err := r.applyNetwork("-s", pair.Source, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply network -s: %w", err)
|
||||
}
|
||||
destExp, err := r.applyNetwork("-d", pair.Destination, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply network -d: %w", err)
|
||||
}
|
||||
|
||||
rule = append(rule, sourceExp...)
|
||||
rule = append(rule, destExp...)
|
||||
rule = append(rule,
|
||||
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
|
||||
)
|
||||
|
||||
// Ensure nat rules come first, so the mark can be overwritten.
|
||||
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
|
||||
if err := r.iptablesClient.Insert(tableMangle, chainRTPre, 1, rule...); err != nil {
|
||||
r.dropSourceMatch(rule)
|
||||
return fmt.Errorf("add marking rule for %s: %w", pair.Destination, err)
|
||||
}
|
||||
|
||||
r.rules[ruleID] = rule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) removeNatRule(pair firewall.RouterPair) error {
|
||||
ruleID := pair.GenKey(firewall.NatFormat)
|
||||
|
||||
if rule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPre, rule...); err != nil {
|
||||
return fmt.Errorf("remove marking rule for %s: %w", pair.Destination, err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement ipset counter: %w", err)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("marking rule %s not found", ruleID)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
@@ -1,18 +1,20 @@
|
||||
package iptables
|
||||
|
||||
// Rule to handle management of rules
|
||||
type Rule struct {
|
||||
ruleID string
|
||||
ipsetName string
|
||||
import "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
|
||||
// Rule to handle management of rules. Source set membership (when the
|
||||
// rule was built against a shared hash:net ipset) is encoded in specs;
|
||||
// DeleteFilterRule recovers it via findSets so the refcounter can drop
|
||||
// the right reference.
|
||||
type Rule struct {
|
||||
id manager.RuleID
|
||||
specs []string
|
||||
mangleSpecs []string
|
||||
ip string
|
||||
chain string
|
||||
v6 bool
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *Rule) ID() string {
|
||||
return r.ruleID
|
||||
// ID returns the rule id
|
||||
func (r *Rule) ID() manager.RuleID {
|
||||
return r.id
|
||||
}
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
package iptables
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type ipList struct {
|
||||
ips map[string]struct{}
|
||||
}
|
||||
|
||||
func newIpList(ip string) *ipList {
|
||||
ips := make(map[string]struct{})
|
||||
ips[ip] = struct{}{}
|
||||
|
||||
return &ipList{
|
||||
ips: ips,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ipList) addIP(ip string) {
|
||||
s.ips[ip] = struct{}{}
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler
|
||||
func (s *ipList) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
IPs map[string]struct{} `json:"ips"`
|
||||
}{
|
||||
IPs: s.ips,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler
|
||||
func (s *ipList) UnmarshalJSON(data []byte) error {
|
||||
temp := struct {
|
||||
IPs map[string]struct{} `json:"ips"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return err
|
||||
}
|
||||
s.ips = temp.IPs
|
||||
|
||||
if temp.IPs == nil {
|
||||
temp.IPs = make(map[string]struct{})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type ipsetStore struct {
|
||||
ipsets map[string]*ipList
|
||||
}
|
||||
|
||||
func newIpsetStore() *ipsetStore {
|
||||
return &ipsetStore{
|
||||
ipsets: make(map[string]*ipList),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
|
||||
r, ok := s.ipsets[ipsetName]
|
||||
return r, ok
|
||||
}
|
||||
|
||||
func (s *ipsetStore) addIpList(ipsetName string, list *ipList) {
|
||||
s.ipsets[ipsetName] = list
|
||||
}
|
||||
|
||||
func (s *ipsetStore) deleteIpset(ipsetName string) {
|
||||
delete(s.ipsets, ipsetName)
|
||||
}
|
||||
|
||||
func (s *ipsetStore) ipsetNames() []string {
|
||||
names := make([]string, 0, len(s.ipsets))
|
||||
for name := range s.ipsets {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler
|
||||
func (s *ipsetStore) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
IPSets map[string]*ipList `json:"ipsets"`
|
||||
}{
|
||||
IPSets: s.ipsets,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler
|
||||
func (s *ipsetStore) UnmarshalJSON(data []byte) error {
|
||||
temp := struct {
|
||||
IPSets map[string]*ipList `json:"ipsets"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return err
|
||||
}
|
||||
s.ipsets = temp.IPSets
|
||||
|
||||
if temp.IPSets == nil {
|
||||
temp.IPSets = make(map[string]*ipList)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -29,17 +29,13 @@ type ShutdownState struct {
|
||||
|
||||
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
|
||||
|
||||
RouteRules routeRules `json:"route_rules,omitempty"`
|
||||
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
|
||||
|
||||
ACLEntries aclEntries `json:"acl_entries,omitempty"`
|
||||
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
|
||||
|
||||
// IPv6 counterparts
|
||||
RouteRules routeRules `json:"route_rules,omitempty"`
|
||||
RouteRules6 routeRules `json:"route_rules_v6,omitempty"`
|
||||
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
|
||||
RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"`
|
||||
ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"`
|
||||
ACLIPsetStore6 *ipsetStore `json:"acl_ipset_store_v6,omitempty"`
|
||||
|
||||
ACLEntries aclEntries `json:"acl_entries,omitempty"`
|
||||
ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"`
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Name() string {
|
||||
@@ -57,17 +53,14 @@ func (s *ShutdownState) Cleanup() error {
|
||||
}
|
||||
|
||||
if s.RouteRules != nil {
|
||||
ipt.router.rules = s.RouteRules
|
||||
ipt.family4.rules = s.RouteRules
|
||||
}
|
||||
if s.RouteIPsetCounter != nil {
|
||||
ipt.router.ipsetCounter.LoadData(s.RouteIPsetCounter)
|
||||
ipt.family4.ipsetCounter.LoadData(s.RouteIPsetCounter)
|
||||
}
|
||||
|
||||
if s.ACLEntries != nil {
|
||||
ipt.aclMgr.entries = s.ACLEntries
|
||||
}
|
||||
if s.ACLIPsetStore != nil {
|
||||
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
|
||||
ipt.family4.entries = s.ACLEntries
|
||||
}
|
||||
|
||||
// Clean up v6 state even if the current run has no IPv6.
|
||||
@@ -79,16 +72,13 @@ func (s *ShutdownState) Cleanup() error {
|
||||
}
|
||||
if ipt.hasIPv6() {
|
||||
if s.RouteRules6 != nil {
|
||||
ipt.router6.rules = s.RouteRules6
|
||||
ipt.family6.rules = s.RouteRules6
|
||||
}
|
||||
if s.RouteIPsetCounter6 != nil {
|
||||
ipt.router6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
|
||||
ipt.family6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
|
||||
}
|
||||
if s.ACLEntries6 != nil {
|
||||
ipt.aclMgr6.entries = s.ACLEntries6
|
||||
}
|
||||
if s.ACLIPsetStore6 != nil {
|
||||
ipt.aclMgr6.ipsetStore = s.ACLIPsetStore6
|
||||
ipt.family6.entries = s.ACLEntries6
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
27
client/firewall/iptables/testhelpers_linux_test.go
Normal file
27
client/firewall/iptables/testhelpers_linux_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
//go:build integration && !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func pfx(ip net.IP) []netip.Prefix {
|
||||
if ip == nil {
|
||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||
}
|
||||
if ip.IsUnspecified() {
|
||||
if ip.To4() != nil {
|
||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||
}
|
||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||
}
|
||||
a, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("invalid IP length: %d", len(ip)))
|
||||
}
|
||||
a = a.Unmap()
|
||||
return []netip.Prefix{netip.PrefixFrom(a, a.BitLen())}
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package manager
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sort"
|
||||
|
||||
@@ -16,6 +15,12 @@ import (
|
||||
// method but the IPv6 firewall components were not initialized.
|
||||
var ErrIPv6NotInitialized = errors.New("IPv6 firewall not initialized")
|
||||
|
||||
// ErrNoSources is returned when AddFilterRule is called with an empty
|
||||
// source list. "Match any source" must be expressed explicitly with a
|
||||
// /0 prefix; an empty list is a caller error and is rejected rather
|
||||
// than silently widening the rule to every source.
|
||||
var ErrNoSources = errors.New("rule has no sources")
|
||||
|
||||
const (
|
||||
ForwardingFormatPrefix = "netbird-fwd-"
|
||||
ForwardingFormat = "netbird-fwd-%s-%t"
|
||||
@@ -23,13 +28,18 @@ const (
|
||||
NatFormat = "netbird-nat-%s-%t"
|
||||
)
|
||||
|
||||
// RuleID identifies a firewall rule. It is a typed string so the
|
||||
// compiler catches accidental mixing with arbitrary string keys. It is
|
||||
// only an identifier and does not implement Rule.
|
||||
type RuleID string
|
||||
|
||||
// Rule abstraction should be implemented by each firewall manager
|
||||
//
|
||||
// Each firewall type for different OS can use different type
|
||||
// of the properties to hold data of the created rule
|
||||
type Rule interface {
|
||||
// ID returns the rule id
|
||||
ID() string
|
||||
ID() RuleID
|
||||
}
|
||||
|
||||
// RuleDirection is the traffic direction which a rule is applied
|
||||
@@ -91,6 +101,13 @@ func (d Network) IsPrefix() bool {
|
||||
return d.Prefix.IsValid()
|
||||
}
|
||||
|
||||
// IsZero returns true if the network designates no destination, i.e. it
|
||||
// is the zero value. A zero Network is the peer-rule sentinel; a non-zero
|
||||
// one carries a prefix or set destination.
|
||||
func (d Network) IsZero() bool {
|
||||
return !d.IsPrefix() && !d.IsSet()
|
||||
}
|
||||
|
||||
// Manager is the high level abstraction of a firewall manager
|
||||
//
|
||||
// It declares methods which handle actions required by the
|
||||
@@ -101,43 +118,42 @@ type Manager interface {
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
AllowNetbird() error
|
||||
|
||||
// AddPeerFiltering adds a rule to the firewall
|
||||
// AddFilterRule adds a packet-filtering rule to the firewall.
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
// If destination is the zero Network, the rule applies to traffic
|
||||
// inbound to this node, i.e. peer ACL semantics, installed in
|
||||
// the kernel's input chain. If destination is set (prefix or
|
||||
// set), the rule applies to forwarded traffic with that
|
||||
// destination, route ACL semantics, installed in the forward
|
||||
// chain.
|
||||
//
|
||||
// Note: Callers should call Flush() after adding rules to ensure
|
||||
// they are applied to the kernel and rule handles are refreshed.
|
||||
AddPeerFiltering(
|
||||
// sources must be a single address family; the caller splits mixed
|
||||
// families and calls once per family. "Match any source" must be
|
||||
// expressed with an explicit /0 prefix; an empty sources list is
|
||||
// rejected with ErrNoSources so a zeroed list can never widen a
|
||||
// rule to every source.
|
||||
//
|
||||
// Note: callers should call Flush() after adding rules.
|
||||
AddFilterRule(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
sources []netip.Prefix,
|
||||
destination Network,
|
||||
proto Protocol,
|
||||
sPort *Port,
|
||||
dPort *Port,
|
||||
action Action,
|
||||
ipsetName string,
|
||||
) ([]Rule, error)
|
||||
) (Rule, error)
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
DeletePeerRule(rule Rule) error
|
||||
// DeleteFilterRule removes a filtering rule previously added via
|
||||
// AddFilterRule. The rule's own type identifies whether it lives
|
||||
// in the peer (input) or route (forward) path.
|
||||
DeleteFilterRule(rule Rule) error
|
||||
|
||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||
IsServerRouteSupported() bool
|
||||
|
||||
IsStateful() bool
|
||||
|
||||
AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination Network,
|
||||
proto Protocol,
|
||||
sPort, dPort *Port,
|
||||
action Action,
|
||||
) (Rule, error)
|
||||
|
||||
// DeleteRouteRule deletes a routing rule
|
||||
DeleteRouteRule(rule Rule) error
|
||||
|
||||
// AddNatRule inserts a routing NAT rule
|
||||
AddNatRule(pair RouterPair) error
|
||||
|
||||
@@ -185,8 +201,9 @@ type Manager interface {
|
||||
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
|
||||
}
|
||||
|
||||
func GenKey(format string, pair RouterPair) string {
|
||||
return fmt.Sprintf(format, pair.ID, pair.Inverse)
|
||||
// GenKey builds the rule id for this pair from the given format.
|
||||
func (p RouterPair) GenKey(format string) RuleID {
|
||||
return RuleID(fmt.Sprintf(format, p.ID, p.Inverse))
|
||||
}
|
||||
|
||||
// LegacyManager defines the interface for legacy management operations
|
||||
@@ -242,6 +259,20 @@ func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
|
||||
return merged
|
||||
}
|
||||
|
||||
// UnmapPrefix normalizes a v4-mapped v6 prefix (::ffff:a.b.c.d) to its
|
||||
// plain v4 form, shifting the prefix length out of the 96-bit mapped
|
||||
// range. Other prefixes are returned unchanged. Keeping prefixes
|
||||
// unmapped ensures v4 rules match consistently and the match builders
|
||||
// read the correct address length.
|
||||
func UnmapPrefix(p netip.Prefix) netip.Prefix {
|
||||
addr := p.Addr()
|
||||
if !addr.Is4In6() {
|
||||
return p
|
||||
}
|
||||
bits := max(p.Bits()-96, 0)
|
||||
return netip.PrefixFrom(addr.Unmap(), bits)
|
||||
}
|
||||
|
||||
// SortPrefixes sorts the given slice of netip.Prefix in place.
|
||||
// It sorts first by IP address, then by prefix length (most specific to least specific).
|
||||
func SortPrefixes(prefixes []netip.Prefix) {
|
||||
|
||||
@@ -13,13 +13,13 @@ type ForwardRule struct {
|
||||
TranslatedPort Port
|
||||
}
|
||||
|
||||
func (r ForwardRule) ID() string {
|
||||
func (r ForwardRule) ID() RuleID {
|
||||
id := fmt.Sprintf("%s;%s;%s;%s",
|
||||
r.Protocol,
|
||||
r.DestinationPort.String(),
|
||||
r.TranslatedAddress.String(),
|
||||
r.TranslatedPort.String())
|
||||
return id
|
||||
return RuleID(id)
|
||||
}
|
||||
|
||||
func (r ForwardRule) String() string {
|
||||
|
||||
@@ -40,7 +40,7 @@ func (h Set) Comment() string {
|
||||
|
||||
// NewPrefixSet generates a unique name for an ipset based on the given prefixes.
|
||||
func NewPrefixSet(prefixes []netip.Prefix) Set {
|
||||
// sort for consistent naming
|
||||
prefixes = slices.Clone(prefixes)
|
||||
SortPrefixes(prefixes)
|
||||
|
||||
hash := sha256.New()
|
||||
|
||||
@@ -1,713 +0,0 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
// rules chains contains the effective ACL rules
|
||||
chainNameInputRules = "netbird-acl-input-rules"
|
||||
|
||||
// filter chains contains the rules that jump to the rules chains
|
||||
chainNameInputFilter = "netbird-acl-input-filter"
|
||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||
chainNameManglePrerouting = "netbird-mangle-prerouting"
|
||||
chainNameManglePostrouting = "netbird-mangle-postrouting"
|
||||
)
|
||||
|
||||
const flushError = "flush: %w"
|
||||
|
||||
type AclManager struct {
|
||||
rConn *nftables.Conn
|
||||
sConn *nftables.Conn
|
||||
wgIface iFaceMapper
|
||||
routingFwChainName string
|
||||
af addrFamily
|
||||
|
||||
workTable *nftables.Table
|
||||
chainInputRules *nftables.Chain
|
||||
chainPrerouting *nftables.Chain
|
||||
|
||||
ipsetStore *ipsetStore
|
||||
rules map[string]*Rule
|
||||
}
|
||||
|
||||
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) {
|
||||
// sConn is used for creating sets and adding/removing elements from them
|
||||
// it's differ then rConn (which does create new conn for each flush operation)
|
||||
// and is permanent. Using same connection for both type of operations
|
||||
// overloads netlink with high amount of rules ( > 10000)
|
||||
sConn, err := nftables.New(nftables.AsLasting())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create nf conn: %w", err)
|
||||
}
|
||||
|
||||
return &AclManager{
|
||||
rConn: &nftables.Conn{},
|
||||
sConn: sConn,
|
||||
wgIface: wgIface,
|
||||
workTable: table,
|
||||
routingFwChainName: routingFwChainName,
|
||||
af: familyForAddr(table.Family == nftables.TableFamilyIPv4),
|
||||
|
||||
ipsetStore: newIpsetStore(),
|
||||
rules: make(map[string]*Rule),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *AclManager) init(workTable *nftables.Table) error {
|
||||
m.workTable = workTable
|
||||
return m.createDefaultChains()
|
||||
}
|
||||
|
||||
// AddPeerFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *AclManager) AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var ipset *nftables.Set
|
||||
if ipsetName != "" {
|
||||
var err error
|
||||
ipset, err = m.addIpToSet(ipsetName, ip)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
newRules := make([]firewall.Rule, 0, 2)
|
||||
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newRules = append(newRules, ioRule)
|
||||
return newRules, nil
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
r, ok := rule.(*Rule)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
}
|
||||
|
||||
if r.nftSet == nil {
|
||||
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
||||
log.Errorf("failed to delete rule: %v", err)
|
||||
}
|
||||
if r.mangleRule != nil {
|
||||
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
||||
log.Errorf("failed to delete mangle rule: %v", err)
|
||||
}
|
||||
}
|
||||
delete(m.rules, r.ID())
|
||||
return m.rConn.Flush()
|
||||
}
|
||||
|
||||
ips, ok := m.ipsetStore.ips(r.nftSet.Name)
|
||||
if !ok {
|
||||
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
||||
log.Errorf("failed to delete rule: %v", err)
|
||||
}
|
||||
if r.mangleRule != nil {
|
||||
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
||||
log.Errorf("failed to delete mangle rule: %v", err)
|
||||
}
|
||||
}
|
||||
delete(m.rules, r.ID())
|
||||
return m.rConn.Flush()
|
||||
}
|
||||
|
||||
if _, ok := ips[r.ip.String()]; ok {
|
||||
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: ipToBytes(r.ip, m.af)}})
|
||||
if err != nil {
|
||||
log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err)
|
||||
}
|
||||
if err := m.sConn.Flush(); err != nil {
|
||||
log.Debugf("flush error of set delete element, %s", r.nftSet.Name)
|
||||
return err
|
||||
}
|
||||
m.ipsetStore.DeleteIpFromSet(r.nftSet.Name, r.ip)
|
||||
}
|
||||
|
||||
// if after delete, set still contains other IPs,
|
||||
// no need to delete firewall rule and we should exit here
|
||||
if len(ips) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
||||
log.Errorf("failed to delete rule: %v", err)
|
||||
}
|
||||
if r.mangleRule != nil {
|
||||
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
||||
log.Errorf("failed to delete mangle rule: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
delete(m.rules, r.ID())
|
||||
m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name)
|
||||
|
||||
if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// we delete last IP from the set, that means we need to delete
|
||||
// set itself and associated firewall rule too
|
||||
m.rConn.FlushSet(r.nftSet)
|
||||
m.rConn.DelSet(r.nftSet)
|
||||
m.ipsetStore.deleteIpset(r.nftSet.Name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// createDefaultAllowRules creates default allow rules for the input and output chains
|
||||
func (m *AclManager) createDefaultAllowRules() error {
|
||||
expIn := []expr.Any{
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
|
||||
_ = m.rConn.InsertRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: m.chainInputRules,
|
||||
Position: 0,
|
||||
Exprs: expIn,
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush rule/chain/set operations from the buffer
|
||||
//
|
||||
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||
func (m *AclManager) Flush() error {
|
||||
if err := m.flushWithBackoff(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := m.refreshRuleHandles(m.chainInputRules, false); err != nil {
|
||||
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
||||
}
|
||||
if err := m.refreshRuleHandles(m.chainPrerouting, true); err != nil {
|
||||
log.Errorf("failed to refresh rule handles prerouting chain: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AclManager) addIOFiltering(
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipset *nftables.Set,
|
||||
) (*Rule, error) {
|
||||
ruleId := generatePeerRuleId(ip, proto, sPort, dPort, action, ipset)
|
||||
if r, ok := m.rules[ruleId]; ok {
|
||||
return &Rule{
|
||||
nftRule: r.nftRule,
|
||||
mangleRule: r.mangleRule,
|
||||
nftSet: r.nftSet,
|
||||
ruleID: r.ruleID,
|
||||
ip: ip,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var expressions []expr.Any
|
||||
|
||||
if proto != firewall.ProtocolALL {
|
||||
expressions = append(expressions, &expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: m.af.protoOffset,
|
||||
Len: uint32(1),
|
||||
})
|
||||
|
||||
protoData, err := m.af.protoNum(proto)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert protocol to number: %v", err)
|
||||
}
|
||||
|
||||
expressions = append(expressions, &expr.Cmp{
|
||||
Register: 1,
|
||||
Op: expr.CmpOpEq,
|
||||
Data: []byte{protoData},
|
||||
})
|
||||
}
|
||||
|
||||
rawIP := ipToBytes(ip, m.af)
|
||||
// check if rawIP contains zeroed IPv4 0.0.0.0 value
|
||||
// in that case not add IP match expression into the rule definition
|
||||
if slices.ContainsFunc(rawIP, func(v byte) bool { return v != 0 }) {
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: m.af.srcAddrOffset,
|
||||
Len: m.af.addrLen,
|
||||
},
|
||||
)
|
||||
// add individual IP for match if no ipset defined
|
||||
if ipset == nil {
|
||||
expressions = append(expressions,
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: rawIP,
|
||||
},
|
||||
)
|
||||
} else {
|
||||
expressions = append(expressions,
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: ipset.Name,
|
||||
SetID: ipset.ID,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
expressions = append(expressions, applyPort(sPort, true)...)
|
||||
expressions = append(expressions, applyPort(dPort, false)...)
|
||||
|
||||
mainExpressions := slices.Clone(expressions)
|
||||
|
||||
switch action {
|
||||
case firewall.ActionAccept:
|
||||
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictAccept})
|
||||
case firewall.ActionDrop:
|
||||
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||
}
|
||||
|
||||
userData := []byte(ruleId)
|
||||
|
||||
chain := m.chainInputRules
|
||||
rule := &nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: chain,
|
||||
Exprs: mainExpressions,
|
||||
UserData: userData,
|
||||
}
|
||||
|
||||
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||
var nftRule *nftables.Rule
|
||||
if action == firewall.ActionDrop {
|
||||
nftRule = m.rConn.InsertRule(rule)
|
||||
} else {
|
||||
nftRule = m.rConn.AddRule(rule)
|
||||
}
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush input rule %s: %v", ruleId, err)
|
||||
}
|
||||
|
||||
ruleStruct := &Rule{
|
||||
nftRule: nftRule,
|
||||
// best effort mangle rule
|
||||
mangleRule: m.createPreroutingRule(expressions, userData),
|
||||
nftSet: ipset,
|
||||
ruleID: ruleId,
|
||||
ip: ip,
|
||||
}
|
||||
m.rules[ruleId] = ruleStruct
|
||||
if ipset != nil {
|
||||
m.ipsetStore.AddReferenceToIpset(ipset.Name)
|
||||
}
|
||||
|
||||
return ruleStruct, nil
|
||||
}
|
||||
|
||||
func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
|
||||
if m.chainPrerouting == nil {
|
||||
log.Warn("prerouting chain is not created")
|
||||
return nil
|
||||
}
|
||||
|
||||
preroutingExprs := slices.Clone(expressions)
|
||||
|
||||
// interface
|
||||
preroutingExprs = append([]expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
}, preroutingExprs...)
|
||||
|
||||
// local destination and mark
|
||||
preroutingExprs = append(preroutingExprs,
|
||||
&expr.Fib{
|
||||
Register: 1,
|
||||
ResultADDRTYPE: true,
|
||||
FlagDADDR: true,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
|
||||
},
|
||||
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
)
|
||||
|
||||
nfRule := m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: m.chainPrerouting,
|
||||
Exprs: preroutingExprs,
|
||||
UserData: userData,
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
log.Errorf("failed to flush mangle rule %s: %v", string(userData), err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return nfRule
|
||||
}
|
||||
|
||||
func (m *AclManager) createDefaultChains() (err error) {
|
||||
// chainNameInputRules
|
||||
chain := m.createChain(chainNameInputRules)
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
m.chainInputRules = chain
|
||||
|
||||
// netbird-acl-input-filter
|
||||
// type filter hook input priority filter; policy accept;
|
||||
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
|
||||
m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
|
||||
m.addDropExpressions(chain, expr.MetaKeyIIFNAME)
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// netbird-acl-forward-filter
|
||||
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
|
||||
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
|
||||
m.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME)
|
||||
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
if err := m.allowRedirectedTraffic(chainFwFilter); err != nil {
|
||||
log.Errorf("failed to allow redirected traffic: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Makes redirected traffic originally destined for the host itself (now subject to the forward filter)
|
||||
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
||||
// netbird peer IP.
|
||||
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
||||
// Chain is created by route manager
|
||||
// TODO: move creation to a common place
|
||||
m.chainPrerouting = &nftables.Chain{
|
||||
Name: chainNameManglePrerouting,
|
||||
Table: m.workTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
}
|
||||
|
||||
m.addFwmarkToForward(chainFwFilter)
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
||||
m.rConn.InsertRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: chainFwFilter,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictJump,
|
||||
Chain: m.routingFwChainName,
|
||||
},
|
||||
}
|
||||
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: chainFwFilter,
|
||||
Exprs: expressions,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *AclManager) createChain(name string) *nftables.Chain {
|
||||
chain := &nftables.Chain{
|
||||
Name: name,
|
||||
Table: m.workTable,
|
||||
}
|
||||
|
||||
chain = m.rConn.AddChain(chain)
|
||||
|
||||
insertReturnTrafficRule(m.rConn, m.workTable, chain)
|
||||
|
||||
return chain
|
||||
}
|
||||
|
||||
func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain {
|
||||
polAccept := nftables.ChainPolicyAccept
|
||||
chain := &nftables.Chain{
|
||||
Name: name,
|
||||
Table: m.workTable,
|
||||
Hooknum: hookNum,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Policy: &polAccept,
|
||||
}
|
||||
|
||||
return m.rConn.AddChain(chain)
|
||||
}
|
||||
|
||||
func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
||||
}
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictJump,
|
||||
Chain: to,
|
||||
},
|
||||
}
|
||||
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) {
|
||||
ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName)
|
||||
rawIP := ipToBytes(ip, m.af)
|
||||
if err != nil {
|
||||
if ipset, err = m.createSet(m.workTable, ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("get set name: %v", err)
|
||||
}
|
||||
|
||||
m.ipsetStore.newIpset(ipset.Name)
|
||||
}
|
||||
|
||||
if m.ipsetStore.IsIpInSet(ipset.Name, ip) {
|
||||
return ipset, nil
|
||||
}
|
||||
|
||||
if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil {
|
||||
return nil, fmt.Errorf("add set element for the first time: %v", err)
|
||||
}
|
||||
|
||||
m.ipsetStore.AddIpToSet(ipset.Name, ip)
|
||||
|
||||
if err := m.sConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush add elements: %v", err)
|
||||
}
|
||||
|
||||
return ipset, nil
|
||||
}
|
||||
|
||||
// createSet in given table by name
|
||||
func (m *AclManager) createSet(table *nftables.Table, name string) (*nftables.Set, error) {
|
||||
ipset := &nftables.Set{
|
||||
Name: name,
|
||||
Table: table,
|
||||
Dynamic: true,
|
||||
KeyType: m.af.setKeyType,
|
||||
}
|
||||
|
||||
if err := m.rConn.AddSet(ipset, nil); err != nil {
|
||||
return nil, fmt.Errorf("create set: %v", err)
|
||||
}
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush created set: %v", err)
|
||||
}
|
||||
|
||||
return ipset, nil
|
||||
}
|
||||
|
||||
func (m *AclManager) flushWithBackoff() (err error) {
|
||||
backoff := 4
|
||||
backoffTime := 1000 * time.Millisecond
|
||||
for i := 0; ; i++ {
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to flush nftables: %v", err)
|
||||
if !strings.Contains(err.Error(), "busy") {
|
||||
return
|
||||
}
|
||||
log.Error("failed to flush nftables, retrying...")
|
||||
if i == backoff-1 {
|
||||
return err
|
||||
}
|
||||
time.Sleep(backoffTime)
|
||||
backoffTime *= 2
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) error {
|
||||
if m.workTable == nil || chain == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
list, err := m.rConn.GetRules(m.workTable, chain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, rule := range list {
|
||||
if len(rule.UserData) == 0 {
|
||||
continue
|
||||
}
|
||||
split := bytes.Split(rule.UserData, []byte(" "))
|
||||
r, ok := m.rules[string(split[0])]
|
||||
if ok {
|
||||
if mangle {
|
||||
*r.mangleRule = *rule
|
||||
} else {
|
||||
*r.nftRule = *rule
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func generatePeerRuleId(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
|
||||
rulesetID := ":" + string(proto) + ":"
|
||||
if sPort != nil {
|
||||
rulesetID += sPort.String()
|
||||
}
|
||||
rulesetID += ":"
|
||||
if dPort != nil {
|
||||
rulesetID += dPort.String()
|
||||
}
|
||||
rulesetID += ":"
|
||||
rulesetID += strconv.Itoa(int(action))
|
||||
if ipset == nil {
|
||||
return "ip:" + ip.String() + rulesetID
|
||||
}
|
||||
return "set:" + ipset.Name + rulesetID
|
||||
}
|
||||
|
||||
func ifname(n string) []byte {
|
||||
b := make([]byte, 16)
|
||||
copy(b, n+"\x00")
|
||||
return b
|
||||
}
|
||||
|
||||
|
||||
// ipToBytes converts net.IP to the correct byte length for the address family.
|
||||
func ipToBytes(ip net.IP, af addrFamily) []byte {
|
||||
if af.addrLen == 4 {
|
||||
return ip.To4()
|
||||
}
|
||||
return ip.To16()
|
||||
}
|
||||
|
||||
885
client/firewall/nftables/chains_linux.go
Normal file
885
client/firewall/nftables/chains_linux.go
Normal file
@@ -0,0 +1,885 @@
|
||||
//go:build !android
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
func (r *family) createContainers() error {
|
||||
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingFw,
|
||||
Table: r.workTable,
|
||||
})
|
||||
|
||||
prio := *nftables.ChainPriorityNATSource - 1
|
||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingNat,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPostrouting,
|
||||
Priority: &prio,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
r.chains[chainNameRoutingRdr] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingRdr,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityNATDest,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
r.chains[chainNameManglePostrouting] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameManglePostrouting,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPostrouting,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
|
||||
r.chains[chainNameManglePrerouting] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameManglePrerouting,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
|
||||
r.chains[chainNameMangleForward] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameMangleForward,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
|
||||
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
||||
|
||||
r.addPostroutingRules()
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("initialize tables: %v", err)
|
||||
}
|
||||
|
||||
if err := r.addMSSClampingRules(); err != nil {
|
||||
log.Errorf("failed to add MSS clamping rules: %s", err)
|
||||
}
|
||||
|
||||
if err := r.acceptForwardRules(); err != nil {
|
||||
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
||||
}
|
||||
|
||||
if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
log.Errorf("failed to refresh rules: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupDataPlaneMark configures the fwmark for the data plane
|
||||
func (r *family) setupDataPlaneMark() error {
|
||||
if r.chains[chainNameManglePrerouting] == nil || r.chains[chainNameManglePostrouting] == nil {
|
||||
return errors.New("no mangle chains found")
|
||||
}
|
||||
|
||||
ctNew := getCtNewExprs()
|
||||
preExprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
}
|
||||
preExprs = append(preExprs, ctNew...)
|
||||
preExprs = append(preExprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkIn),
|
||||
},
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
)
|
||||
|
||||
preNftRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameManglePrerouting],
|
||||
Exprs: preExprs,
|
||||
}
|
||||
r.conn.AddRule(preNftRule)
|
||||
|
||||
postExprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
}
|
||||
postExprs = append(postExprs, ctNew...)
|
||||
postExprs = append(postExprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkOut),
|
||||
},
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
)
|
||||
|
||||
postNftRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameManglePostrouting],
|
||||
Exprs: postExprs,
|
||||
}
|
||||
r.conn.AddRule(postNftRule)
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) acceptForwardRules() error {
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := r.acceptFilterTableRules(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.acceptExternalChainsRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) acceptFilterTableRules() error {
|
||||
if r.filterTable == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
fw := "iptables"
|
||||
|
||||
defer func() {
|
||||
log.Debugf("Used %s to add accept forward and input rules", fw)
|
||||
}()
|
||||
|
||||
// Try iptables first and fallback to nftables if iptables is not available.
|
||||
// Use the correct protocol (iptables vs ip6tables) for the address family.
|
||||
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
|
||||
if err != nil {
|
||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||
|
||||
fw = "nftables"
|
||||
return r.acceptFilterRulesNftables(r.filterTable)
|
||||
}
|
||||
|
||||
if err := r.acceptFilterRulesIptables(ipt); err != nil {
|
||||
log.Warnf("iptables failed (table may be incompatible), falling back to nftables: %v", err)
|
||||
fw = "nftables"
|
||||
return r.acceptFilterRulesNftables(r.filterTable)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||
var merr *multierror.Error
|
||||
|
||||
for _, rule := range r.getAcceptForwardRules() {
|
||||
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err))
|
||||
} else {
|
||||
log.Debugf("added iptables forward rule: %v", rule)
|
||||
}
|
||||
}
|
||||
|
||||
inputRule := r.getAcceptInputRule()
|
||||
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err))
|
||||
} else {
|
||||
log.Debugf("added iptables input rule: %v", inputRule)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) getAcceptForwardRules() [][]string {
|
||||
intf := r.wgIface.Name()
|
||||
return [][]string{
|
||||
{"-i", intf, "-j", "ACCEPT"},
|
||||
{"-o", intf, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *family) getAcceptInputRule() []string {
|
||||
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
// acceptFilterRulesNftables adds accept rules to the ip filter table using nftables.
|
||||
// This is used when iptables is not available.
|
||||
func (r *family) acceptFilterRulesNftables(table *nftables.Table) error {
|
||||
intf := ifname(r.wgIface.Name())
|
||||
|
||||
forwardChain := &nftables.Chain{
|
||||
Name: chainNameForward,
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
}
|
||||
r.insertForwardAcceptRules(forwardChain, intf)
|
||||
|
||||
inputChain := &nftables.Chain{
|
||||
Name: chainNameInput,
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookInput,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
}
|
||||
r.insertInputAcceptRule(inputChain, intf)
|
||||
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables).
|
||||
// It dynamically finds chains at call time to handle chains that may have been created after startup.
|
||||
func (r *family) acceptExternalChainsRules() error {
|
||||
chains := r.findExternalChains()
|
||||
if len(chains) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
intf := ifname(r.wgIface.Name())
|
||||
for _, chain := range chains {
|
||||
r.applyExternalChainAccept(chain, intf)
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush external chain rules: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) applyExternalChainAccept(chain *nftables.Chain, intf []byte) {
|
||||
if chain.Hooknum == nil {
|
||||
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
|
||||
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
|
||||
|
||||
switch *chain.Hooknum {
|
||||
case *nftables.ChainHookForward:
|
||||
r.insertForwardAcceptRules(chain, intf)
|
||||
case *nftables.ChainHookInput:
|
||||
r.insertInputAcceptRule(chain, intf)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *family) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
|
||||
existing, err := r.existingNetbirdRulesInChain(chain)
|
||||
if err != nil {
|
||||
log.Warnf("skip forward accept rules in %s/%s: %v", chain.Table.Name, chain.Name, err)
|
||||
return
|
||||
}
|
||||
r.insertForwardIifRule(chain, intf, existing)
|
||||
r.insertForwardOifEstablishedRule(chain, intf, existing)
|
||||
}
|
||||
|
||||
func (r *family) insertForwardIifRule(chain *nftables.Chain, intf []byte, existing map[string]bool) {
|
||||
if existing[userDataAcceptForwardRuleIif] {
|
||||
return
|
||||
}
|
||||
r.conn.InsertRule(&nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
},
|
||||
UserData: []byte(userDataAcceptForwardRuleIif),
|
||||
})
|
||||
}
|
||||
|
||||
func (r *family) insertForwardOifEstablishedRule(chain *nftables.Chain, intf []byte, existing map[string]bool) {
|
||||
if existing[userDataAcceptForwardRuleOif] {
|
||||
return
|
||||
}
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
|
||||
}
|
||||
r.conn.InsertRule(&nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: append(exprs, getEstablishedExprs(2)...),
|
||||
UserData: []byte(userDataAcceptForwardRuleOif),
|
||||
})
|
||||
}
|
||||
|
||||
func (r *family) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
|
||||
existing, err := r.existingNetbirdRulesInChain(chain)
|
||||
if err != nil {
|
||||
log.Warnf("skip input accept rule in %s/%s: %v", chain.Table.Name, chain.Name, err)
|
||||
return
|
||||
}
|
||||
if existing[userDataAcceptInputRule] {
|
||||
return
|
||||
}
|
||||
r.conn.InsertRule(&nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
},
|
||||
UserData: []byte(userDataAcceptInputRule),
|
||||
})
|
||||
}
|
||||
|
||||
// existingNetbirdRulesInChain returns the set of netbird-owned UserData tags present in a chain; callers must bail on error since InsertRule is additive.
|
||||
func (r *family) existingNetbirdRulesInChain(chain *nftables.Chain) (map[string]bool, error) {
|
||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list rules: %w", err)
|
||||
}
|
||||
present := map[string]bool{}
|
||||
for _, rule := range rules {
|
||||
if !isNetbirdAcceptRuleTag(rule.UserData) {
|
||||
continue
|
||||
}
|
||||
present[string(rule.UserData)] = true
|
||||
}
|
||||
return present, nil
|
||||
}
|
||||
|
||||
func isNetbirdAcceptRuleTag(userData []byte) bool {
|
||||
switch string(userData) {
|
||||
case userDataAcceptForwardRuleIif,
|
||||
userDataAcceptForwardRuleOif,
|
||||
userDataAcceptInputRule:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *family) removeAcceptFilterRules() error {
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := r.removeFilterTableRules(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.removeExternalChainsRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) removeFilterTableRules() error {
|
||||
if r.filterTable == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
|
||||
if err != nil {
|
||||
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
|
||||
return r.removeAcceptRulesFromTable(r.filterTable)
|
||||
}
|
||||
|
||||
if err := r.removeAcceptFilterRulesIptables(ipt); err != nil {
|
||||
log.Debugf("iptables removal failed (table may be incompatible), falling back to nftables: %v", err)
|
||||
return r.removeAcceptRulesFromTable(r.filterTable)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) removeAcceptRulesFromTable(table *nftables.Table) error {
|
||||
chains, err := r.conn.ListChainsOfTableFamily(table.Family)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %v", err)
|
||||
}
|
||||
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name != table.Name {
|
||||
continue
|
||||
}
|
||||
|
||||
if chain.Name != chainNameForward && chain.Name != chainNameInput {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := r.removeAcceptRulesFromChain(table, chain); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
func (r *family) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error {
|
||||
rules, err := r.conn.GetRules(table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err)
|
||||
}
|
||||
|
||||
for _, rule := range rules {
|
||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeExternalChainsRules removes our accept rules from all external chains.
|
||||
// This is deterministic - it scans for chains at removal time rather than relying on saved state,
|
||||
// ensuring cleanup works even after a crash or if chains changed.
|
||||
func (r *family) removeExternalChainsRules() error {
|
||||
chains := r.findExternalChains()
|
||||
if len(chains) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, chain := range chains {
|
||||
if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil {
|
||||
log.Warnf("remove rules from external chain %s/%s: %v", chain.Table.Name, chain.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks.
|
||||
// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal).
|
||||
func (r *family) findExternalChains() []*nftables.Chain {
|
||||
var chains []*nftables.Chain
|
||||
|
||||
families := []nftables.TableFamily{r.af.tableFamily, nftables.TableFamilyINet}
|
||||
|
||||
for _, family := range families {
|
||||
allChains, err := r.conn.ListChainsOfTableFamily(family)
|
||||
if err != nil {
|
||||
log.Debugf("list chains for family %d: %v", family, err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, chain := range allChains {
|
||||
if r.isExternalChain(chain) {
|
||||
chains = append(chains, chain)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return chains
|
||||
}
|
||||
|
||||
func (r *family) isExternalChain(chain *nftables.Chain) bool {
|
||||
if r.workTable != nil && chain.Table.Name == r.workTable.Name {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip firewalld-owned chains. Firewalld creates its chains with the
|
||||
// NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM.
|
||||
// We delegate acceptance to firewalld by trusting the interface instead.
|
||||
if chain.Table.Name == firewalldTableName {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip iptables/ip6tables-managed tables (adding nft-native rules breaks iptables-save compat)
|
||||
if (chain.Table.Family == nftables.TableFamilyIPv4 || chain.Table.Family == nftables.TableFamilyIPv6) && isIptablesTable(chain.Table.Name) {
|
||||
return false
|
||||
}
|
||||
|
||||
if chain.Type != nftables.ChainTypeFilter {
|
||||
return false
|
||||
}
|
||||
|
||||
if chain.Hooknum == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput
|
||||
}
|
||||
|
||||
func isIptablesTable(name string) bool {
|
||||
switch name {
|
||||
case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *family) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||
var merr *multierror.Error
|
||||
|
||||
for _, rule := range r.getAcceptForwardRules() {
|
||||
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
inputRule := r.getAcceptInputRule()
|
||||
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) createDefaultAllowRules() error {
|
||||
expIn := []expr.Any{
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
|
||||
_ = r.conn.InsertRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chainInputRules,
|
||||
Position: 0,
|
||||
Exprs: expIn,
|
||||
})
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush rule/chain/set operations from the buffer
|
||||
//
|
||||
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||
func (r *family) Flush() error {
|
||||
if err := r.flushWithBackoff(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.refreshRuleHandles(r.chainInputRules, false); err != nil {
|
||||
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
||||
}
|
||||
if err := r.refreshRuleHandles(r.chainPrerouting, true); err != nil {
|
||||
log.Errorf("failed to refresh rule handles prerouting chain: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
|
||||
if r.chainPrerouting == nil {
|
||||
log.Warn("prerouting chain is not created")
|
||||
return nil
|
||||
}
|
||||
|
||||
preroutingExprs := slices.Clone(expressions)
|
||||
|
||||
// interface
|
||||
preroutingExprs = append([]expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
}, preroutingExprs...)
|
||||
|
||||
// local destination and mark
|
||||
preroutingExprs = append(preroutingExprs,
|
||||
&expr.Fib{
|
||||
Register: 1,
|
||||
ResultADDRTYPE: true,
|
||||
FlagDADDR: true,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
|
||||
},
|
||||
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
)
|
||||
|
||||
nfRule := r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chainPrerouting,
|
||||
Exprs: preroutingExprs,
|
||||
UserData: userData,
|
||||
})
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
log.Errorf("failed to flush mangle rule %s: %v", string(userData), err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return nfRule
|
||||
}
|
||||
|
||||
func (r *family) createDefaultChains() (err error) {
|
||||
// chainNameInputRules
|
||||
chain := r.createChain(chainNameInputRules)
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
r.chainInputRules = chain
|
||||
|
||||
// netbird-acl-input-filter
|
||||
// type filter hook input priority filter; policy accept;
|
||||
chain = r.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
|
||||
r.addJumpRule(chain, r.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
|
||||
r.addDropExpressions(chain, expr.MetaKeyIIFNAME)
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// netbird-acl-forward-filter
|
||||
chainFwFilter := r.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
|
||||
r.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
|
||||
r.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME)
|
||||
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
if err := r.allowRedirectedTraffic(chainFwFilter); err != nil {
|
||||
log.Errorf("failed to allow redirected traffic: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Makes redirected traffic originally destined for the host itself (now subject to the forward filter)
|
||||
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
||||
// netbird peer IP.
|
||||
func (r *family) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
||||
r.chainPrerouting = r.chains[chainNameManglePrerouting]
|
||||
|
||||
r.addFwmarkToForward(chainFwFilter)
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
||||
r.conn.InsertRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: chainFwFilter,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (r *family) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictJump,
|
||||
Chain: r.routingFwChainName,
|
||||
},
|
||||
}
|
||||
|
||||
_ = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: chainFwFilter,
|
||||
Exprs: expressions,
|
||||
})
|
||||
}
|
||||
|
||||
func (r *family) createChain(name string) *nftables.Chain {
|
||||
chain := &nftables.Chain{
|
||||
Name: name,
|
||||
Table: r.workTable,
|
||||
}
|
||||
|
||||
chain = r.conn.AddChain(chain)
|
||||
|
||||
insertReturnTrafficRule(r.conn, r.workTable, chain)
|
||||
|
||||
return chain
|
||||
}
|
||||
|
||||
func (r *family) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain {
|
||||
polAccept := nftables.ChainPolicyAccept
|
||||
chain := &nftables.Chain{
|
||||
Name: name,
|
||||
Table: r.workTable,
|
||||
Hooknum: hookNum,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Policy: &polAccept,
|
||||
}
|
||||
|
||||
return r.conn.AddChain(chain)
|
||||
}
|
||||
|
||||
func (r *family) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
||||
}
|
||||
_ = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictJump,
|
||||
Chain: to,
|
||||
},
|
||||
}
|
||||
|
||||
_ = r.conn.AddRule(&nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
})
|
||||
}
|
||||
|
||||
func (r *family) flushWithBackoff() (err error) {
|
||||
backoff := 4
|
||||
backoffTime := 1000 * time.Millisecond
|
||||
for i := 0; ; i++ {
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to flush nftables: %v", err)
|
||||
if !strings.Contains(err.Error(), "busy") {
|
||||
return
|
||||
}
|
||||
log.Error("failed to flush nftables, retrying...")
|
||||
if i == backoff-1 {
|
||||
return err
|
||||
}
|
||||
time.Sleep(backoffTime)
|
||||
backoffTime *= 2
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *family) refreshRuleHandles(chain *nftables.Chain, mangle bool) error {
|
||||
if r.workTable == nil || chain == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
list, err := r.conn.GetRules(r.workTable, chain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, rule := range list {
|
||||
if len(rule.UserData) == 0 {
|
||||
continue
|
||||
}
|
||||
pr, ok := r.filters[firewall.RuleID(rule.UserData)]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if mangle {
|
||||
if pr.mangleRule != nil {
|
||||
*pr.mangleRule = *rule
|
||||
}
|
||||
} else {
|
||||
*pr.nftRule = *rule
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
533
client/firewall/nftables/dnat_linux.go
Normal file
533
client/firewall/nftables/dnat_linux.go
Normal file
@@ -0,0 +1,533 @@
|
||||
//go:build !android
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/google/nftables/xt"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
func (r *family) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ruleID := rule.ID()
|
||||
if _, exists := r.rules[ruleID+dnatSuffix]; exists {
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
protoNum, err := r.af.protoNum(rule.Protocol)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addDnatRedirect(rule, protoNum, ruleID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.addDnatMasq(rule, protoNum, ruleID)
|
||||
|
||||
// Unlike iptables, there's no point in adding "out" rules in the forward chain here as our policy is ACCEPT.
|
||||
// To overcome DROP policies in other chains, we'd have to add rules to the chains there.
|
||||
// We also cannot just add "oif <iface> accept" there and filter in our own table as we don't know what is supposed to be allowed.
|
||||
// TODO: find chains with drop policies and add rules there
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush rules: %w", err)
|
||||
}
|
||||
|
||||
return &rule, nil
|
||||
}
|
||||
|
||||
func (r *family) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, ruleID firewall.RuleID) error {
|
||||
dnatExprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{protoNum},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
}
|
||||
dnatExprs = append(dnatExprs, applyPort(&rule.DestinationPort, false)...)
|
||||
|
||||
// shifted translated port is not supported in nftables, so we hand this over to xtables
|
||||
if rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2 {
|
||||
if rule.TranslatedPort.Values[0] != rule.DestinationPort.Values[0] ||
|
||||
rule.TranslatedPort.Values[1] != rule.DestinationPort.Values[1] {
|
||||
return r.addXTablesRedirect(dnatExprs, ruleID, rule)
|
||||
}
|
||||
}
|
||||
|
||||
additionalExprs, regProtoMin, regProtoMax, err := r.handleTranslatedPort(rule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dnatExprs = append(dnatExprs, additionalExprs...)
|
||||
|
||||
dnatExprs = append(dnatExprs,
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(r.af.tableFamily),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: regProtoMin,
|
||||
RegProtoMax: regProtoMax,
|
||||
},
|
||||
)
|
||||
|
||||
dnatRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingRdr],
|
||||
Exprs: dnatExprs,
|
||||
UserData: []byte(ruleID + dnatSuffix),
|
||||
}
|
||||
r.conn.AddRule(dnatRule)
|
||||
r.rules[ruleID+dnatSuffix] = dnatRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) handleTranslatedPort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||
switch {
|
||||
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
|
||||
return r.handlePortRange(rule)
|
||||
case len(rule.TranslatedPort.Values) == 0:
|
||||
return r.handleAddressOnly(rule)
|
||||
case len(rule.TranslatedPort.Values) == 1:
|
||||
return r.handleSinglePort(rule)
|
||||
default:
|
||||
return nil, 0, 0, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *family) handlePortRange(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||
exprs := []expr.Any{
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: rule.TranslatedAddress.AsSlice(),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 3,
|
||||
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[1]),
|
||||
},
|
||||
}
|
||||
return exprs, 2, 3, nil
|
||||
}
|
||||
|
||||
func (r *family) handleAddressOnly(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||
exprs := []expr.Any{
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: rule.TranslatedAddress.AsSlice(),
|
||||
},
|
||||
}
|
||||
return exprs, 0, 0, nil
|
||||
}
|
||||
|
||||
func (r *family) handleSinglePort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||
exprs := []expr.Any{
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: rule.TranslatedAddress.AsSlice(),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
|
||||
},
|
||||
}
|
||||
return exprs, 2, 0, nil
|
||||
}
|
||||
|
||||
func (r *family) addXTablesRedirect(dnatExprs []expr.Any, ruleID firewall.RuleID, rule firewall.ForwardRule) error {
|
||||
dnatExprs = append(dnatExprs,
|
||||
&expr.Counter{},
|
||||
&expr.Target{
|
||||
Name: "DNAT",
|
||||
Rev: 2,
|
||||
Info: &xt.NatRange2{
|
||||
NatRange: xt.NatRange{
|
||||
Flags: uint(xt.NatRangeMapIPs | xt.NatRangeProtoSpecified | xt.NatRangeProtoOffset),
|
||||
MinIP: rule.TranslatedAddress.AsSlice(),
|
||||
MaxIP: rule.TranslatedAddress.AsSlice(),
|
||||
MinPort: rule.TranslatedPort.Values[0],
|
||||
MaxPort: rule.TranslatedPort.Values[1],
|
||||
},
|
||||
BasePort: rule.DestinationPort.Values[0],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
natTable := &nftables.Table{
|
||||
Name: tableNat,
|
||||
Family: r.af.tableFamily,
|
||||
}
|
||||
dnatRule := &nftables.Rule{
|
||||
Table: natTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: chainNameNatPrerouting,
|
||||
Table: natTable,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityNATDest,
|
||||
},
|
||||
Exprs: dnatExprs,
|
||||
UserData: []byte(ruleID + dnatSuffix),
|
||||
}
|
||||
r.conn.AddRule(dnatRule)
|
||||
r.rules[ruleID+dnatSuffix] = dnatRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleID firewall.RuleID) {
|
||||
masqExprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{protoNum},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: r.af.dstAddrOffset,
|
||||
Len: r.af.addrLen,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: rule.TranslatedAddress.AsSlice(),
|
||||
},
|
||||
}
|
||||
|
||||
masqExprs = append(masqExprs, applyPort(&rule.TranslatedPort, false)...)
|
||||
masqExprs = append(masqExprs, &expr.Masq{})
|
||||
|
||||
masqRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: masqExprs,
|
||||
UserData: []byte(ruleID + snatSuffix),
|
||||
}
|
||||
r.conn.AddRule(masqRule)
|
||||
r.rules[ruleID+snatSuffix] = masqRule
|
||||
}
|
||||
|
||||
func (r *family) DeleteDNATRule(rule firewall.Rule) error {
|
||||
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
log.Errorf("%v", err)
|
||||
}
|
||||
|
||||
ruleID := rule.ID()
|
||||
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
var needsFlush bool
|
||||
|
||||
if dnatRule, exists := r.rules[ruleID+dnatSuffix]; exists {
|
||||
if dnatRule.Handle == 0 {
|
||||
log.Warnf("dnat rule %s has no handle, removing stale entry", ruleID+dnatSuffix)
|
||||
delete(r.rules, ruleID+dnatSuffix)
|
||||
} else if err := r.conn.DelRule(dnatRule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
|
||||
} else {
|
||||
needsFlush = true
|
||||
}
|
||||
}
|
||||
|
||||
if masqRule, exists := r.rules[ruleID+snatSuffix]; exists {
|
||||
if masqRule.Handle == 0 {
|
||||
log.Warnf("snat rule %s has no handle, removing stale entry", ruleID+snatSuffix)
|
||||
delete(r.rules, ruleID+snatSuffix)
|
||||
} else if err := r.conn.DelRule(masqRule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
|
||||
} else {
|
||||
needsFlush = true
|
||||
}
|
||||
}
|
||||
|
||||
if needsFlush {
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||
}
|
||||
}
|
||||
|
||||
if merr == nil {
|
||||
delete(r.rules, ruleID+dnatSuffix)
|
||||
delete(r.rules, ruleID+snatSuffix)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
protoNum, err := r.af.protoNum(protocol)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 2,
|
||||
Data: []byte{protoNum},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 3,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 3,
|
||||
Data: binaryutil.BigEndian.PutUint16(originalPort),
|
||||
},
|
||||
}
|
||||
|
||||
bits := 32
|
||||
if localAddr.Is6() {
|
||||
bits = 128
|
||||
}
|
||||
exprs = append(exprs, prefixMatchExprs(r.af, netip.PrefixFrom(localAddr, bits), false)...)
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: localAddr.AsSlice(),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(translatedPort),
|
||||
},
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(r.af.tableFamily),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: 2,
|
||||
RegProtoMax: 0,
|
||||
},
|
||||
)
|
||||
|
||||
dnatRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingRdr],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleID),
|
||||
}
|
||||
r.conn.AddRule(dnatRule)
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("add inbound DNAT rule: %w", err)
|
||||
}
|
||||
|
||||
r.rules[ruleID] = dnatRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (r *family) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||
|
||||
rule, exists := r.rules[ruleID]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
log.Warnf("inbound DNAT rule %s has no handle, removing stale entry", ruleID)
|
||||
delete(r.rules, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureNATOutputChain lazily creates the OUTPUT NAT chain on first use.
|
||||
func (r *family) ensureNATOutputChain() error {
|
||||
if _, exists := r.chains[chainNameNATOutput]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.chains[chainNameNATOutput] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameNATOutput,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookOutput,
|
||||
Priority: nftables.ChainPriorityNATDest,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
delete(r.chains, chainNameNATOutput)
|
||||
return fmt.Errorf("create NAT output chain: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (r *family) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.ensureNATOutputChain(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
protoNum, err := r.af.protoNum(protocol)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{protoNum},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 2,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(originalPort),
|
||||
},
|
||||
}
|
||||
|
||||
bits := 32
|
||||
if localAddr.Is6() {
|
||||
bits = 128
|
||||
}
|
||||
exprs = append(exprs, prefixMatchExprs(r.af, netip.PrefixFrom(localAddr, bits), false)...)
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: localAddr.AsSlice(),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(translatedPort),
|
||||
},
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(r.af.tableFamily),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: 2,
|
||||
},
|
||||
)
|
||||
|
||||
dnatRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameNATOutput],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleID),
|
||||
}
|
||||
r.conn.AddRule(dnatRule)
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("add output DNAT rule: %w", err)
|
||||
}
|
||||
|
||||
r.rules[ruleID] = dnatRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (r *family) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||
|
||||
rule, exists := r.rules[ruleID]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
log.Warnf("output DNAT rule %s has no handle, removing stale entry", ruleID)
|
||||
delete(r.rules, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete output DNAT rule %s: %w", ruleID, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush delete output DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
|
||||
return nil
|
||||
}
|
||||
249
client/firewall/nftables/family_linux.go
Normal file
249
client/firewall/nftables/family_linux.go
Normal file
@@ -0,0 +1,249 @@
|
||||
//go:build !android
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/nftables"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
)
|
||||
|
||||
const (
|
||||
tableNat = "nat"
|
||||
tableMangle = "mangle"
|
||||
tableRaw = "raw"
|
||||
tableSecurity = "security"
|
||||
|
||||
chainNameNatPrerouting = "PREROUTING"
|
||||
chainNameRoutingFw = "netbird-rt-fwd"
|
||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||
chainNameRoutingRdr = "netbird-rt-redirect"
|
||||
chainNameNATOutput = "netbird-nat-output"
|
||||
chainNameForward = "FORWARD"
|
||||
chainNameMangleForward = "netbird-mangle-forward"
|
||||
|
||||
// Peer ACL chain names.
|
||||
chainNameInputRules = "netbird-acl-input-rules"
|
||||
chainNameInputFilter = "netbird-acl-input-filter"
|
||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||
chainNameManglePrerouting = "netbird-mangle-prerouting"
|
||||
chainNameManglePostrouting = "netbird-mangle-postrouting"
|
||||
|
||||
flushError = "flush: %w"
|
||||
|
||||
firewalldTableName = "firewalld"
|
||||
|
||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||
userDataAcceptForwardRuleOif = "frwacceptoif"
|
||||
userDataAcceptInputRule = "inputaccept"
|
||||
|
||||
dnatSuffix firewall.RuleID = "_dnat"
|
||||
snatSuffix firewall.RuleID = "_snat"
|
||||
|
||||
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
|
||||
ipv4TCPHeaderSize = 40
|
||||
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
|
||||
ipv6TCPHeaderSize = 60
|
||||
|
||||
// maxPrefixesSet 1638 prefixes start to fail, taking some margin
|
||||
maxPrefixesSet = 1500
|
||||
refreshRulesMapError = "refresh rules map: %w"
|
||||
)
|
||||
|
||||
var (
|
||||
errFilterTableNotFound = fmt.Errorf("'filter' table not found")
|
||||
)
|
||||
|
||||
type setInput struct {
|
||||
set firewall.Set
|
||||
prefixes []netip.Prefix
|
||||
}
|
||||
|
||||
// family holds the per-address-family nftables state. One instance
|
||||
// handles route ACLs, peer ACLs, NAT, DNAT, and MSS clamping for a
|
||||
// single family; the top-level Manager owns one for v4 and another
|
||||
// for v6. The name predates the peer-ACL absorption; it's effectively
|
||||
// the per-family backend now.
|
||||
type family struct {
|
||||
conn *nftables.Conn
|
||||
workTable *nftables.Table
|
||||
filterTable *nftables.Table
|
||||
chains map[string]*nftables.Chain
|
||||
|
||||
// filters holds peer + route filter rules keyed by content hash.
|
||||
// AddFilterRule writes here; DeleteFilterRule looks up by id.
|
||||
filters map[firewall.RuleID]*Rule
|
||||
|
||||
// rules holds NAT, DNAT, and external accept rules (auxiliary
|
||||
// plumbing that isn't a filter rule).
|
||||
rules map[firewall.RuleID]*nftables.Rule
|
||||
|
||||
// Peer ACL chain handles.
|
||||
chainInputRules *nftables.Chain
|
||||
chainPrerouting *nftables.Chain
|
||||
routingFwChainName string
|
||||
|
||||
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
|
||||
|
||||
af addrFamily
|
||||
wgIface iFaceMapper
|
||||
ipFwdState *ipfwdstate.IPForwardingState
|
||||
legacyManagement bool
|
||||
mtu uint16
|
||||
}
|
||||
|
||||
func newFamily(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*family, error) {
|
||||
r := &family{
|
||||
conn: &nftables.Conn{},
|
||||
workTable: workTable,
|
||||
chains: make(map[string]*nftables.Chain),
|
||||
filters: make(map[firewall.RuleID]*Rule),
|
||||
rules: make(map[firewall.RuleID]*nftables.Rule),
|
||||
routingFwChainName: chainNameRoutingFw,
|
||||
af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4),
|
||||
wgIface: wgIface,
|
||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||
mtu: mtu,
|
||||
}
|
||||
|
||||
r.ipsetCounter = refcounter.New(
|
||||
r.createIpSet,
|
||||
r.deleteIpSet,
|
||||
)
|
||||
|
||||
var err error
|
||||
r.filterTable, err = r.loadFilterTable()
|
||||
if err != nil {
|
||||
log.Debugf("ip filter table not found: %v", err)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *family) init(workTable *nftables.Table) error {
|
||||
r.workTable = workTable
|
||||
|
||||
if err := r.removeAcceptFilterRules(); err != nil {
|
||||
log.Errorf("failed to clean up rules from filter table: %s", err)
|
||||
}
|
||||
|
||||
if err := r.createContainers(); err != nil {
|
||||
return fmt.Errorf("create containers: %w", err)
|
||||
}
|
||||
|
||||
if err := r.setupDataPlaneMark(); err != nil {
|
||||
log.Errorf("failed to set up data plane mark: %v", err)
|
||||
}
|
||||
|
||||
if err := r.createDefaultChains(); err != nil {
|
||||
return fmt.Errorf("create default acl chains: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset cleans existing nftables filter table rules from the system
|
||||
func (r *family) Reset() error {
|
||||
// clear without deleting the ipsets, the nf table will be deleted by the caller
|
||||
r.ipsetCounter.Clear()
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := r.removeAcceptFilterRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
|
||||
}
|
||||
|
||||
if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.removeNatPreroutingRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) loadFilterTable() (*nftables.Table, error) {
|
||||
tables, err := r.conn.ListTablesOfFamily(r.af.tableFamily)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list tables: %w", err)
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
if table.Name == "filter" {
|
||||
return table, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errFilterTableNotFound
|
||||
}
|
||||
|
||||
func hookName(hook *nftables.ChainHook) string {
|
||||
if hook == nil {
|
||||
return "unknown"
|
||||
}
|
||||
switch *hook {
|
||||
case *nftables.ChainHookForward:
|
||||
return chainNameForward
|
||||
case *nftables.ChainHookInput:
|
||||
return chainNameInput
|
||||
default:
|
||||
return fmt.Sprintf("hook(%d)", *hook)
|
||||
}
|
||||
}
|
||||
|
||||
func familyName(family nftables.TableFamily) string {
|
||||
switch family {
|
||||
case nftables.TableFamilyIPv4:
|
||||
return "ip"
|
||||
case nftables.TableFamilyIPv6:
|
||||
return "ip6"
|
||||
case nftables.TableFamilyINet:
|
||||
return "inet"
|
||||
default:
|
||||
return fmt.Sprintf("family(%d)", family)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *family) iptablesProto() iptables.Protocol {
|
||||
if r.af.tableFamily == nftables.TableFamilyIPv6 {
|
||||
return iptables.ProtocolIPv6
|
||||
}
|
||||
return iptables.ProtocolIPv4
|
||||
}
|
||||
|
||||
func (r *family) refreshRulesMap() error {
|
||||
var merr *multierror.Error
|
||||
newRules := make(map[firewall.RuleID]*nftables.Rule)
|
||||
for _, chain := range r.chains {
|
||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("list rules for chain %s: %w", chain.Name, err))
|
||||
// preserve existing entries for this chain since we can't verify their state
|
||||
for k, v := range r.rules {
|
||||
if v.Chain != nil && v.Chain.Name == chain.Name {
|
||||
newRules[k] = v
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
newRules[firewall.RuleID(rule.UserData)] = rule
|
||||
}
|
||||
}
|
||||
}
|
||||
r.rules = newRules
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
441
client/firewall/nftables/filter_linux.go
Normal file
441
client/firewall/nftables/filter_linux.go
Normal file
@@ -0,0 +1,441 @@
|
||||
//go:build !android
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
)
|
||||
|
||||
// AddFilterRule installs one nftables packet-filter rule. With
|
||||
// destination empty the rule goes to the peer ACL input chain plus a
|
||||
// paired prerouting mangle rule for the redirect mark. With
|
||||
// destination set (prefix or named set) it goes to the route ACL
|
||||
// forward chain. Multi-source rules collapse to one nftables rule
|
||||
// backed by the shared refcounted hash:net set.
|
||||
func (r *family) AddFilterRule(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
isRoute := !destination.IsZero()
|
||||
|
||||
ruleID := nbid.GenerateRuleID(sources, destination, proto, sPort, dPort, action)
|
||||
if existing, ok := r.filters[ruleID]; ok {
|
||||
return existing, nil
|
||||
}
|
||||
|
||||
srcExprs, err := r.applyNetwork(sourceNetwork(sources), sources, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("apply source: %w", err)
|
||||
}
|
||||
|
||||
var exprs []expr.Any
|
||||
if isRoute {
|
||||
exprs, err = r.buildRouteFilterExprs(srcExprs, destination, proto, sPort, dPort)
|
||||
} else {
|
||||
exprs, err = r.buildPeerFilterExprs(srcExprs, proto, sPort, dPort)
|
||||
}
|
||||
if err != nil {
|
||||
r.dropNetworkMatch(srcExprs)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mainExprs := slices.Clone(exprs)
|
||||
verdict := expr.VerdictAccept
|
||||
if action == firewall.ActionDrop {
|
||||
verdict = expr.VerdictDrop
|
||||
}
|
||||
mainExprs = append(mainExprs, &expr.Verdict{Kind: verdict})
|
||||
|
||||
chain := r.chainInputRules
|
||||
if isRoute {
|
||||
chain = r.chains[chainNameRoutingFw]
|
||||
}
|
||||
|
||||
userData := []byte(ruleID)
|
||||
nftRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: chain,
|
||||
Exprs: mainExprs,
|
||||
UserData: userData,
|
||||
}
|
||||
if action == firewall.ActionDrop {
|
||||
nftRule = r.conn.InsertRule(nftRule)
|
||||
} else {
|
||||
nftRule = r.conn.AddRule(nftRule)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
r.dropNetworkMatch(exprs)
|
||||
return nil, fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
rule := &Rule{
|
||||
nftRule: nftRule,
|
||||
sources: sources,
|
||||
id: ruleID,
|
||||
}
|
||||
if !isRoute {
|
||||
rule.mangleRule = r.createPreroutingRule(exprs, userData)
|
||||
}
|
||||
r.filters[ruleID] = rule
|
||||
|
||||
log.Debugf("added filter rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v",
|
||||
sources, destination, proto, sPort, dPort, action)
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// buildPeerFilterExprs assembles the input-chain (peer ACL) match: the
|
||||
// IP-header protocol byte read via Payload, then source, then ports
|
||||
// (no counter), matching the historical peer shape so per-rule kernel
|
||||
// state is identical to pre-unification.
|
||||
func (r *family) buildPeerFilterExprs(
|
||||
srcExprs []expr.Any,
|
||||
proto firewall.Protocol,
|
||||
sPort, dPort *firewall.Port,
|
||||
) ([]expr.Any, error) {
|
||||
var exprs []expr.Any
|
||||
|
||||
if proto != firewall.ProtocolALL {
|
||||
protoNum, err := r.af.protoNum(proto)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
exprs = append(exprs,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: r.af.protoOffset,
|
||||
Len: 1,
|
||||
},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{protoNum}},
|
||||
)
|
||||
}
|
||||
exprs = append(exprs, srcExprs...)
|
||||
exprs = append(exprs, applyPort(sPort, true)...)
|
||||
exprs = append(exprs, applyPort(dPort, false)...)
|
||||
return exprs, nil
|
||||
}
|
||||
|
||||
// buildRouteFilterExprs assembles the forward-chain (route ACL) match:
|
||||
// source, then destination, then optional proto/ports, then a counter.
|
||||
func (r *family) buildRouteFilterExprs(
|
||||
srcExprs []expr.Any,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort, dPort *firewall.Port,
|
||||
) ([]expr.Any, error) {
|
||||
exprs := append([]expr.Any{}, srcExprs...)
|
||||
|
||||
destExprs, err := r.applyNetwork(destination, nil, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("apply destination: %w", err)
|
||||
}
|
||||
exprs = append(exprs, destExprs...)
|
||||
|
||||
if proto != firewall.ProtocolALL {
|
||||
protoNum, err := r.af.protoNum(proto)
|
||||
if err != nil {
|
||||
r.dropNetworkMatch(destExprs)
|
||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
exprs = append(exprs,
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{protoNum}},
|
||||
)
|
||||
exprs = append(exprs, applyPort(sPort, true)...)
|
||||
exprs = append(exprs, applyPort(dPort, false)...)
|
||||
}
|
||||
|
||||
exprs = append(exprs, &expr.Counter{})
|
||||
return exprs, nil
|
||||
}
|
||||
|
||||
func (r *family) hasRule(id firewall.RuleID) bool {
|
||||
_, ok := r.filters[id]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r *family) hasDNATRule(id firewall.RuleID) bool {
|
||||
_, ok := r.rules[id+dnatSuffix]
|
||||
return ok
|
||||
}
|
||||
|
||||
// DeleteFilterRule removes a previously installed filter rule. Source
|
||||
// set references are recovered from the stored rule's expressions via
|
||||
// findSets and dropped from the shared refcounter.
|
||||
func (r *family) DeleteFilterRule(rule firewall.Rule) error {
|
||||
ruleID := rule.ID()
|
||||
pr, ok := r.filters[ruleID]
|
||||
if !ok {
|
||||
log.Debugf("filter rule %s not found", ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// A freshly added rule carries no handle until it is read back from
|
||||
// the kernel, and Flush only refreshes the peer chains. Pull live
|
||||
// handles for this rule's chain before deciding it is stale so route
|
||||
// rules (which Flush never refreshes) can actually be deleted.
|
||||
if pr.nftRule.Handle == 0 {
|
||||
if err := r.refreshRuleHandles(pr.nftRule.Chain, false); err != nil {
|
||||
log.Warnf("refresh handles for chain %s: %v", pr.nftRule.Chain.Name, err)
|
||||
}
|
||||
if pr.mangleRule != nil {
|
||||
if err := r.refreshRuleHandles(r.chainPrerouting, true); err != nil {
|
||||
log.Warnf("refresh mangle handles: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if pr.nftRule.Handle == 0 {
|
||||
log.Warnf("filter rule %s has no handle, removing stale entry", ruleID)
|
||||
r.dropNetworkMatch(pr.nftRule.Exprs)
|
||||
delete(r.filters, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(pr.nftRule); err != nil {
|
||||
log.Errorf("queue rule delete: %v", err)
|
||||
}
|
||||
if pr.mangleRule != nil {
|
||||
if err := r.conn.DelRule(pr.mangleRule); err != nil {
|
||||
log.Errorf("queue mangle rule delete: %v", err)
|
||||
}
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush delete %s: %w", ruleID, err)
|
||||
}
|
||||
|
||||
r.dropNetworkMatch(pr.nftRule.Exprs)
|
||||
delete(r.filters, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) decrementSetCounter(rule *nftables.Rule) error {
|
||||
if r.ipsetCounter == nil {
|
||||
return nil
|
||||
}
|
||||
sets := findSets(rule)
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, setName := range sets {
|
||||
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("decrement set counter: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// dropNetworkMatch undoes whatever the source/destination match
|
||||
// reserved. Safe to call when the spec is empty or holds only inline
|
||||
// matchers.
|
||||
func (r *family) dropNetworkMatch(exprs []expr.Any) {
|
||||
if r.ipsetCounter == nil {
|
||||
return
|
||||
}
|
||||
for _, e := range exprs {
|
||||
lookup, ok := e.(*expr.Lookup)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, err := r.ipsetCounter.Decrement(lookup.SetName); err != nil {
|
||||
log.Errorf("rollback ipset decrement %s: %v", lookup.SetName, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *family) applyNetwork(
|
||||
network firewall.Network,
|
||||
setPrefixes []netip.Prefix,
|
||||
isSource bool,
|
||||
) ([]expr.Any, error) {
|
||||
if network.IsSet() {
|
||||
exprs, err := r.getIpSet(network.Set, setPrefixes, isSource)
|
||||
if err != nil {
|
||||
side := "destination"
|
||||
if isSource {
|
||||
side = "source"
|
||||
}
|
||||
return nil, fmt.Errorf("%s set: %w", side, err)
|
||||
}
|
||||
return exprs, nil
|
||||
}
|
||||
|
||||
if network.IsPrefix() {
|
||||
return prefixMatchExprs(r.af, network.Prefix, isSource), nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// prefixMatchExprs is the family-aware match sequence for a CIDR
|
||||
// prefix. /0 returns nil; a host prefix (full bit length for the
|
||||
// family) skips the bitwise step since the mask is all-ones. Shared
|
||||
// between family and aclManager so both treat single prefixes
|
||||
// identically.
|
||||
func prefixMatchExprs(af addrFamily, prefix netip.Prefix, isSource bool) []expr.Any {
|
||||
offset := af.dstAddrOffset
|
||||
if isSource {
|
||||
offset = af.srcAddrOffset
|
||||
}
|
||||
|
||||
ones := prefix.Bits()
|
||||
if ones == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
payload := &expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: offset,
|
||||
Len: af.addrLen,
|
||||
}
|
||||
cmp := &expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: prefix.Masked().Addr().AsSlice(),
|
||||
}
|
||||
|
||||
if ones == af.totalBits {
|
||||
return []expr.Any{payload, cmp}
|
||||
}
|
||||
|
||||
mask := net.CIDRMask(ones, af.totalBits)
|
||||
xor := make([]byte, af.addrLen)
|
||||
return []expr.Any{
|
||||
payload,
|
||||
&expr.Bitwise{
|
||||
DestRegister: 1,
|
||||
SourceRegister: 1,
|
||||
Len: af.addrLen,
|
||||
Mask: mask,
|
||||
Xor: xor,
|
||||
},
|
||||
cmp,
|
||||
}
|
||||
}
|
||||
|
||||
func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
||||
if port == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var exprs []expr.Any
|
||||
|
||||
// src
|
||||
offset := uint32(2)
|
||||
if isSource {
|
||||
// dst
|
||||
offset = 0
|
||||
}
|
||||
|
||||
exprs = append(exprs, &expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: offset,
|
||||
Len: 2,
|
||||
})
|
||||
|
||||
if port.IsRange && len(port.Values) == 2 {
|
||||
exprs = append(exprs,
|
||||
&expr.Range{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
FromData: binaryutil.BigEndian.PutUint16(port.Values[0]),
|
||||
ToData: binaryutil.BigEndian.PutUint16(port.Values[1]),
|
||||
},
|
||||
)
|
||||
} else {
|
||||
for i, p := range port.Values {
|
||||
if i > 0 {
|
||||
exprs = append(exprs, &expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: []byte{0x00, 0x00, 0xff, 0xff},
|
||||
Xor: []byte{0x00, 0x00, 0x00, 0x00},
|
||||
})
|
||||
}
|
||||
exprs = append(exprs, &expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(p),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return exprs
|
||||
}
|
||||
|
||||
func getCtNewExprs() []expr.Any {
|
||||
return []expr.Any{
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// sourceNetwork classifies a source-prefix list into the firewall.Network
|
||||
// shape the rest of the spec-builder consumes: empty for match-any, a
|
||||
// single prefix inline, or an ipset for multiple sources.
|
||||
func sourceNetwork(sources []netip.Prefix) firewall.Network {
|
||||
switch {
|
||||
case len(sources) == 0:
|
||||
return firewall.Network{}
|
||||
case len(sources) == 1 && sources[0].Bits() == 0:
|
||||
return firewall.Network{}
|
||||
case len(sources) == 1:
|
||||
return firewall.Network{Prefix: sources[0]}
|
||||
default:
|
||||
return firewall.Network{Set: firewall.NewPrefixSet(sources)}
|
||||
}
|
||||
}
|
||||
|
||||
func ifname(n string) []byte {
|
||||
b := make([]byte, 16)
|
||||
copy(b, n+"\x00")
|
||||
return b
|
||||
}
|
||||
|
||||
// findSets scans an nftables rule's expressions for expr.Lookup and
|
||||
// returns the named sets in occurrence order. Used at delete time to
|
||||
// drop ipsetCounter references; peer and route ACLs go through it.
|
||||
func findSets(rule *nftables.Rule) []string {
|
||||
var sets []string
|
||||
for _, e := range rule.Exprs {
|
||||
if lookup, ok := e.(*expr.Lookup); ok {
|
||||
sets = append(sets, lookup.SetName)
|
||||
}
|
||||
}
|
||||
return sets
|
||||
}
|
||||
196
client/firewall/nftables/ipset_linux.go
Normal file
196
client/firewall/nftables/ipset_linux.go
Normal file
@@ -0,0 +1,196 @@
|
||||
//go:build !android
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
)
|
||||
|
||||
func (r *family) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bool) ([]expr.Any, error) {
|
||||
ref, err := r.ipsetCounter.Increment(set.HashedName(), setInput{
|
||||
set: set,
|
||||
prefixes: prefixes,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||
}
|
||||
|
||||
return r.getIpSetExprs(ref, isSource)
|
||||
}
|
||||
|
||||
func (r *family) createIpSet(setName string, input setInput) (*nftables.Set, error) {
|
||||
// overlapping prefixes will result in an error, so we need to merge them
|
||||
prefixes := firewall.MergeIPRanges(input.prefixes)
|
||||
|
||||
nfset := &nftables.Set{
|
||||
Name: setName,
|
||||
Comment: input.set.Comment(),
|
||||
Table: r.workTable,
|
||||
// required for prefixes
|
||||
Interval: true,
|
||||
KeyType: r.af.setKeyType,
|
||||
}
|
||||
|
||||
elements := r.convertPrefixesToSet(prefixes)
|
||||
nElements := len(elements)
|
||||
|
||||
maxElements := maxPrefixesSet * 2
|
||||
initialElements := elements[:min(maxElements, nElements)]
|
||||
|
||||
if err := r.conn.AddSet(nfset, initialElements); err != nil {
|
||||
return nil, fmt.Errorf("error adding set %s: %w", setName, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush error: %w", err)
|
||||
}
|
||||
log.Debugf("Created new ipset: %s with %d initial prefixes (total prefixes %d)", setName, len(initialElements)/2, len(prefixes))
|
||||
|
||||
// The set is committed now. If a later batch fails, destroy it: the
|
||||
// refcounter records nothing on a create-callback error, so it would
|
||||
// otherwise leak, and a partial source set fails-open for deny rules.
|
||||
if err := r.addRemainingElements(nfset, elements, maxElements); err != nil {
|
||||
if derr := r.deleteIpSet(setName, nfset); derr != nil {
|
||||
log.Warnf("rollback ipset %s after add failure: %v", setName, derr)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("Created new ipset: %s with %d prefixes", setName, len(prefixes))
|
||||
return nfset, nil
|
||||
}
|
||||
|
||||
// addRemainingElements adds element batches beyond the initial one in
|
||||
// maxElements-sized chunks, flushing each. Called after the set has been
|
||||
// created with its first batch.
|
||||
func (r *family) addRemainingElements(nfset *nftables.Set, elements []nftables.SetElement, maxElements int) error {
|
||||
nElements := len(elements)
|
||||
for subStart := maxElements; subStart < nElements; subStart += maxElements {
|
||||
subEnd := min(subStart+maxElements, nElements)
|
||||
subElement := elements[subStart:subEnd]
|
||||
nSubPrefixes := len(subElement) / 2
|
||||
log.Tracef("Adding new prefixes (%d) in ipset: %s", nSubPrefixes, nfset.Name)
|
||||
if err := r.conn.SetAddElements(nfset, subElement); err != nil {
|
||||
return fmt.Errorf("error adding prefixes (%d) to set %s: %w", nSubPrefixes, nfset.Name, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush error: %w", err)
|
||||
}
|
||||
log.Debugf("Added new prefixes (%d) in ipset: %s", nSubPrefixes, nfset.Name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
|
||||
var elements []nftables.SetElement
|
||||
for _, prefix := range prefixes {
|
||||
// nftables needs half-open intervals [firstIP, lastIP) for prefixes
|
||||
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
|
||||
firstIP := prefix.Addr()
|
||||
lastIP := calculateLastIP(prefix).Next()
|
||||
|
||||
elements = append(elements,
|
||||
// the nft tool also adds a zero-address IntervalEnd element, see https://github.com/google/nftables/issues/247
|
||||
// nftables.SetElement{Key: make([]byte, r.af.addrLen), IntervalEnd: true},
|
||||
nftables.SetElement{Key: firstIP.AsSlice()},
|
||||
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
|
||||
)
|
||||
}
|
||||
return elements
|
||||
}
|
||||
|
||||
// calculateLastIP determines the last IP in a given prefix.
|
||||
func calculateLastIP(prefix netip.Prefix) netip.Addr {
|
||||
masked := prefix.Masked()
|
||||
if masked.Addr().Is4() {
|
||||
hostMask := ^uint32(0) >> masked.Bits()
|
||||
lastIP := uint32FromNetipAddr(masked.Addr()) | hostMask
|
||||
return netip.AddrFrom4(uint32ToBytes(lastIP))
|
||||
}
|
||||
|
||||
// IPv6: set host bits to all 1s
|
||||
b := masked.Addr().As16()
|
||||
bits := masked.Bits()
|
||||
for i := bits; i < 128; i++ {
|
||||
b[i/8] |= 1 << (7 - i%8)
|
||||
}
|
||||
return netip.AddrFrom16(b)
|
||||
}
|
||||
|
||||
// Utility function to convert netip.Addr to uint32.
|
||||
func uint32FromNetipAddr(addr netip.Addr) uint32 {
|
||||
b := addr.As4()
|
||||
return binary.BigEndian.Uint32(b[:])
|
||||
}
|
||||
|
||||
// Utility function to convert uint32 to a netip-compatible byte slice.
|
||||
func uint32ToBytes(ip uint32) [4]byte {
|
||||
var b [4]byte
|
||||
binary.BigEndian.PutUint32(b[:], ip)
|
||||
return b
|
||||
}
|
||||
|
||||
func (r *family) deleteIpSet(setName string, nfset *nftables.Set) error {
|
||||
r.conn.DelSet(nfset)
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
log.Debugf("Deleted unused ipset %s", setName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
nfset, err := r.conn.GetSetByName(r.workTable, set.HashedName())
|
||||
if err != nil {
|
||||
return fmt.Errorf("get set %s: %w", set.HashedName(), err)
|
||||
}
|
||||
|
||||
// Overlapping prefixes (e.g. duplicate resolved addresses) make the
|
||||
// interval set reject the batch, so merge them as createIpSet does.
|
||||
prefixes = firewall.MergeIPRanges(prefixes)
|
||||
|
||||
elements := r.convertPrefixesToSet(prefixes)
|
||||
if err := r.conn.SetAddElements(nfset, elements); err != nil {
|
||||
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
|
||||
// dst offset by default
|
||||
offset := r.af.dstAddrOffset
|
||||
if isSource {
|
||||
// src offset
|
||||
offset = r.af.srcAddrOffset
|
||||
}
|
||||
|
||||
return []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: offset,
|
||||
Len: r.af.addrLen,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: ref.Out.Name,
|
||||
SetID: ref.Out.ID,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
type ipsetStore struct {
|
||||
ipsetReference map[string]int
|
||||
ipsets map[string]map[string]struct{} // ipsetName -> list of ips
|
||||
}
|
||||
|
||||
func newIpsetStore() *ipsetStore {
|
||||
return &ipsetStore{
|
||||
ipsetReference: make(map[string]int),
|
||||
ipsets: make(map[string]map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ipsetStore) ips(ipsetName string) (map[string]struct{}, bool) {
|
||||
r, ok := s.ipsets[ipsetName]
|
||||
return r, ok
|
||||
}
|
||||
|
||||
func (s *ipsetStore) newIpset(ipsetName string) map[string]struct{} {
|
||||
s.ipsetReference[ipsetName] = 0
|
||||
ipList := make(map[string]struct{})
|
||||
s.ipsets[ipsetName] = ipList
|
||||
return ipList
|
||||
}
|
||||
|
||||
func (s *ipsetStore) deleteIpset(ipsetName string) {
|
||||
delete(s.ipsetReference, ipsetName)
|
||||
delete(s.ipsets, ipsetName)
|
||||
}
|
||||
|
||||
func (s *ipsetStore) DeleteIpFromSet(ipsetName string, ip net.IP) {
|
||||
ipList, ok := s.ipsets[ipsetName]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(ipList, ip.String())
|
||||
}
|
||||
|
||||
func (s *ipsetStore) AddIpToSet(ipsetName string, ip net.IP) {
|
||||
ipList, ok := s.ipsets[ipsetName]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ipList[ip.String()] = struct{}{}
|
||||
}
|
||||
|
||||
func (s *ipsetStore) IsIpInSet(ipsetName string, ip net.IP) bool {
|
||||
ipList, ok := s.ipsets[ipsetName]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
_, ok = ipList[ip.String()]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *ipsetStore) AddReferenceToIpset(ipsetName string) {
|
||||
s.ipsetReference[ipsetName]++
|
||||
}
|
||||
|
||||
func (s *ipsetStore) DeleteReferenceFromIpSet(ipsetName string) {
|
||||
r, ok := s.ipsetReference[ipsetName]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if r == 0 {
|
||||
return
|
||||
}
|
||||
s.ipsetReference[ipsetName]--
|
||||
}
|
||||
|
||||
func (s *ipsetStore) HasReferenceToSet(ipsetName string) bool {
|
||||
if _, ok := s.ipsetReference[ipsetName]; !ok {
|
||||
return false
|
||||
}
|
||||
if s.ipsetReference[ipsetName] == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package nftables
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
@@ -45,18 +44,17 @@ type iFaceMapper interface {
|
||||
Address() wgaddr.Address
|
||||
}
|
||||
|
||||
// Manager of iptables firewall
|
||||
// Manager of nftables firewall. Per-family state (peer ACLs, route
|
||||
// ACLs, NAT, DNAT, MSS clamping) lives on family; Manager dispatches
|
||||
// by family and provides the public firewall.Manager surface.
|
||||
type Manager struct {
|
||||
mutex sync.Mutex
|
||||
rConn *nftables.Conn
|
||||
wgIface iFaceMapper
|
||||
|
||||
router *router
|
||||
aclManager *AclManager
|
||||
|
||||
// IPv6 counterparts, nil when no v6 overlay
|
||||
router6 *router
|
||||
aclManager6 *AclManager
|
||||
family4 *family
|
||||
// IPv6 counterpart, nil when no v6 overlay.
|
||||
family6 *family
|
||||
|
||||
notrackOutputChain *nftables.Chain
|
||||
notrackPreroutingChain *nftables.Chain
|
||||
@@ -75,14 +73,9 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
||||
workTable := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}
|
||||
|
||||
var err error
|
||||
m.router, err = newRouter(workTable, wgIface, mtu)
|
||||
m.family4, err = newFamily(workTable, wgIface, mtu)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create router: %w", err)
|
||||
}
|
||||
|
||||
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||
return nil, fmt.Errorf("create family: %w", err)
|
||||
}
|
||||
|
||||
if wgIface.Address().HasIPv6() {
|
||||
@@ -100,26 +93,21 @@ func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mt
|
||||
workTable6 := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6}
|
||||
|
||||
var err error
|
||||
m.router6, err = newRouter(workTable6, wgIface, mtu)
|
||||
m.family6, err = newFamily(workTable6, wgIface, mtu)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create v6 router: %w", err)
|
||||
return fmt.Errorf("create v6 family: %w", err)
|
||||
}
|
||||
|
||||
// Share the same IP forwarding state with the v4 router, since
|
||||
// EnableIPForwarding controls both v4 and v6 sysctls.
|
||||
m.router6.ipFwdState = m.router.ipFwdState
|
||||
|
||||
m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create v6 acl manager: %w", err)
|
||||
}
|
||||
m.family6.ipFwdState = m.family4.ipFwdState
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasIPv6 reports whether the manager has IPv6 components initialized.
|
||||
func (m *Manager) hasIPv6() bool {
|
||||
return m.router6 != nil
|
||||
return m.family6 != nil
|
||||
}
|
||||
|
||||
func (m *Manager) initIPv6() error {
|
||||
@@ -128,12 +116,8 @@ func (m *Manager) initIPv6() error {
|
||||
return fmt.Errorf("create v6 work table: %w", err)
|
||||
}
|
||||
|
||||
if err := m.router6.init(workTable6); err != nil {
|
||||
return fmt.Errorf("v6 router init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.aclManager6.init(workTable6); err != nil {
|
||||
return fmt.Errorf("v6 acl manager init: %w", err)
|
||||
if err := m.family6.init(workTable6); err != nil {
|
||||
return fmt.Errorf("v6 family init: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -162,13 +146,13 @@ func (m *Manager) reconcileExternalChains() error {
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
if m.router != nil {
|
||||
if err := m.router.acceptExternalChainsRules(); err != nil {
|
||||
if m.family4 != nil {
|
||||
if err := m.family4.acceptExternalChainsRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("v4: %w", err))
|
||||
}
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
if err := m.router6.acceptExternalChainsRules(); err != nil {
|
||||
if err := m.family6.acceptExternalChainsRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("v6: %w", err))
|
||||
}
|
||||
}
|
||||
@@ -187,12 +171,8 @@ func (m *Manager) initFirewall() (err error) {
|
||||
}
|
||||
}()
|
||||
|
||||
if err := m.router.init(workTable); err != nil {
|
||||
return fmt.Errorf("router init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.aclManager.init(workTable); err != nil {
|
||||
return fmt.Errorf("acl manager init: %w", err)
|
||||
if err := m.family4.init(workTable); err != nil {
|
||||
return fmt.Errorf("family init: %w", err)
|
||||
}
|
||||
|
||||
if m.hasIPv6() {
|
||||
@@ -220,7 +200,7 @@ func (m *Manager) persistState(stateManager *statemanager.Manager) {
|
||||
InterfaceState: &InterfaceState{
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
MTU: m.router.mtu,
|
||||
MTU: m.family4.mtu,
|
||||
},
|
||||
}); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
@@ -235,12 +215,12 @@ func (m *Manager) persistState(stateManager *statemanager.Manager) {
|
||||
|
||||
// rollbackInit performs best-effort cleanup of already-initialized state when Init fails partway through.
|
||||
func (m *Manager) rollbackInit() {
|
||||
if err := m.router.Reset(); err != nil {
|
||||
log.Warnf("rollback router: %v", err)
|
||||
if err := m.family4.Reset(); err != nil {
|
||||
log.Warnf("rollback family: %v", err)
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
if err := m.router6.Reset(); err != nil {
|
||||
log.Warnf("rollback v6 router: %v", err)
|
||||
if err := m.family6.Reset(); err != nil {
|
||||
log.Warnf("rollback v6 family: %v", err)
|
||||
}
|
||||
}
|
||||
if err := m.cleanupNetbirdTables(); err != nil {
|
||||
@@ -251,118 +231,77 @@ func (m *Manager) rollbackInit() {
|
||||
}
|
||||
}
|
||||
|
||||
// AddPeerFiltering rule to the firewall
|
||||
// AddFilterRule installs a packet-filtering rule.
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if ip.To4() != nil {
|
||||
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.aclManager6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
// Destination semantics: zero Network → input chain (peer ACL);
|
||||
// set Network → forward chain (route ACL).
|
||||
//
|
||||
// Sources are a single address family; the rule is dispatched to the
|
||||
// matching per-family backend.
|
||||
func (m *Manager) AddFilterRule(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort, dPort *firewall.Port,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
if len(sources) == 0 {
|
||||
return nil, firewall.ErrNoSources
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if isIPv6RouteRule(sources, destination) {
|
||||
fam := m.family4
|
||||
if isIPv6Rule(sources, destination) {
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
|
||||
return nil, fmt.Errorf("add filtering: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
fam = m.family6
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
return fam.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
// DeleteFilterRule removes a filtering rule. The owning family is found
|
||||
// by id, refreshing from the kernel if the in-memory caches miss so a
|
||||
// stale cache cannot leak the rule. family.DeleteFilterRule is idempotent
|
||||
// when the id is absent.
|
||||
func (m *Manager) DeleteFilterRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && isIPv6Rule(rule) {
|
||||
return m.aclManager6.DeletePeerRule(rule)
|
||||
}
|
||||
return m.aclManager.DeletePeerRule(rule)
|
||||
}
|
||||
|
||||
func isIPv6Rule(rule firewall.Rule) bool {
|
||||
r, ok := rule.(*Rule)
|
||||
return ok && r.nftRule != nil && r.nftRule.Table != nil && r.nftRule.Table.Family == nftables.TableFamilyIPv6
|
||||
}
|
||||
|
||||
// isIPv6RouteRule determines whether a route rule belongs to the v6 table.
|
||||
// For static routes, the destination prefix determines the family. For dynamic
|
||||
// routes (DomainSet), the sources determine the family since management
|
||||
// duplicates dynamic rules per family.
|
||||
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
|
||||
if destination.IsPrefix() {
|
||||
return destination.Prefix.Addr().Is6()
|
||||
}
|
||||
return len(sources) > 0 && sources[0].Addr().Is6()
|
||||
}
|
||||
|
||||
// DeleteRouteRule deletes a routing rule. Route rules live in exactly one
|
||||
// router; the cached maps are normally authoritative, so the kernel is only
|
||||
// consulted when neither map knows about the rule.
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
id := rule.ID()
|
||||
r, err := m.routerForRuleID(id, (*router).hasRule)
|
||||
fam, err := m.familyForRuleID(rule.ID(), (*family).hasRule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return r.DeleteRouteRule(rule)
|
||||
return fam.DeleteFilterRule(rule)
|
||||
}
|
||||
|
||||
// routerForRuleID picks the router holding the rule with the given id, using
|
||||
// familyForRuleID picks the family holding the rule with the given id, using
|
||||
// the supplied lookup. If the cached maps disagree (or both miss), it refreshes
|
||||
// from the kernel once and re-checks before falling back to the v4 router.
|
||||
func (m *Manager) routerForRuleID(id string, has func(*router, string) bool) (*router, error) {
|
||||
if has(m.router, id) {
|
||||
return m.router, nil
|
||||
// from the kernel once and re-checks before falling back to the v4 family.
|
||||
func (m *Manager) familyForRuleID(id firewall.RuleID, has func(*family, firewall.RuleID) bool) (*family, error) {
|
||||
if has(m.family4, id) {
|
||||
return m.family4, nil
|
||||
}
|
||||
if m.hasIPv6() && has(m.router6, id) {
|
||||
return m.router6, nil
|
||||
if m.hasIPv6() && has(m.family6, id) {
|
||||
return m.family6, nil
|
||||
}
|
||||
if !m.hasIPv6() {
|
||||
return m.router, nil
|
||||
return m.family4, nil
|
||||
}
|
||||
if err := m.router.refreshRulesMap(); err != nil {
|
||||
if err := m.family4.refreshRulesMap(); err != nil {
|
||||
return nil, fmt.Errorf("refresh v4 rules: %w", err)
|
||||
}
|
||||
if err := m.router6.refreshRulesMap(); err != nil {
|
||||
if err := m.family6.refreshRulesMap(); err != nil {
|
||||
return nil, fmt.Errorf("refresh v6 rules: %w", err)
|
||||
}
|
||||
if has(m.router6, id) && !has(m.router, id) {
|
||||
return m.router6, nil
|
||||
if has(m.family6, id) && !has(m.family4, id) {
|
||||
return m.family6, nil
|
||||
}
|
||||
return m.router, nil
|
||||
return m.family4, nil
|
||||
}
|
||||
|
||||
func (m *Manager) IsServerRouteSupported() bool {
|
||||
@@ -381,10 +320,10 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddNatRule(pair)
|
||||
return m.family6.AddNatRule(pair)
|
||||
}
|
||||
|
||||
if err := m.router.AddNatRule(pair); err != nil {
|
||||
if err := m.family4.AddNatRule(pair); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -396,7 +335,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
// so the eventual cleanup still works.
|
||||
if m.hasIPv6() && pair.Dynamic {
|
||||
v6Pair := firewall.ToV6NatPair(pair)
|
||||
if err := m.router6.AddNatRule(v6Pair); err != nil {
|
||||
if err := m.family6.AddNatRule(v6Pair); err != nil {
|
||||
return fmt.Errorf("add v6 NAT rule: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -412,18 +351,18 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if !m.hasIPv6() {
|
||||
return nil
|
||||
}
|
||||
return m.router6.RemoveNatRule(pair)
|
||||
return m.family6.RemoveNatRule(pair)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := m.router.RemoveNatRule(pair); err != nil {
|
||||
if err := m.family4.RemoveNatRule(pair); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
|
||||
}
|
||||
|
||||
if m.hasIPv6() && pair.Dynamic {
|
||||
v6Pair := firewall.ToV6NatPair(pair)
|
||||
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
|
||||
if err := m.family6.RemoveNatRule(v6Pair); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
|
||||
}
|
||||
}
|
||||
@@ -445,11 +384,11 @@ func (m *Manager) AllowNetbird() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if err := m.aclManager.createDefaultAllowRules(); err != nil {
|
||||
if err := m.family4.createDefaultAllowRules(); err != nil {
|
||||
return fmt.Errorf("create default allow rules: %w", err)
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
if err := m.aclManager6.createDefaultAllowRules(); err != nil {
|
||||
if err := m.family6.createDefaultAllowRules(); err != nil {
|
||||
return fmt.Errorf("create v6 default allow rules: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -466,11 +405,11 @@ func (m *Manager) AllowNetbird() error {
|
||||
|
||||
// SetLegacyManagement sets the route manager to use legacy management
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
|
||||
if err := firewall.SetLegacyManagement(m.family4, isLegacy); err != nil {
|
||||
return err
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
return firewall.SetLegacyManagement(m.router6, isLegacy)
|
||||
return firewall.SetLegacyManagement(m.family6, isLegacy)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -484,13 +423,13 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := m.router.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset router: %v", err))
|
||||
if err := m.family4.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset family: %w", err))
|
||||
}
|
||||
|
||||
if m.hasIPv6() {
|
||||
if err := m.router6.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %v", err))
|
||||
if err := m.family6.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 family: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -530,14 +469,14 @@ func (m *Manager) SetLogLevel(log.Level) {
|
||||
}
|
||||
|
||||
func (m *Manager) EnableRouting() error {
|
||||
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
||||
if err := m.family4.ipFwdState.RequestForwarding(); err != nil {
|
||||
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DisableRouting() error {
|
||||
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
if err := m.family4.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
@@ -551,12 +490,12 @@ func (m *Manager) Flush() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if err := m.aclManager.Flush(); err != nil {
|
||||
if err := m.family4.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.hasIPv6() {
|
||||
if err := m.aclManager6.Flush(); err != nil {
|
||||
if err := m.family6.Flush(); err != nil {
|
||||
return fmt.Errorf("flush v6 acl: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -577,9 +516,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddDNATRule(rule)
|
||||
return m.family6.AddDNATRule(rule)
|
||||
}
|
||||
return m.router.AddDNATRule(rule)
|
||||
return m.family4.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
@@ -587,7 +526,7 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
r, err := m.routerForRuleID(rule.ID(), (*router).hasDNATRule)
|
||||
r, err := m.familyForRuleID(rule.ID(), (*family).hasDNATRule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -608,12 +547,12 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
|
||||
if err := m.family4.UpdateSet(set, v4Prefixes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.hasIPv6() && len(v6Prefixes) > 0 {
|
||||
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
|
||||
if err := m.family6.UpdateSet(set, v6Prefixes); err != nil {
|
||||
return fmt.Errorf("update v6 set: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -630,9 +569,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family4.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
@@ -644,9 +583,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family4.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
@@ -658,9 +597,9 @@ func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family4.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
@@ -672,9 +611,9 @@ func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Proto
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family4.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -903,3 +842,14 @@ func getEstablishedExprs(register uint32) []expr.Any {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// isIPv6Rule reports whether the rule belongs to the v6 table. For a
|
||||
// prefix destination the destination family decides; otherwise the
|
||||
// (single-family) sources do, since management duplicates rules per
|
||||
// family.
|
||||
func isIPv6Rule(sources []netip.Prefix, destination firewall.Network) bool {
|
||||
if destination.IsPrefix() {
|
||||
return destination.Prefix.Addr().Is6()
|
||||
}
|
||||
return len(sources) > 0 && sources[0].Addr().Is6()
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build integration && !android
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
@@ -70,13 +72,13 @@ func TestNftablesManager(t *testing.T) {
|
||||
|
||||
testClient := &nftables.Conn{}
|
||||
|
||||
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
||||
rule, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
|
||||
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
||||
rules, err := testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
|
||||
require.Len(t, rules, 2, "expected 2 rules")
|
||||
@@ -147,15 +149,12 @@ func TestNftablesManager(t *testing.T) {
|
||||
// Compare connection tracking rule at position 1 (pushed down by DROP rule insertion)
|
||||
compareExprsIgnoringCounters(t, rules[1].Exprs, expectedExprs1)
|
||||
|
||||
for _, r := range rule {
|
||||
err = manager.DeletePeerRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
}
|
||||
require.NoError(t, manager.DeleteFilterRule(rule), "failed to delete rule")
|
||||
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
|
||||
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
||||
rules, err = testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
// established rule remains
|
||||
require.Len(t, rules, 1, "expected 1 rules after deletion")
|
||||
@@ -180,47 +179,39 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
|
||||
testClient := &nftables.Conn{}
|
||||
|
||||
// Add accept rule first
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "accept-http")
|
||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
require.NoError(t, err, "failed to add accept rule")
|
||||
|
||||
// Add deny rule second for the same traffic
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop, "deny-http")
|
||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||
require.NoError(t, err, "failed to add deny rule")
|
||||
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
|
||||
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
||||
rules, err := testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
|
||||
t.Logf("Found %d rules in nftables chain", len(rules))
|
||||
|
||||
// Find the accept and deny rules and verify deny comes before accept
|
||||
// Single-source rules emit a direct payload+cmp on the source IP
|
||||
// (no set lookup). Match by source-IP + port + verdict instead of
|
||||
// the legacy per-(action,port) set names ("deny-http"/"accept-http")
|
||||
// that this test predates.
|
||||
wantSrc := ip.AsSlice()
|
||||
var acceptRuleIndex, denyRuleIndex = -1, -1
|
||||
for i, rule := range rules {
|
||||
hasAcceptHTTPSet := false
|
||||
hasDenyHTTPSet := false
|
||||
hasPort80 := false
|
||||
var hasSrc, hasPort80 bool
|
||||
var action string
|
||||
|
||||
for _, e := range rule.Exprs {
|
||||
// Check for set lookup
|
||||
if lookup, ok := e.(*expr.Lookup); ok {
|
||||
switch lookup.SetName {
|
||||
case "accept-http":
|
||||
hasAcceptHTTPSet = true
|
||||
case "deny-http":
|
||||
hasDenyHTTPSet = true
|
||||
if cmp, ok := e.(*expr.Cmp); ok && cmp.Op == expr.CmpOpEq {
|
||||
if bytes.Equal(cmp.Data, wantSrc) {
|
||||
hasSrc = true
|
||||
}
|
||||
|
||||
}
|
||||
// Check for port 80
|
||||
if cmp, ok := e.(*expr.Cmp); ok {
|
||||
if cmp.Op == expr.CmpOpEq && len(cmp.Data) == 2 && binary.BigEndian.Uint16(cmp.Data) == 80 {
|
||||
if len(cmp.Data) == 2 && binary.BigEndian.Uint16(cmp.Data) == 80 {
|
||||
hasPort80 = true
|
||||
}
|
||||
}
|
||||
// Check for verdict
|
||||
if verdict, ok := e.(*expr.Verdict); ok {
|
||||
switch verdict.Kind {
|
||||
case expr.VerdictAccept:
|
||||
@@ -231,11 +222,15 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
if hasAcceptHTTPSet && hasPort80 && action == "ACCEPT" {
|
||||
t.Logf("Rule [%d]: accept-http set + Port 80 + ACCEPT", i)
|
||||
if !hasSrc || !hasPort80 {
|
||||
continue
|
||||
}
|
||||
switch action {
|
||||
case "ACCEPT":
|
||||
t.Logf("Rule [%d]: src=%s port=80 ACCEPT", i, ip)
|
||||
acceptRuleIndex = i
|
||||
} else if hasDenyHTTPSet && hasPort80 && action == "DROP" {
|
||||
t.Logf("Rule [%d]: deny-http set + Port 80 + DROP", i)
|
||||
case "DROP":
|
||||
t.Logf("Rule [%d]: src=%s port=80 DROP", i, ip)
|
||||
denyRuleIndex = i
|
||||
}
|
||||
}
|
||||
@@ -279,7 +274,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
if i%100 == 0 {
|
||||
@@ -361,10 +356,10 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
})
|
||||
|
||||
ip := netip.MustParseAddr("100.96.0.1")
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
require.NoError(t, err, "failed to add peer filtering rule")
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
_, err = manager.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("10.1.0.0/24")},
|
||||
@@ -437,10 +432,10 @@ func TestNftablesManagerIPv6CompatibilityWithIp6tables(t *testing.T) {
|
||||
})
|
||||
|
||||
ip := netip.MustParseAddr("fd00::2")
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
require.NoError(t, err, "add v6 peer filtering rule")
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
_, err = manager.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("fd00:1::/64")},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
|
||||
@@ -550,7 +545,7 @@ func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
|
||||
prefixes = append(prefixes, netip.PrefixFrom(addr, 24))
|
||||
}
|
||||
}
|
||||
_, err = manager.AddRouteFiltering(
|
||||
_, err = manager.AddFilterRule(
|
||||
nil,
|
||||
prefixes,
|
||||
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
|
||||
@@ -565,7 +560,7 @@ func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
}
|
||||
|
||||
func TestNftablesManagerCompatibilityWithIptablesForEmptyPrefixes(t *testing.T) {
|
||||
func TestNftablesManagerCompatibilityWithIptablesForWildcardSource(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
@@ -591,9 +586,9 @@ func TestNftablesManagerCompatibilityWithIptablesForEmptyPrefixes(t *testing.T)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
})
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
_, err = manager.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{},
|
||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build integration && !android
|
||||
|
||||
package nftables
|
||||
|
||||
@@ -37,7 +37,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
|
||||
for _, testCase := range test.InsertRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
// need fw manager to init both acl mgr and router for all chains to be present
|
||||
// need fw manager to init both acl mgr and family for all chains to be present
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -47,7 +47,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
rtr := manager.router
|
||||
rtr := manager.family4
|
||||
err = rtr.AddNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "pair should be inserted")
|
||||
|
||||
@@ -90,9 +90,9 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
}
|
||||
|
||||
// Build CIDR matching expressions
|
||||
testRouter := &router{af: afIPv4}
|
||||
sourceExp := testRouter.applyPrefix(testCase.InputPair.Source.Prefix, true)
|
||||
destExp := testRouter.applyPrefix(testCase.InputPair.Destination.Prefix, false)
|
||||
testRouter := &family{af: afIPv4}
|
||||
sourceExp := prefixMatchExprs(testRouter.af, testCase.InputPair.Source.Prefix, true)
|
||||
destExp := prefixMatchExprs(testRouter.af, testCase.InputPair.Destination.Prefix, false)
|
||||
|
||||
// Combine all expressions in the correct order
|
||||
// nolint:gocritic
|
||||
@@ -100,14 +100,14 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
testingExpression = append(testingExpression, sourceExp...)
|
||||
testingExpression = append(testingExpression, destExp...)
|
||||
|
||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||
natRuleKey := testCase.InputPair.GenKey(firewall.PreroutingFormat)
|
||||
found := 0
|
||||
for _, chain := range rtr.chains {
|
||||
if chain.Name == chainNameManglePrerouting {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
|
||||
// Compare expressions up to the mark setting expressions
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "prerouting nat rule elements should match")
|
||||
found = 1
|
||||
@@ -135,19 +135,19 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
rtr := manager.router
|
||||
rtr := manager.family4
|
||||
|
||||
// First add the NAT rule using the router's method
|
||||
// First add the NAT rule using the family's method
|
||||
err = rtr.AddNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "should add NAT rule")
|
||||
|
||||
// Verify the rule was added
|
||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||
natRuleKey := testCase.InputPair.GenKey(firewall.PreroutingFormat)
|
||||
found := false
|
||||
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||
require.NoError(t, err, "should list rules")
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
@@ -163,7 +163,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||
require.NoError(t, err, "should list rules after removal")
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
@@ -200,11 +200,11 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create router")
|
||||
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create family")
|
||||
require.NoError(t, r.init(workTable))
|
||||
|
||||
defer func(r *router) {
|
||||
defer func(r *family) {
|
||||
require.NoError(t, r.Reset(), "Failed to reset rules")
|
||||
}(r)
|
||||
|
||||
@@ -314,16 +314,16 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddRouteFiltering failed")
|
||||
ruleKey, err := r.AddFilterRule(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddFilterRule failed")
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule")
|
||||
require.NoError(t, r.DeleteFilterRule(ruleKey), "Failed to delete rule")
|
||||
})
|
||||
|
||||
// Check if the rule is in the internal map
|
||||
rule, ok := r.rules[ruleKey.ID()]
|
||||
assert.True(t, ok, "Rule not found in internal map")
|
||||
stored, ok := r.filters[id.RuleID(ruleKey.ID())]
|
||||
require.True(t, ok, "Rule not found in filters map")
|
||||
rule := stored.nftRule
|
||||
|
||||
t.Log("Internal rule expressions:")
|
||||
for i, expr := range rule.Exprs {
|
||||
@@ -339,7 +339,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
var nftRule *nftables.Rule
|
||||
for _, rule := range rules {
|
||||
if string(rule.UserData) == ruleKey.ID() {
|
||||
if firewall.RuleID(rule.UserData) == ruleKey.ID() {
|
||||
nftRule = rule
|
||||
break
|
||||
}
|
||||
@@ -367,12 +367,12 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create router")
|
||||
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create family")
|
||||
require.NoError(t, r.init(workTable))
|
||||
|
||||
defer func() {
|
||||
require.NoError(t, r.Reset(), "Failed to reset router")
|
||||
require.NoError(t, r.Reset(), "Failed to reset family")
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
@@ -509,6 +509,42 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestNftablesUpdateSetMergesOverlapping verifies that UpdateSet merges
|
||||
// overlapping prefixes before adding them. An interval set rejects
|
||||
// overlapping elements, so without the merge a batch holding a /32 already
|
||||
// covered by a /24, or a duplicate address as DNS resolution can produce,
|
||||
// would fail.
|
||||
func TestNftablesUpdateSetMergesOverlapping(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
workTable, err := createWorkTable()
|
||||
require.NoError(t, err, "create work table")
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "create family")
|
||||
require.NoError(t, r.init(workTable))
|
||||
defer func() {
|
||||
require.NoError(t, r.Reset(), "reset family")
|
||||
}()
|
||||
|
||||
initial := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}
|
||||
set := firewall.NewPrefixSet(initial)
|
||||
|
||||
created, err := r.createIpSet(set.HashedName(), setInput{prefixes: initial})
|
||||
require.NoError(t, err, "create ip set")
|
||||
require.NotNil(t, created)
|
||||
|
||||
overlapping := []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("192.168.1.1/32"),
|
||||
netip.MustParsePrefix("192.168.1.1/32"),
|
||||
}
|
||||
require.NoError(t, r.UpdateSet(set, overlapping), "UpdateSet must merge overlapping prefixes")
|
||||
}
|
||||
|
||||
func TestNftablesCreateIpSet_IPv6(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
@@ -518,11 +554,11 @@ func TestNftablesCreateIpSet_IPv6(t *testing.T) {
|
||||
require.NoError(t, err, "Failed to create v6 work table")
|
||||
defer deleteWorkTableIPv6()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create router")
|
||||
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create family")
|
||||
require.NoError(t, r.init(workTable))
|
||||
defer func() {
|
||||
require.NoError(t, r.Reset(), "Failed to reset router")
|
||||
require.NoError(t, r.Reset(), "Failed to reset family")
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
@@ -861,13 +897,13 @@ func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, r.init(workTable))
|
||||
defer func() { require.NoError(t, r.Reset()) }()
|
||||
|
||||
// Add a real rule to the kernel
|
||||
ruleKey, err := r.AddRouteFiltering(
|
||||
ruleKey, err := r.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
||||
@@ -878,11 +914,11 @@ func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, r.DeleteRouteRule(ruleKey))
|
||||
require.NoError(t, r.DeleteFilterRule(ruleKey))
|
||||
})
|
||||
|
||||
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
|
||||
staleKey := "stale-rule-that-does-not-exist"
|
||||
staleKey := firewall.RuleID("stale-rule-that-does-not-exist")
|
||||
r.rules[staleKey] = &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingFw],
|
||||
@@ -902,6 +938,55 @@ func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
|
||||
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
|
||||
}
|
||||
|
||||
// TestRouter_DeleteRouteRule_RemovesKernelRule verifies a route filter
|
||||
// rule is actually removed from the kernel on delete. The route chain is
|
||||
// not refreshed by Flush, so the stored rule carries a zero handle;
|
||||
// DeleteFilterRule must pull live handles itself before issuing the
|
||||
// delete or the kernel rule leaks. Regression test for that path.
|
||||
func TestRouter_DeleteRouteRule_RemovesKernelRule(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
workTable, err := createWorkTable()
|
||||
require.NoError(t, err)
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, r.init(workTable))
|
||||
defer func() { require.NoError(t, r.Reset()) }()
|
||||
|
||||
ruleKey, err := r.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
||||
firewall.ProtocolTCP,
|
||||
nil,
|
||||
&firewall.Port{Values: []uint16{80}},
|
||||
firewall.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
countKernelRules := func() int {
|
||||
list, err := r.conn.GetRules(r.workTable, r.chains[chainNameRoutingFw])
|
||||
require.NoError(t, err)
|
||||
n := 0
|
||||
for _, rule := range list {
|
||||
if string(rule.UserData) == string(ruleKey.ID()) {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
require.Equal(t, 1, countKernelRules(), "rule should be present in the kernel after add")
|
||||
|
||||
require.NoError(t, r.DeleteFilterRule(ruleKey))
|
||||
assert.Equal(t, 0, countKernelRules(), "rule must be removed from the kernel after delete")
|
||||
assert.NotContains(t, r.filters, ruleKey.ID(), "filters map entry should be cleared")
|
||||
}
|
||||
|
||||
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
@@ -911,24 +996,28 @@ func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, r.init(workTable))
|
||||
defer func() { require.NoError(t, r.Reset()) }()
|
||||
|
||||
// Inject a stale entry with Handle=0
|
||||
staleKey := "stale-route-rule"
|
||||
r.rules[staleKey] = &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingFw],
|
||||
Handle: 0,
|
||||
UserData: []byte(staleKey),
|
||||
staleKey := id.RuleID("stale-route-rule")
|
||||
staleRule := &Rule{
|
||||
nftRule: &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingFw],
|
||||
Handle: 0,
|
||||
UserData: []byte(staleKey),
|
||||
},
|
||||
id: staleKey,
|
||||
}
|
||||
r.filters[staleKey] = staleRule
|
||||
|
||||
// DeleteRouteRule should not return an error for stale handles
|
||||
err = r.DeleteRouteRule(id.RuleID(staleKey))
|
||||
// DeleteFilterRule should not return an error for stale handles
|
||||
err = r.DeleteFilterRule(staleRule)
|
||||
assert.NoError(t, err, "deleting a stale rule should not error")
|
||||
assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up")
|
||||
assert.NotContains(t, r.filters, staleKey, "stale entry should be cleaned up")
|
||||
}
|
||||
|
||||
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
||||
@@ -950,7 +1039,7 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
||||
Masquerade: true,
|
||||
}
|
||||
|
||||
rtr := manager.router
|
||||
rtr := manager.family4
|
||||
|
||||
// First add succeeds
|
||||
err = rtr.AddNatRule(pair)
|
||||
@@ -960,11 +1049,11 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
||||
})
|
||||
|
||||
// Corrupt the handle to simulate stale state
|
||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||
natRuleKey := pair.GenKey(firewall.PreroutingFormat)
|
||||
if rule, exists := rtr.rules[natRuleKey]; exists {
|
||||
rule.Handle = 0
|
||||
}
|
||||
inverseKey := firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair))
|
||||
inverseKey := firewall.GetInversePair(pair).GenKey(firewall.PreroutingFormat)
|
||||
if rule, exists := rtr.rules[inverseKey]; exists {
|
||||
rule.Handle = 0
|
||||
}
|
||||
@@ -979,7 +1068,7 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
||||
|
||||
found := 0
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
|
||||
found++
|
||||
}
|
||||
}
|
||||
@@ -1010,7 +1099,7 @@ func TestCalculateLastIP(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConvertPrefixesToSet_IPv6(t *testing.T) {
|
||||
r := &router{af: afIPv6}
|
||||
r := &family{af: afIPv6}
|
||||
prefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("fd00::/64"),
|
||||
netip.MustParsePrefix("2001:db8::1/128"),
|
||||
|
||||
490
client/firewall/nftables/routing_linux.go
Normal file
490
client/firewall/nftables/routing_linux.go
Normal file
@@ -0,0 +1,490 @@
|
||||
//go:build !android
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
func (r *family) AddNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
if r.legacyManagement {
|
||||
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("add legacy routing rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if pair.Masquerade {
|
||||
if err := r.addNatRule(pair); err != nil {
|
||||
return fmt.Errorf("add nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("add inverse nat rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
r.rollbackRules(pair)
|
||||
return fmt.Errorf("insert rules for %s: %w", pair.Destination, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// rollbackRules cleans up unflushed rules and their set counters after a flush failure.
|
||||
func (r *family) rollbackRules(pair firewall.RouterPair) {
|
||||
keys := []firewall.RuleID{
|
||||
pair.GenKey(firewall.ForwardingFormat),
|
||||
pair.GenKey(firewall.PreroutingFormat),
|
||||
firewall.GetInversePair(pair).GenKey(firewall.PreroutingFormat),
|
||||
}
|
||||
for _, key := range keys {
|
||||
rule, ok := r.rules[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
log.Warnf("rollback set counter for %s: %v", key, err)
|
||||
}
|
||||
delete(r.rules, key)
|
||||
}
|
||||
}
|
||||
|
||||
// addNatRule inserts a nftables rule to the conn client flush queue
|
||||
func (r *family) addNatRule(pair firewall.RouterPair) error {
|
||||
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply source: %w", err)
|
||||
}
|
||||
|
||||
destExp, err := r.applyNetwork(pair.Destination, nil, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply destination: %w", err)
|
||||
}
|
||||
|
||||
op := expr.CmpOpEq
|
||||
if pair.Inverse {
|
||||
op = expr.CmpOpNeq
|
||||
}
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: op,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
}
|
||||
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
|
||||
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
|
||||
exprs = append(exprs, getCtNewExprs()...)
|
||||
|
||||
exprs = append(exprs, sourceExp...)
|
||||
exprs = append(exprs, destExp...)
|
||||
|
||||
var markValue uint32 = nbnet.PreroutingFwmarkMasquerade
|
||||
if pair.Inverse {
|
||||
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
|
||||
}
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(markValue),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
SourceRegister: true,
|
||||
Register: 1,
|
||||
},
|
||||
)
|
||||
|
||||
ruleID := pair.GenKey(firewall.PreroutingFormat)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove prerouting rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure nat rules come first, so the mark can be overwritten.
|
||||
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
|
||||
r.rules[ruleID] = r.conn.InsertRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameManglePrerouting],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleID),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) addPostroutingRules() {
|
||||
// First masquerade rule for traffic coming in from WireGuard interface
|
||||
exprs := []expr.Any{
|
||||
// Match on the first fwmark
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasquerade),
|
||||
},
|
||||
|
||||
// We need to exclude the loopback interface as this changes the ebpf proxy port
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: ifname("lo"),
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Masq{},
|
||||
}
|
||||
|
||||
r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: exprs,
|
||||
})
|
||||
|
||||
// Second masquerade rule for traffic going out through WireGuard interface
|
||||
exprs2 := []expr.Any{
|
||||
// Match on the second fwmark
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||
},
|
||||
|
||||
// Match WireGuard interface
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Masq{},
|
||||
}
|
||||
|
||||
r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: exprs2,
|
||||
})
|
||||
}
|
||||
|
||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||
func (r *family) addMSSClampingRules() error {
|
||||
overhead := uint16(ipv4TCPHeaderSize)
|
||||
if r.af.tableFamily == nftables.TableFamilyIPv6 {
|
||||
overhead = ipv6TCPHeaderSize
|
||||
}
|
||||
if r.mtu <= overhead {
|
||||
log.Debugf("MTU %d too small for MSS clamping (overhead %d), skipping", r.mtu, overhead)
|
||||
return nil
|
||||
}
|
||||
mss := r.mtu - overhead
|
||||
|
||||
exprsOut := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyL4PROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{unix.IPPROTO_TCP},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 13,
|
||||
Len: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
DestRegister: 1,
|
||||
SourceRegister: 1,
|
||||
Len: 1,
|
||||
Mask: []byte{0x02},
|
||||
Xor: []byte{0x00},
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0x00},
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Exthdr{
|
||||
DestRegister: 1,
|
||||
Type: 2,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
Op: expr.ExthdrOpTcpopt,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpGt,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
|
||||
},
|
||||
&expr.Exthdr{
|
||||
SourceRegister: 1,
|
||||
Type: 2,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
Op: expr.ExthdrOpTcpopt,
|
||||
},
|
||||
}
|
||||
|
||||
r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameMangleForward],
|
||||
Exprs: exprsOut,
|
||||
})
|
||||
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
func (r *family) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply source: %w", err)
|
||||
}
|
||||
|
||||
destExp, err := r.applyNetwork(pair.Destination, nil, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply destination: %w", err)
|
||||
}
|
||||
|
||||
var exprs []expr.Any
|
||||
exprs = append(exprs, sourceExp...)
|
||||
exprs = append(exprs, destExp...)
|
||||
exprs = append(exprs,
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
)
|
||||
|
||||
ruleID := pair.GenKey(firewall.ForwardingFormat)
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.rules[ruleID] = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingFw],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleID),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls
|
||||
func (r *family) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
ruleID := pair.GenKey(firewall.ForwardingFormat)
|
||||
|
||||
rule, exists := r.rules[ruleID]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
log.Warnf("legacy forwarding rule %s has no handle, removing stale entry", ruleID)
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
log.Warnf("decrement set counter for stale rule %s: %v", ruleID, err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleID)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement set counter: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLegacyManagement returns the route manager's legacy management mode
|
||||
func (r *family) GetLegacyManagement() bool {
|
||||
return r.legacyManagement
|
||||
}
|
||||
|
||||
// SetLegacyManagement sets the route manager to use legacy management mode
|
||||
func (r *family) SetLegacyManagement(isLegacy bool) {
|
||||
r.legacyManagement = isLegacy
|
||||
}
|
||||
|
||||
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
|
||||
func (r *family) RemoveAllLegacyRouteRules() error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
for k, rule := range r.rules {
|
||||
if !strings.HasPrefix(string(k), firewall.ForwardingFormatPrefix) {
|
||||
continue
|
||||
}
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
||||
} else {
|
||||
delete(r.rules, k)
|
||||
}
|
||||
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) removeNatPreroutingRules() error {
|
||||
table := &nftables.Table{
|
||||
Name: tableNat,
|
||||
Family: r.af.tableFamily,
|
||||
}
|
||||
chain := &nftables.Chain{
|
||||
Name: chainNameNatPrerouting,
|
||||
Table: table,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityNATDest,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
}
|
||||
rules, err := r.conn.GetRules(table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get rules from nat table: %w", err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
// Delete rules that have our UserData suffix
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) == 0 || !strings.HasSuffix(string(rule.UserData), string(dnatSuffix)) {
|
||||
continue
|
||||
}
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete rule %s: %w", rule.UserData, err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if pair.Masquerade {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove prerouting rule: %w", err))
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove inverse prerouting rule: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy routing rule: %w", err))
|
||||
}
|
||||
|
||||
// Set counters are decremented in the sub-methods above before flush. If flush fails,
|
||||
// counters will be off until the next successful removal or refresh cycle.
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("flush remove nat rules %s: %w", pair.Destination, err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) removeNatRule(pair firewall.RouterPair) error {
|
||||
ruleID := pair.GenKey(firewall.PreroutingFormat)
|
||||
|
||||
rule, exists := r.rules[ruleID]
|
||||
if !exists {
|
||||
log.Debugf("prerouting rule %s not found", ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
log.Warnf("prerouting rule %s has no handle, removing stale entry", ruleID)
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
log.Warnf("decrement set counter for stale rule %s: %v", ruleID, err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("remove prerouting rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleID)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement set counter: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,21 +1,26 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/nftables"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
// Rule to handle management of rules
|
||||
// Rule wraps an installed filter rule (peer or route). Source set
|
||||
// membership is encoded in the rule's expressions; DeleteFilterRule
|
||||
// recovers the set name via findSets so the refcounter can drop the
|
||||
// right reference. mangleRule is set only for peer rules.
|
||||
type Rule struct {
|
||||
nftRule *nftables.Rule
|
||||
mangleRule *nftables.Rule
|
||||
nftSet *nftables.Set
|
||||
ruleID string
|
||||
ip net.IP
|
||||
// sources is the canonical source list this rule was created for.
|
||||
sources []netip.Prefix
|
||||
id manager.RuleID
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *Rule) ID() string {
|
||||
return r.ruleID
|
||||
// ID returns the rule id
|
||||
func (r *Rule) ID() manager.RuleID {
|
||||
return r.id
|
||||
}
|
||||
|
||||
27
client/firewall/nftables/testhelpers_linux_test.go
Normal file
27
client/firewall/nftables/testhelpers_linux_test.go
Normal file
@@ -0,0 +1,27 @@
|
||||
//go:build integration && !android
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func pfx(ip net.IP) []netip.Prefix {
|
||||
if ip == nil {
|
||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||
}
|
||||
if ip.IsUnspecified() {
|
||||
if ip.To4() != nil {
|
||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||
}
|
||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||
}
|
||||
a, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("invalid IP length: %d", len(ip)))
|
||||
}
|
||||
a = a.Unmap()
|
||||
return []netip.Prefix{netip.PrefixFrom(a, a.BitLen())}
|
||||
}
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
@@ -72,12 +71,14 @@ const (
|
||||
|
||||
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
|
||||
|
||||
// RuleSet is a set of rules grouped by a string key
|
||||
type RuleSet map[string]PeerRule
|
||||
// peerRules is the canonical list-based storage for peer ACL rules.
|
||||
// Match order is significant: drop rules come before accept rules so
|
||||
// callers should consult the slice in order.
|
||||
type peerRules []*PeerRule
|
||||
|
||||
type RouteRules []*RouteRule
|
||||
type routeRules []*RouteRule
|
||||
|
||||
func (r RouteRules) Sort() {
|
||||
func (r routeRules) Sort() {
|
||||
slices.SortStableFunc(r, func(a, b *RouteRule) int {
|
||||
// Deny rules come first
|
||||
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
|
||||
@@ -86,22 +87,44 @@ func (r RouteRules) Sort() {
|
||||
if a.action != firewall.ActionDrop && b.action == firewall.ActionDrop {
|
||||
return 1
|
||||
}
|
||||
return strings.Compare(a.id, b.id)
|
||||
return strings.Compare(string(a.id), string(b.id))
|
||||
})
|
||||
}
|
||||
|
||||
// peerRuleSpec carries the parameters that define a peer filter rule,
|
||||
// threaded together through the build path so the builders take a single
|
||||
// argument instead of a long parameter list.
|
||||
type peerRuleSpec struct {
|
||||
mgmtID []byte
|
||||
sources []netip.Prefix
|
||||
ipLayer gopacket.LayerType
|
||||
matchAny bool
|
||||
proto firewall.Protocol
|
||||
sPort *firewall.Port
|
||||
dPort *firewall.Port
|
||||
action firewall.Action
|
||||
}
|
||||
|
||||
// Manager userspace firewall manager
|
||||
type Manager struct {
|
||||
outgoingRules map[netip.Addr]RuleSet
|
||||
incomingDenyRules map[netip.Addr]RuleSet
|
||||
incomingRules map[netip.Addr]RuleSet
|
||||
routeRules RouteRules
|
||||
routeRulesMap map[nbid.RuleID]*RouteRule
|
||||
decoders sync.Pool
|
||||
wgIface common.IFaceMapper
|
||||
nativeFirewall firewall.Manager
|
||||
decoders sync.Pool
|
||||
wgIface common.IFaceMapper
|
||||
// nativeFirewall is the kernel firewall (nftables/iptables) used for
|
||||
// the split case where peer ACLs run in userspace here but routing
|
||||
// stays in the kernel: when the userspace firewall is forced yet the
|
||||
// router keeps using the kernel, route NAT/ACLs and DNAT are
|
||||
// delegated to it. It is nil when no native backend is available.
|
||||
nativeFirewall firewall.Manager
|
||||
mutex sync.RWMutex
|
||||
|
||||
mutex sync.RWMutex
|
||||
incomingDenyRules peerRules
|
||||
incomingAcceptRules peerRules
|
||||
incomingDenyIndex peerRuleIndex
|
||||
incomingAcceptIndex peerRuleIndex
|
||||
peerRulesMap map[nbid.RuleID]*PeerRule
|
||||
|
||||
routeRules routeRules
|
||||
routeRulesMap map[nbid.RuleID]*RouteRule
|
||||
|
||||
// indicates whether server routes are disabled
|
||||
disableServerRoutes bool
|
||||
@@ -183,24 +206,6 @@ func (d *decoder) decodePacket(data []byte) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Create userspace firewall manager constructor
|
||||
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||
return create(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||
}
|
||||
|
||||
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||
if nativeFirewall == nil {
|
||||
return nil, errors.New("native firewall is nil")
|
||||
}
|
||||
|
||||
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger, mtu)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
func parseCreateEnv() (bool, bool, bool) {
|
||||
var disableConntrack, enableLocalForwarding, disableMSSClamping bool
|
||||
var err error
|
||||
@@ -231,7 +236,7 @@ func parseCreateEnv() (bool, bool, bool) {
|
||||
return disableConntrack, enableLocalForwarding, disableMSSClamping
|
||||
}
|
||||
|
||||
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||
func Create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||
disableConntrack, enableLocalForwarding, disableMSSClamping := parseCreateEnv()
|
||||
|
||||
m := &Manager{
|
||||
@@ -254,11 +259,8 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
return d
|
||||
},
|
||||
},
|
||||
nativeFirewall: nativeFirewall,
|
||||
outgoingRules: make(map[netip.Addr]RuleSet),
|
||||
incomingDenyRules: make(map[netip.Addr]RuleSet),
|
||||
incomingRules: make(map[netip.Addr]RuleSet),
|
||||
wgIface: iface,
|
||||
nativeFirewall: nativeFirewall,
|
||||
localipmanager: newLocalIPManager(),
|
||||
disableServerRoutes: disableServerRoutes,
|
||||
stateful: !disableConntrack,
|
||||
@@ -266,6 +268,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
flowLogger: flowLogger,
|
||||
netstack: netstack.IsEnabled(),
|
||||
localForwarding: enableLocalForwarding,
|
||||
peerRulesMap: make(map[nbid.RuleID]*PeerRule),
|
||||
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
|
||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||
portDNATRules: []portDNATRule{},
|
||||
@@ -320,7 +323,7 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) ([]firewall.Rule,
|
||||
}
|
||||
|
||||
var rules []firewall.Rule
|
||||
v4Rule, err := m.addRouteFiltering(
|
||||
v4Rule, err := m.addRouteRule(
|
||||
nil,
|
||||
sources,
|
||||
firewall.Network{Prefix: wgPrefix},
|
||||
@@ -336,7 +339,7 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) ([]firewall.Rule,
|
||||
|
||||
if v6Net.IsValid() {
|
||||
log.Debugf("blocking invalid routed traffic for %s", v6Net)
|
||||
v6Rule, err := m.addRouteFiltering(
|
||||
v6Rule, err := m.addRouteRule(
|
||||
nil,
|
||||
sources,
|
||||
firewall.Network{Prefix: v6Net},
|
||||
@@ -488,64 +491,136 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPeerFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
// addPeerRule installs an input-chain rule that matches packets
|
||||
// by source only. Called from AddFilterRule when the caller doesn't
|
||||
// specify a destination. Mixed-family inputs are split: each family
|
||||
// gets its own rule with a family-correct ipLayer so packet decoding
|
||||
// matches what the matcher expects.
|
||||
func (m *Manager) addPeerRule(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
sources []netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
_ string,
|
||||
) ([]firewall.Rule, error) {
|
||||
// TODO: fix in upper layers
|
||||
i, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid IP: %s", ip)
|
||||
}
|
||||
|
||||
i = i.Unmap()
|
||||
r := PeerRule{
|
||||
id: uuid.New().String(),
|
||||
mgmtId: id,
|
||||
ip: i,
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
matchByIP: true,
|
||||
drop: action == firewall.ActionDrop,
|
||||
}
|
||||
if i.Is4() {
|
||||
r.ipLayer = layers.LayerTypeIPv4
|
||||
}
|
||||
|
||||
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
|
||||
r.matchByIP = false
|
||||
}
|
||||
|
||||
r.sPort = sPort
|
||||
r.dPort = dPort
|
||||
|
||||
r.protoLayer = protoToLayer(proto, r.ipLayer)
|
||||
|
||||
) (firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
var targetMap map[netip.Addr]RuleSet
|
||||
if r.drop {
|
||||
targetMap = m.incomingDenyRules
|
||||
} else {
|
||||
targetMap = m.incomingRules
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if sourcesMatchAny(sources) {
|
||||
spec := peerRuleSpec{
|
||||
mgmtID: id,
|
||||
sources: sources,
|
||||
ipLayer: layerTypeAll,
|
||||
matchAny: true,
|
||||
proto: proto,
|
||||
sPort: sPort,
|
||||
dPort: dPort,
|
||||
action: action,
|
||||
}
|
||||
return m.addOnePeerRule(spec), nil
|
||||
}
|
||||
|
||||
if _, ok := targetMap[r.ip]; !ok {
|
||||
targetMap[r.ip] = make(RuleSet)
|
||||
// Sources are a single family; normalize v4-mapped prefixes to plain
|
||||
// v4 and pick the matching IP layer.
|
||||
normalized := make([]netip.Prefix, len(sources))
|
||||
ipLayer := layers.LayerTypeIPv4
|
||||
for i, p := range sources {
|
||||
normalized[i] = firewall.UnmapPrefix(p)
|
||||
if normalized[i].Addr().Is6() {
|
||||
ipLayer = layers.LayerTypeIPv6
|
||||
}
|
||||
}
|
||||
targetMap[r.ip][r.id] = r
|
||||
m.mutex.Unlock()
|
||||
return []firewall.Rule{&r}, nil
|
||||
spec := peerRuleSpec{
|
||||
mgmtID: id,
|
||||
sources: normalized,
|
||||
ipLayer: ipLayer,
|
||||
matchAny: false,
|
||||
proto: proto,
|
||||
sPort: sPort,
|
||||
dPort: dPort,
|
||||
action: action,
|
||||
}
|
||||
return m.addOnePeerRule(spec), nil
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
// addOnePeerRule builds and registers a single-family peer rule, or
|
||||
// returns the existing rule when one with the same content key is
|
||||
// already installed. The caller must hold m.mutex. The content key is
|
||||
// the shared GenerateRuleID with an empty destination, so peer
|
||||
// rules dedup the same way route rules and the kernel backends do.
|
||||
//
|
||||
// There is no refcount: a content key is installed once and deleted on
|
||||
// the first DeleteFilterRule for that key. The caller must therefore
|
||||
// key its own tracking by the returned rule id so add and delete stay
|
||||
// balanced per content key; the acl manager does this via
|
||||
// peerRulesPairs. The content key is order-independent, so callers
|
||||
// passing the same sources in any order dedup to one rule.
|
||||
func (m *Manager) addOnePeerRule(spec peerRuleSpec) *PeerRule {
|
||||
ruleID := nbid.GenerateRuleID(spec.sources, firewall.Network{}, spec.proto, spec.sPort, spec.dPort, spec.action)
|
||||
if existing, ok := m.peerRulesMap[ruleID]; ok {
|
||||
return existing
|
||||
}
|
||||
|
||||
rule := m.buildPeerRule(ruleID, spec)
|
||||
m.registerPeerRule(rule)
|
||||
return rule
|
||||
}
|
||||
|
||||
func (m *Manager) buildPeerRule(ruleID nbid.RuleID, spec peerRuleSpec) *PeerRule {
|
||||
r := &PeerRule{
|
||||
id: ruleID,
|
||||
mgmtId: spec.mgmtID,
|
||||
sources: spec.sources,
|
||||
matchAny: spec.matchAny,
|
||||
action: spec.action,
|
||||
srcPort: spec.sPort,
|
||||
dstPort: spec.dPort,
|
||||
}
|
||||
if !spec.matchAny {
|
||||
r.sourceAddrs = make(map[netip.Addr]struct{}, len(spec.sources))
|
||||
for _, p := range spec.sources {
|
||||
if p.Bits() == p.Addr().BitLen() {
|
||||
r.sourceAddrs[p.Addr()] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
r.protoLayer = protoToLayer(spec.proto, spec.ipLayer)
|
||||
return r
|
||||
}
|
||||
|
||||
// registerPeerRule records a freshly built peer rule in the matching
|
||||
// slice, index, and dedup map. The caller must hold m.mutex.
|
||||
func (m *Manager) registerPeerRule(r *PeerRule) {
|
||||
if r.action == firewall.ActionDrop {
|
||||
m.incomingDenyRules = append(m.incomingDenyRules, r)
|
||||
m.incomingDenyIndex.add(r)
|
||||
} else {
|
||||
m.incomingAcceptRules = append(m.incomingAcceptRules, r)
|
||||
m.incomingAcceptIndex.add(r)
|
||||
}
|
||||
m.peerRulesMap[r.id] = r
|
||||
}
|
||||
|
||||
// sourcesMatchAny reports whether the source list matches every source,
|
||||
// i.e. contains an explicit /0 prefix. An empty list does not qualify:
|
||||
// AddFilterRule rejects it with ErrNoSources, so "match any" is always
|
||||
// the deliberate /0 case.
|
||||
func sourcesMatchAny(sources []netip.Prefix) bool {
|
||||
for _, p := range sources {
|
||||
if p.Bits() == 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// AddFilterRule is the unified entry point for both peer (input chain)
|
||||
// and route (forward chain) filtering rules. The destination
|
||||
// distinguishes the two semantics: a zero Network installs an
|
||||
// input-side rule that matches by source only; a set Network installs
|
||||
// a forward-side rule that also matches the destination.
|
||||
func (m *Manager) AddFilterRule(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination firewall.Network,
|
||||
@@ -553,13 +628,37 @@ func (m *Manager) AddRouteFiltering(
|
||||
sPort, dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
if len(sources) == 0 {
|
||||
return nil, firewall.ErrNoSources
|
||||
}
|
||||
|
||||
if destination.IsZero() {
|
||||
return m.addPeerRule(id, sources, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
return m.addRouteRule(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
// DeleteFilterRule deletes a filtering rule. The rule's underlying type
|
||||
// is used to route to the correct internal path.
|
||||
func (m *Manager) DeleteFilterRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.addRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
if r, ok := rule.(*PeerRule); ok {
|
||||
return m.deletePeerRuleLocked(r)
|
||||
}
|
||||
|
||||
// Either our *RouteRule or, under native delegation, a native
|
||||
// route-rule object that implements firewall.Rule but isn't one of
|
||||
// our concrete types. The route path forwards the latter to the
|
||||
// native firewall.
|
||||
return m.deleteRouteRule(rule)
|
||||
}
|
||||
|
||||
func (m *Manager) addRouteFiltering(
|
||||
func (m *Manager) addRouteRule(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination firewall.Network,
|
||||
@@ -568,18 +667,17 @@ func (m *Manager) addRouteFiltering(
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
return m.nativeFirewall.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
ruleID := nbid.GenerateRuleID(sources, destination, proto, sPort, dPort, action)
|
||||
|
||||
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
|
||||
if existingRule, ok := m.routeRulesMap[ruleID]; ok {
|
||||
return existingRule, nil
|
||||
}
|
||||
|
||||
rule := RouteRule{
|
||||
// TODO: consolidate these IDs
|
||||
id: string(ruleKey),
|
||||
id: ruleID,
|
||||
mgmtId: id,
|
||||
sources: sources,
|
||||
dstSet: destination.Set,
|
||||
@@ -594,72 +692,57 @@ func (m *Manager) addRouteFiltering(
|
||||
|
||||
m.routeRules = append(m.routeRules, &rule)
|
||||
m.routeRules.Sort()
|
||||
m.routeRulesMap[ruleKey] = &rule
|
||||
m.routeRulesMap[ruleID] = &rule
|
||||
|
||||
return &rule, nil
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.deleteRouteRule(rule)
|
||||
}
|
||||
|
||||
func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||
return m.nativeFirewall.DeleteFilterRule(rule)
|
||||
}
|
||||
|
||||
ruleKey := nbid.RuleID(rule.ID())
|
||||
if _, ok := m.routeRulesMap[ruleKey]; !ok {
|
||||
return fmt.Errorf("route rule not found: %s", ruleKey)
|
||||
ruleID := rule.ID()
|
||||
trimmed, _, ok := removeRuleByID(m.routeRules, ruleID)
|
||||
if !ok {
|
||||
return fmt.Errorf("route rule not found: %s", ruleID)
|
||||
}
|
||||
|
||||
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
|
||||
return r.id == string(ruleKey)
|
||||
})
|
||||
if idx < 0 {
|
||||
return fmt.Errorf("route rule not found in slice: %s", ruleKey)
|
||||
}
|
||||
|
||||
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
|
||||
delete(m.routeRulesMap, ruleKey)
|
||||
m.routeRules = trimmed
|
||||
delete(m.routeRulesMap, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
// deletePeerRuleLocked removes a peer rule from the matching slice,
|
||||
// index, and dedup map. The caller must hold m.mutex.
|
||||
func (m *Manager) deletePeerRuleLocked(r *PeerRule) error {
|
||||
target, index := &m.incomingAcceptRules, &m.incomingAcceptIndex
|
||||
if r.action == firewall.ActionDrop {
|
||||
target, index = &m.incomingDenyRules, &m.incomingDenyIndex
|
||||
}
|
||||
|
||||
r, ok := rule.(*PeerRule)
|
||||
trimmed, stored, ok := removeRuleByID(*target, r.id)
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||
}
|
||||
|
||||
var sourceMap map[netip.Addr]RuleSet
|
||||
if r.drop {
|
||||
sourceMap = m.incomingDenyRules
|
||||
} else {
|
||||
sourceMap = m.incomingRules
|
||||
}
|
||||
|
||||
if ruleset, ok := sourceMap[r.ip]; ok {
|
||||
if _, exists := ruleset[r.id]; !exists {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(ruleset, r.id)
|
||||
if len(ruleset) == 0 {
|
||||
delete(sourceMap, r.ip)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
|
||||
*target = trimmed
|
||||
index.remove(stored)
|
||||
delete(m.peerRulesMap, r.id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeRuleByID removes the first rule whose id matches ruleID from
|
||||
// rules, preserving order. It returns the trimmed slice, the removed
|
||||
// rule, and whether a match was found.
|
||||
func removeRuleByID[S ~[]T, T firewall.Rule](rules S, ruleID firewall.RuleID) (S, T, bool) {
|
||||
idx := slices.IndexFunc(rules, func(r T) bool { return r.ID() == ruleID })
|
||||
var removed T
|
||||
if idx < 0 {
|
||||
return rules, removed, false
|
||||
}
|
||||
removed = rules[idx]
|
||||
return slices.Delete(rules, idx, idx+1), removed, true
|
||||
}
|
||||
|
||||
// SetLegacyManagement doesn't need to be implemented for this manager
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
if m.nativeFirewall == nil {
|
||||
@@ -674,9 +757,11 @@ func (m *Manager) Flush() error { return nil }
|
||||
// resetState clears all firewall rules and closes connection trackers.
|
||||
// Must be called with m.mutex held.
|
||||
func (m *Manager) resetState() {
|
||||
clear(m.outgoingRules)
|
||||
clear(m.incomingDenyRules)
|
||||
clear(m.incomingRules)
|
||||
m.incomingDenyRules = m.incomingDenyRules[:0]
|
||||
m.incomingAcceptRules = m.incomingAcceptRules[:0]
|
||||
m.incomingDenyIndex.reset()
|
||||
m.incomingAcceptIndex.reset()
|
||||
clear(m.peerRulesMap)
|
||||
clear(m.routeRulesMap)
|
||||
m.routeRules = m.routeRules[:0]
|
||||
m.udpHookOut.Store(nil)
|
||||
@@ -820,11 +905,11 @@ func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP netip.Addr) {
|
||||
case layers.LayerTypeIPv4:
|
||||
src, _ := netip.AddrFromSlice(d.ip4.SrcIP)
|
||||
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
|
||||
return src, dst
|
||||
return src.Unmap(), dst.Unmap()
|
||||
case layers.LayerTypeIPv6:
|
||||
src, _ := netip.AddrFromSlice(d.ip6.SrcIP)
|
||||
dst, _ := netip.AddrFromSlice(d.ip6.DstIP)
|
||||
return src, dst
|
||||
return src.Unmap(), dst.Unmap()
|
||||
default:
|
||||
return netip.Addr{}, netip.Addr{}
|
||||
}
|
||||
@@ -1404,20 +1489,12 @@ func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingDenyRules[srcIP], d); ok {
|
||||
if mgmtId, filter, ok := m.incomingDenyIndex.match(srcIP, d); ok {
|
||||
return mgmtId, filter
|
||||
}
|
||||
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[srcIP], d); ok {
|
||||
if mgmtId, filter, ok := m.incomingAcceptIndex.match(srcIP, d); ok {
|
||||
return mgmtId, filter
|
||||
}
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[netip.IPv4Unspecified()], d); ok {
|
||||
return mgmtId, filter
|
||||
}
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[netip.IPv6Unspecified()], d); ok {
|
||||
return mgmtId, filter
|
||||
}
|
||||
|
||||
return nil, true
|
||||
}
|
||||
|
||||
@@ -1438,39 +1515,6 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
|
||||
payloadLayer := d.decoded[1]
|
||||
|
||||
for _, rule := range rules {
|
||||
if rule.matchByIP && ip.Compare(rule.ip) != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if rule.protoLayer == layerTypeAll {
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
|
||||
if !protoLayerMatches(rule.protoLayer, payloadLayer) {
|
||||
continue
|
||||
}
|
||||
|
||||
switch payloadLayer {
|
||||
case layers.LayerTypeTCP:
|
||||
if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) {
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
case layers.LayerTypeUDP:
|
||||
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
// routeACLsPass returns true if the packet is allowed by the route ACLs
|
||||
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) ([]byte, bool) {
|
||||
m.mutex.RLock()
|
||||
|
||||
@@ -94,7 +94,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
stateful: false,
|
||||
setupFunc: func(m *Manager) {
|
||||
// Single rule allowing all traffic
|
||||
_, err := m.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
|
||||
_, err := m.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolALL, nil, nil, fw.ActionAccept)
|
||||
require.NoError(b, err)
|
||||
},
|
||||
desc: "Baseline: Single 'allow all' rule without connection tracking",
|
||||
@@ -114,15 +114,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
// Add explicit rules matching return traffic pattern
|
||||
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
|
||||
ip := generateRandomIPs(1)[0]
|
||||
_, err := m.AddPeerFiltering(
|
||||
_, err := m.AddFilterRule(
|
||||
nil,
|
||||
ip,
|
||||
pfx(ip), fw.Network{},
|
||||
fw.ProtocolTCP,
|
||||
&fw.Port{Values: []uint16{uint16(1024 + i)}},
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
fw.ActionAccept,
|
||||
"",
|
||||
)
|
||||
fw.ActionAccept)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
},
|
||||
@@ -133,15 +131,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
stateful: true,
|
||||
setupFunc: func(m *Manager) {
|
||||
// Add some basic rules but rely on state for established connections
|
||||
_, err := m.AddPeerFiltering(
|
||||
_, err := m.AddFilterRule(
|
||||
nil,
|
||||
net.ParseIP("0.0.0.0"),
|
||||
pfx(net.ParseIP("0.0.0.0")), fw.Network{},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionDrop,
|
||||
"",
|
||||
)
|
||||
fw.ActionDrop)
|
||||
require.NoError(b, err)
|
||||
},
|
||||
desc: "Connection tracking with established connections",
|
||||
@@ -170,7 +166,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
// Create manager and basic setup
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -210,7 +206,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -253,7 +249,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -411,7 +407,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -538,7 +534,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -546,7 +542,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -621,7 +617,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -629,7 +625,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -732,14 +728,14 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -812,13 +808,13 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -931,7 +927,7 @@ func BenchmarkRouteACLs(b *testing.B) {
|
||||
|
||||
for _, r := range rules {
|
||||
dst := fw.Network{Prefix: r.dest}
|
||||
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
|
||||
_, err := manager.AddFilterRule(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
@@ -1016,7 +1012,7 @@ func BenchmarkMSSClamping(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
@@ -1081,7 +1077,7 @@ func BenchmarkMSSClampingOverhead(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
@@ -1136,7 +1132,7 @@ func BenchmarkMSSClampingMemory(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
|
||||
@@ -32,7 +32,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, manager)
|
||||
|
||||
@@ -496,40 +496,32 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if tc.ruleAction == fw.ActionDrop {
|
||||
// add general accept rule for the same IP to test drop rule precedence
|
||||
rules, err := manager.AddPeerFiltering(
|
||||
rules, err := manager.AddFilterRule(
|
||||
nil,
|
||||
net.ParseIP(tc.ruleIP),
|
||||
pfx(net.ParseIP(tc.ruleIP)), fw.Network{},
|
||||
fw.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
"",
|
||||
)
|
||||
fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, rules)
|
||||
require.NotNil(t, rules)
|
||||
t.Cleanup(func() {
|
||||
for _, rule := range rules {
|
||||
require.NoError(t, manager.DeletePeerRule(rule))
|
||||
}
|
||||
require.NoError(t, manager.DeleteFilterRule(rules))
|
||||
})
|
||||
}
|
||||
|
||||
rules, err := manager.AddPeerFiltering(
|
||||
rules, err := manager.AddFilterRule(
|
||||
nil,
|
||||
net.ParseIP(tc.ruleIP),
|
||||
pfx(net.ParseIP(tc.ruleIP)), fw.Network{},
|
||||
tc.ruleProto,
|
||||
tc.ruleSrcPort,
|
||||
tc.ruleDstPort,
|
||||
tc.ruleAction,
|
||||
"",
|
||||
)
|
||||
tc.ruleAction)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, rules)
|
||||
require.NotNil(t, rules)
|
||||
|
||||
t.Cleanup(func() {
|
||||
for _, rule := range rules {
|
||||
require.NoError(t, manager.DeletePeerRule(rule))
|
||||
}
|
||||
require.NoError(t, manager.DeleteFilterRule(rules))
|
||||
})
|
||||
|
||||
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
@@ -557,7 +549,7 @@ func TestPeerACLFilteringIPv6(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||
|
||||
@@ -672,22 +664,18 @@ func TestPeerACLFilteringIPv6(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if tc.ruleAction == fw.ActionDrop {
|
||||
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
|
||||
rules, err := manager.AddFilterRule(nil, pfx(net.ParseIP(tc.ruleIP)), fw.Network{}, fw.ProtocolALL, nil, nil, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
for _, rule := range rules {
|
||||
require.NoError(t, manager.DeletePeerRule(rule))
|
||||
}
|
||||
require.NoError(t, manager.DeleteFilterRule(rules))
|
||||
})
|
||||
}
|
||||
|
||||
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction, "")
|
||||
rules, err := manager.AddFilterRule(nil, pfx(net.ParseIP(tc.ruleIP)), fw.Network{}, tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, rules)
|
||||
require.NotNil(t, rules)
|
||||
t.Cleanup(func() {
|
||||
for _, rule := range rules {
|
||||
require.NoError(t, manager.DeletePeerRule(rule))
|
||||
}
|
||||
require.NoError(t, manager.DeleteFilterRule(rules))
|
||||
})
|
||||
|
||||
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||
@@ -800,7 +788,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(tb, err)
|
||||
require.NoError(tb, manager.EnableRouting())
|
||||
require.NotNil(tb, manager)
|
||||
@@ -1405,7 +1393,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if tc.rule.action == fw.ActionDrop {
|
||||
// add general accept rule to test drop rule
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
rule, err := manager.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||
@@ -1415,13 +1403,13 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
require.NotEmpty(t, rule)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||
require.NoError(t, manager.DeleteFilterRule(rule))
|
||||
})
|
||||
}
|
||||
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
rule, err := manager.AddFilterRule(
|
||||
nil,
|
||||
tc.rule.sources,
|
||||
tc.rule.dest,
|
||||
@@ -1431,10 +1419,10 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
tc.rule.action,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
require.NotEmpty(t, rule)
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||
require.NoError(t, manager.DeleteFilterRule(rule))
|
||||
})
|
||||
|
||||
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||
@@ -1602,9 +1590,9 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var rules []fw.Rule
|
||||
var addedRules []fw.Rule
|
||||
for _, r := range tc.rules {
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
rule, err := manager.AddFilterRule(
|
||||
nil,
|
||||
r.sources,
|
||||
r.dest,
|
||||
@@ -1615,12 +1603,12 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
rules = append(rules, rule)
|
||||
addedRules = append(addedRules, rule)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
for _, rule := range rules {
|
||||
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||
for _, rule := range addedRules {
|
||||
require.NoError(t, manager.DeleteFilterRule(rule))
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1646,7 +1634,7 @@ func TestRouteACLSet(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -1655,7 +1643,7 @@ func TestRouteACLSet(t *testing.T) {
|
||||
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||
|
||||
// Add rule that uses the set (initially empty)
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
rule, err := manager.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.Network{Set: set},
|
||||
@@ -1689,7 +1677,7 @@ func TestRouteACLFilteringIPv6(t *testing.T) {
|
||||
manager := setupRoutedManager(t, "10.10.0.100/16")
|
||||
|
||||
v6Dst := netip.MustParsePrefix("fd00:dead:beef::/48")
|
||||
_, err := manager.AddRouteFiltering(
|
||||
_, err := manager.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
|
||||
fw.Network{Prefix: v6Dst},
|
||||
@@ -1700,7 +1688,7 @@ func TestRouteACLFilteringIPv6(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
_, err = manager.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("fd00:dead:beef:1::/64")},
|
||||
|
||||
@@ -29,7 +29,7 @@ func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
|
||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||
|
||||
// Add rule first time
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
rule1, err := manager.AddFilterRule(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
@@ -42,7 +42,7 @@ func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
|
||||
require.NotNil(t, rule1)
|
||||
|
||||
// Add the same rule again
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
rule2, err := manager.AddFilterRule(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
@@ -74,7 +74,7 @@ func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
|
||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||
|
||||
// Add first rule
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
rule1, err := manager.AddFilterRule(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
@@ -86,7 +86,7 @@ func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add different rule (different destination)
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
rule2, err := manager.AddFilterRule(
|
||||
[]byte("policy-2"),
|
||||
sources,
|
||||
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
|
||||
@@ -115,7 +115,7 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
rule1, err := manager.AddFilterRule(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
@@ -132,7 +132,7 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
||||
require.True(t, pass, "Traffic should pass with rule in place")
|
||||
|
||||
// Re-add same rule (simulates network map update)
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
rule2, err := manager.AddFilterRule(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
@@ -147,7 +147,7 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
||||
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
|
||||
// would remove the only matching rule and cause a traffic gap.
|
||||
if rule1.ID() != rule2.ID() {
|
||||
err = manager.DeleteRouteRule(rule1)
|
||||
err = manager.DeleteFilterRule(rule1)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -156,6 +156,59 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
||||
"Traffic should still pass after rule update - no gap should occur")
|
||||
}
|
||||
|
||||
// TestBlockInvalidRoutedDualStack verifies that when the interface has an
|
||||
// IPv6 overlay address, blockInvalidRouted installs a block rule for both
|
||||
// the v4 and v6 WG prefixes and that routed traffic to the v6 prefix is
|
||||
// denied. The v4-only soft-skip path is covered by
|
||||
// TestBlockInvalidRoutedIdempotent, whose mock leaves IPv6Net invalid.
|
||||
func TestBlockInvalidRoutedDualStack(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
dev := mocks.NewMockDevice(ctrl)
|
||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||
|
||||
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||
wgNet6 := netip.MustParsePrefix("fd00:1234::1/64")
|
||||
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: wgNet.Addr(),
|
||||
Network: wgNet,
|
||||
IPv6: wgNet6.Addr(),
|
||||
IPv6Net: wgNet6,
|
||||
}
|
||||
},
|
||||
GetDeviceFunc: func() *device.FilteredDevice {
|
||||
return &device.FilteredDevice{Device: dev}
|
||||
},
|
||||
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||
return &wgdevice.Device{}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
|
||||
rules, err := manager.blockInvalidRouted(ifaceMock)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rules, 2, "dual-stack interface must produce a v4 and a v6 block rule")
|
||||
|
||||
manager.mutex.RLock()
|
||||
ruleCount := len(manager.routeRules)
|
||||
manager.mutex.RUnlock()
|
||||
assert.Equal(t, 2, ruleCount, "should have one block rule per family")
|
||||
|
||||
// v6 routed traffic to the WG prefix must be denied.
|
||||
srcIP := netip.MustParseAddr("2001:db8::1")
|
||||
dstIP := netip.MustParseAddr("fd00:1234::50")
|
||||
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 80)
|
||||
assert.False(t, pass, "block rule should deny routed traffic to the v6 WG prefix")
|
||||
}
|
||||
|
||||
// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates
|
||||
// exactly one drop rule for the WireGuard network prefix, and calling it again
|
||||
// returns the same rule without duplicating.
|
||||
@@ -182,7 +235,7 @@ func TestBlockInvalidRoutedIdempotent(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -245,7 +298,7 @@ func TestBlockRuleNotAccumulatedOnRepeatedEnableRouting(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -274,7 +327,7 @@ func TestRouteRuleCountStableAcrossUpdates(t *testing.T) {
|
||||
|
||||
// Simulate 5 network map updates with the same route rule
|
||||
for i := 0; i < 5; i++ {
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
rule, err := manager.AddFilterRule(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
@@ -304,7 +357,7 @@ func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
|
||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||
|
||||
// Add same rule twice
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
rule1, err := manager.AddFilterRule(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
@@ -315,7 +368,7 @@ func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
rule2, err := manager.AddFilterRule(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
@@ -329,7 +382,7 @@ func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
|
||||
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
|
||||
|
||||
// Delete using first reference
|
||||
err = manager.DeleteRouteRule(rule1)
|
||||
err = manager.DeleteFilterRule(rule1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify traffic no longer passes
|
||||
@@ -364,7 +417,7 @@ func setupTestManager(t *testing.T) *Manager {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
manager, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.EnableRouting())
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ func TestManagerCreate(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -89,7 +89,7 @@ func TestManagerCreate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
func TestManagerAddFilterRule(t *testing.T) {
|
||||
isSetFilterCalled := false
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error {
|
||||
@@ -98,7 +98,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -109,7 +109,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
|
||||
rule, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
rule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -131,7 +131,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -142,63 +142,33 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
|
||||
rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
|
||||
rule2, err := m.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, proto, nil, port, action)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Check rules exist in appropriate maps
|
||||
for _, r := range rule2 {
|
||||
peerRule, ok := r.(*PeerRule)
|
||||
if !ok {
|
||||
t.Errorf("rule should be a PeerRule")
|
||||
continue
|
||||
}
|
||||
// Check if rule exists in deny or allow maps based on action
|
||||
var found bool
|
||||
if peerRule.drop {
|
||||
_, found = m.incomingDenyRules[ip][r.ID()]
|
||||
} else {
|
||||
_, found = m.incomingRules[ip][r.ID()]
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("rule2 is not in the expected rules map")
|
||||
peerRule, ok := rule2.(*PeerRule)
|
||||
require.True(t, ok, "rule should be a PeerRule")
|
||||
|
||||
inMap := func() bool {
|
||||
if peerRule.action == fw.ActionDrop {
|
||||
return findRuleByID(m.incomingDenyRules, ip, rule2.ID())
|
||||
}
|
||||
return findRuleByID(m.incomingAcceptRules, ip, rule2.ID())
|
||||
}
|
||||
|
||||
for _, r := range rule2 {
|
||||
err = m.DeletePeerRule(r)
|
||||
if err != nil {
|
||||
t.Errorf("failed to delete rule: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
require.True(t, inMap(), "rule2 should be in the expected rules list")
|
||||
|
||||
// Check rules are removed from appropriate maps
|
||||
for _, r := range rule2 {
|
||||
peerRule, ok := r.(*PeerRule)
|
||||
if !ok {
|
||||
t.Errorf("rule should be a PeerRule")
|
||||
continue
|
||||
}
|
||||
// Check if rule is removed from deny or allow maps based on action
|
||||
var found bool
|
||||
if peerRule.drop {
|
||||
_, found = m.incomingDenyRules[ip][r.ID()]
|
||||
} else {
|
||||
_, found = m.incomingRules[ip][r.ID()]
|
||||
}
|
||||
if found {
|
||||
t.Errorf("rule2 should be removed from the rules map")
|
||||
}
|
||||
}
|
||||
require.NoError(t, m.DeleteFilterRule(rule2), "failed to delete rule")
|
||||
|
||||
require.False(t, inMap(), "rule2 should be removed from the rules list")
|
||||
}
|
||||
|
||||
func TestSetUDPPacketHook(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
}, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||
|
||||
@@ -222,7 +192,7 @@ func TestSetUDPPacketHook(t *testing.T) {
|
||||
func TestSetTCPPacketHook(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
}, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||
|
||||
@@ -250,7 +220,7 @@ func TestPeerRuleLifecycleDenyRules(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, m.Close(nil))
|
||||
@@ -260,36 +230,34 @@ func TestPeerRuleLifecycleDenyRules(t *testing.T) {
|
||||
addr := netip.MustParseAddr("192.168.1.1")
|
||||
|
||||
// Add multiple deny rules for different ports
|
||||
rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||
rule1, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{22}}, fw.ActionDrop)
|
||||
require.NoError(t, err)
|
||||
|
||||
rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{80}}, fw.ActionDrop, "")
|
||||
rule2, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCount := len(m.incomingDenyRules[addr])
|
||||
denyCount := countRulesForAddr(m.incomingDenyRules, addr)
|
||||
m.mutex.RUnlock()
|
||||
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
|
||||
|
||||
// Delete the first deny rule
|
||||
err = m.DeletePeerRule(rule1[0])
|
||||
err = m.DeleteFilterRule(rule1)
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCount = len(m.incomingDenyRules[addr])
|
||||
denyCount = countRulesForAddr(m.incomingDenyRules, addr)
|
||||
m.mutex.RUnlock()
|
||||
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
|
||||
|
||||
// Delete the second deny rule
|
||||
err = m.DeletePeerRule(rule2[0])
|
||||
err = m.DeleteFilterRule(rule2)
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
_, exists := m.incomingDenyRules[addr]
|
||||
exists := countRulesForAddr(m.incomingDenyRules, addr) > 0
|
||||
m.mutex.RUnlock()
|
||||
require.False(t, exists, "Deny rules IP entry should be cleaned up when empty")
|
||||
require.False(t, exists, "Deny rules should be cleaned up when empty")
|
||||
}
|
||||
|
||||
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
|
||||
@@ -299,7 +267,7 @@ func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, m.Close(nil))
|
||||
@@ -311,27 +279,21 @@ func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
|
||||
// Simulate 10 network map updates: add rule, delete old, add new
|
||||
for i := 0; i < 10; i++ {
|
||||
// Add a deny rule
|
||||
rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||
rules, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{22}}, fw.ActionDrop)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add an allow rule
|
||||
allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
allowRules, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete them (simulating ACL manager cleanup)
|
||||
for _, r := range rules {
|
||||
require.NoError(t, m.DeletePeerRule(r))
|
||||
}
|
||||
for _, r := range allowRules {
|
||||
require.NoError(t, m.DeletePeerRule(r))
|
||||
}
|
||||
require.NoError(t, m.DeleteFilterRule(rules))
|
||||
require.NoError(t, m.DeleteFilterRule(allowRules))
|
||||
}
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCount := len(m.incomingDenyRules[addr])
|
||||
allowCount := len(m.incomingRules[addr])
|
||||
denyCount := countRulesForAddr(m.incomingDenyRules, addr)
|
||||
allowCount := countRulesForAddr(m.incomingAcceptRules, addr)
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
|
||||
@@ -345,7 +307,7 @@ func TestMixedAllowDenyRulesSameIP(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, m.Close(nil))
|
||||
@@ -354,41 +316,39 @@ func TestMixedAllowDenyRulesSameIP(t *testing.T) {
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
|
||||
// Add allow rule for port 80
|
||||
allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
allowRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add deny rule for port 22
|
||||
denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||
denyRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{22}}, fw.ActionDrop)
|
||||
require.NoError(t, err)
|
||||
|
||||
addr := netip.MustParseAddr("192.168.1.1")
|
||||
m.mutex.RLock()
|
||||
allowCount := len(m.incomingRules[addr])
|
||||
denyCount := len(m.incomingDenyRules[addr])
|
||||
allowCount := countRulesForAddr(m.incomingAcceptRules, addr)
|
||||
denyCount := countRulesForAddr(m.incomingDenyRules, addr)
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
|
||||
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
|
||||
|
||||
// Delete allow rule should not affect deny rule
|
||||
err = m.DeletePeerRule(allowRule[0])
|
||||
err = m.DeleteFilterRule(allowRule)
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCountAfter := len(m.incomingDenyRules[addr])
|
||||
denyCountAfter := countRulesForAddr(m.incomingDenyRules, addr)
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
|
||||
|
||||
// Delete deny rule
|
||||
err = m.DeletePeerRule(denyRule[0])
|
||||
err = m.DeleteFilterRule(denyRule)
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
_, denyExists := m.incomingDenyRules[addr]
|
||||
_, allowExists := m.incomingRules[addr]
|
||||
denyExists := countRulesForAddr(m.incomingDenyRules, addr) > 0
|
||||
allowExists := countRulesForAddr(m.incomingAcceptRules, addr) > 0
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.False(t, denyExists, "Deny rules should be empty")
|
||||
@@ -400,7 +360,7 @@ func TestManagerReset(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -411,7 +371,7 @@ func TestManagerReset(t *testing.T) {
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
|
||||
_, err = m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
_, err = m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -423,7 +383,7 @@ func TestManagerReset(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 || len(m.incomingDenyRules) != 0 {
|
||||
if len(m.incomingAcceptRules) != 0 || len(m.incomingDenyRules) != 0 {
|
||||
t.Errorf("rules are not empty")
|
||||
}
|
||||
}
|
||||
@@ -439,7 +399,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -449,7 +409,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
proto := fw.ProtocolUDP
|
||||
action := fw.ActionAccept
|
||||
|
||||
_, err = m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||
_, err = m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, nil, action)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -502,7 +462,7 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
}
|
||||
|
||||
// creating manager instance
|
||||
manager, err := Create(iface, false, flowLogger, nbiface.DefaultMTU)
|
||||
manager, err := Create(iface, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Manager: %s", err)
|
||||
}
|
||||
@@ -521,7 +481,7 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
func TestProcessOutgoingHooks(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
}, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.udpTracker.Close()
|
||||
@@ -606,7 +566,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
manager, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
|
||||
@@ -621,7 +581,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddFilterRule(nil, pfx(ip), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
@@ -633,7 +593,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
}, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.udpTracker.Close() // Close the existing tracker
|
||||
@@ -845,7 +805,7 @@ func TestUpdateSetMerge(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
manager, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -858,7 +818,7 @@ func TestUpdateSetMerge(t *testing.T) {
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
}
|
||||
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
rule, err := manager.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.Network{Set: set},
|
||||
@@ -931,7 +891,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
manager, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -939,7 +899,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
|
||||
|
||||
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
rule, err := manager.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.Network{Set: set},
|
||||
@@ -1051,7 +1011,7 @@ func TestMSSClamping(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, 1280)
|
||||
manager, err := Create(ifaceMock, nil, false, flowLogger, 1280)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -1243,7 +1203,7 @@ func TestShouldForward(t *testing.T) {
|
||||
return wgaddr.Address{IP: wgIP, Network: netip.PrefixFrom(wgIP, 24)}
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
manager, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -1358,7 +1318,7 @@ func TestShouldForward(t *testing.T) {
|
||||
|
||||
// Re-create manager to pick up the new address with IPv6
|
||||
require.NoError(t, manager.Close(nil))
|
||||
manager, err = Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
manager, err = Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
|
||||
v6Cases := []struct {
|
||||
|
||||
@@ -66,7 +66,7 @@ func BenchmarkDNATTranslation(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
@@ -126,7 +126,7 @@ func BenchmarkDNATTranslation(b *testing.B) {
|
||||
func BenchmarkDNATConcurrency(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
@@ -198,7 +198,7 @@ func BenchmarkDNATScaling(b *testing.B) {
|
||||
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
@@ -310,7 +310,7 @@ func BenchmarkChecksumUpdate(b *testing.B) {
|
||||
func BenchmarkDNATMemoryAllocations(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
@@ -483,7 +483,7 @@ func BenchmarkPortDNAT(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
func TestPortDNATBasic(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -51,7 +51,7 @@ func TestPortDNATBasic(t *testing.T) {
|
||||
func TestPortDNATMultipleRules(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
func TestDNATTranslationCorrectness(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -106,7 +106,7 @@ func parsePacket(t testing.TB, packetData []byte) *decoder {
|
||||
func TestDNATMappingManagement(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -154,7 +154,7 @@ func TestDNATMappingManagement(t *testing.T) {
|
||||
func TestInboundPortDNAT(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -204,7 +204,7 @@ func TestInboundPortDNAT(t *testing.T) {
|
||||
func TestInboundPortDNATNegative(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
|
||||
327
client/firewall/uspfilter/peer_acl_bench_test.go
Normal file
327
client/firewall/uspfilter/peer_acl_bench_test.go
Normal file
@@ -0,0 +1,327 @@
|
||||
//go:build uspbench
|
||||
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// BenchmarkPeerACLMatch measures the per-packet cost of the peer ACL
|
||||
// matcher (peerACLsBlock) across realistic shapes: M distinct policy
|
||||
// rules, each with K source peers in its set.
|
||||
//
|
||||
// With the reverse-source index, miss cost is independent of M and
|
||||
// hit cost grows only with the number of rules touching a single
|
||||
// srcIP, not with total rule count.
|
||||
func BenchmarkPeerACLMatch(b *testing.B) {
|
||||
shapes := []struct{ M, K int }{
|
||||
{1, 100}, {10, 100}, {50, 100}, {100, 100}, {100, 1000},
|
||||
}
|
||||
families := []struct {
|
||||
name string
|
||||
v6 bool
|
||||
}{{"v4", false}, {"v6", true}}
|
||||
|
||||
for _, fam := range families {
|
||||
for _, s := range shapes {
|
||||
b.Run(fmt.Sprintf("%s/M=%d/K=%d/hit", fam.name, s.M, s.K), func(b *testing.B) {
|
||||
runPeerACLBench(b, s.M, s.K, true, fam.v6)
|
||||
})
|
||||
b.Run(fmt.Sprintf("%s/M=%d/K=%d/miss", fam.name, s.M, s.K), func(b *testing.B) {
|
||||
runPeerACLBench(b, s.M, s.K, false, fam.v6)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func runPeerACLBench(b *testing.B, m, k int, hit, v6 bool) {
|
||||
log.SetOutput(io.Discard) // keep manager logs out of the benchmark output
|
||||
|
||||
// Miss packets are dropped, so they always traverse the full peer
|
||||
// ACL matcher (every bucket) without short-circuiting and without
|
||||
// touching conntrack. Disable conntrack for the miss case so it
|
||||
// measures the matcher, not established-state lookups. The hit case
|
||||
// keeps conntrack on: an accepted packet reaches trackInbound, which
|
||||
// needs the trackers conntrack creates.
|
||||
if !hit {
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
}
|
||||
|
||||
bits := 32
|
||||
genPkt := generatePacket
|
||||
addrs := uniqueAddrs
|
||||
if v6 {
|
||||
bits = 128
|
||||
genPkt = generatePacket6
|
||||
addrs = uniqueAddrs6
|
||||
}
|
||||
|
||||
// dstIP must be a local IP so filterInbound takes the local-traffic
|
||||
// path (handleLocalTraffic → peerACLsBlock) we are measuring; an
|
||||
// address the manager doesn't own would be treated as routed and
|
||||
// short-circuit before the peer matcher.
|
||||
dstIP := addrs(1, 2)[0]
|
||||
mockAddr := wgaddr.Address{IP: dstIP, Network: netip.PrefixFrom(dstIP, bits)}
|
||||
if v6 {
|
||||
// The local-IP manager needs a valid v4 address too; expose the v6
|
||||
// dst as the interface's IPv6 so IsLocalIP recognizes it.
|
||||
mockAddr = wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.64.0.1"),
|
||||
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
||||
IPv6: dstIP,
|
||||
IPv6Net: netip.PrefixFrom(dstIP, bits),
|
||||
}
|
||||
}
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address { return mockAddr },
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
b.Cleanup(func() { require.NoError(b, manager.Close(nil)) })
|
||||
|
||||
// Generate M policies × K source peers, all distinct.
|
||||
all := addrs(m*k, 1)
|
||||
for i := 0; i < m; i++ {
|
||||
sources := make([]netip.Prefix, k)
|
||||
for j, a := range all[i*k : (i+1)*k] {
|
||||
sources[j] = netip.PrefixFrom(a, bits)
|
||||
}
|
||||
_, err := manager.AddFilterRule(
|
||||
nil, sources, fw.Network{}, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{uint16(80 + i)}},
|
||||
fw.ActionAccept)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
// Hit: cycle through real sources, picking the matching policy's port.
|
||||
// Miss: a source from a disjoint range, port 80 (matches no policy).
|
||||
var pktFn func(i int) []byte
|
||||
if hit {
|
||||
pktFn = func(i int) []byte {
|
||||
policy := i % m
|
||||
src := all[policy*k+(i%k)]
|
||||
return genPkt(b, src.AsSlice(), dstIP.AsSlice(),
|
||||
uint16(1024+i%60000), uint16(80+policy), layers.IPProtocolTCP)
|
||||
}
|
||||
} else {
|
||||
miss := addrs(4096, 99)
|
||||
pktFn = func(i int) []byte {
|
||||
return genPkt(b, miss[i%len(miss)].AsSlice(), dstIP.AsSlice(),
|
||||
uint16(1024+i%60000), 80, layers.IPProtocolTCP)
|
||||
}
|
||||
}
|
||||
|
||||
// Pre-build a pool to avoid allocations dominating the measurement.
|
||||
pool := make([][]byte, 1024)
|
||||
for i := range pool {
|
||||
pool[i] = pktFn(i)
|
||||
}
|
||||
|
||||
// Confirm the matcher is actually exercised: a hit packet must be
|
||||
// allowed and a miss packet dropped. Without this the benchmark
|
||||
// could silently time the routed early-return instead.
|
||||
require.Equal(b, !hit, manager.filterInbound(pool[0], 0),
|
||||
"benchmark must reach the peer ACL matcher")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.filterInbound(pool[i%len(pool)], 0)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPeerACLIndexMemory reports the resident memory cost of
|
||||
// the source-keyed index across realistic deployment shapes. Two
|
||||
// dimensions matter: (M, K), the number of policies × peers-per-policy,
|
||||
// and overlap, the fraction of peers shared between policies.
|
||||
//
|
||||
// The output uses ReportMetric("bytes/rule") so the cost can be
|
||||
// compared across shapes directly. Total bytes = bytes/rule * M.
|
||||
func BenchmarkPeerACLIndexMemory(b *testing.B) {
|
||||
cases := []struct {
|
||||
name string
|
||||
M, K int
|
||||
overlapFrac float64 // 0 = disjoint per-policy sources, 1 = all share the same pool
|
||||
}{
|
||||
{"M=10/K=100/disjoint", 10, 100, 0},
|
||||
{"M=100/K=100/disjoint", 100, 100, 0},
|
||||
{"M=100/K=1000/disjoint", 100, 1000, 0},
|
||||
{"M=100/K=1000/overlap=0.5", 100, 1000, 0.5},
|
||||
{"M=100/K=1000/overlap=1.0", 100, 1000, 1.0},
|
||||
{"M=1000/K=100/overlap=1.0", 1000, 100, 1.0},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
b.Run(c.name, func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
mgr, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, nil, false, flowLogger, iface.DefaultMTU)
|
||||
|
||||
populateIndexedRules(b, mgr, c.M, c.K, c.overlapFrac)
|
||||
|
||||
runtime.GC()
|
||||
var ms runtime.MemStats
|
||||
runtime.ReadMemStats(&ms)
|
||||
before := ms.HeapAlloc
|
||||
|
||||
// Drop the manager's external roots so we can isolate
|
||||
// the index cost. We hold the manager itself live; the
|
||||
// index is what we measure on the second pass.
|
||||
mgr.incomingAcceptIndex.reset()
|
||||
mgr.incomingDenyIndex.reset()
|
||||
mgr.incomingAcceptRules = mgr.incomingAcceptRules[:0]
|
||||
mgr.incomingDenyRules = mgr.incomingDenyRules[:0]
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&ms)
|
||||
after := ms.HeapAlloc
|
||||
|
||||
delta := int64(before) - int64(after)
|
||||
if delta < 0 {
|
||||
delta = 0
|
||||
}
|
||||
b.ReportMetric(float64(delta)/float64(c.M), "bytes/rule")
|
||||
b.ReportMetric(float64(delta), "bytes/total")
|
||||
|
||||
require.NoError(b, mgr.Close(nil))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func populateIndexedRules(b *testing.B, mgr *Manager, m, k int, overlapFrac float64) {
|
||||
b.Helper()
|
||||
pool := uniqueAddrs(k+m*k, 1) // big enough universe
|
||||
sharedLen := int(float64(k) * overlapFrac)
|
||||
if sharedLen > k {
|
||||
sharedLen = k
|
||||
}
|
||||
shared := pool[:sharedLen]
|
||||
uniquePool := pool[sharedLen:]
|
||||
|
||||
for i := 0; i < m; i++ {
|
||||
sources := make([]netip.Prefix, 0, k)
|
||||
for _, a := range shared {
|
||||
sources = append(sources, netip.PrefixFrom(a, 32))
|
||||
}
|
||||
// each policy gets (k-sharedLen) addresses unique to it from the unique pool
|
||||
unique := uniquePool[i*(k-sharedLen) : (i+1)*(k-sharedLen)]
|
||||
for _, a := range unique {
|
||||
sources = append(sources, netip.PrefixFrom(a, 32))
|
||||
}
|
||||
_, err := mgr.AddFilterRule(
|
||||
nil, sources, fw.Network{}, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{uint16(80 + i)}},
|
||||
fw.ActionAccept)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
}
|
||||
|
||||
// uniqueAddrs returns n distinct addrs. Seeds 1, 2 are used for
|
||||
// policy sources / dst; seed 99 puts misses in 10/8.
|
||||
func uniqueAddrs(n int, seed int64) []netip.Addr {
|
||||
out := make([]netip.Addr, 0, n)
|
||||
seen := make(map[netip.Addr]struct{}, n)
|
||||
r := rand.New(rand.NewSource(seed))
|
||||
miss := seed == 99
|
||||
for len(out) < n {
|
||||
var b [4]byte
|
||||
if miss {
|
||||
b[0] = 10
|
||||
b[1] = byte(r.Intn(256))
|
||||
} else {
|
||||
b[0] = 100
|
||||
b[1] = byte(64 + r.Intn(63))
|
||||
}
|
||||
b[2] = byte(r.Intn(256))
|
||||
b[3] = byte(1 + r.Intn(254))
|
||||
a := netip.AddrFrom4(b)
|
||||
if _, ok := seen[a]; ok {
|
||||
continue
|
||||
}
|
||||
seen[a] = struct{}{}
|
||||
out = append(out, a)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// uniqueAddrs6 mirrors uniqueAddrs for IPv6: sources come from the ULA
|
||||
// range fd00::/8, the miss set (seed 99) from 2001:db8::/32 so it is
|
||||
// disjoint from any source.
|
||||
func uniqueAddrs6(n int, seed int64) []netip.Addr {
|
||||
out := make([]netip.Addr, 0, n)
|
||||
seen := make(map[netip.Addr]struct{}, n)
|
||||
r := rand.New(rand.NewSource(seed))
|
||||
miss := seed == 99
|
||||
for len(out) < n {
|
||||
var b [16]byte
|
||||
if miss {
|
||||
b[0], b[1], b[2], b[3] = 0x20, 0x01, 0x0d, 0xb8
|
||||
} else {
|
||||
b[0] = 0xfd
|
||||
}
|
||||
for x := 8; x < 16; x++ {
|
||||
b[x] = byte(r.Intn(256))
|
||||
}
|
||||
a := netip.AddrFrom16(b)
|
||||
if _, ok := seen[a]; ok {
|
||||
continue
|
||||
}
|
||||
seen[a] = struct{}{}
|
||||
out = append(out, a)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// generatePacket6 builds an IPv6 TCP/UDP packet, mirroring
|
||||
// generatePacket for the v4 case.
|
||||
func generatePacket6(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16, protocol layers.IPProtocol) []byte {
|
||||
b.Helper()
|
||||
|
||||
ipv6 := &layers.IPv6{
|
||||
Version: 6,
|
||||
HopLimit: 64,
|
||||
NextHeader: protocol,
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
}
|
||||
|
||||
var transportLayer gopacket.SerializableLayer
|
||||
switch protocol {
|
||||
case layers.IPProtocolTCP:
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(srcPort),
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
SYN: true,
|
||||
}
|
||||
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv6))
|
||||
transportLayer = tcp
|
||||
case layers.IPProtocolUDP:
|
||||
udp := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(srcPort),
|
||||
DstPort: layers.UDPPort(dstPort),
|
||||
}
|
||||
require.NoError(b, udp.SetNetworkLayerForChecksum(ipv6))
|
||||
transportLayer = udp
|
||||
}
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv6, transportLayer, gopacket.Payload("test")))
|
||||
return buf.Bytes()
|
||||
}
|
||||
149
client/firewall/uspfilter/peer_acl_dedup_test.go
Normal file
149
client/firewall/uspfilter/peer_acl_dedup_test.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbiface "github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
func newTestManager(t *testing.T) *Manager {
|
||||
t.Helper()
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err, "create manager")
|
||||
return m
|
||||
}
|
||||
|
||||
// TestAddPeerFiltering_DeduplicatesIdenticalRules verifies that adding
|
||||
// the same peer rule twice does not create two backing rules. The acl
|
||||
// manager keys its own cache, but the firewall backend must be
|
||||
// idempotent on its own so a double-apply cannot leak rules, matching
|
||||
// the route path and the kernel backends.
|
||||
func TestAddPeerFiltering_DeduplicatesIdenticalRules(t *testing.T) {
|
||||
m := newTestManager(t)
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
|
||||
first, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err, "first add")
|
||||
|
||||
second, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err, "second add")
|
||||
|
||||
assert.Equal(t, first.ID(), second.ID(), "duplicate add should return the same rule id")
|
||||
assert.Len(t, m.incomingDenyRules, 1, "duplicate add must not create a second backing rule")
|
||||
}
|
||||
|
||||
// TestDeletePeerFiltering_NoRefcountSingleDeleteRemoves locks the
|
||||
// backend's no-refcount contract: a content key installed twice is
|
||||
// still one rule, and the first DeleteFilterRule removes it. The
|
||||
// backend does not refcount, so balance is the caller's job (it keys
|
||||
// its tracking by the returned id and deletes once per key). If this
|
||||
// ever silently grew a refcount, the acl manager's delete accounting
|
||||
// would diverge from the kernel.
|
||||
func TestDeletePeerFiltering_NoRefcountSingleDeleteRemoves(t *testing.T) {
|
||||
m := newTestManager(t)
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
|
||||
first, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err, "first add")
|
||||
|
||||
second, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err, "second add")
|
||||
require.Equal(t, first.ID(), second.ID(), "dedup to one rule")
|
||||
require.Len(t, m.incomingDenyRules, 1, "still one backing rule after duplicate add")
|
||||
|
||||
require.NoError(t, m.DeleteFilterRule(first), "delete once")
|
||||
assert.Empty(t, m.incomingDenyRules, "single delete removes the backing rule (no refcount)")
|
||||
assert.NotContains(t, m.peerRulesMap, first.ID(), "dedup map entry cleared")
|
||||
}
|
||||
|
||||
// TestAddPeerFiltering_DeterministicID verifies the peer rule id is a
|
||||
// content hash, not a random UUID: identical inputs produce the same id
|
||||
// across independent managers. A random id breaks caller-side dedup.
|
||||
func TestAddPeerFiltering_DeterministicID(t *testing.T) {
|
||||
ip := net.ParseIP("10.0.0.5")
|
||||
proto := fw.ProtocolUDP
|
||||
port := &fw.Port{Values: []uint16{53}}
|
||||
action := fw.ActionAccept
|
||||
|
||||
m1 := newTestManager(t)
|
||||
r1, err := m1.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err)
|
||||
|
||||
m2 := newTestManager(t)
|
||||
r2, err := m2.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, r1.ID(), r2.ID(), "same inputs must produce the same rule id")
|
||||
}
|
||||
|
||||
// TestAddPeerFiltering_DistinctRulesNotDeduped verifies that rules
|
||||
// differing only by port are kept separate.
|
||||
func TestAddPeerFiltering_DistinctRulesNotDeduped(t *testing.T) {
|
||||
m := newTestManager(t)
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
action := fw.ActionAccept
|
||||
|
||||
r80, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, &fw.Port{Values: []uint16{80}}, action)
|
||||
require.NoError(t, err)
|
||||
|
||||
r443, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, &fw.Port{Values: []uint16{443}}, action)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, r80.ID(), r443.ID(), "different ports must produce different rule ids")
|
||||
assert.Len(t, m.incomingAcceptRules, 2, "distinct rules must both be stored")
|
||||
}
|
||||
|
||||
// TestAddPeerFiltering_SourceVsDestPortNotDeduped verifies that a rule
|
||||
// matching on source port and one matching on destination port for the
|
||||
// same selector do not collide: the port lands in a different slot, so
|
||||
// the content key must differ.
|
||||
func TestAddPeerFiltering_SourceVsDestPortNotDeduped(t *testing.T) {
|
||||
m := newTestManager(t)
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionAccept
|
||||
|
||||
dPortRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err)
|
||||
|
||||
sPortRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, port, nil, action)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, dPortRule.ID(), sPortRule.ID(), "source-port and dest-port matches must produce different rule ids")
|
||||
}
|
||||
|
||||
// TestAddFilterRule_EmptySourcesRejected verifies that an empty source
|
||||
// list is rejected rather than treated as "match any". "Match any" must
|
||||
// be an explicit /0, so a zeroed list can never silently widen a rule to
|
||||
// every source.
|
||||
func TestAddFilterRule_EmptySourcesRejected(t *testing.T) {
|
||||
m := newTestManager(t)
|
||||
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
|
||||
_, err := m.AddFilterRule(nil, nil, fw.Network{}, proto, nil, port, fw.ActionAccept)
|
||||
require.ErrorIs(t, err, fw.ErrNoSources, "empty sources must be rejected")
|
||||
assert.Empty(t, m.incomingAcceptRules, "no rule should be stored for empty sources")
|
||||
}
|
||||
105
client/firewall/uspfilter/peer_acl_ipv6_test.go
Normal file
105
client/firewall/uspfilter/peer_acl_ipv6_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbiface "github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
func newV6TestManager(t *testing.T, localV6 string) *Manager {
|
||||
t.Helper()
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.10.0.100"),
|
||||
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||
IPv6: netip.MustParseAddr(localV6),
|
||||
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
||||
}
|
||||
},
|
||||
}
|
||||
m, err := Create(ifaceMock, nil, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err, "create manager")
|
||||
t.Cleanup(func() { require.NoError(t, m.Close(nil)) })
|
||||
return m
|
||||
}
|
||||
|
||||
func v6UDPPacket(t *testing.T, src, dst string, dstPort uint16) []byte {
|
||||
t.Helper()
|
||||
ip6 := &layers.IPv6{
|
||||
Version: 6,
|
||||
HopLimit: 64,
|
||||
NextHeader: layers.IPProtocolUDP,
|
||||
SrcIP: net.ParseIP(src),
|
||||
DstIP: net.ParseIP(dst),
|
||||
}
|
||||
udp := &layers.UDP{SrcPort: 51334, DstPort: layers.UDPPort(dstPort)}
|
||||
require.NoError(t, udp.SetNetworkLayerForChecksum(ip6))
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
require.NoError(t, gopacket.SerializeLayers(buf, opts, ip6, udp, gopacket.Payload("test")))
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// TestPeerACL_IPv6HostRule verifies the source index resolves /128 v6
|
||||
// rules: a matching v6 source is accepted, a non-matching one is
|
||||
// denied by the default. This is the end-to-end proof that the index
|
||||
// is not v4-only.
|
||||
func TestPeerACL_IPv6HostRule(t *testing.T) {
|
||||
m := newV6TestManager(t, "fd00::100")
|
||||
|
||||
src := net.ParseIP("fd00::1")
|
||||
_, err := m.AddFilterRule(nil, pfx(src), fw.Network{}, fw.ProtocolUDP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionAccept)
|
||||
require.NoError(t, err, "add v6 accept rule")
|
||||
|
||||
require.False(t, m.filterInbound(v6UDPPacket(t, "fd00::1", "fd00::100", 53), 0),
|
||||
"v6 packet from the allowed /128 source must be accepted")
|
||||
require.True(t, m.filterInbound(v6UDPPacket(t, "fd00::2", "fd00::100", 53), 0),
|
||||
"v6 packet from an unlisted source must be denied by default")
|
||||
}
|
||||
|
||||
// TestPeerACL_IPv6IndexBuckets verifies that v6 sources land in the
|
||||
// right index bucket: a /128 in bySource keyed by its address, and
|
||||
// coarser prefixes (including ::/0) in the nonHost slice.
|
||||
func TestPeerACL_IPv6IndexBuckets(t *testing.T) {
|
||||
m := newV6TestManager(t, "fd00::100")
|
||||
port := &fw.Port{Values: []uint16{53}}
|
||||
|
||||
host := netip.MustParseAddr("fd00::1")
|
||||
_, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(host, 128)}, fw.Network{}, fw.ProtocolUDP, nil, port, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, m.incomingAcceptIndex.bySource, host, "/128 v6 source must be indexed by address")
|
||||
|
||||
_, err = m.AddFilterRule(nil, []netip.Prefix{netip.MustParsePrefix("fd00:dead::/64")}, fw.Network{}, fw.ProtocolUDP, nil, port, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, m.incomingAcceptIndex.nonHost, 1, "coarser v6 prefix must land in nonHost")
|
||||
|
||||
_, err = m.AddFilterRule(nil, []netip.Prefix{netip.MustParsePrefix("::/0")}, fw.Network{}, fw.ProtocolUDP, nil, port, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, m.incomingAcceptIndex.nonHost, 2, "::/0 source must also land in nonHost")
|
||||
}
|
||||
|
||||
// TestPeerACL_IPv4MappedSourceNormalized verifies a v4-mapped v6
|
||||
// source prefix is normalized to v4 so a plain v4 packet matches it.
|
||||
func TestPeerACL_IPv4MappedSourceNormalized(t *testing.T) {
|
||||
m := newTestManager(t)
|
||||
|
||||
mapped := netip.MustParseAddr("::ffff:192.168.1.1")
|
||||
_, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(mapped, mapped.BitLen())}, fw.Network{}, fw.ProtocolUDP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
|
||||
v4 := netip.MustParseAddr("192.168.1.1")
|
||||
assert.Contains(t, m.incomingAcceptIndex.bySource, v4, "v4-mapped v6 source must be indexed as plain v4")
|
||||
}
|
||||
139
client/firewall/uspfilter/peer_index.go
Normal file
139
client/firewall/uspfilter/peer_index.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
// peerRuleIndex is the source-side dispatcher consulted on the packet
|
||||
// hot path. It splits rules into two buckets by the shape of their
|
||||
// source list:
|
||||
//
|
||||
// - bySource: every source is a host prefix (/32 for v4, /128 for
|
||||
// v6). Keyed by the concrete source address, so a hit guarantees
|
||||
// the source filter passes and the matcher goes straight to
|
||||
// proto/port checks. This is the common case for peer ACLs.
|
||||
// - nonHost: any source list with a prefix coarser than a host,
|
||||
// including a /0 "match any". Walked linearly with a per-rule
|
||||
// Contains() check. Expected small or empty for typical peer ACLs.
|
||||
//
|
||||
// Maintained incrementally by add/remove, never rebuilt.
|
||||
type peerRuleIndex struct {
|
||||
bySource map[netip.Addr][]*PeerRule
|
||||
nonHost []*PeerRule
|
||||
}
|
||||
|
||||
func (i *peerRuleIndex) add(r *PeerRule) {
|
||||
if hasNonHostSource(r) {
|
||||
i.nonHost = append(i.nonHost, r)
|
||||
return
|
||||
}
|
||||
if i.bySource == nil {
|
||||
i.bySource = make(map[netip.Addr][]*PeerRule)
|
||||
}
|
||||
for a := range r.sourceAddrs {
|
||||
i.bySource[a] = append(i.bySource[a], r)
|
||||
}
|
||||
}
|
||||
|
||||
func (i *peerRuleIndex) remove(r *PeerRule) {
|
||||
if hasNonHostSource(r) {
|
||||
i.nonHost = slices.DeleteFunc(i.nonHost, eqRule(r))
|
||||
return
|
||||
}
|
||||
if i.bySource == nil {
|
||||
return
|
||||
}
|
||||
for a := range r.sourceAddrs {
|
||||
entries := slices.DeleteFunc(i.bySource[a], eqRule(r))
|
||||
if len(entries) == 0 {
|
||||
delete(i.bySource, a)
|
||||
} else {
|
||||
i.bySource[a] = entries
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (i *peerRuleIndex) reset() {
|
||||
i.bySource = nil
|
||||
i.nonHost = i.nonHost[:0]
|
||||
}
|
||||
|
||||
// match returns the first rule matching src and the decoded packet.
|
||||
// Host rules are found by direct map lookup; nonHost rules need a
|
||||
// per-rule source Contains() check, except match-any (/0) rules which
|
||||
// apply to every source regardless of family (a v4 /0 also matches v6).
|
||||
// Within either bucket the matcher runs the proto/port filter.
|
||||
func (i *peerRuleIndex) match(src netip.Addr, d *decoder) ([]byte, bool, bool) {
|
||||
payloadLayer := d.decoded[1]
|
||||
|
||||
for _, rule := range i.bySource[src] {
|
||||
if id, drop, ok := matchProto(rule, d, payloadLayer); ok {
|
||||
return id, drop, true
|
||||
}
|
||||
}
|
||||
for _, rule := range i.nonHost {
|
||||
if !rule.matchAny && !prefixesContain(rule.sources, src) {
|
||||
continue
|
||||
}
|
||||
if id, drop, ok := matchProto(rule, d, payloadLayer); ok {
|
||||
return id, drop, true
|
||||
}
|
||||
}
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
func eqRule(target *PeerRule) func(*PeerRule) bool {
|
||||
return func(p *PeerRule) bool { return p == target }
|
||||
}
|
||||
|
||||
// hasNonHostSource reports whether the rule has any source prefix
|
||||
// that is not a single host address. Called only at add/remove time,
|
||||
// not on the packet path.
|
||||
func hasNonHostSource(r *PeerRule) bool {
|
||||
for _, p := range r.sources {
|
||||
if p.Bits() != p.Addr().BitLen() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// matchProto applies the proto/port half of a rule against the
|
||||
// decoded packet. Source matching is the caller's responsibility.
|
||||
func matchProto(rule *PeerRule, d *decoder, payloadLayer gopacket.LayerType) ([]byte, bool, bool) {
|
||||
drop := rule.action == firewall.ActionDrop
|
||||
if rule.protoLayer == layerTypeAll {
|
||||
return rule.mgmtId, drop, true
|
||||
}
|
||||
if !protoLayerMatches(rule.protoLayer, payloadLayer) {
|
||||
return nil, false, false
|
||||
}
|
||||
switch payloadLayer {
|
||||
case layers.LayerTypeTCP:
|
||||
if portsMatch(rule.srcPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dstPort, uint16(d.tcp.DstPort)) {
|
||||
return rule.mgmtId, drop, true
|
||||
}
|
||||
case layers.LayerTypeUDP:
|
||||
if portsMatch(rule.srcPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dstPort, uint16(d.udp.DstPort)) {
|
||||
return rule.mgmtId, drop, true
|
||||
}
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
return rule.mgmtId, drop, true
|
||||
}
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
func prefixesContain(sources []netip.Prefix, src netip.Addr) bool {
|
||||
for _, p := range sources {
|
||||
if p.Contains(src) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -10,24 +10,49 @@ import (
|
||||
|
||||
// PeerRule to handle management of rules
|
||||
type PeerRule struct {
|
||||
id string
|
||||
mgmtId []byte
|
||||
ip netip.Addr
|
||||
ipLayer gopacket.LayerType
|
||||
matchByIP bool
|
||||
id firewall.RuleID
|
||||
mgmtId []byte
|
||||
// sources is the canonical list of source prefixes this rule
|
||||
// matches against.
|
||||
sources []netip.Prefix
|
||||
// sourceAddrs is a fast-path membership set for host-prefix
|
||||
// sources (/32 v4, /128 v6). Populated alongside sources;
|
||||
// consulted before falling back to prefix scan.
|
||||
sourceAddrs map[netip.Addr]struct{}
|
||||
// matchAny is true when sources covers everything (0.0.0.0/0,
|
||||
// ::/0). In that case neither sourceAddrs nor sources need to be
|
||||
// consulted.
|
||||
matchAny bool
|
||||
protoLayer gopacket.LayerType
|
||||
sPort *firewall.Port
|
||||
dPort *firewall.Port
|
||||
drop bool
|
||||
srcPort *firewall.Port
|
||||
dstPort *firewall.Port
|
||||
action firewall.Action
|
||||
}
|
||||
|
||||
// matchesSource reports whether the given source address is covered
|
||||
// by this rule's source list.
|
||||
func (r *PeerRule) matchesSource(src netip.Addr) bool {
|
||||
if r.matchAny {
|
||||
return true
|
||||
}
|
||||
if _, ok := r.sourceAddrs[src]; ok {
|
||||
return true
|
||||
}
|
||||
for _, p := range r.sources {
|
||||
if p.Contains(src) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ID returns the rule id
|
||||
func (r *PeerRule) ID() string {
|
||||
func (r *PeerRule) ID() firewall.RuleID {
|
||||
return r.id
|
||||
}
|
||||
|
||||
type RouteRule struct {
|
||||
id string
|
||||
id firewall.RuleID
|
||||
mgmtId []byte
|
||||
sources []netip.Prefix
|
||||
dstSet firewall.Set
|
||||
@@ -39,6 +64,6 @@ type RouteRule struct {
|
||||
}
|
||||
|
||||
// ID returns the rule id
|
||||
func (r *RouteRule) ID() string {
|
||||
func (r *RouteRule) ID() firewall.RuleID {
|
||||
return r.id
|
||||
}
|
||||
|
||||
50
client/firewall/uspfilter/testhelpers_test.go
Normal file
50
client/firewall/uspfilter/testhelpers_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
// countRulesForAddr reports how many rules in the given slice match
|
||||
// the supplied source address.
|
||||
func countRulesForAddr(rules peerRules, src netip.Addr) int {
|
||||
n := 0
|
||||
for _, r := range rules {
|
||||
if r.matchesSource(src) {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// findRuleByID returns true if the rules slice contains a rule with
|
||||
// the given id whose source set covers src.
|
||||
func findRuleByID(rules peerRules, src netip.Addr, id firewall.RuleID) bool {
|
||||
for _, r := range rules {
|
||||
if r.id == id && r.matchesSource(src) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// pfx converts a single net.IP into the []netip.Prefix form
|
||||
// AddFilterRule expects. A nil or unspecified address becomes a /0
|
||||
// ("match any") prefix in the matching family; any other address
|
||||
// becomes its /32 (or /128) host prefix.
|
||||
func pfx(ip net.IP) []netip.Prefix {
|
||||
if ip == nil {
|
||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||
}
|
||||
if ip.IsUnspecified() {
|
||||
if ip.To4() != nil {
|
||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||
}
|
||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||
}
|
||||
a, _ := netip.AddrFromSlice(ip)
|
||||
a = a.Unmap()
|
||||
return []netip.Prefix{netip.PrefixFrom(a, a.BitLen())}
|
||||
}
|
||||
@@ -285,6 +285,14 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
||||
trace.SourceIP = srcIP
|
||||
trace.DestinationIP = dstIP
|
||||
|
||||
// A fragment or otherwise truncated packet has no transport layer.
|
||||
// The inbound datapath drops these via isValidPacket; the tracer must
|
||||
// guard explicitly since every downstream stage reads d.decoded[1].
|
||||
if len(d.decoded) < 2 {
|
||||
trace.AddResult(StageReceived, "Packet has no transport layer (fragment or unsupported protocol)", false)
|
||||
return trace
|
||||
}
|
||||
|
||||
// Determine protocol and ports
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
|
||||
@@ -45,7 +45,7 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
m, err := Create(ifaceMock, nil, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
|
||||
if !statefulMode {
|
||||
@@ -97,7 +97,7 @@ func TestTracePacket(t *testing.T) {
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
@@ -121,7 +121,7 @@ func TestTracePacket(t *testing.T) {
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
@@ -150,7 +150,7 @@ func TestTracePacket(t *testing.T) {
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
@@ -178,7 +178,7 @@ func TestTracePacket(t *testing.T) {
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
@@ -205,7 +205,7 @@ func TestTracePacket(t *testing.T) {
|
||||
|
||||
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
|
||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
_, err := m.AddFilterRule(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
@@ -231,7 +231,7 @@ func TestTracePacket(t *testing.T) {
|
||||
|
||||
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
|
||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||
_, err := m.AddFilterRule(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
@@ -332,7 +332,7 @@ func TestTracePacket(t *testing.T) {
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolICMP
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, nil, action)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
@@ -355,7 +355,7 @@ func TestTracePacket(t *testing.T) {
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolICMP
|
||||
action := fw.ActionDrop
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, nil, action)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
@@ -379,7 +379,7 @@ func TestTracePacket(t *testing.T) {
|
||||
proto := fw.ProtocolUDP
|
||||
port := &fw.Port{Values: []uint16{53}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
@@ -423,7 +423,7 @@ func TestTracePacket(t *testing.T) {
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
|
||||
@@ -260,23 +260,15 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
|
||||
|
||||
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
|
||||
|
||||
; Drop Run, App Paths and Uninstall entries left in the 32-bit registry view
|
||||
; or HKCU by legacy installers.
|
||||
DetailPrint "Cleaning legacy 32-bit / HKCU entries..."
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
SetRegView 32
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
DeleteRegKey HKLM "${REG_APP_PATH}"
|
||||
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
|
||||
DeleteRegKey HKLM "${UNINSTALL_PATH}"
|
||||
SetRegView 64
|
||||
|
||||
; Create autostart registry entry based on checkbox
|
||||
DetailPrint "Autostart enabled: $AutostartEnabled"
|
||||
${If} $AutostartEnabled == "1"
|
||||
WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"'
|
||||
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
|
||||
${Else}
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
DetailPrint "Autostart not enabled by user"
|
||||
${EndIf}
|
||||
|
||||
@@ -307,16 +299,11 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
|
||||
DetailPrint "Terminating Netbird UI process..."
|
||||
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
||||
|
||||
; Remove autostart entries from every view a previous installer may have used.
|
||||
; Remove autostart registry entry
|
||||
DetailPrint "Removing autostart registry entry if exists..."
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
|
||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
SetRegView 32
|
||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||
DeleteRegKey HKLM "${REG_APP_PATH}"
|
||||
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
|
||||
DeleteRegKey HKLM "${UNINSTALL_PATH}"
|
||||
SetRegView 64
|
||||
|
||||
; Handle data deletion based on checkbox
|
||||
DetailPrint "Checking if user requested data deletion..."
|
||||
|
||||
190
client/internal/acl/dispatch_test.go
Normal file
190
client/internal/acl/dispatch_test.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
fwmgr "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// TestNetworkZeroPrefixIsRoute guards the route-vs-peer dispatch
|
||||
// invariant: the backends classify a rule as a peer rule purely by the
|
||||
// absence of a destination (neither prefix nor set). A default route
|
||||
// (0.0.0.0/0 or ::/0) is a valid prefix and must therefore classify as
|
||||
// a route, not collapse into the peer path.
|
||||
func TestNetworkZeroPrefixIsRoute(t *testing.T) {
|
||||
for _, p := range []string{"0.0.0.0/0", "::/0", "10.0.0.0/8"} {
|
||||
n := fwmgr.Network{Prefix: netip.MustParsePrefix(p)}
|
||||
assert.True(t, n.IsPrefix(), "%s must report IsPrefix", p)
|
||||
assert.True(t, n.IsPrefix() || n.IsSet(), "%s must classify as a route", p)
|
||||
}
|
||||
|
||||
// A zero-value Network is the only peer-rule shape.
|
||||
var empty fwmgr.Network
|
||||
assert.False(t, empty.IsPrefix(), "zero Network must not be a prefix")
|
||||
assert.False(t, empty.IsSet(), "zero Network must not be a set")
|
||||
}
|
||||
|
||||
// TestDetermineDestinationAlwaysRoute verifies determineDestination
|
||||
// never yields an empty Network for a valid route rule: every branch
|
||||
// (static prefix, default route, dynamic with/without domains, with and
|
||||
// without a local resolver) produces a destination that classifies as a
|
||||
// route. If this regresses, a route rule would be dispatched down the
|
||||
// peer path, which matches on source only.
|
||||
func TestDetermineDestinationAlwaysRoute(t *testing.T) {
|
||||
v4 := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}
|
||||
v6 := []netip.Prefix{netip.MustParsePrefix("2001:db8::/48")}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
rule *mgmProto.RouteFirewallRule
|
||||
resolver bool
|
||||
sources []netip.Prefix
|
||||
}{
|
||||
{"static prefix", &mgmProto.RouteFirewallRule{Destination: "192.168.0.0/16"}, false, v4},
|
||||
{"static default route", &mgmProto.RouteFirewallRule{Destination: "0.0.0.0/0"}, false, v4},
|
||||
{"dynamic with domains + resolver", &mgmProto.RouteFirewallRule{IsDynamic: true, Domains: []string{"example.com"}}, true, v4},
|
||||
{"dynamic no domains + resolver (v4)", &mgmProto.RouteFirewallRule{IsDynamic: true}, true, v4},
|
||||
{"dynamic no domains + resolver (v6)", &mgmProto.RouteFirewallRule{IsDynamic: true}, true, v6},
|
||||
{"dynamic + no local resolver (v4)", &mgmProto.RouteFirewallRule{IsDynamic: true}, false, v4},
|
||||
{"dynamic + no local resolver (v6)", &mgmProto.RouteFirewallRule{IsDynamic: true}, false, v6},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
dest, err := determineDestination(tc.rule, tc.resolver, tc.sources)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, dest.IsPrefix() || dest.IsSet(),
|
||||
"destination must classify as a route, got empty Network")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// countingFirewall wraps a real firewall.Manager and counts filter-rule
|
||||
// add/delete calls so a test can assert how many backing rules the acl
|
||||
// manager actually creates and tears down.
|
||||
type countingFirewall struct {
|
||||
fwmgr.Manager
|
||||
mu sync.Mutex
|
||||
addCalls int
|
||||
dels int
|
||||
ruleIDs map[fwmgr.RuleID]struct{}
|
||||
}
|
||||
|
||||
// distinctRules returns the number of distinct backing rules the
|
||||
// backend produced. Because the backend dedups identical content,
|
||||
// repeated AddFilterRule calls for the same rule resolve to one id.
|
||||
func (f *countingFirewall) distinctRules() int {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return len(f.ruleIDs)
|
||||
}
|
||||
|
||||
func (f *countingFirewall) AddFilterRule(id []byte, sources []netip.Prefix, destination fwmgr.Network, proto fwmgr.Protocol, sPort, dPort *fwmgr.Port, action fwmgr.Action) (fwmgr.Rule, error) {
|
||||
rule, err := f.Manager.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
|
||||
if err == nil {
|
||||
f.mu.Lock()
|
||||
f.addCalls++
|
||||
if f.ruleIDs == nil {
|
||||
f.ruleIDs = make(map[fwmgr.RuleID]struct{})
|
||||
}
|
||||
if rule != nil {
|
||||
f.ruleIDs[rule.ID()] = struct{}{}
|
||||
}
|
||||
f.mu.Unlock()
|
||||
}
|
||||
return rule, err
|
||||
}
|
||||
|
||||
func (f *countingFirewall) DeleteFilterRule(r fwmgr.Rule) error {
|
||||
err := f.Manager.DeleteFilterRule(r)
|
||||
if err == nil {
|
||||
f.mu.Lock()
|
||||
f.dels++
|
||||
delete(f.ruleIDs, r.ID())
|
||||
f.mu.Unlock()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func newCountingACL(t *testing.T) (*DefaultManager, *countingFirewall, func()) {
|
||||
t.Helper()
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{IP: network.Addr(), Network: network}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
realFW, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
|
||||
fw := &countingFirewall{Manager: realFW}
|
||||
cleanup := func() {
|
||||
require.NoError(t, realFW.Close(nil))
|
||||
ctrl.Finish()
|
||||
}
|
||||
return NewDefaultManager(fw), fw, cleanup
|
||||
}
|
||||
|
||||
// TestDuplicateContentPoliciesShareOneRule verifies the dedup contract
|
||||
// the backends rely on: two policies that authorize an identical flow
|
||||
// (same selector and sources) collapse to a single backing firewall
|
||||
// rule, and that rule survives until BOTH policies are gone. This is
|
||||
// why the backend can dedup on add without refcounting on delete: the
|
||||
// acl manager's pair key matches the backend's content key, so add and
|
||||
// delete stay balanced per content key across full-state reapplies.
|
||||
func TestDuplicateContentPoliciesShareOneRule(t *testing.T) {
|
||||
acl, fw, cleanup := newCountingACL(t)
|
||||
defer cleanup()
|
||||
|
||||
ruleA := &mgmProto.FirewallRule{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: "10.0.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
}
|
||||
ruleB := &mgmProto.FirewallRule{
|
||||
PolicyID: []byte("policy-B"),
|
||||
PeerIP: "10.0.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
}
|
||||
|
||||
// Both policies present: identical content collapses to one rule.
|
||||
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: []*mgmProto.FirewallRule{ruleA, ruleB}, FirewallRulesIsEmpty: false}, false)
|
||||
assert.Equal(t, 1, fw.distinctRules(), "identical-content policies must produce one backing rule")
|
||||
assert.Equal(t, 1, len(acl.peerRulesPairs), "one content key, one pair")
|
||||
|
||||
// Drop policy A only: the shared rule is still authorized by B, so
|
||||
// nothing is deleted.
|
||||
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: []*mgmProto.FirewallRule{ruleB}, FirewallRulesIsEmpty: false}, false)
|
||||
assert.Equal(t, 1, fw.distinctRules(), "no new backing rule on reapply")
|
||||
assert.Equal(t, 0, fw.dels, "rule must survive while any policy still authorizes it")
|
||||
assert.Equal(t, 1, len(acl.peerRulesPairs))
|
||||
|
||||
// Drop policy B too: now the content key has no authorizer and the
|
||||
// single backing rule is removed exactly once.
|
||||
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: nil, FirewallRulesIsEmpty: true}, false)
|
||||
assert.Equal(t, 1, fw.dels, "rule removed once when last policy is gone")
|
||||
assert.Equal(t, 0, len(acl.peerRulesPairs))
|
||||
}
|
||||
318
client/internal/acl/grouping_test.go
Normal file
318
client/internal/acl/grouping_test.go
Normal file
@@ -0,0 +1,318 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
fwmgr "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
)
|
||||
|
||||
// TestGroupPeerRulesPolicyIDSeparates verifies that two FirewallRules
|
||||
// with identical selectors but different PolicyIDs do NOT get merged
|
||||
// into one group, so each policy's sources merge under its own
|
||||
// attribution id. (Identical-content groups may still dedup to one
|
||||
// backing rule at the backend; see TestDuplicateContentPoliciesShareOneRule.)
|
||||
func TestGroupPeerRulesPolicyIDSeparates(t *testing.T) {
|
||||
rules := []*mgmProto.FirewallRule{
|
||||
{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: "10.0.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PolicyID: []byte("policy-B"),
|
||||
PeerIP: "10.0.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
}
|
||||
|
||||
groups, denyErr, err := groupPeerRules(rules)
|
||||
require.NoError(t, denyErr)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, groups, 2, "rules with different PolicyIDs must produce separate groups")
|
||||
}
|
||||
|
||||
// TestGroupPeerRulesFamilySeparates verifies that v4 and v6 rules
|
||||
// belonging to the same policy don't merge.
|
||||
func TestGroupPeerRulesFamilySeparates(t *testing.T) {
|
||||
rules := []*mgmProto.FirewallRule{
|
||||
{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: "10.0.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: "2001:db8::1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
}
|
||||
|
||||
groups, denyErr, err := groupPeerRules(rules)
|
||||
require.NoError(t, denyErr)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, groups, 2, "rules of different families must produce separate groups")
|
||||
|
||||
var sawV4, sawV6 bool
|
||||
for _, g := range groups {
|
||||
require.Len(t, g.sources, 1)
|
||||
if g.sources[0].Addr().Is4() {
|
||||
sawV4 = true
|
||||
}
|
||||
if g.sources[0].Addr().Is6() {
|
||||
sawV6 = true
|
||||
}
|
||||
}
|
||||
assert.True(t, sawV4 && sawV6)
|
||||
}
|
||||
|
||||
// TestGroupPeerRulesSplitsMixedFamilySingleRule verifies that a single
|
||||
// FirewallRule carrying both v4 and v6 source prefixes is split into one
|
||||
// group per family. Each backend keys a rule to a single family, so a
|
||||
// group whose sources span families would mismatch the other family's
|
||||
// sources. mgmt normally emits one rule per family; this guards against
|
||||
// a mixed-family rule slipping through.
|
||||
func TestGroupPeerRulesSplitsMixedFamilySingleRule(t *testing.T) {
|
||||
srcs := [][]byte{
|
||||
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.1")),
|
||||
netiputil.EncodeAddr(netip.MustParseAddr("2001:db8::1")),
|
||||
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.2")),
|
||||
netiputil.EncodeAddr(netip.MustParseAddr("2001:db8::2")),
|
||||
}
|
||||
rules := []*mgmProto.FirewallRule{
|
||||
{
|
||||
PolicyID: []byte("policy-A"),
|
||||
SourcePrefixes: srcs,
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
}
|
||||
|
||||
groups, denyErr, err := groupPeerRules(rules)
|
||||
require.NoError(t, denyErr)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, groups, 2, "mixed-family sources in one rule must split into two groups")
|
||||
|
||||
for _, g := range groups {
|
||||
require.Len(t, g.sources, 2)
|
||||
v6 := prefixIsV6(g.sources[0])
|
||||
for _, s := range g.sources {
|
||||
assert.Equal(t, v6, prefixIsV6(s), "every source in a group must share one family")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestGroupPeerRulesMergesSameSelector verifies that rules sharing
|
||||
// every distinguishing field (policy, family, direction, action,
|
||||
// proto, port) collapse into a single multi-source group.
|
||||
func TestGroupPeerRulesMergesSameSelector(t *testing.T) {
|
||||
mk := func(peerIP string) *mgmProto.FirewallRule {
|
||||
return &mgmProto.FirewallRule{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: peerIP,
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
}
|
||||
}
|
||||
rules := []*mgmProto.FirewallRule{mk("10.0.0.1"), mk("10.0.0.2"), mk("10.0.0.3")}
|
||||
|
||||
groups, denyErr, err := groupPeerRules(rules)
|
||||
require.NoError(t, denyErr)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, groups, 1)
|
||||
require.Len(t, groups[0].sources, 3)
|
||||
}
|
||||
|
||||
// TestGroupPeerRulesPortSeparates verifies that PortInfo is part of the
|
||||
// selector key: rules differing only in port must not merge, and a
|
||||
// single port must not merge with a range. A regression dropping the
|
||||
// port from the key would collapse rules for different ports into one.
|
||||
func TestGroupPeerRulesPortSeparates(t *testing.T) {
|
||||
mkPort := func(peerIP string, port uint32) *mgmProto.FirewallRule {
|
||||
return &mgmProto.FirewallRule{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: peerIP,
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{PortSelection: &mgmProto.PortInfo_Port{Port: port}},
|
||||
}
|
||||
}
|
||||
|
||||
groups, denyErr, err := groupPeerRules([]*mgmProto.FirewallRule{
|
||||
mkPort("10.0.0.1", 80), mkPort("10.0.0.2", 80), mkPort("10.0.0.3", 443),
|
||||
})
|
||||
require.NoError(t, denyErr)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, groups, 2, "rules on different ports must not merge")
|
||||
|
||||
rangeRule := &mgmProto.FirewallRule{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: "10.0.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{PortSelection: &mgmProto.PortInfo_Range_{Range: &mgmProto.PortInfo_Range{Start: 80, End: 90}}},
|
||||
}
|
||||
groups, denyErr, err = groupPeerRules([]*mgmProto.FirewallRule{mkPort("10.0.0.1", 80), rangeRule})
|
||||
require.NoError(t, denyErr)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, groups, 2, "a single port and a range must not merge")
|
||||
}
|
||||
|
||||
// TestGroupPeerRulesUsesSourcePrefixesWhenPresent verifies that the
|
||||
// new sourcePrefixes wire field is consumed and produces a
|
||||
// multi-source group in one shot (no client-side merging needed).
|
||||
func TestGroupPeerRulesUsesSourcePrefixesWhenPresent(t *testing.T) {
|
||||
srcs := [][]byte{
|
||||
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.1")),
|
||||
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.2")),
|
||||
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.3")),
|
||||
}
|
||||
rules := []*mgmProto.FirewallRule{
|
||||
{
|
||||
PolicyID: []byte("policy-A"),
|
||||
SourcePrefixes: srcs,
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
}
|
||||
|
||||
groups, denyErr, err := groupPeerRules(rules)
|
||||
require.NoError(t, denyErr)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, groups, 1)
|
||||
require.Len(t, groups[0].sources, 3)
|
||||
}
|
||||
|
||||
// TestGroupPeerRulesActionSeparates verifies the obvious: accept
|
||||
// and drop rules with the same selector don't merge.
|
||||
func TestGroupPeerRulesActionSeparates(t *testing.T) {
|
||||
rules := []*mgmProto.FirewallRule{
|
||||
{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: "10.0.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: "10.0.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
}
|
||||
|
||||
groups, denyErr, err := groupPeerRules(rules)
|
||||
require.NoError(t, denyErr)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, groups, 2)
|
||||
}
|
||||
|
||||
// failingDeleteFirewall wraps a real firewall.Manager and forces the
|
||||
// next N DeleteFilterRule calls to fail. Used to verify that the acl
|
||||
// manager retains rules whose deletion was rejected by the backend,
|
||||
// so they get retried on the next ApplyFiltering pass instead of
|
||||
// becoming orphans.
|
||||
type failingDeleteFirewall struct {
|
||||
fwmgr.Manager
|
||||
failCount int
|
||||
}
|
||||
|
||||
func (f *failingDeleteFirewall) DeleteFilterRule(r fwmgr.Rule) error {
|
||||
if f.failCount > 0 {
|
||||
f.failCount--
|
||||
return errors.New("simulated delete failure")
|
||||
}
|
||||
return f.Manager.DeleteFilterRule(r)
|
||||
}
|
||||
|
||||
// TestApplyFilteringRetainsRulesOnDeleteFailure verifies that a
|
||||
// transient DeleteFilterRule error doesn't make the acl manager
|
||||
// forget about a rule. The rule must remain in peerRulesPairs so the
|
||||
// next ApplyFiltering pass attempts the delete again.
|
||||
func TestApplyFilteringRetainsRulesOnDeleteFailure(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{IP: network.Addr(), Network: network}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
realFW, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() { require.NoError(t, realFW.Close(nil)) }()
|
||||
|
||||
fw := &failingDeleteFirewall{Manager: realFW}
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
// First pass: install a rule.
|
||||
netmap1 := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: "10.0.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "22",
|
||||
},
|
||||
},
|
||||
FirewallRulesIsEmpty: false,
|
||||
}
|
||||
acl.ApplyFiltering(netmap1, false)
|
||||
require.Equal(t, 1, len(acl.peerRulesPairs), "rule should be installed")
|
||||
|
||||
// Second pass: remove the rule from the map. The backend will
|
||||
// fail the delete; the acl manager must retain the rule.
|
||||
fw.failCount = 1
|
||||
netmap2 := &mgmProto.NetworkMap{FirewallRules: nil, FirewallRulesIsEmpty: true}
|
||||
acl.ApplyFiltering(netmap2, false)
|
||||
require.Equal(t, 1, len(acl.peerRulesPairs),
|
||||
"rule must be retained when DeleteFilterRule fails so it gets retried")
|
||||
|
||||
// Third pass: same map, backend no longer fails. The rule
|
||||
// should now succeed in being removed.
|
||||
acl.ApplyFiltering(netmap2, false)
|
||||
require.Equal(t, 0, len(acl.peerRulesPairs), "retry should succeed")
|
||||
}
|
||||
@@ -5,18 +5,18 @@ import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
type RuleID string
|
||||
// RuleID aliases manager.RuleID so existing nbid.RuleID references
|
||||
// keep working while the canonical type lives in the firewall package.
|
||||
type RuleID = manager.RuleID
|
||||
|
||||
func (r RuleID) ID() string {
|
||||
return string(r)
|
||||
}
|
||||
|
||||
func GenerateRouteRuleKey(
|
||||
// GenerateRuleID returns a deterministic content hash identifying a filter rule.
|
||||
func GenerateRuleID(
|
||||
sources []netip.Prefix,
|
||||
destination manager.Network,
|
||||
proto manager.Protocol,
|
||||
@@ -24,6 +24,7 @@ func GenerateRouteRuleKey(
|
||||
dPort *manager.Port,
|
||||
action manager.Action,
|
||||
) RuleID {
|
||||
sources = slices.Clone(sources)
|
||||
manager.SortPrefixes(sources)
|
||||
|
||||
h := sha256.New()
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
@@ -23,6 +21,10 @@ import (
|
||||
|
||||
var ErrSourceRangesEmpty = errors.New("sources range is empty")
|
||||
|
||||
// ErrNoRuleReturned is returned when the firewall backend reports success
|
||||
// from AddFilterRule but yields no rule to track.
|
||||
var ErrNoRuleReturned = errors.New("backend returned no rule")
|
||||
|
||||
// Manager is a ACL rules manager
|
||||
type Manager interface {
|
||||
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
|
||||
@@ -31,17 +33,46 @@ type Manager interface {
|
||||
// DefaultManager uses firewall manager to handle
|
||||
type DefaultManager struct {
|
||||
firewall firewall.Manager
|
||||
ipsetCounter int
|
||||
peerRulesPairs map[id.RuleID][]firewall.Rule
|
||||
routeRules map[id.RuleID]struct{}
|
||||
routeRules map[id.RuleID]firewall.Rule
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// peerRuleGroup collapses a set of single-source FirewallRules sharing
|
||||
// the same selector into one multi-source rule to push to the backend.
|
||||
type peerRuleGroup struct {
|
||||
direction mgmProto.RuleDirection
|
||||
action mgmProto.RuleAction
|
||||
protocol mgmProto.RuleProtocol
|
||||
port *mgmProto.PortInfo
|
||||
// legacyPort is used only when PortInfo is empty (old management).
|
||||
legacyPort string
|
||||
policyID []byte
|
||||
sources []netip.Prefix
|
||||
}
|
||||
|
||||
// peerRuleKey is the comparable selector that decides which single-source
|
||||
// rules merge into one group. Rules with an equal key collapse into one
|
||||
// multi-source backend rule. PortInfo is flattened into its scalar fields
|
||||
// so the key compares by value; policyID keeps policies separate so two
|
||||
// policies authorizing different peers don't merge under one attribution.
|
||||
type peerRuleKey struct {
|
||||
v6 bool
|
||||
policyID string
|
||||
direction mgmProto.RuleDirection
|
||||
action mgmProto.RuleAction
|
||||
protocol mgmProto.RuleProtocol
|
||||
legacyPort string
|
||||
port uint16
|
||||
rangeStart uint16
|
||||
rangeEnd uint16
|
||||
}
|
||||
|
||||
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
|
||||
return &DefaultManager{
|
||||
firewall: fm,
|
||||
peerRulesPairs: make(map[id.RuleID][]firewall.Rule),
|
||||
routeRules: make(map[id.RuleID]struct{}),
|
||||
routeRules: make(map[id.RuleID]firewall.Rule),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,10 +99,12 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
||||
time.Since(start), total)
|
||||
}()
|
||||
|
||||
d.applyPeerACLs(networkMap)
|
||||
if err := d.applyPeerACLs(networkMap); err != nil {
|
||||
log.Errorf("apply peer ACLs: %v", err)
|
||||
}
|
||||
|
||||
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
|
||||
log.Errorf("Failed to apply route ACLs: %v", err)
|
||||
log.Errorf("apply route ACLs: %v", err)
|
||||
}
|
||||
|
||||
if err := d.firewall.Flush(); err != nil {
|
||||
@@ -79,7 +112,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) error {
|
||||
rules := networkMap.FirewallRules
|
||||
|
||||
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
|
||||
@@ -102,59 +135,158 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
)
|
||||
}
|
||||
|
||||
newRulePairs := make(map[id.RuleID][]firewall.Rule)
|
||||
ipsetByRuleSelectors := make(map[string]string)
|
||||
// Group incoming single-source rules from management by their
|
||||
// (direction, action, proto, port) selector and merge sources.
|
||||
// One call to the firewall backend per merged rule.
|
||||
// A deny we cannot decode would leave its traffic unblocked, so skip
|
||||
// the whole pass and keep existing rules until the next sync.
|
||||
groups, denyErr, err := groupPeerRules(rules)
|
||||
if denyErr != nil {
|
||||
return fmt.Errorf("decode deny rule sources: %w", denyErr)
|
||||
}
|
||||
|
||||
// TODO: deny rules should be fatal: if a deny rule fails to apply, we must
|
||||
// roll back all allow rules to avoid a fail-open where allowed traffic bypasses
|
||||
// the missing deny. Currently we accumulate errors and continue.
|
||||
newRulePairs := make(map[id.RuleID][]firewall.Rule)
|
||||
var merr *multierror.Error
|
||||
for _, r := range rules {
|
||||
// if this rule is member of rule selection with more than DefaultIPsCountForSet
|
||||
// it's IP address can be used in the ipset for firewall manager which supports it
|
||||
selector := d.getRuleGroupingSelector(r)
|
||||
ipsetName, ok := ipsetByRuleSelectors[selector]
|
||||
if !ok {
|
||||
d.ipsetCounter++
|
||||
ipsetName = fmt.Sprintf("nb%07d", d.ipsetCounter)
|
||||
ipsetByRuleSelectors[selector] = ipsetName
|
||||
}
|
||||
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("apply firewall rule: %w", err))
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
// Apply denies first. A deny that fails to install is a security
|
||||
// failure (fail-open), so if any deny errors we roll back the
|
||||
// denies we already installed in this pass and bail out without
|
||||
// installing any accept. Pre-existing rules stay untouched until
|
||||
// the next successful pass clears them.
|
||||
denies, accepts := splitDenyAccept(groups)
|
||||
if err := d.installPeerGroups(denies, newRulePairs, true); err != nil {
|
||||
return fmt.Errorf("install deny rules: %w", err)
|
||||
}
|
||||
|
||||
if err := d.installPeerGroups(accepts, newRulePairs, false); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
// Tear down rules that disappeared from the networkmap. Any rule
|
||||
// the backend refuses to delete stays in our tracking so it gets
|
||||
// retried on the next ApplyFiltering. Otherwise a transient
|
||||
// delete failure would leak the rule in the firewall until the
|
||||
// process exits.
|
||||
for pairID, rules := range d.peerRulesPairs {
|
||||
if _, ok := newRulePairs[pairID]; ok {
|
||||
continue
|
||||
}
|
||||
if len(rulePair) > 0 {
|
||||
d.peerRulesPairs[pairID] = rulePair
|
||||
newRulePairs[pairID] = rulePair
|
||||
}
|
||||
}
|
||||
|
||||
if merr != nil {
|
||||
log.Errorf("failed to apply %d peer ACL rule(s): %v", merr.Len(), nberrors.FormatErrorOrNil(merr))
|
||||
}
|
||||
|
||||
for pairID, rules := range d.peerRulesPairs {
|
||||
if _, ok := newRulePairs[pairID]; !ok {
|
||||
for _, rule := range rules {
|
||||
if err := d.firewall.DeletePeerRule(rule); err != nil {
|
||||
log.Errorf("failed to delete peer firewall rule: %v", err)
|
||||
continue
|
||||
}
|
||||
var remaining []firewall.Rule
|
||||
for _, rule := range rules {
|
||||
if err := d.firewall.DeleteFilterRule(rule); err != nil {
|
||||
log.Errorf("failed to delete peer firewall rule, will retry: %v", err)
|
||||
remaining = append(remaining, rule)
|
||||
}
|
||||
delete(d.peerRulesPairs, pairID)
|
||||
}
|
||||
if len(remaining) > 0 {
|
||||
newRulePairs[pairID] = remaining
|
||||
}
|
||||
}
|
||||
d.peerRulesPairs = newRulePairs
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// installPeerGroups applies each group and records the resulting rule
|
||||
// pairs in newRulePairs. With atomic set (deny rules), a single failure
|
||||
// rolls back every rule installed in this call and returns, leaving the
|
||||
// firewall exactly as before: denies are fail-closed and must be applied
|
||||
// all-or-nothing. With atomic unset (accept rules), failures are
|
||||
// accumulated and the remaining groups still install, so one malformed
|
||||
// allow cannot drop every other legitimate allow in the pass.
|
||||
func (d *DefaultManager) installPeerGroups(groups []*peerRuleGroup, newRulePairs map[id.RuleID][]firewall.Rule, atomic bool) error {
|
||||
var freshlyInstalled []id.RuleID
|
||||
var merr *multierror.Error
|
||||
for _, g := range groups {
|
||||
pairID, rulePair, err := d.applyPeerGroup(g)
|
||||
if err != nil {
|
||||
if atomic {
|
||||
d.rollbackInstalled(freshlyInstalled)
|
||||
return fmt.Errorf("apply firewall rule: %w", err)
|
||||
}
|
||||
merr = multierror.Append(merr, fmt.Errorf("apply firewall rule: %w", err))
|
||||
continue
|
||||
}
|
||||
if len(rulePair) == 0 {
|
||||
continue
|
||||
}
|
||||
if _, existed := d.peerRulesPairs[pairID]; !existed {
|
||||
freshlyInstalled = append(freshlyInstalled, pairID)
|
||||
}
|
||||
d.peerRulesPairs[pairID] = rulePair
|
||||
newRulePairs[pairID] = rulePair
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (d *DefaultManager) rollbackInstalled(pairIDs []id.RuleID) {
|
||||
var merr *multierror.Error
|
||||
for _, pairID := range pairIDs {
|
||||
for _, rule := range d.peerRulesPairs[pairID] {
|
||||
if err := d.firewall.DeleteFilterRule(rule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("rule %s: %w", pairID, err))
|
||||
}
|
||||
}
|
||||
delete(d.peerRulesPairs, pairID)
|
||||
}
|
||||
if err := nberrors.FormatErrorOrNil(merr); err != nil {
|
||||
log.Errorf("rollback peer rules: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultManager) applyPeerGroup(g *peerRuleGroup) (id.RuleID, []firewall.Rule, error) {
|
||||
protocol, err := ConvertToFirewallProtocol(g.protocol)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("skipping firewall rule: %w", err)
|
||||
}
|
||||
action, err := convertFirewallAction(g.action)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("skipping firewall rule: %w", err)
|
||||
}
|
||||
port, err := resolveGroupPort(g)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
var fwRule firewall.Rule
|
||||
switch g.direction {
|
||||
case mgmProto.RuleDirection_IN:
|
||||
fwRule, err = d.firewall.AddFilterRule(g.policyID, g.sources, firewall.Network{}, protocol, nil, port, action)
|
||||
case mgmProto.RuleDirection_OUT:
|
||||
if d.firewall.IsStateful() {
|
||||
return "", nil, nil
|
||||
}
|
||||
if shouldSkipInvertedRule(protocol, port) {
|
||||
return "", nil, nil
|
||||
}
|
||||
fwRule, err = d.firewall.AddFilterRule(g.policyID, g.sources, firewall.Network{}, protocol, port, nil, action)
|
||||
default:
|
||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("add firewall rule: %w", err)
|
||||
}
|
||||
if fwRule == nil {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
// Derive the pair id from the backend rule, like the route path:
|
||||
// the backend dedups identical content, so two policies authorizing
|
||||
// the same flow resolve to the same id and a single backing rule.
|
||||
return fwRule.ID(), []firewall.Rule{fwRule}, nil
|
||||
}
|
||||
|
||||
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dynamicResolver bool) error {
|
||||
newRouteRules := make(map[id.RuleID]struct{}, len(rules))
|
||||
newRouteRules := make(map[id.RuleID]firewall.Rule, len(rules))
|
||||
var merr *multierror.Error
|
||||
|
||||
// Apply new rules - firewall manager will return existing rule ID if already present
|
||||
// Apply new rules - firewall manager will return the existing rule if already present
|
||||
for _, rule := range rules {
|
||||
id, err := d.applyRouteACL(rule, dynamicResolver)
|
||||
addedRule, err := d.applyRouteACL(rule, dynamicResolver)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSourceRangesEmpty) {
|
||||
log.Debugf("skipping empty sources rule with destination %s: %v", rule.Destination, err)
|
||||
@@ -163,16 +295,18 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dyn
|
||||
}
|
||||
continue
|
||||
}
|
||||
newRouteRules[id] = struct{}{}
|
||||
newRouteRules[addedRule.ID()] = addedRule
|
||||
}
|
||||
|
||||
// Clean up old firewall rules
|
||||
for id := range d.routeRules {
|
||||
if _, exists := newRouteRules[id]; !exists {
|
||||
if err := d.firewall.DeleteRouteRule(id); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete route rule: %w", err))
|
||||
}
|
||||
// implicitly deleted from the map
|
||||
// Tear down old route rules; retain ones the backend refused so a
|
||||
// transient failure doesn't leave orphaned rules in the firewall.
|
||||
for ruleID, rule := range d.routeRules {
|
||||
if _, exists := newRouteRules[ruleID]; exists {
|
||||
continue
|
||||
}
|
||||
if err := d.firewall.DeleteFilterRule(rule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete route rule, will retry: %w", err))
|
||||
newRouteRules[ruleID] = rule
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,102 +314,196 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dyn
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (id.RuleID, error) {
|
||||
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (firewall.Rule, error) {
|
||||
if len(rule.SourceRanges) == 0 {
|
||||
return "", ErrSourceRangesEmpty
|
||||
return nil, ErrSourceRangesEmpty
|
||||
}
|
||||
|
||||
var sources []netip.Prefix
|
||||
for _, sourceRange := range rule.SourceRanges {
|
||||
source, err := netip.ParsePrefix(sourceRange)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse source range: %w", err)
|
||||
return nil, fmt.Errorf("parse source range: %w", err)
|
||||
}
|
||||
sources = append(sources, source)
|
||||
sources = append(sources, firewall.UnmapPrefix(source))
|
||||
}
|
||||
|
||||
destination, err := determineDestination(rule, dynamicResolver, sources)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("determine destination: %w", err)
|
||||
return nil, fmt.Errorf("determine destination: %w", err)
|
||||
}
|
||||
|
||||
protocol, err := convertToFirewallProtocol(rule.Protocol)
|
||||
protocol, err := ConvertToFirewallProtocol(rule.Protocol)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid protocol: %w", err)
|
||||
return nil, fmt.Errorf("invalid protocol: %w", err)
|
||||
}
|
||||
|
||||
action, err := convertFirewallAction(rule.Action)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid action: %w", err)
|
||||
return nil, fmt.Errorf("invalid action: %w", err)
|
||||
}
|
||||
|
||||
dPorts := convertPortInfo(rule.PortInfo)
|
||||
|
||||
addedRule, err := d.firewall.AddRouteFiltering(rule.PolicyID, sources, destination, protocol, nil, dPorts, action)
|
||||
addedRule, err := d.firewall.AddFilterRule(rule.PolicyID, sources, destination, protocol, nil, dPorts, action)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("add route rule: %w", err)
|
||||
return nil, fmt.Errorf("add route rule: %w", err)
|
||||
}
|
||||
if addedRule == nil {
|
||||
return nil, fmt.Errorf("add route rule: %w", ErrNoRuleReturned)
|
||||
}
|
||||
|
||||
return id.RuleID(addedRule.ID()), nil
|
||||
return addedRule, nil
|
||||
}
|
||||
|
||||
func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
r *mgmProto.FirewallRule,
|
||||
ipsetName string,
|
||||
) (id.RuleID, []firewall.Rule, error) {
|
||||
ip, err := extractRuleIP(r)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
// splitDenyAccept partitions groups by action so denies can be
|
||||
// applied before accepts. Order within each bucket is preserved.
|
||||
func splitDenyAccept(groups []*peerRuleGroup) (denies, accepts []*peerRuleGroup) {
|
||||
for _, g := range groups {
|
||||
if g.action == mgmProto.RuleAction_DROP {
|
||||
denies = append(denies, g)
|
||||
} else {
|
||||
accepts = append(accepts, g)
|
||||
}
|
||||
}
|
||||
return denies, accepts
|
||||
}
|
||||
|
||||
// groupPeerRules merges single-source rules sharing a selector into
|
||||
// multi-source groups. It splits source-decode failures by action:
|
||||
// denyErr is non-nil when a deny rule could not be decoded, which is a
|
||||
// fail-open risk the caller must treat as fatal for the pass; err
|
||||
// carries the tolerable accept-rule failures the caller can log and
|
||||
// continue past.
|
||||
func groupPeerRules(rules []*mgmProto.FirewallRule) (groups []*peerRuleGroup, denyErr error, err error) {
|
||||
var denyMerr, acceptMerr *multierror.Error
|
||||
byKey := make(map[peerRuleKey]*peerRuleGroup)
|
||||
order := make([]peerRuleKey, 0)
|
||||
|
||||
for _, r := range rules {
|
||||
srcs, decErr := extractRuleSources(r)
|
||||
if decErr != nil {
|
||||
if r.Action == mgmProto.RuleAction_DROP {
|
||||
denyMerr = multierror.Append(denyMerr, decErr)
|
||||
} else {
|
||||
acceptMerr = multierror.Append(acceptMerr, decErr)
|
||||
}
|
||||
continue
|
||||
}
|
||||
// A single FirewallRule normally carries one address family, but
|
||||
// split by family defensively: each backend keys a rule to one
|
||||
// family and would mismatch sources of the other, so a group's
|
||||
// sources must never span families.
|
||||
v4, v6 := splitPrefixesByFamily(srcs)
|
||||
for _, sub := range []struct {
|
||||
isV6 bool
|
||||
sources []netip.Prefix
|
||||
}{{false, v4}, {true, v6}} {
|
||||
if len(sub.sources) == 0 {
|
||||
continue
|
||||
}
|
||||
key := ruleGroupKey(r, sub.isV6)
|
||||
g, ok := byKey[key]
|
||||
if !ok {
|
||||
g = &peerRuleGroup{
|
||||
direction: r.Direction,
|
||||
action: r.Action,
|
||||
protocol: r.Protocol,
|
||||
port: r.PortInfo,
|
||||
legacyPort: r.Port,
|
||||
policyID: r.PolicyID,
|
||||
}
|
||||
byKey[key] = g
|
||||
order = append(order, key)
|
||||
}
|
||||
g.sources = append(g.sources, sub.sources...)
|
||||
}
|
||||
}
|
||||
|
||||
protocol, err := convertToFirewallProtocol(r.Protocol)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
|
||||
out := make([]*peerRuleGroup, 0, len(order))
|
||||
for _, k := range order {
|
||||
out = append(out, byKey[k])
|
||||
}
|
||||
return out, nberrors.FormatErrorOrNil(denyMerr), nberrors.FormatErrorOrNil(acceptMerr)
|
||||
}
|
||||
|
||||
func prefixIsV6(p netip.Prefix) bool {
|
||||
return p.Addr().Is6() && !p.Addr().Is4In6()
|
||||
}
|
||||
|
||||
// splitPrefixesByFamily partitions prefixes into IPv4 and IPv6 groups.
|
||||
func splitPrefixesByFamily(prefixes []netip.Prefix) (v4, v6 []netip.Prefix) {
|
||||
for _, p := range prefixes {
|
||||
if prefixIsV6(p) {
|
||||
v6 = append(v6, p)
|
||||
} else {
|
||||
v4 = append(v4, p)
|
||||
}
|
||||
}
|
||||
return v4, v6
|
||||
}
|
||||
|
||||
// ruleGroupKey builds the selector key for a rule. v6 must reflect the
|
||||
// rule's source family: mgmt emits one rule per family and mixing them
|
||||
// would break ICMP-variant selection in uspfilter.
|
||||
func ruleGroupKey(r *mgmProto.FirewallRule, v6 bool) peerRuleKey {
|
||||
k := peerRuleKey{
|
||||
v6: v6,
|
||||
policyID: string(r.PolicyID),
|
||||
direction: r.Direction,
|
||||
action: r.Action,
|
||||
protocol: r.Protocol,
|
||||
legacyPort: r.Port,
|
||||
}
|
||||
if pi := r.PortInfo; pi != nil {
|
||||
k.port = uint16(pi.GetPort())
|
||||
if rng := pi.GetRange(); rng != nil {
|
||||
k.rangeStart = uint16(rng.GetStart())
|
||||
k.rangeEnd = uint16(rng.GetEnd())
|
||||
}
|
||||
}
|
||||
return k
|
||||
}
|
||||
|
||||
// extractRuleSources returns all source prefixes the rule applies to.
|
||||
// New management populates sourcePrefixes; older management sets PeerIP.
|
||||
func extractRuleSources(r *mgmProto.FirewallRule) ([]netip.Prefix, error) {
|
||||
if len(r.SourcePrefixes) > 0 {
|
||||
out := make([]netip.Prefix, 0, len(r.SourcePrefixes))
|
||||
for _, raw := range r.SourcePrefixes {
|
||||
addr, err := netiputil.DecodeAddr(raw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode source prefix: %w", err)
|
||||
}
|
||||
out = append(out, netip.PrefixFrom(addr.Unmap(), addr.Unmap().BitLen()))
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
action, err := convertFirewallAction(r.Action)
|
||||
//nolint:staticcheck // PeerIP used for backward compatibility with old management
|
||||
addr, err := netip.ParseAddr(r.PeerIP)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
|
||||
return nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||
}
|
||||
addr = addr.Unmap()
|
||||
return []netip.Prefix{netip.PrefixFrom(addr, addr.BitLen())}, nil
|
||||
}
|
||||
|
||||
var port *firewall.Port
|
||||
if !portInfoEmpty(r.PortInfo) {
|
||||
port = convertPortInfo(r.PortInfo)
|
||||
} else if r.Port != "" {
|
||||
// old version of management, single port
|
||||
value, err := strconv.Atoi(r.Port)
|
||||
func resolveGroupPort(g *peerRuleGroup) (*firewall.Port, error) {
|
||||
if !portInfoEmpty(g.port) {
|
||||
return convertPortInfo(g.port), nil
|
||||
}
|
||||
if g.legacyPort != "" {
|
||||
value, err := strconv.Atoi(g.legacyPort)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("invalid port: %w", err)
|
||||
return nil, fmt.Errorf("invalid port: %w", err)
|
||||
}
|
||||
port = &firewall.Port{
|
||||
return &firewall.Port{
|
||||
Values: []uint16{uint16(value)},
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action)
|
||||
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
|
||||
return ruleID, rulesPair, nil
|
||||
}
|
||||
|
||||
var rules []firewall.Rule
|
||||
switch r.Direction {
|
||||
case mgmProto.RuleDirection_IN:
|
||||
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||
case mgmProto.RuleDirection_OUT:
|
||||
if d.firewall.IsStateful() {
|
||||
return "", nil, nil
|
||||
}
|
||||
// return traffic for outbound connections if firewall is stateless
|
||||
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||
default:
|
||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return ruleID, rules, nil
|
||||
// nolint:nilnil // a nil port legitimately means "no port restriction"
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
|
||||
@@ -294,85 +522,9 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultManager) addInRules(
|
||||
id []byte,
|
||||
ip netip.Addr,
|
||||
protocol firewall.Protocol,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, nil, port, action, ipsetName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (d *DefaultManager) addOutRules(
|
||||
id []byte,
|
||||
ip netip.Addr,
|
||||
protocol firewall.Protocol,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
if shouldSkipInvertedRule(protocol, port) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, port, nil, action, ipsetName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// getPeerRuleID returns unique ID for the rule based on its parameters.
|
||||
func (d *DefaultManager) getPeerRuleID(
|
||||
ip netip.Addr,
|
||||
proto firewall.Protocol,
|
||||
direction int,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
) id.RuleID {
|
||||
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action))
|
||||
if port != nil {
|
||||
idStr += port.String()
|
||||
}
|
||||
|
||||
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
|
||||
}
|
||||
|
||||
// getRuleGroupingSelector takes all rule properties except IP address to build selector
|
||||
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
|
||||
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
|
||||
}
|
||||
|
||||
|
||||
// extractRuleIP extracts the peer IP from a firewall rule.
|
||||
// If sourcePrefixes is populated (new management), decode the first entry and use its address.
|
||||
// Otherwise fall back to the deprecated PeerIP string field (old management).
|
||||
func extractRuleIP(r *mgmProto.FirewallRule) (netip.Addr, error) {
|
||||
if len(r.SourcePrefixes) > 0 {
|
||||
addr, err := netiputil.DecodeAddr(r.SourcePrefixes[0])
|
||||
if err != nil {
|
||||
return netip.Addr{}, fmt.Errorf("decode source prefix: %w", err)
|
||||
}
|
||||
return addr.Unmap(), nil
|
||||
}
|
||||
|
||||
//nolint:staticcheck // PeerIP used for backward compatibility with old management
|
||||
addr, err := netip.ParseAddr(r.PeerIP)
|
||||
if err != nil {
|
||||
return netip.Addr{}, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||
}
|
||||
return addr.Unmap(), nil
|
||||
}
|
||||
|
||||
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
|
||||
// ConvertToFirewallProtocol maps a management rule protocol to the
|
||||
// firewall protocol type.
|
||||
func ConvertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
|
||||
switch protocol {
|
||||
case mgmProto.RuleProtocol_TCP:
|
||||
return firewall.ProtocolTCP, nil
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
fwmanager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||
@@ -76,9 +77,9 @@ func TestDefaultManager(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("add extra rules", func(t *testing.T) {
|
||||
existedPairs := map[string]struct{}{}
|
||||
existedPairs := map[fwmanager.RuleID]struct{}{}
|
||||
for id := range acl.peerRulesPairs {
|
||||
existedPairs[id.ID()] = struct{}{}
|
||||
existedPairs[id] = struct{}{}
|
||||
}
|
||||
|
||||
// remove first rule
|
||||
@@ -105,7 +106,7 @@ func TestDefaultManager(t *testing.T) {
|
||||
// check that old rule was removed
|
||||
previousCount := 0
|
||||
for id := range acl.peerRulesPairs {
|
||||
if _, ok := existedPairs[id.ID()]; ok {
|
||||
if _, ok := existedPairs[id]; ok {
|
||||
previousCount++
|
||||
}
|
||||
}
|
||||
|
||||
@@ -360,7 +360,13 @@ func isRedirectURLPortUsed(redirectURL string, excludedRanges []excludedPortRang
|
||||
return true
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf(":%s", port)
|
||||
// FreeBSD 15 disables connecting to INADDR_ANY (0.0.0.0) as a localhost
|
||||
// alias by default, ensure explicit ip for localhost.
|
||||
host := parsedURL.Hostname()
|
||||
if host == "" {
|
||||
host = "127.0.0.1"
|
||||
}
|
||||
addr := net.JoinHostPort(host, port)
|
||||
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
|
||||
if err != nil {
|
||||
return false
|
||||
|
||||
@@ -339,8 +339,7 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||
case entry.Pattern == ".":
|
||||
return true
|
||||
case entry.IsWildcard:
|
||||
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
|
||||
return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
|
||||
return strings.HasSuffix(qname, "."+entry.Pattern)
|
||||
default:
|
||||
// For non-wildcard patterns:
|
||||
// If handler wants subdomain matching, allow suffix match
|
||||
|
||||
@@ -164,6 +164,54 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
||||
matchSubdomains: true,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard label-boundary mismatch (suffix overlap)",
|
||||
handlerDomain: "*.b.test.",
|
||||
queryDomain: "x.ab.test.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard label-boundary match",
|
||||
handlerDomain: "*.b.test.",
|
||||
queryDomain: "x.b.test.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard multi-label match",
|
||||
handlerDomain: "*.b.test.",
|
||||
queryDomain: "x.y.b.test.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard no match on multi-label apex",
|
||||
handlerDomain: "*.b.test.",
|
||||
queryDomain: "b.test.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard no match on unrelated suffix containment",
|
||||
handlerDomain: "*.example.com.",
|
||||
queryDomain: "notexample.com.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard accepts pattern registered without trailing dot",
|
||||
handlerDomain: "*.b.test",
|
||||
queryDomain: "x.b.test.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -273,6 +321,19 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
||||
expectedCalls: 1,
|
||||
expectedHandler: 2, // highest priority matching handler should be called
|
||||
},
|
||||
{
|
||||
name: "overlapping wildcard suffixes route to correct handler",
|
||||
handlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
}{
|
||||
{pattern: "*.b.test.", priority: nbdns.PriorityDNSRoute},
|
||||
{pattern: "*.ab.test.", priority: nbdns.PriorityDNSRoute},
|
||||
},
|
||||
queryDomain: "app.ab.test.",
|
||||
expectedCalls: 1,
|
||||
expectedHandler: 1,
|
||||
},
|
||||
{
|
||||
name: "root zone with specific domain",
|
||||
handlers: []struct {
|
||||
|
||||
@@ -26,6 +26,19 @@ type resolver interface {
|
||||
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||
}
|
||||
|
||||
// PeerConnectivity reports whether a tunnel IP belongs to a peer the
|
||||
// client knows about and whether that peer is currently connected. The
|
||||
// local resolver uses this to suppress A/AAAA answers whose RDATA points
|
||||
// at a disconnected peer (typical case: a synthesized private-service
|
||||
// record pointing at an embedded proxy peer that just went offline).
|
||||
//
|
||||
// known=false means the IP isn't in the local peerstore at all — the
|
||||
// record is left alone (it points at something outside our mesh, e.g.
|
||||
// a non-peer upstream).
|
||||
type PeerConnectivity interface {
|
||||
IsConnectedByIP(ip string) (known, connected bool)
|
||||
}
|
||||
|
||||
type Resolver struct {
|
||||
mu sync.RWMutex
|
||||
records map[dns.Question][]dns.RR
|
||||
@@ -33,6 +46,11 @@ type Resolver struct {
|
||||
// zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone)
|
||||
zones map[domain.Domain]bool
|
||||
resolver resolver
|
||||
// peerConn, when non-nil, is consulted on every A/AAAA answer to
|
||||
// drop records pointing at disconnected peers. nil disables the
|
||||
// filter and preserves the legacy "return whatever is registered"
|
||||
// behaviour for callers that never wire a status source.
|
||||
peerConn PeerConnectivity
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
@@ -49,6 +67,15 @@ func NewResolver() *Resolver {
|
||||
}
|
||||
}
|
||||
|
||||
// SetPeerConnectivity wires the per-IP connectivity check used to filter
|
||||
// out A/AAAA answers pointing at disconnected peers. Pass nil to disable.
|
||||
// Safe to call multiple times; the latest value wins.
|
||||
func (d *Resolver) SetPeerConnectivity(p PeerConnectivity) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
d.peerConn = p
|
||||
}
|
||||
|
||||
func (d *Resolver) MatchSubdomains() bool {
|
||||
return true
|
||||
}
|
||||
@@ -95,6 +122,7 @@ func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
replyMessage.RecursionAvailable = true
|
||||
|
||||
result := d.lookupRecords(logger, question)
|
||||
result.records = d.filterDisconnectedPeerAnswers(logger, question, result.records)
|
||||
replyMessage.Authoritative = !result.hasExternalData
|
||||
replyMessage.Answer = result.records
|
||||
replyMessage.Rcode = d.determineRcode(question, result)
|
||||
@@ -436,6 +464,78 @@ func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16,
|
||||
}
|
||||
}
|
||||
|
||||
// filterDisconnectedPeerAnswers drops A/AAAA records whose RDATA matches
|
||||
// a known but disconnected peer. The synthesized private-service zones
|
||||
// emit one A record per connected proxy peer in a cluster; when a peer
|
||||
// goes offline, the server-side refresh removes the record from the
|
||||
// next netmap, but the client may still hold the previous netmap for a
|
||||
// short window. This filter is the local belt to that braces — even on
|
||||
// the stale netmap, the resolver hides the offline target.
|
||||
//
|
||||
// Records pointing at unknown IPs (outside the local peerstore, e.g.
|
||||
// non-mesh upstreams) are never dropped. Non-A/AAAA records pass
|
||||
// through untouched.
|
||||
//
|
||||
// Escape hatch: if filtering would leave the answer empty AND at least
|
||||
// one record was filtered, the original list is returned. Better to
|
||||
// hand the client a record that may not respond than NXDOMAIN it
|
||||
// completely when every proxy peer is offline (the upstream may still
|
||||
// be reachable some other way, or the peerstore may be stale).
|
||||
func (d *Resolver) filterDisconnectedPeerAnswers(logger *log.Entry, question dns.Question, records []dns.RR) []dns.RR {
|
||||
if len(records) == 0 {
|
||||
return records
|
||||
}
|
||||
d.mu.RLock()
|
||||
checker := d.peerConn
|
||||
d.mu.RUnlock()
|
||||
if checker == nil {
|
||||
return records
|
||||
}
|
||||
|
||||
kept := make([]dns.RR, 0, len(records))
|
||||
var dropped int
|
||||
for _, rr := range records {
|
||||
ip := extractRecordIP(rr)
|
||||
if ip == "" {
|
||||
kept = append(kept, rr)
|
||||
continue
|
||||
}
|
||||
known, connected := checker.IsConnectedByIP(ip)
|
||||
if known && !connected {
|
||||
dropped++
|
||||
continue
|
||||
}
|
||||
kept = append(kept, rr)
|
||||
}
|
||||
if dropped == 0 {
|
||||
return records
|
||||
}
|
||||
if len(kept) == 0 {
|
||||
logger.Debugf("all %d answers for %s point at disconnected peers; returning the original list", dropped, question.Name)
|
||||
return records
|
||||
}
|
||||
logger.Tracef("dropped %d disconnected-peer answer(s) for %s, returning %d", dropped, question.Name, len(kept))
|
||||
return kept
|
||||
}
|
||||
|
||||
// extractRecordIP returns the dotted-decimal / colon-hex IP carried by
|
||||
// an A or AAAA record, or "" for any other record type.
|
||||
func extractRecordIP(rr dns.RR) string {
|
||||
switch r := rr.(type) {
|
||||
case *dns.A:
|
||||
if r.A == nil {
|
||||
return ""
|
||||
}
|
||||
return r.A.String()
|
||||
case *dns.AAAA:
|
||||
if r.AAAA == nil {
|
||||
return ""
|
||||
}
|
||||
return r.AAAA.String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Update replaces all zones and their records
|
||||
func (d *Resolver) Update(customZones []nbdns.CustomZone) {
|
||||
d.mu.Lock()
|
||||
|
||||
@@ -30,6 +30,21 @@ func (m *mockResolver) LookupNetIP(ctx context.Context, network, host string) ([
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// mockPeerConnectivity returns canned (known, connected) results per IP.
|
||||
// Used by the disconnected-peer filter tests below. IPs not in the map
|
||||
// are reported as unknown so the filter leaves them alone.
|
||||
type mockPeerConnectivity struct {
|
||||
byIP map[string]struct{ known, connected bool }
|
||||
}
|
||||
|
||||
func (m mockPeerConnectivity) IsConnectedByIP(ip string) (known, connected bool) {
|
||||
v, ok := m.byIP[ip]
|
||||
if !ok {
|
||||
return false, false
|
||||
}
|
||||
return v.known, v.connected
|
||||
}
|
||||
|
||||
func TestLocalResolver_ServeDNS(t *testing.T) {
|
||||
recordA := nbdns.SimpleRecord{
|
||||
Name: "peera.netbird.cloud.",
|
||||
@@ -2652,3 +2667,114 @@ func BenchmarkIsInManagedZone_ManyZones(b *testing.B) {
|
||||
resolver.isInManagedZone(qname)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLocalResolver_FilterDisconnectedPeerAnswers verifies the
|
||||
// connectivity-aware filtering layered on top of lookupRecords:
|
||||
// when an A record's IP belongs to a known peer that's disconnected,
|
||||
// the record is dropped from the answer. Records for unknown IPs pass
|
||||
// through. If filtering would empty the answer entirely and at least
|
||||
// one record was dropped, the original list is restored (escape hatch
|
||||
// for the "all proxies offline" case).
|
||||
func TestLocalResolver_FilterDisconnectedPeerAnswers(t *testing.T) {
|
||||
zone := "svc.cluster.netbird."
|
||||
connectedRec := nbdns.SimpleRecord{
|
||||
Name: zone,
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 5,
|
||||
RData: "100.64.0.10",
|
||||
}
|
||||
disconnectedRec := nbdns.SimpleRecord{
|
||||
Name: zone,
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 5,
|
||||
RData: "100.64.0.11",
|
||||
}
|
||||
unknownRec := nbdns.SimpleRecord{
|
||||
Name: zone,
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 5,
|
||||
RData: "203.0.113.5",
|
||||
}
|
||||
|
||||
type ipState struct{ known, connected bool }
|
||||
tests := []struct {
|
||||
name string
|
||||
records []nbdns.SimpleRecord
|
||||
connByIP map[string]ipState
|
||||
wantInOrder []string
|
||||
}{
|
||||
{
|
||||
name: "drops disconnected peer, keeps connected",
|
||||
records: []nbdns.SimpleRecord{connectedRec, disconnectedRec},
|
||||
connByIP: map[string]ipState{
|
||||
"100.64.0.10": {known: true, connected: true},
|
||||
"100.64.0.11": {known: true, connected: false},
|
||||
},
|
||||
wantInOrder: []string{"100.64.0.10"},
|
||||
},
|
||||
{
|
||||
name: "unknown IPs pass through untouched",
|
||||
records: []nbdns.SimpleRecord{unknownRec, disconnectedRec},
|
||||
connByIP: map[string]ipState{
|
||||
"100.64.0.11": {known: true, connected: false},
|
||||
},
|
||||
wantInOrder: []string{"203.0.113.5"},
|
||||
},
|
||||
{
|
||||
name: "all disconnected falls back to original list",
|
||||
records: []nbdns.SimpleRecord{disconnectedRec, connectedRec},
|
||||
connByIP: map[string]ipState{
|
||||
"100.64.0.10": {known: true, connected: false},
|
||||
"100.64.0.11": {known: true, connected: false},
|
||||
},
|
||||
wantInOrder: []string{"100.64.0.11", "100.64.0.10"},
|
||||
},
|
||||
{
|
||||
name: "no checker wired returns all records",
|
||||
records: []nbdns.SimpleRecord{connectedRec, disconnectedRec},
|
||||
connByIP: nil,
|
||||
wantInOrder: []string{"100.64.0.10", "100.64.0.11"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
if tc.connByIP != nil {
|
||||
cm := mockPeerConnectivity{byIP: make(map[string]struct{ known, connected bool }, len(tc.connByIP))}
|
||||
for ip, st := range tc.connByIP {
|
||||
cm.byIP[ip] = struct{ known, connected bool }{st.known, st.connected}
|
||||
}
|
||||
resolver.SetPeerConnectivity(cm)
|
||||
}
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: strings.TrimSuffix(zone, "."),
|
||||
Records: tc.records,
|
||||
NonAuthoritative: true,
|
||||
}})
|
||||
|
||||
var got *dns.Msg
|
||||
writer := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
got = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
req := new(dns.Msg).SetQuestion(zone, dns.TypeA)
|
||||
resolver.ServeDNS(writer, req)
|
||||
|
||||
require.NotNil(t, got, "resolver must produce a response")
|
||||
require.Len(t, got.Answer, len(tc.wantInOrder),
|
||||
"answer count must match expected: %v", tc.wantInOrder)
|
||||
for i, want := range tc.wantInOrder {
|
||||
a, ok := got.Answer[i].(*dns.A)
|
||||
require.True(t, ok, "answer[%d] must be an A record", i)
|
||||
assert.Equal(t, want, a.A.String(),
|
||||
"answer[%d] expected %s got %s", i, want, a.A.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -301,6 +301,11 @@ func newDefaultServer(
|
||||
warningDelayBase: defaultWarningDelayBase,
|
||||
healthRefresh: make(chan struct{}, 1),
|
||||
}
|
||||
// Wire the local resolver against the peer status recorder so it can
|
||||
// suppress A/AAAA answers that point at disconnected peers (typical
|
||||
// case: synthesised private-service records pointing at an embedded
|
||||
// proxy peer that just went offline).
|
||||
defaultServer.localResolver.SetPeerConnectivity(localPeerConnectivity{statusRecorder})
|
||||
|
||||
// register with root zone, handler chain takes care of the routing
|
||||
dnsService.RegisterMux(".", handlerChain)
|
||||
@@ -1386,3 +1391,25 @@ func (s *DefaultServer) PopulateManagementDomain(mgmtURL *url.URL) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// localPeerConnectivity adapts *peer.Status to local.PeerConnectivity so
|
||||
// the local resolver can ask "is this IP a known peer and is it
|
||||
// connected?" without taking on the peer package as a dependency.
|
||||
// A nil status recorder always reports known=false so the resolver
|
||||
// short-circuits to the legacy "return everything" path.
|
||||
type localPeerConnectivity struct {
|
||||
status *peer.Status
|
||||
}
|
||||
|
||||
// IsConnectedByIP looks the IP up in the peerstore and surfaces both
|
||||
// the known and connected bits. Used by Resolver.filterDisconnectedPeerAnswers.
|
||||
func (l localPeerConnectivity) IsConnectedByIP(ip string) (known, connected bool) {
|
||||
if l.status == nil {
|
||||
return false, false
|
||||
}
|
||||
state, ok := l.status.PeerStateByIP(ip)
|
||||
if !ok {
|
||||
return false, false
|
||||
}
|
||||
return true, state.ConnStatus == peer.StatusConnected
|
||||
}
|
||||
|
||||
@@ -900,7 +900,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pf, err := uspfilter.Create(wgIface, false, flowLogger, iface.DefaultMTU)
|
||||
pf, err := uspfilter.Create(wgIface, nil, false, flowLogger, iface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create uspfilter: %v", err)
|
||||
return nil, err
|
||||
|
||||
@@ -3,7 +3,6 @@ package dnsfwd
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
@@ -160,12 +159,13 @@ func (m *Manager) allowDNSFirewall() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
|
||||
anyV4 := []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||
dnsRule, err := m.firewall.AddFilterRule(nil, anyV4, firewall.Network{}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add udp firewall rule: %w", err)
|
||||
}
|
||||
|
||||
tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "")
|
||||
tcpRule, err := m.firewall.AddFilterRule(nil, anyV4, firewall.Network{}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add tcp firewall rule: %w", err)
|
||||
}
|
||||
@@ -174,8 +174,12 @@ func (m *Manager) allowDNSFirewall() error {
|
||||
return fmt.Errorf("flush: %w", err)
|
||||
}
|
||||
|
||||
m.fwRules = dnsRules
|
||||
m.tcpRules = tcpRules
|
||||
if dnsRule != nil {
|
||||
m.fwRules = []firewall.Rule{dnsRule}
|
||||
}
|
||||
if tcpRule != nil {
|
||||
m.tcpRules = []firewall.Rule{tcpRule}
|
||||
}
|
||||
|
||||
m.registerNetstackServices()
|
||||
|
||||
@@ -209,12 +213,12 @@ func (m *Manager) unregisterNetstackServices() {
|
||||
func (m *Manager) dropDNSFirewall() error {
|
||||
var mErr *multierror.Error
|
||||
for _, rule := range m.fwRules {
|
||||
if err := m.firewall.DeletePeerRule(rule); err != nil {
|
||||
if err := m.firewall.DeleteFilterRule(rule); err != nil {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
|
||||
}
|
||||
}
|
||||
for _, rule := range m.tcpRules {
|
||||
if err := m.firewall.DeletePeerRule(rule); err != nil {
|
||||
if err := m.firewall.DeleteFilterRule(rule); err != nil {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -640,14 +640,14 @@ func (e *Engine) initFirewall() error {
|
||||
port := firewallManager.Port{Values: []uint16{uint16(rosenpassPort)}}
|
||||
|
||||
// IPv4-only: rosenpass peers connect via AllowedIps[0] which is always v4.
|
||||
if _, err := e.firewall.AddPeerFiltering(
|
||||
if _, err := e.firewall.AddFilterRule(
|
||||
nil,
|
||||
net.IP{0, 0, 0, 0},
|
||||
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
||||
firewallManager.Network{},
|
||||
firewallManager.ProtocolUDP,
|
||||
nil,
|
||||
&port,
|
||||
firewallManager.ActionAccept,
|
||||
"",
|
||||
); err != nil {
|
||||
log.Errorf("failed to allow rosenpass interface traffic: %v", err)
|
||||
return nil
|
||||
@@ -697,7 +697,7 @@ func (e *Engine) blockLanAccess() {
|
||||
if network.Addr().Is6() {
|
||||
source = v6
|
||||
}
|
||||
if _, err := e.firewall.AddRouteFiltering(
|
||||
if _, err := e.firewall.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{source},
|
||||
firewallManager.Network{Prefix: network},
|
||||
@@ -1967,6 +1967,29 @@ func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
|
||||
return e.clientMetrics
|
||||
}
|
||||
|
||||
// Performance bundles runtime-adjustable tunnel pool knobs.
|
||||
// See Engine.SetPerformance. Nil fields are ignored.
|
||||
type Performance struct {
|
||||
PreallocatedBuffersPerPool *uint32
|
||||
}
|
||||
|
||||
// SetPerformance applies the given tuning to this engine's live Device.
|
||||
func (e *Engine) SetPerformance(t Performance) error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
if e.wgInterface == nil {
|
||||
return fmt.Errorf("wg interface not initialized")
|
||||
}
|
||||
dev := e.wgInterface.GetWGDevice()
|
||||
if dev == nil {
|
||||
return fmt.Errorf("wg device not initialized")
|
||||
}
|
||||
if t.PreallocatedBuffersPerPool != nil {
|
||||
dev.SetPreallocatedBuffersPerPool(*t.PreallocatedBuffersPerPool)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||
iface, err := net.InterfaceByName(ifaceName)
|
||||
if err != nil {
|
||||
@@ -2346,7 +2369,7 @@ func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewal
|
||||
var merr *multierror.Error
|
||||
forwardingRules := make([]firewallManager.ForwardRule, 0, len(rules))
|
||||
for _, rule := range rules {
|
||||
proto, err := convertToFirewallProtocol(rule.GetProtocol())
|
||||
proto, err := acl.ConvertToFirewallProtocol(rule.GetProtocol())
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("failed to convert protocol '%s': %w", rule.GetProtocol(), err))
|
||||
continue
|
||||
|
||||
@@ -27,7 +27,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
@@ -66,8 +66,8 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
@@ -1641,7 +1641,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -24,14 +24,14 @@ type RulePair struct {
|
||||
type Manager struct {
|
||||
dnatFirewall DNATFirewall
|
||||
|
||||
rules map[string]RulePair // keys is the ID of the ForwardRule
|
||||
rules map[firewall.RuleID]RulePair
|
||||
rulesMu sync.Mutex
|
||||
}
|
||||
|
||||
func NewManager(dnatFirewall DNATFirewall) *Manager {
|
||||
return &Manager{
|
||||
dnatFirewall: dnatFirewall,
|
||||
rules: make(map[string]RulePair),
|
||||
rules: make(map[firewall.RuleID]RulePair),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ func (h *Manager) Update(forwardRules []firewall.ForwardRule) error {
|
||||
|
||||
var mErr *multierror.Error
|
||||
|
||||
toDelete := make(map[string]RulePair, len(h.rules))
|
||||
toDelete := make(map[firewall.RuleID]RulePair, len(h.rules))
|
||||
for id, r := range h.rules {
|
||||
toDelete[id] = r
|
||||
}
|
||||
@@ -59,6 +59,10 @@ func (h *Manager) Update(forwardRules []firewall.ForwardRule) error {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("add forward rule '%s': %v", fwdRule.String(), err))
|
||||
continue
|
||||
}
|
||||
if rule == nil {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("add forward rule '%s': backend returned no rule", fwdRule.String()))
|
||||
continue
|
||||
}
|
||||
log.Infof("forward rule has been added '%s'", fwdRule)
|
||||
h.rules[id] = RulePair{
|
||||
ForwardRule: fwdRule,
|
||||
@@ -90,7 +94,7 @@ func (h *Manager) Close() error {
|
||||
}
|
||||
}
|
||||
|
||||
h.rules = make(map[string]RulePair)
|
||||
h.rules = make(map[firewall.RuleID]RulePair)
|
||||
return nberrors.FormatErrorOrNil(mErr)
|
||||
}
|
||||
|
||||
|
||||
@@ -14,11 +14,11 @@ var (
|
||||
)
|
||||
|
||||
type MocFwRule struct {
|
||||
id string
|
||||
id firewall.RuleID
|
||||
}
|
||||
|
||||
func (m *MocFwRule) ID() string {
|
||||
return string(m.id)
|
||||
func (m *MocFwRule) ID() firewall.RuleID {
|
||||
return m.id
|
||||
}
|
||||
|
||||
type MockDNATFirewall struct {
|
||||
|
||||
@@ -10,21 +10,6 @@ import (
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewallManager.Protocol, error) {
|
||||
switch protocol {
|
||||
case mgmProto.RuleProtocol_TCP:
|
||||
return firewallManager.ProtocolTCP, nil
|
||||
case mgmProto.RuleProtocol_UDP:
|
||||
return firewallManager.ProtocolUDP, nil
|
||||
case mgmProto.RuleProtocol_ICMP:
|
||||
return firewallManager.ProtocolICMP, nil
|
||||
case mgmProto.RuleProtocol_ALL:
|
||||
return firewallManager.ProtocolALL, nil
|
||||
default:
|
||||
return "", fmt.Errorf("invalid protocol type: %s", protocol.String())
|
||||
}
|
||||
}
|
||||
|
||||
func convertPortInfo(portInfo *mgmProto.PortInfo) (*firewallManager.Port, error) {
|
||||
if portInfo == nil {
|
||||
return nil, errors.New("portInfo cannot be nil")
|
||||
|
||||
@@ -50,7 +50,7 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
|
||||
switch msg.Type {
|
||||
// handle route changes
|
||||
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||
route, err := parseRouteMessage(buf[:n])
|
||||
route, flags, err := parseRouteMessage(buf[:n])
|
||||
if err != nil {
|
||||
log.Debugf("Network monitor: error parsing routing message: %v", err)
|
||||
continue
|
||||
@@ -66,6 +66,10 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
|
||||
}
|
||||
switch msg.Type {
|
||||
case unix.RTM_ADD:
|
||||
if systemops.IgnoreAddedDefaultRoute(flags) {
|
||||
log.Debugf("Network monitor: ignoring added default route via %s, interface %s, flags %#x", route.Gw, intf, flags)
|
||||
continue
|
||||
}
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||
return nil
|
||||
case unix.RTM_DELETE:
|
||||
@@ -78,22 +82,26 @@ func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Next
|
||||
}
|
||||
}
|
||||
|
||||
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
|
||||
func parseRouteMessage(buf []byte) (*systemops.Route, int, error) {
|
||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse RIB: %v", err)
|
||||
return nil, 0, fmt.Errorf("parse RIB: %v", err)
|
||||
}
|
||||
|
||||
if len(msgs) != 1 {
|
||||
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
|
||||
return nil, 0, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
|
||||
}
|
||||
|
||||
msg, ok := msgs[0].(*route.RouteMessage)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||
return nil, 0, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||
}
|
||||
|
||||
return systemops.MsgToRoute(msg)
|
||||
r, err := systemops.MsgToRoute(msg)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return r, msg.Flags, nil
|
||||
}
|
||||
|
||||
// waitReadable blocks until fd has data to read, or ctx is cancelled.
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer/id"
|
||||
"github.com/netbirdio/netbird/client/internal/peer/worker"
|
||||
"github.com/netbirdio/netbird/client/internal/portforward"
|
||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
@@ -899,7 +900,7 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
|
||||
}
|
||||
|
||||
// Fallback to deterministic key if no NetBird PSK is configured
|
||||
determKey, err := conn.rosenpassDetermKey()
|
||||
determKey, err := rosenpass.DeterministicSeedKey(conn.config.LocalKey, conn.config.Key)
|
||||
if err != nil {
|
||||
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
|
||||
return nil
|
||||
@@ -908,26 +909,6 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
|
||||
return determKey
|
||||
}
|
||||
|
||||
// todo: move this logic into Rosenpass package
|
||||
func (conn *Conn) rosenpassDetermKey() (*wgtypes.Key, error) {
|
||||
lk := []byte(conn.config.LocalKey)
|
||||
rk := []byte(conn.config.Key) // remote key
|
||||
var keyInput []byte
|
||||
if string(lk) > string(rk) {
|
||||
//nolint:gocritic
|
||||
keyInput = append(lk[:16], rk[:16]...)
|
||||
} else {
|
||||
//nolint:gocritic
|
||||
keyInput = append(rk[:16], lk[:16]...)
|
||||
}
|
||||
|
||||
key, err := wgtypes.NewKey(keyInput)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &key, nil
|
||||
}
|
||||
|
||||
func isController(config ConnConfig) bool {
|
||||
return config.LocalKey > config.Key
|
||||
}
|
||||
|
||||
@@ -185,9 +185,12 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
|
||||
return s.eventsChan
|
||||
}
|
||||
|
||||
// Status holds a state of peers, signal, management connections and relays
|
||||
// Status holds a state of peers, signal, management connections and relays.
|
||||
// mux is an RWMutex so hot read paths (notably PeerStateByIP, called for
|
||||
// every private-service request) don't contend against each other.
|
||||
// Pure read methods take RLock; anything that mutates state takes Lock.
|
||||
type Status struct {
|
||||
mux sync.Mutex
|
||||
mux sync.RWMutex
|
||||
peers map[string]State
|
||||
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
|
||||
signalState bool
|
||||
@@ -283,8 +286,8 @@ func (d *Status) AddPeer(peerPubKey string, fqdn string, ip string, ipv6 string)
|
||||
|
||||
// GetPeer adds peer to Daemon status map
|
||||
func (d *Status) GetPeer(peerPubKey string) (State, error) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
|
||||
state, ok := d.peers[peerPubKey]
|
||||
if !ok {
|
||||
@@ -294,8 +297,8 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
|
||||
}
|
||||
|
||||
func (d *Status) PeerByIP(ip string) (string, bool) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
|
||||
for _, state := range d.peers {
|
||||
if state.IP == ip {
|
||||
@@ -305,6 +308,25 @@ func (d *Status) PeerByIP(ip string) (string, bool) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// PeerStateByIP returns the full peer State for the given tunnel IP.
|
||||
// Matches against either the IPv4 (State.IP) or IPv6 (State.IPv6) tunnel
|
||||
// address so dual-stack peers are reachable on either family. Returns the
|
||||
// zero State and false when no peer matches or the input is empty.
|
||||
func (d *Status) PeerStateByIP(ip string) (State, bool) {
|
||||
if ip == "" {
|
||||
return State{}, false
|
||||
}
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
|
||||
for _, state := range d.peers {
|
||||
if (state.IP != "" && state.IP == ip) || (state.IPv6 != "" && state.IPv6 == ip) {
|
||||
return state, true
|
||||
}
|
||||
}
|
||||
return State{}, false
|
||||
}
|
||||
|
||||
// RemovePeer removes peer from Daemon status map
|
||||
func (d *Status) RemovePeer(peerPubKey string) error {
|
||||
d.mux.Lock()
|
||||
@@ -702,8 +724,8 @@ func (d *Status) UnsubscribePeerStateChanges(subscription *StatusChangeSubscript
|
||||
|
||||
// GetLocalPeerState returns the local peer state
|
||||
func (d *Status) GetLocalPeerState() LocalPeerState {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
return d.localPeer.Clone()
|
||||
}
|
||||
|
||||
@@ -909,8 +931,8 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
|
||||
}
|
||||
|
||||
func (d *Status) GetRosenpassState() RosenpassState {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
return RosenpassState{
|
||||
d.rosenpassEnabled,
|
||||
d.rosenpassPermissive,
|
||||
@@ -918,14 +940,14 @@ func (d *Status) GetRosenpassState() RosenpassState {
|
||||
}
|
||||
|
||||
func (d *Status) GetLazyConnection() bool {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
return d.lazyConnectionEnabled
|
||||
}
|
||||
|
||||
func (d *Status) GetManagementState() ManagementState {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
return ManagementState{
|
||||
d.mgmAddress,
|
||||
d.managementState,
|
||||
@@ -951,8 +973,8 @@ func (d *Status) UpdateLatency(pubKey string, latency time.Duration) error {
|
||||
|
||||
// IsLoginRequired determines if a peer's login has expired.
|
||||
func (d *Status) IsLoginRequired() bool {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
|
||||
// if peer is connected to the management then login is not expired
|
||||
if d.managementState {
|
||||
@@ -967,8 +989,8 @@ func (d *Status) IsLoginRequired() bool {
|
||||
}
|
||||
|
||||
func (d *Status) GetSignalState() SignalState {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
return SignalState{
|
||||
d.signalAddress,
|
||||
d.signalState,
|
||||
@@ -978,8 +1000,8 @@ func (d *Status) GetSignalState() SignalState {
|
||||
|
||||
// GetRelayStates returns the stun/turn/permanent relay states
|
||||
func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
if d.relayMgr == nil {
|
||||
return d.relayStates
|
||||
}
|
||||
@@ -1008,8 +1030,8 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
||||
}
|
||||
|
||||
func (d *Status) ForwardingRules() []firewall.ForwardRule {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
if d.ingressGwMgr == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -1018,16 +1040,16 @@ func (d *Status) ForwardingRules() []firewall.ForwardRule {
|
||||
}
|
||||
|
||||
func (d *Status) GetDNSStates() []NSGroupState {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
|
||||
// shallow copy is good enough, as slices fields are currently not updated
|
||||
return slices.Clone(d.nsGroupStates)
|
||||
}
|
||||
|
||||
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
return maps.Clone(d.resolvedDomainsStates)
|
||||
}
|
||||
|
||||
@@ -1043,8 +1065,8 @@ func (d *Status) GetFullStatus() FullStatus {
|
||||
LazyConnectionEnabled: d.GetLazyConnection(),
|
||||
}
|
||||
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
|
||||
fullStatus.LocalPeerState = d.localPeer
|
||||
|
||||
@@ -1219,8 +1241,8 @@ func (d *Status) SetWgIface(wgInterface WGIfaceStatus) {
|
||||
}
|
||||
|
||||
func (d *Status) PeersStatus() (*configurer.Stats, error) {
|
||||
d.mux.Lock()
|
||||
defer d.mux.Unlock()
|
||||
d.mux.RLock()
|
||||
defer d.mux.RUnlock()
|
||||
if d.wgIface == nil {
|
||||
return nil, fmt.Errorf("wgInterface is nil, cannot retrieve peers status")
|
||||
}
|
||||
|
||||
@@ -63,6 +63,33 @@ func TestUpdatePeerState(t *testing.T) {
|
||||
assert.Equal(t, ip, state.IP, "ip should be equal")
|
||||
}
|
||||
|
||||
func TestStatus_PeerStateByIP(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
req := require.New(t)
|
||||
|
||||
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", ""))
|
||||
req.NoError(status.AddPeer("pk-2", "peer-2.netbird", "100.64.0.11", ""))
|
||||
|
||||
state, ok := status.PeerStateByIP("100.64.0.10")
|
||||
req.True(ok, "known tunnel IP should resolve to a peer state")
|
||||
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
|
||||
req.Equal("peer-1.netbird", state.FQDN, "matching state must carry the right FQDN")
|
||||
|
||||
_, ok = status.PeerStateByIP("100.64.0.99")
|
||||
req.False(ok, "unknown IP must report ok=false")
|
||||
}
|
||||
|
||||
func TestStatus_PeerStateByIP_MatchesIPv6(t *testing.T) {
|
||||
status := NewRecorder("https://mgm")
|
||||
req := require.New(t)
|
||||
|
||||
req.NoError(status.AddPeer("pk-1", "peer-1.netbird", "100.64.0.10", "fd00::1"))
|
||||
|
||||
state, ok := status.PeerStateByIP("fd00::1")
|
||||
req.True(ok, "IPv6-only match must resolve to the peer state")
|
||||
req.Equal("pk-1", state.PubKey, "matching state must carry the right pub key")
|
||||
}
|
||||
|
||||
func TestStatus_UpdatePeerFQDN(t *testing.T) {
|
||||
key := "abc"
|
||||
fqdn := "peer-a.netbird.local"
|
||||
|
||||
@@ -179,8 +179,10 @@ func getDefaultGateway() (gateway net.IP, localIP net.IP, err error) {
|
||||
}
|
||||
|
||||
dst := net.IPv4zero
|
||||
if runtime.GOOS == "linux" {
|
||||
// go-netroute v0.4.0 rejects unspecified destinations client-side on Linux.
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
// go-netroute v0.4.0 rejects unspecified destinations client-side on Linux/Android.
|
||||
// TODO: on android/ios, use platform APIs (ConnectivityManager.getLinkProperties /
|
||||
// NWPathMonitor) when netlink-based lookup is restricted or unavailable.
|
||||
dst = net.IPv4(0, 0, 0, 1)
|
||||
}
|
||||
_, gateway, localIP, err = router.Route(dst)
|
||||
@@ -203,7 +205,7 @@ func getDefaultGateway6() (gateway net.IP, localIP net.IP, err error) {
|
||||
}
|
||||
|
||||
dst := net.IPv6zero
|
||||
if runtime.GOOS == "linux" {
|
||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||
// ::2
|
||||
dst = net.IP{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}
|
||||
}
|
||||
|
||||
@@ -28,6 +28,15 @@ func hashRosenpassKey(key []byte) string {
|
||||
return hex.EncodeToString(hasher.Sum(nil))
|
||||
}
|
||||
|
||||
// rpServer is the subset of rp.Server used by Manager. Defined as an interface
|
||||
// so tests can substitute a mock without spinning up a real UDP server.
|
||||
type rpServer interface {
|
||||
AddPeer(rp.PeerConfig) (rp.PeerID, error)
|
||||
RemovePeer(rp.PeerID) error
|
||||
Run() error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
ifaceName string
|
||||
spk []byte
|
||||
@@ -36,7 +45,7 @@ type Manager struct {
|
||||
preSharedKey *[32]byte
|
||||
rpPeerIDs map[string]*rp.PeerID
|
||||
rpWgHandler *NetbirdHandler
|
||||
server *rp.Server
|
||||
server rpServer
|
||||
lock sync.Mutex
|
||||
port int
|
||||
wgIface PresharedKeySetter
|
||||
@@ -51,7 +60,22 @@ func NewManager(preSharedKey *wgtypes.Key, wgIfaceName string) (*Manager, error)
|
||||
|
||||
rpKeyHash := hashRosenpassKey(public)
|
||||
log.Tracef("generated new rosenpass key pair with public key %s", rpKeyHash)
|
||||
return &Manager{ifaceName: wgIfaceName, rpKeyHash: rpKeyHash, spk: public, ssk: secret, preSharedKey: (*[32]byte)(preSharedKey), rpPeerIDs: make(map[string]*rp.PeerID), lock: sync.Mutex{}}, nil
|
||||
return &Manager{
|
||||
ifaceName: wgIfaceName,
|
||||
rpKeyHash: rpKeyHash,
|
||||
spk: public,
|
||||
ssk: secret,
|
||||
preSharedKey: (*[32]byte)(preSharedKey),
|
||||
rpPeerIDs: make(map[string]*rp.PeerID),
|
||||
// rpWgHandler is created here (instead of only in generateConfig) so it
|
||||
// is never nil between NewManager and Run(). Otherwise an early
|
||||
// OnConnected call (race observed on Android, issue #4341) panics on
|
||||
// nil receiver in addPeer -> m.rpWgHandler.AddPeer. generateConfig will
|
||||
// replace it with a fresh handler on each Run() to clear stale peer
|
||||
// state from previous engine sessions.
|
||||
rpWgHandler: NewNetbirdHandler(),
|
||||
lock: sync.Mutex{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetPubKey() []byte {
|
||||
@@ -65,6 +89,16 @@ func (m *Manager) GetAddress() *net.UDPAddr {
|
||||
|
||||
// addPeer adds a new peer to the Rosenpass server
|
||||
func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuardIP string, wireGuardPubKey string) error {
|
||||
// Defense in depth against issue #4341 (Android crash): if Run() has not
|
||||
// completed yet, m.server / m.rpWgHandler may be nil. Return an explicit
|
||||
// error instead of panicking on nil-receiver dereference.
|
||||
if m.server == nil {
|
||||
return fmt.Errorf("rosenpass server not initialized")
|
||||
}
|
||||
if m.rpWgHandler == nil {
|
||||
return fmt.Errorf("rosenpass wg handler not initialized")
|
||||
}
|
||||
|
||||
var err error
|
||||
pcfg := rp.PeerConfig{PublicKey: rosenpassPubKey}
|
||||
if m.preSharedKey != nil {
|
||||
@@ -79,6 +113,16 @@ func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuar
|
||||
if pcfg.Endpoint, err = net.ResolveUDPAddr("udp", peerAddr); err != nil {
|
||||
return fmt.Errorf("failed to resolve peer endpoint address: %w", err)
|
||||
}
|
||||
// Our local Rosenpass UDP server binds on the IPv6 wildcard ([::]) — see
|
||||
// GetAddress(). The remote peer's endpoint (pcfg.Endpoint) is the destination
|
||||
// our server will sendto when initiating handshakes. ResolveUDPAddr returns a
|
||||
// 4-byte IPv4 for IPv4 hosts, which the kernel rejects (EDESTADDRREQ) when
|
||||
// sent from an AF_INET6 socket. Normalize the remote endpoint to IPv4-mapped
|
||||
// IPv6 so its address family matches our listening socket.
|
||||
// TODO: maybe bind the Rosenpass UDP server to the peer wg IP addr
|
||||
if v4 := pcfg.Endpoint.IP.To4(); v4 != nil {
|
||||
pcfg.Endpoint.IP = v4.To16()
|
||||
}
|
||||
}
|
||||
peerID, err := m.server.AddPeer(pcfg)
|
||||
if err != nil {
|
||||
@@ -182,24 +226,31 @@ func (m *Manager) Run() error {
|
||||
return err
|
||||
}
|
||||
|
||||
m.server, err = rp.NewUDPServer(conf)
|
||||
server, err := rp.NewUDPServer(conf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.lock.Lock()
|
||||
m.server = server
|
||||
m.lock.Unlock()
|
||||
|
||||
log.Infof("starting rosenpass server on port %d", m.port)
|
||||
|
||||
return m.server.Run()
|
||||
return server.Run()
|
||||
}
|
||||
|
||||
// Close closes the Rosenpass server
|
||||
func (m *Manager) Close() error {
|
||||
if m.server != nil {
|
||||
err := m.server.Close()
|
||||
if err != nil {
|
||||
log.Errorf("failed closing local rosenpass server")
|
||||
}
|
||||
m.server = nil
|
||||
m.lock.Lock()
|
||||
server := m.server
|
||||
m.server = nil
|
||||
m.lock.Unlock()
|
||||
if server == nil {
|
||||
return nil
|
||||
}
|
||||
if err := server.Close(); err != nil {
|
||||
log.Errorf("failed closing local rosenpass server: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,14 +1,412 @@
|
||||
package rosenpass
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
rp "cunicu.li/go-rosenpass"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// --- test doubles -----------------------------------------------------------
|
||||
|
||||
type addPeerCall struct {
|
||||
cfg rp.PeerConfig
|
||||
}
|
||||
|
||||
type removePeerCall struct {
|
||||
id rp.PeerID
|
||||
}
|
||||
|
||||
type mockServer struct {
|
||||
mu sync.Mutex
|
||||
addCalls []addPeerCall
|
||||
removed []removePeerCall
|
||||
nextID rp.PeerID
|
||||
addErr error
|
||||
removeErr error
|
||||
closed bool
|
||||
ran bool
|
||||
}
|
||||
|
||||
func (m *mockServer) AddPeer(cfg rp.PeerConfig) (rp.PeerID, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.addCalls = append(m.addCalls, addPeerCall{cfg: cfg})
|
||||
if m.addErr != nil {
|
||||
return rp.PeerID{}, m.addErr
|
||||
}
|
||||
// Increment a byte in nextID so distinct peers get distinct IDs.
|
||||
m.nextID[0]++
|
||||
return m.nextID, nil
|
||||
}
|
||||
|
||||
func (m *mockServer) RemovePeer(id rp.PeerID) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.removed = append(m.removed, removePeerCall{id: id})
|
||||
return m.removeErr
|
||||
}
|
||||
|
||||
func (m *mockServer) Run() error { m.ran = true; return nil }
|
||||
func (m *mockServer) Close() error { m.closed = true; return nil }
|
||||
|
||||
type setPSKCall struct {
|
||||
peerKey string
|
||||
psk wgtypes.Key
|
||||
updateOnly bool
|
||||
}
|
||||
|
||||
type mockIface struct {
|
||||
mu sync.Mutex
|
||||
calls []setPSKCall
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.calls = append(m.calls, setPSKCall{peerKey: peerKey, psk: psk, updateOnly: updateOnly})
|
||||
return m.err
|
||||
}
|
||||
|
||||
// newTestManager builds a Manager with deterministic spk so tie-break
|
||||
// against a peer pubkey is controllable from tests. The provided spk byte
|
||||
// becomes the first byte; remaining bytes are zero.
|
||||
func newTestManager(spkFirstByte byte, mock *mockServer) *Manager {
|
||||
spk := make([]byte, 32)
|
||||
spk[0] = spkFirstByte
|
||||
return &Manager{
|
||||
ifaceName: "wt0",
|
||||
spk: spk,
|
||||
ssk: make([]byte, 32),
|
||||
rpKeyHash: "test-hash",
|
||||
rpPeerIDs: make(map[string]*rp.PeerID),
|
||||
rpWgHandler: NewNetbirdHandler(),
|
||||
server: mock,
|
||||
}
|
||||
}
|
||||
|
||||
// validWGKey returns a deterministic 32-byte wireguard public key (base64).
|
||||
func validWGKey(t *testing.T, lastByte byte) string {
|
||||
t.Helper()
|
||||
var k wgtypes.Key
|
||||
k[31] = lastByte
|
||||
return k.String()
|
||||
}
|
||||
|
||||
// --- pure helpers ----------------------------------------------------------
|
||||
|
||||
func TestHashRosenpassKey_Deterministic(t *testing.T) {
|
||||
key := []byte("hello-rosenpass")
|
||||
require.Equal(t, hashRosenpassKey(key), hashRosenpassKey(key))
|
||||
require.Len(t, hashRosenpassKey(key), 64) // sha256 hex
|
||||
}
|
||||
|
||||
func TestHashRosenpassKey_DifferentInputsDifferOutputs(t *testing.T) {
|
||||
require.NotEqual(t, hashRosenpassKey([]byte("a")), hashRosenpassKey([]byte("b")))
|
||||
}
|
||||
|
||||
func TestGetLogLevel_DefaultWhenUnset(t *testing.T) {
|
||||
// Snapshot + unset to exercise the LookupEnv ok=false branch. t.Setenv
|
||||
// can only set, not delete, so do it manually with restore via t.Cleanup.
|
||||
prev, hadPrev := os.LookupEnv(defaultLogLevelVar)
|
||||
require.NoError(t, os.Unsetenv(defaultLogLevelVar))
|
||||
t.Cleanup(func() {
|
||||
if hadPrev {
|
||||
_ = os.Setenv(defaultLogLevelVar, prev)
|
||||
} else {
|
||||
_ = os.Unsetenv(defaultLogLevelVar)
|
||||
}
|
||||
})
|
||||
require.Equal(t, defaultLog.String(), getLogLevel().String())
|
||||
}
|
||||
|
||||
func TestGetLogLevel_Cases(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"debug": "DEBUG",
|
||||
"info": "INFO",
|
||||
"warn": "WARN",
|
||||
"error": "ERROR",
|
||||
"unknown": "INFO", // default fallback
|
||||
}
|
||||
for input, wantStr := range cases {
|
||||
input, wantStr := input, wantStr
|
||||
t.Run(input, func(t *testing.T) {
|
||||
t.Setenv(defaultLogLevelVar, input)
|
||||
require.Equal(t, wantStr, getLogLevel().String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindRandomAvailableUDPPort(t *testing.T) {
|
||||
port, err := findRandomAvailableUDPPort()
|
||||
require.NoError(t, err)
|
||||
require.Greater(t, port, 0)
|
||||
require.LessOrEqual(t, port, 65535)
|
||||
}
|
||||
|
||||
// --- addPeer ---------------------------------------------------------------
|
||||
|
||||
func TestAddPeer_HigherLocalPubkey_SetsEndpoint(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv) // local spk lexicographically larger
|
||||
|
||||
remotePubKey := make([]byte, 32) // remote spk = all zeros (smaller)
|
||||
err := m.addPeer(remotePubKey, "rosenpass-host:7000", "100.1.1.1", validWGKey(t, 1))
|
||||
require.NoError(t, err)
|
||||
require.Len(t, srv.addCalls, 1)
|
||||
|
||||
ep := srv.addCalls[0].cfg.Endpoint
|
||||
require.NotNil(t, ep, "initiator side must set Endpoint")
|
||||
require.Equal(t, 7000, ep.Port)
|
||||
require.Equal(t, "100.1.1.1", ep.IP.String())
|
||||
}
|
||||
|
||||
func TestAddPeer_HigherLocalPubkey_EndpointIPIsIPv4Mapped(t *testing.T) {
|
||||
// Regression guard for the EDESTADDRREQ fix: Endpoint.IP must be 16-byte
|
||||
// (IPv4-mapped IPv6) so it matches the AF_INET6 listening socket family.
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
|
||||
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
|
||||
require.NoError(t, err)
|
||||
|
||||
ep := srv.addCalls[0].cfg.Endpoint
|
||||
require.NotNil(t, ep)
|
||||
require.Len(t, ep.IP, 16, "IPv4 endpoint must be normalized to 16-byte v4-mapped form")
|
||||
require.True(t, ep.IP.To4() != nil, "Endpoint must still be detected as IPv4")
|
||||
}
|
||||
|
||||
func TestAddPeer_LowerLocalPubkey_LeavesEndpointNil(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0x00, srv) // local spk smaller
|
||||
|
||||
remotePubKey := make([]byte, 32)
|
||||
remotePubKey[0] = 0xFF
|
||||
err := m.addPeer(remotePubKey, "rp:5000", "100.1.1.1", validWGKey(t, 2))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Nil(t, srv.addCalls[0].cfg.Endpoint, "responder side must NOT set Endpoint")
|
||||
}
|
||||
|
||||
func TestAddPeer_PresharedKeyPropagated(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
psk := &wgtypes.Key{0x42}
|
||||
m := newTestManager(0xFF, srv)
|
||||
m.preSharedKey = (*[32]byte)(psk)
|
||||
|
||||
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 3))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, [32]byte(*psk), [32]byte(srv.addCalls[0].cfg.PresharedKey))
|
||||
}
|
||||
|
||||
func TestAddPeer_InvalidRosenpassAddr_ReturnsError(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv) // initiator path → parses rosenpassAddr
|
||||
|
||||
err := m.addPeer(make([]byte, 32), "not-a-host-port", "100.1.1.1", validWGKey(t, 1))
|
||||
require.Error(t, err)
|
||||
require.Empty(t, srv.addCalls, "server.AddPeer must not run when address parse fails")
|
||||
}
|
||||
|
||||
func TestAddPeer_InvalidWireGuardPubKey_ReturnsError(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
|
||||
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", "not-a-valid-key")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestAddPeer_ServerError_Propagates(t *testing.T) {
|
||||
srv := &mockServer{addErr: errors.New("boom")}
|
||||
m := newTestManager(0xFF, srv)
|
||||
|
||||
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// Regression guard for issue #4341 (Android crash). If Run() has not completed
|
||||
// before OnConnected fires, m.rpWgHandler or m.server may be nil. Without the
|
||||
// nil guards, m.rpWgHandler.AddPeer panics on nil receiver.
|
||||
func TestAddPeer_NilHandler_ReturnsErrorNoCrash(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
m.rpWgHandler = nil // simulate Run() not yet completed
|
||||
|
||||
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "wg handler not initialized")
|
||||
}
|
||||
|
||||
func TestAddPeer_NilServer_ReturnsErrorNoCrash(t *testing.T) {
|
||||
m := newTestManager(0xFF, nil)
|
||||
m.server = nil // simulate Run() not yet completed
|
||||
|
||||
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "server not initialized")
|
||||
}
|
||||
|
||||
// NewManager must pre-initialize rpWgHandler so the nil-receiver crash from
|
||||
// issue #4341 cannot occur in the window between NewManager and Run().
|
||||
func TestNewManager_PreInitializesHandler(t *testing.T) {
|
||||
psk := wgtypes.Key{}
|
||||
m, err := NewManager(&psk, "wt0")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, m.rpWgHandler, "rpWgHandler must be initialized in NewManager")
|
||||
}
|
||||
|
||||
func TestAddPeer_RecordsPeerID(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
|
||||
wgKey := validWGKey(t, 5)
|
||||
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, m.rpPeerIDs, wgKey)
|
||||
}
|
||||
|
||||
// --- OnConnected / OnDisconnected ------------------------------------------
|
||||
|
||||
func TestOnConnected_NilRemotePubKey_NoAddPeer(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
|
||||
m.OnConnected(validWGKey(t, 1), nil, "100.1.1.1", "rp:5000")
|
||||
require.Empty(t, srv.addCalls, "nil remote rosenpass pubkey must skip AddPeer")
|
||||
require.Empty(t, m.rpPeerIDs)
|
||||
}
|
||||
|
||||
func TestOnConnected_ValidPubKey_CallsAddPeer(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
|
||||
wgKey := validWGKey(t, 1)
|
||||
m.OnConnected(wgKey, make([]byte, 32), "100.1.1.1", "rp:5000")
|
||||
require.Len(t, srv.addCalls, 1)
|
||||
require.Contains(t, m.rpPeerIDs, wgKey)
|
||||
}
|
||||
|
||||
func TestOnDisconnected_UnknownPeer_NoOp(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
|
||||
m.OnDisconnected(validWGKey(t, 99))
|
||||
require.Empty(t, srv.removed, "unknown peer key must not call RemovePeer")
|
||||
}
|
||||
|
||||
func TestOnDisconnected_KnownPeer_CallsRemoveAndForgets(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
|
||||
wgKey := validWGKey(t, 1)
|
||||
require.NoError(t, m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey))
|
||||
require.Contains(t, m.rpPeerIDs, wgKey)
|
||||
|
||||
m.OnDisconnected(wgKey)
|
||||
require.Len(t, srv.removed, 1)
|
||||
require.NotContains(t, m.rpPeerIDs, wgKey, "peer must be forgotten after disconnect")
|
||||
}
|
||||
|
||||
// --- IsPresharedKeyInitialized ---------------------------------------------
|
||||
|
||||
func TestIsPresharedKeyInitialized_UnknownPeer_ReturnsFalse(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
require.False(t, m.IsPresharedKeyInitialized(validWGKey(t, 1)))
|
||||
}
|
||||
|
||||
func TestIsPresharedKeyInitialized_AddedButNotHandshaken_ReturnsFalse(t *testing.T) {
|
||||
srv := &mockServer{}
|
||||
m := newTestManager(0xFF, srv)
|
||||
|
||||
wgKey := validWGKey(t, 2)
|
||||
require.NoError(t, m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey))
|
||||
require.False(t, m.IsPresharedKeyInitialized(wgKey))
|
||||
}
|
||||
|
||||
// --- NetbirdHandler.outputKey ----------------------------------------------
|
||||
|
||||
func TestHandler_OutputKey_FirstCallUsesUpdateOnlyFalse(t *testing.T) {
|
||||
h := NewNetbirdHandler()
|
||||
iface := &mockIface{}
|
||||
h.SetInterface(iface)
|
||||
|
||||
pid := rp.PeerID{0x01}
|
||||
wgKey := wgtypes.Key{0xAA}
|
||||
h.AddPeer(pid, "wt0", rp.Key(wgKey))
|
||||
|
||||
psk := rp.Key{0xBB}
|
||||
h.HandshakeCompleted(pid, psk)
|
||||
|
||||
require.Len(t, iface.calls, 1)
|
||||
require.False(t, iface.calls[0].updateOnly, "first PSK rotation must use updateOnly=false")
|
||||
require.Equal(t, wgKey.String(), iface.calls[0].peerKey)
|
||||
}
|
||||
|
||||
func TestHandler_OutputKey_SubsequentCallsUseUpdateOnlyTrue(t *testing.T) {
|
||||
h := NewNetbirdHandler()
|
||||
iface := &mockIface{}
|
||||
h.SetInterface(iface)
|
||||
|
||||
pid := rp.PeerID{0x02}
|
||||
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{0xCC}))
|
||||
|
||||
h.HandshakeCompleted(pid, rp.Key{0x01}) // first
|
||||
h.HandshakeCompleted(pid, rp.Key{0x02}) // second
|
||||
|
||||
require.Len(t, iface.calls, 2)
|
||||
require.False(t, iface.calls[0].updateOnly)
|
||||
require.True(t, iface.calls[1].updateOnly, "subsequent rotations must use updateOnly=true")
|
||||
}
|
||||
|
||||
func TestHandler_OutputKey_NilInterface_NoCrashNoCall(t *testing.T) {
|
||||
h := NewNetbirdHandler()
|
||||
// no SetInterface — iface remains nil
|
||||
pid := rp.PeerID{0x03}
|
||||
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{}))
|
||||
|
||||
// Must not panic.
|
||||
h.HandshakeCompleted(pid, rp.Key{})
|
||||
}
|
||||
|
||||
func TestHandler_OutputKey_UnknownPeer_NoCall(t *testing.T) {
|
||||
h := NewNetbirdHandler()
|
||||
iface := &mockIface{}
|
||||
h.SetInterface(iface)
|
||||
|
||||
h.HandshakeCompleted(rp.PeerID{0xFF}, rp.Key{})
|
||||
require.Empty(t, iface.calls, "unknown peer id must not trigger SetPresharedKey")
|
||||
}
|
||||
|
||||
func TestHandler_RemovePeer_ClearsInitializedState(t *testing.T) {
|
||||
h := NewNetbirdHandler()
|
||||
iface := &mockIface{}
|
||||
h.SetInterface(iface)
|
||||
|
||||
pid := rp.PeerID{0x04}
|
||||
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{0xDD}))
|
||||
h.HandshakeCompleted(pid, rp.Key{0x01})
|
||||
require.True(t, h.IsPeerInitialized(pid))
|
||||
|
||||
h.RemovePeer(pid)
|
||||
require.False(t, h.IsPeerInitialized(pid), "RemovePeer must clear initialized flag")
|
||||
}
|
||||
|
||||
func TestHandler_SetInterfaceAfterAddPeer_StillReceivesKey(t *testing.T) {
|
||||
h := NewNetbirdHandler()
|
||||
pid := rp.PeerID{0x05}
|
||||
wgKey := wgtypes.Key{0xEE}
|
||||
h.AddPeer(pid, "wt0", rp.Key(wgKey))
|
||||
|
||||
iface := &mockIface{}
|
||||
h.SetInterface(iface) // set after AddPeer
|
||||
|
||||
h.HandshakeCompleted(pid, rp.Key{0x42})
|
||||
require.Len(t, iface.calls, 1)
|
||||
require.Equal(t, wgKey.String(), iface.calls[0].peerKey)
|
||||
}
|
||||
|
||||
42
client/internal/rosenpass/seed.go
Normal file
42
client/internal/rosenpass/seed.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package rosenpass
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// DeterministicSeedKey derives a 32-byte WireGuard preshared key from a pair
|
||||
// of peer public keys. Both peers, given the same key pair, produce the same
|
||||
// output regardless of which side runs the function: the inputs are ordered
|
||||
// lexicographically before concatenation.
|
||||
//
|
||||
// NetBird uses this value as the initial Rosenpass-side preshared key when no
|
||||
// explicit account-level PSK is configured, so both peers converge on the same
|
||||
// PSK before the first post-quantum handshake completes.
|
||||
//
|
||||
// The resulting key MUST NOT be treated as quantum-safe: it is deterministic
|
||||
// from public keys and exists only to seed WireGuard until Rosenpass rotates
|
||||
// in a real post-quantum PSK.
|
||||
func DeterministicSeedKey(localKey, remoteKey string) (*wgtypes.Key, error) {
|
||||
lk := []byte(localKey)
|
||||
rk := []byte(remoteKey)
|
||||
if len(lk) < 16 || len(rk) < 16 {
|
||||
return nil, fmt.Errorf("rosenpass: peer keys must be at least 16 bytes (got local=%d, remote=%d)", len(lk), len(rk))
|
||||
}
|
||||
|
||||
var keyInput []byte
|
||||
if localKey > remoteKey {
|
||||
keyInput = append(keyInput, lk[:16]...)
|
||||
keyInput = append(keyInput, rk[:16]...)
|
||||
} else {
|
||||
keyInput = append(keyInput, rk[:16]...)
|
||||
keyInput = append(keyInput, lk[:16]...)
|
||||
}
|
||||
|
||||
key, err := wgtypes.NewKey(keyInput)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rosenpass: deterministic seed key: %w", err)
|
||||
}
|
||||
return &key, nil
|
||||
}
|
||||
44
client/internal/rosenpass/seed_test.go
Normal file
44
client/internal/rosenpass/seed_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package rosenpass
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDeterministicSeedKey_SameForBothSides(t *testing.T) {
|
||||
// Peer A and peer B must derive the same PSK regardless of which side
|
||||
// computes it: the function orders inputs internally.
|
||||
a := strings.Repeat("a", 32)
|
||||
b := strings.Repeat("b", 32)
|
||||
|
||||
keyAB, err := DeterministicSeedKey(a, b)
|
||||
require.NoError(t, err)
|
||||
keyBA, err := DeterministicSeedKey(b, a)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, keyAB.String(), keyBA.String(), "swapping arguments must yield identical key")
|
||||
}
|
||||
|
||||
func TestDeterministicSeedKey_ChangesWithKeys(t *testing.T) {
|
||||
a := strings.Repeat("a", 32)
|
||||
b := strings.Repeat("b", 32)
|
||||
c := strings.Repeat("c", 32)
|
||||
|
||||
keyAB, err := DeterministicSeedKey(a, b)
|
||||
require.NoError(t, err)
|
||||
keyAC, err := DeterministicSeedKey(a, c)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, keyAB.String(), keyAC.String(), "different peer pair must yield different key")
|
||||
}
|
||||
|
||||
func TestDeterministicSeedKey_TooShortKey_ReturnsError(t *testing.T) {
|
||||
short := "short" // < 16 bytes
|
||||
long := strings.Repeat("x", 32)
|
||||
|
||||
_, err := DeterministicSeedKey(short, long)
|
||||
require.Error(t, err)
|
||||
_, err = DeterministicSeedKey(long, short)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user