mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-01 13:39:56 +00:00
Compare commits
37 Commits
fix/wiregu
...
peer-acl-m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
91ae600670 | ||
|
|
9189625487 | ||
|
|
e9dbf9db6f | ||
|
|
5a9e9e7bc9 | ||
|
|
43e041cf9f | ||
|
|
77e5693200 | ||
|
|
174dc24867 | ||
|
|
7ea5e37dd4 | ||
|
|
9d7ef9b255 | ||
|
|
944a258459 | ||
|
|
1f9a829f2c | ||
|
|
14af179556 | ||
|
|
1fbb5e6d5d | ||
|
|
6771e35d57 | ||
|
|
e89b1e0596 | ||
|
|
d542c60e21 | ||
|
|
4983b5cf17 | ||
|
|
b3b0feb3b8 | ||
|
|
7aebdd69dd | ||
|
|
0358be2313 | ||
|
|
37052fd5bc | ||
|
|
454ff66518 | ||
|
|
6137a1fcc5 | ||
|
|
4955c345d5 | ||
|
|
9192b4f029 | ||
|
|
c784b02550 | ||
|
|
d250f92c43 | ||
|
|
80966ab1b0 | ||
|
|
af24fd7796 | ||
|
|
13d32d274f | ||
|
|
705f87fc20 | ||
|
|
3f91f49277 | ||
|
|
347c5bf317 | ||
|
|
22e2519d71 | ||
|
|
e916f12cca | ||
|
|
9ed2e2a5b4 | ||
|
|
2ccae7ec47 |
45
.github/dependabot.yml
vendored
Normal file
45
.github/dependabot.yml
vendored
Normal file
@@ -0,0 +1,45 @@
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
open-pull-requests-limit: 15
|
||||
groups:
|
||||
actions:
|
||||
patterns:
|
||||
- "*"
|
||||
ignore:
|
||||
# git-town/action v1.3.x crashes on cyclic PR graphs (self-loop main->main
|
||||
# fork PRs) via its topological-sort visualization. Pinned to v1.2.1 in
|
||||
# git-town.yml; block v1.3.x until upstream tolerates cyclic edges.
|
||||
- dependency-name: "git-town/action"
|
||||
update-types:
|
||||
- "version-update:semver-minor"
|
||||
- "version-update:semver-major"
|
||||
|
||||
- package-ecosystem: "gomod"
|
||||
directories:
|
||||
- "/"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
open-pull-requests-limit: 15
|
||||
groups:
|
||||
aws-sdk:
|
||||
patterns:
|
||||
- "github.com/aws/aws-sdk-go-v2/*"
|
||||
pion:
|
||||
patterns:
|
||||
- "github.com/pion/*"
|
||||
gorm:
|
||||
patterns:
|
||||
- "gorm.io/*"
|
||||
otel:
|
||||
patterns:
|
||||
- "go.opentelemetry.io/*"
|
||||
testcontainers:
|
||||
patterns:
|
||||
- "github.com/testcontainers/testcontainers-go/*"
|
||||
wireguard:
|
||||
patterns:
|
||||
- "golang.zx2c4.com/wireguard*"
|
||||
1
.github/pull_request_template.md
vendored
1
.github/pull_request_template.md
vendored
@@ -12,6 +12,7 @@
|
||||
- [ ] Is a feature enhancement
|
||||
- [ ] It is a refactor
|
||||
- [ ] Created tests that fail without the change (if possible)
|
||||
- [ ] This change does **not** modify the public API, gRPC protocols, functionality behavior, CLI / service flags, or introduce a new feature — **OR** I have discussed it with the NetBird team beforehand (link the issue / Slack thread in the description). See [CONTRIBUTING.md](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTING.md#discuss-changes-with-the-netbird-team-first).
|
||||
|
||||
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).
|
||||
|
||||
|
||||
109
.github/workflows/check-license-dependencies.yml
vendored
109
.github/workflows/check-license-dependencies.yml
vendored
@@ -2,16 +2,16 @@ name: Check License Dependencies
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- '.github/workflows/check-license-dependencies.yml'
|
||||
- "go.mod"
|
||||
- "go.sum"
|
||||
- ".github/workflows/check-license-dependencies.yml"
|
||||
pull_request:
|
||||
paths:
|
||||
- 'go.mod'
|
||||
- 'go.sum'
|
||||
- '.github/workflows/check-license-dependencies.yml'
|
||||
- "go.mod"
|
||||
- "go.sum"
|
||||
- ".github/workflows/check-license-dependencies.yml"
|
||||
|
||||
jobs:
|
||||
check-internal-dependencies:
|
||||
@@ -19,7 +19,10 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Check for problematic license dependencies
|
||||
run: |
|
||||
@@ -56,55 +59,57 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: 'go.mod'
|
||||
cache: true
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: true
|
||||
|
||||
- name: Install go-licenses
|
||||
run: go install github.com/google/go-licenses@v1.6.0
|
||||
- name: Install go-licenses
|
||||
run: go install github.com/google/go-licenses@v1.6.0
|
||||
|
||||
- name: Check for GPL/AGPL licensed dependencies
|
||||
run: |
|
||||
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
|
||||
echo ""
|
||||
|
||||
# Check all Go packages for copyleft licenses, excluding internal netbird packages
|
||||
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
|
||||
|
||||
if [ -n "$COPYLEFT_DEPS" ]; then
|
||||
echo "Found copyleft licensed dependencies:"
|
||||
echo "$COPYLEFT_DEPS"
|
||||
- name: Check for GPL/AGPL licensed dependencies
|
||||
run: |
|
||||
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
|
||||
echo ""
|
||||
|
||||
# Filter out dependencies that are only pulled in by internal AGPL packages
|
||||
INCOMPATIBLE=""
|
||||
while IFS=',' read -r package url license; do
|
||||
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
|
||||
# Find ALL packages that import this GPL package using go list
|
||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||
# Check all Go packages for copyleft licenses, excluding internal netbird packages
|
||||
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
|
||||
|
||||
# Check if any importer is NOT in management/signal/relay
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
|
||||
|
||||
if [ -n "$BSD_IMPORTER" ]; then
|
||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
|
||||
else
|
||||
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
|
||||
fi
|
||||
fi
|
||||
done <<< "$COPYLEFT_DEPS"
|
||||
|
||||
if [ -n "$INCOMPATIBLE" ]; then
|
||||
if [ -n "$COPYLEFT_DEPS" ]; then
|
||||
echo "Found copyleft licensed dependencies:"
|
||||
echo "$COPYLEFT_DEPS"
|
||||
echo ""
|
||||
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
|
||||
echo -e "$INCOMPATIBLE"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "✅ All external license dependencies are compatible with BSD-3-Clause"
|
||||
# Filter out dependencies that are only pulled in by internal AGPL packages
|
||||
INCOMPATIBLE=""
|
||||
while IFS=',' read -r package url license; do
|
||||
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
|
||||
# Find ALL packages that import this GPL package using go list
|
||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||
|
||||
# Check if any importer is NOT in management/signal/relay
|
||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
|
||||
|
||||
if [ -n "$BSD_IMPORTER" ]; then
|
||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
|
||||
else
|
||||
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
|
||||
fi
|
||||
fi
|
||||
done <<< "$COPYLEFT_DEPS"
|
||||
|
||||
if [ -n "$INCOMPATIBLE" ]; then
|
||||
echo ""
|
||||
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
|
||||
echo -e "$INCOMPATIBLE"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "✅ All external license dependencies are compatible with BSD-3-Clause"
|
||||
|
||||
2
.github/workflows/docs-ack.yml
vendored
2
.github/workflows/docs-ack.yml
vendored
@@ -83,7 +83,7 @@ jobs:
|
||||
|
||||
- name: Verify docs PR exists (and is open or merged)
|
||||
if: steps.validate.outputs.mode == 'added'
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
id: verify
|
||||
with:
|
||||
pr_number: ${{ steps.extract.outputs.pr_number }}
|
||||
|
||||
5
.github/workflows/forum.yml
vendored
5
.github/workflows/forum.yml
vendored
@@ -8,11 +8,10 @@ jobs:
|
||||
post:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: roots/discourse-topic-github-release-action@main
|
||||
- uses: roots/discourse-topic-github-release-action@557d74ea05b6cc0c47f555c1d5d28a89d904005b # v1.1.0
|
||||
with:
|
||||
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
|
||||
discourse-base-url: https://forum.netbird.io
|
||||
discourse-author-username: NetBird
|
||||
discourse-category: 17
|
||||
discourse-tags:
|
||||
releases
|
||||
discourse-tags: releases
|
||||
|
||||
8
.github/workflows/git-town.yml
vendored
8
.github/workflows/git-town.yml
vendored
@@ -3,7 +3,7 @@ name: Git Town
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- '**'
|
||||
- "**"
|
||||
|
||||
jobs:
|
||||
git-town:
|
||||
@@ -15,7 +15,9 @@ jobs:
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: git-town/action@v1.2.1
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: git-town/action@3d8b878379abb1ee393fb49865a28b4a6c2cd3b0 # v1.2.1
|
||||
with:
|
||||
skip-single-stacks: true
|
||||
|
||||
9
.github/workflows/golang-test-darwin.yml
vendored
9
.github/workflows/golang-test-darwin.yml
vendored
@@ -16,16 +16,18 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: macos-gotest-${{ hashFiles('**/go.sum') }}
|
||||
@@ -44,4 +46,3 @@ jobs:
|
||||
|
||||
- name: Test
|
||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
|
||||
|
||||
21
.github/workflows/golang-test-freebsd.yml
vendored
21
.github/workflows/golang-test-freebsd.yml
vendored
@@ -15,20 +15,31 @@ jobs:
|
||||
name: "Client / Unit"
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Read Go version from go.mod
|
||||
id: goversion
|
||||
run: echo "version=$(awk '/^go / {print $2}' go.mod)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Test in FreeBSD
|
||||
id: test
|
||||
uses: vmactions/freebsd-vm@v1
|
||||
env:
|
||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
release: "14.2"
|
||||
release: "15.0"
|
||||
envs: "GO_VERSION"
|
||||
prepare: |
|
||||
pkg install -y curl pkgconf xorg
|
||||
GO_TARBALL="go1.25.3.freebsd-amd64.tar.gz"
|
||||
GO_TARBALL="go${GO_VERSION}.freebsd-amd64.tar.gz"
|
||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||
curl -vLO "$GO_URL"
|
||||
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||
|
||||
# -x - to print all executed commands
|
||||
# -e - to faile on first error
|
||||
|
||||
142
.github/workflows/golang-test-linux.yml
vendored
142
.github/workflows/golang-test-linux.yml
vendored
@@ -18,9 +18,11 @@ jobs:
|
||||
management: ${{ steps.filter.outputs.management }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- uses: dorny/paths-filter@v3
|
||||
- uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1
|
||||
id: filter
|
||||
with:
|
||||
filters: |
|
||||
@@ -28,7 +30,7 @@ jobs:
|
||||
- 'management/**'
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -36,10 +38,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
id: cache
|
||||
with:
|
||||
path: |
|
||||
@@ -113,14 +115,16 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
arch: ["386", "amd64"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -128,10 +132,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -154,18 +158,20 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags "devcert integration" -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
|
||||
test_client_on_docker:
|
||||
name: "Client (Docker) / Unit"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -177,7 +183,7 @@ jobs:
|
||||
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
id: cache-restore
|
||||
with:
|
||||
path: |
|
||||
@@ -214,7 +220,7 @@ jobs:
|
||||
sh -c ' \
|
||||
apk update; apk add --no-cache \
|
||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
|
||||
go test -buildvcs=false -tags "devcert integration" -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
|
||||
'
|
||||
|
||||
test_relay:
|
||||
@@ -231,10 +237,12 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -246,10 +254,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -277,14 +285,16 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
arch: ["386", "amd64"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -298,7 +308,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -324,14 +334,16 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ '386','amd64' ]
|
||||
arch: ["386", "amd64"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -343,10 +355,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -370,19 +382,21 @@ jobs:
|
||||
|
||||
test_management:
|
||||
name: "Management / Unit"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
arch: ["amd64"]
|
||||
store: ["sqlite", "postgres", "mysql"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -390,10 +404,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -410,7 +424,7 @@ jobs:
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -427,7 +441,7 @@ jobs:
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
run: |
|
||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||
CI=true \
|
||||
@@ -437,13 +451,13 @@ jobs:
|
||||
|
||||
benchmark:
|
||||
name: "Management / Benchmark"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
arch: ["amd64"]
|
||||
store: ["sqlite", "postgres"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Create Docker network
|
||||
@@ -474,10 +488,12 @@ jobs:
|
||||
prom/prometheus
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -485,10 +501,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -505,7 +521,7 @@ jobs:
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -529,13 +545,13 @@ jobs:
|
||||
|
||||
api_benchmark:
|
||||
name: "Management / Benchmark (API)"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres' ]
|
||||
arch: ["amd64"]
|
||||
store: ["sqlite", "postgres"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Create Docker network
|
||||
@@ -566,10 +582,12 @@ jobs:
|
||||
prom/prometheus
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -577,10 +595,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -597,7 +615,7 @@ jobs:
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
@@ -623,20 +641,22 @@ jobs:
|
||||
|
||||
api_integration_test:
|
||||
name: "Management / Integration"
|
||||
needs: [ build-cache ]
|
||||
needs: [build-cache]
|
||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
arch: [ 'amd64' ]
|
||||
store: [ 'sqlite', 'postgres']
|
||||
arch: ["amd64"]
|
||||
store: ["sqlite", "postgres"]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -644,10 +664,10 @@ jobs:
|
||||
- name: Get Go environment
|
||||
run: |
|
||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache/restore@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
|
||||
19
.github/workflows/golang-test-windows.yml
vendored
19
.github/workflows/golang-test-windows.yml
vendored
@@ -18,10 +18,12 @@ jobs:
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
id: go
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
@@ -33,7 +35,7 @@ jobs:
|
||||
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
${{ env.cache }}
|
||||
@@ -44,16 +46,15 @@ jobs:
|
||||
${{ runner.os }}-go-
|
||||
|
||||
- name: Download wintun
|
||||
uses: carlosperate/download-file-action@v2
|
||||
id: download-wintun
|
||||
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
with:
|
||||
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
|
||||
file-name: wintun.zip
|
||||
location: ${{ env.downloadPath }}
|
||||
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
|
||||
url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
|
||||
destination: ${{ env.downloadPath }}\wintun.zip
|
||||
sha256: 07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51
|
||||
|
||||
- name: Decompressing wintun files
|
||||
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
||||
run: tar -xvf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
||||
|
||||
- run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'
|
||||
|
||||
|
||||
14
.github/workflows/golangci-lint.yml
vendored
14
.github/workflows/golangci-lint.yml
vendored
@@ -15,9 +15,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
uses: codespell-project/actions-codespell@8f01853be192eb0f849a5c7d721450e7a467c579 # v2.2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
|
||||
skip: go.mod,go.sum,**/proxy/web/**
|
||||
@@ -38,13 +40,15 @@ jobs:
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Check for duplicate constants
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: |
|
||||
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
@@ -52,7 +56,7 @@ jobs:
|
||||
if: matrix.os == 'ubuntu-latest'
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
|
||||
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
|
||||
with:
|
||||
version: latest
|
||||
skip-cache: true
|
||||
|
||||
4
.github/workflows/install-script-test.yml
vendored
4
.github/workflows/install-script-test.yml
vendored
@@ -22,7 +22,9 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: run install script
|
||||
env:
|
||||
|
||||
18
.github/workflows/mobile-build-validation.yml
vendored
18
.github/workflows/mobile-build-validation.yml
vendored
@@ -16,23 +16,25 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Setup Android SDK
|
||||
uses: android-actions/setup-android@v3
|
||||
uses: android-actions/setup-android@40fd30fb8d7440372e1316f5d1809ec01dcd3699 # v4.0.1
|
||||
with:
|
||||
cmdline-tools-version: 8512546
|
||||
- name: Setup Java
|
||||
uses: actions/setup-java@v4
|
||||
uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654
|
||||
with:
|
||||
java-version: "11"
|
||||
distribution: "adopt"
|
||||
- name: NDK Cache
|
||||
id: ndk-cache
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: /usr/local/lib/android/sdk/ndk
|
||||
key: ndk-cache-23.1.7779620
|
||||
@@ -52,9 +54,11 @@ jobs:
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: install gomobile
|
||||
|
||||
2
.github/workflows/pr-title-check.yml
vendored
2
.github/workflows/pr-title-check.yml
vendored
@@ -9,7 +9,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Validate PR title prefix
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
script: |
|
||||
const title = context.payload.pull_request.title;
|
||||
|
||||
68
.github/workflows/proto-version-check.yml
vendored
68
.github/workflows/proto-version-check.yml
vendored
@@ -10,7 +10,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check for proto tool version changes
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
with:
|
||||
script: |
|
||||
const files = await github.paginate(github.rest.pulls.listFiles, {
|
||||
@@ -20,34 +20,66 @@ jobs:
|
||||
per_page: 100,
|
||||
});
|
||||
|
||||
const pbFiles = files.filter(f => f.filename.endsWith('.pb.go'));
|
||||
const missingPatch = pbFiles.filter(f => !f.patch).map(f => f.filename);
|
||||
if (missingPatch.length > 0) {
|
||||
core.setFailed(
|
||||
`Cannot inspect patch data for:\n` +
|
||||
missingPatch.map(f => `- ${f}`).join('\n') +
|
||||
`\nThis can happen with very large PRs. Verify proto versions manually.`
|
||||
);
|
||||
const modifiedPbFiles = files.filter(
|
||||
f => f.filename.endsWith('.pb.go') && f.status === 'modified'
|
||||
);
|
||||
if (modifiedPbFiles.length === 0) {
|
||||
console.log('No modified .pb.go files to check');
|
||||
return;
|
||||
}
|
||||
const versionPattern = /^[+-]\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
|
||||
const violations = [];
|
||||
|
||||
for (const file of pbFiles) {
|
||||
const changed = file.patch
|
||||
.split('\n')
|
||||
.filter(line => versionPattern.test(line));
|
||||
if (changed.length > 0) {
|
||||
const versionPattern = /^\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
|
||||
const baseSha = context.payload.pull_request.base.sha;
|
||||
const headSha = context.payload.pull_request.head.sha;
|
||||
|
||||
async function getVersionHeader(path, ref) {
|
||||
try {
|
||||
const res = await github.rest.repos.getContent({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
path,
|
||||
ref,
|
||||
});
|
||||
if (!res.data.content) {
|
||||
return { ok: false, reason: 'no inline content (file too large)' };
|
||||
}
|
||||
const content = Buffer.from(res.data.content, 'base64').toString('utf8');
|
||||
const lines = content
|
||||
.split('\n')
|
||||
.slice(0, 20)
|
||||
.filter(line => versionPattern.test(line));
|
||||
return { ok: true, lines };
|
||||
} catch (e) {
|
||||
return { ok: false, reason: e.message };
|
||||
}
|
||||
}
|
||||
|
||||
const violations = [];
|
||||
for (const file of modifiedPbFiles) {
|
||||
const [base, head] = await Promise.all([
|
||||
getVersionHeader(file.filename, baseSha),
|
||||
getVersionHeader(file.filename, headSha),
|
||||
]);
|
||||
if (!base.ok || !head.ok) {
|
||||
core.warning(
|
||||
`Skipping ${file.filename}: base=${base.ok ? 'ok' : base.reason}, head=${head.ok ? 'ok' : head.reason}`
|
||||
);
|
||||
continue;
|
||||
}
|
||||
if (base.lines.join('\n') !== head.lines.join('\n')) {
|
||||
violations.push({
|
||||
file: file.filename,
|
||||
lines: changed,
|
||||
base: base.lines,
|
||||
head: head.lines,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (violations.length > 0) {
|
||||
const details = violations.map(v =>
|
||||
`${v.file}:\n${v.lines.map(l => ' ' + l).join('\n')}`
|
||||
`${v.file}:\n` +
|
||||
` base:\n${v.base.map(l => ' ' + l).join('\n') || ' (none)'}\n` +
|
||||
` head:\n${v.head.map(l => ' ' + l).join('\n') || ' (none)'}`
|
||||
).join('\n\n');
|
||||
|
||||
core.setFailed(
|
||||
|
||||
168
.github/workflows/release.yml
vendored
168
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
||||
pull_request:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.1.4"
|
||||
SIGN_PIPE_VER: "v0.1.5"
|
||||
GORELEASER_VER: "v2.14.3"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
@@ -24,7 +24,9 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Generate FreeBSD port diff
|
||||
run: bash release_files/freebsd-port-diff.sh
|
||||
@@ -51,19 +53,26 @@ jobs:
|
||||
echo "Generated files for version: $VERSION"
|
||||
cat netbird-*.diff
|
||||
|
||||
- name: Read Go version from go.mod
|
||||
id: goversion
|
||||
run: echo "version=$(awk '/^go / {print $2}' go.mod)" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Test FreeBSD port
|
||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||
uses: vmactions/freebsd-vm@v1
|
||||
env:
|
||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
||||
with:
|
||||
usesh: true
|
||||
copyback: false
|
||||
release: "15.0"
|
||||
envs: "GO_VERSION"
|
||||
prepare: |
|
||||
# Install required packages
|
||||
pkg install -y git curl portlint go
|
||||
pkg install -y git curl portlint
|
||||
|
||||
# Install Go for building
|
||||
GO_TARBALL="go1.25.5.freebsd-amd64.tar.gz"
|
||||
GO_TARBALL="go${GO_VERSION}.freebsd-amd64.tar.gz"
|
||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||
curl -LO "$GO_URL"
|
||||
tar -C /usr/local -xzf "$GO_TARBALL"
|
||||
@@ -93,19 +102,19 @@ jobs:
|
||||
|
||||
# Show patched Makefile
|
||||
version=$(cat security/netbird/Makefile | grep -E '^DISTVERSION=' | awk '{print $NF}')
|
||||
|
||||
|
||||
cd /usr/ports/security/netbird
|
||||
export BATCH=yes
|
||||
make package
|
||||
pkg add ./work/pkg/netbird-*.pkg
|
||||
|
||||
|
||||
netbird version | grep "$version"
|
||||
|
||||
echo "FreeBSD port test completed successfully!"
|
||||
|
||||
- name: Upload FreeBSD port files
|
||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: freebsd-port-files
|
||||
path: |
|
||||
@@ -124,26 +133,25 @@ jobs:
|
||||
env:
|
||||
flags: ""
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||
with:
|
||||
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||
version_extractor_regex: '\/v(.*)$'
|
||||
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -156,18 +164,18 @@ jobs:
|
||||
- name: check git status
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v2
|
||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a #v4.0.0
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd #v4.0.0
|
||||
- name: Login to Docker hub
|
||||
if: github.event_name != 'pull_request'
|
||||
uses: docker/login-action@v1
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Log in to the GitHub container registry
|
||||
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
@@ -191,7 +199,7 @@ jobs:
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --clean ${{ env.flags }}
|
||||
@@ -282,28 +290,28 @@ jobs:
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
- name: upload non tags for debug purposes
|
||||
id: upload_release
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: release
|
||||
path: dist/
|
||||
retention-days: 7
|
||||
- name: upload linux packages
|
||||
id: upload_linux_packages
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: linux-packages
|
||||
path: dist/netbird_linux**
|
||||
retention-days: 7
|
||||
- name: upload windows packages
|
||||
id: upload_windows_packages
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: windows-packages
|
||||
path: dist/netbird_windows**
|
||||
retention-days: 7
|
||||
- name: upload macos packages
|
||||
id: upload_macos_packages
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: macos-packages
|
||||
path: dist/netbird_darwin**
|
||||
@@ -314,27 +322,26 @@ jobs:
|
||||
outputs:
|
||||
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||
with:
|
||||
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||
version_extractor_regex: '\/v(.*)$'
|
||||
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -375,7 +382,7 @@ jobs:
|
||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
|
||||
|
||||
- name: Run GoReleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
|
||||
@@ -404,7 +411,7 @@ jobs:
|
||||
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
||||
- name: upload non tags for debug purposes
|
||||
id: upload_release_ui
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: release-ui
|
||||
path: dist/
|
||||
@@ -418,16 +425,17 @@ jobs:
|
||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||
persist-credentials: false
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
cache: false
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
@@ -441,7 +449,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
- name: Run GoReleaser
|
||||
id: goreleaser
|
||||
uses: goreleaser/goreleaser-action@v4
|
||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
||||
with:
|
||||
version: ${{ env.GORELEASER_VER }}
|
||||
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
|
||||
@@ -449,7 +457,7 @@ jobs:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: upload non tags for debug purposes
|
||||
id: upload_release_ui_darwin
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: release-ui-darwin
|
||||
path: dist/
|
||||
@@ -474,27 +482,26 @@ jobs:
|
||||
PackageWorkdir: netbird_windows_${{ matrix.arch }}
|
||||
downloadPath: '${{ github.workspace }}\temp'
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Parse semver string
|
||||
id: semver_parser
|
||||
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||
with:
|
||||
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||
version_extractor_regex: '\/v(.*)$'
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
|
||||
- name: Add 7-Zip to PATH
|
||||
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
|
||||
- name: Download release artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
||||
with:
|
||||
name: release
|
||||
path: release
|
||||
|
||||
- name: Download UI release artifacts
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
||||
with:
|
||||
name: release-ui
|
||||
path: release-ui
|
||||
@@ -514,29 +521,27 @@ jobs:
|
||||
Get-ChildItem $workdir
|
||||
|
||||
- name: Download wintun
|
||||
uses: carlosperate/download-file-action@v2
|
||||
id: download-wintun
|
||||
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
with:
|
||||
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
|
||||
file-name: wintun.zip
|
||||
location: ${{ env.downloadPath }}
|
||||
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
|
||||
url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
|
||||
destination: ${{ env.downloadPath }}\wintun.zip
|
||||
sha256: 07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51
|
||||
|
||||
- name: Decompress wintun files
|
||||
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
||||
run: tar -xvf "${{ env.downloadPath }}\wintun.zip" -C ${{ env.downloadPath }}
|
||||
|
||||
- name: Move wintun.dll into dist
|
||||
run: mv ${{ env.downloadPath }}\wintun\bin\${{ matrix.wintun_arch }}\wintun.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
|
||||
|
||||
- name: Download Mesa3D (amd64 only)
|
||||
uses: carlosperate/download-file-action@v2
|
||||
id: download-mesa3d
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
with:
|
||||
file-url: https://downloads.fdossena.com/Projects/Mesa3D/Builds/MesaForWindows-x64-20.1.8.7z
|
||||
file-name: mesa3d.7z
|
||||
location: ${{ env.downloadPath }}
|
||||
sha256: '71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9'
|
||||
url: https://pkgs.netbird.io/mesa3d/MesaForWindows-x64-20.1.8.7z
|
||||
destination: ${{ env.downloadPath }}\mesa3d.7z
|
||||
sha256: 71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9
|
||||
|
||||
- name: Extract Mesa3D driver (amd64 only)
|
||||
if: matrix.arch == 'amd64'
|
||||
@@ -547,35 +552,38 @@ jobs:
|
||||
run: mv ${{ env.downloadPath }}\opengl32.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
|
||||
|
||||
- name: Download EnVar plugin for NSIS
|
||||
uses: carlosperate/download-file-action@v2
|
||||
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
with:
|
||||
file-url: https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip
|
||||
file-name: envar_plugin.zip
|
||||
location: ${{ github.workspace }}
|
||||
url: https://pkgs.netbird.io/nsis/EnVar_plugin.zip
|
||||
destination: ${{ github.workspace }}\envar_plugin.zip
|
||||
sha256: e9aa92de351345ed82795251d838f1ae9041ba35af9d381a5780c7843b01f56a
|
||||
|
||||
- name: Extract EnVar plugin
|
||||
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/envar_plugin.zip"
|
||||
|
||||
- name: Download ShellExecAsUser plugin for NSIS (amd64 only)
|
||||
uses: carlosperate/download-file-action@v2
|
||||
if: matrix.arch == 'amd64'
|
||||
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
||||
with:
|
||||
file-url: https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z
|
||||
file-name: ShellExecAsUser_amd64-Unicode.7z
|
||||
location: ${{ github.workspace }}
|
||||
url: https://pkgs.netbird.io/nsis/ShellExecAsUser_amd64-Unicode.7z
|
||||
destination: ${{ github.workspace }}\ShellExecAsUser_amd64-Unicode.7z
|
||||
sha256: 0a55ea25c7330a92cec028eda8afcaf1b1a7092e0dfb77c21c8f654564b4ff9d
|
||||
|
||||
- name: Extract ShellExecAsUser plugin (amd64 only)
|
||||
if: matrix.arch == 'amd64'
|
||||
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/ShellExecAsUser_amd64-Unicode.7z"
|
||||
|
||||
- name: Build NSIS installer
|
||||
uses: joncloud/makensis-action@v3.3
|
||||
with:
|
||||
additional-plugin-paths: ${{ github.workspace }}/NSIS_Plugins/Plugins
|
||||
script-file: client/installer.nsis
|
||||
arguments: "/V4 /DARCH=${{ matrix.arch }}"
|
||||
shell: pwsh
|
||||
env:
|
||||
APPVER: ${{ steps.semver_parser.outputs.major }}.${{ steps.semver_parser.outputs.minor }}.${{ steps.semver_parser.outputs.patch }}.${{ github.run_id }}
|
||||
run: |
|
||||
$nsisPluginDir = "C:\Program Files (x86)\NSIS\Plugins\x86-unicode"
|
||||
$srcPlugins = "${{ github.workspace }}\NSIS_Plugins\Plugins"
|
||||
Get-ChildItem -Path $srcPlugins -Recurse -Filter *.dll |
|
||||
Copy-Item -Destination $nsisPluginDir -Force
|
||||
& "C:\Program Files (x86)\NSIS\makensis.exe" /V4 "/DARCH=${{ matrix.arch }}" client\installer.nsis
|
||||
if ($LASTEXITCODE -ne 0) { throw "makensis failed with exit code $LASTEXITCODE" }
|
||||
|
||||
- name: Rename NSIS installer
|
||||
run: mv netbird-installer.exe netbird_installer_test_windows_${{ matrix.arch }}.exe
|
||||
@@ -592,7 +600,7 @@ jobs:
|
||||
|
||||
- name: Upload installer artifacts
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
||||
with:
|
||||
name: windows-installer-test-${{ matrix.arch }}
|
||||
path: |
|
||||
@@ -611,7 +619,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Create or update PR comment
|
||||
uses: actions/github-script@v7
|
||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
||||
env:
|
||||
RELEASE_RESULT: ${{ needs.release.result }}
|
||||
RELEASE_UI_RESULT: ${{ needs.release_ui.result }}
|
||||
@@ -703,7 +711,7 @@ jobs:
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- name: Trigger binaries sign pipelines
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: Sign bin and installer
|
||||
repo: netbirdio/sign-pipelines
|
||||
|
||||
4
.github/workflows/sync-main.yml
vendored
4
.github/workflows/sync-main.yml
vendored
@@ -14,9 +14,9 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Trigger main branch sync
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: sync-main.yml
|
||||
repo: ${{ secrets.UPSTREAM_REPO }}
|
||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
||||
inputs: '{ "sha": "${{ github.sha }}" }'
|
||||
inputs: '{ "sha": "${{ github.sha }}" }'
|
||||
|
||||
10
.github/workflows/sync-tag.yml
vendored
10
.github/workflows/sync-tag.yml
vendored
@@ -3,7 +3,7 @@ name: sync tag
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
- "v*"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
@@ -16,7 +16,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Trigger release tag sync
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: sync-tag.yml
|
||||
ref: main
|
||||
@@ -29,7 +29,7 @@ jobs:
|
||||
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
|
||||
steps:
|
||||
- name: Trigger android-client submodule bump
|
||||
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: bump-netbird.yml
|
||||
ref: main
|
||||
@@ -42,10 +42,10 @@ jobs:
|
||||
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
|
||||
steps:
|
||||
- name: Trigger ios-client submodule bump
|
||||
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: bump-netbird.yml
|
||||
ref: main
|
||||
repo: netbirdio/ios-client
|
||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
||||
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
||||
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
||||
|
||||
26
.github/workflows/test-infrastructure-files.yml
vendored
26
.github/workflows/test-infrastructure-files.yml
vendored
@@ -6,10 +6,10 @@ on:
|
||||
- main
|
||||
pull_request:
|
||||
paths:
|
||||
- 'infrastructure_files/**'
|
||||
- '.github/workflows/test-infrastructure-files.yml'
|
||||
- 'management/cmd/**'
|
||||
- 'signal/cmd/**'
|
||||
- "infrastructure_files/**"
|
||||
- ".github/workflows/test-infrastructure-files.yml"
|
||||
- "management/cmd/**"
|
||||
- "signal/cmd/**"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||
@@ -20,7 +20,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||
store: ["sqlite", "postgres", "mysql"]
|
||||
services:
|
||||
postgres:
|
||||
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
|
||||
@@ -68,15 +68,17 @@ jobs:
|
||||
run: sudo apt-get install -y curl
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
||||
with:
|
||||
path: ~/go/pkg/mod
|
||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||
@@ -139,8 +141,8 @@ jobs:
|
||||
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
||||
CI_NETBIRD_SIGNAL_PORT: 12345
|
||||
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
|
||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
|
||||
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
|
||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN: "${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$"
|
||||
NETBIRD_STORE_ENGINE_MYSQL_DSN: "${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$"
|
||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
||||
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
|
||||
@@ -254,7 +256,9 @@ jobs:
|
||||
run: sudo apt-get install -y jq
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: run script with Zitadel PostgreSQL
|
||||
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
|
||||
|
||||
8
.github/workflows/update-docs.yml
vendored
8
.github/workflows/update-docs.yml
vendored
@@ -3,9 +3,9 @@ name: update docs
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
- "v*"
|
||||
paths:
|
||||
- 'shared/management/http/api/openapi.yml'
|
||||
- "shared/management/http/api/openapi.yml"
|
||||
|
||||
jobs:
|
||||
trigger_docs_api_update:
|
||||
@@ -13,10 +13,10 @@ jobs:
|
||||
if: startsWith(github.ref, 'refs/tags/')
|
||||
steps:
|
||||
- name: Trigger API pages generation
|
||||
uses: benc-uk/workflow-dispatch@v1
|
||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
||||
with:
|
||||
workflow: generate api pages
|
||||
repo: netbirdio/docs
|
||||
ref: "refs/heads/main"
|
||||
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
||||
inputs: '{ "tag": "${{ github.ref }}" }'
|
||||
inputs: '{ "tag": "${{ github.ref }}" }'
|
||||
|
||||
15
.github/workflows/wasm-build-validation.yml
vendored
15
.github/workflows/wasm-build-validation.yml
vendored
@@ -19,15 +19,17 @@ jobs:
|
||||
GOARCH: wasm
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||
- name: Install golangci-lint
|
||||
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
|
||||
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
|
||||
with:
|
||||
version: latest
|
||||
install-mode: binary
|
||||
@@ -42,9 +44,11 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
persist-credentials: false
|
||||
- name: Install Go
|
||||
uses: actions/setup-go@v5
|
||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
||||
with:
|
||||
go-version-file: "go.mod"
|
||||
- name: Build Wasm client
|
||||
@@ -65,4 +69,3 @@ jobs:
|
||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ If you haven't already, join our slack workspace [here](https://docs.netbird.io/
|
||||
- [Contributing to NetBird](#contributing-to-netbird)
|
||||
- [Contents](#contents)
|
||||
- [Code of conduct](#code-of-conduct)
|
||||
- [Discuss changes with the NetBird team first](#discuss-changes-with-the-netbird-team-first)
|
||||
- [Directory structure](#directory-structure)
|
||||
- [Development setup](#development-setup)
|
||||
- [Requirements](#requirements)
|
||||
@@ -33,6 +34,14 @@ Conduct which can be found in the file [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md).
|
||||
By participating, you are expected to uphold this code. Please report
|
||||
unacceptable behavior to community@netbird.io.
|
||||
|
||||
## Discuss changes with the NetBird team first
|
||||
|
||||
Changes to the **public API**, **gRPC protocols**, **functionality behavior**, **CLI / service flags**, or **new features** should be discussed with the NetBird team before you start the work. These surfaces are part of NetBird's contract with operators, self-hosters, and downstream integrators, and changes to them have compatibility, security, and release-planning implications that benefit from an early conversation.
|
||||
|
||||
Open an issue or reach out on [Slack](https://docs.netbird.io/slack-url) to talk through what you have in mind. We'll help shape the change, flag any constraints we know about, and confirm the direction so the PR review can focus on implementation rather than design.
|
||||
|
||||
Typical bug fixes, internal refactors, documentation updates, and tests do not need pre-discussion — open the PR directly.
|
||||
|
||||
## Directory structure
|
||||
|
||||
The NetBird project monorepo is organized to maintain most of its individual dependencies code within their directories, except for a few auxiliary or shared packages.
|
||||
|
||||
153
README.md
153
README.md
@@ -1,147 +1,134 @@
|
||||
|
||||
<div align="center">
|
||||
<br/>
|
||||
<br/>
|
||||
<p align="center">
|
||||
<img width="234" src="docs/media/logo-full.png"/>
|
||||
</p>
|
||||
<p>
|
||||
<a href="https://img.shields.io/badge/license-BSD--3-blue)">
|
||||
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" />
|
||||
</a>
|
||||
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
|
||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||
</a>
|
||||
<br>
|
||||
<p align="center">
|
||||
<img width="234" src="docs/media/logo-full.png" alt="NetBird logo"/>
|
||||
</p>
|
||||
<p align="center">
|
||||
<a href="https://sonarcloud.io/dashboard?id=netbirdio_netbird">
|
||||
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" alt="SonarCloud alert status"/>
|
||||
</a>
|
||||
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
|
||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" alt="BSD-3 License"/>
|
||||
</a>
|
||||
<a href="https://docs.netbird.io/slack-url">
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||
</a>
|
||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack" alt="NetBird Slack"/>
|
||||
</a>
|
||||
<a href="https://forum.netbird.io">
|
||||
<img src="https://img.shields.io/badge/community forum-@netbird-red.svg?logo=discourse"/>
|
||||
</a>
|
||||
<br>
|
||||
<img src="https://img.shields.io/badge/community%20forum-@netbird-red.svg?logo=discourse" alt="Community forum"/>
|
||||
</a>
|
||||
<a href="https://gurubase.io/g/netbird">
|
||||
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
|
||||
</a>
|
||||
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF" alt="Gurubase: Ask NetBird Guru"/>
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
|
||||
<p align="center">
|
||||
<strong>
|
||||
Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
|
||||
<strong>
|
||||
Start using NetBird at <a href="https://netbird.io/pricing">netbird.io</a>
|
||||
<br/>
|
||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||
<br/>
|
||||
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
|
||||
</strong>
|
||||
<br/>
|
||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||
<br/>
|
||||
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
|
||||
<br/>
|
||||
|
||||
</strong>
|
||||
<br>
|
||||
<strong>
|
||||
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
|
||||
</strong>
|
||||
<br>
|
||||
<br>
|
||||
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
|
||||
New: NetBird terraform provider
|
||||
</a>
|
||||
<strong>
|
||||
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
|
||||
</strong>
|
||||
</p>
|
||||
|
||||
<br>
|
||||
|
||||
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
|
||||
|
||||
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.
|
||||
|
||||
**Secure.** NetBird enables secure remote access by applying granular access policies while allowing you to manage them intuitively from a single place. Works universally on any infrastructure.
|
||||
|
||||
### Open Source Network Security in a Single Platform
|
||||
|
||||
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
||||
|
||||
### Self-Host NetBird (Video)
|
||||
### Self-host NetBird (video)
|
||||
|
||||
[](https://youtu.be/bZAgpT6nzaQ)
|
||||
|
||||
### Key features
|
||||
|
||||
| Connectivity | Management | Security | Automation| Platforms |
|
||||
|----|----|----|----|----|
|
||||
| <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> |
|
||||
| <ul><li>- \[x] Peer-to-peer connections</ul></li> | <ul><li>- \[x] Auto peer discovery and configuration</ui></li> | <ul><li>- \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)</ui></li> | <ul><li>- \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)</ui></li> | <ul><li>- \[x] Mac</ui></li> |
|
||||
| <ul><li>- \[x] Connection relay fallback</ui></li> | <ul><li>- \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)</ui></li> | <ul><li>- \[x] [Activity logging](https://docs.netbird.io/how-to/audit-events-logging)</ui></li> | <ul><li>- \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)</ui></li> | <ul><li>- \[x] Windows</ui></li> |
|
||||
| <ul><li>- \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)</ui></li> | <ul><li>- \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)</ui></li> | <ul><li>- \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)</ui></li> | <ul><li>- \[x] IdP groups sync with JWT</ui></li> | <ul><li>- \[x] Android</ui></li> |
|
||||
| <ul><li>- \[x] NAT traversal with BPF</ui></li> | <ul><li>- \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)</ui></li> | <ul><li>- \[x] Peer-to-peer encryption</ui></li> || <ul><li>- \[x] iOS</ui></li> |
|
||||
||| <ul><li>- \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)</ui></li> || <ul><li>- \[x] OpenWRT</ui></li> |
|
||||
||| <ul><li>- \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ui></li> || <ul><li>- \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)</ui></li> |
|
||||
||||| <ul><li>- \[x] Docker</ui></li> |
|
||||
| Connectivity | Management | Security | Automation | Platforms |
|
||||
|---|---|---|---|---|
|
||||
| ✓ [Kernel WireGuard](https://docs.netbird.io/about-netbird/why-wireguard-with-netbird) | ✓ [Admin Web UI](https://github.com/netbirdio/dashboard) | ✓ [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) | ✓ [Public API](https://docs.netbird.io/api) | ✓ [Linux](https://docs.netbird.io/get-started/install/linux) |
|
||||
| ✓ [Peer-to-peer connections](https://docs.netbird.io/about-netbird/how-netbird-works) | ✓ Auto peer discovery and configuration | ✓ [Access control: groups & rules](https://docs.netbird.io/how-to/manage-network-access) | ✓ [Setup keys for bulk provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) | ✓ [macOS](https://docs.netbird.io/get-started/install/macos) |
|
||||
| ✓ Connection relay fallback | ✓ [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) | ✓ [Activity logging](https://docs.netbird.io/how-to/audit-events-logging) | ✓ [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) | ✓ [Windows](https://docs.netbird.io/get-started/install/windows) |
|
||||
| ✓ [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) | ✓ [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) | ✓ [Traffic events](https://docs.netbird.io/manage/activity/traffic-events-logging) | ✓ [IdP groups sync with JWT](https://docs.netbird.io/manage/team/idp-sync) | ✓ [Android](https://docs.netbird.io/get-started/install/android) |
|
||||
| ✓ [Domain-based DNS routes](https://docs.netbird.io/manage/dns/dns-aliases-for-routed-networks) | ✓ [Custom DNS zones](https://docs.netbird.io/manage/dns/custom-zones) | ✓ [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) | ✓ [Terraform provider](https://registry.terraform.io/providers/netbirdio/netbird/latest) | ✓ [Android TV](https://docs.netbird.io/get-started/install/android-tv) |
|
||||
| ✓ [Exit nodes](https://docs.netbird.io/manage/network-routes/use-cases/exit-nodes) | ✓ [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) | ✓ Peer-to-peer encryption | ✓ [Ansible collection](https://github.com/netbirdio/ansible-netbird) | ✓ [iOS](https://docs.netbird.io/get-started/install/ios) |
|
||||
| ✓ [IPv6 dual-stack overlay](https://docs.netbird.io/manage/settings/ipv6) | ✓ [Multi-account profile switching](https://docs.netbird.io/client/profiles) | ✓ [SSH with central access policies](https://docs.netbird.io/manage/peers/ssh) | | ✓ [Apple TV](https://docs.netbird.io/get-started/install/tvos) |
|
||||
| ✓ [Browser SSH & RDP](https://docs.netbird.io/manage/peers/browser-client) | | ✓ [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) | | ✓ FreeBSD |
|
||||
| ✓ [Reverse proxy with auto-TLS](https://docs.netbird.io/manage/reverse-proxy) | | ✓ [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication) | | ✓ [pfSense](https://docs.netbird.io/get-started/install/pfsense) |
|
||||
| | | | | ✓ [OPNsense](https://docs.netbird.io/get-started/install/opnsense) |
|
||||
| | | | | ✓ [MikroTik RouterOS](https://docs.netbird.io/use-cases/homelab/client-on-mikrotik-router) |
|
||||
| | | | | ✓ OpenWRT |
|
||||
| | | | | ✓ [Synology](https://docs.netbird.io/get-started/install/synology) |
|
||||
| | | | | ✓ [TrueNAS](https://docs.netbird.io/get-started/install/truenas) |
|
||||
| | | | | ✓ [Proxmox](https://docs.netbird.io/get-started/install/proxmox-ve) |
|
||||
| | | | | ✓ [Raspberry Pi](https://docs.netbird.io/get-started/install/raspberrypi) |
|
||||
| | | | | ✓ [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) |
|
||||
| | | | | ✓ [Container](https://docs.netbird.io/get-started/install/docker) |
|
||||
|
||||
### Quickstart with NetBird Cloud
|
||||
|
||||
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)
|
||||
- Follow the steps to sign-up with Google, Microsoft, GitHub or your email address.
|
||||
- Check NetBird [admin UI](https://app.netbird.io/).
|
||||
- Add more machines.
|
||||
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install).
|
||||
- Follow the steps to sign up with Google, Microsoft, GitHub or your email address.
|
||||
- Check the NetBird [admin UI](https://app.netbird.io/).
|
||||
|
||||
### Quickstart with self-hosted NetBird
|
||||
|
||||
> This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM.
|
||||
Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IDPs.
|
||||
This is the quickest way to try self-hosted NetBird. It should take around 5 minutes to get started if you already have a public domain and a VM. Follow the [Advanced guide with a custom identity provider](https://docs.netbird.io/selfhosted/selfhosted-guide#advanced-guide-with-a-custom-identity-provider) for installations with different IdPs.
|
||||
|
||||
**Infrastructure requirements:**
|
||||
- A Linux VM with at least **1CPU** and **2GB** of memory.
|
||||
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port: **3478**.
|
||||
- **Public domain** name pointing to the VM.
|
||||
- A Linux VM with at least **1 CPU** and **2 GB** of memory.
|
||||
- The VM should be publicly accessible on TCP ports **80** and **443** and UDP port **3478**.
|
||||
- A **public domain** name pointing to the VM.
|
||||
|
||||
**Software requirements:**
|
||||
- Docker installed on the VM with the docker-compose plugin ([Docker installation guide](https://docs.docker.com/engine/install/)) or docker with docker-compose in version 2 or higher.
|
||||
- [jq](https://jqlang.github.io/jq/) installed. In most distributions
|
||||
Usually available in the official repositories and can be installed with `sudo apt install jq` or `sudo yum install jq`
|
||||
- [curl](https://curl.se/) installed.
|
||||
Usually available in the official repositories and can be installed with `sudo apt install curl` or `sudo yum install curl`
|
||||
- Docker with the Compose plugin (Compose v2 or higher). See the [Docker installation guide](https://docs.docker.com/engine/install/).
|
||||
|
||||
**Steps**
|
||||
- Download and run the installation script:
|
||||
```bash
|
||||
export NETBIRD_DOMAIN=netbird.example.com; curl -fsSL https://github.com/netbirdio/netbird/releases/latest/download/getting-started.sh | bash
|
||||
```
|
||||
- Once finished, you can manage the resources via `docker-compose`
|
||||
|
||||
### A bit on NetBird internals
|
||||
- Every machine in the network runs [NetBird Agent (or Client)](client/) that manages WireGuard.
|
||||
- Every agent connects to [Management Service](management/) that holds network state, manages peer IPs, and distributes network updates to agents (peers).
|
||||
- NetBird agent uses WebRTC ICE implemented in [pion/ice library](https://github.com/pion/ice) to discover connection candidates when establishing a peer-to-peer connection between machines.
|
||||
- Connection candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
|
||||
- Agents negotiate a connection through [Signal Service](signal/) passing p2p encrypted messages with candidates.
|
||||
- Sometimes the NAT traversal is unsuccessful due to strict NATs (e.g. mobile carrier-grade NAT) and a p2p connection isn't possible. When this occurs the system falls back to a relay server called [TURN](https://en.wikipedia.org/wiki/Traversal_Using_Relays_around_NAT), and a secure WireGuard tunnel is established via the TURN server.
|
||||
|
||||
[Coturn](https://github.com/coturn/coturn) is the one that has been successfully used for STUN and TURN in NetBird setups.
|
||||
- Every machine in the network runs the [NetBird agent](client/), which manages WireGuard.
|
||||
- Every agent connects to the [Management Service](management/), which holds network state, manages peer IPs, and distributes updates to agents.
|
||||
- Agents use ICE (via [pion/ice](https://github.com/pion/ice)) to discover connection candidates for peer-to-peer connections.
|
||||
- Candidates are discovered with the help of [STUN](https://en.wikipedia.org/wiki/STUN) servers.
|
||||
- Agents negotiate a connection through the [Signal Service](signal/), exchanging end-to-end encrypted messages with candidates.
|
||||
- When NAT traversal fails (e.g. mobile carrier-grade NAT) and a direct p2p connection isn't possible, the system falls back to a [Relay Service](relay/) and a secure WireGuard tunnel is established through it.
|
||||
|
||||
<p float="left" align="middle">
|
||||
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700"/>
|
||||
<img src="https://docs.netbird.io/docs-static/img/about-netbird/high-level-dia.png" width="700" alt="NetBird high-level architecture diagram"/>
|
||||
</p>
|
||||
|
||||
See a complete [architecture overview](https://docs.netbird.io/about-netbird/how-netbird-works#architecture) for details.
|
||||
|
||||
### Community projects
|
||||
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
||||
- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/)
|
||||
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) — terminal UI for managing NetBird peers, routes, and settings
|
||||
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
||||
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) - terminal UI for managing NetBird peers, routes, and settings
|
||||
- [caddy-netbird](https://github.com/lixmal/caddy-netbird) - Caddy plugin that embeds a NetBird client for proxying HTTP and TCP/UDP traffic through NetBird networks
|
||||
|
||||
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
|
||||
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
|
||||
|
||||
### Support acknowledgement
|
||||
|
||||
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by The Federal Ministry of Education and Research of The Federal Republic of Germany. Together with [CISPA Helmholtz Center for Information Security](https://cispa.de/en) NetBird brings the security best practices and simplicity to private networking.
|
||||
In November 2022, NetBird joined the [StartUpSecure program](https://www.forschung-it-sicherheit-kommunikationssysteme.de/foerderung/bekanntmachungen/startup-secure) sponsored by the Federal Ministry of Education and Research of the Federal Republic of Germany. Together with the [CISPA Helmholtz Center for Information Security](https://cispa.de/en), NetBird brings security best practices and simplicity to private networking.
|
||||
|
||||

|
||||
|
||||
### Testimonials
|
||||
We use open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE (WebRTC)](https://github.com/pion/ice), [Coturn](https://github.com/coturn/coturn), and [Rosenpass](https://rosenpass.eu). We very much appreciate the work these guys are doing and we'd greatly appreciate if you could support them in any way (e.g., by giving a star or a contribution).
|
||||
### Acknowledgements
|
||||
We build on open-source technologies like [WireGuard®](https://www.wireguard.com/), [Pion ICE](https://github.com/pion/ice), and [Rosenpass](https://rosenpass.eu). We greatly appreciate the work these projects are doing, and we'd love it if you could support them too (e.g., by starring or contributing).
|
||||
|
||||
### Legal
|
||||
This repository is licensed under BSD-3-Clause license that applies to all parts of the repository except for the directories management/, signal/ and relay/.
|
||||
This repository is licensed under the BSD-3-Clause license, which applies to all parts of the repository except for the directories management/, signal/ and relay/.
|
||||
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
|
||||
|
||||
_WireGuard_ and the _WireGuard_ logo are [registered trademarks](https://www.wireguard.com/trademark-policy/) of Jason A. Donenfeld.
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"go.opentelemetry.io/otel"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
|
||||
nbcache "github.com/netbirdio/netbird/management/server/cache"
|
||||
|
||||
@@ -109,7 +109,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
|
||||
iv, _ := validator.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
@@ -84,6 +85,12 @@ type Options struct {
|
||||
DisableIPv6 bool
|
||||
// BlockInbound blocks all inbound connections from peers
|
||||
BlockInbound bool
|
||||
// BlockLANAccess blocks the embedded peer from reaching the host's
|
||||
// LAN (RFC 1918, link-local, loopback) when it's used as a routing
|
||||
// peer. Mirrors profilemanager.ConfigInput.BlockLANAccess. Useful
|
||||
// when the embedded client must never act as a stepping stone into
|
||||
// the host's local network (e.g. the proxy's overlay peer).
|
||||
BlockLANAccess bool
|
||||
// WireguardPort is the port for the tunnel interface. Use 0 for a random port.
|
||||
WireguardPort *int
|
||||
// MTU is the MTU for the tunnel interface.
|
||||
@@ -94,6 +101,26 @@ type Options struct {
|
||||
MTU *uint16
|
||||
// DNSLabels defines additional DNS labels configured in the peer.
|
||||
DNSLabels []string
|
||||
// Performance configures the tunnel's buffer pool cap and batch size.
|
||||
Performance Performance
|
||||
}
|
||||
|
||||
// Performance configures the embedded client's tunnel memory/throughput knobs.
|
||||
//
|
||||
// These settings are process-global: any non-nil field also becomes the
|
||||
// default for Clients constructed by later embed.New calls in the same
|
||||
// process. Nil fields are ignored.
|
||||
type Performance struct {
|
||||
// PreallocatedBuffersPerPool caps the per-tunnel buffer pool. Zero
|
||||
// leaves the pool unbounded. Lower values trade throughput for a
|
||||
// tighter memory ceiling. May also be changed on a running Client via
|
||||
// Client.SetPerformance, provided this field was nonzero at construction.
|
||||
PreallocatedBuffersPerPool *uint32
|
||||
// MaxBatchSize overrides the number of packets the tunnel reads or
|
||||
// writes per syscall, which also bounds eager buffer allocation per
|
||||
// worker. Zero uses the platform default. Applied at construction
|
||||
// only; ignored by Client.SetPerformance.
|
||||
MaxBatchSize *uint32
|
||||
}
|
||||
|
||||
// validateCredentials checks that exactly one credential type is provided
|
||||
@@ -175,6 +202,7 @@ func New(opts Options) (*Client, error) {
|
||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||
DisableIPv6: &opts.DisableIPv6,
|
||||
BlockInbound: &opts.BlockInbound,
|
||||
BlockLANAccess: &opts.BlockLANAccess,
|
||||
WireguardPort: opts.WireguardPort,
|
||||
MTU: opts.MTU,
|
||||
DNSLabels: parsedLabels,
|
||||
@@ -192,6 +220,13 @@ func New(opts Options) (*Client, error) {
|
||||
config.PrivateKey = opts.PrivateKey
|
||||
}
|
||||
|
||||
if opts.Performance.PreallocatedBuffersPerPool != nil {
|
||||
wgdevice.SetPreallocatedBuffersPerPool(*opts.Performance.PreallocatedBuffersPerPool)
|
||||
}
|
||||
if opts.Performance.MaxBatchSize != nil {
|
||||
wgdevice.SetMaxBatchSizeOverride(*opts.Performance.MaxBatchSize)
|
||||
}
|
||||
|
||||
return &Client{
|
||||
deviceName: opts.DeviceName,
|
||||
setupKey: opts.SetupKey,
|
||||
@@ -405,6 +440,21 @@ func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// IdentityForIP looks up a remote peer by its tunnel IP using the
|
||||
// embedded client's status recorder. Returns the peer's WireGuard public
|
||||
// key and FQDN. ok=false means the IP isn't in this client's peer
|
||||
// roster — callers should treat that as "unknown peer".
|
||||
func (c *Client) IdentityForIP(ip netip.Addr) (pubKey, fqdn string, ok bool) {
|
||||
if !ip.IsValid() || c.recorder == nil {
|
||||
return "", "", false
|
||||
}
|
||||
state, found := c.recorder.PeerStateByIP(ip.String())
|
||||
if !found {
|
||||
return "", "", false
|
||||
}
|
||||
return state.PubKey, state.FQDN, true
|
||||
}
|
||||
|
||||
// Status returns the current status of the client.
|
||||
func (c *Client) Status() (peer.FullStatus, error) {
|
||||
c.mu.Lock()
|
||||
@@ -473,6 +523,25 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
|
||||
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
|
||||
}
|
||||
|
||||
// SetPerformance retunes a running Client. Only PreallocatedBuffersPerPool
|
||||
// takes effect, and only when it was nonzero at construction;
|
||||
// MaxBatchSize is construction-only and returns an error if set here.
|
||||
//
|
||||
// Returns ErrClientNotStarted / ErrEngineNotStarted if the Client is not
|
||||
// running yet.
|
||||
func (c *Client) SetPerformance(t Performance) error {
|
||||
if t.MaxBatchSize != nil {
|
||||
return errors.New("MaxBatchSize is construction-only and cannot be changed at runtime")
|
||||
}
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return engine.SetPerformance(internal.Performance{
|
||||
PreallocatedBuffersPerPool: t.PreallocatedBuffersPerPool,
|
||||
})
|
||||
}
|
||||
|
||||
// StartCapture begins capturing packets on this client's tunnel device.
|
||||
// Only one capture can be active at a time; starting a new one stops the previous.
|
||||
// Call StopCapture (or CaptureSession.Stop) to end it.
|
||||
|
||||
@@ -1,554 +0,0 @@
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/uuid"
|
||||
ipset "github.com/lrh3321/ipset-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
const (
|
||||
tableName = "filter"
|
||||
|
||||
// rules chains contains the effective ACL rules
|
||||
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
||||
|
||||
// mangleFwdKey is the entries map key for mangle FORWARD guard rules that prevent
|
||||
// external DNAT from bypassing ACL rules.
|
||||
mangleFwdKey = "MANGLE-FORWARD"
|
||||
)
|
||||
|
||||
type aclEntries map[string][][]string
|
||||
|
||||
type entry struct {
|
||||
spec []string
|
||||
position int
|
||||
}
|
||||
|
||||
type aclManager struct {
|
||||
iptablesClient *iptables.IPTables
|
||||
wgIface iFaceMapper
|
||||
entries aclEntries
|
||||
optionalEntries map[string][]entry
|
||||
ipsetStore *ipsetStore
|
||||
v6 bool
|
||||
|
||||
stateManager *statemanager.Manager
|
||||
}
|
||||
|
||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
|
||||
return &aclManager{
|
||||
iptablesClient: iptablesClient,
|
||||
wgIface: wgIface,
|
||||
entries: make(map[string][][]string),
|
||||
optionalEntries: make(map[string][]entry),
|
||||
ipsetStore: newIpsetStore(),
|
||||
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
||||
m.stateManager = stateManager
|
||||
|
||||
m.seedInitialEntries()
|
||||
m.seedInitialOptionalEntries()
|
||||
|
||||
if err := m.cleanChains(); err != nil {
|
||||
return fmt.Errorf("clean chains: %w", err)
|
||||
}
|
||||
|
||||
if err := m.createDefaultChains(); err != nil {
|
||||
return fmt.Errorf("create default chains: %w", err)
|
||||
}
|
||||
|
||||
m.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
protocol firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
chain := chainNameInputRules
|
||||
|
||||
ipsetName = transformIPsetName(ipsetName, sPort, dPort, action)
|
||||
if m.v6 && ipsetName != "" {
|
||||
ipsetName += "-v6"
|
||||
}
|
||||
proto := protoForFamily(protocol, m.v6)
|
||||
specs := filterRuleSpecs(ip, proto, sPort, dPort, action, ipsetName)
|
||||
|
||||
mangleSpecs := slices.Clone(specs)
|
||||
mangleSpecs = append(mangleSpecs,
|
||||
"-i", m.wgIface.Name(),
|
||||
"-m", "addrtype", "--dst-type", "LOCAL",
|
||||
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
||||
)
|
||||
|
||||
specs = append(specs, "-j", actionToStr(action))
|
||||
if ipsetName != "" {
|
||||
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
||||
if err := m.addToIPSet(ipsetName, ip); err != nil {
|
||||
return nil, fmt.Errorf("add IP to ipset: %w", err)
|
||||
}
|
||||
// if ruleset already exists it means we already have the firewall rule
|
||||
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
||||
ipList.addIP(ip.String())
|
||||
return []firewall.Rule{&Rule{
|
||||
ruleID: uuid.New().String(),
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
chain: chain,
|
||||
specs: specs,
|
||||
v6: m.v6,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
if err := m.flushIPSet(ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("flush ipset %s before use: %v", ipsetName, err)
|
||||
} else {
|
||||
log.Errorf("flush ipset %s before use: %v", ipsetName, err)
|
||||
}
|
||||
}
|
||||
if err := m.createIPSet(ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("create ipset: %w", err)
|
||||
}
|
||||
if err := m.addToIPSet(ipsetName, ip); err != nil {
|
||||
return nil, fmt.Errorf("add IP to ipset: %w", err)
|
||||
}
|
||||
|
||||
ipList := newIpList(ip.String())
|
||||
m.ipsetStore.addIpList(ipsetName, ipList)
|
||||
}
|
||||
|
||||
ok, err := m.iptablesClient.Exists(tableFilter, chain, specs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check rule: %w", err)
|
||||
}
|
||||
if ok {
|
||||
return nil, fmt.Errorf("rule already exists")
|
||||
}
|
||||
|
||||
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||
if action == firewall.ActionDrop {
|
||||
// Insert at the beginning of the chain (position 1)
|
||||
err = m.iptablesClient.Insert(tableFilter, chain, 1, specs...)
|
||||
} else {
|
||||
err = m.iptablesClient.Append(tableFilter, chain, specs...)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := m.iptablesClient.Append(tableMangle, chainRTPRE, mangleSpecs...); err != nil {
|
||||
log.Errorf("failed to add mangle rule: %v", err)
|
||||
mangleSpecs = nil
|
||||
}
|
||||
|
||||
rule := &Rule{
|
||||
ruleID: uuid.New().String(),
|
||||
specs: specs,
|
||||
mangleSpecs: mangleSpecs,
|
||||
ipsetName: ipsetName,
|
||||
ip: ip.String(),
|
||||
chain: chain,
|
||||
v6: m.v6,
|
||||
}
|
||||
|
||||
m.updateState()
|
||||
|
||||
return []firewall.Rule{rule}, nil
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
r, ok := rule.(*Rule)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
}
|
||||
|
||||
shouldDestroyIpset := false
|
||||
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
|
||||
// delete IP from ruleset IPs list and ipset
|
||||
if _, ok := ipsetList.ips[r.ip]; ok {
|
||||
ip := net.ParseIP(r.ip)
|
||||
if ip == nil {
|
||||
return fmt.Errorf("parse IP %s", r.ip)
|
||||
}
|
||||
if err := m.delFromIPSet(r.ipsetName, ip); err != nil {
|
||||
return fmt.Errorf("delete ip from ipset: %w", err)
|
||||
}
|
||||
delete(ipsetList.ips, r.ip)
|
||||
}
|
||||
|
||||
// if after delete, set still contains other IPs,
|
||||
// no need to delete firewall rule and we should exit here
|
||||
if len(ipsetList.ips) != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// we delete last IP from the set, that means we need to delete
|
||||
// set itself and associated firewall rule too
|
||||
m.ipsetStore.deleteIpset(r.ipsetName)
|
||||
shouldDestroyIpset = true
|
||||
}
|
||||
|
||||
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
|
||||
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
|
||||
}
|
||||
|
||||
if r.mangleSpecs != nil {
|
||||
if err := m.iptablesClient.Delete(tableMangle, chainRTPRE, r.mangleSpecs...); err != nil {
|
||||
log.Errorf("failed to delete mangle rule: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if shouldDestroyIpset {
|
||||
if err := m.destroyIPSet(r.ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("destroy empty ipset: %v", err)
|
||||
} else {
|
||||
log.Errorf("destroy empty ipset: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) Reset() error {
|
||||
if err := m.cleanChains(); err != nil {
|
||||
return fmt.Errorf("clean chains: %w", err)
|
||||
}
|
||||
|
||||
m.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// todo write less destructive cleanup mechanism
|
||||
func (m *aclManager) cleanChains() error {
|
||||
ok, err := m.iptablesClient.ChainExists(tableName, chainNameInputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to list chains: %s", err)
|
||||
return err
|
||||
}
|
||||
if ok {
|
||||
for _, rule := range m.entries["INPUT"] {
|
||||
err := m.iptablesClient.DeleteIfExists(tableName, "INPUT", rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range m.entries["FORWARD"] {
|
||||
err := m.iptablesClient.DeleteIfExists(tableName, "FORWARD", rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameInputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to clear and delete %s chain: %s", chainNameInputRules, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %w", err)
|
||||
}
|
||||
if ok {
|
||||
for _, rule := range m.entries["PREROUTING"] {
|
||||
err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range m.entries[mangleFwdKey] {
|
||||
if err := m.iptablesClient.DeleteIfExists(tableMangle, chainFORWARD, rule...); err != nil {
|
||||
log.Errorf("failed to delete mangle FORWARD guard rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
||||
if err := m.flushIPSet(ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
} else {
|
||||
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
}
|
||||
if err := m.destroyIPSet(ipsetName); err != nil {
|
||||
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
|
||||
log.Debugf("destroy ipset %q during reset: %v", ipsetName, err)
|
||||
} else {
|
||||
log.Errorf("destroy ipset %q during reset: %v", ipsetName, err)
|
||||
}
|
||||
}
|
||||
m.ipsetStore.deleteIpset(ipsetName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) createDefaultChains() error {
|
||||
// chain netbird-acl-input-rules
|
||||
if err := m.iptablesClient.NewChain(tableName, chainNameInputRules); err != nil {
|
||||
log.Debugf("failed to create '%s' chain: %s", chainNameInputRules, err)
|
||||
return err
|
||||
}
|
||||
|
||||
for chainName, rules := range m.entries {
|
||||
// mangle FORWARD guard rules are handled separately below
|
||||
if chainName == mangleFwdKey {
|
||||
continue
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
|
||||
log.Debugf("failed to create input chain jump rule: %s", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for chainName, entries := range m.optionalEntries {
|
||||
for _, entry := range entries {
|
||||
if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil {
|
||||
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
|
||||
continue
|
||||
}
|
||||
m.entries[chainName] = append(m.entries[chainName], entry.spec)
|
||||
}
|
||||
}
|
||||
clear(m.optionalEntries)
|
||||
|
||||
// Insert mangle FORWARD guard rules to prevent external DNAT bypass.
|
||||
for _, rule := range m.entries[mangleFwdKey] {
|
||||
if err := m.iptablesClient.AppendUnique(tableMangle, chainFORWARD, rule...); err != nil {
|
||||
log.Errorf("failed to add mangle FORWARD guard rule: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// seedInitialEntries adds default rules to the entries map, rules are inserted on pos 1, hence the order is reversed.
|
||||
// We want to make sure our traffic is not dropped by existing rules.
|
||||
|
||||
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||
func (m *aclManager) seedInitialEntries() {
|
||||
established := getConntrackEstablished()
|
||||
|
||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
||||
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
|
||||
|
||||
// Inbound is handled by our ACLs, the rest is dropped.
|
||||
// For outbound we respect the FORWARD policy. However, we need to allow established/related traffic for inbound rules.
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||
|
||||
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
|
||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
|
||||
|
||||
// Mangle FORWARD guard: when external DNAT redirects traffic from the wg interface, it
|
||||
// traverses FORWARD instead of INPUT, bypassing ACL rules. ACCEPT rules in filter FORWARD
|
||||
// can be inserted above ours. Mangle runs before filter, so these guard rules enforce the
|
||||
// ACL mark check where it cannot be overridden.
|
||||
m.appendToEntries(mangleFwdKey, []string{
|
||||
"-i", m.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED",
|
||||
"-j", "ACCEPT",
|
||||
})
|
||||
m.appendToEntries(mangleFwdKey, []string{
|
||||
"-i", m.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "DNAT",
|
||||
"-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
||||
"-j", "DROP",
|
||||
})
|
||||
}
|
||||
|
||||
func (m *aclManager) seedInitialOptionalEntries() {
|
||||
m.optionalEntries["FORWARD"] = []entry{
|
||||
{
|
||||
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
|
||||
position: 2,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
||||
m.entries[chainName] = append(m.entries[chainName], spec)
|
||||
}
|
||||
|
||||
func (m *aclManager) updateState() {
|
||||
if m.stateManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var currentState *ShutdownState
|
||||
if existing := m.stateManager.GetState(currentState); existing != nil {
|
||||
if existingState, ok := existing.(*ShutdownState); ok {
|
||||
currentState = existingState
|
||||
}
|
||||
}
|
||||
if currentState == nil {
|
||||
currentState = &ShutdownState{}
|
||||
}
|
||||
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
if m.v6 {
|
||||
currentState.ACLEntries6 = m.entries
|
||||
currentState.ACLIPsetStore6 = m.ipsetStore
|
||||
} else {
|
||||
currentState.ACLEntries = m.entries
|
||||
currentState.ACLIPsetStore = m.ipsetStore
|
||||
}
|
||||
|
||||
if err := m.stateManager.UpdateState(currentState); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// filterRuleSpecs returns the specs of a filtering rule
|
||||
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
|
||||
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
|
||||
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
|
||||
if v6 && protocol == firewall.ProtocolICMP {
|
||||
return "ipv6-icmp"
|
||||
}
|
||||
return string(protocol)
|
||||
}
|
||||
|
||||
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
||||
// don't use IP matching if IP is 0.0.0.0
|
||||
matchByIP := !ip.IsUnspecified()
|
||||
|
||||
if matchByIP {
|
||||
if ipsetName != "" {
|
||||
specs = append(specs, "-m", "set", "--match-set", ipsetName, "src")
|
||||
} else {
|
||||
specs = append(specs, "-s", ip.String())
|
||||
}
|
||||
}
|
||||
if protocol != "all" {
|
||||
specs = append(specs, "-p", protocol)
|
||||
}
|
||||
specs = append(specs, applyPort("--sport", sPort)...)
|
||||
specs = append(specs, applyPort("--dport", dPort)...)
|
||||
return specs
|
||||
}
|
||||
|
||||
func actionToStr(action firewall.Action) string {
|
||||
if action == firewall.ActionAccept {
|
||||
return "ACCEPT"
|
||||
}
|
||||
return "DROP"
|
||||
}
|
||||
|
||||
func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action firewall.Action) string {
|
||||
if ipsetName == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
actionSuffix := ""
|
||||
if action == firewall.ActionDrop {
|
||||
actionSuffix = "-drop"
|
||||
}
|
||||
|
||||
switch {
|
||||
case sPort != nil && dPort != nil:
|
||||
return ipsetName + "-sport-dport" + actionSuffix
|
||||
case sPort != nil:
|
||||
return ipsetName + "-sport" + actionSuffix
|
||||
case dPort != nil:
|
||||
return ipsetName + "-dport" + actionSuffix
|
||||
default:
|
||||
return ipsetName + actionSuffix
|
||||
}
|
||||
}
|
||||
|
||||
func (m *aclManager) createIPSet(name string) error {
|
||||
opts := ipset.CreateOptions{
|
||||
Replace: true,
|
||||
}
|
||||
if m.v6 {
|
||||
opts.Family = ipset.FamilyIPV6
|
||||
}
|
||||
|
||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
log.Debugf("created ipset %s with type hash:net", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) addToIPSet(name string, ip net.IP) error {
|
||||
cidr := uint8(32)
|
||||
if ip.To4() == nil {
|
||||
cidr = 128
|
||||
}
|
||||
|
||||
entry := &ipset.Entry{
|
||||
IP: ip,
|
||||
CIDR: cidr,
|
||||
Replace: true,
|
||||
}
|
||||
|
||||
if err := ipset.Add(name, entry); err != nil {
|
||||
return fmt.Errorf("add IP to ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) delFromIPSet(name string, ip net.IP) error {
|
||||
cidr := uint8(32)
|
||||
if ip.To4() == nil {
|
||||
cidr = 128
|
||||
}
|
||||
|
||||
entry := &ipset.Entry{
|
||||
IP: ip,
|
||||
CIDR: cidr,
|
||||
}
|
||||
|
||||
if err := ipset.Del(name, entry); err != nil {
|
||||
return fmt.Errorf("delete IP from ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *aclManager) flushIPSet(name string) error {
|
||||
return ipset.Flush(name)
|
||||
}
|
||||
|
||||
func (m *aclManager) destroyIPSet(name string) error {
|
||||
return ipset.Destroy(name)
|
||||
}
|
||||
346
client/firewall/iptables/chains_linux.go
Normal file
346
client/firewall/iptables/chains_linux.go
Normal file
@@ -0,0 +1,346 @@
|
||||
//go:build !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
func (r *family) cleanUpDefaultForwardRules() error {
|
||||
if err := r.cleanJumpRules(); err != nil {
|
||||
return fmt.Errorf("clean jump rules: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("flushing routing related tables")
|
||||
|
||||
// Remove jump rules from built-in chains before deleting custom chains,
|
||||
// otherwise the chain deletion fails with "device or resource busy".
|
||||
if ok, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput); err != nil {
|
||||
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
|
||||
} else if ok {
|
||||
jumpRule := []string{"-j", chainNATOutput}
|
||||
if err := r.iptablesClient.Delete(tableNat, chainOUTPUT, jumpRule...); err != nil {
|
||||
log.Debugf("clean OUTPUT jump rule: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, chainInfo := range []struct {
|
||||
chain string
|
||||
table string
|
||||
}{
|
||||
{chainRTFWDIN, tableFilter},
|
||||
{chainRTFWDOUT, tableFilter},
|
||||
{chainRTPRE, tableMangle},
|
||||
{chainRTNAT, tableNat},
|
||||
{chainRTRDR, tableNat},
|
||||
{chainNATOutput, tableNat},
|
||||
{chainRTMSSCLAMP, tableMangle},
|
||||
} {
|
||||
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||
} else if ok {
|
||||
if err = r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||
return fmt.Errorf("clear and delete chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) createContainers() error {
|
||||
for _, chainInfo := range []struct {
|
||||
chain string
|
||||
table string
|
||||
}{
|
||||
{chainRTFWDIN, tableFilter},
|
||||
{chainRTFWDOUT, tableFilter},
|
||||
{chainRTPRE, tableMangle},
|
||||
{chainRTNAT, tableNat},
|
||||
{chainRTRDR, tableNat},
|
||||
{chainRTMSSCLAMP, tableMangle},
|
||||
} {
|
||||
// Fallback: clear chains that survived an unclean shutdown.
|
||||
if ok, _ := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain); ok {
|
||||
if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||
log.Warnf("clear stale chain %s in %s: %v", chainInfo.chain, chainInfo.table, err)
|
||||
}
|
||||
}
|
||||
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.insertEstablishedRule(chainRTFWDIN); err != nil {
|
||||
return fmt.Errorf("insert established rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.insertEstablishedRule(chainRTFWDOUT); err != nil {
|
||||
return fmt.Errorf("insert established rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addPostroutingRules(); err != nil {
|
||||
return fmt.Errorf("add static nat rules: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addJumpRules(); err != nil {
|
||||
return fmt.Errorf("add jump rules: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addMSSClampingRules(); err != nil {
|
||||
log.Errorf("failed to add MSS clamping rules: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) setupDataPlaneMark() error {
|
||||
var merr *multierror.Error
|
||||
preRule := []string{
|
||||
"-i", r.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "NEW",
|
||||
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkIn),
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.AppendUnique(tableMangle, chainPREROUTING, preRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add mangle prerouting rule: %w", err))
|
||||
} else {
|
||||
r.rules[markManglePre] = preRule
|
||||
}
|
||||
|
||||
postRule := []string{
|
||||
"-o", r.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "NEW",
|
||||
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkOut),
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.AppendUnique(tableMangle, chainPOSTROUTING, postRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add mangle postrouting rule: %w", err))
|
||||
} else {
|
||||
r.rules[markManglePost] = postRule
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) cleanupDataPlaneMark() error {
|
||||
var merr *multierror.Error
|
||||
if preRule, exists := r.rules[markManglePre]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPREROUTING, preRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove mangle prerouting rule: %w", err))
|
||||
} else {
|
||||
delete(r.rules, markManglePre)
|
||||
}
|
||||
}
|
||||
|
||||
if postRule, exists := r.rules[markManglePost]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPOSTROUTING, postRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove mangle postrouting rule: %w", err))
|
||||
} else {
|
||||
delete(r.rules, markManglePost)
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) addJumpRules() error {
|
||||
// Jump to nat chain
|
||||
natRule := []string{"-j", chainRTNAT}
|
||||
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
|
||||
return fmt.Errorf("add nat postrouting jump rule: %w", err)
|
||||
}
|
||||
r.rules[jumpNatPost] = natRule
|
||||
|
||||
// Jump to mangle prerouting chain
|
||||
preRule := []string{"-j", chainRTPRE}
|
||||
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
|
||||
return fmt.Errorf("add mangle prerouting jump rule: %w", err)
|
||||
}
|
||||
r.rules[jumpManglePre] = preRule
|
||||
|
||||
// Jump to nat prerouting chain
|
||||
rdrRule := []string{"-j", chainRTRDR}
|
||||
if err := r.iptablesClient.Insert(tableNat, chainPREROUTING, 1, rdrRule...); err != nil {
|
||||
return fmt.Errorf("add nat prerouting jump rule: %w", err)
|
||||
}
|
||||
r.rules[jumpNatPre] = rdrRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) cleanJumpRules() error {
|
||||
for _, ruleKey := range []firewall.RuleID{jumpNatPost, jumpManglePre, jumpNatPre, jumpMSSClamp} {
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
var table, chain string
|
||||
switch ruleKey {
|
||||
case jumpNatPost:
|
||||
table = tableNat
|
||||
chain = chainPOSTROUTING
|
||||
case jumpManglePre:
|
||||
table = tableMangle
|
||||
chain = chainPREROUTING
|
||||
case jumpNatPre:
|
||||
table = tableNat
|
||||
chain = chainPREROUTING
|
||||
case jumpMSSClamp:
|
||||
table = tableMangle
|
||||
chain = chainFORWARD
|
||||
default:
|
||||
return fmt.Errorf("unknown jump rule: %s", ruleKey)
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
|
||||
return fmt.Errorf("delete rule from chain %s in table %s: %w", chain, table, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) cleanAclChains() error {
|
||||
ok, err := r.iptablesClient.ChainExists(tableName, chainNameInputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to list chains: %s", err)
|
||||
return err
|
||||
}
|
||||
if ok {
|
||||
for _, rule := range r.entries[chainINPUT] {
|
||||
err := r.iptablesClient.DeleteIfExists(tableName, chainINPUT, rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range r.entries[chainFORWARD] {
|
||||
err := r.iptablesClient.DeleteIfExists(tableName, chainFORWARD, rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = r.iptablesClient.ClearAndDeleteChain(tableName, chainNameInputRules)
|
||||
if err != nil {
|
||||
log.Debugf("failed to clear and delete %s chain: %s", chainNameInputRules, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ok, err = r.iptablesClient.ChainExists(tableMangle, chainPREROUTING)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %w", err)
|
||||
}
|
||||
if ok {
|
||||
for _, rule := range r.entries[chainPREROUTING] {
|
||||
err := r.iptablesClient.DeleteIfExists(tableMangle, chainPREROUTING, rule...)
|
||||
if err != nil {
|
||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, rule := range r.entries[mangleFwdKey] {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainFORWARD, rule...); err != nil {
|
||||
log.Errorf("failed to delete mangle FORWARD guard rule: %v, %s", rule, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) createDefaultChains() error {
|
||||
if err := r.iptablesClient.NewChain(tableName, chainNameInputRules); err != nil {
|
||||
log.Debugf("failed to create '%s' chain: %s", chainNameInputRules, err)
|
||||
return err
|
||||
}
|
||||
|
||||
for chain, rules := range r.entries {
|
||||
// mangle FORWARD guard rules are handled separately below
|
||||
if chain == mangleFwdKey {
|
||||
continue
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if err := r.iptablesClient.InsertUnique(tableName, string(chain), 1, rule...); err != nil {
|
||||
log.Debugf("failed to create input chain jump rule: %s", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for chain, entries := range r.optionalEntries {
|
||||
for _, entry := range entries {
|
||||
if err := r.iptablesClient.InsertUnique(tableName, string(chain), entry.position, entry.spec...); err != nil {
|
||||
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
|
||||
continue
|
||||
}
|
||||
r.entries[chain] = append(r.entries[chain], entry.spec)
|
||||
}
|
||||
}
|
||||
clear(r.optionalEntries)
|
||||
|
||||
// Insert mangle FORWARD guard rules to prevent external DNAT bypass.
|
||||
for _, rule := range r.entries[mangleFwdKey] {
|
||||
if err := r.iptablesClient.AppendUnique(tableMangle, chainFORWARD, rule...); err != nil {
|
||||
log.Errorf("failed to add mangle FORWARD guard rule: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// seedInitialEntries adds default rules to the entries map. Rules are
|
||||
// inserted at position 1, so the order here is reversed.
|
||||
//
|
||||
// Existing FORWARD policy decides outbound traffic towards our
|
||||
// interface. If FORWARD policy is "drop", we add an
|
||||
// established/related rule to allow return traffic for inbound rules.
|
||||
func (r *family) seedInitialEntries() {
|
||||
established := getConntrackEstablished()
|
||||
|
||||
r.appendToEntries(chainINPUT, []string{"-i", r.wgIface.Name(), "-j", "DROP"})
|
||||
r.appendToEntries(chainINPUT, []string{"-i", r.wgIface.Name(), "-j", chainNameInputRules})
|
||||
r.appendToEntries(chainINPUT, append([]string{"-i", r.wgIface.Name()}, established...))
|
||||
|
||||
r.appendToEntries(chainFORWARD, []string{"-i", r.wgIface.Name(), "-j", "DROP"})
|
||||
r.appendToEntries(chainFORWARD, []string{"-o", r.wgIface.Name(), "-j", chainRTFWDOUT})
|
||||
r.appendToEntries(chainFORWARD, []string{"-i", r.wgIface.Name(), "-j", chainRTFWDIN})
|
||||
|
||||
// Mangle FORWARD guard: when external DNAT redirects traffic from
|
||||
// the wg interface, it traverses FORWARD instead of INPUT,
|
||||
// bypassing ACL rules. ACCEPT rules in filter FORWARD can be
|
||||
// inserted above ours. Mangle runs before filter, so these guard
|
||||
// rules enforce the ACL mark check where it cannot be overridden.
|
||||
r.appendToEntries(mangleFwdKey, []string{
|
||||
"-i", r.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED",
|
||||
"-j", "ACCEPT",
|
||||
})
|
||||
r.appendToEntries(mangleFwdKey, []string{
|
||||
"-i", r.wgIface.Name(),
|
||||
"-m", "conntrack", "--ctstate", "DNAT",
|
||||
"-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
||||
"-j", "DROP",
|
||||
})
|
||||
}
|
||||
|
||||
func (r *family) seedInitialOptionalEntries() {
|
||||
r.optionalEntries[chainFORWARD] = []entry{
|
||||
{
|
||||
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
|
||||
position: 2,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *family) appendToEntries(chain chainKey, spec ruleSpec) {
|
||||
r.entries[chain] = append(r.entries[chain], spec)
|
||||
}
|
||||
269
client/firewall/iptables/dnat_linux.go
Normal file
269
client/firewall/iptables/dnat_linux.go
Normal file
@@ -0,0 +1,269 @@
|
||||
//go:build !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
func (r *family) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ruleKey := rule.ID()
|
||||
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
toDestination := rule.TranslatedAddress.String()
|
||||
switch {
|
||||
case len(rule.TranslatedPort.Values) == 0:
|
||||
// no translated port, use original port
|
||||
case len(rule.TranslatedPort.Values) == 1:
|
||||
toDestination += fmt.Sprintf(":%d", rule.TranslatedPort.Values[0])
|
||||
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
|
||||
// need the "/originalport" suffix to avoid dnat port randomization
|
||||
toDestination += fmt.Sprintf(":%d-%d/%d", rule.TranslatedPort.Values[0], rule.TranslatedPort.Values[1], rule.DestinationPort.Values[0])
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
|
||||
}
|
||||
|
||||
proto := strings.ToLower(string(rule.Protocol))
|
||||
|
||||
rules := make(map[firewall.RuleID]ruleInfo, 3)
|
||||
|
||||
// DNAT rule
|
||||
dnatRule := []string{
|
||||
"!", "-i", r.wgIface.Name(),
|
||||
"-p", proto,
|
||||
"-j", "DNAT",
|
||||
"--to-destination", toDestination,
|
||||
}
|
||||
dnatRule = append(dnatRule, applyPort("--dport", &rule.DestinationPort)...)
|
||||
rules[ruleKey+dnatSuffix] = ruleInfo{
|
||||
table: tableNat,
|
||||
chain: chainRTRDR,
|
||||
rule: dnatRule,
|
||||
}
|
||||
|
||||
// SNAT rule
|
||||
snatRule := []string{
|
||||
"-o", r.wgIface.Name(),
|
||||
"-p", proto,
|
||||
"-d", rule.TranslatedAddress.String(),
|
||||
"-j", "MASQUERADE",
|
||||
}
|
||||
snatRule = append(snatRule, applyPort("--dport", &rule.TranslatedPort)...)
|
||||
rules[ruleKey+snatSuffix] = ruleInfo{
|
||||
table: tableNat,
|
||||
chain: chainRTNAT,
|
||||
rule: snatRule,
|
||||
}
|
||||
|
||||
// Forward filtering rule, if fwd policy is DROP
|
||||
forwardRule := []string{
|
||||
"-o", r.wgIface.Name(),
|
||||
"-p", proto,
|
||||
"-d", rule.TranslatedAddress.String(),
|
||||
"-j", "ACCEPT",
|
||||
}
|
||||
forwardRule = append(forwardRule, applyPort("--dport", &rule.TranslatedPort)...)
|
||||
rules[ruleKey+fwdSuffix] = ruleInfo{
|
||||
table: tableFilter,
|
||||
chain: chainRTFWDOUT,
|
||||
rule: forwardRule,
|
||||
}
|
||||
|
||||
for key, ruleInfo := range rules {
|
||||
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
|
||||
if rollbackErr := r.rollbackRules(rules); rollbackErr != nil {
|
||||
log.Errorf("rollback failed: %v", rollbackErr)
|
||||
}
|
||||
return nil, fmt.Errorf("add rule %s: %w", key, err)
|
||||
}
|
||||
r.rules[key] = ruleInfo.rule
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (r *family) rollbackRules(rules map[firewall.RuleID]ruleInfo) error {
|
||||
var merr *multierror.Error
|
||||
for key, ruleInfo := range rules {
|
||||
if err := r.iptablesClient.DeleteIfExists(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("rollback rule %s: %w", key, err))
|
||||
// On rollback error, add to rules map for next cleanup
|
||||
r.rules[key] = ruleInfo.rule
|
||||
}
|
||||
}
|
||||
if merr != nil {
|
||||
r.updateState()
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) DeleteDNATRule(rule firewall.Rule) error {
|
||||
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
log.Errorf("%v", err)
|
||||
}
|
||||
|
||||
ruleKey := rule.ID()
|
||||
|
||||
var merr *multierror.Error
|
||||
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete DNAT rule: %w", err))
|
||||
}
|
||||
delete(r.rules, ruleKey+dnatSuffix)
|
||||
}
|
||||
|
||||
if snatRule, exists := r.rules[ruleKey+snatSuffix]; exists {
|
||||
if err := r.iptablesClient.Delete(tableNat, chainRTNAT, snatRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete SNAT rule: %w", err))
|
||||
}
|
||||
delete(r.rules, ruleKey+snatSuffix)
|
||||
}
|
||||
|
||||
if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists {
|
||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDOUT, fwdRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
|
||||
}
|
||||
delete(r.rules, ruleKey+fwdSuffix)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
dnatRule := []string{
|
||||
"-i", r.wgIface.Name(),
|
||||
"-p", strings.ToLower(protoForFamily(protocol, r.v6)),
|
||||
"--dport", strconv.Itoa(int(originalPort)),
|
||||
"-d", localAddr.String(),
|
||||
"-m", "addrtype", "--dst-type", "LOCAL",
|
||||
"-j", "DNAT",
|
||||
"--to-destination", ":" + strconv.Itoa(int(translatedPort)),
|
||||
}
|
||||
|
||||
info := ruleInfo{
|
||||
table: tableNat,
|
||||
chain: chainRTRDR,
|
||||
rule: dnatRule,
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.Append(info.table, info.chain, info.rule...); err != nil {
|
||||
return fmt.Errorf("add inbound DNAT rule: %w", err)
|
||||
}
|
||||
r.rules[ruleID] = info.rule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (r *family) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||
|
||||
if dnatRule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
|
||||
return fmt.Errorf("delete inbound DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use.
|
||||
func (r *family) ensureNATOutputChain() error {
|
||||
if _, exists := r.rules[jumpNatOutput]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
|
||||
}
|
||||
if !chainExists {
|
||||
if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil {
|
||||
return fmt.Errorf("create chain %s: %w", chainNATOutput, err)
|
||||
}
|
||||
}
|
||||
|
||||
jumpRule := []string{"-j", chainNATOutput}
|
||||
if err := r.iptablesClient.Insert(tableNat, chainOUTPUT, 1, jumpRule...); err != nil {
|
||||
if !chainExists {
|
||||
if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil {
|
||||
log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("add OUTPUT jump rule: %w", err)
|
||||
}
|
||||
r.rules[jumpNatOutput] = jumpRule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (r *family) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.ensureNATOutputChain(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dnatRule := []string{
|
||||
"-p", strings.ToLower(protoForFamily(protocol, localAddr.Is6())),
|
||||
"--dport", strconv.Itoa(int(originalPort)),
|
||||
"-d", localAddr.String(),
|
||||
"-j", "DNAT",
|
||||
"--to-destination", ":" + strconv.Itoa(int(translatedPort)),
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
|
||||
return fmt.Errorf("add output DNAT rule: %w", err)
|
||||
}
|
||||
r.rules[ruleID] = dnatRule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (r *family) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||
|
||||
if dnatRule, exists := r.rules[ruleID]; exists {
|
||||
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
|
||||
return fmt.Errorf("delete output DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
248
client/firewall/iptables/family_linux.go
Normal file
248
client/firewall/iptables/family_linux.go
Normal file
@@ -0,0 +1,248 @@
|
||||
//go:build !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// constants needed to manage and create iptable rules
|
||||
const (
|
||||
tableFilter = "filter"
|
||||
tableName = tableFilter
|
||||
tableNat = "nat"
|
||||
tableMangle = "mangle"
|
||||
|
||||
// chainNameInputRules is the peer ACL chain that holds installed
|
||||
// peer-filtering rules.
|
||||
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
||||
|
||||
// mangleFwdKey is the entries map key for mangle FORWARD guard
|
||||
// rules that prevent external DNAT from bypassing ACL rules.
|
||||
mangleFwdKey chainKey = "MANGLE-FORWARD"
|
||||
|
||||
chainINPUT = "INPUT"
|
||||
chainPOSTROUTING = "POSTROUTING"
|
||||
chainPREROUTING = "PREROUTING"
|
||||
chainFORWARD = "FORWARD"
|
||||
chainRTNAT = "NETBIRD-RT-NAT"
|
||||
chainRTFWDIN = "NETBIRD-RT-FWD-IN"
|
||||
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
|
||||
chainRTPRE = "NETBIRD-RT-PRE"
|
||||
chainRTRDR = "NETBIRD-RT-RDR"
|
||||
chainNATOutput = "NETBIRD-NAT-OUTPUT"
|
||||
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
|
||||
routingFinalForwardJump = "ACCEPT"
|
||||
routingFinalNatJump = "MASQUERADE"
|
||||
|
||||
jumpManglePre = "jump-mangle-pre"
|
||||
jumpNatPre = "jump-nat-pre"
|
||||
jumpNatPost = "jump-nat-post"
|
||||
jumpNatOutput = "jump-nat-output"
|
||||
jumpMSSClamp = "jump-mss-clamp"
|
||||
markManglePre = "mark-mangle-pre"
|
||||
markManglePost = "mark-mangle-post"
|
||||
matchSet = "--match-set"
|
||||
|
||||
dnatSuffix firewall.RuleID = "_dnat"
|
||||
snatSuffix firewall.RuleID = "_snat"
|
||||
fwdSuffix firewall.RuleID = "_fwd"
|
||||
|
||||
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
|
||||
ipv4TCPHeaderSize = 40
|
||||
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
|
||||
ipv6TCPHeaderSize = 60
|
||||
)
|
||||
|
||||
type ruleInfo struct {
|
||||
chain string
|
||||
table string
|
||||
rule []string
|
||||
}
|
||||
|
||||
type routeRules map[firewall.RuleID][]string
|
||||
|
||||
// ruleSpec is a single iptables rule expressed as its argument list
|
||||
// (e.g. {"-i", "wg0", "-j", "DROP"}).
|
||||
type ruleSpec []string
|
||||
|
||||
// chainKey identifies the chain a seeded entry belongs to. It holds
|
||||
// built-in chain names ("INPUT", "FORWARD", "PREROUTING") plus the
|
||||
// synthetic mangleFwdKey bucket for the mangle FORWARD guard rules.
|
||||
type chainKey string
|
||||
|
||||
// aclEntries maps a chain to the rules seeded into it to jump into or
|
||||
// guard the netbird ACL chains.
|
||||
type aclEntries map[chainKey][]ruleSpec
|
||||
|
||||
type entry struct {
|
||||
spec ruleSpec
|
||||
position int
|
||||
}
|
||||
|
||||
// ipsetCounter is the shared hash:net refcounter used by peer and
|
||||
// route ACLs alike. The ipset library does not support comments, so
|
||||
// the key is just the set name (string).
|
||||
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
|
||||
|
||||
// family holds the per-address-family iptables state. One instance
|
||||
// handles route ACLs, peer ACLs, NAT, DNAT, and MSS clamping for a
|
||||
// single family; the top-level Manager owns one for v4 and another
|
||||
// for v6.
|
||||
type family struct {
|
||||
iptablesClient *iptables.IPTables
|
||||
wgIface iFaceMapper
|
||||
v6 bool
|
||||
|
||||
// Peer ACL chain bookkeeping.
|
||||
entries aclEntries
|
||||
optionalEntries map[chainKey][]entry
|
||||
|
||||
// filters holds peer + route filter rules keyed by content hash.
|
||||
// AddFilterRule writes here; DeleteFilterRule looks up by id.
|
||||
filters map[nbid.RuleID]*Rule
|
||||
ipsetCounter *ipsetCounter
|
||||
|
||||
// rules holds NAT, jump, and MSS-clamping rules (auxiliary
|
||||
// plumbing that isn't a filter rule).
|
||||
rules routeRules
|
||||
|
||||
// Routing / NAT.
|
||||
legacyManagement bool
|
||||
mtu uint16
|
||||
ipFwdState *ipfwdstate.IPForwardingState
|
||||
|
||||
stateManager *statemanager.Manager
|
||||
}
|
||||
|
||||
func newFamily(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*family, error) {
|
||||
r := &family{
|
||||
iptablesClient: iptablesClient,
|
||||
wgIface: wgIface,
|
||||
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
|
||||
entries: make(aclEntries),
|
||||
optionalEntries: make(map[chainKey][]entry),
|
||||
filters: make(map[nbid.RuleID]*Rule),
|
||||
rules: make(routeRules),
|
||||
mtu: mtu,
|
||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||
}
|
||||
|
||||
r.ipsetCounter = refcounter.New(
|
||||
func(name string, sources []netip.Prefix) (struct{}, error) {
|
||||
return struct{}{}, r.createIpSet(name, sources)
|
||||
},
|
||||
func(name string, _ struct{}) error {
|
||||
return r.deleteIpSet(name)
|
||||
},
|
||||
)
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
// init wires the family to the state manager and installs both the
|
||||
// route ACL containers and the peer ACL chain skeleton.
|
||||
func (r *family) init(stateManager *statemanager.Manager) error {
|
||||
r.stateManager = stateManager
|
||||
|
||||
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
}
|
||||
|
||||
if err := r.createContainers(); err != nil {
|
||||
return fmt.Errorf("create containers: %w", err)
|
||||
}
|
||||
|
||||
if err := r.setupDataPlaneMark(); err != nil {
|
||||
log.Errorf("failed to set up data plane mark: %v", err)
|
||||
}
|
||||
|
||||
r.seedInitialEntries()
|
||||
r.seedInitialOptionalEntries()
|
||||
|
||||
if err := r.cleanAclChains(); err != nil {
|
||||
return fmt.Errorf("clean acl chains: %w", err)
|
||||
}
|
||||
if err := r.createDefaultChains(); err != nil {
|
||||
return fmt.Errorf("create default chains: %w", err)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset tears down all firewall state owned by this family. ACL
|
||||
// chain cleanup runs before route-chain cleanup because the route
|
||||
// chains are still referenced by FORWARD jumps installed during
|
||||
// seedInitialEntries; deleting them first would trip EBUSY.
|
||||
func (r *family) Reset() error {
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := r.cleanAclChains(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.ipsetCounter.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.cleanupDataPlaneMark(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
r.rules = make(routeRules)
|
||||
r.filters = make(map[nbid.RuleID]*Rule)
|
||||
r.updateState()
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) updateState() {
|
||||
if r.stateManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var currentState *ShutdownState
|
||||
if existing := r.stateManager.GetState(currentState); existing != nil {
|
||||
if existingState, ok := existing.(*ShutdownState); ok {
|
||||
currentState = existingState
|
||||
}
|
||||
}
|
||||
if currentState == nil {
|
||||
currentState = &ShutdownState{}
|
||||
}
|
||||
|
||||
currentState.Lock()
|
||||
defer currentState.Unlock()
|
||||
|
||||
if r.v6 {
|
||||
currentState.RouteRules6 = r.rules
|
||||
currentState.RouteIPsetCounter6 = r.ipsetCounter
|
||||
currentState.ACLEntries6 = r.entries
|
||||
} else {
|
||||
currentState.RouteRules = r.rules
|
||||
currentState.RouteIPsetCounter = r.ipsetCounter
|
||||
currentState.ACLEntries = r.entries
|
||||
}
|
||||
|
||||
if err := r.stateManager.UpdateState(currentState); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
}
|
||||
}
|
||||
331
client/firewall/iptables/filter_linux.go
Normal file
331
client/firewall/iptables/filter_linux.go
Normal file
@@ -0,0 +1,331 @@
|
||||
//go:build !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
// AddFilterRule installs a packet-filtering rule. With destination
|
||||
// empty, the rule goes to the peer ACL input chain plus a paired
|
||||
// mangle PREROUTING rule for the redirect mark. With destination set
|
||||
// (prefix or named set), it goes to the route ACL forward chain.
|
||||
// Multi-source rules collapse to one iptables rule via the shared
|
||||
// hash:net ipset.
|
||||
func (r *family) AddFilterRule(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
ruleID := nbid.GenerateRuleID(sources, destination, proto, sPort, dPort, action)
|
||||
if existing, ok := r.filters[ruleID]; ok {
|
||||
return existing, nil
|
||||
}
|
||||
|
||||
srcMatch, err := r.applySourceMatch(sourceNetwork(sources), sources)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("apply source match: %w", err)
|
||||
}
|
||||
|
||||
rule, err := r.installFilterRule(ruleID, srcMatch, destination, proto, sPort, dPort, action)
|
||||
if err != nil {
|
||||
r.dropSourceMatch(srcMatch)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.filters[ruleID] = rule
|
||||
r.updateState()
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (r *family) hasRule(id nbid.RuleID) bool {
|
||||
_, ok := r.filters[id]
|
||||
return ok
|
||||
}
|
||||
|
||||
// hasDNATRule reports whether this family owns the DNAT rule set for
|
||||
// the given user id. DNAT rules live in r.rules under the well-known
|
||||
// "<id>_dnat" key; the lookup here is used by Manager.DeleteDNATRule
|
||||
// to pick the right family.
|
||||
func (r *family) hasDNATRule(id firewall.RuleID) bool {
|
||||
_, ok := r.rules[id+dnatSuffix]
|
||||
return ok
|
||||
}
|
||||
|
||||
// DeleteFilterRule removes a previously installed filter rule. The
|
||||
// rule's stored chain/table identify where to delete from; source set
|
||||
// references are recovered from the spec via findSets and dropped
|
||||
// from the shared ipset counter.
|
||||
func (r *family) DeleteFilterRule(rule firewall.Rule) error {
|
||||
ruleID := rule.ID()
|
||||
pr, ok := r.filters[ruleID]
|
||||
if !ok {
|
||||
log.Debugf("filter rule %s not found", ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.iptablesClient.Delete(tableFilter, pr.chain, pr.specs...); err != nil {
|
||||
return fmt.Errorf("delete rule %s: %w", pr.chain, err)
|
||||
}
|
||||
if pr.mangleSpecs != nil {
|
||||
if err := r.iptablesClient.Delete(tableMangle, chainRTPRE, pr.mangleSpecs...); err != nil {
|
||||
log.Errorf("delete mangle rule: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.dropSourceMatch(pr.specs)
|
||||
delete(r.filters, ruleID)
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
// findSets scans an iptables rule spec for "-m set --match-set <name>
|
||||
// <dir>" fragments and returns the named sets in occurrence order.
|
||||
// Used at delete time to drop ipsetCounter references.
|
||||
func findSets(rule []string) []string {
|
||||
var sets []string
|
||||
for i, arg := range rule {
|
||||
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
|
||||
sets = append(sets, rule[i+3])
|
||||
}
|
||||
}
|
||||
return sets
|
||||
}
|
||||
|
||||
// sourceNetwork classifies a source-prefix list into the firewall.Network
|
||||
// shape the rest of the spec-builder consumes: empty for match-any, a
|
||||
// single prefix inline, or an ipset for multiple sources.
|
||||
func sourceNetwork(sources []netip.Prefix) firewall.Network {
|
||||
switch {
|
||||
case len(sources) == 0:
|
||||
return firewall.Network{}
|
||||
case len(sources) == 1 && sources[0].Bits() == 0:
|
||||
return firewall.Network{}
|
||||
case len(sources) == 1:
|
||||
return firewall.Network{Prefix: sources[0]}
|
||||
default:
|
||||
return firewall.Network{Set: firewall.NewPrefixSet(sources)}
|
||||
}
|
||||
}
|
||||
|
||||
// applySourceMatch returns the iptables match fragment for the rule's
|
||||
// source. For a Set it increments the shared ipset's refcount; for a
|
||||
// Prefix it emits a direct -s match; for the wildcard it returns nil.
|
||||
func (r *family) applySourceMatch(network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
|
||||
switch {
|
||||
case network.IsSet():
|
||||
if r.ipsetCounter == nil {
|
||||
return nil, fmt.Errorf("multi-source peer rule requires shared ipset counter")
|
||||
}
|
||||
name := r.ipsetName(network.Set.HashedName())
|
||||
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
|
||||
return nil, fmt.Errorf("ipset increment %s: %w", name, err)
|
||||
}
|
||||
return []string{"-m", "set", matchSet, name, "src"}, nil
|
||||
case network.IsPrefix():
|
||||
return []string{"-s", network.Prefix.String()}, nil
|
||||
default:
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// dropSourceMatch undoes whatever applySourceMatch reserved. Safe to
|
||||
// call when the spec is empty or holds only inline matchers. Decrement
|
||||
// errors are logged but not returned: the filter rule has already been
|
||||
// deleted at that point and we don't want to leak the deletion.
|
||||
func (r *family) dropSourceMatch(srcMatch []string) {
|
||||
if r.ipsetCounter == nil {
|
||||
return
|
||||
}
|
||||
for _, name := range findSets(srcMatch) {
|
||||
if _, err := r.ipsetCounter.Decrement(name); err != nil {
|
||||
log.Errorf("rollback ipset decrement %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// decrementSetCounter drops ipset references owned by a raw rule spec
|
||||
// stored in r.rules (NAT / legacy route entries). It returns an error
|
||||
// aggregate so the caller surfaces decrement failures.
|
||||
func (r *family) decrementSetCounter(rule []string) error {
|
||||
if r.ipsetCounter == nil {
|
||||
return nil
|
||||
}
|
||||
var merr *multierror.Error
|
||||
for _, name := range findSets(rule) {
|
||||
if _, err := r.ipsetCounter.Decrement(name); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("decrement counter: %w", err))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// installFilterRule assembles and writes one iptables filter-chain
|
||||
// rule. With destination empty the rule lands in the peer ACL input
|
||||
// chain and a paired mangle PREROUTING rule is added for the redirect
|
||||
// mark. With destination set the rule lands in the route ACL forward
|
||||
// chain and there is no mangle pairing.
|
||||
func (r *family) installFilterRule(
|
||||
ruleID nbid.RuleID,
|
||||
srcMatch []string,
|
||||
destination firewall.Network,
|
||||
protocol firewall.Protocol,
|
||||
sPort, dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (*Rule, error) {
|
||||
isRoute := destination.IsPrefix() || destination.IsSet()
|
||||
|
||||
proto := protoForFamily(protocol, r.v6)
|
||||
|
||||
specs := slices.Clone(srcMatch)
|
||||
var destExp []string
|
||||
if isRoute {
|
||||
var err error
|
||||
destExp, err = r.applyNetwork("-d", destination, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("apply network -d: %w", err)
|
||||
}
|
||||
specs = append(specs, destExp...)
|
||||
}
|
||||
specs = append(specs, filterMatchSpecs(proto, sPort, dPort)...)
|
||||
|
||||
var mangleSpecs []string
|
||||
if !isRoute {
|
||||
mangleSpecs = slices.Clone(specs)
|
||||
mangleSpecs = append(mangleSpecs,
|
||||
"-i", r.wgIface.Name(),
|
||||
"-m", "addrtype", "--dst-type", "LOCAL",
|
||||
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
||||
)
|
||||
}
|
||||
|
||||
specs = append(specs, "-j", actionToStr(action))
|
||||
|
||||
chain := chainNameInputRules
|
||||
if isRoute {
|
||||
chain = chainRTFWDIN
|
||||
}
|
||||
|
||||
// Peer ACL drops are inserted at position 1 so they precede the
|
||||
// chain's catch-all; route ACL drops are inserted at position 2
|
||||
// to sit immediately after the established/related accept rule.
|
||||
var err error
|
||||
if action == firewall.ActionDrop {
|
||||
pos := 1
|
||||
if isRoute {
|
||||
pos = 2
|
||||
}
|
||||
err = r.iptablesClient.Insert(tableFilter, chain, pos, specs...)
|
||||
} else {
|
||||
err = r.iptablesClient.Append(tableFilter, chain, specs...)
|
||||
}
|
||||
if err != nil {
|
||||
r.dropSourceMatch(destExp)
|
||||
return nil, fmt.Errorf("install filter rule on %s: %w", chain, err)
|
||||
}
|
||||
|
||||
if mangleSpecs != nil {
|
||||
if err := r.iptablesClient.Append(tableMangle, chainRTPRE, mangleSpecs...); err != nil {
|
||||
log.Errorf("add mangle rule: %v", err)
|
||||
mangleSpecs = nil
|
||||
}
|
||||
}
|
||||
|
||||
return &Rule{
|
||||
id: ruleID,
|
||||
specs: specs,
|
||||
mangleSpecs: mangleSpecs,
|
||||
chain: chain,
|
||||
v6: r.v6,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// applyNetwork resolves a firewall.Network into the iptables match
|
||||
// fragment for the given direction flag (-s or -d). Set networks
|
||||
// increment the shared ipset refcount; prefixes emit a direct match;
|
||||
// an empty network returns no spec ("match any").
|
||||
func (r *family) applyNetwork(flag string, network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
|
||||
direction := "src"
|
||||
if flag == "-d" {
|
||||
direction = "dst"
|
||||
}
|
||||
|
||||
if network.IsSet() {
|
||||
name := r.ipsetName(network.Set.HashedName())
|
||||
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
|
||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||
}
|
||||
|
||||
return []string{"-m", "set", matchSet, name, direction}, nil
|
||||
}
|
||||
if network.IsPrefix() {
|
||||
return []string{flag, network.Prefix.String()}, nil
|
||||
}
|
||||
|
||||
// nolint:nilnil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
|
||||
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
|
||||
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
|
||||
if v6 && protocol == firewall.ProtocolICMP {
|
||||
return "ipv6-icmp"
|
||||
}
|
||||
return string(protocol)
|
||||
}
|
||||
|
||||
// filterMatchSpecs returns the proto/port match fragment for a
|
||||
// filtering rule. The source match (-s or -m set) is built by the
|
||||
// caller and prepended.
|
||||
func filterMatchSpecs(protocol string, sPort, dPort *firewall.Port) (specs []string) {
|
||||
if protocol != "all" {
|
||||
specs = append(specs, "-p", protocol)
|
||||
}
|
||||
specs = append(specs, applyPort("--sport", sPort)...)
|
||||
specs = append(specs, applyPort("--dport", dPort)...)
|
||||
return specs
|
||||
}
|
||||
|
||||
func actionToStr(action firewall.Action) string {
|
||||
if action == firewall.ActionAccept {
|
||||
return "ACCEPT"
|
||||
}
|
||||
return "DROP"
|
||||
}
|
||||
|
||||
func applyPort(flag string, port *firewall.Port) []string {
|
||||
if port == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if port.IsRange && len(port.Values) == 2 {
|
||||
return []string{flag, fmt.Sprintf("%d:%d", port.Values[0], port.Values[1])}
|
||||
}
|
||||
|
||||
if len(port.Values) > 1 {
|
||||
portList := make([]string, len(port.Values))
|
||||
for i, p := range port.Values {
|
||||
portList[i] = strconv.Itoa(int(p))
|
||||
}
|
||||
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
|
||||
}
|
||||
|
||||
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
||||
}
|
||||
97
client/firewall/iptables/ipset_linux.go
Normal file
97
client/firewall/iptables/ipset_linux.go
Normal file
@@ -0,0 +1,97 @@
|
||||
//go:build !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
ipset "github.com/lrh3321/ipset-go"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
func (r *family) createIpSet(setName string, sources []netip.Prefix) error {
|
||||
if err := r.createIPSet(setName); err != nil {
|
||||
return fmt.Errorf("create set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
for _, prefix := range sources {
|
||||
if err := r.addPrefixToIPSet(setName, prefix); err != nil {
|
||||
return fmt.Errorf("add element to set %s: %w", setName, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) deleteIpSet(setName string) error {
|
||||
if err := r.destroyIPSet(setName); err != nil {
|
||||
return fmt.Errorf("destroy set %s: %w", setName, err)
|
||||
}
|
||||
|
||||
log.Debugf("Deleted unused ipset %s", setName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
name := r.ipsetName(set.HashedName())
|
||||
var merr *multierror.Error
|
||||
for _, prefix := range prefixes {
|
||||
if err := r.addPrefixToIPSet(name, prefix); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
|
||||
}
|
||||
}
|
||||
if merr == nil {
|
||||
log.Debugf("updated set %s with prefixes %v", name, prefixes)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) ipsetName(name string) string {
|
||||
if r.v6 {
|
||||
return name + "-v6"
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func (r *family) createIPSet(name string) error {
|
||||
opts := ipset.CreateOptions{
|
||||
Replace: true,
|
||||
}
|
||||
if r.v6 {
|
||||
opts.Family = ipset.FamilyIPV6
|
||||
}
|
||||
|
||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
log.Debugf("created ipset %s with type hash:net", name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) addPrefixToIPSet(name string, prefix netip.Prefix) error {
|
||||
addr := prefix.Addr()
|
||||
ip := addr.AsSlice()
|
||||
|
||||
entry := &ipset.Entry{
|
||||
IP: ip,
|
||||
CIDR: uint8(prefix.Bits()),
|
||||
Replace: true,
|
||||
}
|
||||
|
||||
if err := ipset.Add(name, entry); err != nil {
|
||||
return fmt.Errorf("add prefix to ipset %s: %w", name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) destroyIPSet(name string) error {
|
||||
return ipset.Destroy(name)
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package iptables
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
@@ -18,25 +17,21 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
type resetter interface {
|
||||
Reset() error
|
||||
}
|
||||
|
||||
// Manager of iptables firewall
|
||||
// Manager of iptables firewall. Per-family state (peer ACLs, route
|
||||
// ACLs, NAT, DNAT, MSS clamping) lives on family; Manager dispatches
|
||||
// by family and provides the public firewall.Manager surface.
|
||||
type Manager struct {
|
||||
mutex sync.Mutex
|
||||
|
||||
wgIface iFaceMapper
|
||||
|
||||
ipv4Client *iptables.IPTables
|
||||
aclMgr *aclManager
|
||||
router *router
|
||||
family4 *family
|
||||
rawSupported bool
|
||||
|
||||
// IPv6 counterparts, nil when no v6 overlay
|
||||
ipv6Client *iptables.IPTables
|
||||
aclMgr6 *aclManager
|
||||
router6 *router
|
||||
family6 *family
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
@@ -57,14 +52,9 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
||||
ipv4Client: iptablesClient,
|
||||
}
|
||||
|
||||
m.router, err = newRouter(iptablesClient, wgIface, mtu)
|
||||
m.family4, err = newFamily(iptablesClient, wgIface, mtu)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create router: %w", err)
|
||||
}
|
||||
|
||||
m.aclMgr, err = newAclManager(iptablesClient, wgIface)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||
return nil, fmt.Errorf("create family: %w", err)
|
||||
}
|
||||
|
||||
if wgIface.Address().HasIPv6() {
|
||||
@@ -83,19 +73,14 @@ func (m *Manager) createIPv6Components(wgIface iFaceMapper, mtu uint16) error {
|
||||
}
|
||||
m.ipv6Client = ip6Client
|
||||
|
||||
m.router6, err = newRouter(ip6Client, wgIface, mtu)
|
||||
m.family6, err = newFamily(ip6Client, wgIface, mtu)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create v6 router: %w", err)
|
||||
return fmt.Errorf("create v6 family: %w", err)
|
||||
}
|
||||
|
||||
// Share the same IP forwarding state with the v4 router, since
|
||||
// Share the same IP forwarding state with the v4 family, since
|
||||
// EnableIPForwarding controls both v4 and v6 sysctls.
|
||||
m.router6.ipFwdState = m.router.ipFwdState
|
||||
|
||||
m.aclMgr6, err = newAclManager(ip6Client, wgIface)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create v6 acl manager: %w", err)
|
||||
}
|
||||
m.family6.ipFwdState = m.family4.ipFwdState
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -109,7 +94,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
InterfaceState: &InterfaceState{
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
MTU: m.router.mtu,
|
||||
MTU: m.family4.mtu,
|
||||
},
|
||||
}
|
||||
stateManager.RegisterState(state)
|
||||
@@ -141,31 +126,24 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// initChains initializes router and ACL chains for both address families,
|
||||
// rolling back on failure.
|
||||
// initChains initializes the per-family firewall state for both
|
||||
// address families, rolling back on failure.
|
||||
func (m *Manager) initChains(stateManager *statemanager.Manager) error {
|
||||
type initStep struct {
|
||||
name string
|
||||
init func(*statemanager.Manager) error
|
||||
mgr resetter
|
||||
r *family
|
||||
}
|
||||
|
||||
steps := []initStep{
|
||||
{"router", m.router.init, m.router},
|
||||
{"acl manager", m.aclMgr.init, m.aclMgr},
|
||||
}
|
||||
steps := []initStep{{"v4", m.family4}}
|
||||
if m.hasIPv6() {
|
||||
steps = append(steps,
|
||||
initStep{"v6 router", m.router6.init, m.router6},
|
||||
initStep{"v6 acl manager", m.aclMgr6.init, m.aclMgr6},
|
||||
)
|
||||
steps = append(steps, initStep{"v6", m.family6})
|
||||
}
|
||||
|
||||
var initialized []initStep
|
||||
for _, s := range steps {
|
||||
if err := s.init(stateManager); err != nil {
|
||||
if err := s.r.init(stateManager); err != nil {
|
||||
for i := len(initialized) - 1; i >= 0; i-- {
|
||||
if rerr := initialized[i].mgr.Reset(); rerr != nil {
|
||||
if rerr := initialized[i].r.Reset(); rerr != nil {
|
||||
log.Warnf("rollback %s: %v", initialized[i].name, rerr)
|
||||
}
|
||||
}
|
||||
@@ -176,84 +154,78 @@ func (m *Manager) initChains(stateManager *statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPeerFiltering adds a rule to the firewall
|
||||
//
|
||||
// Comment will be ignored because some system this feature is not supported
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if ip.To4() != nil {
|
||||
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.aclMgr6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
// AddFilterRule installs a packet-filtering rule. See firewall.Manager
|
||||
// docs for destination semantics. Mixed-family source lists are split
|
||||
// and dispatched to the v4 / v6 backends.
|
||||
func (m *Manager) AddFilterRule(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort, dPort *firewall.Port,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
) ([]firewall.Rule, error) {
|
||||
if len(sources) == 0 {
|
||||
return nil, firewall.ErrNoSources
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if isIPv6RouteRule(sources, destination) {
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
|
||||
isRoute := destination.IsPrefix() || destination.IsSet()
|
||||
if isRoute {
|
||||
fam := m.family4
|
||||
if isIPv6RouteRule(sources, destination) {
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
fam = m.family6
|
||||
}
|
||||
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
rule, err := fam.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []firewall.Rule{rule}, nil
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
|
||||
if destination.IsPrefix() {
|
||||
return destination.Prefix.Addr().Is6()
|
||||
v4, v6 := splitSourcesByFamily(sources)
|
||||
var out []firewall.Rule
|
||||
if len(v4) > 0 {
|
||||
rule, err := m.family4.AddFilterRule(id, v4, destination, proto, sPort, dPort, action)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, rule)
|
||||
}
|
||||
return len(sources) > 0 && sources[0].Addr().Is6()
|
||||
if len(v6) > 0 {
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add peer filtering for v6 sources %v: %w", v6, firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
rule, err := m.family6.AddFilterRule(id, v6, destination, proto, sPort, dPort, action)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, rule)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
// DeleteFilterRule removes a rule previously added via AddFilterRule.
|
||||
// The rule is looked up by id in each family's filter cache.
|
||||
func (m *Manager) DeleteFilterRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && isIPv6IptRule(rule) {
|
||||
return m.aclMgr6.DeletePeerRule(rule)
|
||||
id := rule.ID()
|
||||
if m.family4.hasRule(id) {
|
||||
return m.family4.DeleteFilterRule(rule)
|
||||
}
|
||||
return m.aclMgr.DeletePeerRule(rule)
|
||||
}
|
||||
|
||||
func isIPv6IptRule(rule firewall.Rule) bool {
|
||||
r, ok := rule.(*Rule)
|
||||
return ok && r.v6
|
||||
}
|
||||
|
||||
// DeleteRouteRule deletes a routing rule.
|
||||
// Route rules are keyed by content hash. Check v4 first, try v6 if not found.
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && !m.router.hasRule(rule.ID()) {
|
||||
return m.router6.DeleteRouteRule(rule)
|
||||
if m.hasIPv6() && m.family6.hasRule(id) {
|
||||
return m.family6.DeleteFilterRule(rule)
|
||||
}
|
||||
return m.router.DeleteRouteRule(rule)
|
||||
log.Debugf("filter rule %s not found in any family", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) IsServerRouteSupported() bool {
|
||||
@@ -272,10 +244,10 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddNatRule(pair)
|
||||
return m.family6.AddNatRule(pair)
|
||||
}
|
||||
|
||||
if err := m.router.AddNatRule(pair); err != nil {
|
||||
if err := m.family4.AddNatRule(pair); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -284,7 +256,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
// wildcard 0.0.0.0/0 destination where the client resolves DNS.
|
||||
if m.hasIPv6() && pair.Dynamic {
|
||||
v6Pair := firewall.ToV6NatPair(pair)
|
||||
if err := m.router6.AddNatRule(v6Pair); err != nil {
|
||||
if err := m.family6.AddNatRule(v6Pair); err != nil {
|
||||
return fmt.Errorf("add v6 NAT rule: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -300,18 +272,18 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if !m.hasIPv6() {
|
||||
return nil
|
||||
}
|
||||
return m.router6.RemoveNatRule(pair)
|
||||
return m.family6.RemoveNatRule(pair)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := m.router.RemoveNatRule(pair); err != nil {
|
||||
if err := m.family4.RemoveNatRule(pair); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
|
||||
}
|
||||
|
||||
if m.hasIPv6() && pair.Dynamic {
|
||||
v6Pair := firewall.ToV6NatPair(pair)
|
||||
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
|
||||
if err := m.family6.RemoveNatRule(v6Pair); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
|
||||
}
|
||||
}
|
||||
@@ -320,11 +292,11 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
}
|
||||
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
|
||||
if err := firewall.SetLegacyManagement(m.family4, isLegacy); err != nil {
|
||||
return err
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
return firewall.SetLegacyManagement(m.router6, isLegacy)
|
||||
return firewall.SetLegacyManagement(m.family6, isLegacy)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -341,19 +313,13 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
}
|
||||
|
||||
if m.hasIPv6() {
|
||||
if err := m.aclMgr6.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 acl manager: %w", err))
|
||||
}
|
||||
if err := m.router6.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %w", err))
|
||||
if err := m.family6.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 family: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.aclMgr.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
|
||||
}
|
||||
if err := m.router.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
|
||||
if err := m.family4.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset family: %w", err))
|
||||
}
|
||||
|
||||
// Appending to merr intentionally blocks DeleteState below so ShutdownState
|
||||
@@ -377,11 +343,11 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
// rules so that packet filtering is handled in userspace instead of by netfilter.
|
||||
func (m *Manager) AllowNetbird() error {
|
||||
var merr *multierror.Error
|
||||
if _, err := m.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
|
||||
if _, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}, firewall.Network{}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("allow netbird v4 interface traffic: %w", err))
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
if _, err := m.AddPeerFiltering(nil, net.IPv6zero, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
|
||||
if _, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}, firewall.Network{}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("allow netbird v6 interface traffic: %w", err))
|
||||
}
|
||||
}
|
||||
@@ -402,14 +368,14 @@ func (m *Manager) SetLogLevel(log.Level) {
|
||||
}
|
||||
|
||||
func (m *Manager) EnableRouting() error {
|
||||
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
||||
if err := m.family4.ipFwdState.RequestForwarding(); err != nil {
|
||||
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DisableRouting() error {
|
||||
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
if err := m.family4.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
@@ -424,9 +390,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddDNATRule(rule)
|
||||
return m.family6.AddDNATRule(rule)
|
||||
}
|
||||
return m.router.AddDNATRule(rule)
|
||||
return m.family4.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
@@ -434,10 +400,10 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && !m.router.hasRule(rule.ID()+dnatSuffix) {
|
||||
return m.router6.DeleteDNATRule(rule)
|
||||
if m.hasIPv6() && !m.family4.hasDNATRule(rule.ID()) {
|
||||
return m.family6.DeleteDNATRule(rule)
|
||||
}
|
||||
return m.router.DeleteDNATRule(rule)
|
||||
return m.family4.DeleteDNATRule(rule)
|
||||
}
|
||||
|
||||
// UpdateSet updates the set with the given prefixes
|
||||
@@ -454,12 +420,12 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
|
||||
if err := m.family4.UpdateSet(set, v4Prefixes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.hasIPv6() && len(v6Prefixes) > 0 {
|
||||
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
|
||||
if err := m.family6.UpdateSet(set, v6Prefixes); err != nil {
|
||||
return fmt.Errorf("update v6 set: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -476,9 +442,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family4.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
@@ -490,9 +456,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family4.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
@@ -504,9 +470,9 @@ func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family4.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
@@ -518,9 +484,9 @@ func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Proto
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family4.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -654,3 +620,22 @@ func (m *Manager) cleanupNoTrackChain() error {
|
||||
func getConntrackEstablished() []string {
|
||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
// splitSourcesByFamily partitions a mixed-family prefix list.
|
||||
func splitSourcesByFamily(sources []netip.Prefix) (v4, v6 []netip.Prefix) {
|
||||
for _, p := range sources {
|
||||
if p.Addr().Is4() || p.Addr().Is4In6() {
|
||||
v4 = append(v4, p)
|
||||
} else {
|
||||
v6 = append(v6, p)
|
||||
}
|
||||
}
|
||||
return v4, v6
|
||||
}
|
||||
|
||||
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
|
||||
if destination.IsPrefix() {
|
||||
return destination.Prefix.Addr().Is6()
|
||||
}
|
||||
return len(sources) > 0 && sources[0].Addr().Is6()
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build integration && !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
@@ -72,7 +74,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
IsRange: true,
|
||||
Values: []uint16{8043, 8046},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
|
||||
rule2, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", port, nil, fw.ActionAccept)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
for _, r := range rule2 {
|
||||
@@ -83,18 +85,16 @@ func TestIptablesManager(t *testing.T) {
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
for _, r := range rule2 {
|
||||
err := manager.DeletePeerRule(r)
|
||||
err := manager.DeleteFilterRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
}
|
||||
|
||||
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
|
||||
})
|
||||
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
// add second rule
|
||||
ip := netip.MustParseAddr("10.20.0.3")
|
||||
port := &fw.Port{Values: []uint16{5353}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "udp", nil, port, fw.ActionAccept)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Close(nil)
|
||||
@@ -126,7 +126,7 @@ func TestIptablesManagerDenyRules(t *testing.T) {
|
||||
ip := netip.MustParseAddr("10.20.0.3")
|
||||
port := &fw.Port{Values: []uint16{22}}
|
||||
|
||||
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionDrop, "deny-ssh")
|
||||
rule, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionDrop)
|
||||
require.NoError(t, err, "failed to add deny rule")
|
||||
require.NotEmpty(t, rule, "deny rule should not be empty")
|
||||
|
||||
@@ -142,11 +142,11 @@ func TestIptablesManagerDenyRules(t *testing.T) {
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
|
||||
// Add accept rule first
|
||||
_, err := manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "accept-http")
|
||||
_, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
|
||||
require.NoError(t, err, "failed to add accept rule")
|
||||
|
||||
// Add deny rule second for same IP/port - this should take precedence
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionDrop, "deny-http")
|
||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionDrop)
|
||||
require.NoError(t, err, "failed to add deny rule")
|
||||
|
||||
// Inspect the actual iptables rules to verify deny rule comes before accept rule
|
||||
@@ -159,19 +159,23 @@ func TestIptablesManagerDenyRules(t *testing.T) {
|
||||
t.Logf(" [%d] %s", i, rule)
|
||||
}
|
||||
|
||||
// Single-source rules emit a direct `-s <ip>/32 ... --dport 80`
|
||||
// match. Match on that shape instead of the legacy
|
||||
// per-(action,port) ipset names ("deny-http"/"accept-http")
|
||||
// that this test predates.
|
||||
srcMatch := fmt.Sprintf("-s %s/32", ip)
|
||||
var denyRuleIndex, acceptRuleIndex = -1, -1
|
||||
for i, rule := range rules {
|
||||
if strings.Contains(rule, "DROP") {
|
||||
t.Logf("Found DROP rule at index %d: %s", i, rule)
|
||||
if strings.Contains(rule, "deny-http") && strings.Contains(rule, "80") {
|
||||
denyRuleIndex = i
|
||||
}
|
||||
if !strings.Contains(rule, srcMatch) || !strings.Contains(rule, "--dport 80") {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(rule, "ACCEPT") {
|
||||
if strings.Contains(rule, "-j DROP") {
|
||||
t.Logf("Found DROP rule at index %d: %s", i, rule)
|
||||
denyRuleIndex = i
|
||||
}
|
||||
if strings.Contains(rule, "-j ACCEPT") {
|
||||
t.Logf("Found ACCEPT rule at index %d: %s", i, rule)
|
||||
if strings.Contains(rule, "accept-http") && strings.Contains(rule, "80") {
|
||||
acceptRuleIndex = i
|
||||
}
|
||||
acceptRuleIndex = i
|
||||
}
|
||||
}
|
||||
|
||||
@@ -196,7 +200,6 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
@@ -211,26 +214,41 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
}()
|
||||
|
||||
var rule2 []fw.Rule
|
||||
t.Run("add second rule", func(t *testing.T) {
|
||||
t.Run("single source uses direct -s match (no ipset)", func(t *testing.T) {
|
||||
ip := netip.MustParseAddr("10.20.0.3")
|
||||
port := &fw.Port{
|
||||
Values: []uint16{443},
|
||||
}
|
||||
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
|
||||
rule2, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", port, nil, fw.ActionAccept)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
require.Len(t, rule2, 1)
|
||||
require.Contains(t, rule2[0].(*Rule).specs, "-s",
|
||||
"single-source rule should use direct -s match, not an ipset")
|
||||
require.Empty(t, findSets(rule2[0].(*Rule).specs),
|
||||
"single-source rule should not allocate a shared ipset")
|
||||
})
|
||||
|
||||
t.Run("delete single-source rule", func(t *testing.T) {
|
||||
for _, r := range rule2 {
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||
require.Equal(t, r.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
|
||||
err := manager.DeleteFilterRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("delete second rule", func(t *testing.T) {
|
||||
for _, r := range rule2 {
|
||||
err := manager.DeletePeerRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
|
||||
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
|
||||
t.Run("multi-source uses shared ipset", func(t *testing.T) {
|
||||
sources := []netip.Prefix{
|
||||
netip.PrefixFrom(netip.MustParseAddr("10.20.0.3"), 32),
|
||||
netip.PrefixFrom(netip.MustParseAddr("10.20.0.4"), 32),
|
||||
netip.PrefixFrom(netip.MustParseAddr("10.20.0.5"), 32),
|
||||
}
|
||||
port := &fw.Port{Values: []uint16{8080}}
|
||||
multi, err := manager.AddFilterRule(nil, sources, fw.Network{}, "tcp", nil, port, fw.ActionAccept)
|
||||
require.NoError(t, err, "failed to add multi-source rule")
|
||||
require.Len(t, multi, 1, "multi-source rule must produce exactly one iptables rule")
|
||||
sets := findSets(multi[0].(*Rule).specs)
|
||||
require.Len(t, sets, 1, "multi-source rule must reference exactly one ipset")
|
||||
|
||||
require.NoError(t, manager.DeleteFilterRule(multi[0]))
|
||||
})
|
||||
|
||||
t.Run("reset check", func(t *testing.T) {
|
||||
@@ -281,7 +299,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
|
||||
269
client/firewall/iptables/nat_linux.go
Normal file
269
client/firewall/iptables/nat_linux.go
Normal file
@@ -0,0 +1,269 @@
|
||||
//go:build !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
func (r *family) AddNatRule(pair firewall.RouterPair) error {
|
||||
if r.legacyManagement {
|
||||
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("add legacy routing rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if !pair.Masquerade {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.addNatRule(pair); err != nil {
|
||||
return fmt.Errorf("add nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("add inverse nat rule: %w", err)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
||||
func (r *family) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if pair.Masquerade {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("remove inverse nat rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||
func (r *family) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
|
||||
if err := r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||
return fmt.Errorf("add legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = rule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement ipset counter: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLegacyManagement returns the current legacy management mode
|
||||
func (r *family) GetLegacyManagement() bool {
|
||||
return r.legacyManagement
|
||||
}
|
||||
|
||||
// SetLegacyManagement sets the route manager to use legacy management mode
|
||||
func (r *family) SetLegacyManagement(isLegacy bool) {
|
||||
r.legacyManagement = isLegacy
|
||||
}
|
||||
|
||||
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
|
||||
func (r *family) RemoveAllLegacyRouteRules() error {
|
||||
var merr *multierror.Error
|
||||
for k, rule := range r.rules {
|
||||
if !strings.HasPrefix(string(k), firewall.ForwardingFormatPrefix) {
|
||||
continue
|
||||
}
|
||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %w", err))
|
||||
} else {
|
||||
delete(r.rules, k)
|
||||
}
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) addPostroutingRules() error {
|
||||
// First rule for outbound masquerade
|
||||
rule1 := []string{
|
||||
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
||||
"!", "-o", "lo",
|
||||
"-j", routingFinalNatJump,
|
||||
}
|
||||
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule1...); err != nil {
|
||||
return fmt.Errorf("add outbound masquerade rule: %w", err)
|
||||
}
|
||||
r.rules["static-nat-outbound"] = rule1
|
||||
|
||||
// Second rule for return traffic masquerade
|
||||
rule2 := []string{
|
||||
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||
"-o", r.wgIface.Name(),
|
||||
"-j", routingFinalNatJump,
|
||||
}
|
||||
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule2...); err != nil {
|
||||
return fmt.Errorf("add return masquerade rule: %w", err)
|
||||
}
|
||||
r.rules["static-nat-return"] = rule2
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||
func (r *family) addMSSClampingRules() error {
|
||||
overhead := uint16(ipv4TCPHeaderSize)
|
||||
if r.v6 {
|
||||
overhead = ipv6TCPHeaderSize
|
||||
}
|
||||
mss := r.mtu - overhead
|
||||
|
||||
// Add jump rule from FORWARD chain in mangle table to our custom chain
|
||||
jumpRule := []string{
|
||||
"-j", chainRTMSSCLAMP,
|
||||
}
|
||||
if err := r.iptablesClient.Insert(tableMangle, chainFORWARD, 1, jumpRule...); err != nil {
|
||||
return fmt.Errorf("add jump to MSS clamp chain: %w", err)
|
||||
}
|
||||
r.rules[jumpMSSClamp] = jumpRule
|
||||
|
||||
ruleOut := []string{
|
||||
"-o", r.wgIface.Name(),
|
||||
"-p", "tcp",
|
||||
"--tcp-flags", "SYN,RST", "SYN",
|
||||
"-j", "TCPMSS",
|
||||
"--set-mss", fmt.Sprintf("%d", mss),
|
||||
}
|
||||
if err := r.iptablesClient.Append(tableMangle, chainRTMSSCLAMP, ruleOut...); err != nil {
|
||||
return fmt.Errorf("add outbound MSS clamp rule: %w", err)
|
||||
}
|
||||
r.rules["mss-clamp-out"] = ruleOut
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) insertEstablishedRule(chain string) error {
|
||||
establishedRule := getConntrackEstablished()
|
||||
|
||||
err := r.iptablesClient.Insert(tableFilter, chain, 1, establishedRule...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert established rule: %w", err)
|
||||
}
|
||||
|
||||
ruleKey := firewall.RuleID("established-" + chain)
|
||||
r.rules[ruleKey] = establishedRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) addNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
|
||||
return fmt.Errorf("remove existing marking rule for %s: %w", pair.Destination, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
}
|
||||
|
||||
markValue := nbnet.PreroutingFwmarkMasquerade
|
||||
if pair.Inverse {
|
||||
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
|
||||
}
|
||||
|
||||
rule := []string{"-i", r.wgIface.Name()}
|
||||
if pair.Inverse {
|
||||
rule = []string{"!", "-i", r.wgIface.Name()}
|
||||
}
|
||||
|
||||
rule = append(rule,
|
||||
"-m", "conntrack",
|
||||
"--ctstate", "NEW",
|
||||
)
|
||||
sourceExp, err := r.applyNetwork("-s", pair.Source, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply network -s: %w", err)
|
||||
}
|
||||
destExp, err := r.applyNetwork("-d", pair.Destination, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply network -d: %w", err)
|
||||
}
|
||||
|
||||
rule = append(rule, sourceExp...)
|
||||
rule = append(rule, destExp...)
|
||||
rule = append(rule,
|
||||
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
|
||||
)
|
||||
|
||||
// Ensure nat rules come first, so the mark can be overwritten.
|
||||
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
|
||||
if err := r.iptablesClient.Insert(tableMangle, chainRTPRE, 1, rule...); err != nil {
|
||||
// TODO: rollback ipset counter
|
||||
return fmt.Errorf("add marking rule for %s: %w", pair.Destination, err)
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = rule
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) removeNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
|
||||
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
|
||||
return fmt.Errorf("remove marking rule for %s: %w", pair.Destination, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement ipset counter: %w", err)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("marking rule %s not found", ruleKey)
|
||||
}
|
||||
|
||||
r.updateState()
|
||||
return nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build integration && !android
|
||||
|
||||
package iptables
|
||||
|
||||
@@ -31,7 +31,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "failed to init iptables client")
|
||||
|
||||
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
manager, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "should return a valid iptables manager")
|
||||
require.NoError(t, manager.init(nil))
|
||||
|
||||
@@ -84,7 +84,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "failed to init iptables client")
|
||||
|
||||
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
manager, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
require.NoError(t, manager.init(nil))
|
||||
|
||||
@@ -157,7 +157,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
|
||||
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
manager, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
require.NoError(t, manager.init(nil))
|
||||
defer func() {
|
||||
@@ -219,13 +219,13 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "Failed to create iptables client")
|
||||
|
||||
r, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create router manager")
|
||||
r, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create family manager")
|
||||
require.NoError(t, r.init(nil))
|
||||
|
||||
defer func() {
|
||||
err := r.Reset()
|
||||
require.NoError(t, err, "Failed to reset router")
|
||||
require.NoError(t, err, "Failed to reset family")
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
@@ -334,62 +334,30 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddRouteFiltering failed")
|
||||
ruleKey, err := r.AddFilterRule(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddFilterRule failed")
|
||||
|
||||
// Check if the rule is in the internal map
|
||||
rule, ok := r.rules[ruleKey.ID()]
|
||||
assert.True(t, ok, "Rule not found in internal map")
|
||||
stored, ok := r.filters[ruleKey.ID()]
|
||||
require.True(t, ok, "rule not stored in filters")
|
||||
t.Logf("Internal rule: %v", stored.specs)
|
||||
|
||||
// Log the internal rule
|
||||
t.Logf("Internal rule: %v", rule)
|
||||
|
||||
// Check if the rule exists in iptables
|
||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWDIN, rule...)
|
||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWDIN, stored.specs...)
|
||||
assert.NoError(t, err, "Failed to check rule existence")
|
||||
assert.True(t, exists, "Rule not found in iptables")
|
||||
|
||||
var source firewall.Network
|
||||
if len(tt.sources) > 1 {
|
||||
source.Set = firewall.NewPrefixSet(tt.sources)
|
||||
} else if len(tt.sources) > 0 {
|
||||
source.Prefix = tt.sources[0]
|
||||
}
|
||||
// Verify rule content
|
||||
params := routeFilteringRuleParams{
|
||||
Source: source,
|
||||
Destination: firewall.Network{Prefix: tt.destination},
|
||||
Proto: tt.proto,
|
||||
SPort: tt.sPort,
|
||||
DPort: tt.dPort,
|
||||
Action: tt.action,
|
||||
}
|
||||
|
||||
expectedRule, err := r.genRouteRuleSpec(params, nil)
|
||||
require.NoError(t, err, "Failed to generate expected rule spec")
|
||||
|
||||
if tt.expectSet {
|
||||
setName := firewall.NewPrefixSet(tt.sources).HashedName()
|
||||
expectedRule, err = r.genRouteRuleSpec(params, nil)
|
||||
require.NoError(t, err, "Failed to generate expected rule spec with set")
|
||||
|
||||
// Check if the set was created
|
||||
_, exists := r.ipsetCounter.Get(setName)
|
||||
assert.True(t, exists, "IPSet not created")
|
||||
assert.NotEmpty(t, findSets(stored.specs), "Rule should reference an ipset")
|
||||
}
|
||||
|
||||
assert.Equal(t, expectedRule, rule, "Rule content mismatch")
|
||||
|
||||
// Clean up
|
||||
err = r.DeleteRouteRule(ruleKey)
|
||||
require.NoError(t, err, "Failed to delete rule")
|
||||
require.NoError(t, r.DeleteFilterRule(ruleKey), "Failed to delete rule")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindSetNameInRule(t *testing.T) {
|
||||
r := &router{}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
rule []string
|
||||
@@ -430,7 +398,7 @@ func TestFindSetNameInRule(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := r.findSets(tc.rule)
|
||||
result := findSets(tc.rule)
|
||||
|
||||
if len(result) != len(tc.expected) {
|
||||
t.Errorf("Expected %d sets, got %d. Sets found: %v", len(tc.expected), len(result), result)
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
package iptables
|
||||
|
||||
// Rule to handle management of rules
|
||||
type Rule struct {
|
||||
ruleID string
|
||||
ipsetName string
|
||||
import "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
|
||||
// Rule to handle management of rules. Source set membership (when the
|
||||
// rule was built against a shared hash:net ipset) is encoded in specs;
|
||||
// DeleteFilterRule recovers it via findSets so the refcounter can drop
|
||||
// the right reference.
|
||||
type Rule struct {
|
||||
id manager.RuleID
|
||||
specs []string
|
||||
mangleSpecs []string
|
||||
ip string
|
||||
chain string
|
||||
v6 bool
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *Rule) ID() string {
|
||||
return r.ruleID
|
||||
// ID returns the rule id
|
||||
func (r *Rule) ID() manager.RuleID {
|
||||
return r.id
|
||||
}
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
package iptables
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type ipList struct {
|
||||
ips map[string]struct{}
|
||||
}
|
||||
|
||||
func newIpList(ip string) *ipList {
|
||||
ips := make(map[string]struct{})
|
||||
ips[ip] = struct{}{}
|
||||
|
||||
return &ipList{
|
||||
ips: ips,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ipList) addIP(ip string) {
|
||||
s.ips[ip] = struct{}{}
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler
|
||||
func (s *ipList) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
IPs map[string]struct{} `json:"ips"`
|
||||
}{
|
||||
IPs: s.ips,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler
|
||||
func (s *ipList) UnmarshalJSON(data []byte) error {
|
||||
temp := struct {
|
||||
IPs map[string]struct{} `json:"ips"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return err
|
||||
}
|
||||
s.ips = temp.IPs
|
||||
|
||||
if temp.IPs == nil {
|
||||
temp.IPs = make(map[string]struct{})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type ipsetStore struct {
|
||||
ipsets map[string]*ipList
|
||||
}
|
||||
|
||||
func newIpsetStore() *ipsetStore {
|
||||
return &ipsetStore{
|
||||
ipsets: make(map[string]*ipList),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
|
||||
r, ok := s.ipsets[ipsetName]
|
||||
return r, ok
|
||||
}
|
||||
|
||||
func (s *ipsetStore) addIpList(ipsetName string, list *ipList) {
|
||||
s.ipsets[ipsetName] = list
|
||||
}
|
||||
|
||||
func (s *ipsetStore) deleteIpset(ipsetName string) {
|
||||
delete(s.ipsets, ipsetName)
|
||||
}
|
||||
|
||||
func (s *ipsetStore) ipsetNames() []string {
|
||||
names := make([]string, 0, len(s.ipsets))
|
||||
for name := range s.ipsets {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler
|
||||
func (s *ipsetStore) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(struct {
|
||||
IPSets map[string]*ipList `json:"ipsets"`
|
||||
}{
|
||||
IPSets: s.ipsets,
|
||||
})
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler
|
||||
func (s *ipsetStore) UnmarshalJSON(data []byte) error {
|
||||
temp := struct {
|
||||
IPSets map[string]*ipList `json:"ipsets"`
|
||||
}{}
|
||||
if err := json.Unmarshal(data, &temp); err != nil {
|
||||
return err
|
||||
}
|
||||
s.ipsets = temp.IPSets
|
||||
|
||||
if temp.IPSets == nil {
|
||||
temp.IPSets = make(map[string]*ipList)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -32,14 +32,12 @@ type ShutdownState struct {
|
||||
RouteRules routeRules `json:"route_rules,omitempty"`
|
||||
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
|
||||
|
||||
ACLEntries aclEntries `json:"acl_entries,omitempty"`
|
||||
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
|
||||
ACLEntries aclEntries `json:"acl_entries,omitempty"`
|
||||
|
||||
// IPv6 counterparts
|
||||
RouteRules6 routeRules `json:"route_rules_v6,omitempty"`
|
||||
RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"`
|
||||
ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"`
|
||||
ACLIPsetStore6 *ipsetStore `json:"acl_ipset_store_v6,omitempty"`
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Name() string {
|
||||
@@ -57,17 +55,14 @@ func (s *ShutdownState) Cleanup() error {
|
||||
}
|
||||
|
||||
if s.RouteRules != nil {
|
||||
ipt.router.rules = s.RouteRules
|
||||
ipt.family4.rules = s.RouteRules
|
||||
}
|
||||
if s.RouteIPsetCounter != nil {
|
||||
ipt.router.ipsetCounter.LoadData(s.RouteIPsetCounter)
|
||||
ipt.family4.ipsetCounter.LoadData(s.RouteIPsetCounter)
|
||||
}
|
||||
|
||||
if s.ACLEntries != nil {
|
||||
ipt.aclMgr.entries = s.ACLEntries
|
||||
}
|
||||
if s.ACLIPsetStore != nil {
|
||||
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
|
||||
ipt.family4.entries = s.ACLEntries
|
||||
}
|
||||
|
||||
// Clean up v6 state even if the current run has no IPv6.
|
||||
@@ -79,16 +74,13 @@ func (s *ShutdownState) Cleanup() error {
|
||||
}
|
||||
if ipt.hasIPv6() {
|
||||
if s.RouteRules6 != nil {
|
||||
ipt.router6.rules = s.RouteRules6
|
||||
ipt.family6.rules = s.RouteRules6
|
||||
}
|
||||
if s.RouteIPsetCounter6 != nil {
|
||||
ipt.router6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
|
||||
ipt.family6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
|
||||
}
|
||||
if s.ACLEntries6 != nil {
|
||||
ipt.aclMgr6.entries = s.ACLEntries6
|
||||
}
|
||||
if s.ACLIPsetStore6 != nil {
|
||||
ipt.aclMgr6.ipsetStore = s.ACLIPsetStore6
|
||||
ipt.family6.entries = s.ACLEntries6
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
23
client/firewall/iptables/testhelpers_linux_test.go
Normal file
23
client/firewall/iptables/testhelpers_linux_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
//go:build integration && !android
|
||||
|
||||
package iptables
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func pfx(ip net.IP) []netip.Prefix {
|
||||
if ip == nil {
|
||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||
}
|
||||
if ip.IsUnspecified() {
|
||||
if ip.To4() != nil {
|
||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||
}
|
||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||
}
|
||||
a, _ := netip.AddrFromSlice(ip)
|
||||
a = a.Unmap()
|
||||
return []netip.Prefix{netip.PrefixFrom(a, a.BitLen())}
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package manager
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sort"
|
||||
|
||||
@@ -16,6 +15,12 @@ import (
|
||||
// method but the IPv6 firewall components were not initialized.
|
||||
var ErrIPv6NotInitialized = errors.New("IPv6 firewall not initialized")
|
||||
|
||||
// ErrNoSources is returned when AddFilterRule is called with an empty
|
||||
// source list. "Match any source" must be expressed explicitly with a
|
||||
// /0 prefix; an empty list is a caller error and is rejected rather
|
||||
// than silently widening the rule to every source.
|
||||
var ErrNoSources = errors.New("rule has no sources")
|
||||
|
||||
const (
|
||||
ForwardingFormatPrefix = "netbird-fwd-"
|
||||
ForwardingFormat = "netbird-fwd-%s-%t"
|
||||
@@ -23,13 +28,22 @@ const (
|
||||
NatFormat = "netbird-nat-%s-%t"
|
||||
)
|
||||
|
||||
// RuleID identifies a firewall rule. It is a typed string so the
|
||||
// compiler catches accidental mixing with arbitrary string keys.
|
||||
// RuleID itself satisfies the Rule interface so callers can drop a
|
||||
// bare key into APIs like DeleteFilterRule without wrapping it.
|
||||
type RuleID string
|
||||
|
||||
// ID implements the Rule interface for a bare RuleID.
|
||||
func (r RuleID) ID() RuleID { return r }
|
||||
|
||||
// Rule abstraction should be implemented by each firewall manager
|
||||
//
|
||||
// Each firewall type for different OS can use different type
|
||||
// of the properties to hold data of the created rule
|
||||
type Rule interface {
|
||||
// ID returns the rule id
|
||||
ID() string
|
||||
ID() RuleID
|
||||
}
|
||||
|
||||
// RuleDirection is the traffic direction which a rule is applied
|
||||
@@ -101,43 +115,42 @@ type Manager interface {
|
||||
// AllowNetbird allows netbird interface traffic
|
||||
AllowNetbird() error
|
||||
|
||||
// AddPeerFiltering adds a rule to the firewall
|
||||
// AddFilterRule adds a packet-filtering rule to the firewall.
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
// If destination is the zero Network, the rule applies to traffic
|
||||
// inbound to this node, i.e. peer ACL semantics, installed in
|
||||
// the kernel's input chain. If destination is set (prefix or
|
||||
// set), the rule applies to forwarded traffic with that
|
||||
// destination, route ACL semantics, installed in the forward
|
||||
// chain.
|
||||
//
|
||||
// Note: Callers should call Flush() after adding rules to ensure
|
||||
// they are applied to the kernel and rule handles are refreshed.
|
||||
AddPeerFiltering(
|
||||
// sources may mix IPv4 and IPv6 prefixes; backends split by
|
||||
// family and return one rule per family. "Match any source" must
|
||||
// be expressed with an explicit /0 prefix; an empty sources list
|
||||
// is rejected with ErrNoSources so a zeroed list can never widen a
|
||||
// rule to every source.
|
||||
//
|
||||
// Note: callers should call Flush() after adding rules.
|
||||
AddFilterRule(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
sources []netip.Prefix,
|
||||
destination Network,
|
||||
proto Protocol,
|
||||
sPort *Port,
|
||||
dPort *Port,
|
||||
action Action,
|
||||
ipsetName string,
|
||||
) ([]Rule, error)
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
DeletePeerRule(rule Rule) error
|
||||
// DeleteFilterRule removes a filtering rule previously added via
|
||||
// AddFilterRule. The rule's own type identifies whether it lives
|
||||
// in the peer (input) or route (forward) path.
|
||||
DeleteFilterRule(rule Rule) error
|
||||
|
||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||
IsServerRouteSupported() bool
|
||||
|
||||
IsStateful() bool
|
||||
|
||||
AddRouteFiltering(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination Network,
|
||||
proto Protocol,
|
||||
sPort, dPort *Port,
|
||||
action Action,
|
||||
) (Rule, error)
|
||||
|
||||
// DeleteRouteRule deletes a routing rule
|
||||
DeleteRouteRule(rule Rule) error
|
||||
|
||||
// AddNatRule inserts a routing NAT rule
|
||||
AddNatRule(pair RouterPair) error
|
||||
|
||||
@@ -185,8 +198,8 @@ type Manager interface {
|
||||
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
|
||||
}
|
||||
|
||||
func GenKey(format string, pair RouterPair) string {
|
||||
return fmt.Sprintf(format, pair.ID, pair.Inverse)
|
||||
func GenKey(format string, pair RouterPair) RuleID {
|
||||
return RuleID(fmt.Sprintf(format, pair.ID, pair.Inverse))
|
||||
}
|
||||
|
||||
// LegacyManager defines the interface for legacy management operations
|
||||
@@ -242,6 +255,20 @@ func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
|
||||
return merged
|
||||
}
|
||||
|
||||
// UnmapPrefix normalizes a v4-mapped v6 prefix (::ffff:a.b.c.d) to its
|
||||
// plain v4 form, shifting the prefix length out of the 96-bit mapped
|
||||
// range. Other prefixes are returned unchanged. Keeping prefixes
|
||||
// unmapped ensures v4 rules match consistently and the match builders
|
||||
// read the correct address length.
|
||||
func UnmapPrefix(p netip.Prefix) netip.Prefix {
|
||||
addr := p.Addr()
|
||||
if !addr.Is4In6() {
|
||||
return p
|
||||
}
|
||||
bits := max(p.Bits()-96, 0)
|
||||
return netip.PrefixFrom(addr.Unmap(), bits)
|
||||
}
|
||||
|
||||
// SortPrefixes sorts the given slice of netip.Prefix in place.
|
||||
// It sorts first by IP address, then by prefix length (most specific to least specific).
|
||||
func SortPrefixes(prefixes []netip.Prefix) {
|
||||
|
||||
@@ -13,13 +13,13 @@ type ForwardRule struct {
|
||||
TranslatedPort Port
|
||||
}
|
||||
|
||||
func (r ForwardRule) ID() string {
|
||||
func (r ForwardRule) ID() RuleID {
|
||||
id := fmt.Sprintf("%s;%s;%s;%s",
|
||||
r.Protocol,
|
||||
r.DestinationPort.String(),
|
||||
r.TranslatedAddress.String(),
|
||||
r.TranslatedPort.String())
|
||||
return id
|
||||
return RuleID(id)
|
||||
}
|
||||
|
||||
func (r ForwardRule) String() string {
|
||||
|
||||
@@ -40,7 +40,8 @@ func (h Set) Comment() string {
|
||||
|
||||
// NewPrefixSet generates a unique name for an ipset based on the given prefixes.
|
||||
func NewPrefixSet(prefixes []netip.Prefix) Set {
|
||||
// sort for consistent naming
|
||||
// sort a copy for consistent naming without mutating the caller's slice
|
||||
prefixes = slices.Clone(prefixes)
|
||||
SortPrefixes(prefixes)
|
||||
|
||||
hash := sha256.New()
|
||||
|
||||
@@ -1,713 +0,0 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
// rules chains contains the effective ACL rules
|
||||
chainNameInputRules = "netbird-acl-input-rules"
|
||||
|
||||
// filter chains contains the rules that jump to the rules chains
|
||||
chainNameInputFilter = "netbird-acl-input-filter"
|
||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||
chainNameManglePrerouting = "netbird-mangle-prerouting"
|
||||
chainNameManglePostrouting = "netbird-mangle-postrouting"
|
||||
)
|
||||
|
||||
const flushError = "flush: %w"
|
||||
|
||||
type AclManager struct {
|
||||
rConn *nftables.Conn
|
||||
sConn *nftables.Conn
|
||||
wgIface iFaceMapper
|
||||
routingFwChainName string
|
||||
af addrFamily
|
||||
|
||||
workTable *nftables.Table
|
||||
chainInputRules *nftables.Chain
|
||||
chainPrerouting *nftables.Chain
|
||||
|
||||
ipsetStore *ipsetStore
|
||||
rules map[string]*Rule
|
||||
}
|
||||
|
||||
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) {
|
||||
// sConn is used for creating sets and adding/removing elements from them
|
||||
// it's differ then rConn (which does create new conn for each flush operation)
|
||||
// and is permanent. Using same connection for both type of operations
|
||||
// overloads netlink with high amount of rules ( > 10000)
|
||||
sConn, err := nftables.New(nftables.AsLasting())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create nf conn: %w", err)
|
||||
}
|
||||
|
||||
return &AclManager{
|
||||
rConn: &nftables.Conn{},
|
||||
sConn: sConn,
|
||||
wgIface: wgIface,
|
||||
workTable: table,
|
||||
routingFwChainName: routingFwChainName,
|
||||
af: familyForAddr(table.Family == nftables.TableFamilyIPv4),
|
||||
|
||||
ipsetStore: newIpsetStore(),
|
||||
rules: make(map[string]*Rule),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *AclManager) init(workTable *nftables.Table) error {
|
||||
m.workTable = workTable
|
||||
return m.createDefaultChains()
|
||||
}
|
||||
|
||||
// AddPeerFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *AclManager) AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
var ipset *nftables.Set
|
||||
if ipsetName != "" {
|
||||
var err error
|
||||
ipset, err = m.addIpToSet(ipsetName, ip)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
newRules := make([]firewall.Rule, 0, 2)
|
||||
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newRules = append(newRules, ioRule)
|
||||
return newRules, nil
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
r, ok := rule.(*Rule)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid rule type")
|
||||
}
|
||||
|
||||
if r.nftSet == nil {
|
||||
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
||||
log.Errorf("failed to delete rule: %v", err)
|
||||
}
|
||||
if r.mangleRule != nil {
|
||||
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
||||
log.Errorf("failed to delete mangle rule: %v", err)
|
||||
}
|
||||
}
|
||||
delete(m.rules, r.ID())
|
||||
return m.rConn.Flush()
|
||||
}
|
||||
|
||||
ips, ok := m.ipsetStore.ips(r.nftSet.Name)
|
||||
if !ok {
|
||||
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
||||
log.Errorf("failed to delete rule: %v", err)
|
||||
}
|
||||
if r.mangleRule != nil {
|
||||
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
||||
log.Errorf("failed to delete mangle rule: %v", err)
|
||||
}
|
||||
}
|
||||
delete(m.rules, r.ID())
|
||||
return m.rConn.Flush()
|
||||
}
|
||||
|
||||
if _, ok := ips[r.ip.String()]; ok {
|
||||
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: ipToBytes(r.ip, m.af)}})
|
||||
if err != nil {
|
||||
log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err)
|
||||
}
|
||||
if err := m.sConn.Flush(); err != nil {
|
||||
log.Debugf("flush error of set delete element, %s", r.nftSet.Name)
|
||||
return err
|
||||
}
|
||||
m.ipsetStore.DeleteIpFromSet(r.nftSet.Name, r.ip)
|
||||
}
|
||||
|
||||
// if after delete, set still contains other IPs,
|
||||
// no need to delete firewall rule and we should exit here
|
||||
if len(ips) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
||||
log.Errorf("failed to delete rule: %v", err)
|
||||
}
|
||||
if r.mangleRule != nil {
|
||||
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
||||
log.Errorf("failed to delete mangle rule: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
delete(m.rules, r.ID())
|
||||
m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name)
|
||||
|
||||
if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// we delete last IP from the set, that means we need to delete
|
||||
// set itself and associated firewall rule too
|
||||
m.rConn.FlushSet(r.nftSet)
|
||||
m.rConn.DelSet(r.nftSet)
|
||||
m.ipsetStore.deleteIpset(r.nftSet.Name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// createDefaultAllowRules creates default allow rules for the input and output chains
|
||||
func (m *AclManager) createDefaultAllowRules() error {
|
||||
expIn := []expr.Any{
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
|
||||
_ = m.rConn.InsertRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: m.chainInputRules,
|
||||
Position: 0,
|
||||
Exprs: expIn,
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush rule/chain/set operations from the buffer
|
||||
//
|
||||
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||
func (m *AclManager) Flush() error {
|
||||
if err := m.flushWithBackoff(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := m.refreshRuleHandles(m.chainInputRules, false); err != nil {
|
||||
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
||||
}
|
||||
if err := m.refreshRuleHandles(m.chainPrerouting, true); err != nil {
|
||||
log.Errorf("failed to refresh rule handles prerouting chain: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AclManager) addIOFiltering(
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipset *nftables.Set,
|
||||
) (*Rule, error) {
|
||||
ruleId := generatePeerRuleId(ip, proto, sPort, dPort, action, ipset)
|
||||
if r, ok := m.rules[ruleId]; ok {
|
||||
return &Rule{
|
||||
nftRule: r.nftRule,
|
||||
mangleRule: r.mangleRule,
|
||||
nftSet: r.nftSet,
|
||||
ruleID: r.ruleID,
|
||||
ip: ip,
|
||||
}, nil
|
||||
}
|
||||
|
||||
var expressions []expr.Any
|
||||
|
||||
if proto != firewall.ProtocolALL {
|
||||
expressions = append(expressions, &expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: m.af.protoOffset,
|
||||
Len: uint32(1),
|
||||
})
|
||||
|
||||
protoData, err := m.af.protoNum(proto)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert protocol to number: %v", err)
|
||||
}
|
||||
|
||||
expressions = append(expressions, &expr.Cmp{
|
||||
Register: 1,
|
||||
Op: expr.CmpOpEq,
|
||||
Data: []byte{protoData},
|
||||
})
|
||||
}
|
||||
|
||||
rawIP := ipToBytes(ip, m.af)
|
||||
// check if rawIP contains zeroed IPv4 0.0.0.0 value
|
||||
// in that case not add IP match expression into the rule definition
|
||||
if slices.ContainsFunc(rawIP, func(v byte) bool { return v != 0 }) {
|
||||
expressions = append(expressions,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: m.af.srcAddrOffset,
|
||||
Len: m.af.addrLen,
|
||||
},
|
||||
)
|
||||
// add individual IP for match if no ipset defined
|
||||
if ipset == nil {
|
||||
expressions = append(expressions,
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: rawIP,
|
||||
},
|
||||
)
|
||||
} else {
|
||||
expressions = append(expressions,
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: ipset.Name,
|
||||
SetID: ipset.ID,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
expressions = append(expressions, applyPort(sPort, true)...)
|
||||
expressions = append(expressions, applyPort(dPort, false)...)
|
||||
|
||||
mainExpressions := slices.Clone(expressions)
|
||||
|
||||
switch action {
|
||||
case firewall.ActionAccept:
|
||||
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictAccept})
|
||||
case firewall.ActionDrop:
|
||||
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||
}
|
||||
|
||||
userData := []byte(ruleId)
|
||||
|
||||
chain := m.chainInputRules
|
||||
rule := &nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: chain,
|
||||
Exprs: mainExpressions,
|
||||
UserData: userData,
|
||||
}
|
||||
|
||||
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||
var nftRule *nftables.Rule
|
||||
if action == firewall.ActionDrop {
|
||||
nftRule = m.rConn.InsertRule(rule)
|
||||
} else {
|
||||
nftRule = m.rConn.AddRule(rule)
|
||||
}
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush input rule %s: %v", ruleId, err)
|
||||
}
|
||||
|
||||
ruleStruct := &Rule{
|
||||
nftRule: nftRule,
|
||||
// best effort mangle rule
|
||||
mangleRule: m.createPreroutingRule(expressions, userData),
|
||||
nftSet: ipset,
|
||||
ruleID: ruleId,
|
||||
ip: ip,
|
||||
}
|
||||
m.rules[ruleId] = ruleStruct
|
||||
if ipset != nil {
|
||||
m.ipsetStore.AddReferenceToIpset(ipset.Name)
|
||||
}
|
||||
|
||||
return ruleStruct, nil
|
||||
}
|
||||
|
||||
func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
|
||||
if m.chainPrerouting == nil {
|
||||
log.Warn("prerouting chain is not created")
|
||||
return nil
|
||||
}
|
||||
|
||||
preroutingExprs := slices.Clone(expressions)
|
||||
|
||||
// interface
|
||||
preroutingExprs = append([]expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
}, preroutingExprs...)
|
||||
|
||||
// local destination and mark
|
||||
preroutingExprs = append(preroutingExprs,
|
||||
&expr.Fib{
|
||||
Register: 1,
|
||||
ResultADDRTYPE: true,
|
||||
FlagDADDR: true,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
|
||||
},
|
||||
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
)
|
||||
|
||||
nfRule := m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: m.chainPrerouting,
|
||||
Exprs: preroutingExprs,
|
||||
UserData: userData,
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
log.Errorf("failed to flush mangle rule %s: %v", string(userData), err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return nfRule
|
||||
}
|
||||
|
||||
func (m *AclManager) createDefaultChains() (err error) {
|
||||
// chainNameInputRules
|
||||
chain := m.createChain(chainNameInputRules)
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
m.chainInputRules = chain
|
||||
|
||||
// netbird-acl-input-filter
|
||||
// type filter hook input priority filter; policy accept;
|
||||
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
|
||||
m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
|
||||
m.addDropExpressions(chain, expr.MetaKeyIIFNAME)
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// netbird-acl-forward-filter
|
||||
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
|
||||
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
|
||||
m.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME)
|
||||
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
if err := m.allowRedirectedTraffic(chainFwFilter); err != nil {
|
||||
log.Errorf("failed to allow redirected traffic: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Makes redirected traffic originally destined for the host itself (now subject to the forward filter)
|
||||
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
||||
// netbird peer IP.
|
||||
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
||||
// Chain is created by route manager
|
||||
// TODO: move creation to a common place
|
||||
m.chainPrerouting = &nftables.Chain{
|
||||
Name: chainNameManglePrerouting,
|
||||
Table: m.workTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
}
|
||||
|
||||
m.addFwmarkToForward(chainFwFilter)
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
||||
m.rConn.InsertRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: chainFwFilter,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictJump,
|
||||
Chain: m.routingFwChainName,
|
||||
},
|
||||
}
|
||||
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: chainFwFilter,
|
||||
Exprs: expressions,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *AclManager) createChain(name string) *nftables.Chain {
|
||||
chain := &nftables.Chain{
|
||||
Name: name,
|
||||
Table: m.workTable,
|
||||
}
|
||||
|
||||
chain = m.rConn.AddChain(chain)
|
||||
|
||||
insertReturnTrafficRule(m.rConn, m.workTable, chain)
|
||||
|
||||
return chain
|
||||
}
|
||||
|
||||
func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain {
|
||||
polAccept := nftables.ChainPolicyAccept
|
||||
chain := &nftables.Chain{
|
||||
Name: name,
|
||||
Table: m.workTable,
|
||||
Hooknum: hookNum,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Policy: &polAccept,
|
||||
}
|
||||
|
||||
return m.rConn.AddChain(chain)
|
||||
}
|
||||
|
||||
func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
||||
}
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictJump,
|
||||
Chain: to,
|
||||
},
|
||||
}
|
||||
|
||||
_ = m.rConn.AddRule(&nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
})
|
||||
}
|
||||
|
||||
func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) {
|
||||
ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName)
|
||||
rawIP := ipToBytes(ip, m.af)
|
||||
if err != nil {
|
||||
if ipset, err = m.createSet(m.workTable, ipsetName); err != nil {
|
||||
return nil, fmt.Errorf("get set name: %v", err)
|
||||
}
|
||||
|
||||
m.ipsetStore.newIpset(ipset.Name)
|
||||
}
|
||||
|
||||
if m.ipsetStore.IsIpInSet(ipset.Name, ip) {
|
||||
return ipset, nil
|
||||
}
|
||||
|
||||
if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil {
|
||||
return nil, fmt.Errorf("add set element for the first time: %v", err)
|
||||
}
|
||||
|
||||
m.ipsetStore.AddIpToSet(ipset.Name, ip)
|
||||
|
||||
if err := m.sConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush add elements: %v", err)
|
||||
}
|
||||
|
||||
return ipset, nil
|
||||
}
|
||||
|
||||
// createSet in given table by name
|
||||
func (m *AclManager) createSet(table *nftables.Table, name string) (*nftables.Set, error) {
|
||||
ipset := &nftables.Set{
|
||||
Name: name,
|
||||
Table: table,
|
||||
Dynamic: true,
|
||||
KeyType: m.af.setKeyType,
|
||||
}
|
||||
|
||||
if err := m.rConn.AddSet(ipset, nil); err != nil {
|
||||
return nil, fmt.Errorf("create set: %v", err)
|
||||
}
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush created set: %v", err)
|
||||
}
|
||||
|
||||
return ipset, nil
|
||||
}
|
||||
|
||||
func (m *AclManager) flushWithBackoff() (err error) {
|
||||
backoff := 4
|
||||
backoffTime := 1000 * time.Millisecond
|
||||
for i := 0; ; i++ {
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to flush nftables: %v", err)
|
||||
if !strings.Contains(err.Error(), "busy") {
|
||||
return
|
||||
}
|
||||
log.Error("failed to flush nftables, retrying...")
|
||||
if i == backoff-1 {
|
||||
return err
|
||||
}
|
||||
time.Sleep(backoffTime)
|
||||
backoffTime *= 2
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) error {
|
||||
if m.workTable == nil || chain == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
list, err := m.rConn.GetRules(m.workTable, chain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, rule := range list {
|
||||
if len(rule.UserData) == 0 {
|
||||
continue
|
||||
}
|
||||
split := bytes.Split(rule.UserData, []byte(" "))
|
||||
r, ok := m.rules[string(split[0])]
|
||||
if ok {
|
||||
if mangle {
|
||||
*r.mangleRule = *rule
|
||||
} else {
|
||||
*r.nftRule = *rule
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func generatePeerRuleId(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
|
||||
rulesetID := ":" + string(proto) + ":"
|
||||
if sPort != nil {
|
||||
rulesetID += sPort.String()
|
||||
}
|
||||
rulesetID += ":"
|
||||
if dPort != nil {
|
||||
rulesetID += dPort.String()
|
||||
}
|
||||
rulesetID += ":"
|
||||
rulesetID += strconv.Itoa(int(action))
|
||||
if ipset == nil {
|
||||
return "ip:" + ip.String() + rulesetID
|
||||
}
|
||||
return "set:" + ipset.Name + rulesetID
|
||||
}
|
||||
|
||||
func ifname(n string) []byte {
|
||||
b := make([]byte, 16)
|
||||
copy(b, n+"\x00")
|
||||
return b
|
||||
}
|
||||
|
||||
|
||||
// ipToBytes converts net.IP to the correct byte length for the address family.
|
||||
func ipToBytes(ip net.IP, af addrFamily) []byte {
|
||||
if af.addrLen == 4 {
|
||||
return ip.To4()
|
||||
}
|
||||
return ip.To16()
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package nftables
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"golang.org/x/sys/unix"
|
||||
@@ -63,6 +64,14 @@ func familyForAddr(is4 bool) addrFamily {
|
||||
return afIPv6
|
||||
}
|
||||
|
||||
// zeroPrefix returns the family's unspecified prefix (/0).
|
||||
func (af addrFamily) zeroPrefix() netip.Prefix {
|
||||
if af.addrLen == net.IPv4len {
|
||||
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||
}
|
||||
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||
}
|
||||
|
||||
// protoNum converts a firewall protocol to the IP protocol number,
|
||||
// using the correct ICMP variant for the address family.
|
||||
func (af addrFamily) protoNum(protocol firewall.Protocol) (uint8, error) {
|
||||
|
||||
885
client/firewall/nftables/chains_linux.go
Normal file
885
client/firewall/nftables/chains_linux.go
Normal file
@@ -0,0 +1,885 @@
|
||||
//go:build !android
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
func (r *family) createContainers() error {
|
||||
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingFw,
|
||||
Table: r.workTable,
|
||||
})
|
||||
|
||||
prio := *nftables.ChainPriorityNATSource - 1
|
||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingNat,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPostrouting,
|
||||
Priority: &prio,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
r.chains[chainNameRoutingRdr] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameRoutingRdr,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityNATDest,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
r.chains[chainNameManglePostrouting] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameManglePostrouting,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPostrouting,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
|
||||
r.chains[chainNameManglePrerouting] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameManglePrerouting,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
|
||||
r.chains[chainNameMangleForward] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameMangleForward,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
|
||||
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
||||
|
||||
r.addPostroutingRules()
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("initialize tables: %v", err)
|
||||
}
|
||||
|
||||
if err := r.addMSSClampingRules(); err != nil {
|
||||
log.Errorf("failed to add MSS clamping rules: %s", err)
|
||||
}
|
||||
|
||||
if err := r.acceptForwardRules(); err != nil {
|
||||
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
||||
}
|
||||
|
||||
if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil {
|
||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||
}
|
||||
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
log.Errorf("failed to refresh rules: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupDataPlaneMark configures the fwmark for the data plane
|
||||
func (r *family) setupDataPlaneMark() error {
|
||||
if r.chains[chainNameManglePrerouting] == nil || r.chains[chainNameManglePostrouting] == nil {
|
||||
return errors.New("no mangle chains found")
|
||||
}
|
||||
|
||||
ctNew := getCtNewExprs()
|
||||
preExprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
}
|
||||
preExprs = append(preExprs, ctNew...)
|
||||
preExprs = append(preExprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkIn),
|
||||
},
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
)
|
||||
|
||||
preNftRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameManglePrerouting],
|
||||
Exprs: preExprs,
|
||||
}
|
||||
r.conn.AddRule(preNftRule)
|
||||
|
||||
postExprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
}
|
||||
postExprs = append(postExprs, ctNew...)
|
||||
postExprs = append(postExprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkOut),
|
||||
},
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
)
|
||||
|
||||
postNftRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameManglePostrouting],
|
||||
Exprs: postExprs,
|
||||
}
|
||||
r.conn.AddRule(postNftRule)
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) acceptForwardRules() error {
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := r.acceptFilterTableRules(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.acceptExternalChainsRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) acceptFilterTableRules() error {
|
||||
if r.filterTable == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
fw := "iptables"
|
||||
|
||||
defer func() {
|
||||
log.Debugf("Used %s to add accept forward and input rules", fw)
|
||||
}()
|
||||
|
||||
// Try iptables first and fallback to nftables if iptables is not available.
|
||||
// Use the correct protocol (iptables vs ip6tables) for the address family.
|
||||
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
|
||||
if err != nil {
|
||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||
|
||||
fw = "nftables"
|
||||
return r.acceptFilterRulesNftables(r.filterTable)
|
||||
}
|
||||
|
||||
if err := r.acceptFilterRulesIptables(ipt); err != nil {
|
||||
log.Warnf("iptables failed (table may be incompatible), falling back to nftables: %v", err)
|
||||
fw = "nftables"
|
||||
return r.acceptFilterRulesNftables(r.filterTable)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||
var merr *multierror.Error
|
||||
|
||||
for _, rule := range r.getAcceptForwardRules() {
|
||||
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err))
|
||||
} else {
|
||||
log.Debugf("added iptables forward rule: %v", rule)
|
||||
}
|
||||
}
|
||||
|
||||
inputRule := r.getAcceptInputRule()
|
||||
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err))
|
||||
} else {
|
||||
log.Debugf("added iptables input rule: %v", inputRule)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) getAcceptForwardRules() [][]string {
|
||||
intf := r.wgIface.Name()
|
||||
return [][]string{
|
||||
{"-i", intf, "-j", "ACCEPT"},
|
||||
{"-o", intf, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"},
|
||||
}
|
||||
}
|
||||
|
||||
func (r *family) getAcceptInputRule() []string {
|
||||
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
// acceptFilterRulesNftables adds accept rules to the ip filter table using nftables.
|
||||
// This is used when iptables is not available.
|
||||
func (r *family) acceptFilterRulesNftables(table *nftables.Table) error {
|
||||
intf := ifname(r.wgIface.Name())
|
||||
|
||||
forwardChain := &nftables.Chain{
|
||||
Name: chainNameForward,
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
}
|
||||
r.insertForwardAcceptRules(forwardChain, intf)
|
||||
|
||||
inputChain := &nftables.Chain{
|
||||
Name: chainNameInput,
|
||||
Table: table,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookInput,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
}
|
||||
r.insertInputAcceptRule(inputChain, intf)
|
||||
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables).
|
||||
// It dynamically finds chains at call time to handle chains that may have been created after startup.
|
||||
func (r *family) acceptExternalChainsRules() error {
|
||||
chains := r.findExternalChains()
|
||||
if len(chains) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
intf := ifname(r.wgIface.Name())
|
||||
for _, chain := range chains {
|
||||
r.applyExternalChainAccept(chain, intf)
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush external chain rules: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) applyExternalChainAccept(chain *nftables.Chain, intf []byte) {
|
||||
if chain.Hooknum == nil {
|
||||
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
|
||||
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
|
||||
|
||||
switch *chain.Hooknum {
|
||||
case *nftables.ChainHookForward:
|
||||
r.insertForwardAcceptRules(chain, intf)
|
||||
case *nftables.ChainHookInput:
|
||||
r.insertInputAcceptRule(chain, intf)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *family) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
|
||||
existing, err := r.existingNetbirdRulesInChain(chain)
|
||||
if err != nil {
|
||||
log.Warnf("skip forward accept rules in %s/%s: %v", chain.Table.Name, chain.Name, err)
|
||||
return
|
||||
}
|
||||
r.insertForwardIifRule(chain, intf, existing)
|
||||
r.insertForwardOifEstablishedRule(chain, intf, existing)
|
||||
}
|
||||
|
||||
func (r *family) insertForwardIifRule(chain *nftables.Chain, intf []byte, existing map[string]bool) {
|
||||
if existing[userDataAcceptForwardRuleIif] {
|
||||
return
|
||||
}
|
||||
r.conn.InsertRule(&nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
},
|
||||
UserData: []byte(userDataAcceptForwardRuleIif),
|
||||
})
|
||||
}
|
||||
|
||||
func (r *family) insertForwardOifEstablishedRule(chain *nftables.Chain, intf []byte, existing map[string]bool) {
|
||||
if existing[userDataAcceptForwardRuleOif] {
|
||||
return
|
||||
}
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
|
||||
}
|
||||
r.conn.InsertRule(&nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: append(exprs, getEstablishedExprs(2)...),
|
||||
UserData: []byte(userDataAcceptForwardRuleOif),
|
||||
})
|
||||
}
|
||||
|
||||
func (r *family) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
|
||||
existing, err := r.existingNetbirdRulesInChain(chain)
|
||||
if err != nil {
|
||||
log.Warnf("skip input accept rule in %s/%s: %v", chain.Table.Name, chain.Name, err)
|
||||
return
|
||||
}
|
||||
if existing[userDataAcceptInputRule] {
|
||||
return
|
||||
}
|
||||
r.conn.InsertRule(&nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
},
|
||||
UserData: []byte(userDataAcceptInputRule),
|
||||
})
|
||||
}
|
||||
|
||||
// existingNetbirdRulesInChain returns the set of netbird-owned UserData tags present in a chain; callers must bail on error since InsertRule is additive.
|
||||
func (r *family) existingNetbirdRulesInChain(chain *nftables.Chain) (map[string]bool, error) {
|
||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list rules: %w", err)
|
||||
}
|
||||
present := map[string]bool{}
|
||||
for _, rule := range rules {
|
||||
if !isNetbirdAcceptRuleTag(rule.UserData) {
|
||||
continue
|
||||
}
|
||||
present[string(rule.UserData)] = true
|
||||
}
|
||||
return present, nil
|
||||
}
|
||||
|
||||
func isNetbirdAcceptRuleTag(userData []byte) bool {
|
||||
switch string(userData) {
|
||||
case userDataAcceptForwardRuleIif,
|
||||
userDataAcceptForwardRuleOif,
|
||||
userDataAcceptInputRule:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *family) removeAcceptFilterRules() error {
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := r.removeFilterTableRules(); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.removeExternalChainsRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) removeFilterTableRules() error {
|
||||
if r.filterTable == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
|
||||
if err != nil {
|
||||
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
|
||||
return r.removeAcceptRulesFromTable(r.filterTable)
|
||||
}
|
||||
|
||||
if err := r.removeAcceptFilterRulesIptables(ipt); err != nil {
|
||||
log.Debugf("iptables removal failed (table may be incompatible), falling back to nftables: %v", err)
|
||||
return r.removeAcceptRulesFromTable(r.filterTable)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) removeAcceptRulesFromTable(table *nftables.Table) error {
|
||||
chains, err := r.conn.ListChainsOfTableFamily(table.Family)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %v", err)
|
||||
}
|
||||
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name != table.Name {
|
||||
continue
|
||||
}
|
||||
|
||||
if chain.Name != chainNameForward && chain.Name != chainNameInput {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := r.removeAcceptRulesFromChain(table, chain); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
func (r *family) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error {
|
||||
rules, err := r.conn.GetRules(table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err)
|
||||
}
|
||||
|
||||
for _, rule := range rules {
|
||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeExternalChainsRules removes our accept rules from all external chains.
|
||||
// This is deterministic - it scans for chains at removal time rather than relying on saved state,
|
||||
// ensuring cleanup works even after a crash or if chains changed.
|
||||
func (r *family) removeExternalChainsRules() error {
|
||||
chains := r.findExternalChains()
|
||||
if len(chains) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, chain := range chains {
|
||||
if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil {
|
||||
log.Warnf("remove rules from external chain %s/%s: %v", chain.Table.Name, chain.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks.
|
||||
// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal).
|
||||
func (r *family) findExternalChains() []*nftables.Chain {
|
||||
var chains []*nftables.Chain
|
||||
|
||||
families := []nftables.TableFamily{r.af.tableFamily, nftables.TableFamilyINet}
|
||||
|
||||
for _, family := range families {
|
||||
allChains, err := r.conn.ListChainsOfTableFamily(family)
|
||||
if err != nil {
|
||||
log.Debugf("list chains for family %d: %v", family, err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, chain := range allChains {
|
||||
if r.isExternalChain(chain) {
|
||||
chains = append(chains, chain)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return chains
|
||||
}
|
||||
|
||||
func (r *family) isExternalChain(chain *nftables.Chain) bool {
|
||||
if r.workTable != nil && chain.Table.Name == r.workTable.Name {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip firewalld-owned chains. Firewalld creates its chains with the
|
||||
// NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM.
|
||||
// We delegate acceptance to firewalld by trusting the interface instead.
|
||||
if chain.Table.Name == firewalldTableName {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip iptables/ip6tables-managed tables (adding nft-native rules breaks iptables-save compat)
|
||||
if (chain.Table.Family == nftables.TableFamilyIPv4 || chain.Table.Family == nftables.TableFamilyIPv6) && isIptablesTable(chain.Table.Name) {
|
||||
return false
|
||||
}
|
||||
|
||||
if chain.Type != nftables.ChainTypeFilter {
|
||||
return false
|
||||
}
|
||||
|
||||
if chain.Hooknum == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput
|
||||
}
|
||||
|
||||
func isIptablesTable(name string) bool {
|
||||
switch name {
|
||||
case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *family) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||
var merr *multierror.Error
|
||||
|
||||
for _, rule := range r.getAcceptForwardRules() {
|
||||
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
inputRule := r.getAcceptInputRule()
|
||||
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) createDefaultAllowRules() error {
|
||||
expIn := []expr.Any{
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
|
||||
_ = r.conn.InsertRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chainInputRules,
|
||||
Position: 0,
|
||||
Exprs: expIn,
|
||||
})
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush rule/chain/set operations from the buffer
|
||||
//
|
||||
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||
func (r *family) Flush() error {
|
||||
if err := r.flushWithBackoff(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.refreshRuleHandles(r.chainInputRules, false); err != nil {
|
||||
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
||||
}
|
||||
if err := r.refreshRuleHandles(r.chainPrerouting, true); err != nil {
|
||||
log.Errorf("failed to refresh rule handles prerouting chain: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
|
||||
if r.chainPrerouting == nil {
|
||||
log.Warn("prerouting chain is not created")
|
||||
return nil
|
||||
}
|
||||
|
||||
preroutingExprs := slices.Clone(expressions)
|
||||
|
||||
// interface
|
||||
preroutingExprs = append([]expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
}, preroutingExprs...)
|
||||
|
||||
// local destination and mark
|
||||
preroutingExprs = append(preroutingExprs,
|
||||
&expr.Fib{
|
||||
Register: 1,
|
||||
ResultADDRTYPE: true,
|
||||
FlagDADDR: true,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
|
||||
},
|
||||
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
)
|
||||
|
||||
nfRule := r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chainPrerouting,
|
||||
Exprs: preroutingExprs,
|
||||
UserData: userData,
|
||||
})
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
log.Errorf("failed to flush mangle rule %s: %v", string(userData), err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return nfRule
|
||||
}
|
||||
|
||||
func (r *family) createDefaultChains() (err error) {
|
||||
// chainNameInputRules
|
||||
chain := r.createChain(chainNameInputRules)
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
r.chainInputRules = chain
|
||||
|
||||
// netbird-acl-input-filter
|
||||
// type filter hook input priority filter; policy accept;
|
||||
chain = r.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
|
||||
r.addJumpRule(chain, r.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
|
||||
r.addDropExpressions(chain, expr.MetaKeyIIFNAME)
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// netbird-acl-forward-filter
|
||||
chainFwFilter := r.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
|
||||
r.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
|
||||
r.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME)
|
||||
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
if err := r.allowRedirectedTraffic(chainFwFilter); err != nil {
|
||||
log.Errorf("failed to allow redirected traffic: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Makes redirected traffic originally destined for the host itself (now subject to the forward filter)
|
||||
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
||||
// netbird peer IP.
|
||||
func (r *family) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
||||
r.chainPrerouting = r.chains[chainNameManglePrerouting]
|
||||
|
||||
r.addFwmarkToForward(chainFwFilter)
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
||||
r.conn.InsertRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: chainFwFilter,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (r *family) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictJump,
|
||||
Chain: r.routingFwChainName,
|
||||
},
|
||||
}
|
||||
|
||||
_ = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: chainFwFilter,
|
||||
Exprs: expressions,
|
||||
})
|
||||
}
|
||||
|
||||
func (r *family) createChain(name string) *nftables.Chain {
|
||||
chain := &nftables.Chain{
|
||||
Name: name,
|
||||
Table: r.workTable,
|
||||
}
|
||||
|
||||
chain = r.conn.AddChain(chain)
|
||||
|
||||
insertReturnTrafficRule(r.conn, r.workTable, chain)
|
||||
|
||||
return chain
|
||||
}
|
||||
|
||||
func (r *family) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain {
|
||||
polAccept := nftables.ChainPolicyAccept
|
||||
chain := &nftables.Chain{
|
||||
Name: name,
|
||||
Table: r.workTable,
|
||||
Hooknum: hookNum,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Policy: &polAccept,
|
||||
}
|
||||
|
||||
return r.conn.AddChain(chain)
|
||||
}
|
||||
|
||||
func (r *family) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
||||
}
|
||||
_ = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
|
||||
expressions := []expr.Any{
|
||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictJump,
|
||||
Chain: to,
|
||||
},
|
||||
}
|
||||
|
||||
_ = r.conn.AddRule(&nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: expressions,
|
||||
})
|
||||
}
|
||||
|
||||
func (r *family) flushWithBackoff() (err error) {
|
||||
backoff := 4
|
||||
backoffTime := 1000 * time.Millisecond
|
||||
for i := 0; ; i++ {
|
||||
err = r.conn.Flush()
|
||||
if err != nil {
|
||||
log.Debugf("failed to flush nftables: %v", err)
|
||||
if !strings.Contains(err.Error(), "busy") {
|
||||
return
|
||||
}
|
||||
log.Error("failed to flush nftables, retrying...")
|
||||
if i == backoff-1 {
|
||||
return err
|
||||
}
|
||||
time.Sleep(backoffTime)
|
||||
backoffTime *= 2
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *family) refreshRuleHandles(chain *nftables.Chain, mangle bool) error {
|
||||
if r.workTable == nil || chain == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
list, err := r.conn.GetRules(r.workTable, chain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, rule := range list {
|
||||
if len(rule.UserData) == 0 {
|
||||
continue
|
||||
}
|
||||
pr, ok := r.filters[firewall.RuleID(rule.UserData)]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if mangle {
|
||||
if pr.mangleRule != nil {
|
||||
*pr.mangleRule = *rule
|
||||
}
|
||||
} else {
|
||||
*pr.nftRule = *rule
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
533
client/firewall/nftables/dnat_linux.go
Normal file
533
client/firewall/nftables/dnat_linux.go
Normal file
@@ -0,0 +1,533 @@
|
||||
//go:build !android
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/google/nftables/xt"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
func (r *family) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ruleKey := rule.ID()
|
||||
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
protoNum, err := r.af.protoNum(rule.Protocol)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addDnatRedirect(rule, protoNum, ruleKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.addDnatMasq(rule, protoNum, ruleKey)
|
||||
|
||||
// Unlike iptables, there's no point in adding "out" rules in the forward chain here as our policy is ACCEPT.
|
||||
// To overcome DROP policies in other chains, we'd have to add rules to the chains there.
|
||||
// We also cannot just add "oif <iface> accept" there and filter in our own table as we don't know what is supposed to be allowed.
|
||||
// TODO: find chains with drop policies and add rules there
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush rules: %w", err)
|
||||
}
|
||||
|
||||
return &rule, nil
|
||||
}
|
||||
|
||||
func (r *family) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, ruleKey firewall.RuleID) error {
|
||||
dnatExprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{protoNum},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
}
|
||||
dnatExprs = append(dnatExprs, applyPort(&rule.DestinationPort, false)...)
|
||||
|
||||
// shifted translated port is not supported in nftables, so we hand this over to xtables
|
||||
if rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2 {
|
||||
if rule.TranslatedPort.Values[0] != rule.DestinationPort.Values[0] ||
|
||||
rule.TranslatedPort.Values[1] != rule.DestinationPort.Values[1] {
|
||||
return r.addXTablesRedirect(dnatExprs, ruleKey, rule)
|
||||
}
|
||||
}
|
||||
|
||||
additionalExprs, regProtoMin, regProtoMax, err := r.handleTranslatedPort(rule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dnatExprs = append(dnatExprs, additionalExprs...)
|
||||
|
||||
dnatExprs = append(dnatExprs,
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(r.af.tableFamily),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: regProtoMin,
|
||||
RegProtoMax: regProtoMax,
|
||||
},
|
||||
)
|
||||
|
||||
dnatRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingRdr],
|
||||
Exprs: dnatExprs,
|
||||
UserData: []byte(ruleKey + dnatSuffix),
|
||||
}
|
||||
r.conn.AddRule(dnatRule)
|
||||
r.rules[ruleKey+dnatSuffix] = dnatRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) handleTranslatedPort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||
switch {
|
||||
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
|
||||
return r.handlePortRange(rule)
|
||||
case len(rule.TranslatedPort.Values) == 0:
|
||||
return r.handleAddressOnly(rule)
|
||||
case len(rule.TranslatedPort.Values) == 1:
|
||||
return r.handleSinglePort(rule)
|
||||
default:
|
||||
return nil, 0, 0, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *family) handlePortRange(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||
exprs := []expr.Any{
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: rule.TranslatedAddress.AsSlice(),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 3,
|
||||
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[1]),
|
||||
},
|
||||
}
|
||||
return exprs, 2, 3, nil
|
||||
}
|
||||
|
||||
func (r *family) handleAddressOnly(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||
exprs := []expr.Any{
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: rule.TranslatedAddress.AsSlice(),
|
||||
},
|
||||
}
|
||||
return exprs, 0, 0, nil
|
||||
}
|
||||
|
||||
func (r *family) handleSinglePort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||
exprs := []expr.Any{
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: rule.TranslatedAddress.AsSlice(),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
|
||||
},
|
||||
}
|
||||
return exprs, 2, 0, nil
|
||||
}
|
||||
|
||||
func (r *family) addXTablesRedirect(dnatExprs []expr.Any, ruleKey firewall.RuleID, rule firewall.ForwardRule) error {
|
||||
dnatExprs = append(dnatExprs,
|
||||
&expr.Counter{},
|
||||
&expr.Target{
|
||||
Name: "DNAT",
|
||||
Rev: 2,
|
||||
Info: &xt.NatRange2{
|
||||
NatRange: xt.NatRange{
|
||||
Flags: uint(xt.NatRangeMapIPs | xt.NatRangeProtoSpecified | xt.NatRangeProtoOffset),
|
||||
MinIP: rule.TranslatedAddress.AsSlice(),
|
||||
MaxIP: rule.TranslatedAddress.AsSlice(),
|
||||
MinPort: rule.TranslatedPort.Values[0],
|
||||
MaxPort: rule.TranslatedPort.Values[1],
|
||||
},
|
||||
BasePort: rule.DestinationPort.Values[0],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
natTable := &nftables.Table{
|
||||
Name: tableNat,
|
||||
Family: r.af.tableFamily,
|
||||
}
|
||||
dnatRule := &nftables.Rule{
|
||||
Table: natTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: chainNameNatPrerouting,
|
||||
Table: natTable,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityNATDest,
|
||||
},
|
||||
Exprs: dnatExprs,
|
||||
UserData: []byte(ruleKey + dnatSuffix),
|
||||
}
|
||||
r.conn.AddRule(dnatRule)
|
||||
r.rules[ruleKey+dnatSuffix] = dnatRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey firewall.RuleID) {
|
||||
masqExprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{protoNum},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: r.af.dstAddrOffset,
|
||||
Len: r.af.addrLen,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: rule.TranslatedAddress.AsSlice(),
|
||||
},
|
||||
}
|
||||
|
||||
masqExprs = append(masqExprs, applyPort(&rule.TranslatedPort, false)...)
|
||||
masqExprs = append(masqExprs, &expr.Masq{})
|
||||
|
||||
masqRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: masqExprs,
|
||||
UserData: []byte(ruleKey + snatSuffix),
|
||||
}
|
||||
r.conn.AddRule(masqRule)
|
||||
r.rules[ruleKey+snatSuffix] = masqRule
|
||||
}
|
||||
|
||||
func (r *family) DeleteDNATRule(rule firewall.Rule) error {
|
||||
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
log.Errorf("%v", err)
|
||||
}
|
||||
|
||||
ruleKey := rule.ID()
|
||||
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
var needsFlush bool
|
||||
|
||||
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||
if dnatRule.Handle == 0 {
|
||||
log.Warnf("dnat rule %s has no handle, removing stale entry", ruleKey+dnatSuffix)
|
||||
delete(r.rules, ruleKey+dnatSuffix)
|
||||
} else if err := r.conn.DelRule(dnatRule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
|
||||
} else {
|
||||
needsFlush = true
|
||||
}
|
||||
}
|
||||
|
||||
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
|
||||
if masqRule.Handle == 0 {
|
||||
log.Warnf("snat rule %s has no handle, removing stale entry", ruleKey+snatSuffix)
|
||||
delete(r.rules, ruleKey+snatSuffix)
|
||||
} else if err := r.conn.DelRule(masqRule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
|
||||
} else {
|
||||
needsFlush = true
|
||||
}
|
||||
}
|
||||
|
||||
if needsFlush {
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||
}
|
||||
}
|
||||
|
||||
if merr == nil {
|
||||
delete(r.rules, ruleKey+dnatSuffix)
|
||||
delete(r.rules, ruleKey+snatSuffix)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
protoNum, err := r.af.protoNum(protocol)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 2,
|
||||
Data: []byte{protoNum},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 3,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 3,
|
||||
Data: binaryutil.BigEndian.PutUint16(originalPort),
|
||||
},
|
||||
}
|
||||
|
||||
bits := 32
|
||||
if localAddr.Is6() {
|
||||
bits = 128
|
||||
}
|
||||
exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...)
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: localAddr.AsSlice(),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(translatedPort),
|
||||
},
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(r.af.tableFamily),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: 2,
|
||||
RegProtoMax: 0,
|
||||
},
|
||||
)
|
||||
|
||||
dnatRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingRdr],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleID),
|
||||
}
|
||||
r.conn.AddRule(dnatRule)
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("add inbound DNAT rule: %w", err)
|
||||
}
|
||||
|
||||
r.rules[ruleID] = dnatRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
func (r *family) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||
|
||||
rule, exists := r.rules[ruleID]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
log.Warnf("inbound DNAT rule %s has no handle, removing stale entry", ruleID)
|
||||
delete(r.rules, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureNATOutputChain lazily creates the OUTPUT NAT chain on first use.
|
||||
func (r *family) ensureNATOutputChain() error {
|
||||
if _, exists := r.chains[chainNameNATOutput]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.chains[chainNameNATOutput] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameNATOutput,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookOutput,
|
||||
Priority: nftables.ChainPriorityNATDest,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
delete(r.chains, chainNameNATOutput)
|
||||
return fmt.Errorf("create NAT output chain: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
func (r *family) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||
|
||||
if _, exists := r.rules[ruleID]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.ensureNATOutputChain(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
protoNum, err := r.af.protoNum(protocol)
|
||||
if err != nil {
|
||||
return fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{protoNum},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 2,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(originalPort),
|
||||
},
|
||||
}
|
||||
|
||||
bits := 32
|
||||
if localAddr.Is6() {
|
||||
bits = 128
|
||||
}
|
||||
exprs = append(exprs, r.applyPrefix(netip.PrefixFrom(localAddr, bits), false)...)
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: localAddr.AsSlice(),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 2,
|
||||
Data: binaryutil.BigEndian.PutUint16(translatedPort),
|
||||
},
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(r.af.tableFamily),
|
||||
RegAddrMin: 1,
|
||||
RegProtoMin: 2,
|
||||
},
|
||||
)
|
||||
|
||||
dnatRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameNATOutput],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleID),
|
||||
}
|
||||
r.conn.AddRule(dnatRule)
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("add output DNAT rule: %w", err)
|
||||
}
|
||||
|
||||
r.rules[ruleID] = dnatRule
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
func (r *family) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
||||
|
||||
rule, exists := r.rules[ruleID]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
log.Warnf("output DNAT rule %s has no handle, removing stale entry", ruleID)
|
||||
delete(r.rules, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete output DNAT rule %s: %w", ruleID, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush delete output DNAT rule: %w", err)
|
||||
}
|
||||
delete(r.rules, ruleID)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -52,9 +52,10 @@ func (m *externalChainMonitor) start() {
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
m.cancel = cancel
|
||||
m.done = make(chan struct{})
|
||||
done := make(chan struct{})
|
||||
m.done = done
|
||||
|
||||
go m.run(ctx)
|
||||
go m.run(ctx, done)
|
||||
}
|
||||
|
||||
func (m *externalChainMonitor) stop() {
|
||||
@@ -72,8 +73,8 @@ func (m *externalChainMonitor) stop() {
|
||||
<-done
|
||||
}
|
||||
|
||||
func (m *externalChainMonitor) run(ctx context.Context) {
|
||||
defer close(m.done)
|
||||
func (m *externalChainMonitor) run(ctx context.Context, done chan struct{}) {
|
||||
defer close(done)
|
||||
|
||||
bo := &backoff.ExponentialBackOff{
|
||||
InitialInterval: externalMonitorInitInterval,
|
||||
|
||||
249
client/firewall/nftables/family_linux.go
Normal file
249
client/firewall/nftables/family_linux.go
Normal file
@@ -0,0 +1,249 @@
|
||||
//go:build !android
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/coreos/go-iptables/iptables"
|
||||
"github.com/google/nftables"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
)
|
||||
|
||||
const (
|
||||
tableNat = "nat"
|
||||
tableMangle = "mangle"
|
||||
tableRaw = "raw"
|
||||
tableSecurity = "security"
|
||||
|
||||
chainNameNatPrerouting = "PREROUTING"
|
||||
chainNameRoutingFw = "netbird-rt-fwd"
|
||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||
chainNameRoutingRdr = "netbird-rt-redirect"
|
||||
chainNameNATOutput = "netbird-nat-output"
|
||||
chainNameForward = "FORWARD"
|
||||
chainNameMangleForward = "netbird-mangle-forward"
|
||||
|
||||
// Peer ACL chain names.
|
||||
chainNameInputRules = "netbird-acl-input-rules"
|
||||
chainNameInputFilter = "netbird-acl-input-filter"
|
||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||
chainNameManglePrerouting = "netbird-mangle-prerouting"
|
||||
chainNameManglePostrouting = "netbird-mangle-postrouting"
|
||||
|
||||
flushError = "flush: %w"
|
||||
|
||||
firewalldTableName = "firewalld"
|
||||
|
||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||
userDataAcceptForwardRuleOif = "frwacceptoif"
|
||||
userDataAcceptInputRule = "inputaccept"
|
||||
|
||||
dnatSuffix firewall.RuleID = "_dnat"
|
||||
snatSuffix firewall.RuleID = "_snat"
|
||||
|
||||
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
|
||||
ipv4TCPHeaderSize = 40
|
||||
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
|
||||
ipv6TCPHeaderSize = 60
|
||||
|
||||
// maxPrefixesSet 1638 prefixes start to fail, taking some margin
|
||||
maxPrefixesSet = 1500
|
||||
refreshRulesMapError = "refresh rules map: %w"
|
||||
)
|
||||
|
||||
var (
|
||||
errFilterTableNotFound = fmt.Errorf("'filter' table not found")
|
||||
)
|
||||
|
||||
type setInput struct {
|
||||
set firewall.Set
|
||||
prefixes []netip.Prefix
|
||||
}
|
||||
|
||||
// family holds the per-address-family nftables state. One instance
|
||||
// handles route ACLs, peer ACLs, NAT, DNAT, and MSS clamping for a
|
||||
// single family; the top-level Manager owns one for v4 and another
|
||||
// for v6. The name predates the peer-ACL absorption; it's effectively
|
||||
// the per-family backend now.
|
||||
type family struct {
|
||||
conn *nftables.Conn
|
||||
workTable *nftables.Table
|
||||
filterTable *nftables.Table
|
||||
chains map[string]*nftables.Chain
|
||||
|
||||
// filters holds peer + route filter rules keyed by content hash.
|
||||
// AddFilterRule writes here; DeleteFilterRule looks up by id.
|
||||
filters map[firewall.RuleID]*Rule
|
||||
|
||||
// rules holds NAT, DNAT, and external accept rules (auxiliary
|
||||
// plumbing that isn't a filter rule).
|
||||
rules map[firewall.RuleID]*nftables.Rule
|
||||
|
||||
// Peer ACL chain handles.
|
||||
chainInputRules *nftables.Chain
|
||||
chainPrerouting *nftables.Chain
|
||||
routingFwChainName string
|
||||
|
||||
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
|
||||
|
||||
af addrFamily
|
||||
wgIface iFaceMapper
|
||||
ipFwdState *ipfwdstate.IPForwardingState
|
||||
legacyManagement bool
|
||||
mtu uint16
|
||||
}
|
||||
|
||||
func newFamily(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*family, error) {
|
||||
r := &family{
|
||||
conn: &nftables.Conn{},
|
||||
workTable: workTable,
|
||||
chains: make(map[string]*nftables.Chain),
|
||||
filters: make(map[firewall.RuleID]*Rule),
|
||||
rules: make(map[firewall.RuleID]*nftables.Rule),
|
||||
routingFwChainName: chainNameRoutingFw,
|
||||
af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4),
|
||||
wgIface: wgIface,
|
||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||
mtu: mtu,
|
||||
}
|
||||
|
||||
r.ipsetCounter = refcounter.New(
|
||||
r.createIpSet,
|
||||
r.deleteIpSet,
|
||||
)
|
||||
|
||||
var err error
|
||||
r.filterTable, err = r.loadFilterTable()
|
||||
if err != nil {
|
||||
log.Debugf("ip filter table not found: %v", err)
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *family) init(workTable *nftables.Table) error {
|
||||
r.workTable = workTable
|
||||
|
||||
if err := r.removeAcceptFilterRules(); err != nil {
|
||||
log.Errorf("failed to clean up rules from filter table: %s", err)
|
||||
}
|
||||
|
||||
if err := r.createContainers(); err != nil {
|
||||
return fmt.Errorf("create containers: %w", err)
|
||||
}
|
||||
|
||||
if err := r.setupDataPlaneMark(); err != nil {
|
||||
log.Errorf("failed to set up data plane mark: %v", err)
|
||||
}
|
||||
|
||||
if err := r.createDefaultChains(); err != nil {
|
||||
return fmt.Errorf("create default acl chains: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset cleans existing nftables filter table rules from the system
|
||||
func (r *family) Reset() error {
|
||||
// clear without deleting the ipsets, the nf table will be deleted by the caller
|
||||
r.ipsetCounter.Clear()
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := r.removeAcceptFilterRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
|
||||
}
|
||||
|
||||
if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if err := r.removeNatPreroutingRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) loadFilterTable() (*nftables.Table, error) {
|
||||
tables, err := r.conn.ListTablesOfFamily(r.af.tableFamily)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list tables: %w", err)
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
if table.Name == "filter" {
|
||||
return table, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errFilterTableNotFound
|
||||
}
|
||||
|
||||
func hookName(hook *nftables.ChainHook) string {
|
||||
if hook == nil {
|
||||
return "unknown"
|
||||
}
|
||||
switch *hook {
|
||||
case *nftables.ChainHookForward:
|
||||
return chainNameForward
|
||||
case *nftables.ChainHookInput:
|
||||
return chainNameInput
|
||||
default:
|
||||
return fmt.Sprintf("hook(%d)", *hook)
|
||||
}
|
||||
}
|
||||
|
||||
func familyName(family nftables.TableFamily) string {
|
||||
switch family {
|
||||
case nftables.TableFamilyIPv4:
|
||||
return "ip"
|
||||
case nftables.TableFamilyIPv6:
|
||||
return "ip6"
|
||||
case nftables.TableFamilyINet:
|
||||
return "inet"
|
||||
default:
|
||||
return fmt.Sprintf("family(%d)", family)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *family) iptablesProto() iptables.Protocol {
|
||||
if r.af.tableFamily == nftables.TableFamilyIPv6 {
|
||||
return iptables.ProtocolIPv6
|
||||
}
|
||||
return iptables.ProtocolIPv4
|
||||
}
|
||||
|
||||
func (r *family) refreshRulesMap() error {
|
||||
var merr *multierror.Error
|
||||
newRules := make(map[firewall.RuleID]*nftables.Rule)
|
||||
for _, chain := range r.chains {
|
||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("list rules for chain %s: %w", chain.Name, err))
|
||||
// preserve existing entries for this chain since we can't verify their state
|
||||
for k, v := range r.rules {
|
||||
if v.Chain != nil && v.Chain.Name == chain.Name {
|
||||
newRules[k] = v
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 {
|
||||
newRules[firewall.RuleID(rule.UserData)] = rule
|
||||
}
|
||||
}
|
||||
}
|
||||
r.rules = newRules
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
444
client/firewall/nftables/filter_linux.go
Normal file
444
client/firewall/nftables/filter_linux.go
Normal file
@@ -0,0 +1,444 @@
|
||||
//go:build !android
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||
)
|
||||
|
||||
// AddFilterRule installs one nftables packet-filter rule. With
|
||||
// destination empty the rule goes to the peer ACL input chain plus a
|
||||
// paired prerouting mangle rule for the redirect mark. With
|
||||
// destination set (prefix or named set) it goes to the route ACL
|
||||
// forward chain. Multi-source rules collapse to one nftables rule
|
||||
// backed by the shared refcounted hash:net set.
|
||||
func (r *family) AddFilterRule(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
isRoute := destination.IsPrefix() || destination.IsSet()
|
||||
|
||||
// Peer ACL with no sources is the v4 wildcard. Route paths never
|
||||
// hit this branch; their callers always carry a destination so the
|
||||
// source list can legitimately be empty.
|
||||
if !isRoute && len(sources) == 0 {
|
||||
sources = []netip.Prefix{r.af.zeroPrefix()}
|
||||
}
|
||||
|
||||
ruleID := nbid.GenerateRuleID(sources, destination, proto, sPort, dPort, action)
|
||||
if existing, ok := r.filters[ruleID]; ok {
|
||||
return existing, nil
|
||||
}
|
||||
|
||||
srcExprs, err := r.applyNetwork(sourceNetwork(sources), sources, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("apply source: %w", err)
|
||||
}
|
||||
|
||||
exprs, err := r.buildFilterExprs(srcExprs, destination, proto, sPort, dPort, isRoute)
|
||||
if err != nil {
|
||||
r.dropSourceMatch(srcExprs)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mainExprs := slices.Clone(exprs)
|
||||
verdict := expr.VerdictAccept
|
||||
if action == firewall.ActionDrop {
|
||||
verdict = expr.VerdictDrop
|
||||
}
|
||||
mainExprs = append(mainExprs, &expr.Verdict{Kind: verdict})
|
||||
|
||||
chain := r.chainInputRules
|
||||
if isRoute {
|
||||
chain = r.chains[chainNameRoutingFw]
|
||||
}
|
||||
|
||||
userData := []byte(ruleID)
|
||||
nftRule := &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: chain,
|
||||
Exprs: mainExprs,
|
||||
UserData: userData,
|
||||
}
|
||||
if action == firewall.ActionDrop {
|
||||
nftRule = r.conn.InsertRule(nftRule)
|
||||
} else {
|
||||
nftRule = r.conn.AddRule(nftRule)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
r.dropSourceMatch(exprs)
|
||||
return nil, fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
rule := &Rule{
|
||||
nftRule: nftRule,
|
||||
sources: sources,
|
||||
id: ruleID,
|
||||
}
|
||||
if !isRoute {
|
||||
rule.mangleRule = r.createPreroutingRule(exprs, userData)
|
||||
}
|
||||
r.filters[ruleID] = rule
|
||||
|
||||
if isRoute {
|
||||
log.Debugf("added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v",
|
||||
sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// buildFilterExprs assembles the non-verdict portion of a filter
|
||||
// rule. Route rules use Meta L4PROTO + Counter; peer rules read the
|
||||
// IP-header protocol byte via Payload and skip the counter, matching
|
||||
// the historical shapes so the per-rule kernel state is identical to
|
||||
// pre-unification.
|
||||
func (r *family) buildFilterExprs(
|
||||
srcExprs []expr.Any,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort, dPort *firewall.Port,
|
||||
isRoute bool,
|
||||
) ([]expr.Any, error) {
|
||||
var exprs []expr.Any
|
||||
|
||||
if isRoute {
|
||||
exprs = append(exprs, srcExprs...)
|
||||
|
||||
destExprs, err := r.applyNetwork(destination, nil, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("apply destination: %w", err)
|
||||
}
|
||||
exprs = append(exprs, destExprs...)
|
||||
|
||||
if proto != firewall.ProtocolALL {
|
||||
protoNum, err := r.af.protoNum(proto)
|
||||
if err != nil {
|
||||
r.dropSourceMatch(destExprs)
|
||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
exprs = append(exprs,
|
||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{protoNum}},
|
||||
)
|
||||
exprs = append(exprs, applyPort(sPort, true)...)
|
||||
exprs = append(exprs, applyPort(dPort, false)...)
|
||||
}
|
||||
|
||||
exprs = append(exprs, &expr.Counter{})
|
||||
return exprs, nil
|
||||
}
|
||||
|
||||
// Peer ACL shape: protocol header read first, then source, then ports.
|
||||
if proto != firewall.ProtocolALL {
|
||||
protoNum, err := r.af.protoNum(proto)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||
}
|
||||
exprs = append(exprs,
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: r.af.protoOffset,
|
||||
Len: 1,
|
||||
},
|
||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{protoNum}},
|
||||
)
|
||||
}
|
||||
exprs = append(exprs, srcExprs...)
|
||||
exprs = append(exprs, applyPort(sPort, true)...)
|
||||
exprs = append(exprs, applyPort(dPort, false)...)
|
||||
return exprs, nil
|
||||
}
|
||||
|
||||
func (r *family) hasRule(id firewall.RuleID) bool {
|
||||
_, ok := r.filters[id]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r *family) hasDNATRule(id firewall.RuleID) bool {
|
||||
_, ok := r.rules[id+dnatSuffix]
|
||||
return ok
|
||||
}
|
||||
|
||||
// DeleteFilterRule removes a previously installed filter rule. Source
|
||||
// set references are recovered from the stored rule's expressions via
|
||||
// findSets and dropped from the shared refcounter.
|
||||
func (r *family) DeleteFilterRule(rule firewall.Rule) error {
|
||||
ruleID := rule.ID()
|
||||
pr, ok := r.filters[ruleID]
|
||||
if !ok {
|
||||
log.Debugf("filter rule %s not found", ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// A freshly added rule carries no handle until it is read back from
|
||||
// the kernel, and Flush only refreshes the peer chains. Pull live
|
||||
// handles for this rule's chain before deciding it is stale so route
|
||||
// rules (which Flush never refreshes) can actually be deleted.
|
||||
if pr.nftRule.Handle == 0 {
|
||||
if err := r.refreshRuleHandles(pr.nftRule.Chain, false); err != nil {
|
||||
log.Warnf("refresh handles for chain %s: %v", pr.nftRule.Chain.Name, err)
|
||||
}
|
||||
if pr.mangleRule != nil {
|
||||
if err := r.refreshRuleHandles(r.chainPrerouting, true); err != nil {
|
||||
log.Warnf("refresh mangle handles: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if pr.nftRule.Handle == 0 {
|
||||
log.Warnf("filter rule %s has no handle, removing stale entry", ruleID)
|
||||
r.dropSourceMatch(pr.nftRule.Exprs)
|
||||
delete(r.filters, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(pr.nftRule); err != nil {
|
||||
log.Errorf("queue rule delete: %v", err)
|
||||
}
|
||||
if pr.mangleRule != nil {
|
||||
if err := r.conn.DelRule(pr.mangleRule); err != nil {
|
||||
log.Errorf("queue mangle rule delete: %v", err)
|
||||
}
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush delete %s: %w", ruleID, err)
|
||||
}
|
||||
|
||||
r.dropSourceMatch(pr.nftRule.Exprs)
|
||||
delete(r.filters, ruleID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) decrementSetCounter(rule *nftables.Rule) error {
|
||||
if r.ipsetCounter == nil {
|
||||
return nil
|
||||
}
|
||||
sets := findSets(rule)
|
||||
|
||||
var merr *multierror.Error
|
||||
for _, setName := range sets {
|
||||
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("decrement set counter: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
// findSets scans an nftables rule's expressions for expr.Lookup and
|
||||
// returns the named sets in occurrence order. Used at delete time to
|
||||
// drop ipsetCounter references; peer and route ACLs go through it.
|
||||
func findSets(rule *nftables.Rule) []string {
|
||||
var sets []string
|
||||
for _, e := range rule.Exprs {
|
||||
if lookup, ok := e.(*expr.Lookup); ok {
|
||||
sets = append(sets, lookup.SetName)
|
||||
}
|
||||
}
|
||||
return sets
|
||||
}
|
||||
|
||||
func (r *family) applyNetwork(
|
||||
network firewall.Network,
|
||||
setPrefixes []netip.Prefix,
|
||||
isSource bool,
|
||||
) ([]expr.Any, error) {
|
||||
if network.IsSet() {
|
||||
exprs, err := r.getIpSet(network.Set, setPrefixes, isSource)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("source: %w", err)
|
||||
}
|
||||
return exprs, nil
|
||||
}
|
||||
|
||||
if network.IsPrefix() {
|
||||
return r.applyPrefix(network.Prefix, isSource), nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// applyPrefix generates nftables expressions for a CIDR prefix
|
||||
func (r *family) applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
|
||||
return prefixMatchExprs(r.af, prefix, isSource)
|
||||
}
|
||||
|
||||
// prefixMatchExprs is the family-aware match sequence for a CIDR
|
||||
// prefix. /0 returns nil; a host prefix (full bit length for the
|
||||
// family) skips the bitwise step since the mask is all-ones. Shared
|
||||
// between family and aclManager so both treat single prefixes
|
||||
// identically.
|
||||
func prefixMatchExprs(af addrFamily, prefix netip.Prefix, isSource bool) []expr.Any {
|
||||
offset := af.dstAddrOffset
|
||||
if isSource {
|
||||
offset = af.srcAddrOffset
|
||||
}
|
||||
|
||||
ones := prefix.Bits()
|
||||
if ones == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
payload := &expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: offset,
|
||||
Len: af.addrLen,
|
||||
}
|
||||
cmp := &expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: prefix.Masked().Addr().AsSlice(),
|
||||
}
|
||||
|
||||
if ones == af.totalBits {
|
||||
return []expr.Any{payload, cmp}
|
||||
}
|
||||
|
||||
mask := net.CIDRMask(ones, af.totalBits)
|
||||
xor := make([]byte, af.addrLen)
|
||||
return []expr.Any{
|
||||
payload,
|
||||
&expr.Bitwise{
|
||||
DestRegister: 1,
|
||||
SourceRegister: 1,
|
||||
Len: af.addrLen,
|
||||
Mask: mask,
|
||||
Xor: xor,
|
||||
},
|
||||
cmp,
|
||||
}
|
||||
}
|
||||
|
||||
func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
||||
if port == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var exprs []expr.Any
|
||||
|
||||
offset := uint32(2) // Default offset for destination port
|
||||
if isSource {
|
||||
offset = 0 // Offset for source port
|
||||
}
|
||||
|
||||
exprs = append(exprs, &expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: offset,
|
||||
Len: 2,
|
||||
})
|
||||
|
||||
if port.IsRange && len(port.Values) == 2 {
|
||||
// Handle port range
|
||||
exprs = append(exprs,
|
||||
&expr.Range{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
FromData: binaryutil.BigEndian.PutUint16(port.Values[0]),
|
||||
ToData: binaryutil.BigEndian.PutUint16(port.Values[1]),
|
||||
},
|
||||
)
|
||||
} else {
|
||||
// Handle single port or multiple ports
|
||||
for i, p := range port.Values {
|
||||
if i > 0 {
|
||||
// Add a bitwise OR operation between port checks
|
||||
exprs = append(exprs, &expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: []byte{0x00, 0x00, 0xff, 0xff},
|
||||
Xor: []byte{0x00, 0x00, 0x00, 0x00},
|
||||
})
|
||||
}
|
||||
exprs = append(exprs, &expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(p),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return exprs
|
||||
}
|
||||
|
||||
func getCtNewExprs() []expr.Any {
|
||||
return []expr.Any{
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeySTATE,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
|
||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// sourceNetwork classifies a source-prefix list into the firewall.Network
|
||||
// shape the rest of the spec-builder consumes: empty for match-any, a
|
||||
// single prefix inline, or an ipset for multiple sources.
|
||||
func sourceNetwork(sources []netip.Prefix) firewall.Network {
|
||||
switch {
|
||||
case len(sources) == 0:
|
||||
return firewall.Network{}
|
||||
case len(sources) == 1 && sources[0].Bits() == 0:
|
||||
return firewall.Network{}
|
||||
case len(sources) == 1:
|
||||
return firewall.Network{Prefix: sources[0]}
|
||||
default:
|
||||
return firewall.Network{Set: firewall.NewPrefixSet(sources)}
|
||||
}
|
||||
}
|
||||
|
||||
// dropSourceMatch undoes whatever the source/destination match
|
||||
// reserved. Safe to call when the spec is empty or holds only inline
|
||||
// matchers.
|
||||
func (r *family) dropSourceMatch(exprs []expr.Any) {
|
||||
if r.ipsetCounter == nil {
|
||||
return
|
||||
}
|
||||
for _, e := range exprs {
|
||||
lookup, ok := e.(*expr.Lookup)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if _, err := r.ipsetCounter.Decrement(lookup.SetName); err != nil {
|
||||
log.Errorf("rollback ipset decrement %s: %v", lookup.SetName, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ifname(n string) []byte {
|
||||
b := make([]byte, 16)
|
||||
copy(b, n+"\x00")
|
||||
return b
|
||||
}
|
||||
176
client/firewall/nftables/ipset_linux.go
Normal file
176
client/firewall/nftables/ipset_linux.go
Normal file
@@ -0,0 +1,176 @@
|
||||
//go:build !android
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/expr"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
)
|
||||
|
||||
func (r *family) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bool) ([]expr.Any, error) {
|
||||
ref, err := r.ipsetCounter.Increment(set.HashedName(), setInput{
|
||||
set: set,
|
||||
prefixes: prefixes,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||
}
|
||||
|
||||
return r.getIpSetExprs(ref, isSource)
|
||||
}
|
||||
|
||||
func (r *family) createIpSet(setName string, input setInput) (*nftables.Set, error) {
|
||||
// overlapping prefixes will result in an error, so we need to merge them
|
||||
prefixes := firewall.MergeIPRanges(input.prefixes)
|
||||
|
||||
nfset := &nftables.Set{
|
||||
Name: setName,
|
||||
Comment: input.set.Comment(),
|
||||
Table: r.workTable,
|
||||
// required for prefixes
|
||||
Interval: true,
|
||||
KeyType: r.af.setKeyType,
|
||||
}
|
||||
|
||||
elements := r.convertPrefixesToSet(prefixes)
|
||||
nElements := len(elements)
|
||||
|
||||
maxElements := maxPrefixesSet * 2
|
||||
initialElements := elements[:min(maxElements, nElements)]
|
||||
|
||||
if err := r.conn.AddSet(nfset, initialElements); err != nil {
|
||||
return nil, fmt.Errorf("error adding set %s: %w", setName, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush error: %w", err)
|
||||
}
|
||||
log.Debugf("Created new ipset: %s with %d initial prefixes (total prefixes %d)", setName, len(initialElements)/2, len(prefixes))
|
||||
|
||||
var subEnd int
|
||||
for subStart := maxElements; subStart < nElements; subStart += maxElements {
|
||||
subEnd = min(subStart+maxElements, nElements)
|
||||
subElement := elements[subStart:subEnd]
|
||||
nSubPrefixes := len(subElement) / 2
|
||||
log.Tracef("Adding new prefixes (%d) in ipset: %s", nSubPrefixes, setName)
|
||||
if err := r.conn.SetAddElements(nfset, subElement); err != nil {
|
||||
return nil, fmt.Errorf("error adding prefixes (%d) to set %s: %w", nSubPrefixes, setName, err)
|
||||
}
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf("flush error: %w", err)
|
||||
}
|
||||
log.Debugf("Added new prefixes (%d) in ipset: %s", nSubPrefixes, setName)
|
||||
}
|
||||
|
||||
log.Infof("Created new ipset: %s with %d prefixes", setName, len(prefixes))
|
||||
return nfset, nil
|
||||
}
|
||||
|
||||
func (r *family) convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
|
||||
var elements []nftables.SetElement
|
||||
for _, prefix := range prefixes {
|
||||
// nftables needs half-open intervals [firstIP, lastIP) for prefixes
|
||||
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
|
||||
firstIP := prefix.Addr()
|
||||
lastIP := calculateLastIP(prefix).Next()
|
||||
|
||||
elements = append(elements,
|
||||
// the nft tool also adds a zero-address IntervalEnd element, see https://github.com/google/nftables/issues/247
|
||||
// nftables.SetElement{Key: make([]byte, r.af.addrLen), IntervalEnd: true},
|
||||
nftables.SetElement{Key: firstIP.AsSlice()},
|
||||
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
|
||||
)
|
||||
}
|
||||
return elements
|
||||
}
|
||||
|
||||
// calculateLastIP determines the last IP in a given prefix.
|
||||
func calculateLastIP(prefix netip.Prefix) netip.Addr {
|
||||
masked := prefix.Masked()
|
||||
if masked.Addr().Is4() {
|
||||
hostMask := ^uint32(0) >> masked.Bits()
|
||||
lastIP := uint32FromNetipAddr(masked.Addr()) | hostMask
|
||||
return netip.AddrFrom4(uint32ToBytes(lastIP))
|
||||
}
|
||||
|
||||
// IPv6: set host bits to all 1s
|
||||
b := masked.Addr().As16()
|
||||
bits := masked.Bits()
|
||||
for i := bits; i < 128; i++ {
|
||||
b[i/8] |= 1 << (7 - i%8)
|
||||
}
|
||||
return netip.AddrFrom16(b)
|
||||
}
|
||||
|
||||
// Utility function to convert netip.Addr to uint32.
|
||||
func uint32FromNetipAddr(addr netip.Addr) uint32 {
|
||||
b := addr.As4()
|
||||
return binary.BigEndian.Uint32(b[:])
|
||||
}
|
||||
|
||||
// Utility function to convert uint32 to a netip-compatible byte slice.
|
||||
func uint32ToBytes(ip uint32) [4]byte {
|
||||
var b [4]byte
|
||||
binary.BigEndian.PutUint32(b[:], ip)
|
||||
return b
|
||||
}
|
||||
|
||||
func (r *family) deleteIpSet(setName string, nfset *nftables.Set) error {
|
||||
r.conn.DelSet(nfset)
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
log.Debugf("Deleted unused ipset %s", setName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
nfset, err := r.conn.GetSetByName(r.workTable, set.HashedName())
|
||||
if err != nil {
|
||||
return fmt.Errorf("get set %s: %w", set.HashedName(), err)
|
||||
}
|
||||
|
||||
elements := r.convertPrefixesToSet(prefixes)
|
||||
if err := r.conn.SetAddElements(nfset, elements); err != nil {
|
||||
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
return fmt.Errorf(flushError, err)
|
||||
}
|
||||
|
||||
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
|
||||
// dst offset by default
|
||||
offset := r.af.dstAddrOffset
|
||||
if isSource {
|
||||
// src offset
|
||||
offset = r.af.srcAddrOffset
|
||||
}
|
||||
|
||||
return []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: offset,
|
||||
Len: r.af.addrLen,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetName: ref.Out.Name,
|
||||
SetID: ref.Out.ID,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
type ipsetStore struct {
|
||||
ipsetReference map[string]int
|
||||
ipsets map[string]map[string]struct{} // ipsetName -> list of ips
|
||||
}
|
||||
|
||||
func newIpsetStore() *ipsetStore {
|
||||
return &ipsetStore{
|
||||
ipsetReference: make(map[string]int),
|
||||
ipsets: make(map[string]map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ipsetStore) ips(ipsetName string) (map[string]struct{}, bool) {
|
||||
r, ok := s.ipsets[ipsetName]
|
||||
return r, ok
|
||||
}
|
||||
|
||||
func (s *ipsetStore) newIpset(ipsetName string) map[string]struct{} {
|
||||
s.ipsetReference[ipsetName] = 0
|
||||
ipList := make(map[string]struct{})
|
||||
s.ipsets[ipsetName] = ipList
|
||||
return ipList
|
||||
}
|
||||
|
||||
func (s *ipsetStore) deleteIpset(ipsetName string) {
|
||||
delete(s.ipsetReference, ipsetName)
|
||||
delete(s.ipsets, ipsetName)
|
||||
}
|
||||
|
||||
func (s *ipsetStore) DeleteIpFromSet(ipsetName string, ip net.IP) {
|
||||
ipList, ok := s.ipsets[ipsetName]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(ipList, ip.String())
|
||||
}
|
||||
|
||||
func (s *ipsetStore) AddIpToSet(ipsetName string, ip net.IP) {
|
||||
ipList, ok := s.ipsets[ipsetName]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ipList[ip.String()] = struct{}{}
|
||||
}
|
||||
|
||||
func (s *ipsetStore) IsIpInSet(ipsetName string, ip net.IP) bool {
|
||||
ipList, ok := s.ipsets[ipsetName]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
_, ok = ipList[ip.String()]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *ipsetStore) AddReferenceToIpset(ipsetName string) {
|
||||
s.ipsetReference[ipsetName]++
|
||||
}
|
||||
|
||||
func (s *ipsetStore) DeleteReferenceFromIpSet(ipsetName string) {
|
||||
r, ok := s.ipsetReference[ipsetName]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if r == 0 {
|
||||
return
|
||||
}
|
||||
s.ipsetReference[ipsetName]--
|
||||
}
|
||||
|
||||
func (s *ipsetStore) HasReferenceToSet(ipsetName string) bool {
|
||||
if _, ok := s.ipsetReference[ipsetName]; !ok {
|
||||
return false
|
||||
}
|
||||
if s.ipsetReference[ipsetName] == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package nftables
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
@@ -45,18 +44,17 @@ type iFaceMapper interface {
|
||||
Address() wgaddr.Address
|
||||
}
|
||||
|
||||
// Manager of iptables firewall
|
||||
// Manager of nftables firewall. Per-family state (peer ACLs, route
|
||||
// ACLs, NAT, DNAT, MSS clamping) lives on family; Manager dispatches
|
||||
// by family and provides the public firewall.Manager surface.
|
||||
type Manager struct {
|
||||
mutex sync.Mutex
|
||||
rConn *nftables.Conn
|
||||
wgIface iFaceMapper
|
||||
|
||||
router *router
|
||||
aclManager *AclManager
|
||||
|
||||
// IPv6 counterparts, nil when no v6 overlay
|
||||
router6 *router
|
||||
aclManager6 *AclManager
|
||||
family4 *family
|
||||
// IPv6 counterpart, nil when no v6 overlay.
|
||||
family6 *family
|
||||
|
||||
notrackOutputChain *nftables.Chain
|
||||
notrackPreroutingChain *nftables.Chain
|
||||
@@ -75,14 +73,9 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
||||
workTable := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}
|
||||
|
||||
var err error
|
||||
m.router, err = newRouter(workTable, wgIface, mtu)
|
||||
m.family4, err = newFamily(workTable, wgIface, mtu)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create router: %w", err)
|
||||
}
|
||||
|
||||
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||
return nil, fmt.Errorf("create family: %w", err)
|
||||
}
|
||||
|
||||
if wgIface.Address().HasIPv6() {
|
||||
@@ -100,26 +93,21 @@ func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mt
|
||||
workTable6 := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6}
|
||||
|
||||
var err error
|
||||
m.router6, err = newRouter(workTable6, wgIface, mtu)
|
||||
m.family6, err = newFamily(workTable6, wgIface, mtu)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create v6 router: %w", err)
|
||||
return fmt.Errorf("create v6 family: %w", err)
|
||||
}
|
||||
|
||||
// Share the same IP forwarding state with the v4 router, since
|
||||
// EnableIPForwarding controls both v4 and v6 sysctls.
|
||||
m.router6.ipFwdState = m.router.ipFwdState
|
||||
|
||||
m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create v6 acl manager: %w", err)
|
||||
}
|
||||
m.family6.ipFwdState = m.family4.ipFwdState
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasIPv6 reports whether the manager has IPv6 components initialized.
|
||||
func (m *Manager) hasIPv6() bool {
|
||||
return m.router6 != nil
|
||||
return m.family6 != nil
|
||||
}
|
||||
|
||||
func (m *Manager) initIPv6() error {
|
||||
@@ -128,12 +116,8 @@ func (m *Manager) initIPv6() error {
|
||||
return fmt.Errorf("create v6 work table: %w", err)
|
||||
}
|
||||
|
||||
if err := m.router6.init(workTable6); err != nil {
|
||||
return fmt.Errorf("v6 router init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.aclManager6.init(workTable6); err != nil {
|
||||
return fmt.Errorf("v6 acl manager init: %w", err)
|
||||
if err := m.family6.init(workTable6); err != nil {
|
||||
return fmt.Errorf("v6 family init: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -162,13 +146,13 @@ func (m *Manager) reconcileExternalChains() error {
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
var merr *multierror.Error
|
||||
if m.router != nil {
|
||||
if err := m.router.acceptExternalChainsRules(); err != nil {
|
||||
if m.family4 != nil {
|
||||
if err := m.family4.acceptExternalChainsRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("v4: %w", err))
|
||||
}
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
if err := m.router6.acceptExternalChainsRules(); err != nil {
|
||||
if err := m.family6.acceptExternalChainsRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("v6: %w", err))
|
||||
}
|
||||
}
|
||||
@@ -187,12 +171,8 @@ func (m *Manager) initFirewall() (err error) {
|
||||
}
|
||||
}()
|
||||
|
||||
if err := m.router.init(workTable); err != nil {
|
||||
return fmt.Errorf("router init: %w", err)
|
||||
}
|
||||
|
||||
if err := m.aclManager.init(workTable); err != nil {
|
||||
return fmt.Errorf("acl manager init: %w", err)
|
||||
if err := m.family4.init(workTable); err != nil {
|
||||
return fmt.Errorf("family init: %w", err)
|
||||
}
|
||||
|
||||
if m.hasIPv6() {
|
||||
@@ -220,7 +200,7 @@ func (m *Manager) persistState(stateManager *statemanager.Manager) {
|
||||
InterfaceState: &InterfaceState{
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
MTU: m.router.mtu,
|
||||
MTU: m.family4.mtu,
|
||||
},
|
||||
}); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
@@ -235,12 +215,12 @@ func (m *Manager) persistState(stateManager *statemanager.Manager) {
|
||||
|
||||
// rollbackInit performs best-effort cleanup of already-initialized state when Init fails partway through.
|
||||
func (m *Manager) rollbackInit() {
|
||||
if err := m.router.Reset(); err != nil {
|
||||
log.Warnf("rollback router: %v", err)
|
||||
if err := m.family4.Reset(); err != nil {
|
||||
log.Warnf("rollback family: %v", err)
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
if err := m.router6.Reset(); err != nil {
|
||||
log.Warnf("rollback v6 router: %v", err)
|
||||
if err := m.family6.Reset(); err != nil {
|
||||
log.Warnf("rollback v6 family: %v", err)
|
||||
}
|
||||
}
|
||||
if err := m.cleanupNetbirdTables(); err != nil {
|
||||
@@ -251,118 +231,108 @@ func (m *Manager) rollbackInit() {
|
||||
}
|
||||
}
|
||||
|
||||
// AddPeerFiltering rule to the firewall
|
||||
// AddFilterRule installs a packet-filtering rule.
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if ip.To4() != nil {
|
||||
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.aclManager6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
// Destination semantics: zero Network → input chain (peer ACL);
|
||||
// set Network → forward chain (route ACL).
|
||||
//
|
||||
// Sources can mix IPv4 and IPv6 prefixes; they're split by family
|
||||
// and dispatched to the per-family backends.
|
||||
func (m *Manager) AddFilterRule(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort, dPort *firewall.Port,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
) ([]firewall.Rule, error) {
|
||||
if len(sources) == 0 {
|
||||
return nil, firewall.ErrNoSources
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if isIPv6RouteRule(sources, destination) {
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
|
||||
isRoute := destination.IsPrefix() || destination.IsSet()
|
||||
if isRoute {
|
||||
fam := m.family4
|
||||
if isIPv6RouteRule(sources, destination) {
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
fam = m.family6
|
||||
}
|
||||
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
rule, err := fam.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []firewall.Rule{rule}, nil
|
||||
}
|
||||
|
||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if m.hasIPv6() && isIPv6Rule(rule) {
|
||||
return m.aclManager6.DeletePeerRule(rule)
|
||||
v4Sources, v6Sources := splitSourcesByFamily(sources)
|
||||
if len(v6Sources) > 0 && !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add peer filtering for v6 sources %v: %w", v6Sources, firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.aclManager.DeletePeerRule(rule)
|
||||
}
|
||||
|
||||
func isIPv6Rule(rule firewall.Rule) bool {
|
||||
r, ok := rule.(*Rule)
|
||||
return ok && r.nftRule != nil && r.nftRule.Table != nil && r.nftRule.Table.Family == nftables.TableFamilyIPv6
|
||||
}
|
||||
|
||||
// isIPv6RouteRule determines whether a route rule belongs to the v6 table.
|
||||
// For static routes, the destination prefix determines the family. For dynamic
|
||||
// routes (DomainSet), the sources determine the family since management
|
||||
// duplicates dynamic rules per family.
|
||||
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
|
||||
if destination.IsPrefix() {
|
||||
return destination.Prefix.Addr().Is6()
|
||||
var out []firewall.Rule
|
||||
if len(v4Sources) > 0 {
|
||||
rule, err := m.family4.AddFilterRule(id, v4Sources, destination, proto, sPort, dPort, action)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, rule)
|
||||
}
|
||||
return len(sources) > 0 && sources[0].Addr().Is6()
|
||||
if len(v6Sources) > 0 {
|
||||
rule, err := m.family6.AddFilterRule(id, v6Sources, destination, proto, sPort, dPort, action)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, rule)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// DeleteRouteRule deletes a routing rule. Route rules live in exactly one
|
||||
// router; the cached maps are normally authoritative, so the kernel is only
|
||||
// consulted when neither map knows about the rule.
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
// DeleteFilterRule removes a filtering rule. The rule is looked up by
|
||||
// id in each family's filter cache.
|
||||
func (m *Manager) DeleteFilterRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
id := rule.ID()
|
||||
r, err := m.routerForRuleID(id, (*router).hasRule)
|
||||
if err != nil {
|
||||
return err
|
||||
if m.family4.hasRule(id) {
|
||||
return m.family4.DeleteFilterRule(rule)
|
||||
}
|
||||
return r.DeleteRouteRule(rule)
|
||||
if m.hasIPv6() && m.family6.hasRule(id) {
|
||||
return m.family6.DeleteFilterRule(rule)
|
||||
}
|
||||
log.Debugf("filter rule %s not found in any family", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// routerForRuleID picks the router holding the rule with the given id, using
|
||||
// familyForRuleID picks the family holding the rule with the given id, using
|
||||
// the supplied lookup. If the cached maps disagree (or both miss), it refreshes
|
||||
// from the kernel once and re-checks before falling back to the v4 router.
|
||||
func (m *Manager) routerForRuleID(id string, has func(*router, string) bool) (*router, error) {
|
||||
if has(m.router, id) {
|
||||
return m.router, nil
|
||||
// from the kernel once and re-checks before falling back to the v4 family.
|
||||
func (m *Manager) familyForRuleID(id firewall.RuleID, has func(*family, firewall.RuleID) bool) (*family, error) {
|
||||
if has(m.family4, id) {
|
||||
return m.family4, nil
|
||||
}
|
||||
if m.hasIPv6() && has(m.router6, id) {
|
||||
return m.router6, nil
|
||||
if m.hasIPv6() && has(m.family6, id) {
|
||||
return m.family6, nil
|
||||
}
|
||||
if !m.hasIPv6() {
|
||||
return m.router, nil
|
||||
return m.family4, nil
|
||||
}
|
||||
if err := m.router.refreshRulesMap(); err != nil {
|
||||
if err := m.family4.refreshRulesMap(); err != nil {
|
||||
return nil, fmt.Errorf("refresh v4 rules: %w", err)
|
||||
}
|
||||
if err := m.router6.refreshRulesMap(); err != nil {
|
||||
if err := m.family6.refreshRulesMap(); err != nil {
|
||||
return nil, fmt.Errorf("refresh v6 rules: %w", err)
|
||||
}
|
||||
if has(m.router6, id) && !has(m.router, id) {
|
||||
return m.router6, nil
|
||||
if has(m.family6, id) && !has(m.family4, id) {
|
||||
return m.family6, nil
|
||||
}
|
||||
return m.router, nil
|
||||
return m.family4, nil
|
||||
}
|
||||
|
||||
func (m *Manager) IsServerRouteSupported() bool {
|
||||
@@ -381,10 +351,10 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddNatRule(pair)
|
||||
return m.family6.AddNatRule(pair)
|
||||
}
|
||||
|
||||
if err := m.router.AddNatRule(pair); err != nil {
|
||||
if err := m.family4.AddNatRule(pair); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -396,7 +366,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||
// so the eventual cleanup still works.
|
||||
if m.hasIPv6() && pair.Dynamic {
|
||||
v6Pair := firewall.ToV6NatPair(pair)
|
||||
if err := m.router6.AddNatRule(v6Pair); err != nil {
|
||||
if err := m.family6.AddNatRule(v6Pair); err != nil {
|
||||
return fmt.Errorf("add v6 NAT rule: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -412,18 +382,18 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if !m.hasIPv6() {
|
||||
return nil
|
||||
}
|
||||
return m.router6.RemoveNatRule(pair)
|
||||
return m.family6.RemoveNatRule(pair)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := m.router.RemoveNatRule(pair); err != nil {
|
||||
if err := m.family4.RemoveNatRule(pair); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
|
||||
}
|
||||
|
||||
if m.hasIPv6() && pair.Dynamic {
|
||||
v6Pair := firewall.ToV6NatPair(pair)
|
||||
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
|
||||
if err := m.family6.RemoveNatRule(v6Pair); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
|
||||
}
|
||||
}
|
||||
@@ -445,11 +415,11 @@ func (m *Manager) AllowNetbird() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if err := m.aclManager.createDefaultAllowRules(); err != nil {
|
||||
if err := m.family4.createDefaultAllowRules(); err != nil {
|
||||
return fmt.Errorf("create default allow rules: %w", err)
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
if err := m.aclManager6.createDefaultAllowRules(); err != nil {
|
||||
if err := m.family6.createDefaultAllowRules(); err != nil {
|
||||
return fmt.Errorf("create v6 default allow rules: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -466,11 +436,11 @@ func (m *Manager) AllowNetbird() error {
|
||||
|
||||
// SetLegacyManagement sets the route manager to use legacy management
|
||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
|
||||
if err := firewall.SetLegacyManagement(m.family4, isLegacy); err != nil {
|
||||
return err
|
||||
}
|
||||
if m.hasIPv6() {
|
||||
return firewall.SetLegacyManagement(m.router6, isLegacy)
|
||||
return firewall.SetLegacyManagement(m.family6, isLegacy)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -484,13 +454,13 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := m.router.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset router: %v", err))
|
||||
if err := m.family4.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset family: %w", err))
|
||||
}
|
||||
|
||||
if m.hasIPv6() {
|
||||
if err := m.router6.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %v", err))
|
||||
if err := m.family6.Reset(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 family: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -530,14 +500,14 @@ func (m *Manager) SetLogLevel(log.Level) {
|
||||
}
|
||||
|
||||
func (m *Manager) EnableRouting() error {
|
||||
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
||||
if err := m.family4.ipFwdState.RequestForwarding(); err != nil {
|
||||
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DisableRouting() error {
|
||||
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
if err := m.family4.ipFwdState.ReleaseForwarding(); err != nil {
|
||||
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||
}
|
||||
return nil
|
||||
@@ -551,12 +521,12 @@ func (m *Manager) Flush() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if err := m.aclManager.Flush(); err != nil {
|
||||
if err := m.family4.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.hasIPv6() {
|
||||
if err := m.aclManager6.Flush(); err != nil {
|
||||
if err := m.family6.Flush(); err != nil {
|
||||
return fmt.Errorf("flush v6 acl: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -577,9 +547,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
|
||||
if !m.hasIPv6() {
|
||||
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddDNATRule(rule)
|
||||
return m.family6.AddDNATRule(rule)
|
||||
}
|
||||
return m.router.AddDNATRule(rule)
|
||||
return m.family4.AddDNATRule(rule)
|
||||
}
|
||||
|
||||
// DeleteDNATRule deletes a DNAT rule
|
||||
@@ -587,7 +557,7 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
r, err := m.routerForRuleID(rule.ID(), (*router).hasDNATRule)
|
||||
r, err := m.familyForRuleID(rule.ID(), (*family).hasDNATRule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -608,12 +578,12 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
|
||||
if err := m.family4.UpdateSet(set, v4Prefixes); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.hasIPv6() && len(v6Prefixes) > 0 {
|
||||
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
|
||||
if err := m.family6.UpdateSet(set, v6Prefixes); err != nil {
|
||||
return fmt.Errorf("update v6 set: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -630,9 +600,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family4.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
|
||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||
@@ -644,9 +614,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family4.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
|
||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||
@@ -658,9 +628,9 @@ func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family4.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
|
||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||
@@ -672,9 +642,9 @@ func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Proto
|
||||
if !m.hasIPv6() {
|
||||
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||
}
|
||||
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
return m.family4.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -903,3 +873,31 @@ func getEstablishedExprs(register uint32) []expr.Any {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// splitSourcesByFamily partitions a mixed-family prefix list into v4
|
||||
// and v6 buckets. v4-mapped v6 prefixes are normalized to v4 so the
|
||||
// match builder reads the correct address length. An empty input maps
|
||||
// to "match any" on v4 only; callers wanting v6 wildcards must include
|
||||
// ::/0 explicitly.
|
||||
func splitSourcesByFamily(sources []netip.Prefix) (v4, v6 []netip.Prefix) {
|
||||
for _, p := range sources {
|
||||
addr := p.Addr()
|
||||
if addr.Is4() || addr.Is4In6() {
|
||||
v4 = append(v4, firewall.UnmapPrefix(p))
|
||||
} else {
|
||||
v6 = append(v6, p)
|
||||
}
|
||||
}
|
||||
return v4, v6
|
||||
}
|
||||
|
||||
// isIPv6RouteRule determines whether a route rule belongs to the v6 table.
|
||||
// For static routes, the destination prefix determines the family. For dynamic
|
||||
// routes (DomainSet), the sources determine the family since management
|
||||
// duplicates dynamic rules per family.
|
||||
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
|
||||
if destination.IsPrefix() {
|
||||
return destination.Prefix.Addr().Is6()
|
||||
}
|
||||
return len(sources) > 0 && sources[0].Addr().Is6()
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build integration && !android
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
@@ -70,13 +72,13 @@ func TestNftablesManager(t *testing.T) {
|
||||
|
||||
testClient := &nftables.Conn{}
|
||||
|
||||
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
||||
rule, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
|
||||
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
||||
rules, err := testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
|
||||
require.Len(t, rules, 2, "expected 2 rules")
|
||||
@@ -148,14 +150,14 @@ func TestNftablesManager(t *testing.T) {
|
||||
compareExprsIgnoringCounters(t, rules[1].Exprs, expectedExprs1)
|
||||
|
||||
for _, r := range rule {
|
||||
err = manager.DeletePeerRule(r)
|
||||
err = manager.DeleteFilterRule(r)
|
||||
require.NoError(t, err, "failed to delete rule")
|
||||
}
|
||||
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
|
||||
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
||||
rules, err = testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
// established rule remains
|
||||
require.Len(t, rules, 1, "expected 1 rules after deletion")
|
||||
@@ -180,47 +182,39 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
|
||||
testClient := &nftables.Conn{}
|
||||
|
||||
// Add accept rule first
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "accept-http")
|
||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
require.NoError(t, err, "failed to add accept rule")
|
||||
|
||||
// Add deny rule second for the same traffic
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop, "deny-http")
|
||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||
require.NoError(t, err, "failed to add deny rule")
|
||||
|
||||
err = manager.Flush()
|
||||
require.NoError(t, err, "failed to flush")
|
||||
|
||||
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
||||
rules, err := testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
|
||||
require.NoError(t, err, "failed to get rules")
|
||||
|
||||
t.Logf("Found %d rules in nftables chain", len(rules))
|
||||
|
||||
// Find the accept and deny rules and verify deny comes before accept
|
||||
// Single-source rules emit a direct payload+cmp on the source IP
|
||||
// (no set lookup). Match by source-IP + port + verdict instead of
|
||||
// the legacy per-(action,port) set names ("deny-http"/"accept-http")
|
||||
// that this test predates.
|
||||
wantSrc := ip.AsSlice()
|
||||
var acceptRuleIndex, denyRuleIndex = -1, -1
|
||||
for i, rule := range rules {
|
||||
hasAcceptHTTPSet := false
|
||||
hasDenyHTTPSet := false
|
||||
hasPort80 := false
|
||||
var hasSrc, hasPort80 bool
|
||||
var action string
|
||||
|
||||
for _, e := range rule.Exprs {
|
||||
// Check for set lookup
|
||||
if lookup, ok := e.(*expr.Lookup); ok {
|
||||
switch lookup.SetName {
|
||||
case "accept-http":
|
||||
hasAcceptHTTPSet = true
|
||||
case "deny-http":
|
||||
hasDenyHTTPSet = true
|
||||
if cmp, ok := e.(*expr.Cmp); ok && cmp.Op == expr.CmpOpEq {
|
||||
if bytes.Equal(cmp.Data, wantSrc) {
|
||||
hasSrc = true
|
||||
}
|
||||
|
||||
}
|
||||
// Check for port 80
|
||||
if cmp, ok := e.(*expr.Cmp); ok {
|
||||
if cmp.Op == expr.CmpOpEq && len(cmp.Data) == 2 && binary.BigEndian.Uint16(cmp.Data) == 80 {
|
||||
if len(cmp.Data) == 2 && binary.BigEndian.Uint16(cmp.Data) == 80 {
|
||||
hasPort80 = true
|
||||
}
|
||||
}
|
||||
// Check for verdict
|
||||
if verdict, ok := e.(*expr.Verdict); ok {
|
||||
switch verdict.Kind {
|
||||
case expr.VerdictAccept:
|
||||
@@ -231,11 +225,15 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
if hasAcceptHTTPSet && hasPort80 && action == "ACCEPT" {
|
||||
t.Logf("Rule [%d]: accept-http set + Port 80 + ACCEPT", i)
|
||||
if !hasSrc || !hasPort80 {
|
||||
continue
|
||||
}
|
||||
switch action {
|
||||
case "ACCEPT":
|
||||
t.Logf("Rule [%d]: src=%s port=80 ACCEPT", i, ip)
|
||||
acceptRuleIndex = i
|
||||
} else if hasDenyHTTPSet && hasPort80 && action == "DROP" {
|
||||
t.Logf("Rule [%d]: deny-http set + Port 80 + DROP", i)
|
||||
case "DROP":
|
||||
t.Logf("Rule [%d]: src=%s port=80 DROP", i, ip)
|
||||
denyRuleIndex = i
|
||||
}
|
||||
}
|
||||
@@ -279,7 +277,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
|
||||
if i%100 == 0 {
|
||||
@@ -361,10 +359,10 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
})
|
||||
|
||||
ip := netip.MustParseAddr("100.96.0.1")
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
require.NoError(t, err, "failed to add peer filtering rule")
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
_, err = manager.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("10.1.0.0/24")},
|
||||
@@ -437,10 +435,10 @@ func TestNftablesManagerIPv6CompatibilityWithIp6tables(t *testing.T) {
|
||||
})
|
||||
|
||||
ip := netip.MustParseAddr("fd00::2")
|
||||
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
require.NoError(t, err, "add v6 peer filtering rule")
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
_, err = manager.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("fd00:1::/64")},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
|
||||
@@ -550,7 +548,7 @@ func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
|
||||
prefixes = append(prefixes, netip.PrefixFrom(addr, 24))
|
||||
}
|
||||
}
|
||||
_, err = manager.AddRouteFiltering(
|
||||
_, err = manager.AddFilterRule(
|
||||
nil,
|
||||
prefixes,
|
||||
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
|
||||
@@ -591,7 +589,7 @@ func TestNftablesManagerCompatibilityWithIptablesForEmptyPrefixes(t *testing.T)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
})
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
_, err = manager.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
|
||||
|
||||
492
client/firewall/nftables/nat_linux.go
Normal file
492
client/firewall/nftables/nat_linux.go
Normal file
@@ -0,0 +1,492 @@
|
||||
//go:build !android
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/nftables"
|
||||
"github.com/google/nftables/binaryutil"
|
||||
"github.com/google/nftables/expr"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
func (r *family) AddNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
if r.legacyManagement {
|
||||
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("add legacy routing rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if pair.Masquerade {
|
||||
if err := r.addNatRule(pair); err != nil {
|
||||
return fmt.Errorf("add nat rule: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
return fmt.Errorf("add inverse nat rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
r.rollbackRules(pair)
|
||||
return fmt.Errorf("insert rules for %s: %w", pair.Destination, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// rollbackRules cleans up unflushed rules and their set counters after a flush failure.
|
||||
func (r *family) rollbackRules(pair firewall.RouterPair) {
|
||||
keys := []firewall.RuleID{
|
||||
firewall.GenKey(firewall.ForwardingFormat, pair),
|
||||
firewall.GenKey(firewall.PreroutingFormat, pair),
|
||||
firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair)),
|
||||
}
|
||||
for _, key := range keys {
|
||||
rule, ok := r.rules[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
log.Warnf("rollback set counter for %s: %v", key, err)
|
||||
}
|
||||
delete(r.rules, key)
|
||||
}
|
||||
}
|
||||
|
||||
// addNatRule inserts a nftables rule to the conn client flush queue
|
||||
func (r *family) addNatRule(pair firewall.RouterPair) error {
|
||||
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply source: %w", err)
|
||||
}
|
||||
|
||||
destExp, err := r.applyNetwork(pair.Destination, nil, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply destination: %w", err)
|
||||
}
|
||||
|
||||
op := expr.CmpOpEq
|
||||
if pair.Inverse {
|
||||
op = expr.CmpOpNeq
|
||||
}
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: op,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
}
|
||||
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
|
||||
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
|
||||
exprs = append(exprs, getCtNewExprs()...)
|
||||
|
||||
exprs = append(exprs, sourceExp...)
|
||||
exprs = append(exprs, destExp...)
|
||||
|
||||
var markValue uint32 = nbnet.PreroutingFwmarkMasquerade
|
||||
if pair.Inverse {
|
||||
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
|
||||
}
|
||||
|
||||
exprs = append(exprs,
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(markValue),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
SourceRegister: true,
|
||||
Register: 1,
|
||||
},
|
||||
)
|
||||
|
||||
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||
|
||||
if _, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
return fmt.Errorf("remove prerouting rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure nat rules come first, so the mark can be overwritten.
|
||||
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
|
||||
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameManglePrerouting],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleKey),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *family) addPostroutingRules() {
|
||||
// First masquerade rule for traffic coming in from WireGuard interface
|
||||
exprs := []expr.Any{
|
||||
// Match on the first fwmark
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasquerade),
|
||||
},
|
||||
|
||||
// We need to exclude the loopback interface as this changes the ebpf proxy port
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: ifname("lo"),
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Masq{},
|
||||
}
|
||||
|
||||
r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: exprs,
|
||||
})
|
||||
|
||||
// Second masquerade rule for traffic going out through WireGuard interface
|
||||
exprs2 := []expr.Any{
|
||||
// Match on the second fwmark
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||
},
|
||||
|
||||
// Match WireGuard interface
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Masq{},
|
||||
}
|
||||
|
||||
r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingNat],
|
||||
Exprs: exprs2,
|
||||
})
|
||||
}
|
||||
|
||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||
func (r *family) addMSSClampingRules() error {
|
||||
overhead := uint16(ipv4TCPHeaderSize)
|
||||
if r.af.tableFamily == nftables.TableFamilyIPv6 {
|
||||
overhead = ipv6TCPHeaderSize
|
||||
}
|
||||
if r.mtu <= overhead {
|
||||
log.Debugf("MTU %d too small for MSS clamping (overhead %d), skipping", r.mtu, overhead)
|
||||
return nil
|
||||
}
|
||||
mss := r.mtu - overhead
|
||||
|
||||
exprsOut := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyL4PROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{unix.IPPROTO_TCP},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 13,
|
||||
Len: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
DestRegister: 1,
|
||||
SourceRegister: 1,
|
||||
Len: 1,
|
||||
Mask: []byte{0x02},
|
||||
Xor: []byte{0x00},
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0x00},
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Exthdr{
|
||||
DestRegister: 1,
|
||||
Type: 2,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
Op: expr.ExthdrOpTcpopt,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpGt,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
|
||||
},
|
||||
&expr.Exthdr{
|
||||
SourceRegister: 1,
|
||||
Type: 2,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
Op: expr.ExthdrOpTcpopt,
|
||||
},
|
||||
}
|
||||
|
||||
r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameMangleForward],
|
||||
Exprs: exprsOut,
|
||||
})
|
||||
|
||||
return r.conn.Flush()
|
||||
}
|
||||
|
||||
func (r *family) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply source: %w", err)
|
||||
}
|
||||
|
||||
destExp, err := r.applyNetwork(pair.Destination, nil, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("apply destination: %w", err)
|
||||
}
|
||||
|
||||
exprs := []expr.Any{
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
}
|
||||
|
||||
exprs = append(exprs, sourceExp...)
|
||||
exprs = append(exprs, destExp...)
|
||||
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
|
||||
if _, exists := r.rules[ruleKey]; exists {
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingFw],
|
||||
Exprs: exprs,
|
||||
UserData: []byte(ruleKey),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls
|
||||
func (r *family) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||
|
||||
rule, exists := r.rules[ruleKey]
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
log.Warnf("legacy forwarding rule %s has no handle, removing stale entry", ruleKey)
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement set counter: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLegacyManagement returns the route manager's legacy management mode
|
||||
func (r *family) GetLegacyManagement() bool {
|
||||
return r.legacyManagement
|
||||
}
|
||||
|
||||
// SetLegacyManagement sets the route manager to use legacy management mode
|
||||
func (r *family) SetLegacyManagement(isLegacy bool) {
|
||||
r.legacyManagement = isLegacy
|
||||
}
|
||||
|
||||
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
|
||||
func (r *family) RemoveAllLegacyRouteRules() error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
for k, rule := range r.rules {
|
||||
if !strings.HasPrefix(string(k), firewall.ForwardingFormatPrefix) {
|
||||
continue
|
||||
}
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
||||
} else {
|
||||
delete(r.rules, k)
|
||||
}
|
||||
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) removeNatPreroutingRules() error {
|
||||
table := &nftables.Table{
|
||||
Name: tableNat,
|
||||
Family: r.af.tableFamily,
|
||||
}
|
||||
chain := &nftables.Chain{
|
||||
Name: chainNameNatPrerouting,
|
||||
Table: table,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityNATDest,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
}
|
||||
rules, err := r.conn.GetRules(table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get rules from nat table: %w", err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
// Delete rules that have our UserData suffix
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) == 0 || !strings.HasSuffix(string(rule.UserData), string(dnatSuffix)) {
|
||||
continue
|
||||
}
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete rule %s: %w", rule.UserData, err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
if err := r.refreshRulesMap(); err != nil {
|
||||
return fmt.Errorf(refreshRulesMapError, err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if pair.Masquerade {
|
||||
if err := r.removeNatRule(pair); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove prerouting rule: %w", err))
|
||||
}
|
||||
|
||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove inverse prerouting rule: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy routing rule: %w", err))
|
||||
}
|
||||
|
||||
// Set counters are decremented in the sub-methods above before flush. If flush fails,
|
||||
// counters will be off until the next successful removal or refresh cycle.
|
||||
if err := r.conn.Flush(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("flush remove nat rules %s: %w", pair.Destination, err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *family) removeNatRule(pair firewall.RouterPair) error {
|
||||
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||
|
||||
rule, exists := r.rules[ruleKey]
|
||||
if !exists {
|
||||
log.Debugf("prerouting rule %s not found", ruleKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
if rule.Handle == 0 {
|
||||
log.Warnf("prerouting rule %s has no handle, removing stale entry", ruleKey)
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
||||
}
|
||||
delete(r.rules, ruleKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("remove prerouting rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||
}
|
||||
|
||||
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||
|
||||
delete(r.rules, ruleKey)
|
||||
|
||||
if err := r.decrementSetCounter(rule); err != nil {
|
||||
return fmt.Errorf("decrement set counter: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
||||
//go:build !android
|
||||
//go:build integration && !android
|
||||
|
||||
package nftables
|
||||
|
||||
@@ -37,7 +37,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
|
||||
for _, testCase := range test.InsertRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
// need fw manager to init both acl mgr and router for all chains to be present
|
||||
// need fw manager to init both acl mgr and family for all chains to be present
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -47,7 +47,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
|
||||
nftablesTestingClient := &nftables.Conn{}
|
||||
|
||||
rtr := manager.router
|
||||
rtr := manager.family4
|
||||
err = rtr.AddNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "pair should be inserted")
|
||||
|
||||
@@ -90,7 +90,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
}
|
||||
|
||||
// Build CIDR matching expressions
|
||||
testRouter := &router{af: afIPv4}
|
||||
testRouter := &family{af: afIPv4}
|
||||
sourceExp := testRouter.applyPrefix(testCase.InputPair.Source.Prefix, true)
|
||||
destExp := testRouter.applyPrefix(testCase.InputPair.Destination.Prefix, false)
|
||||
|
||||
@@ -107,7 +107,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
|
||||
// Compare expressions up to the mark setting expressions
|
||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "prerouting nat rule elements should match")
|
||||
found = 1
|
||||
@@ -135,9 +135,9 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
rtr := manager.router
|
||||
rtr := manager.family4
|
||||
|
||||
// First add the NAT rule using the router's method
|
||||
// First add the NAT rule using the family's method
|
||||
err = rtr.AddNatRule(testCase.InputPair)
|
||||
require.NoError(t, err, "should add NAT rule")
|
||||
|
||||
@@ -147,7 +147,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||
require.NoError(t, err, "should list rules")
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
@@ -163,7 +163,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||
require.NoError(t, err, "should list rules after removal")
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
@@ -200,11 +200,11 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create router")
|
||||
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create family")
|
||||
require.NoError(t, r.init(workTable))
|
||||
|
||||
defer func(r *router) {
|
||||
defer func(r *family) {
|
||||
require.NoError(t, r.Reset(), "Failed to reset rules")
|
||||
}(r)
|
||||
|
||||
@@ -314,16 +314,16 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddRouteFiltering failed")
|
||||
ruleKey, err := r.AddFilterRule(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||
require.NoError(t, err, "AddFilterRule failed")
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule")
|
||||
require.NoError(t, r.DeleteFilterRule(ruleKey), "Failed to delete rule")
|
||||
})
|
||||
|
||||
// Check if the rule is in the internal map
|
||||
rule, ok := r.rules[ruleKey.ID()]
|
||||
assert.True(t, ok, "Rule not found in internal map")
|
||||
stored, ok := r.filters[id.RuleID(ruleKey.ID())]
|
||||
require.True(t, ok, "Rule not found in filters map")
|
||||
rule := stored.nftRule
|
||||
|
||||
t.Log("Internal rule expressions:")
|
||||
for i, expr := range rule.Exprs {
|
||||
@@ -339,7 +339,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
var nftRule *nftables.Rule
|
||||
for _, rule := range rules {
|
||||
if string(rule.UserData) == ruleKey.ID() {
|
||||
if firewall.RuleID(rule.UserData) == ruleKey.ID() {
|
||||
nftRule = rule
|
||||
break
|
||||
}
|
||||
@@ -367,12 +367,12 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create router")
|
||||
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create family")
|
||||
require.NoError(t, r.init(workTable))
|
||||
|
||||
defer func() {
|
||||
require.NoError(t, r.Reset(), "Failed to reset router")
|
||||
require.NoError(t, r.Reset(), "Failed to reset family")
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
@@ -518,11 +518,11 @@ func TestNftablesCreateIpSet_IPv6(t *testing.T) {
|
||||
require.NoError(t, err, "Failed to create v6 work table")
|
||||
defer deleteWorkTableIPv6()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create router")
|
||||
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create family")
|
||||
require.NoError(t, r.init(workTable))
|
||||
defer func() {
|
||||
require.NoError(t, r.Reset(), "Failed to reset router")
|
||||
require.NoError(t, r.Reset(), "Failed to reset family")
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
@@ -861,13 +861,13 @@ func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, r.init(workTable))
|
||||
defer func() { require.NoError(t, r.Reset()) }()
|
||||
|
||||
// Add a real rule to the kernel
|
||||
ruleKey, err := r.AddRouteFiltering(
|
||||
ruleKey, err := r.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
||||
@@ -878,11 +878,11 @@ func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, r.DeleteRouteRule(ruleKey))
|
||||
require.NoError(t, r.DeleteFilterRule(ruleKey))
|
||||
})
|
||||
|
||||
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
|
||||
staleKey := "stale-rule-that-does-not-exist"
|
||||
staleKey := firewall.RuleID("stale-rule-that-does-not-exist")
|
||||
r.rules[staleKey] = &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingFw],
|
||||
@@ -902,6 +902,55 @@ func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
|
||||
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
|
||||
}
|
||||
|
||||
// TestRouter_DeleteRouteRule_RemovesKernelRule verifies a route filter
|
||||
// rule is actually removed from the kernel on delete. The route chain is
|
||||
// not refreshed by Flush, so the stored rule carries a zero handle;
|
||||
// DeleteFilterRule must pull live handles itself before issuing the
|
||||
// delete or the kernel rule leaks. Regression test for that path.
|
||||
func TestRouter_DeleteRouteRule_RemovesKernelRule(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
}
|
||||
|
||||
workTable, err := createWorkTable()
|
||||
require.NoError(t, err)
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, r.init(workTable))
|
||||
defer func() { require.NoError(t, r.Reset()) }()
|
||||
|
||||
ruleKey, err := r.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
||||
firewall.ProtocolTCP,
|
||||
nil,
|
||||
&firewall.Port{Values: []uint16{80}},
|
||||
firewall.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
countKernelRules := func() int {
|
||||
list, err := r.conn.GetRules(r.workTable, r.chains[chainNameRoutingFw])
|
||||
require.NoError(t, err)
|
||||
n := 0
|
||||
for _, rule := range list {
|
||||
if string(rule.UserData) == string(ruleKey.ID()) {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
require.Equal(t, 1, countKernelRules(), "rule should be present in the kernel after add")
|
||||
|
||||
require.NoError(t, r.DeleteFilterRule(ruleKey))
|
||||
assert.Equal(t, 0, countKernelRules(), "rule must be removed from the kernel after delete")
|
||||
assert.NotContains(t, r.filters, ruleKey.ID(), "filters map entry should be cleared")
|
||||
}
|
||||
|
||||
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
|
||||
if check() != NFTABLES {
|
||||
t.Skip("nftables not supported on this system")
|
||||
@@ -911,24 +960,27 @@ func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
r, err := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, r.init(workTable))
|
||||
defer func() { require.NoError(t, r.Reset()) }()
|
||||
|
||||
// Inject a stale entry with Handle=0
|
||||
staleKey := "stale-route-rule"
|
||||
r.rules[staleKey] = &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingFw],
|
||||
Handle: 0,
|
||||
UserData: []byte(staleKey),
|
||||
staleKey := id.RuleID("stale-route-rule")
|
||||
r.filters[staleKey] = &Rule{
|
||||
nftRule: &nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameRoutingFw],
|
||||
Handle: 0,
|
||||
UserData: []byte(staleKey),
|
||||
},
|
||||
ruleID: staleKey,
|
||||
}
|
||||
|
||||
// DeleteRouteRule should not return an error for stale handles
|
||||
err = r.DeleteRouteRule(id.RuleID(staleKey))
|
||||
// DeleteFilterRule should not return an error for stale handles
|
||||
err = r.DeleteFilterRule(staleKey)
|
||||
assert.NoError(t, err, "deleting a stale rule should not error")
|
||||
assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up")
|
||||
assert.NotContains(t, r.filters, staleKey, "stale entry should be cleaned up")
|
||||
}
|
||||
|
||||
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
||||
@@ -950,7 +1002,7 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
||||
Masquerade: true,
|
||||
}
|
||||
|
||||
rtr := manager.router
|
||||
rtr := manager.family4
|
||||
|
||||
// First add succeeds
|
||||
err = rtr.AddNatRule(pair)
|
||||
@@ -979,7 +1031,7 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
||||
|
||||
found := 0
|
||||
for _, rule := range rules {
|
||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
|
||||
found++
|
||||
}
|
||||
}
|
||||
@@ -1010,7 +1062,7 @@ func TestCalculateLastIP(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConvertPrefixesToSet_IPv6(t *testing.T) {
|
||||
r := &router{af: afIPv6}
|
||||
r := &family{af: afIPv6}
|
||||
prefixes := []netip.Prefix{
|
||||
netip.MustParsePrefix("fd00::/64"),
|
||||
netip.MustParsePrefix("2001:db8::1/128"),
|
||||
|
||||
@@ -1,21 +1,26 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
"github.com/google/nftables"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
// Rule to handle management of rules
|
||||
// Rule wraps an installed filter rule (peer or route). Source set
|
||||
// membership is encoded in the rule's expressions; DeleteFilterRule
|
||||
// recovers the set name via findSets so the refcounter can drop the
|
||||
// right reference. mangleRule is set only for peer rules.
|
||||
type Rule struct {
|
||||
nftRule *nftables.Rule
|
||||
mangleRule *nftables.Rule
|
||||
nftSet *nftables.Set
|
||||
ruleID string
|
||||
ip net.IP
|
||||
// sources is the canonical source list this rule was created for.
|
||||
sources []netip.Prefix
|
||||
id manager.RuleID
|
||||
}
|
||||
|
||||
// GetRuleID returns the rule id
|
||||
func (r *Rule) ID() string {
|
||||
return r.ruleID
|
||||
// ID returns the rule id
|
||||
func (r *Rule) ID() manager.RuleID {
|
||||
return r.id
|
||||
}
|
||||
|
||||
23
client/firewall/nftables/testhelpers_linux_test.go
Normal file
23
client/firewall/nftables/testhelpers_linux_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
//go:build integration && !android
|
||||
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func pfx(ip net.IP) []netip.Prefix {
|
||||
if ip == nil {
|
||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||
}
|
||||
if ip.IsUnspecified() {
|
||||
if ip.To4() != nil {
|
||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||
}
|
||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||
}
|
||||
a, _ := netip.AddrFromSlice(ip)
|
||||
a = a.Unmap()
|
||||
return []netip.Prefix{netip.PrefixFrom(a, a.BitLen())}
|
||||
}
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
@@ -72,8 +71,10 @@ const (
|
||||
|
||||
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
|
||||
|
||||
// RuleSet is a set of rules grouped by a string key
|
||||
type RuleSet map[string]PeerRule
|
||||
// PeerRules is the canonical list-based storage for peer ACL rules.
|
||||
// Match order is significant: drop rules come before accept rules so
|
||||
// callers should consult the slice in order.
|
||||
type PeerRules []*PeerRule
|
||||
|
||||
type RouteRules []*RouteRule
|
||||
|
||||
@@ -86,20 +87,22 @@ func (r RouteRules) Sort() {
|
||||
if a.action != firewall.ActionDrop && b.action == firewall.ActionDrop {
|
||||
return 1
|
||||
}
|
||||
return strings.Compare(a.id, b.id)
|
||||
return strings.Compare(string(a.id), string(b.id))
|
||||
})
|
||||
}
|
||||
|
||||
// Manager userspace firewall manager
|
||||
type Manager struct {
|
||||
outgoingRules map[netip.Addr]RuleSet
|
||||
incomingDenyRules map[netip.Addr]RuleSet
|
||||
incomingRules map[netip.Addr]RuleSet
|
||||
routeRules RouteRules
|
||||
routeRulesMap map[nbid.RuleID]*RouteRule
|
||||
decoders sync.Pool
|
||||
wgIface common.IFaceMapper
|
||||
nativeFirewall firewall.Manager
|
||||
incomingDenyRules PeerRules
|
||||
incomingRules PeerRules
|
||||
incomingDenyIndex peerRuleIndex
|
||||
incomingAcceptIndex peerRuleIndex
|
||||
peerRulesMap map[nbid.RuleID]*PeerRule
|
||||
routeRules RouteRules
|
||||
routeRulesMap map[nbid.RuleID]*RouteRule
|
||||
decoders sync.Pool
|
||||
wgIface common.IFaceMapper
|
||||
nativeFirewall firewall.Manager
|
||||
|
||||
mutex sync.RWMutex
|
||||
|
||||
@@ -255,9 +258,6 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
},
|
||||
},
|
||||
nativeFirewall: nativeFirewall,
|
||||
outgoingRules: make(map[netip.Addr]RuleSet),
|
||||
incomingDenyRules: make(map[netip.Addr]RuleSet),
|
||||
incomingRules: make(map[netip.Addr]RuleSet),
|
||||
wgIface: iface,
|
||||
localipmanager: newLocalIPManager(),
|
||||
disableServerRoutes: disableServerRoutes,
|
||||
@@ -266,6 +266,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
flowLogger: flowLogger,
|
||||
netstack: netstack.IsEnabled(),
|
||||
localForwarding: enableLocalForwarding,
|
||||
peerRulesMap: make(map[nbid.RuleID]*PeerRule),
|
||||
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
|
||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||
portDNATRules: []portDNATRule{},
|
||||
@@ -488,75 +489,284 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPeerFiltering rule to the firewall
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
func (m *Manager) AddPeerFiltering(
|
||||
// addPeerFiltering installs an input-chain rule that matches packets
|
||||
// by source only. Called from AddFilterRule when the caller doesn't
|
||||
// specify a destination. Mixed-family inputs are split: each family
|
||||
// gets its own rule with a family-correct ipLayer so packet decoding
|
||||
// matches what the matcher expects.
|
||||
func (m *Manager) addPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
sources []netip.Prefix,
|
||||
proto firewall.Protocol,
|
||||
sPort *firewall.Port,
|
||||
dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
_ string,
|
||||
) ([]firewall.Rule, error) {
|
||||
// TODO: fix in upper layers
|
||||
i, ok := netip.AddrFromSlice(ip)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid IP: %s", ip)
|
||||
}
|
||||
|
||||
i = i.Unmap()
|
||||
r := PeerRule{
|
||||
id: uuid.New().String(),
|
||||
mgmtId: id,
|
||||
ip: i,
|
||||
ipLayer: layers.LayerTypeIPv6,
|
||||
matchByIP: true,
|
||||
drop: action == firewall.ActionDrop,
|
||||
}
|
||||
if i.Is4() {
|
||||
r.ipLayer = layers.LayerTypeIPv4
|
||||
}
|
||||
|
||||
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
|
||||
r.matchByIP = false
|
||||
}
|
||||
|
||||
r.sPort = sPort
|
||||
r.dPort = dPort
|
||||
|
||||
r.protoLayer = protoToLayer(proto, r.ipLayer)
|
||||
|
||||
m.mutex.Lock()
|
||||
var targetMap map[netip.Addr]RuleSet
|
||||
if r.drop {
|
||||
targetMap = m.incomingDenyRules
|
||||
} else {
|
||||
targetMap = m.incomingRules
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if sourcesMatchAny(sources) {
|
||||
return []firewall.Rule{m.addOnePeerRule(id, sources, layerTypeAll, true, proto, sPort, dPort, action)}, nil
|
||||
}
|
||||
|
||||
if _, ok := targetMap[r.ip]; !ok {
|
||||
targetMap[r.ip] = make(RuleSet)
|
||||
v4, v6 := splitPrefixesByFamily(sources)
|
||||
out := make([]firewall.Rule, 0, 2)
|
||||
if len(v4) > 0 {
|
||||
out = append(out, m.addOnePeerRule(id, v4, layers.LayerTypeIPv4, false, proto, sPort, dPort, action))
|
||||
}
|
||||
targetMap[r.ip][r.id] = r
|
||||
m.mutex.Unlock()
|
||||
return []firewall.Rule{&r}, nil
|
||||
if len(v6) > 0 {
|
||||
out = append(out, m.addOnePeerRule(id, v6, layers.LayerTypeIPv6, false, proto, sPort, dPort, action))
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (m *Manager) AddRouteFiltering(
|
||||
// addOnePeerRule builds and registers a single-family peer rule, or
|
||||
// returns the existing rule when one with the same content key is
|
||||
// already installed. The caller must hold m.mutex. The content key is
|
||||
// the shared GenerateRuleID with an empty destination, so peer
|
||||
// rules dedup the same way route rules and the kernel backends do.
|
||||
//
|
||||
// There is no refcount: a content key is installed once and deleted on
|
||||
// the first DeleteFilterRule for that key. The caller must therefore
|
||||
// key its own tracking by the returned rule id so add and delete stay
|
||||
// balanced per content key; the acl manager does this via
|
||||
// peerRulesPairs. The content key is order-independent, so callers
|
||||
// passing the same sources in any order dedup to one rule.
|
||||
func (m *Manager) addOnePeerRule(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
ipLayer gopacket.LayerType,
|
||||
matchAny bool,
|
||||
proto firewall.Protocol,
|
||||
sPort, dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) *PeerRule {
|
||||
ruleKey := nbid.GenerateRuleID(sources, firewall.Network{}, proto, sPort, dPort, action)
|
||||
if existing, ok := m.peerRulesMap[ruleKey]; ok {
|
||||
return existing
|
||||
}
|
||||
|
||||
rule := m.buildPeerRule(ruleKey, id, sources, ipLayer, matchAny, proto, sPort, dPort, action)
|
||||
m.registerPeerRule(rule)
|
||||
return rule
|
||||
}
|
||||
|
||||
func (m *Manager) buildPeerRule(
|
||||
ruleKey nbid.RuleID,
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
ipLayer gopacket.LayerType,
|
||||
matchAny bool,
|
||||
proto firewall.Protocol,
|
||||
sPort, dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) *PeerRule {
|
||||
r := &PeerRule{
|
||||
id: ruleKey,
|
||||
mgmtId: id,
|
||||
sources: sources,
|
||||
matchAny: matchAny,
|
||||
action: action,
|
||||
srcPort: sPort,
|
||||
dstPort: dPort,
|
||||
}
|
||||
if !matchAny {
|
||||
r.sourceAddrs = make(map[netip.Addr]struct{}, len(sources))
|
||||
for _, p := range sources {
|
||||
if p.Bits() == p.Addr().BitLen() {
|
||||
r.sourceAddrs[p.Addr()] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
r.protoLayer = protoToLayer(proto, ipLayer)
|
||||
return r
|
||||
}
|
||||
|
||||
// registerPeerRule records a freshly built peer rule in the matching
|
||||
// slice, index, and dedup map. The caller must hold m.mutex.
|
||||
func (m *Manager) registerPeerRule(r *PeerRule) {
|
||||
if r.action == firewall.ActionDrop {
|
||||
m.incomingDenyRules = append(m.incomingDenyRules, r)
|
||||
m.incomingDenyIndex.add(r)
|
||||
} else {
|
||||
m.incomingRules = append(m.incomingRules, r)
|
||||
m.incomingAcceptIndex.add(r)
|
||||
}
|
||||
m.peerRulesMap[r.id] = r
|
||||
}
|
||||
|
||||
// splitPrefixesByFamily partitions a mixed-family prefix list into v4
|
||||
// and v6 buckets. v4-mapped v6 addresses are normalized to v4.
|
||||
func splitPrefixesByFamily(sources []netip.Prefix) (v4, v6 []netip.Prefix) {
|
||||
for _, p := range sources {
|
||||
addr := p.Addr()
|
||||
if addr.Is4() || addr.Is4In6() {
|
||||
v4 = append(v4, firewall.UnmapPrefix(p))
|
||||
} else {
|
||||
v6 = append(v6, p)
|
||||
}
|
||||
}
|
||||
return v4, v6
|
||||
}
|
||||
|
||||
// peerRuleIndex is the source-side dispatcher consulted on the packet
|
||||
// hot path. It separates rules into three disjoint buckets based on
|
||||
// the shape of their source list. Every rule lives in exactly one:
|
||||
//
|
||||
// - bySource: every source is a host prefix (full-bit-length for
|
||||
// its family: /32 for v4, /128 for v6). The map key is the
|
||||
// concrete source address, so a hit guarantees the rule's source
|
||||
// filter passes; the matcher proceeds straight to proto/port
|
||||
// checks without re-verifying the source.
|
||||
// - byCIDR: any source list containing a prefix coarser than a
|
||||
// single host. The matcher walks this slice and runs prefix
|
||||
// Contains() per rule. Expected to be empty for typical peer
|
||||
// ACLs (always host prefixes) and small even when populated.
|
||||
// - matchAny: at least one /0 source. Always matches every packet;
|
||||
// no per-rule source check needed.
|
||||
//
|
||||
// Maintained incrementally by Add/DeleteFilterRule, never rebuilt.
|
||||
type peerRuleIndex struct {
|
||||
bySource map[netip.Addr][]*PeerRule
|
||||
byCIDR []*PeerRule
|
||||
matchAny []*PeerRule
|
||||
}
|
||||
|
||||
func (i *peerRuleIndex) add(r *PeerRule) {
|
||||
switch {
|
||||
case r.matchAny:
|
||||
i.matchAny = append(i.matchAny, r)
|
||||
case hasNonHostSource(r):
|
||||
i.byCIDR = append(i.byCIDR, r)
|
||||
default:
|
||||
if i.bySource == nil {
|
||||
i.bySource = make(map[netip.Addr][]*PeerRule)
|
||||
}
|
||||
for a := range r.sourceAddrs {
|
||||
i.bySource[a] = append(i.bySource[a], r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (i *peerRuleIndex) remove(r *PeerRule) {
|
||||
switch {
|
||||
case r.matchAny:
|
||||
i.matchAny = slices.DeleteFunc(i.matchAny, eqRule(r))
|
||||
case hasNonHostSource(r):
|
||||
i.byCIDR = slices.DeleteFunc(i.byCIDR, eqRule(r))
|
||||
default:
|
||||
if i.bySource == nil {
|
||||
return
|
||||
}
|
||||
for a := range r.sourceAddrs {
|
||||
entries := slices.DeleteFunc(i.bySource[a], eqRule(r))
|
||||
if len(entries) == 0 {
|
||||
delete(i.bySource, a)
|
||||
} else {
|
||||
i.bySource[a] = entries
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (i *peerRuleIndex) reset() {
|
||||
i.bySource = nil
|
||||
i.byCIDR = i.byCIDR[:0]
|
||||
i.matchAny = i.matchAny[:0]
|
||||
}
|
||||
|
||||
func eqRule(target *PeerRule) func(*PeerRule) bool {
|
||||
return func(p *PeerRule) bool { return p == target }
|
||||
}
|
||||
|
||||
// hasNonHostSource reports whether the rule has any source prefix
|
||||
// that is not a single host address. Called only at add/remove time,
|
||||
// not on the packet path.
|
||||
func hasNonHostSource(r *PeerRule) bool {
|
||||
for _, p := range r.sources {
|
||||
if p.Bits() != p.Addr().BitLen() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// sourcesMatchAny reports whether the source list matches every source,
|
||||
// i.e. contains an explicit /0 prefix. An empty list does not qualify:
|
||||
// AddFilterRule rejects it with ErrNoSources, so "match any" is always
|
||||
// the deliberate /0 case.
|
||||
func sourcesMatchAny(sources []netip.Prefix) bool {
|
||||
for _, p := range sources {
|
||||
if p.Bits() == 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// AddFilterRule is the unified entry point for both peer (input chain)
|
||||
// and route (forward chain) filtering rules. The destination
|
||||
// distinguishes the two semantics: a zero Network installs an
|
||||
// input-side rule that matches by source only; a set Network installs
|
||||
// a forward-side rule that also matches the destination.
|
||||
func (m *Manager) AddFilterRule(
|
||||
id []byte,
|
||||
sources []netip.Prefix,
|
||||
destination firewall.Network,
|
||||
proto firewall.Protocol,
|
||||
sPort, dPort *firewall.Port,
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
) ([]firewall.Rule, error) {
|
||||
if len(sources) == 0 {
|
||||
return nil, firewall.ErrNoSources
|
||||
}
|
||||
if destination.IsPrefix() || destination.IsSet() {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
r, err := m.addRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []firewall.Rule{r}, nil
|
||||
}
|
||||
// Peer path: sources are expected to be single-family; the acl
|
||||
// manager keys its selector grouping on family, and management
|
||||
// emits one FirewallRule per family upstream of that. The kernel
|
||||
// backends still split defensively because their per-family
|
||||
// tables can't encode the other family's addresses; uspfilter
|
||||
// has no such constraint.
|
||||
return m.addPeerFiltering(id, sources, proto, sPort, dPort, action)
|
||||
}
|
||||
|
||||
return m.addRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
// DeleteFilterRule deletes a filtering rule. The rule's underlying type
|
||||
// is used to route to the correct internal path.
|
||||
func (m *Manager) DeleteFilterRule(rule firewall.Rule) error {
|
||||
switch r := rule.(type) {
|
||||
case *RouteRule:
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
return m.deleteRouteRule(rule)
|
||||
case *PeerRule:
|
||||
return m.deletePeerRule(rule)
|
||||
case firewall.RuleID:
|
||||
// Deletion by bare id. Resolve to the concrete rule rather than
|
||||
// assuming a route: a peer rule if we own that id, otherwise the
|
||||
// route path. Without this, deleting a peer rule by id would
|
||||
// silently miss in the route map.
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
if pr, ok := m.peerRulesMap[r]; ok {
|
||||
return m.deletePeerRuleLocked(pr)
|
||||
}
|
||||
return m.deleteRouteRule(rule)
|
||||
default:
|
||||
// Native firewall route rules implement firewall.Rule but
|
||||
// aren't one of our concrete types; the route path knows how
|
||||
// to forward them.
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
return m.deleteRouteRule(rule)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) addRouteFiltering(
|
||||
@@ -568,18 +778,24 @@ func (m *Manager) addRouteFiltering(
|
||||
action firewall.Action,
|
||||
) (firewall.Rule, error) {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||
rules, err := m.nativeFirewall.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(rules) == 0 {
|
||||
return nil, fmt.Errorf("native firewall returned no rule")
|
||||
}
|
||||
return rules[0], nil
|
||||
}
|
||||
|
||||
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||
ruleKey := nbid.GenerateRuleID(sources, destination, proto, sPort, dPort, action)
|
||||
|
||||
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
|
||||
return existingRule, nil
|
||||
}
|
||||
|
||||
rule := RouteRule{
|
||||
// TODO: consolidate these IDs
|
||||
id: string(ruleKey),
|
||||
id: ruleKey,
|
||||
mgmtId: id,
|
||||
sources: sources,
|
||||
dstSet: destination.Set,
|
||||
@@ -599,25 +815,18 @@ func (m *Manager) addRouteFiltering(
|
||||
return &rule, nil
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
return m.deleteRouteRule(rule)
|
||||
}
|
||||
|
||||
func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
|
||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||
return m.nativeFirewall.DeleteFilterRule(rule)
|
||||
}
|
||||
|
||||
ruleKey := nbid.RuleID(rule.ID())
|
||||
ruleKey := rule.ID()
|
||||
if _, ok := m.routeRulesMap[ruleKey]; !ok {
|
||||
return fmt.Errorf("route rule not found: %s", ruleKey)
|
||||
}
|
||||
|
||||
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
|
||||
return r.id == string(ruleKey)
|
||||
return r.id == ruleKey
|
||||
})
|
||||
if idx < 0 {
|
||||
return fmt.Errorf("route rule not found in slice: %s", ruleKey)
|
||||
@@ -628,8 +837,8 @@ func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePeerRule from the firewall by rule definition
|
||||
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
// deletePeerRule removes an input-chain rule. Acquires m.mutex.
|
||||
func (m *Manager) deletePeerRule(rule firewall.Rule) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
@@ -637,26 +846,31 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||
if !ok {
|
||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||
}
|
||||
return m.deletePeerRuleLocked(r)
|
||||
}
|
||||
|
||||
var sourceMap map[netip.Addr]RuleSet
|
||||
if r.drop {
|
||||
sourceMap = m.incomingDenyRules
|
||||
// deletePeerRuleLocked removes a peer rule from the matching slice,
|
||||
// index, and dedup map. The caller must hold m.mutex.
|
||||
func (m *Manager) deletePeerRuleLocked(r *PeerRule) error {
|
||||
var target *PeerRules
|
||||
if r.action == firewall.ActionDrop {
|
||||
target = &m.incomingDenyRules
|
||||
} else {
|
||||
sourceMap = m.incomingRules
|
||||
target = &m.incomingRules
|
||||
}
|
||||
|
||||
if ruleset, ok := sourceMap[r.ip]; ok {
|
||||
if _, exists := ruleset[r.id]; !exists {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
delete(ruleset, r.id)
|
||||
if len(ruleset) == 0 {
|
||||
delete(sourceMap, r.ip)
|
||||
}
|
||||
} else {
|
||||
pos := slices.IndexFunc(*target, func(p *PeerRule) bool { return p.id == r.id })
|
||||
if pos < 0 {
|
||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||
}
|
||||
|
||||
stored := (*target)[pos]
|
||||
*target = slices.Delete(*target, pos, pos+1)
|
||||
if r.action == firewall.ActionDrop {
|
||||
m.incomingDenyIndex.remove(stored)
|
||||
} else {
|
||||
m.incomingAcceptIndex.remove(stored)
|
||||
}
|
||||
delete(m.peerRulesMap, r.id)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -674,9 +888,11 @@ func (m *Manager) Flush() error { return nil }
|
||||
// resetState clears all firewall rules and closes connection trackers.
|
||||
// Must be called with m.mutex held.
|
||||
func (m *Manager) resetState() {
|
||||
clear(m.outgoingRules)
|
||||
clear(m.incomingDenyRules)
|
||||
clear(m.incomingRules)
|
||||
m.incomingDenyRules = m.incomingDenyRules[:0]
|
||||
m.incomingRules = m.incomingRules[:0]
|
||||
m.incomingDenyIndex.reset()
|
||||
m.incomingAcceptIndex.reset()
|
||||
clear(m.peerRulesMap)
|
||||
clear(m.routeRulesMap)
|
||||
m.routeRules = m.routeRules[:0]
|
||||
m.udpHookOut.Store(nil)
|
||||
@@ -820,11 +1036,11 @@ func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP netip.Addr) {
|
||||
case layers.LayerTypeIPv4:
|
||||
src, _ := netip.AddrFromSlice(d.ip4.SrcIP)
|
||||
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
|
||||
return src, dst
|
||||
return src.Unmap(), dst.Unmap()
|
||||
case layers.LayerTypeIPv6:
|
||||
src, _ := netip.AddrFromSlice(d.ip6.SrcIP)
|
||||
dst, _ := netip.AddrFromSlice(d.ip6.DstIP)
|
||||
return src, dst
|
||||
return src.Unmap(), dst.Unmap()
|
||||
default:
|
||||
return netip.Addr{}, netip.Addr{}
|
||||
}
|
||||
@@ -1404,23 +1620,83 @@ func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingDenyRules[srcIP], d); ok {
|
||||
if mgmtId, filter, ok := m.incomingDenyIndex.match(srcIP, d); ok {
|
||||
return mgmtId, filter
|
||||
}
|
||||
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[srcIP], d); ok {
|
||||
if mgmtId, filter, ok := m.incomingAcceptIndex.match(srcIP, d); ok {
|
||||
return mgmtId, filter
|
||||
}
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[netip.IPv4Unspecified()], d); ok {
|
||||
return mgmtId, filter
|
||||
}
|
||||
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[netip.IPv6Unspecified()], d); ok {
|
||||
return mgmtId, filter
|
||||
}
|
||||
|
||||
return nil, true
|
||||
}
|
||||
|
||||
// match walks the three buckets in order and returns the first rule
|
||||
// that matches src and the decoded packet. Source filtering is
|
||||
// dispatched up-front by bucket: bySource[src] is by definition
|
||||
// source-matching; byCIDR rules need a prefix Contains() check;
|
||||
// matchAny rules apply to every source. Within each bucket the
|
||||
// matcher runs proto/port filters.
|
||||
func (i *peerRuleIndex) match(src netip.Addr, d *decoder) ([]byte, bool, bool) {
|
||||
payloadLayer := d.decoded[1]
|
||||
|
||||
if len(i.bySource) > 0 {
|
||||
if rules := i.bySource[src]; len(rules) > 0 {
|
||||
for _, rule := range rules {
|
||||
if id, drop, ok := matchProto(rule, d, payloadLayer); ok {
|
||||
return id, drop, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, rule := range i.byCIDR {
|
||||
if !prefixesContain(rule.sources, src) {
|
||||
continue
|
||||
}
|
||||
if id, drop, ok := matchProto(rule, d, payloadLayer); ok {
|
||||
return id, drop, true
|
||||
}
|
||||
}
|
||||
for _, rule := range i.matchAny {
|
||||
if id, drop, ok := matchProto(rule, d, payloadLayer); ok {
|
||||
return id, drop, true
|
||||
}
|
||||
}
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
// matchProto applies the proto/port half of a rule against the
|
||||
// decoded packet. Source matching is the caller's responsibility.
|
||||
func matchProto(rule *PeerRule, d *decoder, payloadLayer gopacket.LayerType) ([]byte, bool, bool) {
|
||||
drop := rule.action == firewall.ActionDrop
|
||||
if rule.protoLayer == layerTypeAll {
|
||||
return rule.mgmtId, drop, true
|
||||
}
|
||||
if !protoLayerMatches(rule.protoLayer, payloadLayer) {
|
||||
return nil, false, false
|
||||
}
|
||||
switch payloadLayer {
|
||||
case layers.LayerTypeTCP:
|
||||
if portsMatch(rule.srcPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dstPort, uint16(d.tcp.DstPort)) {
|
||||
return rule.mgmtId, drop, true
|
||||
}
|
||||
case layers.LayerTypeUDP:
|
||||
if portsMatch(rule.srcPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dstPort, uint16(d.udp.DstPort)) {
|
||||
return rule.mgmtId, drop, true
|
||||
}
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
return rule.mgmtId, drop, true
|
||||
}
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
func prefixesContain(sources []netip.Prefix, src netip.Addr) bool {
|
||||
for _, p := range sources {
|
||||
if p.Contains(src) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
||||
if rulePort == nil {
|
||||
return true
|
||||
@@ -1438,39 +1714,6 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
|
||||
payloadLayer := d.decoded[1]
|
||||
|
||||
for _, rule := range rules {
|
||||
if rule.matchByIP && ip.Compare(rule.ip) != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if rule.protoLayer == layerTypeAll {
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
|
||||
if !protoLayerMatches(rule.protoLayer, payloadLayer) {
|
||||
continue
|
||||
}
|
||||
|
||||
switch payloadLayer {
|
||||
case layers.LayerTypeTCP:
|
||||
if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) {
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
case layers.LayerTypeUDP:
|
||||
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||
return rule.mgmtId, rule.drop, true
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false, false
|
||||
}
|
||||
|
||||
// routeACLsPass returns true if the packet is allowed by the route ACLs
|
||||
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) ([]byte, bool) {
|
||||
m.mutex.RLock()
|
||||
|
||||
@@ -94,7 +94,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
stateful: false,
|
||||
setupFunc: func(m *Manager) {
|
||||
// Single rule allowing all traffic
|
||||
_, err := m.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
|
||||
_, err := m.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolALL, nil, nil, fw.ActionAccept)
|
||||
require.NoError(b, err)
|
||||
},
|
||||
desc: "Baseline: Single 'allow all' rule without connection tracking",
|
||||
@@ -114,15 +114,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
// Add explicit rules matching return traffic pattern
|
||||
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
|
||||
ip := generateRandomIPs(1)[0]
|
||||
_, err := m.AddPeerFiltering(
|
||||
_, err := m.AddFilterRule(
|
||||
nil,
|
||||
ip,
|
||||
pfx(ip), fw.Network{},
|
||||
fw.ProtocolTCP,
|
||||
&fw.Port{Values: []uint16{uint16(1024 + i)}},
|
||||
&fw.Port{Values: []uint16{80}},
|
||||
fw.ActionAccept,
|
||||
"",
|
||||
)
|
||||
fw.ActionAccept)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
},
|
||||
@@ -133,15 +131,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
stateful: true,
|
||||
setupFunc: func(m *Manager) {
|
||||
// Add some basic rules but rely on state for established connections
|
||||
_, err := m.AddPeerFiltering(
|
||||
_, err := m.AddFilterRule(
|
||||
nil,
|
||||
net.ParseIP("0.0.0.0"),
|
||||
pfx(net.ParseIP("0.0.0.0")), fw.Network{},
|
||||
fw.ProtocolTCP,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionDrop,
|
||||
"",
|
||||
)
|
||||
fw.ActionDrop)
|
||||
require.NoError(b, err)
|
||||
},
|
||||
desc: "Connection tracking with established connections",
|
||||
@@ -546,7 +542,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -629,7 +625,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
// Single rule to allow all return traffic from port 80
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -739,7 +735,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
|
||||
// Setup initial state based on scenario
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -818,7 +814,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
})
|
||||
|
||||
if sc.rules {
|
||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
@@ -931,7 +927,7 @@ func BenchmarkRouteACLs(b *testing.B) {
|
||||
|
||||
for _, r := range rules {
|
||||
dst := fw.Network{Prefix: r.dest}
|
||||
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
|
||||
_, err := manager.AddFilterRule(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -496,39 +496,35 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if tc.ruleAction == fw.ActionDrop {
|
||||
// add general accept rule for the same IP to test drop rule precedence
|
||||
rules, err := manager.AddPeerFiltering(
|
||||
rules, err := manager.AddFilterRule(
|
||||
nil,
|
||||
net.ParseIP(tc.ruleIP),
|
||||
pfx(net.ParseIP(tc.ruleIP)), fw.Network{},
|
||||
fw.ProtocolALL,
|
||||
nil,
|
||||
nil,
|
||||
fw.ActionAccept,
|
||||
"",
|
||||
)
|
||||
fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, rules)
|
||||
t.Cleanup(func() {
|
||||
for _, rule := range rules {
|
||||
require.NoError(t, manager.DeletePeerRule(rule))
|
||||
require.NoError(t, manager.DeleteFilterRule(rule))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
rules, err := manager.AddPeerFiltering(
|
||||
rules, err := manager.AddFilterRule(
|
||||
nil,
|
||||
net.ParseIP(tc.ruleIP),
|
||||
pfx(net.ParseIP(tc.ruleIP)), fw.Network{},
|
||||
tc.ruleProto,
|
||||
tc.ruleSrcPort,
|
||||
tc.ruleDstPort,
|
||||
tc.ruleAction,
|
||||
"",
|
||||
)
|
||||
tc.ruleAction)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, rules)
|
||||
|
||||
t.Cleanup(func() {
|
||||
for _, rule := range rules {
|
||||
require.NoError(t, manager.DeletePeerRule(rule))
|
||||
require.NoError(t, manager.DeleteFilterRule(rule))
|
||||
}
|
||||
})
|
||||
|
||||
@@ -672,21 +668,21 @@ func TestPeerACLFilteringIPv6(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if tc.ruleAction == fw.ActionDrop {
|
||||
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
|
||||
rules, err := manager.AddFilterRule(nil, pfx(net.ParseIP(tc.ruleIP)), fw.Network{}, fw.ProtocolALL, nil, nil, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
for _, rule := range rules {
|
||||
require.NoError(t, manager.DeletePeerRule(rule))
|
||||
require.NoError(t, manager.DeleteFilterRule(rule))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction, "")
|
||||
rules, err := manager.AddFilterRule(nil, pfx(net.ParseIP(tc.ruleIP)), fw.Network{}, tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, rules)
|
||||
t.Cleanup(func() {
|
||||
for _, rule := range rules {
|
||||
require.NoError(t, manager.DeletePeerRule(rule))
|
||||
require.NoError(t, manager.DeleteFilterRule(rule))
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1405,7 +1401,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if tc.rule.action == fw.ActionDrop {
|
||||
// add general accept rule to test drop rule
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
rule, err := manager.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||
@@ -1415,13 +1411,13 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
fw.ActionAccept,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
require.NotEmpty(t, rule)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||
require.NoError(t, manager.DeleteFilterRule(rule[0]))
|
||||
})
|
||||
}
|
||||
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
rule, err := manager.AddFilterRule(
|
||||
nil,
|
||||
tc.rule.sources,
|
||||
tc.rule.dest,
|
||||
@@ -1431,10 +1427,10 @@ func TestRouteACLFiltering(t *testing.T) {
|
||||
tc.rule.action,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
require.NotEmpty(t, rule)
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||
require.NoError(t, manager.DeleteFilterRule(rule[0]))
|
||||
})
|
||||
|
||||
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||
@@ -1604,7 +1600,7 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var rules []fw.Rule
|
||||
for _, r := range tc.rules {
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
rule, err := manager.AddFilterRule(
|
||||
nil,
|
||||
r.sources,
|
||||
r.dest,
|
||||
@@ -1614,13 +1610,13 @@ func TestRouteACLOrder(t *testing.T) {
|
||||
r.action,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, rule)
|
||||
rules = append(rules, rule)
|
||||
require.NotEmpty(t, rule)
|
||||
rules = append(rules, rule...)
|
||||
}
|
||||
|
||||
t.Cleanup(func() {
|
||||
for _, rule := range rules {
|
||||
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||
require.NoError(t, manager.DeleteFilterRule(rule))
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1655,7 +1651,7 @@ func TestRouteACLSet(t *testing.T) {
|
||||
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||
|
||||
// Add rule that uses the set (initially empty)
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
rule, err := manager.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.Network{Set: set},
|
||||
@@ -1689,7 +1685,7 @@ func TestRouteACLFilteringIPv6(t *testing.T) {
|
||||
manager := setupRoutedManager(t, "10.10.0.100/16")
|
||||
|
||||
v6Dst := netip.MustParsePrefix("fd00:dead:beef::/48")
|
||||
_, err := manager.AddRouteFiltering(
|
||||
_, err := manager.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
|
||||
fw.Network{Prefix: v6Dst},
|
||||
@@ -1700,7 +1696,7 @@ func TestRouteACLFilteringIPv6(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = manager.AddRouteFiltering(
|
||||
_, err = manager.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
|
||||
fw.Network{Prefix: netip.MustParsePrefix("fd00:dead:beef:1::/64")},
|
||||
|
||||
@@ -29,7 +29,7 @@ func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
|
||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||
|
||||
// Add rule first time
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
rule1, err := manager.AddFilterRule(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
@@ -42,7 +42,7 @@ func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
|
||||
require.NotNil(t, rule1)
|
||||
|
||||
// Add the same rule again
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
rule2, err := manager.AddFilterRule(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
@@ -55,7 +55,7 @@ func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
|
||||
require.NotNil(t, rule2)
|
||||
|
||||
// These should be the same (idempotent) like nftables/iptables implementations
|
||||
assert.Equal(t, rule1.ID(), rule2.ID(),
|
||||
assert.Equal(t, rule1[0].ID(), rule2[0].ID(),
|
||||
"Adding the same rule twice should return the same rule ID (idempotent)")
|
||||
|
||||
manager.mutex.RLock()
|
||||
@@ -74,7 +74,7 @@ func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
|
||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||
|
||||
// Add first rule
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
rule1, err := manager.AddFilterRule(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||
@@ -86,7 +86,7 @@ func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add different rule (different destination)
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
rule2, err := manager.AddFilterRule(
|
||||
[]byte("policy-2"),
|
||||
sources,
|
||||
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
|
||||
@@ -97,7 +97,7 @@ func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, rule1.ID(), rule2.ID(),
|
||||
assert.NotEqual(t, rule1[0].ID(), rule2[0].ID(),
|
||||
"Different rules should have different IDs")
|
||||
|
||||
manager.mutex.RLock()
|
||||
@@ -115,7 +115,7 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
rule1, err := manager.AddFilterRule(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
@@ -132,7 +132,7 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
||||
require.True(t, pass, "Traffic should pass with rule in place")
|
||||
|
||||
// Re-add same rule (simulates network map update)
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
rule2, err := manager.AddFilterRule(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
@@ -143,11 +143,11 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Idempotent IDs mean rule1.ID() == rule2.ID(), so the ACL manager
|
||||
// Idempotent IDs mean rule1[0].ID() == rule2[0].ID(), so the ACL manager
|
||||
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
|
||||
// would remove the only matching rule and cause a traffic gap.
|
||||
if rule1.ID() != rule2.ID() {
|
||||
err = manager.DeleteRouteRule(rule1)
|
||||
if rule1[0].ID() != rule2[0].ID() {
|
||||
err = manager.DeleteFilterRule(rule1[0])
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -274,7 +274,7 @@ func TestRouteRuleCountStableAcrossUpdates(t *testing.T) {
|
||||
|
||||
// Simulate 5 network map updates with the same route rule
|
||||
for i := 0; i < 5; i++ {
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
rule, err := manager.AddFilterRule(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
@@ -304,7 +304,7 @@ func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
|
||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||
|
||||
// Add same rule twice
|
||||
rule1, err := manager.AddRouteFiltering(
|
||||
rule1, err := manager.AddFilterRule(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
@@ -315,7 +315,7 @@ func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
rule2, err := manager.AddRouteFiltering(
|
||||
rule2, err := manager.AddFilterRule(
|
||||
[]byte("policy-1"),
|
||||
sources,
|
||||
destination,
|
||||
@@ -326,10 +326,10 @@ func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
|
||||
require.Equal(t, rule1[0].ID(), rule2[0].ID(), "Should return same rule ID")
|
||||
|
||||
// Delete using first reference
|
||||
err = manager.DeleteRouteRule(rule1)
|
||||
err = manager.DeleteFilterRule(rule1[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify traffic no longer passes
|
||||
|
||||
@@ -89,7 +89,7 @@ func TestManagerCreate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
func TestManagerAddFilterRule(t *testing.T) {
|
||||
isSetFilterCalled := false
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error {
|
||||
@@ -109,7 +109,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
|
||||
rule, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
rule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -142,7 +142,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
|
||||
rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
|
||||
rule2, err := m.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, proto, nil, port, action)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -155,42 +155,39 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
t.Errorf("rule should be a PeerRule")
|
||||
continue
|
||||
}
|
||||
// Check if rule exists in deny or allow maps based on action
|
||||
var found bool
|
||||
if peerRule.drop {
|
||||
_, found = m.incomingDenyRules[ip][r.ID()]
|
||||
if peerRule.action == fw.ActionDrop {
|
||||
found = findRuleByID(m.incomingDenyRules, ip, r.ID())
|
||||
} else {
|
||||
_, found = m.incomingRules[ip][r.ID()]
|
||||
found = findRuleByID(m.incomingRules, ip, r.ID())
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("rule2 is not in the expected rules map")
|
||||
t.Errorf("rule2 is not in the expected rules list")
|
||||
}
|
||||
}
|
||||
|
||||
for _, r := range rule2 {
|
||||
err = m.DeletePeerRule(r)
|
||||
err = m.DeleteFilterRule(r)
|
||||
if err != nil {
|
||||
t.Errorf("failed to delete rule: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Check rules are removed from appropriate maps
|
||||
for _, r := range rule2 {
|
||||
peerRule, ok := r.(*PeerRule)
|
||||
if !ok {
|
||||
t.Errorf("rule should be a PeerRule")
|
||||
continue
|
||||
}
|
||||
// Check if rule is removed from deny or allow maps based on action
|
||||
var found bool
|
||||
if peerRule.drop {
|
||||
_, found = m.incomingDenyRules[ip][r.ID()]
|
||||
if peerRule.action == fw.ActionDrop {
|
||||
found = findRuleByID(m.incomingDenyRules, ip, r.ID())
|
||||
} else {
|
||||
_, found = m.incomingRules[ip][r.ID()]
|
||||
found = findRuleByID(m.incomingRules, ip, r.ID())
|
||||
}
|
||||
if found {
|
||||
t.Errorf("rule2 should be removed from the rules map")
|
||||
t.Errorf("rule2 should be removed from the rules list")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -260,36 +257,34 @@ func TestPeerRuleLifecycleDenyRules(t *testing.T) {
|
||||
addr := netip.MustParseAddr("192.168.1.1")
|
||||
|
||||
// Add multiple deny rules for different ports
|
||||
rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||
rule1, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{22}}, fw.ActionDrop)
|
||||
require.NoError(t, err)
|
||||
|
||||
rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{80}}, fw.ActionDrop, "")
|
||||
rule2, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCount := len(m.incomingDenyRules[addr])
|
||||
denyCount := countRulesForAddr(m.incomingDenyRules, addr)
|
||||
m.mutex.RUnlock()
|
||||
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
|
||||
|
||||
// Delete the first deny rule
|
||||
err = m.DeletePeerRule(rule1[0])
|
||||
err = m.DeleteFilterRule(rule1[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCount = len(m.incomingDenyRules[addr])
|
||||
denyCount = countRulesForAddr(m.incomingDenyRules, addr)
|
||||
m.mutex.RUnlock()
|
||||
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
|
||||
|
||||
// Delete the second deny rule
|
||||
err = m.DeletePeerRule(rule2[0])
|
||||
err = m.DeleteFilterRule(rule2[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
_, exists := m.incomingDenyRules[addr]
|
||||
exists := countRulesForAddr(m.incomingDenyRules, addr) > 0
|
||||
m.mutex.RUnlock()
|
||||
require.False(t, exists, "Deny rules IP entry should be cleaned up when empty")
|
||||
require.False(t, exists, "Deny rules should be cleaned up when empty")
|
||||
}
|
||||
|
||||
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
|
||||
@@ -311,27 +306,25 @@ func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
|
||||
// Simulate 10 network map updates: add rule, delete old, add new
|
||||
for i := 0; i < 10; i++ {
|
||||
// Add a deny rule
|
||||
rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||
rules, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{22}}, fw.ActionDrop)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add an allow rule
|
||||
allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
allowRules, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete them (simulating ACL manager cleanup)
|
||||
for _, r := range rules {
|
||||
require.NoError(t, m.DeletePeerRule(r))
|
||||
require.NoError(t, m.DeleteFilterRule(r))
|
||||
}
|
||||
for _, r := range allowRules {
|
||||
require.NoError(t, m.DeletePeerRule(r))
|
||||
require.NoError(t, m.DeleteFilterRule(r))
|
||||
}
|
||||
}
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCount := len(m.incomingDenyRules[addr])
|
||||
allowCount := len(m.incomingRules[addr])
|
||||
denyCount := countRulesForAddr(m.incomingDenyRules, addr)
|
||||
allowCount := countRulesForAddr(m.incomingRules, addr)
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
|
||||
@@ -354,41 +347,39 @@ func TestMixedAllowDenyRulesSameIP(t *testing.T) {
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
|
||||
// Add allow rule for port 80
|
||||
allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||
allowRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add deny rule for port 22
|
||||
denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||
denyRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{22}}, fw.ActionDrop)
|
||||
require.NoError(t, err)
|
||||
|
||||
addr := netip.MustParseAddr("192.168.1.1")
|
||||
m.mutex.RLock()
|
||||
allowCount := len(m.incomingRules[addr])
|
||||
denyCount := len(m.incomingDenyRules[addr])
|
||||
allowCount := countRulesForAddr(m.incomingRules, addr)
|
||||
denyCount := countRulesForAddr(m.incomingDenyRules, addr)
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
|
||||
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
|
||||
|
||||
// Delete allow rule should not affect deny rule
|
||||
err = m.DeletePeerRule(allowRule[0])
|
||||
err = m.DeleteFilterRule(allowRule[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
denyCountAfter := len(m.incomingDenyRules[addr])
|
||||
denyCountAfter := countRulesForAddr(m.incomingDenyRules, addr)
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
|
||||
|
||||
// Delete deny rule
|
||||
err = m.DeletePeerRule(denyRule[0])
|
||||
err = m.DeleteFilterRule(denyRule[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
m.mutex.RLock()
|
||||
_, denyExists := m.incomingDenyRules[addr]
|
||||
_, allowExists := m.incomingRules[addr]
|
||||
denyExists := countRulesForAddr(m.incomingDenyRules, addr) > 0
|
||||
allowExists := countRulesForAddr(m.incomingRules, addr) > 0
|
||||
m.mutex.RUnlock()
|
||||
|
||||
require.False(t, denyExists, "Deny rules should be empty")
|
||||
@@ -411,7 +402,7 @@ func TestManagerReset(t *testing.T) {
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
|
||||
_, err = m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
_, err = m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -423,7 +414,7 @@ func TestManagerReset(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 || len(m.incomingDenyRules) != 0 {
|
||||
if len(m.incomingRules) != 0 || len(m.incomingDenyRules) != 0 {
|
||||
t.Errorf("rules are not empty")
|
||||
}
|
||||
}
|
||||
@@ -449,7 +440,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
proto := fw.ProtocolUDP
|
||||
action := fw.ActionAccept
|
||||
|
||||
_, err = m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||
_, err = m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, nil, action)
|
||||
if err != nil {
|
||||
t.Errorf("failed to add filtering: %v", err)
|
||||
return
|
||||
@@ -621,7 +612,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
start := time.Now()
|
||||
for i := 0; i < testMax; i++ {
|
||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||
_, err = manager.AddFilterRule(nil, pfx(ip), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
|
||||
|
||||
require.NoError(t, err, "failed to add rule")
|
||||
}
|
||||
@@ -858,7 +849,7 @@ func TestUpdateSetMerge(t *testing.T) {
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
}
|
||||
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
rule, err := manager.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.Network{Set: set},
|
||||
@@ -916,7 +907,7 @@ func TestUpdateSetMerge(t *testing.T) {
|
||||
manager.mutex.RLock()
|
||||
foundRule := false
|
||||
for _, r := range manager.routeRules {
|
||||
if r.id == rule.ID() {
|
||||
if r.id == rule[0].ID() {
|
||||
foundRule = true
|
||||
require.Len(t, r.destinations, len(initialPrefixes)+len(newPrefixes),
|
||||
"Rule should have all prefixes merged")
|
||||
@@ -939,7 +930,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
|
||||
|
||||
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||
|
||||
rule, err := manager.AddRouteFiltering(
|
||||
rule, err := manager.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||
fw.Network{Set: set},
|
||||
@@ -965,7 +956,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
|
||||
manager.mutex.RLock()
|
||||
foundRule := false
|
||||
for _, r := range manager.routeRules {
|
||||
if r.id == rule.ID() {
|
||||
if r.id == rule[0].ID() {
|
||||
foundRule = true
|
||||
// Should have deduplicated to 2 prefixes
|
||||
require.Len(t, r.destinations, 2, "Duplicate prefixes should be removed")
|
||||
@@ -998,7 +989,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
|
||||
// Check that all prefixes are included (no deduplication of overlapping prefixes)
|
||||
manager.mutex.RLock()
|
||||
for _, r := range manager.routeRules {
|
||||
if r.id == rule.ID() {
|
||||
if r.id == rule[0].ID() {
|
||||
// Should have all 4 prefixes (2 original + 2 new more general ones)
|
||||
require.Len(t, r.destinations, 4,
|
||||
"Overlapping prefixes should not be deduplicated")
|
||||
|
||||
327
client/firewall/uspfilter/peer_acl_bench_test.go
Normal file
327
client/firewall/uspfilter/peer_acl_bench_test.go
Normal file
@@ -0,0 +1,327 @@
|
||||
//go:build uspbench
|
||||
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
// BenchmarkPeerACLMatch measures the per-packet cost of the peer ACL
|
||||
// matcher (peerACLsBlock) across realistic shapes: M distinct policy
|
||||
// rules, each with K source peers in its set.
|
||||
//
|
||||
// With the reverse-source index, miss cost is independent of M and
|
||||
// hit cost grows only with the number of rules touching a single
|
||||
// srcIP, not with total rule count.
|
||||
func BenchmarkPeerACLMatch(b *testing.B) {
|
||||
shapes := []struct{ M, K int }{
|
||||
{1, 100}, {10, 100}, {50, 100}, {100, 100}, {100, 1000},
|
||||
}
|
||||
families := []struct {
|
||||
name string
|
||||
v6 bool
|
||||
}{{"v4", false}, {"v6", true}}
|
||||
|
||||
for _, fam := range families {
|
||||
for _, s := range shapes {
|
||||
b.Run(fmt.Sprintf("%s/M=%d/K=%d/hit", fam.name, s.M, s.K), func(b *testing.B) {
|
||||
runPeerACLBench(b, s.M, s.K, true, fam.v6)
|
||||
})
|
||||
b.Run(fmt.Sprintf("%s/M=%d/K=%d/miss", fam.name, s.M, s.K), func(b *testing.B) {
|
||||
runPeerACLBench(b, s.M, s.K, false, fam.v6)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func runPeerACLBench(b *testing.B, m, k int, hit, v6 bool) {
|
||||
log.SetOutput(io.Discard) // keep manager logs out of the benchmark output
|
||||
|
||||
// Miss packets are dropped, so they always traverse the full peer
|
||||
// ACL matcher (every bucket) without short-circuiting and without
|
||||
// touching conntrack. Disable conntrack for the miss case so it
|
||||
// measures the matcher, not established-state lookups. The hit case
|
||||
// keeps conntrack on: an accepted packet reaches trackInbound, which
|
||||
// needs the trackers conntrack creates.
|
||||
if !hit {
|
||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||
}
|
||||
|
||||
bits := 32
|
||||
genPkt := generatePacket
|
||||
addrs := uniqueAddrs
|
||||
if v6 {
|
||||
bits = 128
|
||||
genPkt = generatePacket6
|
||||
addrs = uniqueAddrs6
|
||||
}
|
||||
|
||||
// dstIP must be a local IP so filterInbound takes the local-traffic
|
||||
// path (handleLocalTraffic → peerACLsBlock) we are measuring; an
|
||||
// address the manager doesn't own would be treated as routed and
|
||||
// short-circuit before the peer matcher.
|
||||
dstIP := addrs(1, 2)[0]
|
||||
mockAddr := wgaddr.Address{IP: dstIP, Network: netip.PrefixFrom(dstIP, bits)}
|
||||
if v6 {
|
||||
// The local-IP manager needs a valid v4 address too; expose the v6
|
||||
// dst as the interface's IPv6 so IsLocalIP recognizes it.
|
||||
mockAddr = wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.64.0.1"),
|
||||
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
||||
IPv6: dstIP,
|
||||
IPv6Net: netip.PrefixFrom(dstIP, bits),
|
||||
}
|
||||
}
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address { return mockAddr },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
b.Cleanup(func() { require.NoError(b, manager.Close(nil)) })
|
||||
|
||||
// Generate M policies × K source peers, all distinct.
|
||||
all := addrs(m*k, 1)
|
||||
for i := 0; i < m; i++ {
|
||||
sources := make([]netip.Prefix, k)
|
||||
for j, a := range all[i*k : (i+1)*k] {
|
||||
sources[j] = netip.PrefixFrom(a, bits)
|
||||
}
|
||||
_, err := manager.AddFilterRule(
|
||||
nil, sources, fw.Network{}, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{uint16(80 + i)}},
|
||||
fw.ActionAccept)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
|
||||
// Hit: cycle through real sources, picking the matching policy's port.
|
||||
// Miss: a source from a disjoint range, port 80 (matches no policy).
|
||||
var pktFn func(i int) []byte
|
||||
if hit {
|
||||
pktFn = func(i int) []byte {
|
||||
policy := i % m
|
||||
src := all[policy*k+(i%k)]
|
||||
return genPkt(b, src.AsSlice(), dstIP.AsSlice(),
|
||||
uint16(1024+i%60000), uint16(80+policy), layers.IPProtocolTCP)
|
||||
}
|
||||
} else {
|
||||
miss := addrs(4096, 99)
|
||||
pktFn = func(i int) []byte {
|
||||
return genPkt(b, miss[i%len(miss)].AsSlice(), dstIP.AsSlice(),
|
||||
uint16(1024+i%60000), 80, layers.IPProtocolTCP)
|
||||
}
|
||||
}
|
||||
|
||||
// Pre-build a pool to avoid allocations dominating the measurement.
|
||||
pool := make([][]byte, 1024)
|
||||
for i := range pool {
|
||||
pool[i] = pktFn(i)
|
||||
}
|
||||
|
||||
// Confirm the matcher is actually exercised: a hit packet must be
|
||||
// allowed and a miss packet dropped. Without this the benchmark
|
||||
// could silently time the routed early-return instead.
|
||||
require.Equal(b, !hit, manager.filterInbound(pool[0], 0),
|
||||
"benchmark must reach the peer ACL matcher")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.filterInbound(pool[i%len(pool)], 0)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPeerACLIndexMemory reports the resident memory cost of
|
||||
// the source-keyed index across realistic deployment shapes. Two
|
||||
// dimensions matter: (M, K), the number of policies × peers-per-policy,
|
||||
// and overlap, the fraction of peers shared between policies.
|
||||
//
|
||||
// The output uses ReportMetric("bytes/rule") so the cost can be
|
||||
// compared across shapes directly. Total bytes = bytes/rule * M.
|
||||
func BenchmarkPeerACLIndexMemory(b *testing.B) {
|
||||
cases := []struct {
|
||||
name string
|
||||
M, K int
|
||||
overlapFrac float64 // 0 = disjoint per-policy sources, 1 = all share the same pool
|
||||
}{
|
||||
{"M=10/K=100/disjoint", 10, 100, 0},
|
||||
{"M=100/K=100/disjoint", 100, 100, 0},
|
||||
{"M=100/K=1000/disjoint", 100, 1000, 0},
|
||||
{"M=100/K=1000/overlap=0.5", 100, 1000, 0.5},
|
||||
{"M=100/K=1000/overlap=1.0", 100, 1000, 1.0},
|
||||
{"M=1000/K=100/overlap=1.0", 1000, 100, 1.0},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
b.Run(c.name, func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
mgr, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
|
||||
populateIndexedRules(b, mgr, c.M, c.K, c.overlapFrac)
|
||||
|
||||
runtime.GC()
|
||||
var ms runtime.MemStats
|
||||
runtime.ReadMemStats(&ms)
|
||||
before := ms.HeapAlloc
|
||||
|
||||
// Drop the manager's external roots so we can isolate
|
||||
// the index cost. We hold the manager itself live; the
|
||||
// index is what we measure on the second pass.
|
||||
mgr.incomingAcceptIndex.reset()
|
||||
mgr.incomingDenyIndex.reset()
|
||||
mgr.incomingRules = mgr.incomingRules[:0]
|
||||
mgr.incomingDenyRules = mgr.incomingDenyRules[:0]
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&ms)
|
||||
after := ms.HeapAlloc
|
||||
|
||||
delta := int64(before) - int64(after)
|
||||
if delta < 0 {
|
||||
delta = 0
|
||||
}
|
||||
b.ReportMetric(float64(delta)/float64(c.M), "bytes/rule")
|
||||
b.ReportMetric(float64(delta), "bytes/total")
|
||||
|
||||
require.NoError(b, mgr.Close(nil))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func populateIndexedRules(b *testing.B, mgr *Manager, m, k int, overlapFrac float64) {
|
||||
b.Helper()
|
||||
pool := uniqueAddrs(k+m*k, 1) // big enough universe
|
||||
sharedLen := int(float64(k) * overlapFrac)
|
||||
if sharedLen > k {
|
||||
sharedLen = k
|
||||
}
|
||||
shared := pool[:sharedLen]
|
||||
uniquePool := pool[sharedLen:]
|
||||
|
||||
for i := 0; i < m; i++ {
|
||||
sources := make([]netip.Prefix, 0, k)
|
||||
for _, a := range shared {
|
||||
sources = append(sources, netip.PrefixFrom(a, 32))
|
||||
}
|
||||
// each policy gets (k-sharedLen) addresses unique to it from the unique pool
|
||||
unique := uniquePool[i*(k-sharedLen) : (i+1)*(k-sharedLen)]
|
||||
for _, a := range unique {
|
||||
sources = append(sources, netip.PrefixFrom(a, 32))
|
||||
}
|
||||
_, err := mgr.AddFilterRule(
|
||||
nil, sources, fw.Network{}, fw.ProtocolTCP, nil,
|
||||
&fw.Port{Values: []uint16{uint16(80 + i)}},
|
||||
fw.ActionAccept)
|
||||
require.NoError(b, err)
|
||||
}
|
||||
}
|
||||
|
||||
// uniqueAddrs returns n distinct addrs. Seeds 1, 2 are used for
|
||||
// policy sources / dst; seed 99 puts misses in 10/8.
|
||||
func uniqueAddrs(n int, seed int64) []netip.Addr {
|
||||
out := make([]netip.Addr, 0, n)
|
||||
seen := make(map[netip.Addr]struct{}, n)
|
||||
r := rand.New(rand.NewSource(seed))
|
||||
miss := seed == 99
|
||||
for len(out) < n {
|
||||
var b [4]byte
|
||||
if miss {
|
||||
b[0] = 10
|
||||
b[1] = byte(r.Intn(256))
|
||||
} else {
|
||||
b[0] = 100
|
||||
b[1] = byte(64 + r.Intn(63))
|
||||
}
|
||||
b[2] = byte(r.Intn(256))
|
||||
b[3] = byte(1 + r.Intn(254))
|
||||
a := netip.AddrFrom4(b)
|
||||
if _, ok := seen[a]; ok {
|
||||
continue
|
||||
}
|
||||
seen[a] = struct{}{}
|
||||
out = append(out, a)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// uniqueAddrs6 mirrors uniqueAddrs for IPv6: sources come from the ULA
|
||||
// range fd00::/8, the miss set (seed 99) from 2001:db8::/32 so it is
|
||||
// disjoint from any source.
|
||||
func uniqueAddrs6(n int, seed int64) []netip.Addr {
|
||||
out := make([]netip.Addr, 0, n)
|
||||
seen := make(map[netip.Addr]struct{}, n)
|
||||
r := rand.New(rand.NewSource(seed))
|
||||
miss := seed == 99
|
||||
for len(out) < n {
|
||||
var b [16]byte
|
||||
if miss {
|
||||
b[0], b[1], b[2], b[3] = 0x20, 0x01, 0x0d, 0xb8
|
||||
} else {
|
||||
b[0] = 0xfd
|
||||
}
|
||||
for x := 8; x < 16; x++ {
|
||||
b[x] = byte(r.Intn(256))
|
||||
}
|
||||
a := netip.AddrFrom16(b)
|
||||
if _, ok := seen[a]; ok {
|
||||
continue
|
||||
}
|
||||
seen[a] = struct{}{}
|
||||
out = append(out, a)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// generatePacket6 builds an IPv6 TCP/UDP packet, mirroring
|
||||
// generatePacket for the v4 case.
|
||||
func generatePacket6(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16, protocol layers.IPProtocol) []byte {
|
||||
b.Helper()
|
||||
|
||||
ipv6 := &layers.IPv6{
|
||||
Version: 6,
|
||||
HopLimit: 64,
|
||||
NextHeader: protocol,
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
}
|
||||
|
||||
var transportLayer gopacket.SerializableLayer
|
||||
switch protocol {
|
||||
case layers.IPProtocolTCP:
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(srcPort),
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
SYN: true,
|
||||
}
|
||||
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv6))
|
||||
transportLayer = tcp
|
||||
case layers.IPProtocolUDP:
|
||||
udp := &layers.UDP{
|
||||
SrcPort: layers.UDPPort(srcPort),
|
||||
DstPort: layers.UDPPort(dstPort),
|
||||
}
|
||||
require.NoError(b, udp.SetNetworkLayerForChecksum(ipv6))
|
||||
transportLayer = udp
|
||||
}
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv6, transportLayer, gopacket.Payload("test")))
|
||||
return buf.Bytes()
|
||||
}
|
||||
175
client/firewall/uspfilter/peer_acl_dedup_test.go
Normal file
175
client/firewall/uspfilter/peer_acl_dedup_test.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbiface "github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
func newTestManager(t *testing.T) *Manager {
|
||||
t.Helper()
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err, "create manager")
|
||||
return m
|
||||
}
|
||||
|
||||
// TestAddPeerFiltering_DeduplicatesIdenticalRules verifies that adding
|
||||
// the same peer rule twice does not create two backing rules. The acl
|
||||
// manager keys its own cache, but the firewall backend must be
|
||||
// idempotent on its own so a double-apply cannot leak rules, matching
|
||||
// the route path and the kernel backends.
|
||||
func TestAddPeerFiltering_DeduplicatesIdenticalRules(t *testing.T) {
|
||||
m := newTestManager(t)
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
|
||||
first, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err, "first add")
|
||||
require.Len(t, first, 1, "first add should yield one rule")
|
||||
|
||||
second, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err, "second add")
|
||||
require.Len(t, second, 1, "second add should yield one rule")
|
||||
|
||||
assert.Equal(t, first[0].ID(), second[0].ID(), "duplicate add should return the same rule id")
|
||||
assert.Len(t, m.incomingDenyRules, 1, "duplicate add must not create a second backing rule")
|
||||
}
|
||||
|
||||
// TestDeletePeerFiltering_NoRefcountSingleDeleteRemoves locks the
|
||||
// backend's no-refcount contract: a content key installed twice is
|
||||
// still one rule, and the first DeleteFilterRule removes it. The
|
||||
// backend does not refcount, so balance is the caller's job (it keys
|
||||
// its tracking by the returned id and deletes once per key). If this
|
||||
// ever silently grew a refcount, the acl manager's delete accounting
|
||||
// would diverge from the kernel.
|
||||
func TestDeletePeerFiltering_NoRefcountSingleDeleteRemoves(t *testing.T) {
|
||||
m := newTestManager(t)
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
|
||||
first, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err, "first add")
|
||||
require.Len(t, first, 1)
|
||||
|
||||
second, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err, "second add")
|
||||
require.Len(t, second, 1)
|
||||
require.Equal(t, first[0].ID(), second[0].ID(), "dedup to one rule")
|
||||
require.Len(t, m.incomingDenyRules, 1, "still one backing rule after duplicate add")
|
||||
|
||||
require.NoError(t, m.DeleteFilterRule(first[0]), "delete once")
|
||||
assert.Empty(t, m.incomingDenyRules, "single delete removes the backing rule (no refcount)")
|
||||
assert.NotContains(t, m.peerRulesMap, first[0].ID(), "dedup map entry cleared")
|
||||
}
|
||||
|
||||
// TestDeletePeerFiltering_ByRuleID verifies a peer rule can be deleted
|
||||
// by its bare RuleID, not only by the concrete *PeerRule, so a caller
|
||||
// that tracks ids cannot accidentally fall through to the route path.
|
||||
func TestDeletePeerFiltering_ByRuleID(t *testing.T) {
|
||||
m := newTestManager(t)
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
rules, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rules, 1)
|
||||
|
||||
require.NoError(t, m.DeleteFilterRule(rules[0].ID()), "delete by bare id")
|
||||
assert.Empty(t, m.incomingRules, "rule removed when deleted by id")
|
||||
assert.NotContains(t, m.peerRulesMap, rules[0].ID())
|
||||
}
|
||||
|
||||
// TestAddPeerFiltering_DeterministicID verifies the peer rule id is a
|
||||
// content hash, not a random UUID: identical inputs produce the same id
|
||||
// across independent managers. A random id breaks caller-side dedup.
|
||||
func TestAddPeerFiltering_DeterministicID(t *testing.T) {
|
||||
ip := net.ParseIP("10.0.0.5")
|
||||
proto := fw.ProtocolUDP
|
||||
port := &fw.Port{Values: []uint16{53}}
|
||||
action := fw.ActionAccept
|
||||
|
||||
m1 := newTestManager(t)
|
||||
r1, err := m1.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r1, 1)
|
||||
|
||||
m2 := newTestManager(t)
|
||||
r2, err := m2.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r2, 1)
|
||||
|
||||
assert.Equal(t, r1[0].ID(), r2[0].ID(), "same inputs must produce the same rule id")
|
||||
}
|
||||
|
||||
// TestAddPeerFiltering_DistinctRulesNotDeduped verifies that rules
|
||||
// differing only by port are kept separate.
|
||||
func TestAddPeerFiltering_DistinctRulesNotDeduped(t *testing.T) {
|
||||
m := newTestManager(t)
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
action := fw.ActionAccept
|
||||
|
||||
r80, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, &fw.Port{Values: []uint16{80}}, action)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r80, 1)
|
||||
|
||||
r443, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, &fw.Port{Values: []uint16{443}}, action)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, r443, 1)
|
||||
|
||||
assert.NotEqual(t, r80[0].ID(), r443[0].ID(), "different ports must produce different rule ids")
|
||||
assert.Len(t, m.incomingRules, 2, "distinct rules must both be stored")
|
||||
}
|
||||
|
||||
// TestAddPeerFiltering_SourceVsDestPortNotDeduped verifies that a rule
|
||||
// matching on source port and one matching on destination port for the
|
||||
// same selector do not collide: the port lands in a different slot, so
|
||||
// the content key must differ.
|
||||
func TestAddPeerFiltering_SourceVsDestPortNotDeduped(t *testing.T) {
|
||||
m := newTestManager(t)
|
||||
|
||||
ip := net.ParseIP("192.168.1.1")
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionAccept
|
||||
|
||||
dPortRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, dPortRule, 1)
|
||||
|
||||
sPortRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, port, nil, action)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, sPortRule, 1)
|
||||
|
||||
assert.NotEqual(t, dPortRule[0].ID(), sPortRule[0].ID(), "source-port and dest-port matches must produce different rule ids")
|
||||
}
|
||||
|
||||
// TestAddFilterRule_EmptySourcesRejected verifies that an empty source
|
||||
// list is rejected rather than treated as "match any". "Match any" must
|
||||
// be an explicit /0, so a zeroed list can never silently widen a rule to
|
||||
// every source.
|
||||
func TestAddFilterRule_EmptySourcesRejected(t *testing.T) {
|
||||
m := newTestManager(t)
|
||||
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
|
||||
_, err := m.AddFilterRule(nil, nil, fw.Network{}, proto, nil, port, fw.ActionAccept)
|
||||
require.ErrorIs(t, err, fw.ErrNoSources, "empty sources must be rejected")
|
||||
assert.Empty(t, m.incomingRules, "no rule should be stored for empty sources")
|
||||
}
|
||||
106
client/firewall/uspfilter/peer_acl_ipv6_test.go
Normal file
106
client/firewall/uspfilter/peer_acl_ipv6_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
nbiface "github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
func newV6TestManager(t *testing.T, localV6 string) *Manager {
|
||||
t.Helper()
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.10.0.100"),
|
||||
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||
IPv6: netip.MustParseAddr(localV6),
|
||||
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
||||
}
|
||||
},
|
||||
}
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err, "create manager")
|
||||
t.Cleanup(func() { require.NoError(t, m.Close(nil)) })
|
||||
return m
|
||||
}
|
||||
|
||||
func v6UDPPacket(t *testing.T, src, dst string, dstPort uint16) []byte {
|
||||
t.Helper()
|
||||
ip6 := &layers.IPv6{
|
||||
Version: 6,
|
||||
HopLimit: 64,
|
||||
NextHeader: layers.IPProtocolUDP,
|
||||
SrcIP: net.ParseIP(src),
|
||||
DstIP: net.ParseIP(dst),
|
||||
}
|
||||
udp := &layers.UDP{SrcPort: 51334, DstPort: layers.UDPPort(dstPort)}
|
||||
require.NoError(t, udp.SetNetworkLayerForChecksum(ip6))
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
require.NoError(t, gopacket.SerializeLayers(buf, opts, ip6, udp, gopacket.Payload("test")))
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
// TestPeerACL_IPv6HostRule verifies the source index resolves /128 v6
|
||||
// rules: a matching v6 source is accepted, a non-matching one is
|
||||
// denied by the default. This is the end-to-end proof that the index
|
||||
// is not v4-only.
|
||||
func TestPeerACL_IPv6HostRule(t *testing.T) {
|
||||
m := newV6TestManager(t, "fd00::100")
|
||||
|
||||
src := net.ParseIP("fd00::1")
|
||||
_, err := m.AddFilterRule(nil, pfx(src), fw.Network{}, fw.ProtocolUDP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionAccept)
|
||||
require.NoError(t, err, "add v6 accept rule")
|
||||
|
||||
require.False(t, m.filterInbound(v6UDPPacket(t, "fd00::1", "fd00::100", 53), 0),
|
||||
"v6 packet from the allowed /128 source must be accepted")
|
||||
require.True(t, m.filterInbound(v6UDPPacket(t, "fd00::2", "fd00::100", 53), 0),
|
||||
"v6 packet from an unlisted source must be denied by default")
|
||||
}
|
||||
|
||||
// TestPeerACL_IPv6IndexBuckets verifies that v6 sources land in the
|
||||
// right index bucket: a /128 in bySource keyed by its address, a
|
||||
// coarser prefix in byCIDR, and ::/0 in matchAny.
|
||||
func TestPeerACL_IPv6IndexBuckets(t *testing.T) {
|
||||
m := newV6TestManager(t, "fd00::100")
|
||||
port := &fw.Port{Values: []uint16{53}}
|
||||
|
||||
host := netip.MustParseAddr("fd00::1")
|
||||
_, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(host, 128)}, fw.Network{}, fw.ProtocolUDP, nil, port, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, m.incomingAcceptIndex.bySource, host, "/128 v6 source must be indexed by address")
|
||||
|
||||
_, err = m.AddFilterRule(nil, []netip.Prefix{netip.MustParsePrefix("fd00:dead::/64")}, fw.Network{}, fw.ProtocolUDP, nil, port, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, m.incomingAcceptIndex.byCIDR, 1, "coarser v6 prefix must land in byCIDR")
|
||||
|
||||
_, err = m.AddFilterRule(nil, []netip.Prefix{netip.MustParsePrefix("::/0")}, fw.Network{}, fw.ProtocolUDP, nil, port, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, m.incomingAcceptIndex.matchAny, 1, "::/0 source must land in matchAny")
|
||||
}
|
||||
|
||||
// TestPeerACL_IPv4MappedSourceNormalized verifies a v4-mapped v6
|
||||
// source prefix is normalized to v4 so a plain v4 packet matches it.
|
||||
func TestPeerACL_IPv4MappedSourceNormalized(t *testing.T) {
|
||||
m := newTestManager(t)
|
||||
|
||||
mapped := netip.MustParseAddr("::ffff:192.168.1.1")
|
||||
rules, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(mapped, mapped.BitLen())}, fw.Network{}, fw.ProtocolUDP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rules, 1)
|
||||
|
||||
v4 := netip.MustParseAddr("192.168.1.1")
|
||||
assert.Contains(t, m.incomingAcceptIndex.bySource, v4, "v4-mapped v6 source must be indexed as plain v4")
|
||||
}
|
||||
@@ -10,24 +10,50 @@ import (
|
||||
|
||||
// PeerRule to handle management of rules
|
||||
type PeerRule struct {
|
||||
id string
|
||||
mgmtId []byte
|
||||
ip netip.Addr
|
||||
ipLayer gopacket.LayerType
|
||||
matchByIP bool
|
||||
id firewall.RuleID
|
||||
mgmtId []byte
|
||||
// sources is the canonical list of source prefixes this rule
|
||||
// matches against. A single 0.0.0.0/0 (or ::/0) entry means
|
||||
// "match any source".
|
||||
sources []netip.Prefix
|
||||
// sourceAddrs is a fast-path membership set for host-prefix
|
||||
// sources (/32 v4, /128 v6). Populated alongside sources;
|
||||
// consulted before falling back to prefix scan.
|
||||
sourceAddrs map[netip.Addr]struct{}
|
||||
// matchAny is true when sources covers everything (0.0.0.0/0,
|
||||
// ::/0). In that case neither sourceAddrs nor sources need to be
|
||||
// consulted.
|
||||
matchAny bool
|
||||
protoLayer gopacket.LayerType
|
||||
sPort *firewall.Port
|
||||
dPort *firewall.Port
|
||||
drop bool
|
||||
srcPort *firewall.Port
|
||||
dstPort *firewall.Port
|
||||
action firewall.Action
|
||||
}
|
||||
|
||||
// matchesSource reports whether the given source address is covered
|
||||
// by this rule's source list.
|
||||
func (r *PeerRule) matchesSource(src netip.Addr) bool {
|
||||
if r.matchAny {
|
||||
return true
|
||||
}
|
||||
if _, ok := r.sourceAddrs[src]; ok {
|
||||
return true
|
||||
}
|
||||
for _, p := range r.sources {
|
||||
if p.Contains(src) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ID returns the rule id
|
||||
func (r *PeerRule) ID() string {
|
||||
func (r *PeerRule) ID() firewall.RuleID {
|
||||
return r.id
|
||||
}
|
||||
|
||||
type RouteRule struct {
|
||||
id string
|
||||
id firewall.RuleID
|
||||
mgmtId []byte
|
||||
sources []netip.Prefix
|
||||
dstSet firewall.Set
|
||||
@@ -39,6 +65,6 @@ type RouteRule struct {
|
||||
}
|
||||
|
||||
// ID returns the rule id
|
||||
func (r *RouteRule) ID() string {
|
||||
func (r *RouteRule) ID() firewall.RuleID {
|
||||
return r.id
|
||||
}
|
||||
|
||||
50
client/firewall/uspfilter/testhelpers_test.go
Normal file
50
client/firewall/uspfilter/testhelpers_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
// countRulesForAddr reports how many rules in the given slice match
|
||||
// the supplied source address.
|
||||
func countRulesForAddr(rules PeerRules, src netip.Addr) int {
|
||||
n := 0
|
||||
for _, r := range rules {
|
||||
if r.matchesSource(src) {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// findRuleByID returns true if the rules slice contains a rule with
|
||||
// the given id whose source set covers src.
|
||||
func findRuleByID(rules PeerRules, src netip.Addr, id firewall.RuleID) bool {
|
||||
for _, r := range rules {
|
||||
if r.id == id && r.matchesSource(src) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// pfx converts a single net.IP into the []netip.Prefix form
|
||||
// AddFilterRule expects. A nil or unspecified address becomes a /0
|
||||
// ("match any") prefix in the matching family; any other address
|
||||
// becomes its /32 (or /128) host prefix.
|
||||
func pfx(ip net.IP) []netip.Prefix {
|
||||
if ip == nil {
|
||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||
}
|
||||
if ip.IsUnspecified() {
|
||||
if ip.To4() != nil {
|
||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||
}
|
||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||
}
|
||||
a, _ := netip.AddrFromSlice(ip)
|
||||
a = a.Unmap()
|
||||
return []netip.Prefix{netip.PrefixFrom(a, a.BitLen())}
|
||||
}
|
||||
@@ -97,7 +97,7 @@ func TestTracePacket(t *testing.T) {
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
@@ -121,7 +121,7 @@ func TestTracePacket(t *testing.T) {
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
@@ -150,7 +150,7 @@ func TestTracePacket(t *testing.T) {
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
@@ -178,7 +178,7 @@ func TestTracePacket(t *testing.T) {
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
@@ -205,7 +205,7 @@ func TestTracePacket(t *testing.T) {
|
||||
|
||||
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
|
||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
_, err := m.AddFilterRule(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
@@ -231,7 +231,7 @@ func TestTracePacket(t *testing.T) {
|
||||
|
||||
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
|
||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||
_, err := m.AddFilterRule(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
@@ -332,7 +332,7 @@ func TestTracePacket(t *testing.T) {
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolICMP
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, nil, action)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
@@ -355,7 +355,7 @@ func TestTracePacket(t *testing.T) {
|
||||
ip := net.ParseIP("1.1.1.1")
|
||||
proto := fw.ProtocolICMP
|
||||
action := fw.ActionDrop
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, nil, action)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
@@ -379,7 +379,7 @@ func TestTracePacket(t *testing.T) {
|
||||
proto := fw.ProtocolUDP
|
||||
port := &fw.Port{Values: []uint16{53}}
|
||||
action := fw.ActionAccept
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
@@ -423,7 +423,7 @@ func TestTracePacket(t *testing.T) {
|
||||
proto := fw.ProtocolTCP
|
||||
port := &fw.Port{Values: []uint16{80}}
|
||||
action := fw.ActionDrop
|
||||
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
packetBuilder: func() *PacketBuilder {
|
||||
|
||||
190
client/internal/acl/dispatch_test.go
Normal file
190
client/internal/acl/dispatch_test.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
fwmgr "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// TestNetworkZeroPrefixIsRoute guards the route-vs-peer dispatch
|
||||
// invariant: the backends classify a rule as a peer rule purely by the
|
||||
// absence of a destination (neither prefix nor set). A default route
|
||||
// (0.0.0.0/0 or ::/0) is a valid prefix and must therefore classify as
|
||||
// a route, not collapse into the peer path.
|
||||
func TestNetworkZeroPrefixIsRoute(t *testing.T) {
|
||||
for _, p := range []string{"0.0.0.0/0", "::/0", "10.0.0.0/8"} {
|
||||
n := fwmgr.Network{Prefix: netip.MustParsePrefix(p)}
|
||||
assert.True(t, n.IsPrefix(), "%s must report IsPrefix", p)
|
||||
assert.True(t, n.IsPrefix() || n.IsSet(), "%s must classify as a route", p)
|
||||
}
|
||||
|
||||
// A zero-value Network is the only peer-rule shape.
|
||||
var empty fwmgr.Network
|
||||
assert.False(t, empty.IsPrefix(), "zero Network must not be a prefix")
|
||||
assert.False(t, empty.IsSet(), "zero Network must not be a set")
|
||||
}
|
||||
|
||||
// TestDetermineDestinationAlwaysRoute verifies determineDestination
|
||||
// never yields an empty Network for a valid route rule: every branch
|
||||
// (static prefix, default route, dynamic with/without domains, with and
|
||||
// without a local resolver) produces a destination that classifies as a
|
||||
// route. If this regresses, a route rule would be dispatched down the
|
||||
// peer path, which matches on source only.
|
||||
func TestDetermineDestinationAlwaysRoute(t *testing.T) {
|
||||
v4 := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}
|
||||
v6 := []netip.Prefix{netip.MustParsePrefix("2001:db8::/48")}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
rule *mgmProto.RouteFirewallRule
|
||||
resolver bool
|
||||
sources []netip.Prefix
|
||||
}{
|
||||
{"static prefix", &mgmProto.RouteFirewallRule{Destination: "192.168.0.0/16"}, false, v4},
|
||||
{"static default route", &mgmProto.RouteFirewallRule{Destination: "0.0.0.0/0"}, false, v4},
|
||||
{"dynamic with domains + resolver", &mgmProto.RouteFirewallRule{IsDynamic: true, Domains: []string{"example.com"}}, true, v4},
|
||||
{"dynamic no domains + resolver (v4)", &mgmProto.RouteFirewallRule{IsDynamic: true}, true, v4},
|
||||
{"dynamic no domains + resolver (v6)", &mgmProto.RouteFirewallRule{IsDynamic: true}, true, v6},
|
||||
{"dynamic + no local resolver (v4)", &mgmProto.RouteFirewallRule{IsDynamic: true}, false, v4},
|
||||
{"dynamic + no local resolver (v6)", &mgmProto.RouteFirewallRule{IsDynamic: true}, false, v6},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
dest, err := determineDestination(tc.rule, tc.resolver, tc.sources)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, dest.IsPrefix() || dest.IsSet(),
|
||||
"destination must classify as a route, got empty Network")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// countingFirewall wraps a real firewall.Manager and counts filter-rule
|
||||
// add/delete calls so a test can assert how many backing rules the acl
|
||||
// manager actually creates and tears down.
|
||||
type countingFirewall struct {
|
||||
fwmgr.Manager
|
||||
mu sync.Mutex
|
||||
addCalls int
|
||||
dels int
|
||||
ruleIDs map[fwmgr.RuleID]struct{}
|
||||
}
|
||||
|
||||
// distinctRules returns the number of distinct backing rules the
|
||||
// backend produced. Because the backend dedups identical content,
|
||||
// repeated AddFilterRule calls for the same rule resolve to one id.
|
||||
func (f *countingFirewall) distinctRules() int {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return len(f.ruleIDs)
|
||||
}
|
||||
|
||||
func (f *countingFirewall) AddFilterRule(id []byte, sources []netip.Prefix, destination fwmgr.Network, proto fwmgr.Protocol, sPort, dPort *fwmgr.Port, action fwmgr.Action) ([]fwmgr.Rule, error) {
|
||||
rules, err := f.Manager.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
|
||||
if err == nil {
|
||||
f.mu.Lock()
|
||||
f.addCalls++
|
||||
if f.ruleIDs == nil {
|
||||
f.ruleIDs = make(map[fwmgr.RuleID]struct{})
|
||||
}
|
||||
for _, r := range rules {
|
||||
f.ruleIDs[r.ID()] = struct{}{}
|
||||
}
|
||||
f.mu.Unlock()
|
||||
}
|
||||
return rules, err
|
||||
}
|
||||
|
||||
func (f *countingFirewall) DeleteFilterRule(r fwmgr.Rule) error {
|
||||
err := f.Manager.DeleteFilterRule(r)
|
||||
if err == nil {
|
||||
f.mu.Lock()
|
||||
f.dels++
|
||||
delete(f.ruleIDs, r.ID())
|
||||
f.mu.Unlock()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func newCountingACL(t *testing.T) (*DefaultManager, *countingFirewall, func()) {
|
||||
t.Helper()
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{IP: network.Addr(), Network: network}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
realFW, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
|
||||
fw := &countingFirewall{Manager: realFW}
|
||||
cleanup := func() {
|
||||
require.NoError(t, realFW.Close(nil))
|
||||
ctrl.Finish()
|
||||
}
|
||||
return NewDefaultManager(fw), fw, cleanup
|
||||
}
|
||||
|
||||
// TestDuplicateContentPoliciesShareOneRule verifies the dedup contract
|
||||
// the backends rely on: two policies that authorize an identical flow
|
||||
// (same selector and sources) collapse to a single backing firewall
|
||||
// rule, and that rule survives until BOTH policies are gone. This is
|
||||
// why the backend can dedup on add without refcounting on delete: the
|
||||
// acl manager's pair key matches the backend's content key, so add and
|
||||
// delete stay balanced per content key across full-state reapplies.
|
||||
func TestDuplicateContentPoliciesShareOneRule(t *testing.T) {
|
||||
acl, fw, cleanup := newCountingACL(t)
|
||||
defer cleanup()
|
||||
|
||||
ruleA := &mgmProto.FirewallRule{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: "10.0.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
}
|
||||
ruleB := &mgmProto.FirewallRule{
|
||||
PolicyID: []byte("policy-B"),
|
||||
PeerIP: "10.0.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
}
|
||||
|
||||
// Both policies present: identical content collapses to one rule.
|
||||
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: []*mgmProto.FirewallRule{ruleA, ruleB}, FirewallRulesIsEmpty: false}, false)
|
||||
assert.Equal(t, 1, fw.distinctRules(), "identical-content policies must produce one backing rule")
|
||||
assert.Equal(t, 1, len(acl.peerRulesPairs), "one content key, one pair")
|
||||
|
||||
// Drop policy A only: the shared rule is still authorized by B, so
|
||||
// nothing is deleted.
|
||||
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: []*mgmProto.FirewallRule{ruleB}, FirewallRulesIsEmpty: false}, false)
|
||||
assert.Equal(t, 1, fw.distinctRules(), "no new backing rule on reapply")
|
||||
assert.Equal(t, 0, fw.dels, "rule must survive while any policy still authorizes it")
|
||||
assert.Equal(t, 1, len(acl.peerRulesPairs))
|
||||
|
||||
// Drop policy B too: now the content key has no authorizer and the
|
||||
// single backing rule is removed exactly once.
|
||||
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: nil, FirewallRulesIsEmpty: true}, false)
|
||||
assert.Equal(t, 1, fw.dels, "rule removed once when last policy is gone")
|
||||
assert.Equal(t, 0, len(acl.peerRulesPairs))
|
||||
}
|
||||
238
client/internal/acl/grouping_test.go
Normal file
238
client/internal/acl/grouping_test.go
Normal file
@@ -0,0 +1,238 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
fwmgr "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
)
|
||||
|
||||
// TestGroupPeerRulesPolicyIDSeparates verifies that two FirewallRules
|
||||
// with identical selectors but different PolicyIDs do NOT get merged
|
||||
// into one group, so each policy's sources merge under its own
|
||||
// attribution id. (Identical-content groups may still dedup to one
|
||||
// backing rule at the backend; see TestDuplicateContentPoliciesShareOneRule.)
|
||||
func TestGroupPeerRulesPolicyIDSeparates(t *testing.T) {
|
||||
rules := []*mgmProto.FirewallRule{
|
||||
{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: "10.0.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PolicyID: []byte("policy-B"),
|
||||
PeerIP: "10.0.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
}
|
||||
|
||||
groups, merr, _ := groupPeerRules(rules)
|
||||
require.Nil(t, merr.ErrorOrNil())
|
||||
require.Len(t, groups, 2, "rules with different PolicyIDs must produce separate groups")
|
||||
}
|
||||
|
||||
// TestGroupPeerRulesFamilySeparates verifies that v4 and v6 rules
|
||||
// belonging to the same policy don't merge.
|
||||
func TestGroupPeerRulesFamilySeparates(t *testing.T) {
|
||||
rules := []*mgmProto.FirewallRule{
|
||||
{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: "10.0.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: "2001:db8::1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
}
|
||||
|
||||
groups, merr, _ := groupPeerRules(rules)
|
||||
require.Nil(t, merr.ErrorOrNil())
|
||||
require.Len(t, groups, 2, "rules of different families must produce separate groups")
|
||||
|
||||
var sawV4, sawV6 bool
|
||||
for _, g := range groups {
|
||||
require.Len(t, g.sources, 1)
|
||||
if g.sources[0].Addr().Is4() {
|
||||
sawV4 = true
|
||||
}
|
||||
if g.sources[0].Addr().Is6() {
|
||||
sawV6 = true
|
||||
}
|
||||
}
|
||||
assert.True(t, sawV4 && sawV6)
|
||||
}
|
||||
|
||||
// TestGroupPeerRulesMergesSameSelector verifies that rules sharing
|
||||
// every distinguishing field (policy, family, direction, action,
|
||||
// proto, port) collapse into a single multi-source group.
|
||||
func TestGroupPeerRulesMergesSameSelector(t *testing.T) {
|
||||
mk := func(peerIP string) *mgmProto.FirewallRule {
|
||||
return &mgmProto.FirewallRule{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: peerIP,
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
}
|
||||
}
|
||||
rules := []*mgmProto.FirewallRule{mk("10.0.0.1"), mk("10.0.0.2"), mk("10.0.0.3")}
|
||||
|
||||
groups, merr, _ := groupPeerRules(rules)
|
||||
require.Nil(t, merr.ErrorOrNil())
|
||||
require.Len(t, groups, 1)
|
||||
require.Len(t, groups[0].sources, 3)
|
||||
}
|
||||
|
||||
// TestGroupPeerRulesUsesSourcePrefixesWhenPresent verifies that the
|
||||
// new sourcePrefixes wire field is consumed and produces a
|
||||
// multi-source group in one shot (no client-side merging needed).
|
||||
func TestGroupPeerRulesUsesSourcePrefixesWhenPresent(t *testing.T) {
|
||||
srcs := [][]byte{
|
||||
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.1")),
|
||||
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.2")),
|
||||
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.3")),
|
||||
}
|
||||
rules := []*mgmProto.FirewallRule{
|
||||
{
|
||||
PolicyID: []byte("policy-A"),
|
||||
SourcePrefixes: srcs,
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
}
|
||||
|
||||
groups, merr, _ := groupPeerRules(rules)
|
||||
require.Nil(t, merr.ErrorOrNil())
|
||||
require.Len(t, groups, 1)
|
||||
require.Len(t, groups[0].sources, 3)
|
||||
}
|
||||
|
||||
// TestGroupPeerRulesActionSeparates verifies the obvious: accept
|
||||
// and drop rules with the same selector don't merge.
|
||||
func TestGroupPeerRulesActionSeparates(t *testing.T) {
|
||||
rules := []*mgmProto.FirewallRule{
|
||||
{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: "10.0.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: "10.0.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
}
|
||||
|
||||
groups, merr, _ := groupPeerRules(rules)
|
||||
require.Nil(t, merr.ErrorOrNil())
|
||||
require.Len(t, groups, 2)
|
||||
}
|
||||
|
||||
// failingDeleteFirewall wraps a real firewall.Manager and forces the
|
||||
// next N DeleteFilterRule calls to fail. Used to verify that the acl
|
||||
// manager retains rules whose deletion was rejected by the backend,
|
||||
// so they get retried on the next ApplyFiltering pass instead of
|
||||
// becoming orphans.
|
||||
type failingDeleteFirewall struct {
|
||||
fwmgr.Manager
|
||||
failCount int
|
||||
}
|
||||
|
||||
func (f *failingDeleteFirewall) DeleteFilterRule(r fwmgr.Rule) error {
|
||||
if f.failCount > 0 {
|
||||
f.failCount--
|
||||
return errors.New("simulated delete failure")
|
||||
}
|
||||
return f.Manager.DeleteFilterRule(r)
|
||||
}
|
||||
|
||||
// TestApplyFilteringRetainsRulesOnDeleteFailure verifies that a
|
||||
// transient DeleteFilterRule error doesn't make the acl manager
|
||||
// forget about a rule. The rule must remain in peerRulesPairs so the
|
||||
// next ApplyFiltering pass attempts the delete again.
|
||||
func TestApplyFilteringRetainsRulesOnDeleteFailure(t *testing.T) {
|
||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{IP: network.Addr(), Network: network}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
realFW, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() { require.NoError(t, realFW.Close(nil)) }()
|
||||
|
||||
fw := &failingDeleteFirewall{Manager: realFW}
|
||||
acl := NewDefaultManager(fw)
|
||||
|
||||
// First pass: install a rule.
|
||||
netmap1 := &mgmProto.NetworkMap{
|
||||
FirewallRules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PolicyID: []byte("policy-A"),
|
||||
PeerIP: "10.0.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "22",
|
||||
},
|
||||
},
|
||||
FirewallRulesIsEmpty: false,
|
||||
}
|
||||
acl.ApplyFiltering(netmap1, false)
|
||||
require.Equal(t, 1, len(acl.peerRulesPairs), "rule should be installed")
|
||||
|
||||
// Second pass: remove the rule from the map. The backend will
|
||||
// fail the delete; the acl manager must retain the rule.
|
||||
fw.failCount = 1
|
||||
netmap2 := &mgmProto.NetworkMap{FirewallRules: nil, FirewallRulesIsEmpty: true}
|
||||
acl.ApplyFiltering(netmap2, false)
|
||||
require.Equal(t, 1, len(acl.peerRulesPairs),
|
||||
"rule must be retained when DeleteFilterRule fails so it gets retried")
|
||||
|
||||
// Third pass: same map, backend no longer fails. The rule
|
||||
// should now succeed in being removed.
|
||||
acl.ApplyFiltering(netmap2, false)
|
||||
require.Equal(t, 0, len(acl.peerRulesPairs), "retry should succeed")
|
||||
}
|
||||
@@ -5,18 +5,21 @@ import (
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strconv"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
)
|
||||
|
||||
type RuleID string
|
||||
// RuleID aliases manager.RuleID so existing nbid.RuleID references
|
||||
// keep working while the canonical type lives in the firewall package.
|
||||
type RuleID = manager.RuleID
|
||||
|
||||
func (r RuleID) ID() string {
|
||||
return string(r)
|
||||
}
|
||||
|
||||
func GenerateRouteRuleKey(
|
||||
// GenerateRuleID returns a deterministic content hash identifying a
|
||||
// filter rule. It covers both peer rules (empty destination) and route
|
||||
// rules (prefix or set destination), so identical rules dedup to the
|
||||
// same id across backends regardless of which path created them.
|
||||
func GenerateRuleID(
|
||||
sources []netip.Prefix,
|
||||
destination manager.Network,
|
||||
proto manager.Protocol,
|
||||
@@ -24,6 +27,7 @@ func GenerateRouteRuleKey(
|
||||
dPort *manager.Port,
|
||||
action manager.Action,
|
||||
) RuleID {
|
||||
sources = slices.Clone(sources)
|
||||
manager.SortPrefixes(sources)
|
||||
|
||||
h := sha256.New()
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
@@ -31,7 +29,6 @@ type Manager interface {
|
||||
// DefaultManager uses firewall manager to handle
|
||||
type DefaultManager struct {
|
||||
firewall firewall.Manager
|
||||
ipsetCounter int
|
||||
peerRulesPairs map[id.RuleID][]firewall.Rule
|
||||
routeRules map[id.RuleID]struct{}
|
||||
mutex sync.Mutex
|
||||
@@ -102,59 +99,271 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||
)
|
||||
}
|
||||
|
||||
newRulePairs := make(map[id.RuleID][]firewall.Rule)
|
||||
ipsetByRuleSelectors := make(map[string]string)
|
||||
// Group incoming single-source rules from management by their
|
||||
// (direction, action, proto, port) selector and merge sources.
|
||||
// One call to the firewall backend per merged rule.
|
||||
groups, merr, denyFailed := groupPeerRules(rules)
|
||||
if denyFailed {
|
||||
log.Errorf("a deny rule failed to decode its sources, skipping this pass to avoid fail-open: %v", nberrors.FormatErrorOrNil(merr))
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: deny rules should be fatal: if a deny rule fails to apply, we must
|
||||
// roll back all allow rules to avoid a fail-open where allowed traffic bypasses
|
||||
// the missing deny. Currently we accumulate errors and continue.
|
||||
var merr *multierror.Error
|
||||
for _, r := range rules {
|
||||
// if this rule is member of rule selection with more than DefaultIPsCountForSet
|
||||
// it's IP address can be used in the ipset for firewall manager which supports it
|
||||
selector := d.getRuleGroupingSelector(r)
|
||||
ipsetName, ok := ipsetByRuleSelectors[selector]
|
||||
if !ok {
|
||||
d.ipsetCounter++
|
||||
ipsetName = fmt.Sprintf("nb%07d", d.ipsetCounter)
|
||||
ipsetByRuleSelectors[selector] = ipsetName
|
||||
}
|
||||
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("apply firewall rule: %w", err))
|
||||
continue
|
||||
}
|
||||
if len(rulePair) > 0 {
|
||||
d.peerRulesPairs[pairID] = rulePair
|
||||
newRulePairs[pairID] = rulePair
|
||||
}
|
||||
newRulePairs := make(map[id.RuleID][]firewall.Rule)
|
||||
|
||||
// Apply denies first. A deny that fails to install is a security
|
||||
// failure (fail-open), so if any deny errors we roll back the
|
||||
// denies we already installed in this pass and bail out without
|
||||
// installing any accept. Pre-existing rules stay untouched until
|
||||
// the next successful pass clears them.
|
||||
denies, accepts := splitDenyAccept(groups)
|
||||
if err := d.installPeerGroups(denies, newRulePairs, true); err != nil {
|
||||
log.Errorf("deny install failed, skipping accepts to avoid fail-open: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := d.installPeerGroups(accepts, newRulePairs, false); err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
}
|
||||
|
||||
if merr != nil {
|
||||
log.Errorf("failed to apply %d peer ACL rule(s): %v", merr.Len(), nberrors.FormatErrorOrNil(merr))
|
||||
}
|
||||
|
||||
// Tear down rules that disappeared from the networkmap. Any rule
|
||||
// the backend refuses to delete stays in our tracking so it gets
|
||||
// retried on the next ApplyFiltering. Otherwise a transient
|
||||
// delete failure would leak the rule in the kernel until the
|
||||
// process exits.
|
||||
for pairID, rules := range d.peerRulesPairs {
|
||||
if _, ok := newRulePairs[pairID]; !ok {
|
||||
for _, rule := range rules {
|
||||
if err := d.firewall.DeletePeerRule(rule); err != nil {
|
||||
log.Errorf("failed to delete peer firewall rule: %v", err)
|
||||
continue
|
||||
}
|
||||
if _, ok := newRulePairs[pairID]; ok {
|
||||
continue
|
||||
}
|
||||
var remaining []firewall.Rule
|
||||
for _, rule := range rules {
|
||||
if err := d.firewall.DeleteFilterRule(rule); err != nil {
|
||||
log.Errorf("failed to delete peer firewall rule, will retry: %v", err)
|
||||
remaining = append(remaining, rule)
|
||||
}
|
||||
delete(d.peerRulesPairs, pairID)
|
||||
}
|
||||
if len(remaining) > 0 {
|
||||
newRulePairs[pairID] = remaining
|
||||
}
|
||||
}
|
||||
d.peerRulesPairs = newRulePairs
|
||||
}
|
||||
|
||||
// splitDenyAccept partitions groups by action so denies can be
|
||||
// applied before accepts. Order within each bucket is preserved.
|
||||
func splitDenyAccept(groups []*peerRuleGroup) (denies, accepts []*peerRuleGroup) {
|
||||
for _, g := range groups {
|
||||
if g.action == mgmProto.RuleAction_DROP {
|
||||
denies = append(denies, g)
|
||||
} else {
|
||||
accepts = append(accepts, g)
|
||||
}
|
||||
}
|
||||
return denies, accepts
|
||||
}
|
||||
|
||||
// installPeerGroups applies each group and records the resulting rule
|
||||
// pairs in newRulePairs. With atomic set (deny rules), a single failure
|
||||
// rolls back every rule installed in this call and returns, leaving the
|
||||
// kernel exactly as before: denies are fail-closed and must be applied
|
||||
// all-or-nothing. With atomic unset (accept rules), failures are
|
||||
// accumulated and the remaining groups still install, so one malformed
|
||||
// allow cannot drop every other legitimate allow in the pass.
|
||||
func (d *DefaultManager) installPeerGroups(groups []*peerRuleGroup, newRulePairs map[id.RuleID][]firewall.Rule, atomic bool) error {
|
||||
var freshlyInstalled []id.RuleID
|
||||
var merr *multierror.Error
|
||||
for _, g := range groups {
|
||||
pairID, rulePair, err := d.applyPeerGroup(g)
|
||||
if err != nil {
|
||||
if atomic {
|
||||
d.rollbackInstalled(freshlyInstalled)
|
||||
return fmt.Errorf("apply firewall rule: %w", err)
|
||||
}
|
||||
merr = multierror.Append(merr, fmt.Errorf("apply firewall rule: %w", err))
|
||||
continue
|
||||
}
|
||||
if len(rulePair) == 0 {
|
||||
continue
|
||||
}
|
||||
if _, existed := d.peerRulesPairs[pairID]; !existed {
|
||||
freshlyInstalled = append(freshlyInstalled, pairID)
|
||||
}
|
||||
d.peerRulesPairs[pairID] = rulePair
|
||||
newRulePairs[pairID] = rulePair
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (d *DefaultManager) rollbackInstalled(pairIDs []id.RuleID) {
|
||||
for _, pairID := range pairIDs {
|
||||
for _, rule := range d.peerRulesPairs[pairID] {
|
||||
if err := d.firewall.DeleteFilterRule(rule); err != nil {
|
||||
log.Errorf("rollback peer rule %s: %v", pairID, err)
|
||||
}
|
||||
}
|
||||
delete(d.peerRulesPairs, pairID)
|
||||
}
|
||||
}
|
||||
|
||||
// peerRuleGroup collapses a set of single-source FirewallRules sharing
|
||||
// the same selector into one multi-source rule to push to the backend.
|
||||
type peerRuleGroup struct {
|
||||
direction mgmProto.RuleDirection
|
||||
action mgmProto.RuleAction
|
||||
protocol mgmProto.RuleProtocol
|
||||
port *mgmProto.PortInfo
|
||||
// legacyPort is used only when PortInfo is empty (old management).
|
||||
legacyPort string
|
||||
policyID []byte
|
||||
sources []netip.Prefix
|
||||
}
|
||||
|
||||
// groupPeerRules merges single-source rules sharing a selector into
|
||||
// multi-source groups. The bool return reports whether any deny rule
|
||||
// failed to decode its sources: a deny we cannot realize is a
|
||||
// fail-open risk, so the caller skips the whole pass and retries rather
|
||||
// than installing accepts on top of a missing deny.
|
||||
func groupPeerRules(rules []*mgmProto.FirewallRule) ([]*peerRuleGroup, *multierror.Error, bool) {
|
||||
var merr *multierror.Error
|
||||
denyFailed := false
|
||||
byKey := make(map[string]*peerRuleGroup)
|
||||
order := make([]string, 0)
|
||||
|
||||
for _, r := range rules {
|
||||
srcs, err := extractRuleSources(r)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
if r.Action == mgmProto.RuleAction_DROP {
|
||||
denyFailed = true
|
||||
}
|
||||
continue
|
||||
}
|
||||
// extractRuleSources returns at least one source on success;
|
||||
// pick the family from the first to key this group. Sources
|
||||
// from a single FirewallRule are always same-family (mgmt
|
||||
// emits one rule per family), so this is unambiguous.
|
||||
family := familyTag(srcs[0])
|
||||
key := ruleGroupKey(r, family)
|
||||
g, ok := byKey[key]
|
||||
if !ok {
|
||||
g = &peerRuleGroup{
|
||||
direction: r.Direction,
|
||||
action: r.Action,
|
||||
protocol: r.Protocol,
|
||||
port: r.PortInfo,
|
||||
legacyPort: r.Port,
|
||||
policyID: r.PolicyID,
|
||||
}
|
||||
byKey[key] = g
|
||||
order = append(order, key)
|
||||
}
|
||||
g.sources = append(g.sources, srcs...)
|
||||
}
|
||||
|
||||
out := make([]*peerRuleGroup, 0, len(order))
|
||||
for _, k := range order {
|
||||
out = append(out, byKey[k])
|
||||
}
|
||||
return out, merr, denyFailed
|
||||
}
|
||||
|
||||
func familyTag(p netip.Prefix) string {
|
||||
if p.Addr().Is6() && !p.Addr().Is4In6() {
|
||||
return "v6"
|
||||
}
|
||||
return "v4"
|
||||
}
|
||||
|
||||
// ruleGroupKey returns a string that uniquely identifies a peer-rule
|
||||
// selector. Rules sharing a key can be collapsed into one multi-source
|
||||
// rule pushed to the firewall backend.
|
||||
//
|
||||
// All distinguishing fields must be in the key:
|
||||
// - family (v4 vs v6): mgmt emits one FirewallRule per family per
|
||||
// peer, and merging would produce mixed-family groups that broke
|
||||
// ICMP-variant selection in uspfilter.
|
||||
// - policyID: two policies may authorize different source peers for
|
||||
// the same proto/port/direction. Keeping them in separate groups
|
||||
// keeps each policy's sources in its own backend rule instead of
|
||||
// merging unrelated peers into one rule attributed to a single
|
||||
// policy. (When two policies authorize the identical sources and
|
||||
// selector, the backend dedups them to one rule regardless,
|
||||
// attributed to whichever was applied first.)
|
||||
// - direction, action, protocol, port: behavioral fields; mismatched
|
||||
// rules must produce mismatched kernel rules.
|
||||
func ruleGroupKey(r *mgmProto.FirewallRule, family string) string {
|
||||
return fmt.Sprintf("%s:%x:%d:%d:%d:%s:%v",
|
||||
family, r.PolicyID, r.Direction, r.Action, r.Protocol, r.Port, r.PortInfo)
|
||||
}
|
||||
|
||||
func (d *DefaultManager) applyPeerGroup(g *peerRuleGroup) (id.RuleID, []firewall.Rule, error) {
|
||||
protocol, err := convertToFirewallProtocol(g.protocol)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("skipping firewall rule: %w", err)
|
||||
}
|
||||
action, err := convertFirewallAction(g.action)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("skipping firewall rule: %w", err)
|
||||
}
|
||||
port, err := resolveGroupPort(g)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
var fwRules []firewall.Rule
|
||||
switch g.direction {
|
||||
case mgmProto.RuleDirection_IN:
|
||||
fwRules, err = d.firewall.AddFilterRule(g.policyID, g.sources, firewall.Network{}, protocol, nil, port, action)
|
||||
case mgmProto.RuleDirection_OUT:
|
||||
if d.firewall.IsStateful() {
|
||||
return "", nil, nil
|
||||
}
|
||||
if shouldSkipInvertedRule(protocol, port) {
|
||||
return "", nil, nil
|
||||
}
|
||||
fwRules, err = d.firewall.AddFilterRule(g.policyID, g.sources, firewall.Network{}, protocol, port, nil, action)
|
||||
default:
|
||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("add firewall rule: %w", err)
|
||||
}
|
||||
if len(fwRules) == 0 {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
// Derive the pair id from the backend rule, like the route path:
|
||||
// the backend dedups identical content, so two policies authorizing
|
||||
// the same flow resolve to the same id and a single backing rule.
|
||||
return fwRules[0].ID(), fwRules, nil
|
||||
}
|
||||
|
||||
func resolveGroupPort(g *peerRuleGroup) (*firewall.Port, error) {
|
||||
if !portInfoEmpty(g.port) {
|
||||
return convertPortInfo(g.port), nil
|
||||
}
|
||||
if g.legacyPort != "" {
|
||||
value, err := strconv.Atoi(g.legacyPort)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid port: %w", err)
|
||||
}
|
||||
return &firewall.Port{Values: []uint16{uint16(value)}}, nil
|
||||
}
|
||||
// nolint:nilnil // a nil port legitimately means "no port restriction"
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dynamicResolver bool) error {
|
||||
newRouteRules := make(map[id.RuleID]struct{}, len(rules))
|
||||
var merr *multierror.Error
|
||||
|
||||
// Apply new rules - firewall manager will return existing rule ID if already present
|
||||
for _, rule := range rules {
|
||||
id, err := d.applyRouteACL(rule, dynamicResolver)
|
||||
ruleID, err := d.applyRouteACL(rule, dynamicResolver)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSourceRangesEmpty) {
|
||||
log.Debugf("skipping empty sources rule with destination %s: %v", rule.Destination, err)
|
||||
@@ -163,16 +372,18 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dyn
|
||||
}
|
||||
continue
|
||||
}
|
||||
newRouteRules[id] = struct{}{}
|
||||
newRouteRules[ruleID] = struct{}{}
|
||||
}
|
||||
|
||||
// Clean up old firewall rules
|
||||
for id := range d.routeRules {
|
||||
if _, exists := newRouteRules[id]; !exists {
|
||||
if err := d.firewall.DeleteRouteRule(id); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete route rule: %w", err))
|
||||
}
|
||||
// implicitly deleted from the map
|
||||
// Tear down old route rules; retain ones the backend refused so a
|
||||
// transient failure doesn't leave kernel-side orphans.
|
||||
for ruleID := range d.routeRules {
|
||||
if _, exists := newRouteRules[ruleID]; exists {
|
||||
continue
|
||||
}
|
||||
if err := d.firewall.DeleteFilterRule(ruleID); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete route rule, will retry: %w", err))
|
||||
newRouteRules[ruleID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -191,7 +402,7 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamic
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse source range: %w", err)
|
||||
}
|
||||
sources = append(sources, source)
|
||||
sources = append(sources, firewall.UnmapPrefix(source))
|
||||
}
|
||||
|
||||
destination, err := determineDestination(rule, dynamicResolver, sources)
|
||||
@@ -211,71 +422,15 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamic
|
||||
|
||||
dPorts := convertPortInfo(rule.PortInfo)
|
||||
|
||||
addedRule, err := d.firewall.AddRouteFiltering(rule.PolicyID, sources, destination, protocol, nil, dPorts, action)
|
||||
addedRules, err := d.firewall.AddFilterRule(rule.PolicyID, sources, destination, protocol, nil, dPorts, action)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("add route rule: %w", err)
|
||||
}
|
||||
|
||||
return id.RuleID(addedRule.ID()), nil
|
||||
}
|
||||
|
||||
func (d *DefaultManager) protoRuleToFirewallRule(
|
||||
r *mgmProto.FirewallRule,
|
||||
ipsetName string,
|
||||
) (id.RuleID, []firewall.Rule, error) {
|
||||
ip, err := extractRuleIP(r)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
if len(addedRules) == 0 {
|
||||
return "", fmt.Errorf("add route rule: backend returned no rules")
|
||||
}
|
||||
|
||||
protocol, err := convertToFirewallProtocol(r.Protocol)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
|
||||
}
|
||||
|
||||
action, err := convertFirewallAction(r.Action)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
|
||||
}
|
||||
|
||||
var port *firewall.Port
|
||||
if !portInfoEmpty(r.PortInfo) {
|
||||
port = convertPortInfo(r.PortInfo)
|
||||
} else if r.Port != "" {
|
||||
// old version of management, single port
|
||||
value, err := strconv.Atoi(r.Port)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("invalid port: %w", err)
|
||||
}
|
||||
port = &firewall.Port{
|
||||
Values: []uint16{uint16(value)},
|
||||
}
|
||||
}
|
||||
|
||||
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action)
|
||||
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
|
||||
return ruleID, rulesPair, nil
|
||||
}
|
||||
|
||||
var rules []firewall.Rule
|
||||
switch r.Direction {
|
||||
case mgmProto.RuleDirection_IN:
|
||||
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||
case mgmProto.RuleDirection_OUT:
|
||||
if d.firewall.IsStateful() {
|
||||
return "", nil, nil
|
||||
}
|
||||
// return traffic for outbound connections if firewall is stateless
|
||||
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||
default:
|
||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return ruleID, rules, nil
|
||||
return addedRules[0].ID(), nil
|
||||
}
|
||||
|
||||
func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
|
||||
@@ -294,82 +449,28 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DefaultManager) addInRules(
|
||||
id []byte,
|
||||
ip netip.Addr,
|
||||
protocol firewall.Protocol,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, nil, port, action, ipsetName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (d *DefaultManager) addOutRules(
|
||||
id []byte,
|
||||
ip netip.Addr,
|
||||
protocol firewall.Protocol,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
ipsetName string,
|
||||
) ([]firewall.Rule, error) {
|
||||
if shouldSkipInvertedRule(protocol, port) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, port, nil, action, ipsetName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||
}
|
||||
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
// getPeerRuleID returns unique ID for the rule based on its parameters.
|
||||
func (d *DefaultManager) getPeerRuleID(
|
||||
ip netip.Addr,
|
||||
proto firewall.Protocol,
|
||||
direction int,
|
||||
port *firewall.Port,
|
||||
action firewall.Action,
|
||||
) id.RuleID {
|
||||
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action))
|
||||
if port != nil {
|
||||
idStr += port.String()
|
||||
}
|
||||
|
||||
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
|
||||
}
|
||||
|
||||
// getRuleGroupingSelector takes all rule properties except IP address to build selector
|
||||
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
|
||||
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
|
||||
}
|
||||
|
||||
|
||||
// extractRuleIP extracts the peer IP from a firewall rule.
|
||||
// If sourcePrefixes is populated (new management), decode the first entry and use its address.
|
||||
// Otherwise fall back to the deprecated PeerIP string field (old management).
|
||||
func extractRuleIP(r *mgmProto.FirewallRule) (netip.Addr, error) {
|
||||
// extractRuleSources returns all source prefixes the rule applies to.
|
||||
// New management populates sourcePrefixes; older management sets PeerIP.
|
||||
func extractRuleSources(r *mgmProto.FirewallRule) ([]netip.Prefix, error) {
|
||||
if len(r.SourcePrefixes) > 0 {
|
||||
addr, err := netiputil.DecodeAddr(r.SourcePrefixes[0])
|
||||
if err != nil {
|
||||
return netip.Addr{}, fmt.Errorf("decode source prefix: %w", err)
|
||||
out := make([]netip.Prefix, 0, len(r.SourcePrefixes))
|
||||
for _, raw := range r.SourcePrefixes {
|
||||
addr, err := netiputil.DecodeAddr(raw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode source prefix: %w", err)
|
||||
}
|
||||
out = append(out, netip.PrefixFrom(addr.Unmap(), addr.Unmap().BitLen()))
|
||||
}
|
||||
return addr.Unmap(), nil
|
||||
return out, nil
|
||||
}
|
||||
|
||||
//nolint:staticcheck // PeerIP used for backward compatibility with old management
|
||||
addr, err := netip.ParseAddr(r.PeerIP)
|
||||
if err != nil {
|
||||
return netip.Addr{}, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||
return nil, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||
}
|
||||
return addr.Unmap(), nil
|
||||
addr = addr.Unmap()
|
||||
return []netip.Prefix{netip.PrefixFrom(addr, addr.BitLen())}, nil
|
||||
}
|
||||
|
||||
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
fwmanager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||
@@ -76,9 +77,9 @@ func TestDefaultManager(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("add extra rules", func(t *testing.T) {
|
||||
existedPairs := map[string]struct{}{}
|
||||
existedPairs := map[fwmanager.RuleID]struct{}{}
|
||||
for id := range acl.peerRulesPairs {
|
||||
existedPairs[id.ID()] = struct{}{}
|
||||
existedPairs[id] = struct{}{}
|
||||
}
|
||||
|
||||
// remove first rule
|
||||
@@ -105,7 +106,7 @@ func TestDefaultManager(t *testing.T) {
|
||||
// check that old rule was removed
|
||||
previousCount := 0
|
||||
for id := range acl.peerRulesPairs {
|
||||
if _, ok := existedPairs[id.ID()]; ok {
|
||||
if _, ok := existedPairs[id]; ok {
|
||||
previousCount++
|
||||
}
|
||||
}
|
||||
|
||||
@@ -360,7 +360,13 @@ func isRedirectURLPortUsed(redirectURL string, excludedRanges []excludedPortRang
|
||||
return true
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf(":%s", port)
|
||||
// FreeBSD 15 disables connecting to INADDR_ANY (0.0.0.0) as a localhost
|
||||
// alias by default, ensure explicit ip for localhost.
|
||||
host := parsedURL.Hostname()
|
||||
if host == "" {
|
||||
host = "127.0.0.1"
|
||||
}
|
||||
addr := net.JoinHostPort(host, port)
|
||||
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
|
||||
if err != nil {
|
||||
return false
|
||||
|
||||
@@ -116,7 +116,6 @@ func (c *ConnectClient) RunOniOS(
|
||||
fileDescriptor int32,
|
||||
networkChangeListener listener.NetworkChangeListener,
|
||||
dnsManager dns.IosDnsManager,
|
||||
dnsAddresses []netip.AddrPort,
|
||||
stateFilePath string,
|
||||
) error {
|
||||
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
|
||||
@@ -126,7 +125,6 @@ func (c *ConnectClient) RunOniOS(
|
||||
FileDescriptor: fileDescriptor,
|
||||
NetworkChangeListener: networkChangeListener,
|
||||
DnsManager: dnsManager,
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
StateFilePath: stateFilePath,
|
||||
}
|
||||
return c.run(mobileDependency, nil, "")
|
||||
|
||||
@@ -339,8 +339,7 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||
case entry.Pattern == ".":
|
||||
return true
|
||||
case entry.IsWildcard:
|
||||
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
|
||||
return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
|
||||
return strings.HasSuffix(qname, "."+entry.Pattern)
|
||||
default:
|
||||
// For non-wildcard patterns:
|
||||
// If handler wants subdomain matching, allow suffix match
|
||||
|
||||
@@ -164,6 +164,54 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
||||
matchSubdomains: true,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard label-boundary mismatch (suffix overlap)",
|
||||
handlerDomain: "*.b.test.",
|
||||
queryDomain: "x.ab.test.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard label-boundary match",
|
||||
handlerDomain: "*.b.test.",
|
||||
queryDomain: "x.b.test.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard multi-label match",
|
||||
handlerDomain: "*.b.test.",
|
||||
queryDomain: "x.y.b.test.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard no match on multi-label apex",
|
||||
handlerDomain: "*.b.test.",
|
||||
queryDomain: "b.test.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard no match on unrelated suffix containment",
|
||||
handlerDomain: "*.example.com.",
|
||||
queryDomain: "notexample.com.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard accepts pattern registered without trailing dot",
|
||||
handlerDomain: "*.b.test",
|
||||
queryDomain: "x.b.test.",
|
||||
isWildcard: true,
|
||||
matchSubdomains: false,
|
||||
shouldMatch: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@@ -273,6 +321,19 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
||||
expectedCalls: 1,
|
||||
expectedHandler: 2, // highest priority matching handler should be called
|
||||
},
|
||||
{
|
||||
name: "overlapping wildcard suffixes route to correct handler",
|
||||
handlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
}{
|
||||
{pattern: "*.b.test.", priority: nbdns.PriorityDNSRoute},
|
||||
{pattern: "*.ab.test.", priority: nbdns.PriorityDNSRoute},
|
||||
},
|
||||
queryDomain: "app.ab.test.",
|
||||
expectedCalls: 1,
|
||||
expectedHandler: 1,
|
||||
},
|
||||
{
|
||||
name: "root zone with specific domain",
|
||||
handlers: []struct {
|
||||
|
||||
@@ -16,6 +16,10 @@ type hostManager interface {
|
||||
restoreHostDNS() error
|
||||
supportCustomPort() bool
|
||||
string() string
|
||||
// getOriginalNameservers returns the OS-side resolvers used as PriorityFallback
|
||||
// upstreams: pre-takeover snapshots on desktop, the OS-pushed list on Android,
|
||||
// hardcoded Quad9 on iOS, nil for noop / mock.
|
||||
getOriginalNameservers() []netip.Addr
|
||||
}
|
||||
|
||||
type SystemDNSSettings struct {
|
||||
@@ -131,3 +135,11 @@ func (n noopHostConfigurator) supportCustomPort() bool {
|
||||
func (n noopHostConfigurator) string() string {
|
||||
return "noop"
|
||||
}
|
||||
|
||||
func (n noopHostConfigurator) getOriginalNameservers() []netip.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockHostConfigurator) getOriginalNameservers() []netip.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// androidHostManager is a noop on the OS side (Android's VPN service handles
|
||||
// DNS for us) but tracks the OS-reported resolver list pushed via
|
||||
// OnUpdatedHostDNSServer so it can serve as the fallback nameserver source.
|
||||
type androidHostManager struct {
|
||||
holder *hostsDNSHolder
|
||||
}
|
||||
|
||||
func newHostManager() (*androidHostManager, error) {
|
||||
return &androidHostManager{}, nil
|
||||
func newHostManager(holder *hostsDNSHolder) (*androidHostManager, error) {
|
||||
return &androidHostManager{holder: holder}, nil
|
||||
}
|
||||
|
||||
func (a androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error {
|
||||
@@ -26,3 +32,12 @@ func (a androidHostManager) supportCustomPort() bool {
|
||||
func (a androidHostManager) string() string {
|
||||
return "none"
|
||||
}
|
||||
|
||||
func (a androidHostManager) getOriginalNameservers() []netip.Addr {
|
||||
hosts := a.holder.get()
|
||||
out := make([]netip.Addr, 0, len(hosts))
|
||||
for ap := range hosts {
|
||||
out = append(out, ap.Addr())
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package dns
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -20,6 +21,14 @@ func newHostManager(dnsManager IosDnsManager) (*iosHostManager, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a iosHostManager) getOriginalNameservers() []netip.Addr {
|
||||
// Quad9 v4+v6: 9.9.9.9, 2620:fe::fe.
|
||||
return []netip.Addr{
|
||||
netip.AddrFrom4([4]byte{9, 9, 9, 9}),
|
||||
netip.AddrFrom16([16]byte{0x26, 0x20, 0x00, 0xfe, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xfe}),
|
||||
}
|
||||
}
|
||||
|
||||
func (a iosHostManager) applyDNSConfig(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||
jsonData, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"slices"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -44,9 +45,11 @@ const (
|
||||
|
||||
nrptMaxDomainsPerRule = 50
|
||||
|
||||
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
||||
interfaceConfigNameServerKey = "NameServer"
|
||||
interfaceConfigSearchListKey = "SearchList"
|
||||
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
||||
interfaceConfigPathV6 = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces`
|
||||
interfaceConfigNameServerKey = "NameServer"
|
||||
interfaceConfigDhcpNameSrvKey = "DhcpNameServer"
|
||||
interfaceConfigSearchListKey = "SearchList"
|
||||
|
||||
// Network interface DNS registration settings
|
||||
disableDynamicUpdateKey = "DisableDynamicUpdate"
|
||||
@@ -67,10 +70,11 @@ const (
|
||||
)
|
||||
|
||||
type registryConfigurator struct {
|
||||
guid string
|
||||
routingAll bool
|
||||
gpo bool
|
||||
nrptEntryCount int
|
||||
guid string
|
||||
routingAll bool
|
||||
gpo bool
|
||||
nrptEntryCount int
|
||||
origNameservers []netip.Addr
|
||||
}
|
||||
|
||||
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||
@@ -94,6 +98,17 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||
gpo: useGPO,
|
||||
}
|
||||
|
||||
origNameservers, err := configurator.captureOriginalNameservers()
|
||||
switch {
|
||||
case err != nil:
|
||||
log.Warnf("capture original nameservers from non-WG adapters: %v", err)
|
||||
case len(origNameservers) == 0:
|
||||
log.Warnf("no original nameservers captured from non-WG adapters; DNS fallback will be empty")
|
||||
default:
|
||||
log.Debugf("captured %d original nameservers from non-WG adapters: %v", len(origNameservers), origNameservers)
|
||||
}
|
||||
configurator.origNameservers = origNameservers
|
||||
|
||||
if err := configurator.configureInterface(); err != nil {
|
||||
log.Errorf("failed to configure interface settings: %v", err)
|
||||
}
|
||||
@@ -101,6 +116,98 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
|
||||
return configurator, nil
|
||||
}
|
||||
|
||||
// captureOriginalNameservers reads DNS addresses from every Tcpip(6) interface
|
||||
// registry key except the WG adapter. v4 and v6 servers live in separate
|
||||
// hives (Tcpip vs Tcpip6) keyed by the same interface GUID.
|
||||
func (r *registryConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
|
||||
seen := make(map[netip.Addr]struct{})
|
||||
var out []netip.Addr
|
||||
var merr *multierror.Error
|
||||
for _, root := range []string{interfaceConfigPath, interfaceConfigPathV6} {
|
||||
addrs, err := r.captureFromTcpipRoot(root)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("%s: %w", root, err))
|
||||
continue
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
if _, dup := seen[addr]; dup {
|
||||
continue
|
||||
}
|
||||
seen[addr] = struct{}{}
|
||||
out = append(out, addr)
|
||||
}
|
||||
}
|
||||
return out, nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) captureFromTcpipRoot(rootPath string) ([]netip.Addr, error) {
|
||||
root, err := registry.OpenKey(registry.LOCAL_MACHINE, rootPath, registry.READ)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open key: %w", err)
|
||||
}
|
||||
defer closer(root)
|
||||
|
||||
guids, err := root.ReadSubKeyNames(-1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read subkeys: %w", err)
|
||||
}
|
||||
|
||||
var out []netip.Addr
|
||||
for _, guid := range guids {
|
||||
if strings.EqualFold(guid, r.guid) {
|
||||
continue
|
||||
}
|
||||
out = append(out, readInterfaceNameservers(rootPath, guid)...)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func readInterfaceNameservers(rootPath, guid string) []netip.Addr {
|
||||
keyPath := rootPath + "\\" + guid
|
||||
k, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer closer(k)
|
||||
|
||||
// Static NameServer wins over DhcpNameServer for actual resolution.
|
||||
for _, name := range []string{interfaceConfigNameServerKey, interfaceConfigDhcpNameSrvKey} {
|
||||
raw, _, err := k.GetStringValue(name)
|
||||
if err != nil || raw == "" {
|
||||
continue
|
||||
}
|
||||
if out := parseRegistryNameservers(raw); len(out) > 0 {
|
||||
return out
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseRegistryNameservers(raw string) []netip.Addr {
|
||||
var out []netip.Addr
|
||||
for _, field := range strings.FieldsFunc(raw, func(r rune) bool { return r == ',' || r == ' ' || r == '\t' }) {
|
||||
addr, err := netip.ParseAddr(strings.TrimSpace(field))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
addr = addr.Unmap()
|
||||
if !addr.IsValid() || addr.IsUnspecified() {
|
||||
continue
|
||||
}
|
||||
// Drop unzoned link-local: not routable without a scope id. If
|
||||
// the user wrote "fe80::1%eth0" ParseAddr preserves the zone.
|
||||
if addr.IsLinkLocalUnicast() && addr.Zone() == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, addr)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) getOriginalNameservers() []netip.Addr {
|
||||
return slices.Clone(r.origNameservers)
|
||||
}
|
||||
|
||||
func (r *registryConfigurator) supportCustomPort() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ func (h *hostsDNSHolder) set(list []netip.AddrPort) {
|
||||
h.mutex.Unlock()
|
||||
}
|
||||
|
||||
//nolint:unused
|
||||
func (h *hostsDNSHolder) get() map[netip.AddrPort]struct{} {
|
||||
h.mutex.RLock()
|
||||
l := h.unprotectedDNSList
|
||||
|
||||
@@ -26,6 +26,19 @@ type resolver interface {
|
||||
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
|
||||
}
|
||||
|
||||
// PeerConnectivity reports whether a tunnel IP belongs to a peer the
|
||||
// client knows about and whether that peer is currently connected. The
|
||||
// local resolver uses this to suppress A/AAAA answers whose RDATA points
|
||||
// at a disconnected peer (typical case: a synthesized private-service
|
||||
// record pointing at an embedded proxy peer that just went offline).
|
||||
//
|
||||
// known=false means the IP isn't in the local peerstore at all — the
|
||||
// record is left alone (it points at something outside our mesh, e.g.
|
||||
// a non-peer upstream).
|
||||
type PeerConnectivity interface {
|
||||
IsConnectedByIP(ip string) (known, connected bool)
|
||||
}
|
||||
|
||||
type Resolver struct {
|
||||
mu sync.RWMutex
|
||||
records map[dns.Question][]dns.RR
|
||||
@@ -33,6 +46,11 @@ type Resolver struct {
|
||||
// zones maps zone domain -> NonAuthoritative (true = non-authoritative, user-created zone)
|
||||
zones map[domain.Domain]bool
|
||||
resolver resolver
|
||||
// peerConn, when non-nil, is consulted on every A/AAAA answer to
|
||||
// drop records pointing at disconnected peers. nil disables the
|
||||
// filter and preserves the legacy "return whatever is registered"
|
||||
// behaviour for callers that never wire a status source.
|
||||
peerConn PeerConnectivity
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
@@ -49,6 +67,15 @@ func NewResolver() *Resolver {
|
||||
}
|
||||
}
|
||||
|
||||
// SetPeerConnectivity wires the per-IP connectivity check used to filter
|
||||
// out A/AAAA answers pointing at disconnected peers. Pass nil to disable.
|
||||
// Safe to call multiple times; the latest value wins.
|
||||
func (d *Resolver) SetPeerConnectivity(p PeerConnectivity) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
d.peerConn = p
|
||||
}
|
||||
|
||||
func (d *Resolver) MatchSubdomains() bool {
|
||||
return true
|
||||
}
|
||||
@@ -76,8 +103,6 @@ func (d *Resolver) ID() types.HandlerID {
|
||||
return "local-resolver"
|
||||
}
|
||||
|
||||
func (d *Resolver) ProbeAvailability(context.Context) {}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
logger := log.WithFields(log.Fields{
|
||||
@@ -97,6 +122,7 @@ func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
replyMessage.RecursionAvailable = true
|
||||
|
||||
result := d.lookupRecords(logger, question)
|
||||
result.records = d.filterDisconnectedPeerAnswers(logger, question, result.records)
|
||||
replyMessage.Authoritative = !result.hasExternalData
|
||||
replyMessage.Answer = result.records
|
||||
replyMessage.Rcode = d.determineRcode(question, result)
|
||||
@@ -438,6 +464,78 @@ func (d *Resolver) logDNSError(logger *log.Entry, hostname string, qtype uint16,
|
||||
}
|
||||
}
|
||||
|
||||
// filterDisconnectedPeerAnswers drops A/AAAA records whose RDATA matches
|
||||
// a known but disconnected peer. The synthesized private-service zones
|
||||
// emit one A record per connected proxy peer in a cluster; when a peer
|
||||
// goes offline, the server-side refresh removes the record from the
|
||||
// next netmap, but the client may still hold the previous netmap for a
|
||||
// short window. This filter is the local belt to that braces — even on
|
||||
// the stale netmap, the resolver hides the offline target.
|
||||
//
|
||||
// Records pointing at unknown IPs (outside the local peerstore, e.g.
|
||||
// non-mesh upstreams) are never dropped. Non-A/AAAA records pass
|
||||
// through untouched.
|
||||
//
|
||||
// Escape hatch: if filtering would leave the answer empty AND at least
|
||||
// one record was filtered, the original list is returned. Better to
|
||||
// hand the client a record that may not respond than NXDOMAIN it
|
||||
// completely when every proxy peer is offline (the upstream may still
|
||||
// be reachable some other way, or the peerstore may be stale).
|
||||
func (d *Resolver) filterDisconnectedPeerAnswers(logger *log.Entry, question dns.Question, records []dns.RR) []dns.RR {
|
||||
if len(records) == 0 {
|
||||
return records
|
||||
}
|
||||
d.mu.RLock()
|
||||
checker := d.peerConn
|
||||
d.mu.RUnlock()
|
||||
if checker == nil {
|
||||
return records
|
||||
}
|
||||
|
||||
kept := make([]dns.RR, 0, len(records))
|
||||
var dropped int
|
||||
for _, rr := range records {
|
||||
ip := extractRecordIP(rr)
|
||||
if ip == "" {
|
||||
kept = append(kept, rr)
|
||||
continue
|
||||
}
|
||||
known, connected := checker.IsConnectedByIP(ip)
|
||||
if known && !connected {
|
||||
dropped++
|
||||
continue
|
||||
}
|
||||
kept = append(kept, rr)
|
||||
}
|
||||
if dropped == 0 {
|
||||
return records
|
||||
}
|
||||
if len(kept) == 0 {
|
||||
logger.Debugf("all %d answers for %s point at disconnected peers; returning the original list", dropped, question.Name)
|
||||
return records
|
||||
}
|
||||
logger.Tracef("dropped %d disconnected-peer answer(s) for %s, returning %d", dropped, question.Name, len(kept))
|
||||
return kept
|
||||
}
|
||||
|
||||
// extractRecordIP returns the dotted-decimal / colon-hex IP carried by
|
||||
// an A or AAAA record, or "" for any other record type.
|
||||
func extractRecordIP(rr dns.RR) string {
|
||||
switch r := rr.(type) {
|
||||
case *dns.A:
|
||||
if r.A == nil {
|
||||
return ""
|
||||
}
|
||||
return r.A.String()
|
||||
case *dns.AAAA:
|
||||
if r.AAAA == nil {
|
||||
return ""
|
||||
}
|
||||
return r.AAAA.String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Update replaces all zones and their records
|
||||
func (d *Resolver) Update(customZones []nbdns.CustomZone) {
|
||||
d.mu.Lock()
|
||||
|
||||
@@ -30,6 +30,21 @@ func (m *mockResolver) LookupNetIP(ctx context.Context, network, host string) ([
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// mockPeerConnectivity returns canned (known, connected) results per IP.
|
||||
// Used by the disconnected-peer filter tests below. IPs not in the map
|
||||
// are reported as unknown so the filter leaves them alone.
|
||||
type mockPeerConnectivity struct {
|
||||
byIP map[string]struct{ known, connected bool }
|
||||
}
|
||||
|
||||
func (m mockPeerConnectivity) IsConnectedByIP(ip string) (known, connected bool) {
|
||||
v, ok := m.byIP[ip]
|
||||
if !ok {
|
||||
return false, false
|
||||
}
|
||||
return v.known, v.connected
|
||||
}
|
||||
|
||||
func TestLocalResolver_ServeDNS(t *testing.T) {
|
||||
recordA := nbdns.SimpleRecord{
|
||||
Name: "peera.netbird.cloud.",
|
||||
@@ -2652,3 +2667,114 @@ func BenchmarkIsInManagedZone_ManyZones(b *testing.B) {
|
||||
resolver.isInManagedZone(qname)
|
||||
}
|
||||
}
|
||||
|
||||
// TestLocalResolver_FilterDisconnectedPeerAnswers verifies the
|
||||
// connectivity-aware filtering layered on top of lookupRecords:
|
||||
// when an A record's IP belongs to a known peer that's disconnected,
|
||||
// the record is dropped from the answer. Records for unknown IPs pass
|
||||
// through. If filtering would empty the answer entirely and at least
|
||||
// one record was dropped, the original list is restored (escape hatch
|
||||
// for the "all proxies offline" case).
|
||||
func TestLocalResolver_FilterDisconnectedPeerAnswers(t *testing.T) {
|
||||
zone := "svc.cluster.netbird."
|
||||
connectedRec := nbdns.SimpleRecord{
|
||||
Name: zone,
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 5,
|
||||
RData: "100.64.0.10",
|
||||
}
|
||||
disconnectedRec := nbdns.SimpleRecord{
|
||||
Name: zone,
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 5,
|
||||
RData: "100.64.0.11",
|
||||
}
|
||||
unknownRec := nbdns.SimpleRecord{
|
||||
Name: zone,
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: 5,
|
||||
RData: "203.0.113.5",
|
||||
}
|
||||
|
||||
type ipState struct{ known, connected bool }
|
||||
tests := []struct {
|
||||
name string
|
||||
records []nbdns.SimpleRecord
|
||||
connByIP map[string]ipState
|
||||
wantInOrder []string
|
||||
}{
|
||||
{
|
||||
name: "drops disconnected peer, keeps connected",
|
||||
records: []nbdns.SimpleRecord{connectedRec, disconnectedRec},
|
||||
connByIP: map[string]ipState{
|
||||
"100.64.0.10": {known: true, connected: true},
|
||||
"100.64.0.11": {known: true, connected: false},
|
||||
},
|
||||
wantInOrder: []string{"100.64.0.10"},
|
||||
},
|
||||
{
|
||||
name: "unknown IPs pass through untouched",
|
||||
records: []nbdns.SimpleRecord{unknownRec, disconnectedRec},
|
||||
connByIP: map[string]ipState{
|
||||
"100.64.0.11": {known: true, connected: false},
|
||||
},
|
||||
wantInOrder: []string{"203.0.113.5"},
|
||||
},
|
||||
{
|
||||
name: "all disconnected falls back to original list",
|
||||
records: []nbdns.SimpleRecord{disconnectedRec, connectedRec},
|
||||
connByIP: map[string]ipState{
|
||||
"100.64.0.10": {known: true, connected: false},
|
||||
"100.64.0.11": {known: true, connected: false},
|
||||
},
|
||||
wantInOrder: []string{"100.64.0.11", "100.64.0.10"},
|
||||
},
|
||||
{
|
||||
name: "no checker wired returns all records",
|
||||
records: []nbdns.SimpleRecord{connectedRec, disconnectedRec},
|
||||
connByIP: nil,
|
||||
wantInOrder: []string{"100.64.0.10", "100.64.0.11"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
if tc.connByIP != nil {
|
||||
cm := mockPeerConnectivity{byIP: make(map[string]struct{ known, connected bool }, len(tc.connByIP))}
|
||||
for ip, st := range tc.connByIP {
|
||||
cm.byIP[ip] = struct{ known, connected bool }{st.known, st.connected}
|
||||
}
|
||||
resolver.SetPeerConnectivity(cm)
|
||||
}
|
||||
resolver.Update([]nbdns.CustomZone{{
|
||||
Domain: strings.TrimSuffix(zone, "."),
|
||||
Records: tc.records,
|
||||
NonAuthoritative: true,
|
||||
}})
|
||||
|
||||
var got *dns.Msg
|
||||
writer := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
got = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
req := new(dns.Msg).SetQuestion(zone, dns.TypeA)
|
||||
resolver.ServeDNS(writer, req)
|
||||
|
||||
require.NotNil(t, got, "resolver must produce a response")
|
||||
require.Len(t, got.Answer, len(tc.wantInOrder),
|
||||
"answer count must match expected: %v", tc.wantInOrder)
|
||||
for i, want := range tc.wantInOrder {
|
||||
a, ok := got.Answer[i].(*dns.A)
|
||||
require.True(t, ok, "answer[%d] must be an A record", i)
|
||||
assert.Equal(t, want, a.A.String(),
|
||||
"answer[%d] expected %s got %s", i, want, a.A.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
@@ -70,10 +71,6 @@ func (m *MockServer) SearchDomains() []string {
|
||||
return make([]string, 0)
|
||||
}
|
||||
|
||||
// ProbeAvailability mocks implementation of ProbeAvailability from the Server interface
|
||||
func (m *MockServer) ProbeAvailability() {
|
||||
}
|
||||
|
||||
func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
|
||||
if m.UpdateServerConfigFunc != nil {
|
||||
return m.UpdateServerConfigFunc(domains)
|
||||
@@ -85,8 +82,8 @@ func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetRouteChecker mock implementation of SetRouteChecker from Server interface
|
||||
func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) {
|
||||
// SetRouteSources mock implementation of SetRouteSources from Server interface
|
||||
func (m *MockServer) SetRouteSources(selected, active func() route.HAMap) {
|
||||
// Mock implementation - no-op
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -32,6 +33,15 @@ const (
|
||||
networkManagerDbusDeviceGetAppliedConnectionMethod = networkManagerDbusDeviceInterface + ".GetAppliedConnection"
|
||||
networkManagerDbusDeviceReapplyMethod = networkManagerDbusDeviceInterface + ".Reapply"
|
||||
networkManagerDbusDeviceDeleteMethod = networkManagerDbusDeviceInterface + ".Delete"
|
||||
networkManagerDbusDeviceIp4ConfigProperty = networkManagerDbusDeviceInterface + ".Ip4Config"
|
||||
networkManagerDbusDeviceIp6ConfigProperty = networkManagerDbusDeviceInterface + ".Ip6Config"
|
||||
networkManagerDbusDeviceIfaceProperty = networkManagerDbusDeviceInterface + ".Interface"
|
||||
networkManagerDbusGetDevicesMethod = networkManagerDest + ".GetDevices"
|
||||
networkManagerDbusIp4ConfigInterface = "org.freedesktop.NetworkManager.IP4Config"
|
||||
networkManagerDbusIp6ConfigInterface = "org.freedesktop.NetworkManager.IP6Config"
|
||||
networkManagerDbusIp4ConfigNameserverDataProperty = networkManagerDbusIp4ConfigInterface + ".NameserverData"
|
||||
networkManagerDbusIp4ConfigNameserversProperty = networkManagerDbusIp4ConfigInterface + ".Nameservers"
|
||||
networkManagerDbusIp6ConfigNameserversProperty = networkManagerDbusIp6ConfigInterface + ".Nameservers"
|
||||
networkManagerDbusDefaultBehaviorFlag networkManagerConfigBehavior = 0
|
||||
networkManagerDbusIPv4Key = "ipv4"
|
||||
networkManagerDbusIPv6Key = "ipv6"
|
||||
@@ -51,9 +61,10 @@ var supportedNetworkManagerVersionConstraints = []string{
|
||||
}
|
||||
|
||||
type networkManagerDbusConfigurator struct {
|
||||
dbusLinkObject dbus.ObjectPath
|
||||
routingAll bool
|
||||
ifaceName string
|
||||
dbusLinkObject dbus.ObjectPath
|
||||
routingAll bool
|
||||
ifaceName string
|
||||
origNameservers []netip.Addr
|
||||
}
|
||||
|
||||
// the types below are based on dbus specification, each field is mapped to a dbus type
|
||||
@@ -92,10 +103,200 @@ func newNetworkManagerDbusConfigurator(wgInterface string) (*networkManagerDbusC
|
||||
|
||||
log.Debugf("got network manager dbus Link Object: %s from net interface %s", s, wgInterface)
|
||||
|
||||
return &networkManagerDbusConfigurator{
|
||||
c := &networkManagerDbusConfigurator{
|
||||
dbusLinkObject: dbus.ObjectPath(s),
|
||||
ifaceName: wgInterface,
|
||||
}, nil
|
||||
}
|
||||
|
||||
origNameservers, err := c.captureOriginalNameservers()
|
||||
switch {
|
||||
case err != nil:
|
||||
log.Warnf("capture original nameservers from NetworkManager: %v", err)
|
||||
case len(origNameservers) == 0:
|
||||
log.Warnf("no original nameservers captured from non-WG NetworkManager devices; DNS fallback will be empty")
|
||||
default:
|
||||
log.Debugf("captured %d original nameservers from non-WG NetworkManager devices: %v", len(origNameservers), origNameservers)
|
||||
}
|
||||
c.origNameservers = origNameservers
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// captureOriginalNameservers reads DNS servers from every NM device's
|
||||
// IP4Config / IP6Config except our WG device.
|
||||
func (n *networkManagerDbusConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
|
||||
devices, err := networkManagerListDevices()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list devices: %w", err)
|
||||
}
|
||||
|
||||
seen := make(map[netip.Addr]struct{})
|
||||
var out []netip.Addr
|
||||
for _, dev := range devices {
|
||||
if dev == n.dbusLinkObject {
|
||||
continue
|
||||
}
|
||||
ifaceName := readNetworkManagerDeviceInterface(dev)
|
||||
for _, addr := range readNetworkManagerDeviceDNS(dev) {
|
||||
addr = addr.Unmap()
|
||||
if !addr.IsValid() || addr.IsUnspecified() {
|
||||
continue
|
||||
}
|
||||
// IP6Config.Nameservers is a byte slice without zone info;
|
||||
// reattach the device's interface name so a captured fe80::…
|
||||
// stays routable.
|
||||
if addr.IsLinkLocalUnicast() && ifaceName != "" {
|
||||
addr = addr.WithZone(ifaceName)
|
||||
}
|
||||
if _, dup := seen[addr]; dup {
|
||||
continue
|
||||
}
|
||||
seen[addr] = struct{}{}
|
||||
out = append(out, addr)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func readNetworkManagerDeviceInterface(devicePath dbus.ObjectPath) string {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, devicePath)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer closeConn()
|
||||
v, err := obj.GetProperty(networkManagerDbusDeviceIfaceProperty)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
s, _ := v.Value().(string)
|
||||
return s
|
||||
}
|
||||
|
||||
func networkManagerListDevices() ([]dbus.ObjectPath, error) {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dbus NetworkManager: %w", err)
|
||||
}
|
||||
defer closeConn()
|
||||
var devs []dbus.ObjectPath
|
||||
if err := obj.Call(networkManagerDbusGetDevicesMethod, dbusDefaultFlag).Store(&devs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return devs, nil
|
||||
}
|
||||
|
||||
func readNetworkManagerDeviceDNS(devicePath dbus.ObjectPath) []netip.Addr {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, devicePath)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer closeConn()
|
||||
|
||||
var out []netip.Addr
|
||||
if path := readNetworkManagerConfigPath(obj, networkManagerDbusDeviceIp4ConfigProperty); path != "" {
|
||||
out = append(out, readIPv4ConfigDNS(path)...)
|
||||
}
|
||||
if path := readNetworkManagerConfigPath(obj, networkManagerDbusDeviceIp6ConfigProperty); path != "" {
|
||||
out = append(out, readIPv6ConfigDNS(path)...)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func readNetworkManagerConfigPath(obj dbus.BusObject, property string) dbus.ObjectPath {
|
||||
v, err := obj.GetProperty(property)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
path, ok := v.Value().(dbus.ObjectPath)
|
||||
if !ok || path == "/" {
|
||||
return ""
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func readIPv4ConfigDNS(path dbus.ObjectPath) []netip.Addr {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer closeConn()
|
||||
|
||||
// NameserverData (NM 1.13+) carries strings; older NMs only expose the
|
||||
// legacy uint32 Nameservers property.
|
||||
if out := readIPv4NameserverData(obj); len(out) > 0 {
|
||||
return out
|
||||
}
|
||||
return readIPv4LegacyNameservers(obj)
|
||||
}
|
||||
|
||||
func readIPv4NameserverData(obj dbus.BusObject) []netip.Addr {
|
||||
v, err := obj.GetProperty(networkManagerDbusIp4ConfigNameserverDataProperty)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
entries, ok := v.Value().([]map[string]dbus.Variant)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
var out []netip.Addr
|
||||
for _, entry := range entries {
|
||||
addrVar, ok := entry["address"]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
s, ok := addrVar.Value().(string)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if a, err := netip.ParseAddr(s); err == nil {
|
||||
out = append(out, a)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func readIPv4LegacyNameservers(obj dbus.BusObject) []netip.Addr {
|
||||
v, err := obj.GetProperty(networkManagerDbusIp4ConfigNameserversProperty)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := v.Value().([]uint32)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
out := make([]netip.Addr, 0, len(raw))
|
||||
for _, n := range raw {
|
||||
var b [4]byte
|
||||
binary.LittleEndian.PutUint32(b[:], n)
|
||||
out = append(out, netip.AddrFrom4(b))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func readIPv6ConfigDNS(path dbus.ObjectPath) []netip.Addr {
|
||||
obj, closeConn, err := getDbusObject(networkManagerDest, path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer closeConn()
|
||||
v, err := obj.GetProperty(networkManagerDbusIp6ConfigNameserversProperty)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := v.Value().([][]byte)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
out := make([]netip.Addr, 0, len(raw))
|
||||
for _, b := range raw {
|
||||
if a, ok := netip.AddrFromSlice(b); ok {
|
||||
out = append(out, a)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (n *networkManagerDbusConfigurator) getOriginalNameservers() []netip.Addr {
|
||||
return slices.Clone(n.origNameservers)
|
||||
}
|
||||
|
||||
func (n *networkManagerDbusConfigurator) supportCustomPort() bool {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
package dns
|
||||
|
||||
func (s *DefaultServer) initialize() (manager hostManager, err error) {
|
||||
return newHostManager()
|
||||
return newHostManager(s.hostsDNSHolder)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
@@ -31,8 +32,10 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
@@ -101,16 +104,17 @@ func init() {
|
||||
formatter.SetTextFormatter(log.StandardLogger())
|
||||
}
|
||||
|
||||
func generateDummyHandler(domain string, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||
func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase {
|
||||
var srvs []netip.AddrPort
|
||||
for _, srv := range servers {
|
||||
srvs = append(srvs, srv.AddrPort())
|
||||
}
|
||||
return &upstreamResolverBase{
|
||||
domain: domain,
|
||||
upstreamServers: srvs,
|
||||
cancel: func() {},
|
||||
u := &upstreamResolverBase{
|
||||
domain: domain.Domain(d),
|
||||
cancel: func() {},
|
||||
}
|
||||
u.addRace(srvs)
|
||||
return u
|
||||
}
|
||||
|
||||
func TestUpdateDNSServer(t *testing.T) {
|
||||
@@ -653,74 +657,8 @@ func TestDNSServerStartStop(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
|
||||
hostManager := &mockHostConfigurator{}
|
||||
server := DefaultServer{
|
||||
ctx: context.Background(),
|
||||
service: NewServiceViaMemory(&mocWGIface{}),
|
||||
localResolver: local.NewResolver(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: hostManager,
|
||||
currentConfig: HostDNSConfig{
|
||||
Domains: []DomainConfig{
|
||||
{false, "domain0", false},
|
||||
{false, "domain1", false},
|
||||
{false, "domain2", false},
|
||||
},
|
||||
},
|
||||
statusRecorder: peer.NewRecorder("mgm"),
|
||||
}
|
||||
|
||||
var domainsUpdate string
|
||||
hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error {
|
||||
domains := []string{}
|
||||
for _, item := range config.Domains {
|
||||
if item.Disabled {
|
||||
continue
|
||||
}
|
||||
domains = append(domains, item.Domain)
|
||||
}
|
||||
domainsUpdate = strings.Join(domains, ",")
|
||||
return nil
|
||||
}
|
||||
|
||||
deactivate, reactivate := server.upstreamCallbacks(&nbdns.NameServerGroup{
|
||||
Domains: []string{"domain1"},
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.0.0"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
},
|
||||
}, nil, 0)
|
||||
|
||||
deactivate(nil)
|
||||
expected := "domain0,domain2"
|
||||
domains := []string{}
|
||||
for _, item := range server.currentConfig.Domains {
|
||||
if item.Disabled {
|
||||
continue
|
||||
}
|
||||
domains = append(domains, item.Domain)
|
||||
}
|
||||
got := strings.Join(domains, ",")
|
||||
if expected != got {
|
||||
t.Errorf("expected domains list: %q, got %q", expected, got)
|
||||
}
|
||||
|
||||
reactivate()
|
||||
expected = "domain0,domain1,domain2"
|
||||
domains = []string{}
|
||||
for _, item := range server.currentConfig.Domains {
|
||||
if item.Disabled {
|
||||
continue
|
||||
}
|
||||
domains = append(domains, item.Domain)
|
||||
}
|
||||
got = strings.Join(domains, ",")
|
||||
if expected != got {
|
||||
t.Errorf("expected domains list: %q, got %q", expected, domainsUpdate)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
||||
skipUnlessAndroid(t)
|
||||
wgIFace, err := createWgInterfaceWithBind(t)
|
||||
if err != nil {
|
||||
t.Fatal("failed to initialize wg interface")
|
||||
@@ -748,6 +686,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDNSPermanent_updateUpstream(t *testing.T) {
|
||||
skipUnlessAndroid(t)
|
||||
wgIFace, err := createWgInterfaceWithBind(t)
|
||||
if err != nil {
|
||||
t.Fatal("failed to initialize wg interface")
|
||||
@@ -841,6 +780,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
skipUnlessAndroid(t)
|
||||
wgIFace, err := createWgInterfaceWithBind(t)
|
||||
if err != nil {
|
||||
t.Fatal("failed to initialize wg interface")
|
||||
@@ -913,6 +853,18 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// skipUnlessAndroid marks tests that exercise the mobile-permanent DNS path,
|
||||
// which only matches a real production setup on android (NewDefaultServerPermanentUpstream
|
||||
// + androidHostManager). On non-android the desktop host manager replaces it
|
||||
// during Initialize and the assertion stops making sense. Skipped here until we
|
||||
// have an android CI runner.
|
||||
func skipUnlessAndroid(t *testing.T) {
|
||||
t.Helper()
|
||||
if runtime.GOOS != "android" {
|
||||
t.Skip("requires android runner; mobile-permanent path doesn't match production on this OS")
|
||||
}
|
||||
}
|
||||
|
||||
func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||
t.Helper()
|
||||
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
|
||||
@@ -1065,7 +1017,6 @@ type mockHandler struct {
|
||||
|
||||
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
||||
func (m *mockHandler) Stop() {}
|
||||
func (m *mockHandler) ProbeAvailability(context.Context) {}
|
||||
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
|
||||
|
||||
type mockService struct{}
|
||||
@@ -2085,6 +2036,598 @@ func TestLocalResolverPriorityConstants(t *testing.T) {
|
||||
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
|
||||
}
|
||||
|
||||
// TestBuildUpstreamHandler_MergesGroupsPerDomain verifies that multiple
|
||||
// admin-defined nameserver groups targeting the same domain collapse into a
|
||||
// single handler with each group preserved as a sequential inner list.
|
||||
func TestBuildUpstreamHandler_MergesGroupsPerDomain(t *testing.T) {
|
||||
wgInterface := &mocWGIface{}
|
||||
service := NewServiceViaMemory(wgInterface)
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
wgInterface: wgInterface,
|
||||
service: service,
|
||||
localResolver: local.NewResolver(),
|
||||
handlerChain: NewHandlerChain(),
|
||||
hostManager: &noopHostConfigurator{},
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
}
|
||||
|
||||
groups := []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("192.0.2.1"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
},
|
||||
Domains: []string{"example.com"},
|
||||
},
|
||||
{
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("192.0.2.2"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
{IP: netip.MustParseAddr("192.0.2.3"), NSType: nbdns.UDPNameServerType, Port: 53},
|
||||
},
|
||||
Domains: []string{"example.com"},
|
||||
},
|
||||
}
|
||||
|
||||
muxUpdates, err := server.buildUpstreamHandlerUpdate(groups)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, muxUpdates, 1, "same-domain groups should merge into one handler")
|
||||
assert.Equal(t, "example.com", muxUpdates[0].domain)
|
||||
assert.Equal(t, PriorityUpstream, muxUpdates[0].priority)
|
||||
|
||||
handler := muxUpdates[0].handler.(*upstreamResolver)
|
||||
require.Len(t, handler.upstreamServers, 2, "handler should have two groups")
|
||||
assert.Equal(t, upstreamRace{netip.MustParseAddrPort("192.0.2.1:53")}, handler.upstreamServers[0])
|
||||
assert.Equal(t, upstreamRace{
|
||||
netip.MustParseAddrPort("192.0.2.2:53"),
|
||||
netip.MustParseAddrPort("192.0.2.3:53"),
|
||||
}, handler.upstreamServers[1])
|
||||
}
|
||||
|
||||
// TestEvaluateNSGroupHealth covers the records-only verdict. The gate
|
||||
// (overlay route selected-but-no-active-peer) is intentionally NOT an
|
||||
// input to the evaluator anymore: the verdict drives the Enabled flag,
|
||||
// which must always reflect what we actually observed. Gate-aware event
|
||||
// suppression is tested separately in the projection test.
|
||||
//
|
||||
// Matrix per upstream: {no record, fresh Ok, fresh Fail, stale Fail,
|
||||
// stale Ok, Ok newer than Fail, Fail newer than Ok}.
|
||||
// Group verdict: any fresh-working → Healthy; any fresh-broken with no
|
||||
// fresh-working → Unhealthy; otherwise Undecided.
|
||||
func TestEvaluateNSGroupHealth(t *testing.T) {
|
||||
now := time.Now()
|
||||
a := netip.MustParseAddrPort("192.0.2.1:53")
|
||||
b := netip.MustParseAddrPort("192.0.2.2:53")
|
||||
|
||||
recentOk := UpstreamHealth{LastOk: now.Add(-2 * time.Second)}
|
||||
recentFail := UpstreamHealth{LastFail: now.Add(-1 * time.Second), LastErr: "timeout"}
|
||||
staleOk := UpstreamHealth{LastOk: now.Add(-10 * time.Minute)}
|
||||
staleFail := UpstreamHealth{LastFail: now.Add(-10 * time.Minute), LastErr: "timeout"}
|
||||
okThenFail := UpstreamHealth{
|
||||
LastOk: now.Add(-10 * time.Second),
|
||||
LastFail: now.Add(-1 * time.Second),
|
||||
LastErr: "timeout",
|
||||
}
|
||||
failThenOk := UpstreamHealth{
|
||||
LastOk: now.Add(-1 * time.Second),
|
||||
LastFail: now.Add(-10 * time.Second),
|
||||
LastErr: "timeout",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
health map[netip.AddrPort]UpstreamHealth
|
||||
servers []netip.AddrPort
|
||||
wantVerdict nsGroupVerdict
|
||||
wantErrSubst string
|
||||
}{
|
||||
{
|
||||
name: "no record, undecided",
|
||||
servers: []netip.AddrPort{a},
|
||||
wantVerdict: nsVerdictUndecided,
|
||||
},
|
||||
{
|
||||
name: "fresh success, healthy",
|
||||
health: map[netip.AddrPort]UpstreamHealth{a: recentOk},
|
||||
servers: []netip.AddrPort{a},
|
||||
wantVerdict: nsVerdictHealthy,
|
||||
},
|
||||
{
|
||||
name: "fresh failure, unhealthy",
|
||||
health: map[netip.AddrPort]UpstreamHealth{a: recentFail},
|
||||
servers: []netip.AddrPort{a},
|
||||
wantVerdict: nsVerdictUnhealthy,
|
||||
wantErrSubst: "timeout",
|
||||
},
|
||||
{
|
||||
name: "only stale success, undecided",
|
||||
health: map[netip.AddrPort]UpstreamHealth{a: staleOk},
|
||||
servers: []netip.AddrPort{a},
|
||||
wantVerdict: nsVerdictUndecided,
|
||||
},
|
||||
{
|
||||
name: "only stale failure, undecided",
|
||||
health: map[netip.AddrPort]UpstreamHealth{a: staleFail},
|
||||
servers: []netip.AddrPort{a},
|
||||
wantVerdict: nsVerdictUndecided,
|
||||
},
|
||||
{
|
||||
name: "both fresh, fail newer, unhealthy",
|
||||
health: map[netip.AddrPort]UpstreamHealth{a: okThenFail},
|
||||
servers: []netip.AddrPort{a},
|
||||
wantVerdict: nsVerdictUnhealthy,
|
||||
wantErrSubst: "timeout",
|
||||
},
|
||||
{
|
||||
name: "both fresh, ok newer, healthy",
|
||||
health: map[netip.AddrPort]UpstreamHealth{a: failThenOk},
|
||||
servers: []netip.AddrPort{a},
|
||||
wantVerdict: nsVerdictHealthy,
|
||||
},
|
||||
{
|
||||
name: "two upstreams, one success wins",
|
||||
health: map[netip.AddrPort]UpstreamHealth{
|
||||
a: recentFail,
|
||||
b: recentOk,
|
||||
},
|
||||
servers: []netip.AddrPort{a, b},
|
||||
wantVerdict: nsVerdictHealthy,
|
||||
},
|
||||
{
|
||||
name: "two upstreams, one fail one unseen, unhealthy",
|
||||
health: map[netip.AddrPort]UpstreamHealth{
|
||||
a: recentFail,
|
||||
},
|
||||
servers: []netip.AddrPort{a, b},
|
||||
wantVerdict: nsVerdictUnhealthy,
|
||||
wantErrSubst: "timeout",
|
||||
},
|
||||
{
|
||||
name: "two upstreams, all recent failures, unhealthy",
|
||||
health: map[netip.AddrPort]UpstreamHealth{
|
||||
a: {LastFail: now.Add(-5 * time.Second), LastErr: "timeout"},
|
||||
b: {LastFail: now.Add(-1 * time.Second), LastErr: "SERVFAIL"},
|
||||
},
|
||||
servers: []netip.AddrPort{a, b},
|
||||
wantVerdict: nsVerdictUnhealthy,
|
||||
wantErrSubst: "SERVFAIL",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
verdict, err := evaluateNSGroupHealth(tc.health, tc.servers, now)
|
||||
assert.Equal(t, tc.wantVerdict, verdict, "verdict mismatch")
|
||||
if tc.wantErrSubst != "" {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tc.wantErrSubst)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// healthStubHandler is a minimal dnsMuxMap entry that exposes a fixed
|
||||
// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates
|
||||
// without spinning up real handlers.
|
||||
type healthStubHandler struct {
|
||||
health map[netip.AddrPort]UpstreamHealth
|
||||
}
|
||||
|
||||
func (h *healthStubHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
||||
func (h *healthStubHandler) Stop() {}
|
||||
func (h *healthStubHandler) ID() types.HandlerID { return "health-stub" }
|
||||
func (h *healthStubHandler) UpstreamHealth() map[netip.AddrPort]UpstreamHealth {
|
||||
return h.health
|
||||
}
|
||||
|
||||
// TestProjection_SteadyStateIsSilent guards against duplicate events:
|
||||
// while a group stays Unhealthy tick after tick, only the first
|
||||
// Unhealthy transition may emit. Same for staying Healthy.
|
||||
func TestProjection_SteadyStateIsSilent(t *testing.T) {
|
||||
fx := newProjTestFixture(t)
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
fx.tick()
|
||||
fx.expectEvent("unreachable", "first fail emits warning")
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
fx.tick()
|
||||
fx.tick()
|
||||
fx.expectNoEvent("staying unhealthy must not re-emit")
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||
fx.tick()
|
||||
fx.expectEvent("recovered", "recovery on transition")
|
||||
|
||||
fx.tick()
|
||||
fx.tick()
|
||||
fx.expectNoEvent("staying healthy must not re-emit")
|
||||
}
|
||||
|
||||
// projTestFixture is the common setup for the projection tests: a
|
||||
// single-upstream group whose route classification the test can flip by
|
||||
// assigning to selected/active. Callers drive failures/successes by
|
||||
// mutating stub.health and calling refreshHealth.
|
||||
type projTestFixture struct {
|
||||
t *testing.T
|
||||
recorder *peer.Status
|
||||
events <-chan *proto.SystemEvent
|
||||
server *DefaultServer
|
||||
stub *healthStubHandler
|
||||
group *nbdns.NameServerGroup
|
||||
srv netip.AddrPort
|
||||
selected route.HAMap
|
||||
active route.HAMap
|
||||
}
|
||||
|
||||
func newProjTestFixture(t *testing.T) *projTestFixture {
|
||||
t.Helper()
|
||||
recorder := peer.NewRecorder("mgm")
|
||||
sub := recorder.SubscribeToEvents()
|
||||
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
|
||||
|
||||
srv := netip.MustParseAddrPort("100.64.0.1:53")
|
||||
fx := &projTestFixture{
|
||||
t: t,
|
||||
recorder: recorder,
|
||||
events: sub.Events(),
|
||||
stub: &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{}},
|
||||
srv: srv,
|
||||
group: &nbdns.NameServerGroup{
|
||||
Domains: []string{"example.com"},
|
||||
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
|
||||
},
|
||||
}
|
||||
fx.server = &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
wgInterface: &mocWGIface{},
|
||||
statusRecorder: recorder,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
selectedRoutes: func() route.HAMap { return fx.selected },
|
||||
activeRoutes: func() route.HAMap { return fx.active },
|
||||
warningDelayBase: defaultWarningDelayBase,
|
||||
}
|
||||
fx.server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}
|
||||
|
||||
fx.server.mux.Lock()
|
||||
fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group})
|
||||
fx.server.mux.Unlock()
|
||||
return fx
|
||||
}
|
||||
|
||||
func (f *projTestFixture) setHealth(h UpstreamHealth) {
|
||||
f.stub.health = map[netip.AddrPort]UpstreamHealth{f.srv: h}
|
||||
}
|
||||
|
||||
func (f *projTestFixture) tick() []peer.NSGroupState {
|
||||
f.server.refreshHealth()
|
||||
return f.recorder.GetDNSStates()
|
||||
}
|
||||
|
||||
func (f *projTestFixture) expectNoEvent(why string) {
|
||||
f.t.Helper()
|
||||
select {
|
||||
case evt := <-f.events:
|
||||
f.t.Fatalf("unexpected event (%s): %+v", why, evt)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
func (f *projTestFixture) expectEvent(substr, why string) *proto.SystemEvent {
|
||||
f.t.Helper()
|
||||
select {
|
||||
case evt := <-f.events:
|
||||
assert.Contains(f.t, evt.Message, substr, why)
|
||||
return evt
|
||||
case <-time.After(time.Second):
|
||||
f.t.Fatalf("expected event (%s) with %q", why, substr)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
var overlayNetForTest = netip.MustParsePrefix("100.64.0.0/16")
|
||||
var overlayMapForTest = route.HAMap{"overlay": {{Network: overlayNetForTest}}}
|
||||
|
||||
// TestProjection_PublicFailEmitsImmediately covers rule 1: an upstream
|
||||
// that is not inside any selected route (public DNS) fires the warning
|
||||
// on the first Unhealthy tick, no grace period.
|
||||
func TestProjection_PublicFailEmitsImmediately(t *testing.T) {
|
||||
fx := newProjTestFixture(t)
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
states := fx.tick()
|
||||
require.Len(t, states, 1)
|
||||
assert.False(t, states[0].Enabled)
|
||||
fx.expectEvent("unreachable", "public DNS failure")
|
||||
}
|
||||
|
||||
// TestProjection_OverlayConnectedFailEmitsImmediately covers rule 2:
|
||||
// the upstream is inside a selected route AND the route has a Connected
|
||||
// peer. Tunnel is up, failure is real, emit immediately.
|
||||
func TestProjection_OverlayConnectedFailEmitsImmediately(t *testing.T) {
|
||||
fx := newProjTestFixture(t)
|
||||
fx.selected = overlayMapForTest
|
||||
fx.active = overlayMapForTest
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
states := fx.tick()
|
||||
require.Len(t, states, 1)
|
||||
assert.False(t, states[0].Enabled)
|
||||
fx.expectEvent("unreachable", "overlay + connected failure")
|
||||
}
|
||||
|
||||
// TestProjection_OverlayNotConnectedDelaysWarning covers rule 3: the
|
||||
// upstream is routed but no peer is Connected (Connecting/Idle/missing).
|
||||
// First tick: Unhealthy display, no warning. After the grace window
|
||||
// elapses with no recovery, the warning fires.
|
||||
func TestProjection_OverlayNotConnectedDelaysWarning(t *testing.T) {
|
||||
grace := 50 * time.Millisecond
|
||||
fx := newProjTestFixture(t)
|
||||
fx.server.warningDelayBase = grace
|
||||
fx.selected = overlayMapForTest
|
||||
// active stays nil: routed but not connected.
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
states := fx.tick()
|
||||
require.Len(t, states, 1)
|
||||
assert.False(t, states[0].Enabled, "display must reflect failure even during grace window")
|
||||
fx.expectNoEvent("first fail tick within grace window")
|
||||
|
||||
time.Sleep(grace + 10*time.Millisecond)
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
fx.tick()
|
||||
fx.expectEvent("unreachable", "warning after grace window")
|
||||
}
|
||||
|
||||
// TestProjection_OverlayAddrNoRouteDelaysWarning covers an upstream
|
||||
// whose address is inside the WireGuard overlay range but is not
|
||||
// covered by any selected route (peer-to-peer DNS without an explicit
|
||||
// route). Until a peer reports Connected for that address, startup
|
||||
// failures must be held just like the routed case.
|
||||
func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
|
||||
recorder := peer.NewRecorder("mgm")
|
||||
sub := recorder.SubscribeToEvents()
|
||||
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
|
||||
|
||||
overlayPeer := netip.MustParseAddrPort("100.66.100.5:53")
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
wgInterface: &mocWGIface{},
|
||||
statusRecorder: recorder,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
selectedRoutes: func() route.HAMap { return nil },
|
||||
activeRoutes: func() route.HAMap { return nil },
|
||||
warningDelayBase: 50 * time.Millisecond,
|
||||
}
|
||||
group := &nbdns.NameServerGroup{
|
||||
Domains: []string{"example.com"},
|
||||
NameServers: []nbdns.NameServer{{IP: overlayPeer.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlayPeer.Port())}},
|
||||
}
|
||||
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{
|
||||
overlayPeer: {LastFail: time.Now(), LastErr: "timeout"},
|
||||
}}
|
||||
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
||||
|
||||
server.mux.Lock()
|
||||
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||
server.mux.Unlock()
|
||||
server.refreshHealth()
|
||||
|
||||
select {
|
||||
case evt := <-sub.Events():
|
||||
t.Fatalf("unexpected event during grace window: %+v", evt)
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
stub.health = map[netip.AddrPort]UpstreamHealth{overlayPeer: {LastFail: time.Now(), LastErr: "timeout"}}
|
||||
server.refreshHealth()
|
||||
|
||||
select {
|
||||
case evt := <-sub.Events():
|
||||
assert.Contains(t, evt.Message, "unreachable")
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected warning after grace window")
|
||||
}
|
||||
}
|
||||
|
||||
// TestProjection_StopClearsHealthState verifies that Stop wipes the
|
||||
// per-group projection state so a subsequent Start doesn't inherit
|
||||
// sticky flags (notably everHealthy) that would bypass the grace
|
||||
// window during the next peer handshake.
|
||||
func TestProjection_StopClearsHealthState(t *testing.T) {
|
||||
wgIface := &mocWGIface{}
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
wgInterface: wgIface,
|
||||
service: NewServiceViaMemory(wgIface),
|
||||
hostManager: &noopHostConfigurator{},
|
||||
extraDomains: map[domain.Domain]int{},
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
statusRecorder: peer.NewRecorder("mgm"),
|
||||
selectedRoutes: func() route.HAMap { return nil },
|
||||
activeRoutes: func() route.HAMap { return nil },
|
||||
warningDelayBase: defaultWarningDelayBase,
|
||||
currentConfigHash: ^uint64(0),
|
||||
}
|
||||
server.ctx, server.ctxCancel = context.WithCancel(context.Background())
|
||||
|
||||
srv := netip.MustParseAddrPort("8.8.8.8:53")
|
||||
group := &nbdns.NameServerGroup{
|
||||
Domains: []string{"example.com"},
|
||||
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
|
||||
}
|
||||
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}}
|
||||
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
||||
|
||||
server.mux.Lock()
|
||||
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||
server.mux.Unlock()
|
||||
server.refreshHealth()
|
||||
|
||||
server.healthProjectMu.Lock()
|
||||
p, ok := server.nsGroupProj[generateGroupKey(group)]
|
||||
server.healthProjectMu.Unlock()
|
||||
require.True(t, ok, "projection state should exist after tick")
|
||||
require.True(t, p.everHealthy, "tick with success must set everHealthy")
|
||||
|
||||
server.Stop()
|
||||
|
||||
server.healthProjectMu.Lock()
|
||||
cleared := server.nsGroupProj == nil
|
||||
server.healthProjectMu.Unlock()
|
||||
assert.True(t, cleared, "Stop must clear nsGroupProj")
|
||||
}
|
||||
|
||||
// TestProjection_OverlayRecoversDuringGrace covers the happy path of
|
||||
// rule 3: startup failures while the peer is handshaking, then the peer
|
||||
// comes up and a query succeeds before the grace window elapses. No
|
||||
// warning should ever have fired, and no recovery either.
|
||||
func TestProjection_OverlayRecoversDuringGrace(t *testing.T) {
|
||||
fx := newProjTestFixture(t)
|
||||
fx.server.warningDelayBase = 200 * time.Millisecond
|
||||
fx.selected = overlayMapForTest
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
fx.tick()
|
||||
fx.expectNoEvent("fail within grace, warning suppressed")
|
||||
|
||||
fx.active = overlayMapForTest
|
||||
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||
states := fx.tick()
|
||||
require.Len(t, states, 1)
|
||||
assert.True(t, states[0].Enabled)
|
||||
fx.expectNoEvent("recovery without prior warning must not emit")
|
||||
}
|
||||
|
||||
// TestProjection_RecoveryOnlyAfterWarning enforces the invariant the
|
||||
// whole design leans on: recovery events only appear when a warning
|
||||
// event was actually emitted for the current streak. A Healthy verdict
|
||||
// without a prior warning is silent, so the user never sees "recovered"
|
||||
// out of thin air.
|
||||
func TestProjection_RecoveryOnlyAfterWarning(t *testing.T) {
|
||||
fx := newProjTestFixture(t)
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||
states := fx.tick()
|
||||
require.Len(t, states, 1)
|
||||
assert.True(t, states[0].Enabled)
|
||||
fx.expectNoEvent("first healthy tick should not recover anything")
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
fx.tick()
|
||||
fx.expectEvent("unreachable", "public fail emits immediately")
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||
fx.tick()
|
||||
fx.expectEvent("recovered", "recovery follows real warning")
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
fx.tick()
|
||||
fx.expectEvent("unreachable", "second cycle warning")
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||
fx.tick()
|
||||
fx.expectEvent("recovered", "second cycle recovery")
|
||||
}
|
||||
|
||||
// TestProjection_EverHealthyOverridesDelay covers rule 4: once a group
|
||||
// has ever been Healthy, subsequent failures skip the grace window even
|
||||
// if classification says "routed + not connected". The system has
|
||||
// proved it can work, so any new failure is real.
|
||||
func TestProjection_EverHealthyOverridesDelay(t *testing.T) {
|
||||
fx := newProjTestFixture(t)
|
||||
// Large base so any emission must come from the everHealthy bypass, not elapsed time.
|
||||
fx.server.warningDelayBase = time.Hour
|
||||
fx.selected = overlayMapForTest
|
||||
fx.active = overlayMapForTest
|
||||
|
||||
// Establish "ever healthy".
|
||||
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||
fx.tick()
|
||||
fx.expectNoEvent("first healthy tick")
|
||||
|
||||
// Peer drops. Query fails. Routed + not connected → normally grace,
|
||||
// but everHealthy flag bypasses it.
|
||||
fx.active = nil
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
fx.tick()
|
||||
fx.expectEvent("unreachable", "failure after ever-healthy must be immediate")
|
||||
}
|
||||
|
||||
// TestProjection_ReconnectBlipEmitsPair covers the explicit tradeoff
|
||||
// from the design discussion: once a group has been healthy, a brief
|
||||
// reconnect that produces a failing tick will fire warning + recovery.
|
||||
// This is by design: user-visible blips are accurate signal, not noise.
|
||||
func TestProjection_ReconnectBlipEmitsPair(t *testing.T) {
|
||||
fx := newProjTestFixture(t)
|
||||
fx.selected = overlayMapForTest
|
||||
fx.active = overlayMapForTest
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||
fx.tick()
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastFail: time.Now(), LastErr: "timeout"})
|
||||
fx.tick()
|
||||
fx.expectEvent("unreachable", "blip warning")
|
||||
|
||||
fx.setHealth(UpstreamHealth{LastOk: time.Now()})
|
||||
fx.tick()
|
||||
fx.expectEvent("recovered", "blip recovery")
|
||||
}
|
||||
|
||||
// TestProjection_MixedGroupEmitsImmediately covers the multi-upstream
|
||||
// rule: a group with at least one public upstream is in the "immediate"
|
||||
// category regardless of the other upstreams' routing, because the
|
||||
// public one has no peer-startup excuse. Prevents public-DNS failures
|
||||
// from being hidden behind a routed sibling.
|
||||
func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
|
||||
recorder := peer.NewRecorder("mgm")
|
||||
sub := recorder.SubscribeToEvents()
|
||||
t.Cleanup(func() { recorder.UnsubscribeFromEvents(sub) })
|
||||
events := sub.Events()
|
||||
|
||||
public := netip.MustParseAddrPort("8.8.8.8:53")
|
||||
overlay := netip.MustParseAddrPort("100.64.0.1:53")
|
||||
overlayMap := route.HAMap{"overlay": {{Network: netip.MustParsePrefix("100.64.0.0/16")}}}
|
||||
|
||||
server := &DefaultServer{
|
||||
ctx: context.Background(),
|
||||
statusRecorder: recorder,
|
||||
dnsMuxMap: make(registeredHandlerMap),
|
||||
selectedRoutes: func() route.HAMap { return overlayMap },
|
||||
activeRoutes: func() route.HAMap { return nil },
|
||||
warningDelayBase: time.Hour,
|
||||
}
|
||||
group := &nbdns.NameServerGroup{
|
||||
Domains: []string{"example.com"},
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: public.Addr(), NSType: nbdns.UDPNameServerType, Port: int(public.Port())},
|
||||
{IP: overlay.Addr(), NSType: nbdns.UDPNameServerType, Port: int(overlay.Port())},
|
||||
},
|
||||
}
|
||||
stub := &healthStubHandler{
|
||||
health: map[netip.AddrPort]UpstreamHealth{
|
||||
public: {LastFail: time.Now(), LastErr: "servfail"},
|
||||
overlay: {LastFail: time.Now(), LastErr: "timeout"},
|
||||
},
|
||||
}
|
||||
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream}
|
||||
|
||||
server.mux.Lock()
|
||||
server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
|
||||
server.mux.Unlock()
|
||||
server.refreshHealth()
|
||||
|
||||
select {
|
||||
case evt := <-events:
|
||||
assert.Contains(t, evt.Message, "unreachable")
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected immediate warning because group contains a public upstream")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSLoopPrevention(t *testing.T) {
|
||||
wgInterface := &mocWGIface{}
|
||||
service := NewServiceViaMemory(wgInterface)
|
||||
@@ -2183,17 +2726,18 @@ func TestDNSLoopPrevention(t *testing.T) {
|
||||
|
||||
if tt.expectedHandlers > 0 {
|
||||
handler := muxUpdates[0].handler.(*upstreamResolver)
|
||||
assert.Len(t, handler.upstreamServers, len(tt.expectedServers))
|
||||
flat := handler.flatUpstreams()
|
||||
assert.Len(t, flat, len(tt.expectedServers))
|
||||
|
||||
if tt.shouldFilterOwnIP {
|
||||
for _, upstream := range handler.upstreamServers {
|
||||
for _, upstream := range flat {
|
||||
assert.NotEqual(t, dnsServerIP, upstream.Addr())
|
||||
}
|
||||
}
|
||||
|
||||
for _, expected := range tt.expectedServers {
|
||||
found := false
|
||||
for _, upstream := range handler.upstreamServers {
|
||||
for _, upstream := range flat {
|
||||
if upstream.Addr() == expected {
|
||||
found = true
|
||||
break
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/godbus/dbus/v5"
|
||||
@@ -40,10 +41,17 @@ const (
|
||||
)
|
||||
|
||||
type systemdDbusConfigurator struct {
|
||||
dbusLinkObject dbus.ObjectPath
|
||||
ifaceName string
|
||||
dbusLinkObject dbus.ObjectPath
|
||||
ifaceName string
|
||||
wgIndex int
|
||||
origNameservers []netip.Addr
|
||||
}
|
||||
|
||||
const (
|
||||
systemdDbusLinkDNSProperty = systemdDbusLinkInterface + ".DNS"
|
||||
systemdDbusLinkDefaultRouteProperty = systemdDbusLinkInterface + ".DefaultRoute"
|
||||
)
|
||||
|
||||
// the types below are based on dbus specification, each field is mapped to a dbus type
|
||||
// see https://dbus.freedesktop.org/doc/dbus-specification.html#basic-types for more details on dbus types
|
||||
// see https://www.freedesktop.org/software/systemd/man/org.freedesktop.resolve1.html on resolve1 input types
|
||||
@@ -79,10 +87,145 @@ func newSystemdDbusConfigurator(wgInterface string) (*systemdDbusConfigurator, e
|
||||
|
||||
log.Debugf("got dbus Link interface: %s from net interface %s and index %d", s, iface.Name, iface.Index)
|
||||
|
||||
return &systemdDbusConfigurator{
|
||||
c := &systemdDbusConfigurator{
|
||||
dbusLinkObject: dbus.ObjectPath(s),
|
||||
ifaceName: wgInterface,
|
||||
}, nil
|
||||
wgIndex: iface.Index,
|
||||
}
|
||||
|
||||
origNameservers, err := c.captureOriginalNameservers()
|
||||
switch {
|
||||
case err != nil:
|
||||
log.Warnf("capture original nameservers from systemd-resolved: %v", err)
|
||||
case len(origNameservers) == 0:
|
||||
log.Warnf("no original nameservers captured from systemd-resolved default-route links; DNS fallback will be empty")
|
||||
default:
|
||||
log.Debugf("captured %d original nameservers from systemd-resolved default-route links: %v", len(origNameservers), origNameservers)
|
||||
}
|
||||
c.origNameservers = origNameservers
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// captureOriginalNameservers reads per-link DNS from systemd-resolved for
|
||||
// every default-route link except our own WG link. Non-default-route links
|
||||
// (VPNs, docker bridges) are skipped because their upstreams wouldn't
|
||||
// actually serve host queries.
|
||||
func (s *systemdDbusConfigurator) captureOriginalNameservers() ([]netip.Addr, error) {
|
||||
ifaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list interfaces: %w", err)
|
||||
}
|
||||
|
||||
seen := make(map[netip.Addr]struct{})
|
||||
var out []netip.Addr
|
||||
for _, iface := range ifaces {
|
||||
if !s.isCandidateLink(iface) {
|
||||
continue
|
||||
}
|
||||
linkPath, err := getSystemdLinkPath(iface.Index)
|
||||
if err != nil || !isSystemdLinkDefaultRoute(linkPath) {
|
||||
continue
|
||||
}
|
||||
for _, addr := range readSystemdLinkDNS(linkPath) {
|
||||
addr = normalizeSystemdAddr(addr, iface.Name)
|
||||
if !addr.IsValid() {
|
||||
continue
|
||||
}
|
||||
if _, dup := seen[addr]; dup {
|
||||
continue
|
||||
}
|
||||
seen[addr] = struct{}{}
|
||||
out = append(out, addr)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *systemdDbusConfigurator) isCandidateLink(iface net.Interface) bool {
|
||||
if iface.Index == s.wgIndex {
|
||||
return false
|
||||
}
|
||||
if iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagUp == 0 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// normalizeSystemdAddr unmaps v4-mapped-v6, drops unspecified, and reattaches
|
||||
// the link's iface name as zone for link-local v6 (Link.DNS strips it).
|
||||
// Returns the zero Addr to signal "skip this entry".
|
||||
func normalizeSystemdAddr(addr netip.Addr, ifaceName string) netip.Addr {
|
||||
addr = addr.Unmap()
|
||||
if !addr.IsValid() || addr.IsUnspecified() {
|
||||
return netip.Addr{}
|
||||
}
|
||||
if addr.IsLinkLocalUnicast() {
|
||||
return addr.WithZone(ifaceName)
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
||||
func getSystemdLinkPath(ifIndex int) (dbus.ObjectPath, error) {
|
||||
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("dbus resolve1: %w", err)
|
||||
}
|
||||
defer closeConn()
|
||||
var p string
|
||||
if err := obj.Call(systemdDbusGetLinkMethod, dbusDefaultFlag, int32(ifIndex)).Store(&p); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return dbus.ObjectPath(p), nil
|
||||
}
|
||||
|
||||
func isSystemdLinkDefaultRoute(linkPath dbus.ObjectPath) bool {
|
||||
obj, closeConn, err := getDbusObject(systemdResolvedDest, linkPath)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer closeConn()
|
||||
v, err := obj.GetProperty(systemdDbusLinkDefaultRouteProperty)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
b, ok := v.Value().(bool)
|
||||
return ok && b
|
||||
}
|
||||
|
||||
func readSystemdLinkDNS(linkPath dbus.ObjectPath) []netip.Addr {
|
||||
obj, closeConn, err := getDbusObject(systemdResolvedDest, linkPath)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
defer closeConn()
|
||||
v, err := obj.GetProperty(systemdDbusLinkDNSProperty)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
entries, ok := v.Value().([][]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
var out []netip.Addr
|
||||
for _, entry := range entries {
|
||||
if len(entry) < 2 {
|
||||
continue
|
||||
}
|
||||
raw, ok := entry[1].([]byte)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
addr, ok := netip.AddrFromSlice(raw)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
out = append(out, addr)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *systemdDbusConfigurator) getOriginalNameservers() []netip.Addr {
|
||||
return slices.Clone(s.origNameservers)
|
||||
}
|
||||
|
||||
func (s *systemdDbusConfigurator) supportCustomPort() bool {
|
||||
|
||||
@@ -1,3 +1,32 @@
|
||||
// Package dns implements the client-side DNS stack: listener/service on the
|
||||
// peer's tunnel address, handler chain that routes questions by domain and
|
||||
// priority, and upstream resolvers that forward what remains to configured
|
||||
// nameservers.
|
||||
//
|
||||
// # Upstream resolution and the race model
|
||||
//
|
||||
// When two or more nameserver groups target the same domain, DefaultServer
|
||||
// merges them into one upstream handler whose state is:
|
||||
//
|
||||
// upstreamResolverBase
|
||||
// └── upstreamServers []upstreamRace // one entry per source NS group
|
||||
// └── []netip.AddrPort // primary, fallback, ...
|
||||
//
|
||||
// Each source nameserver group contributes one upstreamRace. Within a race
|
||||
// upstreams are tried in order: the next is used only on failure (timeout,
|
||||
// SERVFAIL, REFUSED, no response). NXDOMAIN is a valid answer and stops
|
||||
// the walk. When more than one race exists, ServeDNS fans out one
|
||||
// goroutine per race and returns the first valid answer, cancelling the
|
||||
// rest. A handler with a single race skips the fan-out.
|
||||
//
|
||||
// # Health projection
|
||||
//
|
||||
// Query outcomes are recorded per-upstream in UpstreamHealth. The server
|
||||
// periodically merges these snapshots across handlers and projects them
|
||||
// into peer.NSGroupState. There is no active probing: a group is marked
|
||||
// unhealthy only when every seen upstream has a recent failure and none
|
||||
// has a recent success. Healthy→unhealthy fires a single
|
||||
// SystemEvent_WARNING; steady-state refreshes do not duplicate it.
|
||||
package dns
|
||||
|
||||
import (
|
||||
@@ -11,11 +40,8 @@ import (
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
@@ -25,7 +51,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
||||
"github.com/netbirdio/netbird/client/internal/dns/types"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
var currentMTU uint16 = iface.DefaultMTU
|
||||
@@ -67,15 +94,17 @@ const (
|
||||
// Set longer than UpstreamTimeout to ensure context timeout takes precedence
|
||||
ClientTimeout = 5 * time.Second
|
||||
|
||||
reactivatePeriod = 30 * time.Second
|
||||
probeTimeout = 2 * time.Second
|
||||
|
||||
// ipv6HeaderSize + udpHeaderSize, used to derive the maximum DNS UDP
|
||||
// payload from the tunnel MTU.
|
||||
ipUDPHeaderSize = 60 + 8
|
||||
)
|
||||
|
||||
const testRecord = "com."
|
||||
// raceMaxTotalTimeout caps the combined time spent walking all upstreams
|
||||
// within one race, so a slow primary can't eat the whole race budget.
|
||||
raceMaxTotalTimeout = 5 * time.Second
|
||||
// raceMinPerUpstreamTimeout is the floor applied when dividing
|
||||
// raceMaxTotalTimeout across upstreams within a race.
|
||||
raceMinPerUpstreamTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
const (
|
||||
protoUDP = "udp"
|
||||
@@ -84,6 +113,69 @@ const (
|
||||
|
||||
type dnsProtocolKey struct{}
|
||||
|
||||
type upstreamProtocolKey struct{}
|
||||
|
||||
// upstreamProtocolResult holds the protocol used for the upstream exchange.
|
||||
// Stored as a pointer in context so the exchange function can set it.
|
||||
type upstreamProtocolResult struct {
|
||||
protocol string
|
||||
}
|
||||
|
||||
type upstreamClient interface {
|
||||
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||
}
|
||||
|
||||
type UpstreamResolver interface {
|
||||
serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||
upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||
}
|
||||
|
||||
// upstreamRace is an ordered list of upstreams derived from one configured
|
||||
// nameserver group. Order matters: the first upstream is tried first, the
|
||||
// second only on failure, and so on. Multiple upstreamRace values coexist
|
||||
// inside one resolver when overlapping nameserver groups target the same
|
||||
// domain; those races run in parallel and the first valid answer wins.
|
||||
type upstreamRace []netip.AddrPort
|
||||
|
||||
// UpstreamHealth is the last query-path outcome for a single upstream,
|
||||
// consumed by nameserver-group status projection.
|
||||
type UpstreamHealth struct {
|
||||
LastOk time.Time
|
||||
LastFail time.Time
|
||||
LastErr string
|
||||
}
|
||||
|
||||
type upstreamResolverBase struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
upstreamClient upstreamClient
|
||||
upstreamServers []upstreamRace
|
||||
domain domain.Domain
|
||||
upstreamTimeout time.Duration
|
||||
|
||||
healthMu sync.RWMutex
|
||||
health map[netip.AddrPort]*UpstreamHealth
|
||||
|
||||
statusRecorder *peer.Status
|
||||
// selectedRoutes returns the current set of client routes the admin
|
||||
// has enabled. Called lazily from the query hot path when an upstream
|
||||
// might need a tunnel-bound client (iOS) and from health projection.
|
||||
selectedRoutes func() route.HAMap
|
||||
}
|
||||
|
||||
type upstreamFailure struct {
|
||||
upstream netip.AddrPort
|
||||
reason string
|
||||
}
|
||||
|
||||
type raceResult struct {
|
||||
msg *dns.Msg
|
||||
upstream netip.AddrPort
|
||||
protocol string
|
||||
ede string
|
||||
failures []upstreamFailure
|
||||
}
|
||||
|
||||
// contextWithDNSProtocol stores the inbound DNS protocol ("udp" or "tcp") in context.
|
||||
func contextWithDNSProtocol(ctx context.Context, network string) context.Context {
|
||||
return context.WithValue(ctx, dnsProtocolKey{}, network)
|
||||
@@ -100,16 +192,8 @@ func dnsProtocolFromContext(ctx context.Context) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
type upstreamProtocolKey struct{}
|
||||
|
||||
// upstreamProtocolResult holds the protocol used for the upstream exchange.
|
||||
// Stored as a pointer in context so the exchange function can set it.
|
||||
type upstreamProtocolResult struct {
|
||||
protocol string
|
||||
}
|
||||
|
||||
// contextWithupstreamProtocolResult stores a mutable result holder in the context.
|
||||
func contextWithupstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
|
||||
// contextWithUpstreamProtocolResult stores a mutable result holder in the context.
|
||||
func contextWithUpstreamProtocolResult(ctx context.Context) (context.Context, *upstreamProtocolResult) {
|
||||
r := &upstreamProtocolResult{}
|
||||
return context.WithValue(ctx, upstreamProtocolKey{}, r), r
|
||||
}
|
||||
@@ -124,67 +208,37 @@ func setUpstreamProtocol(ctx context.Context, protocol string) {
|
||||
}
|
||||
}
|
||||
|
||||
type upstreamClient interface {
|
||||
exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||
}
|
||||
|
||||
type UpstreamResolver interface {
|
||||
serveDNS(r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||
upstreamExchange(upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error)
|
||||
}
|
||||
|
||||
type upstreamResolverBase struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
upstreamClient upstreamClient
|
||||
upstreamServers []netip.AddrPort
|
||||
domain string
|
||||
disabled bool
|
||||
successCount atomic.Int32
|
||||
mutex sync.Mutex
|
||||
reactivatePeriod time.Duration
|
||||
upstreamTimeout time.Duration
|
||||
wg sync.WaitGroup
|
||||
|
||||
deactivate func(error)
|
||||
reactivate func()
|
||||
statusRecorder *peer.Status
|
||||
routeMatch func(netip.Addr) bool
|
||||
}
|
||||
|
||||
type upstreamFailure struct {
|
||||
upstream netip.AddrPort
|
||||
reason string
|
||||
}
|
||||
|
||||
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, domain string) *upstreamResolverBase {
|
||||
func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d domain.Domain) *upstreamResolverBase {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
return &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
domain: domain,
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
reactivatePeriod: reactivatePeriod,
|
||||
statusRecorder: statusRecorder,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
domain: d,
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
statusRecorder: statusRecorder,
|
||||
}
|
||||
}
|
||||
|
||||
// String returns a string representation of the upstream resolver
|
||||
func (u *upstreamResolverBase) String() string {
|
||||
return fmt.Sprintf("Upstream %s", u.upstreamServers)
|
||||
return fmt.Sprintf("Upstream %s", u.flatUpstreams())
|
||||
}
|
||||
|
||||
// ID returns the unique handler ID
|
||||
// ID returns the unique handler ID. Race groupings and within-race
|
||||
// ordering are both part of the identity: [[A,B]] and [[A],[B]] query
|
||||
// the same servers but with different semantics (serial fallback vs
|
||||
// parallel race), so their handlers must not collide.
|
||||
func (u *upstreamResolverBase) ID() types.HandlerID {
|
||||
servers := slices.Clone(u.upstreamServers)
|
||||
slices.SortFunc(servers, func(a, b netip.AddrPort) int { return a.Compare(b) })
|
||||
|
||||
hash := sha256.New()
|
||||
hash.Write([]byte(u.domain + ":"))
|
||||
for _, s := range servers {
|
||||
hash.Write([]byte(s.String()))
|
||||
hash.Write([]byte("|"))
|
||||
hash.Write([]byte(u.domain.PunycodeString() + ":"))
|
||||
for _, race := range u.upstreamServers {
|
||||
hash.Write([]byte("["))
|
||||
for _, s := range race {
|
||||
hash.Write([]byte(s.String()))
|
||||
hash.Write([]byte("|"))
|
||||
}
|
||||
hash.Write([]byte("]"))
|
||||
}
|
||||
return types.HandlerID("upstream-" + hex.EncodeToString(hash.Sum(nil)[:8]))
|
||||
}
|
||||
@@ -194,13 +248,31 @@ func (u *upstreamResolverBase) MatchSubdomains() bool {
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) Stop() {
|
||||
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
|
||||
log.Debugf("stopping serving DNS for upstreams %s", u.flatUpstreams())
|
||||
u.cancel()
|
||||
}
|
||||
|
||||
u.mutex.Lock()
|
||||
u.wg.Wait()
|
||||
u.mutex.Unlock()
|
||||
// flatUpstreams is for logging and ID hashing only, not for dispatch.
|
||||
func (u *upstreamResolverBase) flatUpstreams() []netip.AddrPort {
|
||||
var out []netip.AddrPort
|
||||
for _, g := range u.upstreamServers {
|
||||
out = append(out, g...)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// setSelectedRoutes swaps the accessor used to classify overlay-routed
|
||||
// upstreams. Called when route sources are wired after the handler was
|
||||
// built (permanent / iOS constructors).
|
||||
func (u *upstreamResolverBase) setSelectedRoutes(selected func() route.HAMap) {
|
||||
u.selectedRoutes = selected
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) addRace(servers []netip.AddrPort) {
|
||||
if len(servers) == 0 {
|
||||
return
|
||||
}
|
||||
u.upstreamServers = append(u.upstreamServers, slices.Clone(servers))
|
||||
}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
@@ -242,82 +314,201 @@ func (u *upstreamResolverBase) prepareRequest(r *dns.Msg) {
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) tryUpstreamServers(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, logger *log.Entry) (bool, []upstreamFailure) {
|
||||
timeout := u.upstreamTimeout
|
||||
if len(u.upstreamServers) > 1 {
|
||||
maxTotal := 5 * time.Second
|
||||
minPerUpstream := 2 * time.Second
|
||||
scaledTimeout := maxTotal / time.Duration(len(u.upstreamServers))
|
||||
if scaledTimeout > minPerUpstream {
|
||||
timeout = scaledTimeout
|
||||
} else {
|
||||
timeout = minPerUpstream
|
||||
}
|
||||
groups := u.upstreamServers
|
||||
switch len(groups) {
|
||||
case 0:
|
||||
return false, nil
|
||||
case 1:
|
||||
return u.tryOnlyRace(ctx, w, r, groups[0], logger)
|
||||
default:
|
||||
return u.raceAll(ctx, w, r, groups, logger)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) tryOnlyRace(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, group upstreamRace, logger *log.Entry) (bool, []upstreamFailure) {
|
||||
res := u.tryRace(ctx, r, group)
|
||||
if res.msg == nil {
|
||||
return false, res.failures
|
||||
}
|
||||
if res.ede != "" {
|
||||
resutil.SetMeta(w, "ede", res.ede)
|
||||
}
|
||||
u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger)
|
||||
return true, res.failures
|
||||
}
|
||||
|
||||
// raceAll runs one worker per group in parallel, taking the first valid
|
||||
// answer and cancelling the rest.
|
||||
func (u *upstreamResolverBase) raceAll(ctx context.Context, w dns.ResponseWriter, r *dns.Msg, groups []upstreamRace, logger *log.Entry) (bool, []upstreamFailure) {
|
||||
raceCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// Buffer sized to len(groups) so workers never block on send, even
|
||||
// after the coordinator has returned.
|
||||
results := make(chan raceResult, len(groups))
|
||||
for _, g := range groups {
|
||||
// tryRace clones the request per attempt, so workers never share
|
||||
// a *dns.Msg and concurrent EDNS0 mutations can't race.
|
||||
go func(g upstreamRace) {
|
||||
results <- u.tryRace(raceCtx, r, g)
|
||||
}(g)
|
||||
}
|
||||
|
||||
var failures []upstreamFailure
|
||||
for _, upstream := range u.upstreamServers {
|
||||
if failure := u.queryUpstream(ctx, w, r, upstream, timeout, logger); failure != nil {
|
||||
failures = append(failures, *failure)
|
||||
} else {
|
||||
return true, failures
|
||||
for range groups {
|
||||
select {
|
||||
case res := <-results:
|
||||
failures = append(failures, res.failures...)
|
||||
if res.msg != nil {
|
||||
if res.ede != "" {
|
||||
resutil.SetMeta(w, "ede", res.ede)
|
||||
}
|
||||
u.writeSuccessResponse(w, res.msg, res.upstream, r.Question[0].Name, res.protocol, logger)
|
||||
return true, failures
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return false, failures
|
||||
}
|
||||
}
|
||||
return false, failures
|
||||
}
|
||||
|
||||
// queryUpstream queries a single upstream server. Returns nil on success, or failure info to try next upstream.
|
||||
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.ResponseWriter, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration, logger *log.Entry) *upstreamFailure {
|
||||
var rm *dns.Msg
|
||||
var t time.Duration
|
||||
var err error
|
||||
func (u *upstreamResolverBase) tryRace(ctx context.Context, r *dns.Msg, group upstreamRace) raceResult {
|
||||
timeout := u.upstreamTimeout
|
||||
if len(group) > 1 {
|
||||
// Cap the whole walk at raceMaxTotalTimeout: per-upstream timeouts
|
||||
// still honor raceMinPerUpstreamTimeout as a floor for correctness
|
||||
// on slow links, but the outer context ensures the combined walk
|
||||
// cannot exceed the cap regardless of group size.
|
||||
timeout = max(raceMaxTotalTimeout/time.Duration(len(group)), raceMinPerUpstreamTimeout)
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, raceMaxTotalTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
var failures []upstreamFailure
|
||||
for _, upstream := range group {
|
||||
if ctx.Err() != nil {
|
||||
return raceResult{failures: failures}
|
||||
}
|
||||
// Clone the request per attempt: the exchange path mutates EDNS0
|
||||
// options in-place, so reusing the same *dns.Msg across sequential
|
||||
// upstreams would carry those mutations (e.g. a reduced UDP size)
|
||||
// into the next attempt.
|
||||
res, failure := u.queryUpstream(ctx, r.Copy(), upstream, timeout)
|
||||
if failure != nil {
|
||||
failures = append(failures, *failure)
|
||||
continue
|
||||
}
|
||||
res.failures = failures
|
||||
return res
|
||||
}
|
||||
return raceResult{failures: failures}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.Msg, upstream netip.AddrPort, timeout time.Duration) (raceResult, *upstreamFailure) {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, timeout)
|
||||
defer cancel()
|
||||
ctx, upstreamProto := contextWithUpstreamProtocolResult(ctx)
|
||||
|
||||
// Advertise EDNS0 so the upstream may include Extended DNS Errors
|
||||
// (RFC 8914) in failure responses; we use those to short-circuit
|
||||
// failover for definitive answers like DNSSEC validation failures.
|
||||
// Operate on a copy so the inbound request is unchanged: a client that
|
||||
// did not advertise EDNS0 must not see an OPT in the response.
|
||||
// The caller already passed a per-attempt copy, so we can mutate r
|
||||
// directly; hadEdns reflects the original client request's state and
|
||||
// controls whether we strip the OPT from the response.
|
||||
hadEdns := r.IsEdns0() != nil
|
||||
reqUp := r
|
||||
if !hadEdns {
|
||||
reqUp = r.Copy()
|
||||
reqUp.SetEdns0(upstreamUDPSize(), false)
|
||||
r.SetEdns0(upstreamUDPSize(), false)
|
||||
}
|
||||
|
||||
var startTime time.Time
|
||||
var upstreamProto *upstreamProtocolResult
|
||||
func() {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, timeout)
|
||||
defer cancel()
|
||||
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
|
||||
startTime = time.Now()
|
||||
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), reqUp)
|
||||
}()
|
||||
startTime := time.Now()
|
||||
rm, _, err := u.upstreamClient.exchange(ctx, upstream.String(), r)
|
||||
|
||||
if err != nil {
|
||||
return u.handleUpstreamError(err, upstream, startTime)
|
||||
// A parent cancellation (e.g., another race won and the coordinator
|
||||
// cancelled the losers) is not an upstream failure. Check both the
|
||||
// error chain and the parent context: a transport may surface the
|
||||
// cancellation as a read/deadline error rather than context.Canceled.
|
||||
if errors.Is(err, context.Canceled) || errors.Is(parentCtx.Err(), context.Canceled) {
|
||||
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "canceled"}
|
||||
}
|
||||
failure := u.handleUpstreamError(err, upstream, startTime)
|
||||
u.markUpstreamFail(upstream, failure.reason)
|
||||
return raceResult{}, failure
|
||||
}
|
||||
|
||||
if rm == nil || !rm.Response {
|
||||
return &upstreamFailure{upstream: upstream, reason: "no response"}
|
||||
u.markUpstreamFail(upstream, "no response")
|
||||
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"}
|
||||
}
|
||||
|
||||
proto := ""
|
||||
if upstreamProto != nil {
|
||||
proto = upstreamProto.protocol
|
||||
}
|
||||
|
||||
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
|
||||
if code, ok := nonRetryableEDE(rm); ok {
|
||||
resutil.SetMeta(w, "ede", edeName(code))
|
||||
if !hadEdns {
|
||||
stripOPT(rm)
|
||||
}
|
||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
|
||||
return nil
|
||||
u.markUpstreamOk(upstream)
|
||||
return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil
|
||||
}
|
||||
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
|
||||
reason := dns.RcodeToString[rm.Rcode]
|
||||
u.markUpstreamFail(upstream, reason)
|
||||
return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason}
|
||||
}
|
||||
|
||||
if !hadEdns {
|
||||
stripOPT(rm)
|
||||
}
|
||||
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
|
||||
return nil
|
||||
|
||||
u.markUpstreamOk(upstream)
|
||||
return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil
|
||||
}
|
||||
|
||||
// healthEntry returns the mutable health record for addr, lazily creating
|
||||
// the map and the entry. Caller must hold u.healthMu.
|
||||
func (u *upstreamResolverBase) healthEntry(addr netip.AddrPort) *UpstreamHealth {
|
||||
if u.health == nil {
|
||||
u.health = make(map[netip.AddrPort]*UpstreamHealth)
|
||||
}
|
||||
h := u.health[addr]
|
||||
if h == nil {
|
||||
h = &UpstreamHealth{}
|
||||
u.health[addr] = h
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) markUpstreamOk(addr netip.AddrPort) {
|
||||
u.healthMu.Lock()
|
||||
defer u.healthMu.Unlock()
|
||||
h := u.healthEntry(addr)
|
||||
h.LastOk = time.Now()
|
||||
h.LastFail = time.Time{}
|
||||
h.LastErr = ""
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) markUpstreamFail(addr netip.AddrPort, reason string) {
|
||||
u.healthMu.Lock()
|
||||
defer u.healthMu.Unlock()
|
||||
h := u.healthEntry(addr)
|
||||
h.LastFail = time.Now()
|
||||
h.LastErr = reason
|
||||
}
|
||||
|
||||
// UpstreamHealth returns a snapshot of per-upstream query outcomes.
|
||||
func (u *upstreamResolverBase) UpstreamHealth() map[netip.AddrPort]UpstreamHealth {
|
||||
u.healthMu.RLock()
|
||||
defer u.healthMu.RUnlock()
|
||||
out := make(map[netip.AddrPort]UpstreamHealth, len(u.health))
|
||||
for k, v := range u.health {
|
||||
out[k] = *v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// upstreamUDPSize returns the EDNS0 UDP buffer size we advertise to upstreams,
|
||||
@@ -358,12 +549,23 @@ func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.Add
|
||||
return &upstreamFailure{upstream: upstream, reason: reason}
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, t time.Duration, upstreamProto *upstreamProtocolResult, logger *log.Entry) bool {
|
||||
u.successCount.Add(1)
|
||||
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string {
|
||||
if u.statusRecorder == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder)
|
||||
if peerInfo == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dns.Msg, upstream netip.AddrPort, domain string, proto string, logger *log.Entry) {
|
||||
resutil.SetMeta(w, "upstream", upstream.String())
|
||||
if upstreamProto != nil && upstreamProto.protocol != "" {
|
||||
resutil.SetMeta(w, "upstream_protocol", upstreamProto.protocol)
|
||||
if proto != "" {
|
||||
resutil.SetMeta(w, "upstream_protocol", proto)
|
||||
}
|
||||
|
||||
// Clear Zero bit from external responses to prevent upstream servers from
|
||||
@@ -372,14 +574,11 @@ func (u *upstreamResolverBase) writeSuccessResponse(w dns.ResponseWriter, rm *dn
|
||||
|
||||
if err := w.WriteMsg(rm); err != nil {
|
||||
logger.Errorf("failed to write DNS response for question domain=%s: %s", domain, err)
|
||||
return true
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) logUpstreamFailures(domain string, failures []upstreamFailure, succeeded bool, logger *log.Entry) {
|
||||
totalUpstreams := len(u.upstreamServers)
|
||||
totalUpstreams := len(u.flatUpstreams())
|
||||
failedCount := len(failures)
|
||||
failureSummary := formatFailures(failures)
|
||||
|
||||
@@ -434,119 +633,6 @@ func edeName(code uint16) string {
|
||||
return fmt.Sprintf("EDE %d", code)
|
||||
}
|
||||
|
||||
// ProbeAvailability tests all upstream servers simultaneously and
|
||||
// disables the resolver if none work
|
||||
func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {
|
||||
u.mutex.Lock()
|
||||
defer u.mutex.Unlock()
|
||||
|
||||
// avoid probe if upstreams could resolve at least one query
|
||||
if u.successCount.Load() > 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var success bool
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
|
||||
var errs *multierror.Error
|
||||
for _, upstream := range u.upstreamServers {
|
||||
wg.Add(1)
|
||||
go func(upstream netip.AddrPort) {
|
||||
defer wg.Done()
|
||||
err := u.testNameserver(u.ctx, ctx, upstream, 500*time.Millisecond)
|
||||
if err != nil {
|
||||
mu.Lock()
|
||||
errs = multierror.Append(errs, err)
|
||||
mu.Unlock()
|
||||
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
success = true
|
||||
mu.Unlock()
|
||||
}(upstream)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-u.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// didn't find a working upstream server, let's disable and try later
|
||||
if !success {
|
||||
u.disable(errs.ErrorOrNil())
|
||||
|
||||
if u.statusRecorder == nil {
|
||||
return
|
||||
}
|
||||
|
||||
u.statusRecorder.PublishEvent(
|
||||
proto.SystemEvent_WARNING,
|
||||
proto.SystemEvent_DNS,
|
||||
"All upstream servers failed (probe failed)",
|
||||
"Unable to reach one or more DNS servers. This might affect your ability to connect to some services.",
|
||||
map[string]string{"upstreams": u.upstreamServersString()},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// waitUntilResponse retries, in an exponential interval, querying the upstream servers until it gets a positive response
|
||||
func (u *upstreamResolverBase) waitUntilResponse() {
|
||||
exponentialBackOff := &backoff.ExponentialBackOff{
|
||||
InitialInterval: 500 * time.Millisecond,
|
||||
RandomizationFactor: 0.5,
|
||||
Multiplier: 1.1,
|
||||
MaxInterval: u.reactivatePeriod,
|
||||
MaxElapsedTime: 0,
|
||||
Stop: backoff.Stop,
|
||||
Clock: backoff.SystemClock,
|
||||
}
|
||||
|
||||
operation := func() error {
|
||||
select {
|
||||
case <-u.ctx.Done():
|
||||
return backoff.Permanent(fmt.Errorf("exiting upstream retry loop for upstreams %s: parent context has been canceled", u.upstreamServersString()))
|
||||
default:
|
||||
}
|
||||
|
||||
for _, upstream := range u.upstreamServers {
|
||||
if err := u.testNameserver(u.ctx, nil, upstream, probeTimeout); err != nil {
|
||||
log.Tracef("upstream check for %s: %s", upstream, err)
|
||||
} else {
|
||||
// at least one upstream server is available, stop probing
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
log.Tracef("checking connectivity with upstreams %s failed. Retrying in %s", u.upstreamServersString(), exponentialBackOff.NextBackOff())
|
||||
return fmt.Errorf("upstream check call error")
|
||||
}
|
||||
|
||||
err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx))
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString())
|
||||
} else {
|
||||
log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
|
||||
u.successCount.Add(1)
|
||||
u.reactivate()
|
||||
u.mutex.Lock()
|
||||
u.disabled = false
|
||||
u.mutex.Unlock()
|
||||
}
|
||||
|
||||
// isTimeout returns true if the given error is a network timeout error.
|
||||
//
|
||||
// Copied from k8s.io/apimachinery/pkg/util/net.IsTimeout
|
||||
@@ -558,45 +644,6 @@ func isTimeout(err error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) disable(err error) {
|
||||
if u.disabled {
|
||||
return
|
||||
}
|
||||
|
||||
log.Warnf("Upstream resolving is Disabled for %v", reactivatePeriod)
|
||||
u.successCount.Store(0)
|
||||
u.deactivate(err)
|
||||
u.disabled = true
|
||||
u.wg.Add(1)
|
||||
go func() {
|
||||
defer u.wg.Done()
|
||||
u.waitUntilResponse()
|
||||
}()
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) upstreamServersString() string {
|
||||
var servers []string
|
||||
for _, server := range u.upstreamServers {
|
||||
servers = append(servers, server.String())
|
||||
}
|
||||
return strings.Join(servers, ", ")
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalCtx context.Context, server netip.AddrPort, timeout time.Duration) error {
|
||||
mergedCtx, cancel := context.WithTimeout(baseCtx, timeout)
|
||||
defer cancel()
|
||||
|
||||
if externalCtx != nil {
|
||||
stop2 := context.AfterFunc(externalCtx, cancel)
|
||||
defer stop2()
|
||||
}
|
||||
|
||||
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
|
||||
|
||||
_, _, err := u.upstreamClient.exchange(mergedCtx, server.String(), r)
|
||||
return err
|
||||
}
|
||||
|
||||
// clientUDPMaxSize returns the maximum UDP response size the client accepts.
|
||||
func clientUDPMaxSize(r *dns.Msg) int {
|
||||
if opt := r.IsEdns0(); opt != nil {
|
||||
@@ -608,13 +655,10 @@ func clientUDPMaxSize(r *dns.Msg) int {
|
||||
// ExchangeWithFallback exchanges a DNS message with the upstream server.
|
||||
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
|
||||
// If the inbound request came over TCP (via context), it skips the UDP attempt.
|
||||
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
|
||||
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
|
||||
// If the request came in over TCP, go straight to TCP upstream.
|
||||
if dnsProtocolFromContext(ctx) == protoTCP {
|
||||
tcpClient := *client
|
||||
tcpClient.Net = protoTCP
|
||||
rm, t, err := tcpClient.ExchangeContext(ctx, r, upstream)
|
||||
rm, t, err := toTCPClient(client).ExchangeContext(ctx, r, upstream)
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||
}
|
||||
@@ -634,18 +678,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
opt.SetUDPSize(maxUDPPayload)
|
||||
}
|
||||
|
||||
var (
|
||||
rm *dns.Msg
|
||||
t time.Duration
|
||||
err error
|
||||
)
|
||||
|
||||
if ctx == nil {
|
||||
rm, t, err = client.Exchange(r, upstream)
|
||||
} else {
|
||||
rm, t, err = client.ExchangeContext(ctx, r, upstream)
|
||||
}
|
||||
|
||||
rm, t, err := client.ExchangeContext(ctx, r, upstream)
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with udp: %w", err)
|
||||
}
|
||||
@@ -659,15 +692,7 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
// data than the client's buffer, we could truncate locally and skip
|
||||
// the TCP retry.
|
||||
|
||||
tcpClient := *client
|
||||
tcpClient.Net = protoTCP
|
||||
|
||||
if ctx == nil {
|
||||
rm, t, err = tcpClient.Exchange(r, upstream)
|
||||
} else {
|
||||
rm, t, err = tcpClient.ExchangeContext(ctx, r, upstream)
|
||||
}
|
||||
|
||||
rm, t, err = toTCPClient(client).ExchangeContext(ctx, r, upstream)
|
||||
if err != nil {
|
||||
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||
}
|
||||
@@ -681,6 +706,25 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
|
||||
return rm, t, nil
|
||||
}
|
||||
|
||||
// toTCPClient returns a copy of c configured for TCP. If c's Dialer has a
|
||||
// *net.UDPAddr bound as LocalAddr (iOS does this to keep the source IP on
|
||||
// the tunnel interface), it is converted to the equivalent *net.TCPAddr
|
||||
// so net.Dialer doesn't reject the TCP dial with "mismatched local
|
||||
// address type".
|
||||
func toTCPClient(c *dns.Client) *dns.Client {
|
||||
tcp := *c
|
||||
tcp.Net = protoTCP
|
||||
if tcp.Dialer == nil {
|
||||
return &tcp
|
||||
}
|
||||
d := *tcp.Dialer
|
||||
if ua, ok := d.LocalAddr.(*net.UDPAddr); ok {
|
||||
d.LocalAddr = &net.TCPAddr{IP: ua.IP, Port: ua.Port, Zone: ua.Zone}
|
||||
}
|
||||
tcp.Dialer = &d
|
||||
return &tcp
|
||||
}
|
||||
|
||||
// ExchangeWithNetstack performs a DNS exchange using netstack for dialing.
|
||||
// This is needed when netstack is enabled to reach peer IPs through the tunnel.
|
||||
func ExchangeWithNetstack(ctx context.Context, nsNet *netstack.Net, r *dns.Msg, upstream string) (*dns.Msg, error) {
|
||||
@@ -822,15 +866,36 @@ func findPeerForIP(ip netip.Addr, statusRecorder *peer.Status) *peer.State {
|
||||
return bestMatch
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) debugUpstreamTimeout(upstream netip.AddrPort) string {
|
||||
if u.statusRecorder == nil {
|
||||
return ""
|
||||
// haMapRouteCount returns the total number of routes across all HA
|
||||
// groups in the map. route.HAMap is keyed by HAUniqueID with slices of
|
||||
// routes per key, so len(hm) is the number of HA groups, not routes.
|
||||
func haMapRouteCount(hm route.HAMap) int {
|
||||
total := 0
|
||||
for _, routes := range hm {
|
||||
total += len(routes)
|
||||
}
|
||||
|
||||
peerInfo := findPeerForIP(upstream.Addr(), u.statusRecorder)
|
||||
if peerInfo == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return fmt.Sprintf("(routes through NetBird peer %s)", FormatPeerStatus(peerInfo))
|
||||
return total
|
||||
}
|
||||
|
||||
// haMapContains checks whether ip is covered by any concrete prefix in
|
||||
// the HA map. haveDynamic is reported separately: dynamic (domain-based)
|
||||
// routes carry a placeholder Network that can't be prefix-checked, so we
|
||||
// can't know at this point whether ip is reached through one. Callers
|
||||
// decide how to interpret the unknown: health projection treats it as
|
||||
// "possibly routed" to avoid emitting false-positive warnings during
|
||||
// startup, while iOS dial selection requires a concrete match before
|
||||
// binding to the tunnel.
|
||||
func haMapContains(hm route.HAMap, ip netip.Addr) (matched, haveDynamic bool) {
|
||||
for _, routes := range hm {
|
||||
for _, r := range routes {
|
||||
if r.IsDynamic() {
|
||||
haveDynamic = true
|
||||
continue
|
||||
}
|
||||
if r.Network.Contains(ip) {
|
||||
return true, haveDynamic
|
||||
}
|
||||
}
|
||||
}
|
||||
return false, haveDynamic
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
type upstreamResolver struct {
|
||||
@@ -26,9 +27,9 @@ func newUpstreamResolver(
|
||||
_ WGIface,
|
||||
statusRecorder *peer.Status,
|
||||
hostsDNSHolder *hostsDNSHolder,
|
||||
domain string,
|
||||
d domain.Domain,
|
||||
) (*upstreamResolver, error) {
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
|
||||
c := &upstreamResolver{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
hostsDNSHolder: hostsDNSHolder,
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
type upstreamResolver struct {
|
||||
@@ -24,9 +25,9 @@ func newUpstreamResolver(
|
||||
wgIface WGIface,
|
||||
statusRecorder *peer.Status,
|
||||
_ *hostsDNSHolder,
|
||||
domain string,
|
||||
d domain.Domain,
|
||||
) (*upstreamResolver, error) {
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
|
||||
nonIOS := &upstreamResolver{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
nsNet: wgIface.GetNet(),
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
type upstreamResolverIOS struct {
|
||||
@@ -27,9 +28,9 @@ func newUpstreamResolver(
|
||||
wgIface WGIface,
|
||||
statusRecorder *peer.Status,
|
||||
_ *hostsDNSHolder,
|
||||
domain string,
|
||||
d domain.Domain,
|
||||
) (*upstreamResolverIOS, error) {
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, domain)
|
||||
upstreamResolverBase := newUpstreamResolverBase(ctx, statusRecorder, d)
|
||||
|
||||
ios := &upstreamResolverIOS{
|
||||
upstreamResolverBase: upstreamResolverBase,
|
||||
@@ -62,9 +63,16 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
upstreamIP = upstreamIP.Unmap()
|
||||
}
|
||||
addr := u.wgIface.Address()
|
||||
var routed bool
|
||||
if u.selectedRoutes != nil {
|
||||
// Only a concrete prefix match binds to the tunnel: dialing
|
||||
// through a private client for an upstream we can't prove is
|
||||
// routed would break public resolvers.
|
||||
routed, _ = haMapContains(u.selectedRoutes(), upstreamIP)
|
||||
}
|
||||
needsPrivate := addr.Network.Contains(upstreamIP) ||
|
||||
addr.IPv6Net.Contains(upstreamIP) ||
|
||||
(u.routeMatch != nil && u.routeMatch(upstreamIP))
|
||||
routed
|
||||
if needsPrivate {
|
||||
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
|
||||
client, err = GetClientPrivate(u.wgIface, upstreamIP, timeout)
|
||||
@@ -73,8 +81,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
||||
}
|
||||
}
|
||||
|
||||
// Cannot use client.ExchangeContext because it overwrites our Dialer
|
||||
return ExchangeWithFallback(nil, client, r, upstream)
|
||||
return ExchangeWithFallback(ctx, client, r, upstream)
|
||||
}
|
||||
|
||||
// GetClientPrivate returns a new DNS client bound to the local IP of the Netbird interface.
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -73,7 +74,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
||||
servers = append(servers, netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port()))
|
||||
}
|
||||
}
|
||||
resolver.upstreamServers = servers
|
||||
resolver.addRace(servers)
|
||||
resolver.upstreamTimeout = testCase.timeout
|
||||
if testCase.cancelCTX {
|
||||
cancel()
|
||||
@@ -132,20 +133,10 @@ func (m *mockNetstackProvider) GetInterfaceGUIDString() (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
type mockUpstreamResolver struct {
|
||||
r *dns.Msg
|
||||
rtt time.Duration
|
||||
err error
|
||||
}
|
||||
|
||||
// exchange mock implementation of exchange from upstreamResolver
|
||||
func (c mockUpstreamResolver) exchange(_ context.Context, _ string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
|
||||
return c.r, c.rtt, c.err
|
||||
}
|
||||
|
||||
type mockUpstreamResponse struct {
|
||||
msg *dns.Msg
|
||||
err error
|
||||
msg *dns.Msg
|
||||
err error
|
||||
delay time.Duration
|
||||
}
|
||||
|
||||
type mockUpstreamResolverPerServer struct {
|
||||
@@ -153,63 +144,19 @@ type mockUpstreamResolverPerServer struct {
|
||||
rtt time.Duration
|
||||
}
|
||||
|
||||
func (c mockUpstreamResolverPerServer) exchange(_ context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
|
||||
if r, ok := c.responses[upstream]; ok {
|
||||
return r.msg, c.rtt, r.err
|
||||
func (c mockUpstreamResolverPerServer) exchange(ctx context.Context, upstream string, _ *dns.Msg) (*dns.Msg, time.Duration, error) {
|
||||
r, ok := c.responses[upstream]
|
||||
if !ok {
|
||||
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
|
||||
}
|
||||
return nil, c.rtt, fmt.Errorf("no mock response for %s", upstream)
|
||||
}
|
||||
|
||||
func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||
mockClient := &mockUpstreamResolver{
|
||||
err: dns.ErrTime,
|
||||
r: new(dns.Msg),
|
||||
rtt: time.Millisecond,
|
||||
}
|
||||
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: context.TODO(),
|
||||
upstreamClient: mockClient,
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
reactivatePeriod: time.Microsecond * 100,
|
||||
}
|
||||
addrPort, _ := netip.ParseAddrPort("0.0.0.0:1") // Use valid port for parsing, test will still fail on connection
|
||||
resolver.upstreamServers = []netip.AddrPort{netip.AddrPortFrom(addrPort.Addr().Unmap(), addrPort.Port())}
|
||||
|
||||
failed := false
|
||||
resolver.deactivate = func(error) {
|
||||
failed = true
|
||||
// After deactivation, make the mock client work again
|
||||
mockClient.err = nil
|
||||
}
|
||||
|
||||
reactivated := false
|
||||
resolver.reactivate = func() {
|
||||
reactivated = true
|
||||
}
|
||||
|
||||
resolver.ProbeAvailability(context.TODO())
|
||||
|
||||
if !failed {
|
||||
t.Errorf("expected that resolving was deactivated")
|
||||
return
|
||||
}
|
||||
|
||||
if !resolver.disabled {
|
||||
t.Errorf("resolver should be Disabled")
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(time.Millisecond * 200)
|
||||
|
||||
if !reactivated {
|
||||
t.Errorf("expected that resolving was reactivated")
|
||||
return
|
||||
}
|
||||
|
||||
if resolver.disabled {
|
||||
t.Errorf("should be enabled")
|
||||
if r.delay > 0 {
|
||||
select {
|
||||
case <-time.After(r.delay):
|
||||
case <-ctx.Done():
|
||||
return nil, c.rtt, ctx.Err()
|
||||
}
|
||||
}
|
||||
return r.msg, c.rtt, r.err
|
||||
}
|
||||
|
||||
func TestUpstreamResolver_Failover(t *testing.T) {
|
||||
@@ -339,9 +286,9 @@ func TestUpstreamResolver_Failover(t *testing.T) {
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: trackingClient,
|
||||
upstreamServers: []netip.AddrPort{upstream1, upstream2},
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
}
|
||||
resolver.addRace([]netip.AddrPort{upstream1, upstream2})
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
@@ -421,9 +368,9 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) {
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: mockClient,
|
||||
upstreamServers: []netip.AddrPort{upstream},
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
}
|
||||
resolver.addRace([]netip.AddrPort{upstream})
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
@@ -440,6 +387,136 @@ func TestUpstreamResolver_SingleUpstreamFailure(t *testing.T) {
|
||||
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode, "single upstream SERVFAIL should return SERVFAIL")
|
||||
}
|
||||
|
||||
// TestUpstreamResolver_RaceAcrossGroups covers two nameserver groups
|
||||
// configured for the same domain, with one broken group. The merge+race
|
||||
// path should answer as fast as the working group and not pay the timeout
|
||||
// of the broken one on every query.
|
||||
func TestUpstreamResolver_RaceAcrossGroups(t *testing.T) {
|
||||
broken := netip.MustParseAddrPort("192.0.2.1:53")
|
||||
working := netip.MustParseAddrPort("192.0.2.2:53")
|
||||
successAnswer := "192.0.2.100"
|
||||
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
|
||||
|
||||
mockClient := &mockUpstreamResolverPerServer{
|
||||
responses: map[string]mockUpstreamResponse{
|
||||
// Force the broken upstream to only unblock via timeout /
|
||||
// cancellation so the assertion below can't pass if races
|
||||
// were run serially.
|
||||
broken.String(): {err: timeoutErr, delay: 500 * time.Millisecond},
|
||||
working.String(): {msg: buildMockResponse(dns.RcodeSuccess, successAnswer)},
|
||||
},
|
||||
rtt: time.Millisecond,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: mockClient,
|
||||
upstreamTimeout: 250 * time.Millisecond,
|
||||
}
|
||||
resolver.addRace([]netip.AddrPort{broken})
|
||||
resolver.addRace([]netip.AddrPort{working})
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
inputMSG := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
|
||||
start := time.Now()
|
||||
resolver.ServeDNS(responseWriter, inputMSG)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
require.NotNil(t, responseMSG, "should write a response")
|
||||
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode)
|
||||
require.NotEmpty(t, responseMSG.Answer)
|
||||
assert.Contains(t, responseMSG.Answer[0].String(), successAnswer)
|
||||
// Working group answers in a single RTT; the broken group's
|
||||
// timeout (100ms) must not block the response.
|
||||
assert.Less(t, elapsed, 100*time.Millisecond, "race must not wait for broken group's timeout")
|
||||
}
|
||||
|
||||
// TestUpstreamResolver_AllGroupsFail checks that when every group fails the
|
||||
// resolver returns SERVFAIL rather than leaking a partial response.
|
||||
func TestUpstreamResolver_AllGroupsFail(t *testing.T) {
|
||||
a := netip.MustParseAddrPort("192.0.2.1:53")
|
||||
b := netip.MustParseAddrPort("192.0.2.2:53")
|
||||
|
||||
mockClient := &mockUpstreamResolverPerServer{
|
||||
responses: map[string]mockUpstreamResponse{
|
||||
a.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
b.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
},
|
||||
rtt: time.Millisecond,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: mockClient,
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
}
|
||||
resolver.addRace([]netip.AddrPort{a})
|
||||
resolver.addRace([]netip.AddrPort{b})
|
||||
|
||||
var responseMSG *dns.Msg
|
||||
responseWriter := &test.MockResponseWriter{
|
||||
WriteMsgFunc: func(m *dns.Msg) error {
|
||||
responseMSG = m
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
|
||||
require.NotNil(t, responseMSG)
|
||||
assert.Equal(t, dns.RcodeServerFailure, responseMSG.Rcode)
|
||||
}
|
||||
|
||||
// TestUpstreamResolver_HealthTracking verifies that query-path results are
|
||||
// recorded into per-upstream health, which is what projects back to
|
||||
// NSGroupState for status reporting.
|
||||
func TestUpstreamResolver_HealthTracking(t *testing.T) {
|
||||
ok := netip.MustParseAddrPort("192.0.2.10:53")
|
||||
bad := netip.MustParseAddrPort("192.0.2.11:53")
|
||||
|
||||
mockClient := &mockUpstreamResolverPerServer{
|
||||
responses: map[string]mockUpstreamResponse{
|
||||
ok.String(): {msg: buildMockResponse(dns.RcodeSuccess, "192.0.2.100")},
|
||||
bad.String(): {msg: buildMockResponse(dns.RcodeServerFailure, "")},
|
||||
},
|
||||
rtt: time.Millisecond,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: mockClient,
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
}
|
||||
resolver.addRace([]netip.AddrPort{ok, bad})
|
||||
|
||||
responseWriter := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}
|
||||
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
|
||||
|
||||
health := resolver.UpstreamHealth()
|
||||
require.Contains(t, health, ok)
|
||||
assert.False(t, health[ok].LastOk.IsZero(), "ok upstream should have LastOk set")
|
||||
assert.Empty(t, health[ok].LastErr)
|
||||
|
||||
// bad upstream was never tried because ok answered first; its health
|
||||
// should remain unset.
|
||||
assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers")
|
||||
}
|
||||
|
||||
func TestFormatFailures(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -665,10 +742,10 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
|
||||
// Verify that a client EDNS0 larger than our MTU-derived limit gets
|
||||
// capped in the outgoing request so the upstream doesn't send a
|
||||
// response larger than our read buffer.
|
||||
var receivedUDPSize uint16
|
||||
var receivedUDPSize atomic.Uint32
|
||||
udpHandler := dns.HandlerFunc(func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if opt := r.IsEdns0(); opt != nil {
|
||||
receivedUDPSize = opt.UDPSize()
|
||||
receivedUDPSize.Store(uint32(opt.UDPSize()))
|
||||
}
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
@@ -699,7 +776,7 @@ func TestExchangeWithFallback_EDNS0Capped(t *testing.T) {
|
||||
require.NotNil(t, rm)
|
||||
|
||||
expectedMax := uint16(currentMTU - ipUDPHeaderSize)
|
||||
assert.Equal(t, expectedMax, receivedUDPSize,
|
||||
assert.Equal(t, expectedMax, uint16(receivedUDPSize.Load()),
|
||||
"upstream should see capped EDNS0, not the client's 4096")
|
||||
}
|
||||
|
||||
@@ -874,7 +951,7 @@ func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
|
||||
resolver := &upstreamResolverBase{
|
||||
ctx: ctx,
|
||||
upstreamClient: tracking,
|
||||
upstreamServers: []netip.AddrPort{upstream1, upstream2},
|
||||
upstreamServers: []upstreamRace{{upstream1, upstream2}},
|
||||
upstreamTimeout: UpstreamTimeout,
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ package dnsfwd
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
@@ -160,12 +159,13 @@ func (m *Manager) allowDNSFirewall() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
|
||||
anyV4 := []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||
dnsRules, err := m.firewall.AddFilterRule(nil, anyV4, firewall.Network{}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add udp firewall rule: %w", err)
|
||||
}
|
||||
|
||||
tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "")
|
||||
tcpRules, err := m.firewall.AddFilterRule(nil, anyV4, firewall.Network{}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add tcp firewall rule: %w", err)
|
||||
}
|
||||
@@ -209,12 +209,12 @@ func (m *Manager) unregisterNetstackServices() {
|
||||
func (m *Manager) dropDNSFirewall() error {
|
||||
var mErr *multierror.Error
|
||||
for _, rule := range m.fwRules {
|
||||
if err := m.firewall.DeletePeerRule(rule); err != nil {
|
||||
if err := m.firewall.DeleteFilterRule(rule); err != nil {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
|
||||
}
|
||||
}
|
||||
for _, rule := range m.tcpRules {
|
||||
if err := m.firewall.DeletePeerRule(rule); err != nil {
|
||||
if err := m.firewall.DeleteFilterRule(rule); err != nil {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -512,16 +512,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
|
||||
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
||||
|
||||
e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool {
|
||||
for _, routes := range e.routeManager.GetSelectedClientRoutes() {
|
||||
for _, r := range routes {
|
||||
if r.Network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
e.dnsServer.SetRouteSources(e.routeManager.GetSelectedClientRoutes, e.routeManager.GetActiveClientRoutes)
|
||||
|
||||
if err = e.wgInterfaceCreate(); err != nil {
|
||||
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
||||
@@ -649,14 +640,14 @@ func (e *Engine) initFirewall() error {
|
||||
port := firewallManager.Port{Values: []uint16{uint16(rosenpassPort)}}
|
||||
|
||||
// IPv4-only: rosenpass peers connect via AllowedIps[0] which is always v4.
|
||||
if _, err := e.firewall.AddPeerFiltering(
|
||||
if _, err := e.firewall.AddFilterRule(
|
||||
nil,
|
||||
net.IP{0, 0, 0, 0},
|
||||
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
||||
firewallManager.Network{},
|
||||
firewallManager.ProtocolUDP,
|
||||
nil,
|
||||
&port,
|
||||
firewallManager.ActionAccept,
|
||||
"",
|
||||
); err != nil {
|
||||
log.Errorf("failed to allow rosenpass interface traffic: %v", err)
|
||||
return nil
|
||||
@@ -706,7 +697,7 @@ func (e *Engine) blockLanAccess() {
|
||||
if network.Addr().Is6() {
|
||||
source = v6
|
||||
}
|
||||
if _, err := e.firewall.AddRouteFiltering(
|
||||
if _, err := e.firewall.AddFilterRule(
|
||||
nil,
|
||||
[]netip.Prefix{source},
|
||||
firewallManager.Network{Prefix: network},
|
||||
@@ -1386,9 +1377,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
|
||||
e.networkSerial = serial
|
||||
|
||||
// Test received (upstream) servers for availability right away instead of upon usage.
|
||||
// If no server of a server group responds this will disable the respective handler and retry later.
|
||||
go e.dnsServer.ProbeAvailability()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1932,7 +1920,7 @@ func (e *Engine) newDnsServer(dnsConfig *nbdns.Config) (dns.Server, error) {
|
||||
return dnsServer, nil
|
||||
|
||||
case "ios":
|
||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.mobileDep.HostDNSAddresses, e.statusRecorder, e.config.DisableDNS)
|
||||
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
|
||||
return dnsServer, nil
|
||||
|
||||
default:
|
||||
@@ -1979,6 +1967,29 @@ func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
|
||||
return e.clientMetrics
|
||||
}
|
||||
|
||||
// Performance bundles runtime-adjustable tunnel pool knobs.
|
||||
// See Engine.SetPerformance. Nil fields are ignored.
|
||||
type Performance struct {
|
||||
PreallocatedBuffersPerPool *uint32
|
||||
}
|
||||
|
||||
// SetPerformance applies the given tuning to this engine's live Device.
|
||||
func (e *Engine) SetPerformance(t Performance) error {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
if e.wgInterface == nil {
|
||||
return fmt.Errorf("wg interface not initialized")
|
||||
}
|
||||
dev := e.wgInterface.GetWGDevice()
|
||||
if dev == nil {
|
||||
return fmt.Errorf("wg device not initialized")
|
||||
}
|
||||
if t.PreallocatedBuffersPerPool != nil {
|
||||
dev.SetPreallocatedBuffersPerPool(*t.PreallocatedBuffersPerPool)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||
iface, err := net.InterfaceByName(ifaceName)
|
||||
if err != nil {
|
||||
|
||||
@@ -27,7 +27,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
@@ -66,8 +66,8 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||
@@ -1641,7 +1641,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user