mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-17 21:39:58 +00:00
Compare commits
1 Commits
peer-acl-m
...
client-jso
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a2fd1bb0a8 |
45
.github/dependabot.yml
vendored
45
.github/dependabot.yml
vendored
@@ -1,45 +0,0 @@
|
|||||||
version: 2
|
|
||||||
updates:
|
|
||||||
- package-ecosystem: "github-actions"
|
|
||||||
directory: "/"
|
|
||||||
schedule:
|
|
||||||
interval: "daily"
|
|
||||||
open-pull-requests-limit: 15
|
|
||||||
groups:
|
|
||||||
actions:
|
|
||||||
patterns:
|
|
||||||
- "*"
|
|
||||||
ignore:
|
|
||||||
# git-town/action v1.3.x crashes on cyclic PR graphs (self-loop main->main
|
|
||||||
# fork PRs) via its topological-sort visualization. Pinned to v1.2.1 in
|
|
||||||
# git-town.yml; block v1.3.x until upstream tolerates cyclic edges.
|
|
||||||
- dependency-name: "git-town/action"
|
|
||||||
update-types:
|
|
||||||
- "version-update:semver-minor"
|
|
||||||
- "version-update:semver-major"
|
|
||||||
|
|
||||||
- package-ecosystem: "gomod"
|
|
||||||
directories:
|
|
||||||
- "/"
|
|
||||||
schedule:
|
|
||||||
interval: "daily"
|
|
||||||
open-pull-requests-limit: 15
|
|
||||||
groups:
|
|
||||||
aws-sdk:
|
|
||||||
patterns:
|
|
||||||
- "github.com/aws/aws-sdk-go-v2/*"
|
|
||||||
pion:
|
|
||||||
patterns:
|
|
||||||
- "github.com/pion/*"
|
|
||||||
gorm:
|
|
||||||
patterns:
|
|
||||||
- "gorm.io/*"
|
|
||||||
otel:
|
|
||||||
patterns:
|
|
||||||
- "go.opentelemetry.io/*"
|
|
||||||
testcontainers:
|
|
||||||
patterns:
|
|
||||||
- "github.com/testcontainers/testcontainers-go/*"
|
|
||||||
wireguard:
|
|
||||||
patterns:
|
|
||||||
- "golang.zx2c4.com/wireguard*"
|
|
||||||
105
.github/workflows/check-license-dependencies.yml
vendored
105
.github/workflows/check-license-dependencies.yml
vendored
@@ -2,16 +2,16 @@ name: Check License Dependencies
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [main]
|
branches: [ main ]
|
||||||
paths:
|
paths:
|
||||||
- "go.mod"
|
- 'go.mod'
|
||||||
- "go.sum"
|
- 'go.sum'
|
||||||
- ".github/workflows/check-license-dependencies.yml"
|
- '.github/workflows/check-license-dependencies.yml'
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- "go.mod"
|
- 'go.mod'
|
||||||
- "go.sum"
|
- 'go.sum'
|
||||||
- ".github/workflows/check-license-dependencies.yml"
|
- '.github/workflows/check-license-dependencies.yml'
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
check-internal-dependencies:
|
check-internal-dependencies:
|
||||||
@@ -19,10 +19,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- uses: actions/checkout@v4
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
|
|
||||||
- name: Check for problematic license dependencies
|
- name: Check for problematic license dependencies
|
||||||
run: |
|
run: |
|
||||||
@@ -59,57 +56,55 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
- uses: actions/checkout@v4
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
|
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: 'go.mod'
|
||||||
cache: true
|
cache: true
|
||||||
|
|
||||||
- name: Install go-licenses
|
- name: Install go-licenses
|
||||||
run: go install github.com/google/go-licenses@v1.6.0
|
run: go install github.com/google/go-licenses@v1.6.0
|
||||||
|
|
||||||
- name: Check for GPL/AGPL licensed dependencies
|
- name: Check for GPL/AGPL licensed dependencies
|
||||||
run: |
|
run: |
|
||||||
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
|
echo "Checking for GPL/AGPL/LGPL licensed dependencies..."
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Check all Go packages for copyleft licenses, excluding internal netbird packages
|
||||||
|
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
|
||||||
|
|
||||||
|
if [ -n "$COPYLEFT_DEPS" ]; then
|
||||||
|
echo "Found copyleft licensed dependencies:"
|
||||||
|
echo "$COPYLEFT_DEPS"
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# Check all Go packages for copyleft licenses, excluding internal netbird packages
|
# Filter out dependencies that are only pulled in by internal AGPL packages
|
||||||
COPYLEFT_DEPS=$(go-licenses report ./... 2>/dev/null | grep -E 'GPL|AGPL|LGPL' | grep -v 'github.com/netbirdio/netbird/' || true)
|
INCOMPATIBLE=""
|
||||||
|
while IFS=',' read -r package url license; do
|
||||||
|
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
|
||||||
|
# Find ALL packages that import this GPL package using go list
|
||||||
|
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||||
|
|
||||||
if [ -n "$COPYLEFT_DEPS" ]; then
|
# Check if any importer is NOT in management/signal/relay
|
||||||
echo "Found copyleft licensed dependencies:"
|
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
|
||||||
echo "$COPYLEFT_DEPS"
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
# Filter out dependencies that are only pulled in by internal AGPL packages
|
if [ -n "$BSD_IMPORTER" ]; then
|
||||||
INCOMPATIBLE=""
|
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||||
while IFS=',' read -r package url license; do
|
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
|
||||||
if echo "$license" | grep -qE 'GPL-[0-9]|AGPL-[0-9]|LGPL-[0-9]'; then
|
else
|
||||||
# Find ALL packages that import this GPL package using go list
|
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
|
||||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
|
||||||
|
|
||||||
# Check if any importer is NOT in management/signal/relay
|
|
||||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
|
|
||||||
|
|
||||||
if [ -n "$BSD_IMPORTER" ]; then
|
|
||||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
|
||||||
INCOMPATIBLE="${INCOMPATIBLE}${package},${url},${license}\n"
|
|
||||||
else
|
|
||||||
echo "✓ $package ($license) is only used by internal AGPL packages - OK"
|
|
||||||
fi
|
|
||||||
fi
|
fi
|
||||||
done <<< "$COPYLEFT_DEPS"
|
|
||||||
|
|
||||||
if [ -n "$INCOMPATIBLE" ]; then
|
|
||||||
echo ""
|
|
||||||
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
|
|
||||||
echo -e "$INCOMPATIBLE"
|
|
||||||
exit 1
|
|
||||||
fi
|
fi
|
||||||
fi
|
done <<< "$COPYLEFT_DEPS"
|
||||||
|
|
||||||
echo "✅ All external license dependencies are compatible with BSD-3-Clause"
|
if [ -n "$INCOMPATIBLE" ]; then
|
||||||
|
echo ""
|
||||||
|
echo "❌ INCOMPATIBLE licenses found that are used by BSD-licensed code:"
|
||||||
|
echo -e "$INCOMPATIBLE"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "✅ All external license dependencies are compatible with BSD-3-Clause"
|
||||||
|
|||||||
2
.github/workflows/docs-ack.yml
vendored
2
.github/workflows/docs-ack.yml
vendored
@@ -83,7 +83,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Verify docs PR exists (and is open or merged)
|
- name: Verify docs PR exists (and is open or merged)
|
||||||
if: steps.validate.outputs.mode == 'added'
|
if: steps.validate.outputs.mode == 'added'
|
||||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
uses: actions/github-script@v7
|
||||||
id: verify
|
id: verify
|
||||||
with:
|
with:
|
||||||
pr_number: ${{ steps.extract.outputs.pr_number }}
|
pr_number: ${{ steps.extract.outputs.pr_number }}
|
||||||
|
|||||||
5
.github/workflows/forum.yml
vendored
5
.github/workflows/forum.yml
vendored
@@ -8,10 +8,11 @@ jobs:
|
|||||||
post:
|
post:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: roots/discourse-topic-github-release-action@557d74ea05b6cc0c47f555c1d5d28a89d904005b # v1.1.0
|
- uses: roots/discourse-topic-github-release-action@main
|
||||||
with:
|
with:
|
||||||
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
|
discourse-api-key: ${{ secrets.DISCOURSE_RELEASES_API_KEY }}
|
||||||
discourse-base-url: https://forum.netbird.io
|
discourse-base-url: https://forum.netbird.io
|
||||||
discourse-author-username: NetBird
|
discourse-author-username: NetBird
|
||||||
discourse-category: 17
|
discourse-category: 17
|
||||||
discourse-tags: releases
|
discourse-tags:
|
||||||
|
releases
|
||||||
|
|||||||
8
.github/workflows/git-town.yml
vendored
8
.github/workflows/git-town.yml
vendored
@@ -3,7 +3,7 @@ name: Git Town
|
|||||||
on:
|
on:
|
||||||
pull_request:
|
pull_request:
|
||||||
branches:
|
branches:
|
||||||
- "**"
|
- '**'
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
git-town:
|
git-town:
|
||||||
@@ -15,9 +15,7 @@ jobs:
|
|||||||
pull-requests: write
|
pull-requests: write
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
- uses: actions/checkout@v4
|
||||||
with:
|
- uses: git-town/action@v1.2.1
|
||||||
persist-credentials: false
|
|
||||||
- uses: git-town/action@3d8b878379abb1ee393fb49865a28b4a6c2cd3b0 # v1.2.1
|
|
||||||
with:
|
with:
|
||||||
skip-single-stacks: true
|
skip-single-stacks: true
|
||||||
|
|||||||
16
.github/workflows/golang-test-darwin.yml
vendored
16
.github/workflows/golang-test-darwin.yml
vendored
@@ -16,18 +16,16 @@ jobs:
|
|||||||
runs-on: macos-latest
|
runs-on: macos-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@v4
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: macos-gotest-${{ hashFiles('**/go.sum') }}
|
key: macos-gotest-${{ hashFiles('**/go.sum') }}
|
||||||
@@ -45,11 +43,5 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||||
|
|
||||||
- name: Upload coverage reports to Codecov
|
|
||||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
|
||||||
with:
|
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
|
||||||
slug: netbirdio/netbird
|
|
||||||
flags: unit,client
|
|
||||||
|
|||||||
21
.github/workflows/golang-test-freebsd.yml
vendored
21
.github/workflows/golang-test-freebsd.yml
vendored
@@ -15,31 +15,20 @@ jobs:
|
|||||||
name: "Client / Unit"
|
name: "Client / Unit"
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- uses: actions/checkout@v4
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
|
|
||||||
- name: Read Go version from go.mod
|
|
||||||
id: goversion
|
|
||||||
run: echo "version=$(awk '/^go / {print $2}' go.mod)" >> "$GITHUB_OUTPUT"
|
|
||||||
|
|
||||||
- name: Test in FreeBSD
|
- name: Test in FreeBSD
|
||||||
id: test
|
id: test
|
||||||
env:
|
uses: vmactions/freebsd-vm@v1
|
||||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
|
||||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
|
||||||
with:
|
with:
|
||||||
usesh: true
|
usesh: true
|
||||||
copyback: false
|
copyback: false
|
||||||
release: "15.0"
|
release: "14.2"
|
||||||
envs: "GO_VERSION"
|
|
||||||
prepare: |
|
prepare: |
|
||||||
pkg install -y curl pkgconf xorg
|
pkg install -y curl pkgconf xorg
|
||||||
GO_TARBALL="go${GO_VERSION}.freebsd-amd64.tar.gz"
|
GO_TARBALL="go1.25.3.freebsd-amd64.tar.gz"
|
||||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||||
curl -vLO "$GO_URL"
|
curl -vLO "$GO_URL"
|
||||||
tar -C /usr/local -vxzf "$GO_TARBALL"
|
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||||
|
|
||||||
# -x - to print all executed commands
|
# -x - to print all executed commands
|
||||||
# -e - to faile on first error
|
# -e - to faile on first error
|
||||||
|
|||||||
200
.github/workflows/golang-test-linux.yml
vendored
200
.github/workflows/golang-test-linux.yml
vendored
@@ -18,11 +18,9 @@ jobs:
|
|||||||
management: ${{ steps.filter.outputs.management }}
|
management: ${{ steps.filter.outputs.management }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@v4
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
|
|
||||||
- uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1
|
- uses: dorny/paths-filter@v3
|
||||||
id: filter
|
id: filter
|
||||||
with:
|
with:
|
||||||
filters: |
|
filters: |
|
||||||
@@ -30,7 +28,7 @@ jobs:
|
|||||||
- 'management/**'
|
- 'management/**'
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -38,10 +36,10 @@ jobs:
|
|||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@v4
|
||||||
id: cache
|
id: cache
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
@@ -115,16 +113,14 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
arch: ["386", "amd64"]
|
arch: [ '386','amd64' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@v4
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -132,10 +128,10 @@ jobs:
|
|||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@v4
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -158,28 +154,18 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags "devcert integration" -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||||
|
|
||||||
- name: Upload coverage reports to Codecov
|
|
||||||
if: matrix.arch == 'amd64'
|
|
||||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
|
||||||
with:
|
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
|
||||||
slug: netbirdio/netbird
|
|
||||||
flags: unit,client
|
|
||||||
|
|
||||||
test_client_on_docker:
|
test_client_on_docker:
|
||||||
name: "Client (Docker) / Unit"
|
name: "Client (Docker) / Unit"
|
||||||
needs: [build-cache]
|
needs: [ build-cache ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@v4
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -191,7 +177,7 @@ jobs:
|
|||||||
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
|
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@v4
|
||||||
id: cache-restore
|
id: cache-restore
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
@@ -228,7 +214,7 @@ jobs:
|
|||||||
sh -c ' \
|
sh -c ' \
|
||||||
apk update; apk add --no-cache \
|
apk update; apk add --no-cache \
|
||||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||||
go test -buildvcs=false -tags "devcert integration" -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
|
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
|
||||||
'
|
'
|
||||||
|
|
||||||
test_relay:
|
test_relay:
|
||||||
@@ -245,12 +231,10 @@ jobs:
|
|||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@v4
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -262,10 +246,10 @@ jobs:
|
|||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@v4
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -284,33 +268,23 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
go test ${{ matrix.raceFlag }} \
|
go test ${{ matrix.raceFlag }} \
|
||||||
-exec 'sudo' -coverprofile=coverage.txt \
|
-exec 'sudo' \
|
||||||
-timeout 10m -p 1 ./relay/... ./shared/relay/...
|
-timeout 10m -p 1 ./relay/... ./shared/relay/...
|
||||||
|
|
||||||
- name: Upload coverage reports to Codecov
|
|
||||||
if: matrix.arch == 'amd64'
|
|
||||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
|
||||||
with:
|
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
|
||||||
slug: netbirdio/netbird
|
|
||||||
flags: unit,relay
|
|
||||||
|
|
||||||
test_proxy:
|
test_proxy:
|
||||||
name: "Proxy / Unit"
|
name: "Proxy / Unit"
|
||||||
needs: [build-cache]
|
needs: [build-cache]
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
arch: ["386", "amd64"]
|
arch: [ '386','amd64' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@v4
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -324,7 +298,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@v4
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -342,15 +316,7 @@ jobs:
|
|||||||
- name: Test
|
- name: Test
|
||||||
run: |
|
run: |
|
||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
go test -timeout 10m -p 1 -coverprofile=coverage.txt ./proxy/...
|
go test -timeout 10m -p 1 ./proxy/...
|
||||||
|
|
||||||
- name: Upload coverage reports to Codecov
|
|
||||||
if: matrix.arch == 'amd64'
|
|
||||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
|
||||||
with:
|
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
|
||||||
slug: netbirdio/netbird
|
|
||||||
flags: unit,proxy
|
|
||||||
|
|
||||||
test_signal:
|
test_signal:
|
||||||
name: "Signal / Unit"
|
name: "Signal / Unit"
|
||||||
@@ -358,16 +324,14 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
arch: ["386", "amd64"]
|
arch: [ '386','amd64' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@v4
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -379,10 +343,10 @@ jobs:
|
|||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@v4
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -401,34 +365,24 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
go test \
|
go test \
|
||||||
-exec 'sudo' -coverprofile=coverage.txt \
|
-exec 'sudo' \
|
||||||
-timeout 10m ./signal/... ./shared/signal/...
|
-timeout 10m ./signal/... ./shared/signal/...
|
||||||
|
|
||||||
- name: Upload coverage reports to Codecov
|
|
||||||
if: matrix.arch == 'amd64'
|
|
||||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
|
||||||
with:
|
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
|
||||||
slug: netbirdio/netbird
|
|
||||||
flags: unit,signal
|
|
||||||
|
|
||||||
test_management:
|
test_management:
|
||||||
name: "Management / Unit"
|
name: "Management / Unit"
|
||||||
needs: [build-cache]
|
needs: [ build-cache ]
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
arch: ["amd64"]
|
arch: [ 'amd64' ]
|
||||||
store: ["sqlite", "postgres", "mysql"]
|
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@v4
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -436,10 +390,10 @@ jobs:
|
|||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@v4
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -456,7 +410,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Login to Docker hub
|
- name: Login to Docker hub
|
||||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKER_USER }}
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
password: ${{ secrets.DOCKER_TOKEN }}
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
@@ -473,31 +427,23 @@ jobs:
|
|||||||
run: docker pull mlsmaycon/warmed-mysql:8
|
run: docker pull mlsmaycon/warmed-mysql:8
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: |
|
run: |
|
||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||||
CI=true \
|
CI=true \
|
||||||
go test -tags=devcert -coverprofile=coverage.txt \
|
go test -tags=devcert \
|
||||||
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
|
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
|
||||||
-timeout 20m ./management/... ./shared/management/...
|
-timeout 20m ./management/... ./shared/management/...
|
||||||
|
|
||||||
- name: Upload coverage reports to Codecov
|
|
||||||
if: matrix.arch == 'amd64'
|
|
||||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
|
||||||
with:
|
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
|
||||||
slug: netbirdio/netbird
|
|
||||||
flags: unit,management
|
|
||||||
|
|
||||||
benchmark:
|
benchmark:
|
||||||
name: "Management / Benchmark"
|
name: "Management / Benchmark"
|
||||||
needs: [build-cache]
|
needs: [ build-cache ]
|
||||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
arch: ["amd64"]
|
arch: [ 'amd64' ]
|
||||||
store: ["sqlite", "postgres"]
|
store: [ 'sqlite', 'postgres' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Create Docker network
|
- name: Create Docker network
|
||||||
@@ -528,12 +474,10 @@ jobs:
|
|||||||
prom/prometheus
|
prom/prometheus
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@v4
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -541,10 +485,10 @@ jobs:
|
|||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@v4
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -561,7 +505,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Login to Docker hub
|
- name: Login to Docker hub
|
||||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKER_USER }}
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
password: ${{ secrets.DOCKER_TOKEN }}
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
@@ -585,13 +529,13 @@ jobs:
|
|||||||
|
|
||||||
api_benchmark:
|
api_benchmark:
|
||||||
name: "Management / Benchmark (API)"
|
name: "Management / Benchmark (API)"
|
||||||
needs: [build-cache]
|
needs: [ build-cache ]
|
||||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
arch: ["amd64"]
|
arch: [ 'amd64' ]
|
||||||
store: ["sqlite", "postgres"]
|
store: [ 'sqlite', 'postgres' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Create Docker network
|
- name: Create Docker network
|
||||||
@@ -622,12 +566,10 @@ jobs:
|
|||||||
prom/prometheus
|
prom/prometheus
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@v4
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -635,10 +577,10 @@ jobs:
|
|||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@v4
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -655,7 +597,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Login to Docker hub
|
- name: Login to Docker hub
|
||||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKER_USER }}
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
password: ${{ secrets.DOCKER_TOKEN }}
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
@@ -681,22 +623,20 @@ jobs:
|
|||||||
|
|
||||||
api_integration_test:
|
api_integration_test:
|
||||||
name: "Management / Integration"
|
name: "Management / Integration"
|
||||||
needs: [build-cache]
|
needs: [ build-cache ]
|
||||||
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
arch: ["amd64"]
|
arch: [ 'amd64' ]
|
||||||
store: ["sqlite", "postgres"]
|
store: [ 'sqlite', 'postgres']
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@v4
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -704,10 +644,10 @@ jobs:
|
|||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache/restore@v4
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -727,14 +667,6 @@ jobs:
|
|||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||||
CI=true \
|
CI=true \
|
||||||
go test -tags=integration -coverprofile=coverage.txt \
|
go test -tags=integration \
|
||||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||||
-timeout 20m ./management/server/http/...
|
-timeout 20m ./management/server/http/...
|
||||||
|
|
||||||
- name: Upload coverage reports to Codecov
|
|
||||||
if: matrix.arch == 'amd64'
|
|
||||||
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1
|
|
||||||
with:
|
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
|
||||||
slug: netbirdio/netbird
|
|
||||||
flags: integration,management
|
|
||||||
|
|||||||
19
.github/workflows/golang-test-windows.yml
vendored
19
.github/workflows/golang-test-windows.yml
vendored
@@ -18,12 +18,10 @@ jobs:
|
|||||||
runs-on: windows-latest
|
runs-on: windows-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@v4
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@v5
|
||||||
id: go
|
id: go
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
@@ -35,7 +33,7 @@ jobs:
|
|||||||
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
|
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
${{ env.cache }}
|
${{ env.cache }}
|
||||||
@@ -46,15 +44,16 @@ jobs:
|
|||||||
${{ runner.os }}-go-
|
${{ runner.os }}-go-
|
||||||
|
|
||||||
- name: Download wintun
|
- name: Download wintun
|
||||||
|
uses: carlosperate/download-file-action@v2
|
||||||
id: download-wintun
|
id: download-wintun
|
||||||
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
|
||||||
with:
|
with:
|
||||||
url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
|
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
|
||||||
destination: ${{ env.downloadPath }}\wintun.zip
|
file-name: wintun.zip
|
||||||
sha256: 07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51
|
location: ${{ env.downloadPath }}
|
||||||
|
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
|
||||||
|
|
||||||
- name: Decompressing wintun files
|
- name: Decompressing wintun files
|
||||||
run: tar -xvf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
||||||
|
|
||||||
- run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'
|
- run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'
|
||||||
|
|
||||||
|
|||||||
14
.github/workflows/golangci-lint.yml
vendored
14
.github/workflows/golangci-lint.yml
vendored
@@ -15,11 +15,9 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@v4
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
- name: codespell
|
- name: codespell
|
||||||
uses: codespell-project/actions-codespell@8f01853be192eb0f849a5c7d721450e7a467c579 # v2.2
|
uses: codespell-project/actions-codespell@v2
|
||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
|
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
|
||||||
skip: go.mod,go.sum,**/proxy/web/**
|
skip: go.mod,go.sum,**/proxy/web/**
|
||||||
@@ -40,15 +38,13 @@ jobs:
|
|||||||
timeout-minutes: 15
|
timeout-minutes: 15
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@v4
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
- name: Check for duplicate constants
|
- name: Check for duplicate constants
|
||||||
if: matrix.os == 'ubuntu-latest'
|
if: matrix.os == 'ubuntu-latest'
|
||||||
run: |
|
run: |
|
||||||
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
@@ -56,7 +52,7 @@ jobs:
|
|||||||
if: matrix.os == 'ubuntu-latest'
|
if: matrix.os == 'ubuntu-latest'
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||||
- name: golangci-lint
|
- name: golangci-lint
|
||||||
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
|
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
|
||||||
with:
|
with:
|
||||||
version: latest
|
version: latest
|
||||||
skip-cache: true
|
skip-cache: true
|
||||||
|
|||||||
4
.github/workflows/install-script-test.yml
vendored
4
.github/workflows/install-script-test.yml
vendored
@@ -22,9 +22,7 @@ jobs:
|
|||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@v4
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
|
|
||||||
- name: run install script
|
- name: run install script
|
||||||
env:
|
env:
|
||||||
|
|||||||
18
.github/workflows/mobile-build-validation.yml
vendored
18
.github/workflows/mobile-build-validation.yml
vendored
@@ -16,25 +16,23 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@v4
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
- name: Setup Android SDK
|
- name: Setup Android SDK
|
||||||
uses: android-actions/setup-android@40fd30fb8d7440372e1316f5d1809ec01dcd3699 # v4.0.1
|
uses: android-actions/setup-android@v3
|
||||||
with:
|
with:
|
||||||
cmdline-tools-version: 8512546
|
cmdline-tools-version: 8512546
|
||||||
- name: Setup Java
|
- name: Setup Java
|
||||||
uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654
|
uses: actions/setup-java@v4
|
||||||
with:
|
with:
|
||||||
java-version: "11"
|
java-version: "11"
|
||||||
distribution: "adopt"
|
distribution: "adopt"
|
||||||
- name: NDK Cache
|
- name: NDK Cache
|
||||||
id: ndk-cache
|
id: ndk-cache
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: /usr/local/lib/android/sdk/ndk
|
path: /usr/local/lib/android/sdk/ndk
|
||||||
key: ndk-cache-23.1.7779620
|
key: ndk-cache-23.1.7779620
|
||||||
@@ -54,11 +52,9 @@ jobs:
|
|||||||
runs-on: macos-latest
|
runs-on: macos-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@v4
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
- name: install gomobile
|
- name: install gomobile
|
||||||
|
|||||||
2
.github/workflows/pr-title-check.yml
vendored
2
.github/workflows/pr-title-check.yml
vendored
@@ -9,7 +9,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Validate PR title prefix
|
- name: Validate PR title prefix
|
||||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const title = context.payload.pull_request.title;
|
const title = context.payload.pull_request.title;
|
||||||
|
|||||||
41
.github/workflows/proto-version-check.yml
vendored
41
.github/workflows/proto-version-check.yml
vendored
@@ -10,7 +10,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Check for proto tool version changes
|
- name: Check for proto tool version changes
|
||||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
uses: actions/github-script@v7
|
||||||
with:
|
with:
|
||||||
script: |
|
script: |
|
||||||
const files = await github.paginate(github.rest.pulls.listFiles, {
|
const files = await github.paginate(github.rest.pulls.listFiles, {
|
||||||
@@ -20,30 +20,15 @@ jobs:
|
|||||||
per_page: 100,
|
per_page: 100,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Cover renamed .pb.go files in addition to plain edits.
|
const modifiedPbFiles = files.filter(
|
||||||
// Renamed entries land under the new path with previous_filename
|
f => f.filename.endsWith('.pb.go') && f.status === 'modified'
|
||||||
// pointing at the base-side name, so we read the base content
|
);
|
||||||
// from the old path when present.
|
if (modifiedPbFiles.length === 0) {
|
||||||
const changedPbFiles = files
|
console.log('No modified .pb.go files to check');
|
||||||
.filter(f => (f.status === 'modified' || f.status === 'renamed')
|
|
||||||
&& f.filename.endsWith('.pb.go'))
|
|
||||||
.map(f => ({
|
|
||||||
headPath: f.filename,
|
|
||||||
basePath: f.previous_filename || f.filename,
|
|
||||||
}));
|
|
||||||
if (changedPbFiles.length === 0) {
|
|
||||||
console.log('No modified or renamed .pb.go files to check');
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Matches the generator version headers protoc writes at the top
|
const versionPattern = /^\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
|
||||||
// of generated files:
|
|
||||||
// // protoc v3.21.12
|
|
||||||
// // protoc-gen-go v1.26.0
|
|
||||||
// // - protoc-gen-go-grpc v1.6.1 (grpc files prefix with "- ")
|
|
||||||
// The optional "- " prefix and the optional -gen-go / -gen-go-grpc
|
|
||||||
// suffixes keep the *_grpc.pb.go headers in scope.
|
|
||||||
const versionPattern = /^\s*\/\/\s+(?:-\s+)?protoc(?:-gen-go(?:-grpc)?)?\s+v[\d.]+/;
|
|
||||||
const baseSha = context.payload.pull_request.base.sha;
|
const baseSha = context.payload.pull_request.base.sha;
|
||||||
const headSha = context.payload.pull_request.head.sha;
|
const headSha = context.payload.pull_request.head.sha;
|
||||||
|
|
||||||
@@ -70,22 +55,20 @@ jobs:
|
|||||||
}
|
}
|
||||||
|
|
||||||
const violations = [];
|
const violations = [];
|
||||||
for (const file of changedPbFiles) {
|
for (const file of modifiedPbFiles) {
|
||||||
const [base, head] = await Promise.all([
|
const [base, head] = await Promise.all([
|
||||||
getVersionHeader(file.basePath, baseSha),
|
getVersionHeader(file.filename, baseSha),
|
||||||
getVersionHeader(file.headPath, headSha),
|
getVersionHeader(file.filename, headSha),
|
||||||
]);
|
]);
|
||||||
if (!base.ok || !head.ok) {
|
if (!base.ok || !head.ok) {
|
||||||
core.warning(
|
core.warning(
|
||||||
`Skipping ${file.headPath}: base=${base.ok ? 'ok' : base.reason}, head=${head.ok ? 'ok' : head.reason}`
|
`Skipping ${file.filename}: base=${base.ok ? 'ok' : base.reason}, head=${head.ok ? 'ok' : head.reason}`
|
||||||
);
|
);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (base.lines.join('\n') !== head.lines.join('\n')) {
|
if (base.lines.join('\n') !== head.lines.join('\n')) {
|
||||||
violations.push({
|
violations.push({
|
||||||
file: file.basePath === file.headPath
|
file: file.filename,
|
||||||
? file.headPath
|
|
||||||
: `${file.basePath} → ${file.headPath}`,
|
|
||||||
base: base.lines,
|
base: base.lines,
|
||||||
head: head.lines,
|
head: head.lines,
|
||||||
});
|
});
|
||||||
|
|||||||
172
.github/workflows/release.yml
vendored
172
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.1.5"
|
SIGN_PIPE_VER: "v0.1.4"
|
||||||
GORELEASER_VER: "v2.14.3"
|
GORELEASER_VER: "v2.14.3"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "NetBird GmbH"
|
COPYRIGHT: "NetBird GmbH"
|
||||||
@@ -24,15 +24,13 @@ jobs:
|
|||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@v4
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
|
|
||||||
- name: Generate FreeBSD port diff
|
- name: Generate FreeBSD port diff
|
||||||
run: bash -x release_files/freebsd-port-diff.sh
|
run: bash release_files/freebsd-port-diff.sh
|
||||||
|
|
||||||
- name: Generate FreeBSD port issue body
|
- name: Generate FreeBSD port issue body
|
||||||
run: bash -x release_files/freebsd-port-issue-body.sh
|
run: bash release_files/freebsd-port-issue-body.sh
|
||||||
|
|
||||||
- name: Check if diff was generated
|
- name: Check if diff was generated
|
||||||
id: check_diff
|
id: check_diff
|
||||||
@@ -53,26 +51,19 @@ jobs:
|
|||||||
echo "Generated files for version: $VERSION"
|
echo "Generated files for version: $VERSION"
|
||||||
cat netbird-*.diff
|
cat netbird-*.diff
|
||||||
|
|
||||||
- name: Read Go version from go.mod
|
|
||||||
id: goversion
|
|
||||||
run: echo "version=$(awk '/^go / {print $2}' go.mod)" >> "$GITHUB_OUTPUT"
|
|
||||||
|
|
||||||
- name: Test FreeBSD port
|
- name: Test FreeBSD port
|
||||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||||
env:
|
uses: vmactions/freebsd-vm@v1
|
||||||
GO_VERSION: ${{ steps.goversion.outputs.version }}
|
|
||||||
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5
|
|
||||||
with:
|
with:
|
||||||
usesh: true
|
usesh: true
|
||||||
copyback: false
|
copyback: false
|
||||||
release: "15.0"
|
release: "15.0"
|
||||||
envs: "GO_VERSION"
|
|
||||||
prepare: |
|
prepare: |
|
||||||
# Install required packages
|
# Install required packages
|
||||||
pkg install -y git curl portlint
|
pkg install -y git curl portlint go
|
||||||
|
|
||||||
# Install Go for building
|
# Install Go for building
|
||||||
GO_TARBALL="go${GO_VERSION}.freebsd-amd64.tar.gz"
|
GO_TARBALL="go1.25.5.freebsd-amd64.tar.gz"
|
||||||
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||||
curl -LO "$GO_URL"
|
curl -LO "$GO_URL"
|
||||||
tar -C /usr/local -xzf "$GO_TARBALL"
|
tar -C /usr/local -xzf "$GO_TARBALL"
|
||||||
@@ -102,19 +93,19 @@ jobs:
|
|||||||
|
|
||||||
# Show patched Makefile
|
# Show patched Makefile
|
||||||
version=$(cat security/netbird/Makefile | grep -E '^DISTVERSION=' | awk '{print $NF}')
|
version=$(cat security/netbird/Makefile | grep -E '^DISTVERSION=' | awk '{print $NF}')
|
||||||
|
|
||||||
cd /usr/ports/security/netbird
|
cd /usr/ports/security/netbird
|
||||||
export BATCH=yes
|
export BATCH=yes
|
||||||
make package
|
make package
|
||||||
pkg add ./work/pkg/netbird-*.pkg
|
pkg add ./work/pkg/netbird-*.pkg
|
||||||
|
|
||||||
netbird version | grep "$version"
|
netbird version | grep "$version"
|
||||||
|
|
||||||
echo "FreeBSD port test completed successfully!"
|
echo "FreeBSD port test completed successfully!"
|
||||||
|
|
||||||
- name: Upload FreeBSD port files
|
- name: Upload FreeBSD port files
|
||||||
if: steps.check_diff.outputs.diff_exists == 'true'
|
if: steps.check_diff.outputs.diff_exists == 'true'
|
||||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: freebsd-port-files
|
name: freebsd-port-files
|
||||||
path: |
|
path: |
|
||||||
@@ -133,25 +124,26 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
flags: ""
|
flags: ""
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
|
||||||
with:
|
|
||||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
|
||||||
persist-credentials: false
|
|
||||||
|
|
||||||
- name: Parse semver string
|
- name: Parse semver string
|
||||||
id: semver_parser
|
id: semver_parser
|
||||||
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||||
|
with:
|
||||||
|
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||||
|
version_extractor_regex: '\/v(.*)$'
|
||||||
|
|
||||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/go/pkg/mod
|
~/go/pkg/mod
|
||||||
@@ -164,18 +156,18 @@ jobs:
|
|||||||
- name: check git status
|
- name: check git status
|
||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
- name: Set up QEMU
|
- name: Set up QEMU
|
||||||
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a #v4.0.0
|
uses: docker/setup-qemu-action@v2
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd #v4.0.0
|
uses: docker/setup-buildx-action@v2
|
||||||
- name: Login to Docker hub
|
- name: Login to Docker hub
|
||||||
if: github.event_name != 'pull_request'
|
if: github.event_name != 'pull_request'
|
||||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
uses: docker/login-action@v1
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKER_USER }}
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
password: ${{ secrets.DOCKER_TOKEN }}
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
- name: Log in to the GitHub container registry
|
- name: Log in to the GitHub container registry
|
||||||
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||||
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
registry: ghcr.io
|
registry: ghcr.io
|
||||||
username: ${{ github.actor }}
|
username: ${{ github.actor }}
|
||||||
@@ -199,7 +191,7 @@ jobs:
|
|||||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||||
- name: Run GoReleaser
|
- name: Run GoReleaser
|
||||||
id: goreleaser
|
id: goreleaser
|
||||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
uses: goreleaser/goreleaser-action@v4
|
||||||
with:
|
with:
|
||||||
version: ${{ env.GORELEASER_VER }}
|
version: ${{ env.GORELEASER_VER }}
|
||||||
args: release --clean ${{ env.flags }}
|
args: release --clean ${{ env.flags }}
|
||||||
@@ -290,28 +282,28 @@ jobs:
|
|||||||
} >> "$GITHUB_OUTPUT"
|
} >> "$GITHUB_OUTPUT"
|
||||||
- name: upload non tags for debug purposes
|
- name: upload non tags for debug purposes
|
||||||
id: upload_release
|
id: upload_release
|
||||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: release
|
name: release
|
||||||
path: dist/
|
path: dist/
|
||||||
retention-days: 7
|
retention-days: 7
|
||||||
- name: upload linux packages
|
- name: upload linux packages
|
||||||
id: upload_linux_packages
|
id: upload_linux_packages
|
||||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: linux-packages
|
name: linux-packages
|
||||||
path: dist/netbird_linux**
|
path: dist/netbird_linux**
|
||||||
retention-days: 7
|
retention-days: 7
|
||||||
- name: upload windows packages
|
- name: upload windows packages
|
||||||
id: upload_windows_packages
|
id: upload_windows_packages
|
||||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: windows-packages
|
name: windows-packages
|
||||||
path: dist/netbird_windows**
|
path: dist/netbird_windows**
|
||||||
retention-days: 7
|
retention-days: 7
|
||||||
- name: upload macos packages
|
- name: upload macos packages
|
||||||
id: upload_macos_packages
|
id: upload_macos_packages
|
||||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: macos-packages
|
name: macos-packages
|
||||||
path: dist/netbird_darwin**
|
path: dist/netbird_darwin**
|
||||||
@@ -322,26 +314,27 @@ jobs:
|
|||||||
outputs:
|
outputs:
|
||||||
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
|
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
|
||||||
with:
|
|
||||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
|
||||||
persist-credentials: false
|
|
||||||
|
|
||||||
- name: Parse semver string
|
- name: Parse semver string
|
||||||
id: semver_parser
|
id: semver_parser
|
||||||
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||||
|
with:
|
||||||
|
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||||
|
version_extractor_regex: '\/v(.*)$'
|
||||||
|
|
||||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||||
|
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/go/pkg/mod
|
~/go/pkg/mod
|
||||||
@@ -382,7 +375,7 @@ jobs:
|
|||||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
|
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
|
||||||
|
|
||||||
- name: Run GoReleaser
|
- name: Run GoReleaser
|
||||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
uses: goreleaser/goreleaser-action@v4
|
||||||
with:
|
with:
|
||||||
version: ${{ env.GORELEASER_VER }}
|
version: ${{ env.GORELEASER_VER }}
|
||||||
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
|
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
|
||||||
@@ -411,7 +404,7 @@ jobs:
|
|||||||
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
||||||
- name: upload non tags for debug purposes
|
- name: upload non tags for debug purposes
|
||||||
id: upload_release_ui
|
id: upload_release_ui
|
||||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: release-ui
|
name: release-ui
|
||||||
path: dist/
|
path: dist/
|
||||||
@@ -425,17 +418,16 @@ jobs:
|
|||||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
fetch-depth: 0 # It is required for GoReleaser to work properly
|
fetch-depth: 0 # It is required for GoReleaser to work properly
|
||||||
persist-credentials: false
|
|
||||||
- name: Set up Go
|
- name: Set up Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
cache: false
|
cache: false
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: |
|
path: |
|
||||||
~/go/pkg/mod
|
~/go/pkg/mod
|
||||||
@@ -449,7 +441,7 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
- name: Run GoReleaser
|
- name: Run GoReleaser
|
||||||
id: goreleaser
|
id: goreleaser
|
||||||
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0
|
uses: goreleaser/goreleaser-action@v4
|
||||||
with:
|
with:
|
||||||
version: ${{ env.GORELEASER_VER }}
|
version: ${{ env.GORELEASER_VER }}
|
||||||
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
|
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
|
||||||
@@ -457,7 +449,7 @@ jobs:
|
|||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
- name: upload non tags for debug purposes
|
- name: upload non tags for debug purposes
|
||||||
id: upload_release_ui_darwin
|
id: upload_release_ui_darwin
|
||||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: release-ui-darwin
|
name: release-ui-darwin
|
||||||
path: dist/
|
path: dist/
|
||||||
@@ -482,26 +474,27 @@ jobs:
|
|||||||
PackageWorkdir: netbird_windows_${{ matrix.arch }}
|
PackageWorkdir: netbird_windows_${{ matrix.arch }}
|
||||||
downloadPath: '${{ github.workspace }}\temp'
|
downloadPath: '${{ github.workspace }}\temp'
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
|
|
||||||
- name: Parse semver string
|
- name: Parse semver string
|
||||||
id: semver_parser
|
id: semver_parser
|
||||||
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
uses: booxmedialtd/ws-action-parse-semver@v1
|
||||||
|
with:
|
||||||
|
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
||||||
|
version_extractor_regex: '\/v(.*)$'
|
||||||
|
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Add 7-Zip to PATH
|
- name: Add 7-Zip to PATH
|
||||||
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||||
|
|
||||||
- name: Download release artifacts
|
- name: Download release artifacts
|
||||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
uses: actions/download-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: release
|
name: release
|
||||||
path: release
|
path: release
|
||||||
|
|
||||||
- name: Download UI release artifacts
|
- name: Download UI release artifacts
|
||||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1
|
uses: actions/download-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: release-ui
|
name: release-ui
|
||||||
path: release-ui
|
path: release-ui
|
||||||
@@ -521,27 +514,29 @@ jobs:
|
|||||||
Get-ChildItem $workdir
|
Get-ChildItem $workdir
|
||||||
|
|
||||||
- name: Download wintun
|
- name: Download wintun
|
||||||
|
uses: carlosperate/download-file-action@v2
|
||||||
id: download-wintun
|
id: download-wintun
|
||||||
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
|
||||||
with:
|
with:
|
||||||
url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
|
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
|
||||||
destination: ${{ env.downloadPath }}\wintun.zip
|
file-name: wintun.zip
|
||||||
sha256: 07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51
|
location: ${{ env.downloadPath }}
|
||||||
|
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
|
||||||
|
|
||||||
- name: Decompress wintun files
|
- name: Decompress wintun files
|
||||||
run: tar -xvf "${{ env.downloadPath }}\wintun.zip" -C ${{ env.downloadPath }}
|
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
||||||
|
|
||||||
- name: Move wintun.dll into dist
|
- name: Move wintun.dll into dist
|
||||||
run: mv ${{ env.downloadPath }}\wintun\bin\${{ matrix.wintun_arch }}\wintun.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
|
run: mv ${{ env.downloadPath }}\wintun\bin\${{ matrix.wintun_arch }}\wintun.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
|
||||||
|
|
||||||
- name: Download Mesa3D (amd64 only)
|
- name: Download Mesa3D (amd64 only)
|
||||||
|
uses: carlosperate/download-file-action@v2
|
||||||
id: download-mesa3d
|
id: download-mesa3d
|
||||||
if: matrix.arch == 'amd64'
|
if: matrix.arch == 'amd64'
|
||||||
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
|
||||||
with:
|
with:
|
||||||
url: https://pkgs.netbird.io/mesa3d/MesaForWindows-x64-20.1.8.7z
|
file-url: https://downloads.fdossena.com/Projects/Mesa3D/Builds/MesaForWindows-x64-20.1.8.7z
|
||||||
destination: ${{ env.downloadPath }}\mesa3d.7z
|
file-name: mesa3d.7z
|
||||||
sha256: 71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9
|
location: ${{ env.downloadPath }}
|
||||||
|
sha256: '71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9'
|
||||||
|
|
||||||
- name: Extract Mesa3D driver (amd64 only)
|
- name: Extract Mesa3D driver (amd64 only)
|
||||||
if: matrix.arch == 'amd64'
|
if: matrix.arch == 'amd64'
|
||||||
@@ -552,38 +547,35 @@ jobs:
|
|||||||
run: mv ${{ env.downloadPath }}\opengl32.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
|
run: mv ${{ env.downloadPath }}\opengl32.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
|
||||||
|
|
||||||
- name: Download EnVar plugin for NSIS
|
- name: Download EnVar plugin for NSIS
|
||||||
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
uses: carlosperate/download-file-action@v2
|
||||||
with:
|
with:
|
||||||
url: https://pkgs.netbird.io/nsis/EnVar_plugin.zip
|
file-url: https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip
|
||||||
destination: ${{ github.workspace }}\envar_plugin.zip
|
file-name: envar_plugin.zip
|
||||||
sha256: e9aa92de351345ed82795251d838f1ae9041ba35af9d381a5780c7843b01f56a
|
location: ${{ github.workspace }}
|
||||||
|
|
||||||
- name: Extract EnVar plugin
|
- name: Extract EnVar plugin
|
||||||
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/envar_plugin.zip"
|
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/envar_plugin.zip"
|
||||||
|
|
||||||
- name: Download ShellExecAsUser plugin for NSIS (amd64 only)
|
- name: Download ShellExecAsUser plugin for NSIS (amd64 only)
|
||||||
|
uses: carlosperate/download-file-action@v2
|
||||||
if: matrix.arch == 'amd64'
|
if: matrix.arch == 'amd64'
|
||||||
uses: netbirdio/shared-actions/actions/win-download-and-verify@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
|
|
||||||
with:
|
with:
|
||||||
url: https://pkgs.netbird.io/nsis/ShellExecAsUser_amd64-Unicode.7z
|
file-url: https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z
|
||||||
destination: ${{ github.workspace }}\ShellExecAsUser_amd64-Unicode.7z
|
file-name: ShellExecAsUser_amd64-Unicode.7z
|
||||||
sha256: 0a55ea25c7330a92cec028eda8afcaf1b1a7092e0dfb77c21c8f654564b4ff9d
|
location: ${{ github.workspace }}
|
||||||
|
|
||||||
- name: Extract ShellExecAsUser plugin (amd64 only)
|
- name: Extract ShellExecAsUser plugin (amd64 only)
|
||||||
if: matrix.arch == 'amd64'
|
if: matrix.arch == 'amd64'
|
||||||
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/ShellExecAsUser_amd64-Unicode.7z"
|
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/ShellExecAsUser_amd64-Unicode.7z"
|
||||||
|
|
||||||
- name: Build NSIS installer
|
- name: Build NSIS installer
|
||||||
shell: pwsh
|
uses: joncloud/makensis-action@v3.3
|
||||||
|
with:
|
||||||
|
additional-plugin-paths: ${{ github.workspace }}/NSIS_Plugins/Plugins
|
||||||
|
script-file: client/installer.nsis
|
||||||
|
arguments: "/V4 /DARCH=${{ matrix.arch }}"
|
||||||
env:
|
env:
|
||||||
APPVER: ${{ steps.semver_parser.outputs.major }}.${{ steps.semver_parser.outputs.minor }}.${{ steps.semver_parser.outputs.patch }}.${{ github.run_id }}
|
APPVER: ${{ steps.semver_parser.outputs.major }}.${{ steps.semver_parser.outputs.minor }}.${{ steps.semver_parser.outputs.patch }}.${{ github.run_id }}
|
||||||
run: |
|
|
||||||
$nsisPluginDir = "C:\Program Files (x86)\NSIS\Plugins\x86-unicode"
|
|
||||||
$srcPlugins = "${{ github.workspace }}\NSIS_Plugins\Plugins"
|
|
||||||
Get-ChildItem -Path $srcPlugins -Recurse -Filter *.dll |
|
|
||||||
Copy-Item -Destination $nsisPluginDir -Force
|
|
||||||
& "C:\Program Files (x86)\NSIS\makensis.exe" /V4 "/DARCH=${{ matrix.arch }}" client\installer.nsis
|
|
||||||
if ($LASTEXITCODE -ne 0) { throw "makensis failed with exit code $LASTEXITCODE" }
|
|
||||||
|
|
||||||
- name: Rename NSIS installer
|
- name: Rename NSIS installer
|
||||||
run: mv netbird-installer.exe netbird_installer_test_windows_${{ matrix.arch }}.exe
|
run: mv netbird-installer.exe netbird_installer_test_windows_${{ matrix.arch }}.exe
|
||||||
@@ -600,7 +592,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Upload installer artifacts
|
- name: Upload installer artifacts
|
||||||
if: always()
|
if: always()
|
||||||
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a #v7.0.1
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: windows-installer-test-${{ matrix.arch }}
|
name: windows-installer-test-${{ matrix.arch }}
|
||||||
path: |
|
path: |
|
||||||
@@ -619,7 +611,7 @@ jobs:
|
|||||||
pull-requests: write
|
pull-requests: write
|
||||||
steps:
|
steps:
|
||||||
- name: Create or update PR comment
|
- name: Create or update PR comment
|
||||||
uses: actions/github-script@3a2844b7e9c422d3c10d287c895573f7108da1b3 # v9.0.0
|
uses: actions/github-script@v7
|
||||||
env:
|
env:
|
||||||
RELEASE_RESULT: ${{ needs.release.result }}
|
RELEASE_RESULT: ${{ needs.release.result }}
|
||||||
RELEASE_UI_RESULT: ${{ needs.release_ui.result }}
|
RELEASE_UI_RESULT: ${{ needs.release_ui.result }}
|
||||||
@@ -711,7 +703,7 @@ jobs:
|
|||||||
if: startsWith(github.ref, 'refs/tags/')
|
if: startsWith(github.ref, 'refs/tags/')
|
||||||
steps:
|
steps:
|
||||||
- name: Trigger binaries sign pipelines
|
- name: Trigger binaries sign pipelines
|
||||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
uses: benc-uk/workflow-dispatch@v1
|
||||||
with:
|
with:
|
||||||
workflow: Sign bin and installer
|
workflow: Sign bin and installer
|
||||||
repo: netbirdio/sign-pipelines
|
repo: netbirdio/sign-pipelines
|
||||||
|
|||||||
4
.github/workflows/sync-main.yml
vendored
4
.github/workflows/sync-main.yml
vendored
@@ -14,9 +14,9 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Trigger main branch sync
|
- name: Trigger main branch sync
|
||||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
uses: benc-uk/workflow-dispatch@v1
|
||||||
with:
|
with:
|
||||||
workflow: sync-main.yml
|
workflow: sync-main.yml
|
||||||
repo: ${{ secrets.UPSTREAM_REPO }}
|
repo: ${{ secrets.UPSTREAM_REPO }}
|
||||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
||||||
inputs: '{ "sha": "${{ github.sha }}" }'
|
inputs: '{ "sha": "${{ github.sha }}" }'
|
||||||
10
.github/workflows/sync-tag.yml
vendored
10
.github/workflows/sync-tag.yml
vendored
@@ -3,7 +3,7 @@ name: sync tag
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- "v*"
|
- 'v*'
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||||
@@ -16,7 +16,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Trigger release tag sync
|
- name: Trigger release tag sync
|
||||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
uses: benc-uk/workflow-dispatch@v1
|
||||||
with:
|
with:
|
||||||
workflow: sync-tag.yml
|
workflow: sync-tag.yml
|
||||||
ref: main
|
ref: main
|
||||||
@@ -29,7 +29,7 @@ jobs:
|
|||||||
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
|
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
|
||||||
steps:
|
steps:
|
||||||
- name: Trigger android-client submodule bump
|
- name: Trigger android-client submodule bump
|
||||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
|
||||||
with:
|
with:
|
||||||
workflow: bump-netbird.yml
|
workflow: bump-netbird.yml
|
||||||
ref: main
|
ref: main
|
||||||
@@ -42,10 +42,10 @@ jobs:
|
|||||||
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
|
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
|
||||||
steps:
|
steps:
|
||||||
- name: Trigger ios-client submodule bump
|
- name: Trigger ios-client submodule bump
|
||||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
|
||||||
with:
|
with:
|
||||||
workflow: bump-netbird.yml
|
workflow: bump-netbird.yml
|
||||||
ref: main
|
ref: main
|
||||||
repo: netbirdio/ios-client
|
repo: netbirdio/ios-client
|
||||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
||||||
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
||||||
26
.github/workflows/test-infrastructure-files.yml
vendored
26
.github/workflows/test-infrastructure-files.yml
vendored
@@ -6,10 +6,10 @@ on:
|
|||||||
- main
|
- main
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- "infrastructure_files/**"
|
- 'infrastructure_files/**'
|
||||||
- ".github/workflows/test-infrastructure-files.yml"
|
- '.github/workflows/test-infrastructure-files.yml'
|
||||||
- "management/cmd/**"
|
- 'management/cmd/**'
|
||||||
- "signal/cmd/**"
|
- 'signal/cmd/**'
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||||
@@ -20,7 +20,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
store: ["sqlite", "postgres", "mysql"]
|
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||||
services:
|
services:
|
||||||
postgres:
|
postgres:
|
||||||
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
|
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
|
||||||
@@ -68,17 +68,15 @@ jobs:
|
|||||||
run: sudo apt-get install -y curl
|
run: sudo apt-get install -y curl
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@v4
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
||||||
@@ -141,8 +139,8 @@ jobs:
|
|||||||
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
|
||||||
CI_NETBIRD_SIGNAL_PORT: 12345
|
CI_NETBIRD_SIGNAL_PORT: 12345
|
||||||
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
|
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
|
||||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN: "${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$"
|
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
|
||||||
NETBIRD_STORE_ENGINE_MYSQL_DSN: "${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$"
|
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
|
||||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||||
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
||||||
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
|
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
|
||||||
@@ -256,9 +254,7 @@ jobs:
|
|||||||
run: sudo apt-get install -y jq
|
run: sudo apt-get install -y jq
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@v4
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
|
|
||||||
- name: run script with Zitadel PostgreSQL
|
- name: run script with Zitadel PostgreSQL
|
||||||
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
|
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
|
||||||
|
|||||||
8
.github/workflows/update-docs.yml
vendored
8
.github/workflows/update-docs.yml
vendored
@@ -3,9 +3,9 @@ name: update docs
|
|||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
tags:
|
tags:
|
||||||
- "v*"
|
- 'v*'
|
||||||
paths:
|
paths:
|
||||||
- "shared/management/http/api/openapi.yml"
|
- 'shared/management/http/api/openapi.yml'
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
trigger_docs_api_update:
|
trigger_docs_api_update:
|
||||||
@@ -13,10 +13,10 @@ jobs:
|
|||||||
if: startsWith(github.ref, 'refs/tags/')
|
if: startsWith(github.ref, 'refs/tags/')
|
||||||
steps:
|
steps:
|
||||||
- name: Trigger API pages generation
|
- name: Trigger API pages generation
|
||||||
uses: benc-uk/workflow-dispatch@31e2b3319479a63f0ab15bf800eff9e913504e26 # v1.3.2
|
uses: benc-uk/workflow-dispatch@v1
|
||||||
with:
|
with:
|
||||||
workflow: generate api pages
|
workflow: generate api pages
|
||||||
repo: netbirdio/docs
|
repo: netbirdio/docs
|
||||||
ref: "refs/heads/main"
|
ref: "refs/heads/main"
|
||||||
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
|
||||||
inputs: '{ "tag": "${{ github.ref }}" }'
|
inputs: '{ "tag": "${{ github.ref }}" }'
|
||||||
19
.github/workflows/wasm-build-validation.yml
vendored
19
.github/workflows/wasm-build-validation.yml
vendored
@@ -19,17 +19,15 @@ jobs:
|
|||||||
GOARCH: wasm
|
GOARCH: wasm
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@v4
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||||
- name: Install golangci-lint
|
- name: Install golangci-lint
|
||||||
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
|
uses: golangci/golangci-lint-action@4afd733a84b1f43292c63897423277bb7f4313a9 # v8.0.0
|
||||||
with:
|
with:
|
||||||
version: latest
|
version: latest
|
||||||
install-mode: binary
|
install-mode: binary
|
||||||
@@ -44,11 +42,9 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
uses: actions/checkout@v4
|
||||||
with:
|
|
||||||
persist-credentials: false
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version-file: "go.mod"
|
go-version-file: "go.mod"
|
||||||
- name: Build Wasm client
|
- name: Build Wasm client
|
||||||
@@ -65,7 +61,8 @@ jobs:
|
|||||||
|
|
||||||
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
||||||
|
|
||||||
if [ ${SIZE} -gt 62914560 ]; then
|
if [ ${SIZE} -gt 58720256 ]; then
|
||||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 60MB limit!"
|
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/server"
|
"github.com/netbirdio/netbird/client/server"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/upload-server/types"
|
"github.com/netbirdio/netbird/upload-server/types"
|
||||||
"github.com/netbirdio/netbird/version"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const errCloseConnection = "Failed to close connection: %v"
|
const errCloseConnection = "Failed to close connection: %v"
|
||||||
@@ -101,7 +100,6 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
|||||||
Anonymize: anonymizeFlag,
|
Anonymize: anonymizeFlag,
|
||||||
SystemInfo: systemInfoFlag,
|
SystemInfo: systemInfoFlag,
|
||||||
LogFileCount: logFileCount,
|
LogFileCount: logFileCount,
|
||||||
CliVersion: version.NetbirdVersion(),
|
|
||||||
}
|
}
|
||||||
if uploadBundleFlag {
|
if uploadBundleFlag {
|
||||||
request.UploadURL = uploadBundleURLFlag
|
request.UploadURL = uploadBundleURLFlag
|
||||||
@@ -300,7 +298,6 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
Anonymize: anonymizeFlag,
|
Anonymize: anonymizeFlag,
|
||||||
SystemInfo: systemInfoFlag,
|
SystemInfo: systemInfoFlag,
|
||||||
LogFileCount: logFileCount,
|
LogFileCount: logFileCount,
|
||||||
CliVersion: version.NetbirdVersion(),
|
|
||||||
}
|
}
|
||||||
if uploadBundleFlag {
|
if uploadBundleFlag {
|
||||||
request.UploadURL = uploadBundleURLFlag
|
request.UploadURL = uploadBundleURLFlag
|
||||||
@@ -435,7 +432,6 @@ func generateDebugBundle(config *profilemanager.Config, recorder *peer.Status, c
|
|||||||
SyncResponse: syncResponse,
|
SyncResponse: syncResponse,
|
||||||
LogPath: logFilePath,
|
LogPath: logFilePath,
|
||||||
CPUProfile: nil,
|
CPUProfile: nil,
|
||||||
DaemonVersion: version.NetbirdVersion(), // acting as daemon
|
|
||||||
},
|
},
|
||||||
debug.BundleConfig{
|
debug.BundleConfig{
|
||||||
IncludeSystemInfo: true,
|
IncludeSystemInfo: true,
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -22,15 +23,21 @@ var serviceCmd = &cobra.Command{
|
|||||||
Short: "Manage the NetBird daemon service",
|
Short: "Manage the NetBird daemon service",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const defaultJSONSocket = "unix:///var/run/netbird-http.sock"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
serviceName string
|
serviceName string
|
||||||
serviceEnvVars []string
|
serviceEnvVars []string
|
||||||
|
jsonSocket string
|
||||||
|
jsonSocketDisabled bool
|
||||||
)
|
)
|
||||||
|
|
||||||
type program struct {
|
type program struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
serv *grpc.Server
|
serv *grpc.Server
|
||||||
|
jsonServ *http.Server
|
||||||
|
jsonServMu sync.Mutex
|
||||||
serverInstance *server.Server
|
serverInstance *server.Server
|
||||||
serverInstanceMu sync.Mutex
|
serverInstanceMu sync.Mutex
|
||||||
}
|
}
|
||||||
@@ -46,6 +53,8 @@ func init() {
|
|||||||
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
|
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
|
||||||
serviceCmd.PersistentFlags().BoolVar(&captureEnabled, "enable-capture", false, "Enables packet capture via 'netbird debug capture'. To persist, use: netbird service install --enable-capture")
|
serviceCmd.PersistentFlags().BoolVar(&captureEnabled, "enable-capture", false, "Enables packet capture via 'netbird debug capture'. To persist, use: netbird service install --enable-capture")
|
||||||
serviceCmd.PersistentFlags().BoolVar(&networksDisabled, "disable-networks", false, "Disables network selection. If enabled, the client will not allow listing, selecting, or deselecting networks. To persist, use: netbird service install --disable-networks")
|
serviceCmd.PersistentFlags().BoolVar(&networksDisabled, "disable-networks", false, "Disables network selection. If enabled, the client will not allow listing, selecting, or deselecting networks. To persist, use: netbird service install --disable-networks")
|
||||||
|
serviceCmd.PersistentFlags().StringVar(&jsonSocket, "json-socket", defaultJSONSocket, "HTTP/JSON API socket address served by grpc-gateway [unix|tcp]://[path|host:port]. To persist, use: netbird service install --json-socket")
|
||||||
|
serviceCmd.PersistentFlags().BoolVar(&jsonSocketDisabled, "disable-json-socket", false, "Disables the HTTP/JSON API socket. To persist, use: netbird service install --disable-json-socket")
|
||||||
|
|
||||||
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
||||||
serviceEnvDesc := `Sets extra environment variables for the service. ` +
|
serviceEnvDesc := `Sets extra environment variables for the service. ` +
|
||||||
|
|||||||
@@ -5,9 +5,6 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/kardianos/service"
|
"github.com/kardianos/service"
|
||||||
@@ -32,31 +29,35 @@ func (p *program) Start(svc service.Service) error {
|
|||||||
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
|
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
|
||||||
p.serv = grpc.NewServer()
|
p.serv = grpc.NewServer()
|
||||||
|
|
||||||
split := strings.Split(daemonAddr, "://")
|
daemonListener, err := listenOnAddress(daemonAddr)
|
||||||
switch split[0] {
|
|
||||||
case "unix":
|
|
||||||
// cleanup failed close
|
|
||||||
stat, err := os.Stat(split[1])
|
|
||||||
if err == nil && !stat.IsDir() {
|
|
||||||
if err := os.Remove(split[1]); err != nil {
|
|
||||||
log.Debugf("remove socket file: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case "tcp":
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unsupported daemon address protocol: %v", split[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
listen, err := net.Listen(split[0], split[1])
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("listen daemon interface: %w", err)
|
return fmt.Errorf("listen daemon interface: %w", err)
|
||||||
}
|
}
|
||||||
go func() {
|
|
||||||
defer listen.Close()
|
|
||||||
|
|
||||||
if split[0] == "unix" {
|
var jsonListener *socketListener
|
||||||
if err := os.Chmod(split[1], 0666); err != nil {
|
if !jsonSocketDisabled {
|
||||||
log.Errorf("failed setting daemon permissions: %v", split[1])
|
jsonListener, err = listenOnAddress(jsonSocket)
|
||||||
|
if err != nil {
|
||||||
|
_ = daemonListener.Close()
|
||||||
|
return fmt.Errorf("listen daemon JSON interface: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
removeStaleUnixSocketForAddress(jsonSocket)
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer daemonListener.Close()
|
||||||
|
if jsonListener != nil {
|
||||||
|
defer jsonListener.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := daemonListener.chmodUnixSocket("daemon"); err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if jsonListener != nil {
|
||||||
|
if err := jsonListener.chmodUnixSocket("daemon JSON"); err != nil {
|
||||||
|
log.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -71,8 +72,16 @@ func (p *program) Start(svc service.Service) error {
|
|||||||
p.serverInstance = serverInstance
|
p.serverInstance = serverInstance
|
||||||
p.serverInstanceMu.Unlock()
|
p.serverInstanceMu.Unlock()
|
||||||
|
|
||||||
log.Printf("started daemon server: %v", split[1])
|
if jsonListener != nil {
|
||||||
if err := p.serv.Serve(listen); err != nil {
|
if err := p.startJSONGateway(jsonListener, daemonAddr); err != nil {
|
||||||
|
log.Fatalf("failed to start daemon JSON server: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Debug("daemon JSON socket disabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("started daemon server: %v", daemonListener.address)
|
||||||
|
if err := p.serv.Serve(daemonListener.Listener); err != nil {
|
||||||
log.Errorf("failed to serve daemon requests: %v", err)
|
log.Errorf("failed to serve daemon requests: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -92,6 +101,20 @@ func (p *program) Stop(srv service.Service) error {
|
|||||||
|
|
||||||
p.cancel()
|
p.cancel()
|
||||||
|
|
||||||
|
p.jsonServMu.Lock()
|
||||||
|
jsonServ := p.jsonServ
|
||||||
|
p.jsonServMu.Unlock()
|
||||||
|
if jsonServ != nil {
|
||||||
|
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
if err := jsonServ.Shutdown(shutdownCtx); err != nil {
|
||||||
|
log.Errorf("failed to stop daemon JSON server gracefully: %v", err)
|
||||||
|
if err := jsonServ.Close(); err != nil {
|
||||||
|
log.Errorf("failed to close daemon JSON server: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
shutdownCancel()
|
||||||
|
}
|
||||||
|
|
||||||
if p.serv != nil {
|
if p.serv != nil {
|
||||||
p.serv.Stop()
|
p.serv.Stop()
|
||||||
}
|
}
|
||||||
@@ -102,7 +125,7 @@ func (p *program) Stop(srv service.Service) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Common setup for service control commands
|
// Common setup for service control commands
|
||||||
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc, consoleLog bool) (service.Service, error) {
|
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) {
|
||||||
// rootCmd env vars are already applied by PersistentPreRunE.
|
// rootCmd env vars are already applied by PersistentPreRunE.
|
||||||
SetFlagsFromEnvVars(serviceCmd)
|
SetFlagsFromEnvVars(serviceCmd)
|
||||||
|
|
||||||
@@ -112,14 +135,8 @@ func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if consoleLog {
|
if err := util.InitLog(logLevel, logFiles...); err != nil {
|
||||||
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
|
return nil, fmt.Errorf("init log: %w", err)
|
||||||
return nil, fmt.Errorf("init log: %w", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if err := util.InitLog(logLevel, logFiles...); err != nil {
|
|
||||||
return nil, fmt.Errorf("init log: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg, err := newSVCConfig()
|
cfg, err := newSVCConfig()
|
||||||
@@ -144,7 +161,7 @@ var runCmd = &cobra.Command{
|
|||||||
SetupCloseHandler(ctx, cancel)
|
SetupCloseHandler(ctx, cancel)
|
||||||
SetupDebugHandler(ctx, nil, nil, nil, util.FindFirstLogPath(logFiles))
|
SetupDebugHandler(ctx, nil, nil, nil, util.FindFirstLogPath(logFiles))
|
||||||
|
|
||||||
s, err := setupServiceControlCommand(cmd, ctx, cancel, false)
|
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -158,7 +175,7 @@ var startCmd = &cobra.Command{
|
|||||||
Short: "starts NetBird service",
|
Short: "starts NetBird service",
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
ctx, cancel := context.WithCancel(cmd.Context())
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
s, err := setupServiceControlCommand(cmd, ctx, cancel, false)
|
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -176,7 +193,7 @@ var stopCmd = &cobra.Command{
|
|||||||
Short: "stops NetBird service",
|
Short: "stops NetBird service",
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
ctx, cancel := context.WithCancel(cmd.Context())
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
s, err := setupServiceControlCommand(cmd, ctx, cancel, false)
|
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -194,7 +211,7 @@ var restartCmd = &cobra.Command{
|
|||||||
Short: "restarts NetBird service",
|
Short: "restarts NetBird service",
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
ctx, cancel := context.WithCancel(cmd.Context())
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
s, err := setupServiceControlCommand(cmd, ctx, cancel, false)
|
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -212,7 +229,7 @@ var svcStatusCmd = &cobra.Command{
|
|||||||
Short: "shows NetBird service status",
|
Short: "shows NetBird service status",
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
ctx, cancel := context.WithCancel(cmd.Context())
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
s, err := setupServiceControlCommand(cmd, ctx, cancel, true)
|
s, err := setupServiceControlCommand(cmd, ctx, cancel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -67,6 +67,11 @@ func buildServiceArguments() []string {
|
|||||||
args = append(args, "--disable-networks")
|
args = append(args, "--disable-networks")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
args = append(args, "--json-socket", jsonSocket)
|
||||||
|
if jsonSocketDisabled {
|
||||||
|
args = append(args, "--disable-json-socket")
|
||||||
|
}
|
||||||
|
|
||||||
return args
|
return args
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
52
client/cmd/service_json_gateway.go
Normal file
52
client/cmd/service_json_gateway.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
//go:build !ios && !android
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func grpcGatewayEndpoint(addr string) string {
|
||||||
|
return strings.TrimPrefix(addr, "tcp://")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *program) startJSONGateway(jsonListener *socketListener, daemonEndpoint string) error {
|
||||||
|
mux := runtime.NewServeMux()
|
||||||
|
opts := []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}
|
||||||
|
if err := proto.RegisterDaemonServiceHandlerFromEndpoint(p.ctx, mux, grpcGatewayEndpoint(daemonEndpoint), opts); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonServer := &http.Server{
|
||||||
|
Handler: mux,
|
||||||
|
ReadHeaderTimeout: 5 * time.Second,
|
||||||
|
BaseContext: func(net.Listener) context.Context {
|
||||||
|
return p.ctx
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
p.jsonServMu.Lock()
|
||||||
|
p.jsonServ = jsonServer
|
||||||
|
p.jsonServMu.Unlock()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
log.Printf("started daemon JSON server: %v", jsonListener.address)
|
||||||
|
if err := jsonServer.Serve(jsonListener.Listener); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
log.Errorf("failed to serve daemon JSON requests: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -23,6 +23,7 @@ const serviceParamsFile = "service.json"
|
|||||||
type serviceParams struct {
|
type serviceParams struct {
|
||||||
LogLevel string `json:"log_level"`
|
LogLevel string `json:"log_level"`
|
||||||
DaemonAddr string `json:"daemon_addr"`
|
DaemonAddr string `json:"daemon_addr"`
|
||||||
|
JSONSocket string `json:"json_socket"`
|
||||||
ManagementURL string `json:"management_url,omitempty"`
|
ManagementURL string `json:"management_url,omitempty"`
|
||||||
ConfigPath string `json:"config_path,omitempty"`
|
ConfigPath string `json:"config_path,omitempty"`
|
||||||
LogFiles []string `json:"log_files,omitempty"`
|
LogFiles []string `json:"log_files,omitempty"`
|
||||||
@@ -30,6 +31,7 @@ type serviceParams struct {
|
|||||||
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
|
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
|
||||||
EnableCapture bool `json:"enable_capture,omitempty"`
|
EnableCapture bool `json:"enable_capture,omitempty"`
|
||||||
DisableNetworks bool `json:"disable_networks,omitempty"`
|
DisableNetworks bool `json:"disable_networks,omitempty"`
|
||||||
|
DisableJSONSocket bool `json:"disable_json_socket,omitempty"`
|
||||||
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
|
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -75,6 +77,7 @@ func currentServiceParams() *serviceParams {
|
|||||||
params := &serviceParams{
|
params := &serviceParams{
|
||||||
LogLevel: logLevel,
|
LogLevel: logLevel,
|
||||||
DaemonAddr: daemonAddr,
|
DaemonAddr: daemonAddr,
|
||||||
|
JSONSocket: jsonSocket,
|
||||||
ManagementURL: managementURL,
|
ManagementURL: managementURL,
|
||||||
ConfigPath: configPath,
|
ConfigPath: configPath,
|
||||||
LogFiles: logFiles,
|
LogFiles: logFiles,
|
||||||
@@ -82,6 +85,7 @@ func currentServiceParams() *serviceParams {
|
|||||||
DisableUpdateSettings: updateSettingsDisabled,
|
DisableUpdateSettings: updateSettingsDisabled,
|
||||||
EnableCapture: captureEnabled,
|
EnableCapture: captureEnabled,
|
||||||
DisableNetworks: networksDisabled,
|
DisableNetworks: networksDisabled,
|
||||||
|
DisableJSONSocket: jsonSocketDisabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(serviceEnvVars) > 0 {
|
if len(serviceEnvVars) > 0 {
|
||||||
@@ -113,9 +117,8 @@ func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// For fields with non-empty defaults (log-level, daemon-addr), keep the
|
// For fields with non-empty defaults, keep the != "" guard so that an older
|
||||||
// != "" guard so that an older service.json missing the field doesn't
|
// service.json missing the field doesn't clobber the default with an empty string.
|
||||||
// clobber the default with an empty string.
|
|
||||||
if !rootCmd.PersistentFlags().Changed("log-level") && params.LogLevel != "" {
|
if !rootCmd.PersistentFlags().Changed("log-level") && params.LogLevel != "" {
|
||||||
logLevel = params.LogLevel
|
logLevel = params.LogLevel
|
||||||
}
|
}
|
||||||
@@ -124,6 +127,20 @@ func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
|
|||||||
daemonAddr = params.DaemonAddr
|
daemonAddr = params.DaemonAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
jsonSocketChanged := serviceCmd.PersistentFlags().Changed("json-socket")
|
||||||
|
if !jsonSocketChanged && params.JSONSocket != "" {
|
||||||
|
jsonSocket = params.JSONSocket
|
||||||
|
}
|
||||||
|
|
||||||
|
if !serviceCmd.PersistentFlags().Changed("disable-json-socket") {
|
||||||
|
jsonSocketDisabled = params.DisableJSONSocket
|
||||||
|
// Passing --json-socket should re-enable the JSON gateway unless
|
||||||
|
// --disable-json-socket was explicitly provided too.
|
||||||
|
if jsonSocketChanged {
|
||||||
|
jsonSocketDisabled = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// For optional fields where empty means "use default", always apply so
|
// For optional fields where empty means "use default", always apply so
|
||||||
// that an explicit clear (--management-url "") persists across reinstalls.
|
// that an explicit clear (--management-url "") persists across reinstalls.
|
||||||
if !rootCmd.PersistentFlags().Changed("management-url") {
|
if !rootCmd.PersistentFlags().Changed("management-url") {
|
||||||
|
|||||||
@@ -530,6 +530,7 @@ func fieldToGlobalVar(field string) string {
|
|||||||
m := map[string]string{
|
m := map[string]string{
|
||||||
"LogLevel": "logLevel",
|
"LogLevel": "logLevel",
|
||||||
"DaemonAddr": "daemonAddr",
|
"DaemonAddr": "daemonAddr",
|
||||||
|
"JSONSocket": "jsonSocket",
|
||||||
"ManagementURL": "managementURL",
|
"ManagementURL": "managementURL",
|
||||||
"ConfigPath": "configPath",
|
"ConfigPath": "configPath",
|
||||||
"LogFiles": "logFiles",
|
"LogFiles": "logFiles",
|
||||||
@@ -537,6 +538,7 @@ func fieldToGlobalVar(field string) string {
|
|||||||
"DisableUpdateSettings": "updateSettingsDisabled",
|
"DisableUpdateSettings": "updateSettingsDisabled",
|
||||||
"EnableCapture": "captureEnabled",
|
"EnableCapture": "captureEnabled",
|
||||||
"DisableNetworks": "networksDisabled",
|
"DisableNetworks": "networksDisabled",
|
||||||
|
"DisableJSONSocket": "jsonSocketDisabled",
|
||||||
"ServiceEnvVars": "serviceEnvVars",
|
"ServiceEnvVars": "serviceEnvVars",
|
||||||
}
|
}
|
||||||
if v, ok := m[field]; ok {
|
if v, ok := m[field]; ok {
|
||||||
|
|||||||
83
client/cmd/service_socket.go
Normal file
83
client/cmd/service_socket.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
//go:build !ios && !android
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type socketListener struct {
|
||||||
|
net.Listener
|
||||||
|
network string
|
||||||
|
address string
|
||||||
|
}
|
||||||
|
|
||||||
|
func listenOnAddress(addr string) (*socketListener, error) {
|
||||||
|
network, address, err := parseListenAddress(addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if network == "unix" {
|
||||||
|
removeStaleUnixSocket(address)
|
||||||
|
}
|
||||||
|
|
||||||
|
listener, err := net.Listen(network, address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &socketListener{Listener: listener, network: network, address: address}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseListenAddress(addr string) (string, string, error) {
|
||||||
|
network, address, ok := strings.Cut(addr, "://")
|
||||||
|
if !ok || network == "" || address == "" {
|
||||||
|
return "", "", fmt.Errorf("address must be in [unix|tcp]://[path|host:port] format: %q", addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch network {
|
||||||
|
case "unix", "tcp":
|
||||||
|
return network, address, nil
|
||||||
|
default:
|
||||||
|
return "", "", fmt.Errorf("unsupported daemon address protocol: %v", network)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeStaleUnixSocket(path string) {
|
||||||
|
stat, err := os.Stat(path)
|
||||||
|
if err == nil && !stat.IsDir() {
|
||||||
|
if err := os.Remove(path); err != nil {
|
||||||
|
log.Debugf("remove socket file: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil && !os.IsNotExist(err) {
|
||||||
|
log.Debugf("stat socket file: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeStaleUnixSocketForAddress(addr string) {
|
||||||
|
network, address, err := parseListenAddress(addr)
|
||||||
|
if err != nil || network != "unix" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
removeStaleUnixSocket(address)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *socketListener) chmodUnixSocket(description string) error {
|
||||||
|
if l == nil || l.network != "unix" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.Chmod(l.address, 0666); err != nil {
|
||||||
|
return fmt.Errorf("failed setting %s permissions for %s: %w", description, l.address, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -12,13 +12,7 @@ var (
|
|||||||
Short: "Print the NetBird's client application version",
|
Short: "Print the NetBird's client application version",
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
cmd.SetOut(cmd.OutOrStdout())
|
cmd.SetOut(cmd.OutOrStdout())
|
||||||
out := version.NetbirdVersion()
|
cmd.Println(version.NetbirdVersion())
|
||||||
if version.IsDevelopmentVersion(out) {
|
|
||||||
if commit := version.NetbirdCommit(); commit != "" {
|
|
||||||
out += "-" + commit
|
|
||||||
}
|
|
||||||
}
|
|
||||||
cmd.Println(out)
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,11 +0,0 @@
|
|||||||
//go:build android || (!linux && !windows)
|
|
||||||
|
|
||||||
package firewall
|
|
||||||
|
|
||||||
import "github.com/netbirdio/netbird/client/firewall/uspfilter"
|
|
||||||
|
|
||||||
// interfaceAllower returns no allower: these platforms have no host firewall to
|
|
||||||
// open for the interface.
|
|
||||||
func interfaceAllower(IFaceMapper, uint16) uspfilter.InterfaceAllower {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
//go:build windows
|
|
||||||
|
|
||||||
package firewall
|
|
||||||
|
|
||||||
import "github.com/netbirdio/netbird/client/firewall/uspfilter"
|
|
||||||
|
|
||||||
// interfaceAllower returns the Windows netsh-based interface allower.
|
|
||||||
func interfaceAllower(iface IFaceMapper, _ uint16) uspfilter.InterfaceAllower {
|
|
||||||
return uspfilter.NewWindowsInterfaceAllower(iface)
|
|
||||||
}
|
|
||||||
@@ -6,6 +6,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
@@ -19,11 +21,13 @@ func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.
|
|||||||
}
|
}
|
||||||
|
|
||||||
// use userspace packet filtering firewall
|
// use userspace packet filtering firewall
|
||||||
return uspfilter.Create(uspfilter.Config{
|
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
|
||||||
IFace: iface,
|
if err != nil {
|
||||||
DisableServerRoutes: disableServerRoutes,
|
return nil, err
|
||||||
FlowLogger: flowLogger,
|
}
|
||||||
MTU: mtu,
|
err = fm.AllowNetbird()
|
||||||
InterfaceAllower: interfaceAllower(iface, mtu),
|
if err != nil {
|
||||||
})
|
log.Warnf("failed to allow netbird interface traffic: %v", err)
|
||||||
|
}
|
||||||
|
return fm, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
@@ -30,107 +29,47 @@ const (
|
|||||||
NFTABLES
|
NFTABLES
|
||||||
)
|
)
|
||||||
|
|
||||||
// SkipNftablesEnv is the environment variable to skip nftables check
|
// SKIP_NFTABLES_ENV is the environment variable to skip nftables check
|
||||||
const SkipNftablesEnv = "NB_SKIP_NFTABLES_CHECK"
|
const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
||||||
|
|
||||||
// errNoFirewallManager indicates no kernel firewall backend is present,
|
|
||||||
// as opposed to a backend that exists but failed to create or initialize.
|
|
||||||
var errNoFirewallManager = errors.New("no firewall manager found")
|
|
||||||
|
|
||||||
// FWType is the type for the firewall type
|
// FWType is the type for the firewall type
|
||||||
type FWType int
|
type FWType int
|
||||||
|
|
||||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
|
||||||
// Userspace firewall without a native counterpart: routing is handled
|
// We run in userspace mode and force userspace firewall was requested. We don't attempt native firewall.
|
||||||
// entirely in userspace. The interface is opened in the kernel's foreign
|
if iface.IsUserspaceBind() && forceUserspaceFirewall() {
|
||||||
// filter chains via a table-less allower, except in netstack mode where no
|
log.Info("forcing userspace firewall")
|
||||||
// kernel interface exists.
|
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||||
if netstack.IsEnabled() || (iface.IsUserspaceBind() && forceUserspaceFirewall()) {
|
|
||||||
if netstack.IsEnabled() {
|
|
||||||
log.Info("netstack mode, using userspace firewall")
|
|
||||||
} else {
|
|
||||||
log.Info("forcing userspace firewall")
|
|
||||||
}
|
|
||||||
cfg := uspfilter.Config{
|
|
||||||
IFace: iface,
|
|
||||||
DisableServerRoutes: disableServerRoutes,
|
|
||||||
FlowLogger: flowLogger,
|
|
||||||
MTU: mtu,
|
|
||||||
InterfaceAllower: interfaceAllower(iface, mtu),
|
|
||||||
}
|
|
||||||
|
|
||||||
return uspfilter.Create(cfg)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use native firewall for either kernel or userspace, the interface appears identical to netfilter
|
// Use native firewall for either kernel or userspace, the interface appears identical to netfilter
|
||||||
fm, err := createNativeFirewall(iface, stateManager, mtu)
|
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
|
||||||
switch {
|
|
||||||
case err == nil && !iface.IsUserspaceBind():
|
// Kernel cannot fall back to anything else, need to return error
|
||||||
// Nothing to do, fall through
|
if !iface.IsUserspaceBind() {
|
||||||
case err == nil && iface.IsUserspaceBind():
|
return fm, err
|
||||||
// Native firewall handles packet filtering, but the userspace WireGuard bind
|
}
|
||||||
// needs a device filter for DNS interception hooks. Install a minimal
|
|
||||||
// hooks-only filter that passes all traffic through to the kernel firewall.
|
// Fall back to the userspace packet filter if native is unavailable
|
||||||
if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err)
|
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||||
}
|
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||||
case err != nil && !iface.IsUserspaceBind():
|
}
|
||||||
// Kernel cannot fall back to anything else, need to return error
|
|
||||||
return nil, err
|
// Native firewall handles packet filtering, but the userspace WireGuard bind
|
||||||
case err != nil && iface.IsUserspaceBind():
|
// needs a device filter for DNS interception hooks. Install a minimal
|
||||||
// Fall back to the userspace packet filter if native is unavailable
|
// hooks-only filter that passes all traffic through to the kernel firewall.
|
||||||
logNativeFirewallUnavailable(err)
|
if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil {
|
||||||
return uspfilter.Create(uspfilter.Config{
|
log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err)
|
||||||
IFace: iface,
|
|
||||||
DisableServerRoutes: disableServerRoutes,
|
|
||||||
FlowLogger: flowLogger,
|
|
||||||
MTU: mtu,
|
|
||||||
InterfaceAllower: interfaceAllower(iface, mtu),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return fm, nil
|
return fm, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// interfaceAllower selects how the userspace firewall opens the interface in
|
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
|
||||||
// foreign kernel chains: nftables when available (which also opens foreign nft
|
|
||||||
// tables), else iptables (the legacy fallback, filter INPUT only), else nil.
|
|
||||||
// firewalld trust is applied separately by the manager. Netstack has no kernel
|
|
||||||
// interface to open.
|
|
||||||
func interfaceAllower(iface IFaceMapper, mtu uint16) uspfilter.InterfaceAllower {
|
|
||||||
if netstack.IsEnabled() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
nftAllower, err := nbnftables.NewInterfaceAllower(iface, mtu)
|
|
||||||
if err == nil {
|
|
||||||
return nftAllower
|
|
||||||
}
|
|
||||||
log.Infof("no nftables interface allower: %v", err)
|
|
||||||
|
|
||||||
iptAllower, err := nbiptables.NewInterfaceAllower(iface)
|
|
||||||
if err == nil {
|
|
||||||
return iptAllower
|
|
||||||
}
|
|
||||||
log.Infof("no iptables interface allower: %v", err)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// logNativeFirewallUnavailable logs the fallback to userspace at info level
|
|
||||||
// when no kernel firewall backend exists, and at warn level otherwise.
|
|
||||||
func logNativeFirewallUnavailable(err error) {
|
|
||||||
if errors.Is(err, errNoFirewallManager) {
|
|
||||||
log.Infof("no native firewall backend available: %v. Proceeding with userspace", err)
|
|
||||||
} else {
|
|
||||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, mtu uint16) (firewall.Manager, error) {
|
|
||||||
fm, err := createFW(iface, mtu)
|
fm, err := createFW(iface, mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create firewall: %w", err)
|
return nil, fmt.Errorf("create firewall: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = fm.Init(stateManager); err != nil {
|
if err = fm.Init(stateManager); err != nil {
|
||||||
@@ -149,10 +88,29 @@ func createFW(iface IFaceMapper, mtu uint16) (firewall.Manager, error) {
|
|||||||
log.Info("creating an nftables firewall manager")
|
log.Info("creating an nftables firewall manager")
|
||||||
return nbnftables.Create(iface, mtu)
|
return nbnftables.Create(iface, mtu)
|
||||||
default:
|
default:
|
||||||
return nil, errNoFirewallManager
|
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
|
||||||
|
return nil, errors.New("no firewall manager found")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (firewall.Manager, error) {
|
||||||
|
var errUsp error
|
||||||
|
if fm != nil {
|
||||||
|
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
|
||||||
|
} else {
|
||||||
|
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
|
||||||
|
}
|
||||||
|
|
||||||
|
if errUsp != nil {
|
||||||
|
return nil, fmt.Errorf("create userspace firewall: %s", errUsp)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := fm.AllowNetbird(); err != nil {
|
||||||
|
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
||||||
|
}
|
||||||
|
return fm, nil
|
||||||
|
}
|
||||||
|
|
||||||
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
|
||||||
func check() FWType {
|
func check() FWType {
|
||||||
useIPTABLES := false
|
useIPTABLES := false
|
||||||
@@ -174,38 +132,35 @@ func check() FWType {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Honor the skip env before probing nftables at all.
|
nf := nftables.Conn{}
|
||||||
if os.Getenv(SkipNftablesEnv) != "true" {
|
if chains, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
|
||||||
nf := nftables.Conn{}
|
if !useIPTABLES {
|
||||||
if chains, err := nf.ListChains(); err == nil {
|
return NFTABLES
|
||||||
if !useIPTABLES {
|
}
|
||||||
|
|
||||||
|
// search for chains where table is filter
|
||||||
|
// if we find one, we assume that nftables manager can be used with iptables
|
||||||
|
for _, chain := range chains {
|
||||||
|
if chain.Table.Name == "filter" {
|
||||||
return NFTABLES
|
return NFTABLES
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// search for chains where table is filter
|
// check tables for the following constraints:
|
||||||
// if we find one, we assume that nftables manager can be used with iptables
|
// 1. there is no chain in nftables for the filter table and there is at least one chain in iptables, we assume that nftables manager can not be used
|
||||||
for _, chain := range chains {
|
// 2. there is no tables or more than one table, we assume that nftables manager can be used
|
||||||
if chain.Table.Name == "filter" {
|
// 3. there is only one table and its name is filter, we assume that nftables manager can not be used, since there was no chain in it
|
||||||
return NFTABLES
|
// 4. if we find an error we log and continue with iptables check
|
||||||
}
|
nbTablesList, err := nf.ListTables()
|
||||||
}
|
switch {
|
||||||
|
case err == nil && len(iptablesChains) > 0:
|
||||||
// check tables for the following constraints:
|
return IPTABLES
|
||||||
// 1. there is no chain in nftables for the filter table and there is at least one chain in iptables, we assume that nftables manager can not be used
|
case err == nil && len(nbTablesList) != 1:
|
||||||
// 2. there is no tables or more than one table, we assume that nftables manager can be used
|
return NFTABLES
|
||||||
// 3. there is only one table and its name is filter, we assume that nftables manager can not be used, since there was no chain in it
|
case err == nil && len(nbTablesList) == 1 && nbTablesList[0].Name == "filter":
|
||||||
// 4. if we find an error we log and continue with iptables check
|
return IPTABLES
|
||||||
nbTablesList, err := nf.ListTables()
|
case err != nil:
|
||||||
switch {
|
log.Errorf("failed to list nftables tables on fw manager discovery: %s", err)
|
||||||
case err == nil && len(iptablesChains) > 0:
|
|
||||||
return IPTABLES
|
|
||||||
case err == nil && len(nbTablesList) != 1:
|
|
||||||
return NFTABLES
|
|
||||||
case err == nil && len(nbTablesList) == 1 && nbTablesList[0].Name == "filter":
|
|
||||||
return IPTABLES
|
|
||||||
case err != nil:
|
|
||||||
log.Errorf("failed to list nftables tables on fw manager discovery: %s", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -221,21 +176,15 @@ func isIptablesClientAvailable(client *iptables.IPTables) bool {
|
|||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// forceUserspaceFirewall reports whether the userspace firewall is forced.
|
|
||||||
// NB_FORCE_USERSPACE_ROUTER is an alias: forcing userspace routing implies the
|
|
||||||
// userspace firewall, since the two are no longer separable.
|
|
||||||
func forceUserspaceFirewall() bool {
|
func forceUserspaceFirewall() bool {
|
||||||
return envForceBool(EnvForceUserspaceFirewall) || envForceBool(uspfilter.EnvForceUserspaceRouter)
|
val := os.Getenv(EnvForceUserspaceFirewall)
|
||||||
}
|
|
||||||
|
|
||||||
func envForceBool(name string) bool {
|
|
||||||
val := os.Getenv(name)
|
|
||||||
if val == "" {
|
if val == "" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
force, err := strconv.ParseBool(val)
|
force, err := strconv.ParseBool(val)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to parse %s: %v", name, err)
|
log.Warnf("failed to parse %s: %v", EnvForceUserspaceFirewall, err)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return force
|
return force
|
||||||
|
|||||||
554
client/firewall/iptables/acl_linux.go
Normal file
554
client/firewall/iptables/acl_linux.go
Normal file
@@ -0,0 +1,554 @@
|
|||||||
|
package iptables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
"github.com/coreos/go-iptables/iptables"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
ipset "github.com/lrh3321/ipset-go"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
tableName = "filter"
|
||||||
|
|
||||||
|
// rules chains contains the effective ACL rules
|
||||||
|
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
||||||
|
|
||||||
|
// mangleFwdKey is the entries map key for mangle FORWARD guard rules that prevent
|
||||||
|
// external DNAT from bypassing ACL rules.
|
||||||
|
mangleFwdKey = "MANGLE-FORWARD"
|
||||||
|
)
|
||||||
|
|
||||||
|
type aclEntries map[string][][]string
|
||||||
|
|
||||||
|
type entry struct {
|
||||||
|
spec []string
|
||||||
|
position int
|
||||||
|
}
|
||||||
|
|
||||||
|
type aclManager struct {
|
||||||
|
iptablesClient *iptables.IPTables
|
||||||
|
wgIface iFaceMapper
|
||||||
|
entries aclEntries
|
||||||
|
optionalEntries map[string][]entry
|
||||||
|
ipsetStore *ipsetStore
|
||||||
|
v6 bool
|
||||||
|
|
||||||
|
stateManager *statemanager.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
|
||||||
|
return &aclManager{
|
||||||
|
iptablesClient: iptablesClient,
|
||||||
|
wgIface: wgIface,
|
||||||
|
entries: make(map[string][][]string),
|
||||||
|
optionalEntries: make(map[string][]entry),
|
||||||
|
ipsetStore: newIpsetStore(),
|
||||||
|
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
||||||
|
m.stateManager = stateManager
|
||||||
|
|
||||||
|
m.seedInitialEntries()
|
||||||
|
m.seedInitialOptionalEntries()
|
||||||
|
|
||||||
|
if err := m.cleanChains(); err != nil {
|
||||||
|
return fmt.Errorf("clean chains: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.createDefaultChains(); err != nil {
|
||||||
|
return fmt.Errorf("create default chains: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.updateState()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
|
ip net.IP,
|
||||||
|
protocol firewall.Protocol,
|
||||||
|
sPort *firewall.Port,
|
||||||
|
dPort *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
ipsetName string,
|
||||||
|
) ([]firewall.Rule, error) {
|
||||||
|
chain := chainNameInputRules
|
||||||
|
|
||||||
|
ipsetName = transformIPsetName(ipsetName, sPort, dPort, action)
|
||||||
|
if m.v6 && ipsetName != "" {
|
||||||
|
ipsetName += "-v6"
|
||||||
|
}
|
||||||
|
proto := protoForFamily(protocol, m.v6)
|
||||||
|
specs := filterRuleSpecs(ip, proto, sPort, dPort, action, ipsetName)
|
||||||
|
|
||||||
|
mangleSpecs := slices.Clone(specs)
|
||||||
|
mangleSpecs = append(mangleSpecs,
|
||||||
|
"-i", m.wgIface.Name(),
|
||||||
|
"-m", "addrtype", "--dst-type", "LOCAL",
|
||||||
|
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
||||||
|
)
|
||||||
|
|
||||||
|
specs = append(specs, "-j", actionToStr(action))
|
||||||
|
if ipsetName != "" {
|
||||||
|
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
||||||
|
if err := m.addToIPSet(ipsetName, ip); err != nil {
|
||||||
|
return nil, fmt.Errorf("add IP to ipset: %w", err)
|
||||||
|
}
|
||||||
|
// if ruleset already exists it means we already have the firewall rule
|
||||||
|
// so we need to update IPs in the ruleset and return new fw.Rule object for ACL manager.
|
||||||
|
ipList.addIP(ip.String())
|
||||||
|
return []firewall.Rule{&Rule{
|
||||||
|
ruleID: uuid.New().String(),
|
||||||
|
ipsetName: ipsetName,
|
||||||
|
ip: ip.String(),
|
||||||
|
chain: chain,
|
||||||
|
specs: specs,
|
||||||
|
v6: m.v6,
|
||||||
|
}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.flushIPSet(ipsetName); err != nil {
|
||||||
|
if errors.Is(err, ipset.ErrSetNotExist) {
|
||||||
|
log.Debugf("flush ipset %s before use: %v", ipsetName, err)
|
||||||
|
} else {
|
||||||
|
log.Errorf("flush ipset %s before use: %v", ipsetName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := m.createIPSet(ipsetName); err != nil {
|
||||||
|
return nil, fmt.Errorf("create ipset: %w", err)
|
||||||
|
}
|
||||||
|
if err := m.addToIPSet(ipsetName, ip); err != nil {
|
||||||
|
return nil, fmt.Errorf("add IP to ipset: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ipList := newIpList(ip.String())
|
||||||
|
m.ipsetStore.addIpList(ipsetName, ipList)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err := m.iptablesClient.Exists(tableFilter, chain, specs...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to check rule: %w", err)
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
return nil, fmt.Errorf("rule already exists")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||||
|
if action == firewall.ActionDrop {
|
||||||
|
// Insert at the beginning of the chain (position 1)
|
||||||
|
err = m.iptablesClient.Insert(tableFilter, chain, 1, specs...)
|
||||||
|
} else {
|
||||||
|
err = m.iptablesClient.Append(tableFilter, chain, specs...)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.iptablesClient.Append(tableMangle, chainRTPRE, mangleSpecs...); err != nil {
|
||||||
|
log.Errorf("failed to add mangle rule: %v", err)
|
||||||
|
mangleSpecs = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := &Rule{
|
||||||
|
ruleID: uuid.New().String(),
|
||||||
|
specs: specs,
|
||||||
|
mangleSpecs: mangleSpecs,
|
||||||
|
ipsetName: ipsetName,
|
||||||
|
ip: ip.String(),
|
||||||
|
chain: chain,
|
||||||
|
v6: m.v6,
|
||||||
|
}
|
||||||
|
|
||||||
|
m.updateState()
|
||||||
|
|
||||||
|
return []firewall.Rule{rule}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePeerRule from the firewall by rule definition
|
||||||
|
func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||||
|
r, ok := rule.(*Rule)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("invalid rule type")
|
||||||
|
}
|
||||||
|
|
||||||
|
shouldDestroyIpset := false
|
||||||
|
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
|
||||||
|
// delete IP from ruleset IPs list and ipset
|
||||||
|
if _, ok := ipsetList.ips[r.ip]; ok {
|
||||||
|
ip := net.ParseIP(r.ip)
|
||||||
|
if ip == nil {
|
||||||
|
return fmt.Errorf("parse IP %s", r.ip)
|
||||||
|
}
|
||||||
|
if err := m.delFromIPSet(r.ipsetName, ip); err != nil {
|
||||||
|
return fmt.Errorf("delete ip from ipset: %w", err)
|
||||||
|
}
|
||||||
|
delete(ipsetList.ips, r.ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// if after delete, set still contains other IPs,
|
||||||
|
// no need to delete firewall rule and we should exit here
|
||||||
|
if len(ipsetList.ips) != 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// we delete last IP from the set, that means we need to delete
|
||||||
|
// set itself and associated firewall rule too
|
||||||
|
m.ipsetStore.deleteIpset(r.ipsetName)
|
||||||
|
shouldDestroyIpset = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.mangleSpecs != nil {
|
||||||
|
if err := m.iptablesClient.Delete(tableMangle, chainRTPRE, r.mangleSpecs...); err != nil {
|
||||||
|
log.Errorf("failed to delete mangle rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if shouldDestroyIpset {
|
||||||
|
if err := m.destroyIPSet(r.ipsetName); err != nil {
|
||||||
|
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
|
||||||
|
log.Debugf("destroy empty ipset: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Errorf("destroy empty ipset: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.updateState()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) Reset() error {
|
||||||
|
if err := m.cleanChains(); err != nil {
|
||||||
|
return fmt.Errorf("clean chains: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.updateState()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// todo write less destructive cleanup mechanism
|
||||||
|
func (m *aclManager) cleanChains() error {
|
||||||
|
ok, err := m.iptablesClient.ChainExists(tableName, chainNameInputRules)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to list chains: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
for _, rule := range m.entries["INPUT"] {
|
||||||
|
err := m.iptablesClient.DeleteIfExists(tableName, "INPUT", rule...)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range m.entries["FORWARD"] {
|
||||||
|
err := m.iptablesClient.DeleteIfExists(tableName, "FORWARD", rule...)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameInputRules)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to clear and delete %s chain: %s", chainNameInputRules, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("list chains: %w", err)
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
for _, rule := range m.entries["PREROUTING"] {
|
||||||
|
err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range m.entries[mangleFwdKey] {
|
||||||
|
if err := m.iptablesClient.DeleteIfExists(tableMangle, chainFORWARD, rule...); err != nil {
|
||||||
|
log.Errorf("failed to delete mangle FORWARD guard rule: %v, %s", rule, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ipsetName := range m.ipsetStore.ipsetNames() {
|
||||||
|
if err := m.flushIPSet(ipsetName); err != nil {
|
||||||
|
if errors.Is(err, ipset.ErrSetNotExist) {
|
||||||
|
log.Debugf("flush ipset %q during reset: %v", ipsetName, err)
|
||||||
|
} else {
|
||||||
|
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := m.destroyIPSet(ipsetName); err != nil {
|
||||||
|
if errors.Is(err, ipset.ErrBusy) || errors.Is(err, ipset.ErrSetNotExist) {
|
||||||
|
log.Debugf("destroy ipset %q during reset: %v", ipsetName, err)
|
||||||
|
} else {
|
||||||
|
log.Errorf("destroy ipset %q during reset: %v", ipsetName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.ipsetStore.deleteIpset(ipsetName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) createDefaultChains() error {
|
||||||
|
// chain netbird-acl-input-rules
|
||||||
|
if err := m.iptablesClient.NewChain(tableName, chainNameInputRules); err != nil {
|
||||||
|
log.Debugf("failed to create '%s' chain: %s", chainNameInputRules, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for chainName, rules := range m.entries {
|
||||||
|
// mangle FORWARD guard rules are handled separately below
|
||||||
|
if chainName == mangleFwdKey {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, rule := range rules {
|
||||||
|
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
|
||||||
|
log.Debugf("failed to create input chain jump rule: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for chainName, entries := range m.optionalEntries {
|
||||||
|
for _, entry := range entries {
|
||||||
|
if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil {
|
||||||
|
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
m.entries[chainName] = append(m.entries[chainName], entry.spec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
clear(m.optionalEntries)
|
||||||
|
|
||||||
|
// Insert mangle FORWARD guard rules to prevent external DNAT bypass.
|
||||||
|
for _, rule := range m.entries[mangleFwdKey] {
|
||||||
|
if err := m.iptablesClient.AppendUnique(tableMangle, chainFORWARD, rule...); err != nil {
|
||||||
|
log.Errorf("failed to add mangle FORWARD guard rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// seedInitialEntries adds default rules to the entries map, rules are inserted on pos 1, hence the order is reversed.
|
||||||
|
// We want to make sure our traffic is not dropped by existing rules.
|
||||||
|
|
||||||
|
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
||||||
|
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||||
|
func (m *aclManager) seedInitialEntries() {
|
||||||
|
established := getConntrackEstablished()
|
||||||
|
|
||||||
|
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||||
|
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
||||||
|
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
|
||||||
|
|
||||||
|
// Inbound is handled by our ACLs, the rest is dropped.
|
||||||
|
// For outbound we respect the FORWARD policy. However, we need to allow established/related traffic for inbound rules.
|
||||||
|
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||||
|
|
||||||
|
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
|
||||||
|
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
|
||||||
|
|
||||||
|
// Mangle FORWARD guard: when external DNAT redirects traffic from the wg interface, it
|
||||||
|
// traverses FORWARD instead of INPUT, bypassing ACL rules. ACCEPT rules in filter FORWARD
|
||||||
|
// can be inserted above ours. Mangle runs before filter, so these guard rules enforce the
|
||||||
|
// ACL mark check where it cannot be overridden.
|
||||||
|
m.appendToEntries(mangleFwdKey, []string{
|
||||||
|
"-i", m.wgIface.Name(),
|
||||||
|
"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED",
|
||||||
|
"-j", "ACCEPT",
|
||||||
|
})
|
||||||
|
m.appendToEntries(mangleFwdKey, []string{
|
||||||
|
"-i", m.wgIface.Name(),
|
||||||
|
"-m", "conntrack", "--ctstate", "DNAT",
|
||||||
|
"-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
||||||
|
"-j", "DROP",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) seedInitialOptionalEntries() {
|
||||||
|
m.optionalEntries["FORWARD"] = []entry{
|
||||||
|
{
|
||||||
|
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
|
||||||
|
position: 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
||||||
|
m.entries[chainName] = append(m.entries[chainName], spec)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) updateState() {
|
||||||
|
if m.stateManager == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var currentState *ShutdownState
|
||||||
|
if existing := m.stateManager.GetState(currentState); existing != nil {
|
||||||
|
if existingState, ok := existing.(*ShutdownState); ok {
|
||||||
|
currentState = existingState
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if currentState == nil {
|
||||||
|
currentState = &ShutdownState{}
|
||||||
|
}
|
||||||
|
|
||||||
|
currentState.Lock()
|
||||||
|
defer currentState.Unlock()
|
||||||
|
|
||||||
|
if m.v6 {
|
||||||
|
currentState.ACLEntries6 = m.entries
|
||||||
|
currentState.ACLIPsetStore6 = m.ipsetStore
|
||||||
|
} else {
|
||||||
|
currentState.ACLEntries = m.entries
|
||||||
|
currentState.ACLIPsetStore = m.ipsetStore
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.stateManager.UpdateState(currentState); err != nil {
|
||||||
|
log.Errorf("failed to update state: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterRuleSpecs returns the specs of a filtering rule
|
||||||
|
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
|
||||||
|
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
|
||||||
|
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
|
||||||
|
if v6 && protocol == firewall.ProtocolICMP {
|
||||||
|
return "ipv6-icmp"
|
||||||
|
}
|
||||||
|
return string(protocol)
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
||||||
|
// don't use IP matching if IP is 0.0.0.0
|
||||||
|
matchByIP := !ip.IsUnspecified()
|
||||||
|
|
||||||
|
if matchByIP {
|
||||||
|
if ipsetName != "" {
|
||||||
|
specs = append(specs, "-m", "set", "--match-set", ipsetName, "src")
|
||||||
|
} else {
|
||||||
|
specs = append(specs, "-s", ip.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if protocol != "all" {
|
||||||
|
specs = append(specs, "-p", protocol)
|
||||||
|
}
|
||||||
|
specs = append(specs, applyPort("--sport", sPort)...)
|
||||||
|
specs = append(specs, applyPort("--dport", dPort)...)
|
||||||
|
return specs
|
||||||
|
}
|
||||||
|
|
||||||
|
func actionToStr(action firewall.Action) string {
|
||||||
|
if action == firewall.ActionAccept {
|
||||||
|
return "ACCEPT"
|
||||||
|
}
|
||||||
|
return "DROP"
|
||||||
|
}
|
||||||
|
|
||||||
|
func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port, action firewall.Action) string {
|
||||||
|
if ipsetName == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
actionSuffix := ""
|
||||||
|
if action == firewall.ActionDrop {
|
||||||
|
actionSuffix = "-drop"
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case sPort != nil && dPort != nil:
|
||||||
|
return ipsetName + "-sport-dport" + actionSuffix
|
||||||
|
case sPort != nil:
|
||||||
|
return ipsetName + "-sport" + actionSuffix
|
||||||
|
case dPort != nil:
|
||||||
|
return ipsetName + "-dport" + actionSuffix
|
||||||
|
default:
|
||||||
|
return ipsetName + actionSuffix
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) createIPSet(name string) error {
|
||||||
|
opts := ipset.CreateOptions{
|
||||||
|
Replace: true,
|
||||||
|
}
|
||||||
|
if m.v6 {
|
||||||
|
opts.Family = ipset.FamilyIPV6
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
||||||
|
return fmt.Errorf("create ipset %s: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("created ipset %s with type hash:net", name)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) addToIPSet(name string, ip net.IP) error {
|
||||||
|
cidr := uint8(32)
|
||||||
|
if ip.To4() == nil {
|
||||||
|
cidr = 128
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := &ipset.Entry{
|
||||||
|
IP: ip,
|
||||||
|
CIDR: cidr,
|
||||||
|
Replace: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ipset.Add(name, entry); err != nil {
|
||||||
|
return fmt.Errorf("add IP to ipset %s: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) delFromIPSet(name string, ip net.IP) error {
|
||||||
|
cidr := uint8(32)
|
||||||
|
if ip.To4() == nil {
|
||||||
|
cidr = 128
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := &ipset.Entry{
|
||||||
|
IP: ip,
|
||||||
|
CIDR: cidr,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ipset.Del(name, entry); err != nil {
|
||||||
|
return fmt.Errorf("delete IP from ipset %s: %w", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) flushIPSet(name string) error {
|
||||||
|
return ipset.Flush(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *aclManager) destroyIPSet(name string) error {
|
||||||
|
return ipset.Destroy(name)
|
||||||
|
}
|
||||||
@@ -1,346 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package iptables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (r *family) createContainers() error {
|
|
||||||
for _, chainInfo := range []struct {
|
|
||||||
chain string
|
|
||||||
table string
|
|
||||||
}{
|
|
||||||
{chainRTFwdIn, tableFilter},
|
|
||||||
{chainRTFwdOut, tableFilter},
|
|
||||||
{chainRTPre, tableMangle},
|
|
||||||
{chainRTNAT, tableNat},
|
|
||||||
{chainRTRdr, tableNat},
|
|
||||||
{chainRTMSSClamp, tableMangle},
|
|
||||||
} {
|
|
||||||
// Fallback: clear chains that survived an unclean shutdown.
|
|
||||||
if ok, _ := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain); ok {
|
|
||||||
if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
|
|
||||||
log.Warnf("clear stale chain %s in %s: %v", chainInfo.chain, chainInfo.table, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
|
|
||||||
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.insertEstablishedRule(chainRTFwdIn); err != nil {
|
|
||||||
return fmt.Errorf("insert established rule: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.insertEstablishedRule(chainRTFwdOut); err != nil {
|
|
||||||
return fmt.Errorf("insert established rule: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.addPostroutingRules(); err != nil {
|
|
||||||
return fmt.Errorf("add static nat rules: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.addJumpRules(); err != nil {
|
|
||||||
return fmt.Errorf("add jump rules: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.addMSSClampingRules(); err != nil {
|
|
||||||
log.Errorf("failed to add MSS clamping rules: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) addJumpRules() error {
|
|
||||||
// Jump to nat chain
|
|
||||||
natRule := jumpRuleSpec(chainRTNAT)
|
|
||||||
if err := r.iptablesClient.Insert(tableNat, chainPostrouting, 1, natRule...); err != nil {
|
|
||||||
return fmt.Errorf("add nat postrouting jump rule: %w", err)
|
|
||||||
}
|
|
||||||
r.rules[jumpNATPost] = natRule
|
|
||||||
|
|
||||||
// Jump to mangle prerouting chain
|
|
||||||
preRule := jumpRuleSpec(chainRTPre)
|
|
||||||
if err := r.iptablesClient.Insert(tableMangle, chainPrerouting, 1, preRule...); err != nil {
|
|
||||||
return fmt.Errorf("add mangle prerouting jump rule: %w", err)
|
|
||||||
}
|
|
||||||
r.rules[jumpManglePre] = preRule
|
|
||||||
|
|
||||||
// Jump to nat prerouting chain
|
|
||||||
rdrRule := jumpRuleSpec(chainRTRdr)
|
|
||||||
if err := r.iptablesClient.Insert(tableNat, chainPrerouting, 1, rdrRule...); err != nil {
|
|
||||||
return fmt.Errorf("add nat prerouting jump rule: %w", err)
|
|
||||||
}
|
|
||||||
r.rules[jumpNATPre] = rdrRule
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) setupDataPlaneMark() error {
|
|
||||||
var merr *multierror.Error
|
|
||||||
preRule := []string{
|
|
||||||
"-i", r.wgIface.Name(),
|
|
||||||
"-m", "conntrack", "--ctstate", "NEW",
|
|
||||||
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkIn),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.iptablesClient.AppendUnique(tableMangle, chainPrerouting, preRule...); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add mangle prerouting rule: %w", err))
|
|
||||||
} else {
|
|
||||||
r.rules[markManglePre] = preRule
|
|
||||||
}
|
|
||||||
|
|
||||||
postRule := []string{
|
|
||||||
"-o", r.wgIface.Name(),
|
|
||||||
"-m", "conntrack", "--ctstate", "NEW",
|
|
||||||
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkOut),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.iptablesClient.AppendUnique(tableMangle, chainPostrouting, postRule...); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add mangle postrouting rule: %w", err))
|
|
||||||
} else {
|
|
||||||
r.rules[markManglePost] = postRule
|
|
||||||
}
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// seedInitialEntries adds default rules to the entries map. Rules are
|
|
||||||
// inserted at position 1, so the order here is reversed.
|
|
||||||
//
|
|
||||||
// Existing FORWARD policy decides outbound traffic towards our
|
|
||||||
// interface. If FORWARD policy is "drop", we add an
|
|
||||||
// established/related rule to allow return traffic for inbound rules.
|
|
||||||
func (r *family) seedInitialEntries() {
|
|
||||||
established := getConntrackEstablished()
|
|
||||||
|
|
||||||
r.appendToEntries(chainInput, []string{"-i", r.wgIface.Name(), "-j", "DROP"})
|
|
||||||
r.appendToEntries(chainInput, []string{"-i", r.wgIface.Name(), "-j", chainACLInput})
|
|
||||||
r.appendToEntries(chainInput, append([]string{"-i", r.wgIface.Name()}, established...))
|
|
||||||
|
|
||||||
r.appendToEntries(chainForward, []string{"-i", r.wgIface.Name(), "-j", "DROP"})
|
|
||||||
r.appendToEntries(chainForward, []string{"-o", r.wgIface.Name(), "-j", chainRTFwdOut})
|
|
||||||
r.appendToEntries(chainForward, []string{"-i", r.wgIface.Name(), "-j", chainRTFwdIn})
|
|
||||||
|
|
||||||
// Mangle FORWARD guard: when external DNAT redirects traffic from
|
|
||||||
// the wg interface, it traverses FORWARD instead of INPUT,
|
|
||||||
// bypassing ACL rules. ACCEPT rules in filter FORWARD can be
|
|
||||||
// inserted above ours. Mangle runs before filter, so these guard
|
|
||||||
// rules enforce the ACL mark check where it cannot be overridden.
|
|
||||||
r.appendToEntries(mangleForwardKey, []string{
|
|
||||||
"-i", r.wgIface.Name(),
|
|
||||||
"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED",
|
|
||||||
"-j", "ACCEPT",
|
|
||||||
})
|
|
||||||
r.appendToEntries(mangleForwardKey, []string{
|
|
||||||
"-i", r.wgIface.Name(),
|
|
||||||
"-m", "conntrack", "--ctstate", "DNAT",
|
|
||||||
"-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
|
||||||
"-j", "DROP",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) seedInitialOptionalEntries() {
|
|
||||||
r.optionalEntries[chainForward] = []entry{
|
|
||||||
{
|
|
||||||
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
|
|
||||||
position: 2,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) appendToEntries(chain chainKey, spec ruleSpec) {
|
|
||||||
r.entries[chain] = append(r.entries[chain], spec)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) createDefaultChains() error {
|
|
||||||
if err := r.iptablesClient.NewChain(tableFilter, chainACLInput); err != nil {
|
|
||||||
return fmt.Errorf("create %s chain: %w", chainACLInput, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for chain, rules := range r.entries {
|
|
||||||
// mangle FORWARD guard rules are handled separately below
|
|
||||||
if chain == mangleForwardKey {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, rule := range rules {
|
|
||||||
if err := r.iptablesClient.InsertUnique(tableFilter, string(chain), 1, rule...); err != nil {
|
|
||||||
return fmt.Errorf("insert jump rule into %s: %w", chain, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for chain, entries := range r.optionalEntries {
|
|
||||||
for _, entry := range entries {
|
|
||||||
if err := r.iptablesClient.InsertUnique(tableFilter, string(chain), entry.position, entry.spec...); err != nil {
|
|
||||||
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
r.entries[chain] = append(r.entries[chain], entry.spec)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
clear(r.optionalEntries)
|
|
||||||
|
|
||||||
// Insert mangle FORWARD guard rules to prevent external DNAT bypass.
|
|
||||||
for _, rule := range r.entries[mangleForwardKey] {
|
|
||||||
if err := r.iptablesClient.AppendUnique(tableMangle, chainForward, rule...); err != nil {
|
|
||||||
log.Errorf("failed to add mangle FORWARD guard rule: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) cleanUpDefaultForwardRules() error {
|
|
||||||
var merr *multierror.Error
|
|
||||||
|
|
||||||
// cleanJumpRules removes the OUTPUT jump to NETBIRD-NAT-OUTPUT among
|
|
||||||
// the others, so the chain below deletes cleanly instead of failing
|
|
||||||
// with "device or resource busy".
|
|
||||||
if err := r.cleanJumpRules(); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("clean jump rules: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, chainInfo := range []struct {
|
|
||||||
chain string
|
|
||||||
table string
|
|
||||||
}{
|
|
||||||
{chainRTFwdIn, tableFilter},
|
|
||||||
{chainRTFwdOut, tableFilter},
|
|
||||||
{chainRTPre, tableMangle},
|
|
||||||
{chainRTNAT, tableNat},
|
|
||||||
{chainRTRdr, tableNat},
|
|
||||||
{chainNATOutput, tableNat},
|
|
||||||
{chainRTMSSClamp, tableMangle},
|
|
||||||
} {
|
|
||||||
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
|
||||||
if err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("check chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
if err := r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("clear and delete chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) cleanJumpRules() error {
|
|
||||||
// locations maps each jump rule to the built-in table and chain it
|
|
||||||
// was inserted into, plus the netbird chain it targets.
|
|
||||||
locations := map[firewall.RuleID]struct{ table, chain, target string }{
|
|
||||||
jumpNATPost: {tableNat, chainPostrouting, chainRTNAT},
|
|
||||||
jumpManglePre: {tableMangle, chainPrerouting, chainRTPre},
|
|
||||||
jumpNATPre: {tableNat, chainPrerouting, chainRTRdr},
|
|
||||||
jumpMSSClamp: {tableMangle, chainForward, chainRTMSSClamp},
|
|
||||||
jumpNATOutput: {tableNat, chainOutput, chainNATOutput},
|
|
||||||
}
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
for ruleID, loc := range locations {
|
|
||||||
rule, exists := r.rules[ruleID]
|
|
||||||
if !exists {
|
|
||||||
// Untracked (e.g. fresh start after an unclean shutdown with no
|
|
||||||
// restored state): if the target chain survived, remove the stale
|
|
||||||
// jump to it so the chain can be deleted.
|
|
||||||
ok, err := r.iptablesClient.ChainExists(loc.table, loc.target)
|
|
||||||
if err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("check chain %s in table %s: %w", loc.target, loc.table, err))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
rule = jumpRuleSpec(loc.target)
|
|
||||||
}
|
|
||||||
if err := r.iptablesClient.DeleteIfExists(loc.table, loc.chain, rule...); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete rule from chain %s in table %s: %w", loc.chain, loc.table, err))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
delete(r.rules, ruleID)
|
|
||||||
}
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// jumpRuleSpec builds the iptables rule spec that jumps to target. Create
|
|
||||||
// and cleanup sites share it so the installed and deleted specs cannot drift.
|
|
||||||
func jumpRuleSpec(target string) []string {
|
|
||||||
return []string{"-j", target}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) cleanAclChains() error {
|
|
||||||
var merr *multierror.Error
|
|
||||||
|
|
||||||
if err := r.cleanInputAclChain(); err != nil {
|
|
||||||
merr = multierror.Append(merr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rule := range r.entries[mangleForwardKey] {
|
|
||||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainForward, rule...); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete mangle %s guard rule %v: %w", chainForward, rule, err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) cleanInputAclChain() error {
|
|
||||||
ok, err := r.iptablesClient.ChainExists(tableFilter, chainACLInput)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("check chain %s: %w", chainACLInput, err)
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
for _, rule := range r.entries[chainInput] {
|
|
||||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainInput, rule...); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete %s rule %v: %w", chainInput, rule, err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rule := range r.entries[chainForward] {
|
|
||||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainForward, rule...); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete %s rule %v: %w", chainForward, rule, err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.iptablesClient.ClearAndDeleteChain(tableFilter, chainACLInput); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("clear and delete %s chain: %w", chainACLInput, err))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) cleanupDataPlaneMark() error {
|
|
||||||
var merr *multierror.Error
|
|
||||||
if preRule, exists := r.rules[markManglePre]; exists {
|
|
||||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPrerouting, preRule...); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove mangle prerouting rule: %w", err))
|
|
||||||
} else {
|
|
||||||
delete(r.rules, markManglePre)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if postRule, exists := r.rules[markManglePost]; exists {
|
|
||||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPostrouting, postRule...); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove mangle postrouting rule: %w", err))
|
|
||||||
} else {
|
|
||||||
delete(r.rules, markManglePost)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
@@ -1,285 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package iptables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (r *family) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
|
||||||
ruleID := rule.ID()
|
|
||||||
if _, exists := r.rules[ruleID+dnatSuffix]; exists {
|
|
||||||
return rule, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
toDestination := rule.TranslatedAddress.String()
|
|
||||||
switch {
|
|
||||||
case len(rule.TranslatedPort.Values) == 0:
|
|
||||||
// no translated port, use original port
|
|
||||||
case len(rule.TranslatedPort.Values) == 1:
|
|
||||||
toDestination += fmt.Sprintf(":%d", rule.TranslatedPort.Values[0])
|
|
||||||
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
|
|
||||||
// need the "/originalport" suffix to avoid dnat port randomization
|
|
||||||
toDestination += fmt.Sprintf(":%d-%d/%d", rule.TranslatedPort.Values[0], rule.TranslatedPort.Values[1], rule.DestinationPort.Values[0])
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
|
|
||||||
}
|
|
||||||
|
|
||||||
proto := strings.ToLower(string(rule.Protocol))
|
|
||||||
|
|
||||||
rules := make(map[firewall.RuleID]ruleInfo, 3)
|
|
||||||
|
|
||||||
// DNAT rule
|
|
||||||
dnatRule := []string{
|
|
||||||
"!", "-i", r.wgIface.Name(),
|
|
||||||
"-p", proto,
|
|
||||||
"-j", "DNAT",
|
|
||||||
"--to-destination", toDestination,
|
|
||||||
}
|
|
||||||
dnatRule = append(dnatRule, applyPort("--dport", &rule.DestinationPort)...)
|
|
||||||
rules[ruleID+dnatSuffix] = ruleInfo{
|
|
||||||
table: tableNat,
|
|
||||||
chain: chainRTRdr,
|
|
||||||
rule: dnatRule,
|
|
||||||
}
|
|
||||||
|
|
||||||
// SNAT rule
|
|
||||||
snatRule := []string{
|
|
||||||
"-o", r.wgIface.Name(),
|
|
||||||
"-p", proto,
|
|
||||||
"-d", rule.TranslatedAddress.String(),
|
|
||||||
"-j", "MASQUERADE",
|
|
||||||
}
|
|
||||||
snatRule = append(snatRule, applyPort("--dport", &rule.TranslatedPort)...)
|
|
||||||
rules[ruleID+snatSuffix] = ruleInfo{
|
|
||||||
table: tableNat,
|
|
||||||
chain: chainRTNAT,
|
|
||||||
rule: snatRule,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forward filtering rule, if fwd policy is DROP
|
|
||||||
forwardRule := []string{
|
|
||||||
"-o", r.wgIface.Name(),
|
|
||||||
"-p", proto,
|
|
||||||
"-d", rule.TranslatedAddress.String(),
|
|
||||||
"-j", "ACCEPT",
|
|
||||||
}
|
|
||||||
forwardRule = append(forwardRule, applyPort("--dport", &rule.TranslatedPort)...)
|
|
||||||
rules[ruleID+fwdSuffix] = ruleInfo{
|
|
||||||
table: tableFilter,
|
|
||||||
chain: chainRTFwdOut,
|
|
||||||
rule: forwardRule,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Request forwarding once the rule is about to be installed, releasing
|
|
||||||
// it if installation fails so the refcount tracks the real rules.
|
|
||||||
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for key, ruleInfo := range rules {
|
|
||||||
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
|
|
||||||
if rollbackErr := r.rollbackRules(rules); rollbackErr != nil {
|
|
||||||
log.Errorf("rollback failed: %v", rollbackErr)
|
|
||||||
}
|
|
||||||
r.releaseForwarding()
|
|
||||||
return nil, fmt.Errorf("add rule %s: %w", key, err)
|
|
||||||
}
|
|
||||||
r.rules[key] = ruleInfo.rule
|
|
||||||
}
|
|
||||||
|
|
||||||
r.updateState()
|
|
||||||
return rule, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) rollbackRules(rules map[firewall.RuleID]ruleInfo) error {
|
|
||||||
var merr *multierror.Error
|
|
||||||
for key, ruleInfo := range rules {
|
|
||||||
if err := r.iptablesClient.DeleteIfExists(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("rollback rule %s: %w", key, err))
|
|
||||||
// On rollback error, add to rules map for next cleanup
|
|
||||||
r.rules[key] = ruleInfo.rule
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if merr != nil {
|
|
||||||
r.updateState()
|
|
||||||
}
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) DeleteDNATRule(rule firewall.Rule) error {
|
|
||||||
ruleID := rule.ID()
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
var found bool
|
|
||||||
if dnatRule, exists := r.rules[ruleID+dnatSuffix]; exists {
|
|
||||||
found = true
|
|
||||||
if err := r.iptablesClient.Delete(tableNat, chainRTRdr, dnatRule...); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete DNAT rule: %w", err))
|
|
||||||
}
|
|
||||||
delete(r.rules, ruleID+dnatSuffix)
|
|
||||||
}
|
|
||||||
|
|
||||||
if snatRule, exists := r.rules[ruleID+snatSuffix]; exists {
|
|
||||||
found = true
|
|
||||||
if err := r.iptablesClient.Delete(tableNat, chainRTNAT, snatRule...); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete SNAT rule: %w", err))
|
|
||||||
}
|
|
||||||
delete(r.rules, ruleID+snatSuffix)
|
|
||||||
}
|
|
||||||
|
|
||||||
if fwdRule, exists := r.rules[ruleID+fwdSuffix]; exists {
|
|
||||||
found = true
|
|
||||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFwdOut, fwdRule...); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
|
|
||||||
}
|
|
||||||
delete(r.rules, ruleID+fwdSuffix)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.updateState()
|
|
||||||
|
|
||||||
// Release once, only if the rule was present and removed.
|
|
||||||
if merr == nil && found {
|
|
||||||
r.releaseForwarding()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// releaseForwarding drops one IP forwarding reference, logging any error.
|
|
||||||
func (r *family) releaseForwarding() {
|
|
||||||
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
|
||||||
log.Errorf("release IP forwarding: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
|
||||||
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
|
||||||
|
|
||||||
if _, exists := r.rules[ruleID]; exists {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
dnatRule := []string{
|
|
||||||
"-i", r.wgIface.Name(),
|
|
||||||
"-p", strings.ToLower(protoForFamily(protocol, r.v6)),
|
|
||||||
"--dport", strconv.Itoa(int(originalPort)),
|
|
||||||
"-d", localAddr.String(),
|
|
||||||
"-m", "addrtype", "--dst-type", "LOCAL",
|
|
||||||
"-j", "DNAT",
|
|
||||||
"--to-destination", ":" + strconv.Itoa(int(translatedPort)),
|
|
||||||
}
|
|
||||||
|
|
||||||
info := ruleInfo{
|
|
||||||
table: tableNat,
|
|
||||||
chain: chainRTRdr,
|
|
||||||
rule: dnatRule,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.iptablesClient.Append(info.table, info.chain, info.rule...); err != nil {
|
|
||||||
return fmt.Errorf("add inbound DNAT rule: %w", err)
|
|
||||||
}
|
|
||||||
r.rules[ruleID] = info.rule
|
|
||||||
|
|
||||||
r.updateState()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
|
||||||
func (r *family) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
|
||||||
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
|
||||||
|
|
||||||
if dnatRule, exists := r.rules[ruleID]; exists {
|
|
||||||
if err := r.iptablesClient.Delete(tableNat, chainRTRdr, dnatRule...); err != nil {
|
|
||||||
return fmt.Errorf("delete inbound DNAT rule: %w", err)
|
|
||||||
}
|
|
||||||
delete(r.rules, ruleID)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.updateState()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ensureNATOutputChain lazily creates the OUTPUT NAT chain and jump rule on first use.
|
|
||||||
func (r *family) ensureNATOutputChain() error {
|
|
||||||
if _, exists := r.rules[jumpNATOutput]; exists {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
chainExists, err := r.iptablesClient.ChainExists(tableNat, chainNATOutput)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("check chain %s: %w", chainNATOutput, err)
|
|
||||||
}
|
|
||||||
if !chainExists {
|
|
||||||
if err := r.iptablesClient.NewChain(tableNat, chainNATOutput); err != nil {
|
|
||||||
return fmt.Errorf("create chain %s: %w", chainNATOutput, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
jumpRule := jumpRuleSpec(chainNATOutput)
|
|
||||||
if err := r.iptablesClient.Insert(tableNat, chainOutput, 1, jumpRule...); err != nil {
|
|
||||||
if !chainExists {
|
|
||||||
if delErr := r.iptablesClient.ClearAndDeleteChain(tableNat, chainNATOutput); delErr != nil {
|
|
||||||
log.Warnf("failed to rollback chain %s: %v", chainNATOutput, delErr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fmt.Errorf("add OUTPUT jump rule: %w", err)
|
|
||||||
}
|
|
||||||
r.rules[jumpNATOutput] = jumpRule
|
|
||||||
|
|
||||||
r.updateState()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
|
||||||
func (r *family) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
|
||||||
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
|
||||||
|
|
||||||
if _, exists := r.rules[ruleID]; exists {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.ensureNATOutputChain(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
dnatRule := []string{
|
|
||||||
"-p", strings.ToLower(protoForFamily(protocol, localAddr.Is6())),
|
|
||||||
"--dport", strconv.Itoa(int(originalPort)),
|
|
||||||
"-d", localAddr.String(),
|
|
||||||
"-j", "DNAT",
|
|
||||||
"--to-destination", ":" + strconv.Itoa(int(translatedPort)),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.iptablesClient.Append(tableNat, chainNATOutput, dnatRule...); err != nil {
|
|
||||||
return fmt.Errorf("add output DNAT rule: %w", err)
|
|
||||||
}
|
|
||||||
r.rules[ruleID] = dnatRule
|
|
||||||
|
|
||||||
r.updateState()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
|
||||||
func (r *family) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
|
||||||
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
|
||||||
|
|
||||||
if dnatRule, exists := r.rules[ruleID]; exists {
|
|
||||||
if err := r.iptablesClient.Delete(tableNat, chainNATOutput, dnatRule...); err != nil {
|
|
||||||
return fmt.Errorf("delete output DNAT rule: %w", err)
|
|
||||||
}
|
|
||||||
delete(r.rules, ruleID)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.updateState()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,252 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package iptables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"maps"
|
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
|
||||||
)
|
|
||||||
|
|
||||||
// constants needed to manage and create iptable rules
|
|
||||||
const (
|
|
||||||
tableFilter = "filter"
|
|
||||||
tableNat = "nat"
|
|
||||||
tableMangle = "mangle"
|
|
||||||
|
|
||||||
// chainACLInput is the peer ACL chain that holds installed
|
|
||||||
// peer-filtering rules.
|
|
||||||
chainACLInput = "NETBIRD-ACL-INPUT"
|
|
||||||
|
|
||||||
// mangleForwardKey is the entries map key for mangle FORWARD guard
|
|
||||||
// rules that prevent external DNAT from bypassing ACL rules.
|
|
||||||
mangleForwardKey chainKey = "MANGLE-FORWARD"
|
|
||||||
|
|
||||||
chainInput = "INPUT"
|
|
||||||
chainPostrouting = "POSTROUTING"
|
|
||||||
chainPrerouting = "PREROUTING"
|
|
||||||
chainForward = "FORWARD"
|
|
||||||
chainRTNAT = "NETBIRD-RT-NAT"
|
|
||||||
chainRTFwdIn = "NETBIRD-RT-FWD-IN"
|
|
||||||
chainRTFwdOut = "NETBIRD-RT-FWD-OUT"
|
|
||||||
chainRTPre = "NETBIRD-RT-PRE"
|
|
||||||
chainRTRdr = "NETBIRD-RT-RDR"
|
|
||||||
chainNATOutput = "NETBIRD-NAT-OUTPUT"
|
|
||||||
chainRTMSSClamp = "NETBIRD-RT-MSSCLAMP"
|
|
||||||
|
|
||||||
jumpManglePre = "jump-mangle-pre"
|
|
||||||
jumpNATPre = "jump-nat-pre"
|
|
||||||
jumpNATPost = "jump-nat-post"
|
|
||||||
jumpNATOutput = "jump-nat-output"
|
|
||||||
jumpMSSClamp = "jump-mss-clamp"
|
|
||||||
markManglePre = "mark-mangle-pre"
|
|
||||||
markManglePost = "mark-mangle-post"
|
|
||||||
matchSet = "--match-set"
|
|
||||||
|
|
||||||
dnatSuffix firewall.RuleID = "_dnat"
|
|
||||||
snatSuffix firewall.RuleID = "_snat"
|
|
||||||
fwdSuffix firewall.RuleID = "_fwd"
|
|
||||||
|
|
||||||
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
|
|
||||||
ipv4TCPHeaderSize = 40
|
|
||||||
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
|
|
||||||
ipv6TCPHeaderSize = 60
|
|
||||||
)
|
|
||||||
|
|
||||||
type ruleInfo struct {
|
|
||||||
chain string
|
|
||||||
table string
|
|
||||||
rule []string
|
|
||||||
}
|
|
||||||
|
|
||||||
type routeRules map[firewall.RuleID][]string
|
|
||||||
|
|
||||||
// ruleSpec is a single iptables rule expressed as its argument list
|
|
||||||
// (e.g. {"-i", "wg0", "-j", "DROP"}).
|
|
||||||
type ruleSpec []string
|
|
||||||
|
|
||||||
// chainKey identifies the chain a seeded entry belongs to. It holds
|
|
||||||
// built-in chain names ("INPUT", "FORWARD", "PREROUTING") plus the
|
|
||||||
// synthetic mangleForwardKey bucket for the mangle FORWARD guard rules.
|
|
||||||
type chainKey string
|
|
||||||
|
|
||||||
// aclEntries maps a chain to the rules seeded into it to jump into or
|
|
||||||
// guard the netbird ACL chains.
|
|
||||||
type aclEntries map[chainKey][]ruleSpec
|
|
||||||
|
|
||||||
type entry struct {
|
|
||||||
spec ruleSpec
|
|
||||||
position int
|
|
||||||
}
|
|
||||||
|
|
||||||
// ipsetCounter is the shared hash:net refcounter used by peer and
|
|
||||||
// route ACLs alike. The ipset library does not support comments, so
|
|
||||||
// the key is just the set name (string).
|
|
||||||
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
|
|
||||||
|
|
||||||
// family holds the per-address-family iptables state. One instance
|
|
||||||
// handles route ACLs, peer ACLs, NAT, DNAT, and MSS clamping for a
|
|
||||||
// single family; the top-level Manager owns one for v4 and another
|
|
||||||
// for v6.
|
|
||||||
type family struct {
|
|
||||||
iptablesClient *iptables.IPTables
|
|
||||||
wgIface iFaceMapper
|
|
||||||
v6 bool
|
|
||||||
|
|
||||||
// Peer ACL chain bookkeeping.
|
|
||||||
entries aclEntries
|
|
||||||
optionalEntries map[chainKey][]entry
|
|
||||||
|
|
||||||
// filters holds peer + route filter rules keyed by content hash.
|
|
||||||
// AddFilterRule writes here; DeleteFilterRule looks up by id.
|
|
||||||
filters map[nbid.RuleID]*Rule
|
|
||||||
ipsetCounter *ipsetCounter
|
|
||||||
|
|
||||||
// rules holds NAT, jump, and MSS-clamping rules (auxiliary
|
|
||||||
// plumbing that isn't a filter rule).
|
|
||||||
rules routeRules
|
|
||||||
|
|
||||||
// Routing / NAT.
|
|
||||||
legacyManagement bool
|
|
||||||
mtu uint16
|
|
||||||
ipFwdState *ipfwdstate.IPForwardingState
|
|
||||||
|
|
||||||
stateManager *statemanager.Manager
|
|
||||||
}
|
|
||||||
|
|
||||||
func newFamily(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*family, error) {
|
|
||||||
r := &family{
|
|
||||||
iptablesClient: iptablesClient,
|
|
||||||
wgIface: wgIface,
|
|
||||||
v6: iptablesClient.Proto() == iptables.ProtocolIPv6,
|
|
||||||
entries: make(aclEntries),
|
|
||||||
optionalEntries: make(map[chainKey][]entry),
|
|
||||||
filters: make(map[nbid.RuleID]*Rule),
|
|
||||||
rules: make(routeRules),
|
|
||||||
mtu: mtu,
|
|
||||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
|
||||||
}
|
|
||||||
|
|
||||||
r.ipsetCounter = refcounter.New(
|
|
||||||
func(name string, sources []netip.Prefix) (struct{}, error) {
|
|
||||||
return struct{}{}, r.createIpSet(name, sources)
|
|
||||||
},
|
|
||||||
func(name string, _ struct{}) error {
|
|
||||||
return r.deleteIpSet(name)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
return r, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// init wires the family to the state manager and installs both the
|
|
||||||
// route ACL containers and the peer ACL chain skeleton.
|
|
||||||
func (r *family) init(stateManager *statemanager.Manager) error {
|
|
||||||
r.stateManager = stateManager
|
|
||||||
|
|
||||||
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
|
||||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.createContainers(); err != nil {
|
|
||||||
return fmt.Errorf("create containers: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.setupDataPlaneMark(); err != nil {
|
|
||||||
log.Errorf("failed to set up data plane mark: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.seedInitialEntries()
|
|
||||||
r.seedInitialOptionalEntries()
|
|
||||||
|
|
||||||
if err := r.cleanAclChains(); err != nil {
|
|
||||||
return fmt.Errorf("clean acl chains: %w", err)
|
|
||||||
}
|
|
||||||
if err := r.createDefaultChains(); err != nil {
|
|
||||||
return fmt.Errorf("create default chains: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.updateState()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset tears down all firewall state owned by this family. ACL
|
|
||||||
// chain cleanup runs before route-chain cleanup because the route
|
|
||||||
// chains are still referenced by FORWARD jumps installed during
|
|
||||||
// seedInitialEntries; deleting them first would trip EBUSY.
|
|
||||||
func (r *family) Reset() error {
|
|
||||||
var merr *multierror.Error
|
|
||||||
|
|
||||||
if err := r.cleanAclChains(); err != nil {
|
|
||||||
merr = multierror.Append(merr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
|
||||||
merr = multierror.Append(merr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.ipsetCounter.Flush(); err != nil {
|
|
||||||
merr = multierror.Append(merr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.cleanupDataPlaneMark(); err != nil {
|
|
||||||
merr = multierror.Append(merr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
clear(r.rules)
|
|
||||||
clear(r.filters)
|
|
||||||
r.updateState()
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) updateState() {
|
|
||||||
if r.stateManager == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var currentState *ShutdownState
|
|
||||||
if existing := r.stateManager.GetState(currentState); existing != nil {
|
|
||||||
if existingState, ok := existing.(*ShutdownState); ok {
|
|
||||||
currentState = existingState
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if currentState == nil {
|
|
||||||
currentState = &ShutdownState{}
|
|
||||||
}
|
|
||||||
|
|
||||||
currentState.Lock()
|
|
||||||
defer currentState.Unlock()
|
|
||||||
|
|
||||||
// Clone the rule maps so the persisted state holds a private snapshot.
|
|
||||||
// The live maps keep being mutated by subsequent rule operations while
|
|
||||||
// the state manager marshals the state from its periodic-save goroutine.
|
|
||||||
// Sharing the maps by reference races the two and aborts the process with
|
|
||||||
// a concurrent map iteration and write. The ipset counter guards itself
|
|
||||||
// during marshaling, so it can be shared directly.
|
|
||||||
if r.v6 {
|
|
||||||
currentState.RouteRules6 = maps.Clone(r.rules)
|
|
||||||
currentState.RouteIPsetCounter6 = r.ipsetCounter
|
|
||||||
currentState.ACLEntries6 = maps.Clone(r.entries)
|
|
||||||
} else {
|
|
||||||
currentState.RouteRules = maps.Clone(r.rules)
|
|
||||||
currentState.RouteIPsetCounter = r.ipsetCounter
|
|
||||||
currentState.ACLEntries = maps.Clone(r.entries)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.stateManager.UpdateState(currentState); err != nil {
|
|
||||||
log.Errorf("failed to update state: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,346 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package iptables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
"slices"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
|
||||||
)
|
|
||||||
|
|
||||||
// AddFilterRule installs a packet-filtering rule. With destination
|
|
||||||
// empty, the rule goes to the peer ACL input chain plus a paired
|
|
||||||
// mangle PREROUTING rule for the redirect mark. With destination set
|
|
||||||
// (prefix or named set), it goes to the route ACL forward chain.
|
|
||||||
// Multi-source rules collapse to one iptables rule via the shared
|
|
||||||
// hash:net ipset.
|
|
||||||
func (r *family) AddFilterRule(
|
|
||||||
id []byte,
|
|
||||||
sources []netip.Prefix,
|
|
||||||
destination firewall.Network,
|
|
||||||
proto firewall.Protocol,
|
|
||||||
sPort *firewall.Port,
|
|
||||||
dPort *firewall.Port,
|
|
||||||
action firewall.Action,
|
|
||||||
) (firewall.Rule, error) {
|
|
||||||
ruleID := nbid.GenerateRuleID(sources, destination, proto, sPort, dPort, action)
|
|
||||||
if existing, ok := r.filters[ruleID]; ok {
|
|
||||||
return existing, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
srcMatch, err := r.applySourceMatch(sourceNetwork(sources), sources)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("apply source match: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rule, err := r.installFilterRule(ruleID, srcMatch, destination, proto, sPort, dPort, action)
|
|
||||||
if err != nil {
|
|
||||||
r.dropSourceMatch(srcMatch)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
r.filters[ruleID] = rule
|
|
||||||
r.updateState()
|
|
||||||
return rule, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) hasRule(id nbid.RuleID) bool {
|
|
||||||
_, ok := r.filters[id]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// hasDNATRule reports whether this family owns the DNAT rule set for
|
|
||||||
// the given user id. DNAT rules live in r.rules under the well-known
|
|
||||||
// "<id>_dnat" key; the lookup here is used by Manager.DeleteDNATRule
|
|
||||||
// to pick the right family.
|
|
||||||
func (r *family) hasDNATRule(id firewall.RuleID) bool {
|
|
||||||
_, ok := r.rules[id+dnatSuffix]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteFilterRule removes a previously installed filter rule. The
|
|
||||||
// rule's stored chain/table identify where to delete from; source set
|
|
||||||
// references are recovered from the spec via findSets and dropped
|
|
||||||
// from the shared ipset counter.
|
|
||||||
func (r *family) DeleteFilterRule(rule firewall.Rule) error {
|
|
||||||
ruleID := rule.ID()
|
|
||||||
pr, ok := r.filters[ruleID]
|
|
||||||
if !ok {
|
|
||||||
log.Debugf("filter rule %s not found", ruleID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteIfExists keeps both deletes idempotent so a retry after a
|
|
||||||
// partial failure does not error on the half that was already removed.
|
|
||||||
var merr *multierror.Error
|
|
||||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, pr.chain, pr.specs...); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete rule from %s: %w", pr.chain, err))
|
|
||||||
}
|
|
||||||
if pr.mangleSpecs != nil {
|
|
||||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPre, pr.mangleSpecs...); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete mangle rule: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if merr != nil {
|
|
||||||
// Leave the rule tracked so the caller retries the remaining half.
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The rule is gone from iptables, so untrack it regardless of how the
|
|
||||||
// refcount decrement goes, but surface decrement failures so callers
|
|
||||||
// see the ipset desync.
|
|
||||||
delete(r.filters, ruleID)
|
|
||||||
r.updateState()
|
|
||||||
if err := r.decrementSetCounter(pr.specs); err != nil {
|
|
||||||
return fmt.Errorf("drop source set references: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// findSets scans an iptables rule spec for "-m set --match-set <name>
|
|
||||||
// <dir>" fragments and returns the named sets in occurrence order.
|
|
||||||
// Used at delete time to drop ipsetCounter references.
|
|
||||||
func findSets(rule []string) []string {
|
|
||||||
var sets []string
|
|
||||||
for i, arg := range rule {
|
|
||||||
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
|
|
||||||
sets = append(sets, rule[i+3])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return sets
|
|
||||||
}
|
|
||||||
|
|
||||||
// sourceNetwork classifies a source-prefix list into the firewall.Network
|
|
||||||
// shape the rest of the spec-builder consumes: empty for match-any, a
|
|
||||||
// single prefix inline, or an ipset for multiple sources.
|
|
||||||
func sourceNetwork(sources []netip.Prefix) firewall.Network {
|
|
||||||
switch {
|
|
||||||
case len(sources) == 0:
|
|
||||||
return firewall.Network{}
|
|
||||||
case len(sources) == 1 && sources[0].Bits() == 0:
|
|
||||||
return firewall.Network{}
|
|
||||||
case len(sources) == 1:
|
|
||||||
return firewall.Network{Prefix: sources[0]}
|
|
||||||
default:
|
|
||||||
return firewall.Network{Set: firewall.NewPrefixSet(sources)}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// applySourceMatch returns the iptables match fragment for the rule's
|
|
||||||
// source. For a Set it increments the shared ipset's refcount; for a
|
|
||||||
// Prefix it emits a direct -s match; for the wildcard it returns nil.
|
|
||||||
func (r *family) applySourceMatch(network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
|
|
||||||
switch {
|
|
||||||
case network.IsSet():
|
|
||||||
if r.ipsetCounter == nil {
|
|
||||||
return nil, fmt.Errorf("multi-source peer rule requires shared ipset counter")
|
|
||||||
}
|
|
||||||
name := r.ipsetName(network.Set.HashedName())
|
|
||||||
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
|
|
||||||
return nil, fmt.Errorf("ipset increment %s: %w", name, err)
|
|
||||||
}
|
|
||||||
return []string{"-m", "set", matchSet, name, "src"}, nil
|
|
||||||
case network.IsPrefix():
|
|
||||||
return []string{"-s", network.Prefix.String()}, nil
|
|
||||||
default:
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// dropSourceMatch undoes whatever applySourceMatch reserved when
|
|
||||||
// installing a rule fails. Safe to call when the spec is empty or holds
|
|
||||||
// only inline matchers. Decrement errors are logged but not returned:
|
|
||||||
// the install error is what the caller needs to see.
|
|
||||||
func (r *family) dropSourceMatch(srcMatch []string) {
|
|
||||||
if r.ipsetCounter == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, name := range findSets(srcMatch) {
|
|
||||||
if _, err := r.ipsetCounter.Decrement(name); err != nil {
|
|
||||||
log.Errorf("rollback ipset decrement %s: %v", name, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// decrementSetCounter drops ipset references owned by a raw rule spec
|
|
||||||
// stored in r.rules (NAT / legacy route entries). It returns an error
|
|
||||||
// aggregate so the caller surfaces decrement failures.
|
|
||||||
func (r *family) decrementSetCounter(rule []string) error {
|
|
||||||
if r.ipsetCounter == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var merr *multierror.Error
|
|
||||||
for _, name := range findSets(rule) {
|
|
||||||
if _, err := r.ipsetCounter.Decrement(name); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("decrement counter: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// installFilterRule assembles and writes one iptables filter-chain
|
|
||||||
// rule. With destination empty the rule lands in the peer ACL input
|
|
||||||
// chain and a paired mangle PREROUTING rule is added for the redirect
|
|
||||||
// mark. With destination set the rule lands in the route ACL forward
|
|
||||||
// chain and there is no mangle pairing.
|
|
||||||
func (r *family) installFilterRule(
|
|
||||||
ruleID nbid.RuleID,
|
|
||||||
srcMatch []string,
|
|
||||||
destination firewall.Network,
|
|
||||||
protocol firewall.Protocol,
|
|
||||||
sPort, dPort *firewall.Port,
|
|
||||||
action firewall.Action,
|
|
||||||
) (*Rule, error) {
|
|
||||||
isRoute := !destination.IsZero()
|
|
||||||
|
|
||||||
proto := protoForFamily(protocol, r.v6)
|
|
||||||
|
|
||||||
specs := slices.Clone(srcMatch)
|
|
||||||
var destExp []string
|
|
||||||
if isRoute {
|
|
||||||
var err error
|
|
||||||
destExp, err = r.applyNetwork("-d", destination, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("apply network -d: %w", err)
|
|
||||||
}
|
|
||||||
specs = append(specs, destExp...)
|
|
||||||
}
|
|
||||||
specs = append(specs, filterMatchSpecs(proto, sPort, dPort)...)
|
|
||||||
|
|
||||||
var mangleSpecs []string
|
|
||||||
if !isRoute {
|
|
||||||
mangleSpecs = slices.Clone(specs)
|
|
||||||
mangleSpecs = append(mangleSpecs,
|
|
||||||
"-i", r.wgIface.Name(),
|
|
||||||
"-m", "addrtype", "--dst-type", "LOCAL",
|
|
||||||
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
specs = append(specs, "-j", actionToStr(action))
|
|
||||||
|
|
||||||
chain := chainACLInput
|
|
||||||
if isRoute {
|
|
||||||
chain = chainRTFwdIn
|
|
||||||
}
|
|
||||||
|
|
||||||
// Peer ACL drops are inserted at position 1 so they precede the
|
|
||||||
// chain's catch-all; route ACL drops are inserted at position 2
|
|
||||||
// to sit immediately after the established/related accept rule.
|
|
||||||
var err error
|
|
||||||
if action == firewall.ActionDrop {
|
|
||||||
pos := 1
|
|
||||||
if isRoute {
|
|
||||||
pos = 2
|
|
||||||
}
|
|
||||||
err = r.iptablesClient.Insert(tableFilter, chain, pos, specs...)
|
|
||||||
} else {
|
|
||||||
err = r.iptablesClient.Append(tableFilter, chain, specs...)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
r.dropSourceMatch(destExp)
|
|
||||||
return nil, fmt.Errorf("install filter rule on %s: %w", chain, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// The mangle redirect-mark rule is best effort: the filter rule itself
|
|
||||||
// is what enforces the ACL, so a mangle failure must not undo it. Drop
|
|
||||||
// the spec so teardown does not try to remove a rule that was not added.
|
|
||||||
if mangleSpecs != nil {
|
|
||||||
if err := r.iptablesClient.Append(tableMangle, chainRTPre, mangleSpecs...); err != nil {
|
|
||||||
log.Errorf("add mangle rule: %v", err)
|
|
||||||
mangleSpecs = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &Rule{
|
|
||||||
id: ruleID,
|
|
||||||
specs: specs,
|
|
||||||
mangleSpecs: mangleSpecs,
|
|
||||||
chain: chain,
|
|
||||||
v6: r.v6,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// applyNetwork resolves a firewall.Network into the iptables match
|
|
||||||
// fragment for the given direction flag (-s or -d). Set networks
|
|
||||||
// increment the shared ipset refcount; prefixes emit a direct match;
|
|
||||||
// an empty network returns no spec ("match any").
|
|
||||||
func (r *family) applyNetwork(flag string, network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
|
|
||||||
direction := "src"
|
|
||||||
if flag == "-d" {
|
|
||||||
direction = "dst"
|
|
||||||
}
|
|
||||||
|
|
||||||
if network.IsSet() {
|
|
||||||
name := r.ipsetName(network.Set.HashedName())
|
|
||||||
if _, err := r.ipsetCounter.Increment(name, prefixes); err != nil {
|
|
||||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return []string{"-m", "set", matchSet, name, direction}, nil
|
|
||||||
}
|
|
||||||
if network.IsPrefix() {
|
|
||||||
return []string{flag, network.Prefix.String()}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// nolint:nilnil
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// protoForFamily translates ICMP to ICMPv6 for ip6tables.
|
|
||||||
// ip6tables requires "ipv6-icmp" (or "icmpv6") instead of "icmp".
|
|
||||||
func protoForFamily(protocol firewall.Protocol, v6 bool) string {
|
|
||||||
if v6 && protocol == firewall.ProtocolICMP {
|
|
||||||
return "ipv6-icmp"
|
|
||||||
}
|
|
||||||
return string(protocol)
|
|
||||||
}
|
|
||||||
|
|
||||||
// filterMatchSpecs returns the proto/port match fragment for a
|
|
||||||
// filtering rule. The source match (-s or -m set) is built by the
|
|
||||||
// caller and prepended.
|
|
||||||
func filterMatchSpecs(protocol string, sPort, dPort *firewall.Port) (specs []string) {
|
|
||||||
if protocol != "all" {
|
|
||||||
specs = append(specs, "-p", protocol)
|
|
||||||
}
|
|
||||||
specs = append(specs, applyPort("--sport", sPort)...)
|
|
||||||
specs = append(specs, applyPort("--dport", dPort)...)
|
|
||||||
return specs
|
|
||||||
}
|
|
||||||
|
|
||||||
func actionToStr(action firewall.Action) string {
|
|
||||||
if action == firewall.ActionAccept {
|
|
||||||
return "ACCEPT"
|
|
||||||
}
|
|
||||||
return "DROP"
|
|
||||||
}
|
|
||||||
|
|
||||||
func applyPort(flag string, port *firewall.Port) []string {
|
|
||||||
if port == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if port.IsRange && len(port.Values) == 2 {
|
|
||||||
return []string{flag, fmt.Sprintf("%d:%d", port.Values[0], port.Values[1])}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(port.Values) > 1 {
|
|
||||||
portList := make([]string, len(port.Values))
|
|
||||||
for i, p := range port.Values {
|
|
||||||
portList[i] = strconv.Itoa(int(p))
|
|
||||||
}
|
|
||||||
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
|
|
||||||
}
|
|
||||||
|
|
||||||
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
|
||||||
}
|
|
||||||
@@ -1,93 +0,0 @@
|
|||||||
package iptables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
// InterfaceAllower opens the NetBird interface on the iptables filter INPUT
|
|
||||||
// chain so the host firewall doesn't drop traffic the userspace firewall
|
|
||||||
// handles. It is the fallback used when nftables is unavailable (an
|
|
||||||
// iptables-legacy host).
|
|
||||||
//
|
|
||||||
// It opens INPUT only: the userspace router never forwards in the kernel.
|
|
||||||
// firewalld trust is handled by the uspfilter manager, not here.
|
|
||||||
type InterfaceAllower struct {
|
|
||||||
ifaceName string
|
|
||||||
ipt4 *iptables.IPTables
|
|
||||||
// ipt6 is nil when the interface has no IPv6 overlay address.
|
|
||||||
ipt6 *iptables.IPTables
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewInterfaceAllower builds an iptables allower for the interface. It returns
|
|
||||||
// an error when iptables is unavailable, so the caller can fall back to
|
|
||||||
// firewalld trust.
|
|
||||||
func NewInterfaceAllower(wgIface iFaceMapper) (*InterfaceAllower, error) {
|
|
||||||
ipt4, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("iptables not available: %w", err)
|
|
||||||
}
|
|
||||||
if _, err := ipt4.ListChains(tableFilter); err != nil {
|
|
||||||
return nil, fmt.Errorf("iptables filter table not available: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
a := &InterfaceAllower{ifaceName: wgIface.Name(), ipt4: ipt4}
|
|
||||||
|
|
||||||
// Missing v6 must not break the v4 path: open v4 only and continue.
|
|
||||||
if wgIface.Address().HasIPv6() {
|
|
||||||
ipt6, err := iptables.NewWithProtocol(iptables.ProtocolIPv6)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("ip6tables not available, opening interface on v4 only: %v", err)
|
|
||||||
} else if _, err := ipt6.ListChains(tableFilter); err != nil {
|
|
||||||
log.Warnf("ip6tables filter table not available, opening interface on v4 only: %v", err)
|
|
||||||
} else {
|
|
||||||
a.ipt6 = ipt6
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return a, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply inserts the interface accept rule on the filter INPUT chain. It removes
|
|
||||||
// any stale rule first so an unclean exit (e.g. SIGKILL, where Close never ran)
|
|
||||||
// is recovered deterministically rather than accumulating duplicates.
|
|
||||||
func (a *InterfaceAllower) Apply() error {
|
|
||||||
var merr *multierror.Error
|
|
||||||
for _, ipt := range a.clients() {
|
|
||||||
if err := ipt.DeleteIfExists(tableFilter, chainInput, a.inputRule()...); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("clean stale interface accept rule: %w", err))
|
|
||||||
}
|
|
||||||
if err := ipt.Insert(tableFilter, chainInput, 1, a.inputRule()...); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add interface accept rule: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close removes the interface accept rule.
|
|
||||||
func (a *InterfaceAllower) Close() error {
|
|
||||||
var merr *multierror.Error
|
|
||||||
for _, ipt := range a.clients() {
|
|
||||||
if err := ipt.DeleteIfExists(tableFilter, chainInput, a.inputRule()...); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove interface accept rule: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *InterfaceAllower) inputRule() []string {
|
|
||||||
return []string{"-i", a.ifaceName, "-j", "ACCEPT"}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *InterfaceAllower) clients() []*iptables.IPTables {
|
|
||||||
clients := []*iptables.IPTables{a.ipt4}
|
|
||||||
if a.ipt6 != nil {
|
|
||||||
clients = append(clients, a.ipt6)
|
|
||||||
}
|
|
||||||
return clients
|
|
||||||
}
|
|
||||||
@@ -1,104 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package iptables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
"github.com/lrh3321/ipset-go"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (r *family) createIpSet(setName string, sources []netip.Prefix) error {
|
|
||||||
if err := r.createIPSet(setName); err != nil {
|
|
||||||
return fmt.Errorf("create set %s: %w", setName, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, prefix := range sources {
|
|
||||||
if err := r.addPrefixToIPSet(setName, prefix); err != nil {
|
|
||||||
// The refcounter records nothing when this callback errors,
|
|
||||||
// so destroy the set or it leaks in the kernel. A partial
|
|
||||||
// source set would also fail-open for deny rules, so the
|
|
||||||
// rule must fail rather than install with a missing source.
|
|
||||||
if derr := r.destroyIPSet(setName); derr != nil {
|
|
||||||
log.Warnf("rollback ipset %s after add failure: %v", setName, derr)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("add element to set %s: %w", setName, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) deleteIpSet(setName string) error {
|
|
||||||
if err := r.destroyIPSet(setName); err != nil {
|
|
||||||
return fmt.Errorf("destroy set %s: %w", setName, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("deleted unused ipset %s", setName)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
|
||||||
name := r.ipsetName(set.HashedName())
|
|
||||||
var merr *multierror.Error
|
|
||||||
for _, prefix := range prefixes {
|
|
||||||
if err := r.addPrefixToIPSet(name, prefix); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add prefix to ipset: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if merr == nil {
|
|
||||||
log.Debugf("updated set %s with prefixes %v", name, prefixes)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) ipsetName(name string) string {
|
|
||||||
if r.v6 {
|
|
||||||
return name + "-v6"
|
|
||||||
}
|
|
||||||
return name
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) createIPSet(name string) error {
|
|
||||||
opts := ipset.CreateOptions{
|
|
||||||
Replace: true,
|
|
||||||
}
|
|
||||||
if r.v6 {
|
|
||||||
opts.Family = ipset.FamilyIPV6
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ipset.Create(name, ipset.TypeHashNet, opts); err != nil {
|
|
||||||
return fmt.Errorf("create ipset %s: %w", name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("created ipset %s with type hash:net", name)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) addPrefixToIPSet(name string, prefix netip.Prefix) error {
|
|
||||||
addr := prefix.Addr()
|
|
||||||
ip := addr.AsSlice()
|
|
||||||
|
|
||||||
entry := &ipset.Entry{
|
|
||||||
IP: ip,
|
|
||||||
CIDR: uint8(prefix.Bits()),
|
|
||||||
Replace: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := ipset.Add(name, entry); err != nil {
|
|
||||||
return fmt.Errorf("add prefix to ipset %s: %w", name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) destroyIPSet(name string) error {
|
|
||||||
return ipset.Destroy(name)
|
|
||||||
}
|
|
||||||
@@ -3,6 +3,7 @@ package iptables
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -17,21 +18,25 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Manager of iptables firewall. Per-family state (peer ACLs, route
|
type resetter interface {
|
||||||
// ACLs, NAT, DNAT, MSS clamping) lives on family; Manager dispatches
|
Reset() error
|
||||||
// by family and provides the public firewall.Manager surface.
|
}
|
||||||
|
|
||||||
|
// Manager of iptables firewall
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
|
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
|
|
||||||
ipv4Client *iptables.IPTables
|
ipv4Client *iptables.IPTables
|
||||||
family4 *family
|
aclMgr *aclManager
|
||||||
|
router *router
|
||||||
rawSupported bool
|
rawSupported bool
|
||||||
|
|
||||||
// IPv6 counterparts, nil when no v6 overlay
|
// IPv6 counterparts, nil when no v6 overlay
|
||||||
ipv6Client *iptables.IPTables
|
ipv6Client *iptables.IPTables
|
||||||
family6 *family
|
aclMgr6 *aclManager
|
||||||
|
router6 *router
|
||||||
}
|
}
|
||||||
|
|
||||||
// iFaceMapper defines subset methods of interface required for manager
|
// iFaceMapper defines subset methods of interface required for manager
|
||||||
@@ -52,9 +57,14 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
|||||||
ipv4Client: iptablesClient,
|
ipv4Client: iptablesClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
m.family4, err = newFamily(iptablesClient, wgIface, mtu)
|
m.router, err = newRouter(iptablesClient, wgIface, mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create family: %w", err)
|
return nil, fmt.Errorf("create router: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.aclMgr, err = newAclManager(iptablesClient, wgIface)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if wgIface.Address().HasIPv6() {
|
if wgIface.Address().HasIPv6() {
|
||||||
@@ -71,18 +81,21 @@ func (m *Manager) createIPv6Components(wgIface iFaceMapper, mtu uint16) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("init ip6tables: %w", err)
|
return fmt.Errorf("init ip6tables: %w", err)
|
||||||
}
|
}
|
||||||
|
m.ipv6Client = ip6Client
|
||||||
|
|
||||||
family6, err := newFamily(ip6Client, wgIface, mtu)
|
m.router6, err = newRouter(ip6Client, wgIface, mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create v6 family: %w", err)
|
return fmt.Errorf("create v6 router: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Share the same IP forwarding state with the v4 family, since
|
// Share the same IP forwarding state with the v4 router, since
|
||||||
// EnableIPForwarding controls both v4 and v6 sysctls.
|
// EnableIPForwarding controls both v4 and v6 sysctls.
|
||||||
family6.ipFwdState = m.family4.ipFwdState
|
m.router6.ipFwdState = m.router.ipFwdState
|
||||||
|
|
||||||
m.ipv6Client = ip6Client
|
m.aclMgr6, err = newAclManager(ip6Client, wgIface)
|
||||||
m.family6 = family6
|
if err != nil {
|
||||||
|
return fmt.Errorf("create v6 acl manager: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -96,7 +109,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
InterfaceState: &InterfaceState{
|
InterfaceState: &InterfaceState{
|
||||||
NameStr: m.wgIface.Name(),
|
NameStr: m.wgIface.Name(),
|
||||||
WGAddress: m.wgIface.Address(),
|
WGAddress: m.wgIface.Address(),
|
||||||
MTU: m.family4.mtu,
|
MTU: m.router.mtu,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
stateManager.RegisterState(state)
|
stateManager.RegisterState(state)
|
||||||
@@ -128,24 +141,31 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// initChains initializes the per-family firewall state for both
|
// initChains initializes router and ACL chains for both address families,
|
||||||
// address families, rolling back on failure.
|
// rolling back on failure.
|
||||||
func (m *Manager) initChains(stateManager *statemanager.Manager) error {
|
func (m *Manager) initChains(stateManager *statemanager.Manager) error {
|
||||||
type initStep struct {
|
type initStep struct {
|
||||||
name string
|
name string
|
||||||
r *family
|
init func(*statemanager.Manager) error
|
||||||
|
mgr resetter
|
||||||
}
|
}
|
||||||
|
|
||||||
steps := []initStep{{"v4", m.family4}}
|
steps := []initStep{
|
||||||
|
{"router", m.router.init, m.router},
|
||||||
|
{"acl manager", m.aclMgr.init, m.aclMgr},
|
||||||
|
}
|
||||||
if m.hasIPv6() {
|
if m.hasIPv6() {
|
||||||
steps = append(steps, initStep{"v6", m.family6})
|
steps = append(steps,
|
||||||
|
initStep{"v6 router", m.router6.init, m.router6},
|
||||||
|
initStep{"v6 acl manager", m.aclMgr6.init, m.aclMgr6},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
var initialized []initStep
|
var initialized []initStep
|
||||||
for _, s := range steps {
|
for _, s := range steps {
|
||||||
if err := s.r.init(stateManager); err != nil {
|
if err := s.init(stateManager); err != nil {
|
||||||
for i := len(initialized) - 1; i >= 0; i-- {
|
for i := len(initialized) - 1; i >= 0; i-- {
|
||||||
if rerr := initialized[i].r.Reset(); rerr != nil {
|
if rerr := initialized[i].mgr.Reset(); rerr != nil {
|
||||||
log.Warnf("rollback %s: %v", initialized[i].name, rerr)
|
log.Warnf("rollback %s: %v", initialized[i].name, rerr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -156,50 +176,84 @@ func (m *Manager) initChains(stateManager *statemanager.Manager) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddFilterRule installs a packet-filtering rule. See firewall.Manager
|
// AddPeerFiltering adds a rule to the firewall
|
||||||
// docs for destination semantics. Sources are a single address family;
|
//
|
||||||
// the rule is dispatched to the matching v4 / v6 backend.
|
// Comment will be ignored because some system this feature is not supported
|
||||||
func (m *Manager) AddFilterRule(
|
func (m *Manager) AddPeerFiltering(
|
||||||
id []byte,
|
id []byte,
|
||||||
sources []netip.Prefix,
|
ip net.IP,
|
||||||
destination firewall.Network,
|
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
ipsetName string,
|
||||||
if len(sources) == 0 {
|
) ([]firewall.Rule, error) {
|
||||||
return nil, firewall.ErrNoSources
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
fam := m.family4
|
if ip.To4() != nil {
|
||||||
if isIPv6Rule(sources, destination) {
|
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||||
if !m.hasIPv6() {
|
|
||||||
return nil, fmt.Errorf("add filtering: %w", firewall.ErrIPv6NotInitialized)
|
|
||||||
}
|
|
||||||
fam = m.family6
|
|
||||||
}
|
}
|
||||||
return fam.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
|
if !m.hasIPv6() {
|
||||||
|
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
|
||||||
|
}
|
||||||
|
return m.aclMgr6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteFilterRule removes a rule previously added via AddFilterRule.
|
func (m *Manager) AddRouteFiltering(
|
||||||
// The rule is looked up by id in each family's filter cache.
|
id []byte,
|
||||||
func (m *Manager) DeleteFilterRule(rule firewall.Rule) error {
|
sources []netip.Prefix,
|
||||||
|
destination firewall.Network,
|
||||||
|
proto firewall.Protocol,
|
||||||
|
sPort, dPort *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
) (firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
id := rule.ID()
|
if isIPv6RouteRule(sources, destination) {
|
||||||
if m.family4.hasRule(id) {
|
if !m.hasIPv6() {
|
||||||
return m.family4.DeleteFilterRule(rule)
|
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
|
||||||
|
}
|
||||||
|
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
if m.hasIPv6() && m.family6.hasRule(id) {
|
|
||||||
return m.family6.DeleteFilterRule(rule)
|
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
|
||||||
|
if destination.IsPrefix() {
|
||||||
|
return destination.Prefix.Addr().Is6()
|
||||||
}
|
}
|
||||||
log.Debugf("filter rule %s not found in any family", id)
|
return len(sources) > 0 && sources[0].Addr().Is6()
|
||||||
return nil
|
}
|
||||||
|
|
||||||
|
// DeletePeerRule from the firewall by rule definition
|
||||||
|
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if m.hasIPv6() && isIPv6IptRule(rule) {
|
||||||
|
return m.aclMgr6.DeletePeerRule(rule)
|
||||||
|
}
|
||||||
|
return m.aclMgr.DeletePeerRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIPv6IptRule(rule firewall.Rule) bool {
|
||||||
|
r, ok := rule.(*Rule)
|
||||||
|
return ok && r.v6
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRouteRule deletes a routing rule.
|
||||||
|
// Route rules are keyed by content hash. Check v4 first, try v6 if not found.
|
||||||
|
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if m.hasIPv6() && !m.router.hasRule(rule.ID()) {
|
||||||
|
return m.router6.DeleteRouteRule(rule)
|
||||||
|
}
|
||||||
|
return m.router.DeleteRouteRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) IsServerRouteSupported() bool {
|
func (m *Manager) IsServerRouteSupported() bool {
|
||||||
@@ -218,10 +272,10 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
|
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
|
||||||
}
|
}
|
||||||
return m.family6.AddNatRule(pair)
|
return m.router6.AddNatRule(pair)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.family4.AddNatRule(pair); err != nil {
|
if err := m.router.AddNatRule(pair); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -230,7 +284,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
// wildcard 0.0.0.0/0 destination where the client resolves DNS.
|
// wildcard 0.0.0.0/0 destination where the client resolves DNS.
|
||||||
if m.hasIPv6() && pair.Dynamic {
|
if m.hasIPv6() && pair.Dynamic {
|
||||||
v6Pair := firewall.ToV6NatPair(pair)
|
v6Pair := firewall.ToV6NatPair(pair)
|
||||||
if err := m.family6.AddNatRule(v6Pair); err != nil {
|
if err := m.router6.AddNatRule(v6Pair); err != nil {
|
||||||
return fmt.Errorf("add v6 NAT rule: %w", err)
|
return fmt.Errorf("add v6 NAT rule: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -246,18 +300,18 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return m.family6.RemoveNatRule(pair)
|
return m.router6.RemoveNatRule(pair)
|
||||||
}
|
}
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
|
||||||
if err := m.family4.RemoveNatRule(pair); err != nil {
|
if err := m.router.RemoveNatRule(pair); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.hasIPv6() && pair.Dynamic {
|
if m.hasIPv6() && pair.Dynamic {
|
||||||
v6Pair := firewall.ToV6NatPair(pair)
|
v6Pair := firewall.ToV6NatPair(pair)
|
||||||
if err := m.family6.RemoveNatRule(v6Pair); err != nil {
|
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -266,14 +320,11 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||||
m.mutex.Lock()
|
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
if err := firewall.SetLegacyManagement(m.family4, isLegacy); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if m.hasIPv6() {
|
if m.hasIPv6() {
|
||||||
return firewall.SetLegacyManagement(m.family6, isLegacy)
|
return firewall.SetLegacyManagement(m.router6, isLegacy)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -290,13 +341,19 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if m.hasIPv6() {
|
if m.hasIPv6() {
|
||||||
if err := m.family6.Reset(); err != nil {
|
if err := m.aclMgr6.Reset(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 family: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("reset v6 acl manager: %w", err))
|
||||||
|
}
|
||||||
|
if err := m.router6.Reset(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %w", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.family4.Reset(); err != nil {
|
if err := m.aclMgr.Reset(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("reset family: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
|
||||||
|
}
|
||||||
|
if err := m.router.Reset(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Appending to merr intentionally blocks DeleteState below so ShutdownState
|
// Appending to merr intentionally blocks DeleteState below so ShutdownState
|
||||||
@@ -315,6 +372,27 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AllowNetbird allows netbird interface traffic.
|
||||||
|
// This is called when USPFilter wraps the native firewall, adding blanket accept
|
||||||
|
// rules so that packet filtering is handled in userspace instead of by netfilter.
|
||||||
|
func (m *Manager) AllowNetbird() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
if _, err := m.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("allow netbird v4 interface traffic: %w", err))
|
||||||
|
}
|
||||||
|
if m.hasIPv6() {
|
||||||
|
if _, err := m.AddPeerFiltering(nil, net.IPv6zero, firewall.ProtocolALL, nil, nil, firewall.ActionAccept, ""); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("allow netbird v6 interface traffic: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||||
|
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
func (m *Manager) Flush() error { return nil }
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
@@ -324,14 +402,14 @@ func (m *Manager) SetLogLevel(log.Level) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) EnableRouting() error {
|
func (m *Manager) EnableRouting() error {
|
||||||
if err := m.family4.ipFwdState.RequestForwarding(); err != nil {
|
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
||||||
return fmt.Errorf("enable IP forwarding: %w", err)
|
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) DisableRouting() error {
|
func (m *Manager) DisableRouting() error {
|
||||||
if err := m.family4.ipFwdState.ReleaseForwarding(); err != nil {
|
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
return fmt.Errorf("disable IP forwarding: %w", err)
|
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -346,9 +424,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
|
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
|
||||||
}
|
}
|
||||||
return m.family6.AddDNATRule(rule)
|
return m.router6.AddDNATRule(rule)
|
||||||
}
|
}
|
||||||
return m.family4.AddDNATRule(rule)
|
return m.router.AddDNATRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteDNATRule deletes a DNAT rule
|
// DeleteDNATRule deletes a DNAT rule
|
||||||
@@ -356,10 +434,10 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if m.hasIPv6() && !m.family4.hasDNATRule(rule.ID()) {
|
if m.hasIPv6() && !m.router.hasRule(rule.ID()+dnatSuffix) {
|
||||||
return m.family6.DeleteDNATRule(rule)
|
return m.router6.DeleteDNATRule(rule)
|
||||||
}
|
}
|
||||||
return m.family4.DeleteDNATRule(rule)
|
return m.router.DeleteDNATRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateSet updates the set with the given prefixes
|
// UpdateSet updates the set with the given prefixes
|
||||||
@@ -376,12 +454,12 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.family4.UpdateSet(set, v4Prefixes); err != nil {
|
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.hasIPv6() && len(v6Prefixes) > 0 {
|
if m.hasIPv6() && len(v6Prefixes) > 0 {
|
||||||
if err := m.family6.UpdateSet(set, v6Prefixes); err != nil {
|
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
|
||||||
return fmt.Errorf("update v6 set: %w", err)
|
return fmt.Errorf("update v6 set: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -398,9 +476,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||||
}
|
}
|
||||||
return m.family6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
return m.family4.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||||
@@ -412,9 +490,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||||
}
|
}
|
||||||
return m.family6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
return m.family4.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||||
@@ -426,9 +504,9 @@ func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||||
}
|
}
|
||||||
return m.family6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
return m.family4.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||||
@@ -440,14 +518,14 @@ func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Proto
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||||
}
|
}
|
||||||
return m.family6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
return m.family4.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
chainNameRaw = "NETBIRD-RAW"
|
chainNameRaw = "NETBIRD-RAW"
|
||||||
chainOutput = "OUTPUT"
|
chainOUTPUT = "OUTPUT"
|
||||||
tableRaw = "raw"
|
tableRaw = "raw"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -522,15 +600,15 @@ func (m *Manager) initNoTrackChain() error {
|
|||||||
|
|
||||||
jumpRule := []string{"-j", chainNameRaw}
|
jumpRule := []string{"-j", chainNameRaw}
|
||||||
|
|
||||||
if err := m.ipv4Client.InsertUnique(tableRaw, chainOutput, 1, jumpRule...); err != nil {
|
if err := m.ipv4Client.InsertUnique(tableRaw, chainOUTPUT, 1, jumpRule...); err != nil {
|
||||||
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
|
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
|
||||||
log.Debugf("delete orphan chain: %v", delErr)
|
log.Debugf("delete orphan chain: %v", delErr)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("add output jump rule: %w", err)
|
return fmt.Errorf("add output jump rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.ipv4Client.InsertUnique(tableRaw, chainPrerouting, 1, jumpRule...); err != nil {
|
if err := m.ipv4Client.InsertUnique(tableRaw, chainPREROUTING, 1, jumpRule...); err != nil {
|
||||||
if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOutput, jumpRule...); delErr != nil {
|
if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); delErr != nil {
|
||||||
log.Debugf("delete output jump rule: %v", delErr)
|
log.Debugf("delete output jump rule: %v", delErr)
|
||||||
}
|
}
|
||||||
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
|
if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil {
|
||||||
@@ -557,11 +635,11 @@ func (m *Manager) cleanupNoTrackChain() error {
|
|||||||
|
|
||||||
jumpRule := []string{"-j", chainNameRaw}
|
jumpRule := []string{"-j", chainNameRaw}
|
||||||
|
|
||||||
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOutput, jumpRule...); err != nil {
|
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); err != nil {
|
||||||
return fmt.Errorf("remove output jump rule: %w", err)
|
return fmt.Errorf("remove output jump rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPrerouting, jumpRule...); err != nil {
|
if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPREROUTING, jumpRule...); err != nil {
|
||||||
return fmt.Errorf("remove prerouting jump rule: %w", err)
|
return fmt.Errorf("remove prerouting jump rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -576,13 +654,3 @@ func (m *Manager) cleanupNoTrackChain() error {
|
|||||||
func getConntrackEstablished() []string {
|
func getConntrackEstablished() []string {
|
||||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||||
}
|
}
|
||||||
|
|
||||||
// isIPv6Rule reports whether the rule belongs to the IPv6 family, from
|
|
||||||
// the destination prefix when set, otherwise from the (single-family)
|
|
||||||
// sources.
|
|
||||||
func isIPv6Rule(sources []netip.Prefix, destination firewall.Network) bool {
|
|
||||||
if destination.IsPrefix() {
|
|
||||||
return destination.Prefix.Addr().Is6()
|
|
||||||
}
|
|
||||||
return len(sources) > 0 && sources[0].Addr().Is6()
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
//go:build integration && !android
|
|
||||||
|
|
||||||
package iptables
|
package iptables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -67,39 +65,46 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var rule2 fw.Rule
|
var rule2 []fw.Rule
|
||||||
t.Run("add second rule", func(t *testing.T) {
|
t.Run("add second rule", func(t *testing.T) {
|
||||||
ip := netip.MustParseAddr("10.20.0.3")
|
ip := netip.MustParseAddr("10.20.0.3")
|
||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
Values: []uint16{8043, 8046},
|
Values: []uint16{8043, 8046},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", port, nil, fw.ActionAccept)
|
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
rr := rule2.(*Rule)
|
for _, r := range rule2 {
|
||||||
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
|
rr := r.(*Rule)
|
||||||
|
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("delete second rule", func(t *testing.T) {
|
t.Run("delete second rule", func(t *testing.T) {
|
||||||
require.NoError(t, manager.DeleteFilterRule(rule2), "failed to delete rule")
|
for _, r := range rule2 {
|
||||||
|
err := manager.DeletePeerRule(r)
|
||||||
|
require.NoError(t, err, "failed to delete rule")
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("reset check", func(t *testing.T) {
|
t.Run("reset check", func(t *testing.T) {
|
||||||
// add second rule
|
// add second rule
|
||||||
ip := netip.MustParseAddr("10.20.0.3")
|
ip := netip.MustParseAddr("10.20.0.3")
|
||||||
port := &fw.Port{Values: []uint16{5353}}
|
port := &fw.Port{Values: []uint16{5353}}
|
||||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "udp", nil, port, fw.ActionAccept)
|
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Close(nil)
|
err = manager.Close(nil)
|
||||||
require.NoError(t, err, "failed to reset")
|
require.NoError(t, err, "failed to reset")
|
||||||
|
|
||||||
ok, err := ipv4Client.ChainExists("filter", chainACLInput)
|
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
|
||||||
require.NoError(t, err, "failed check chain exists")
|
require.NoError(t, err, "failed check chain exists")
|
||||||
|
|
||||||
if ok {
|
if ok {
|
||||||
require.NoErrorf(t, err, "chain '%v' still exists after Close", chainACLInput)
|
require.NoErrorf(t, err, "chain '%v' still exists after Close", chainNameInputRules)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -121,13 +126,15 @@ func TestIptablesManagerDenyRules(t *testing.T) {
|
|||||||
ip := netip.MustParseAddr("10.20.0.3")
|
ip := netip.MustParseAddr("10.20.0.3")
|
||||||
port := &fw.Port{Values: []uint16{22}}
|
port := &fw.Port{Values: []uint16{22}}
|
||||||
|
|
||||||
rule, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionDrop)
|
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionDrop, "deny-ssh")
|
||||||
require.NoError(t, err, "failed to add deny rule")
|
require.NoError(t, err, "failed to add deny rule")
|
||||||
require.NotNil(t, rule, "deny rule should not be nil")
|
require.NotEmpty(t, rule, "deny rule should not be empty")
|
||||||
|
|
||||||
// Verify the rule was added by checking iptables
|
// Verify the rule was added by checking iptables
|
||||||
rr := rule.(*Rule)
|
for _, r := range rule {
|
||||||
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
|
rr := r.(*Rule)
|
||||||
|
checkRuleSpecs(t, ipv4Client, rr.chain, true, rr.specs...)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("deny rule precedence test", func(t *testing.T) {
|
t.Run("deny rule precedence test", func(t *testing.T) {
|
||||||
@@ -135,40 +142,36 @@ func TestIptablesManagerDenyRules(t *testing.T) {
|
|||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
|
||||||
// Add accept rule first
|
// Add accept rule first
|
||||||
_, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
|
_, err := manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "accept-http")
|
||||||
require.NoError(t, err, "failed to add accept rule")
|
require.NoError(t, err, "failed to add accept rule")
|
||||||
|
|
||||||
// Add deny rule second for same IP/port - this should take precedence
|
// Add deny rule second for same IP/port - this should take precedence
|
||||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionDrop)
|
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionDrop, "deny-http")
|
||||||
require.NoError(t, err, "failed to add deny rule")
|
require.NoError(t, err, "failed to add deny rule")
|
||||||
|
|
||||||
// Inspect the actual iptables rules to verify deny rule comes before accept rule
|
// Inspect the actual iptables rules to verify deny rule comes before accept rule
|
||||||
rules, err := ipv4Client.List("filter", chainACLInput)
|
rules, err := ipv4Client.List("filter", chainNameInputRules)
|
||||||
require.NoError(t, err, "failed to list iptables rules")
|
require.NoError(t, err, "failed to list iptables rules")
|
||||||
|
|
||||||
// Debug: print all rules
|
// Debug: print all rules
|
||||||
t.Logf("All iptables rules in chain %s:", chainACLInput)
|
t.Logf("All iptables rules in chain %s:", chainNameInputRules)
|
||||||
for i, rule := range rules {
|
for i, rule := range rules {
|
||||||
t.Logf(" [%d] %s", i, rule)
|
t.Logf(" [%d] %s", i, rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Single-source rules emit a direct `-s <ip>/32 ... --dport 80`
|
|
||||||
// match. Match on that shape instead of the legacy
|
|
||||||
// per-(action,port) ipset names ("deny-http"/"accept-http")
|
|
||||||
// that this test predates.
|
|
||||||
srcMatch := fmt.Sprintf("-s %s/32", ip)
|
|
||||||
var denyRuleIndex, acceptRuleIndex = -1, -1
|
var denyRuleIndex, acceptRuleIndex = -1, -1
|
||||||
for i, rule := range rules {
|
for i, rule := range rules {
|
||||||
if !strings.Contains(rule, srcMatch) || !strings.Contains(rule, "--dport 80") {
|
if strings.Contains(rule, "DROP") {
|
||||||
continue
|
|
||||||
}
|
|
||||||
if strings.Contains(rule, "-j DROP") {
|
|
||||||
t.Logf("Found DROP rule at index %d: %s", i, rule)
|
t.Logf("Found DROP rule at index %d: %s", i, rule)
|
||||||
denyRuleIndex = i
|
if strings.Contains(rule, "deny-http") && strings.Contains(rule, "80") {
|
||||||
|
denyRuleIndex = i
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if strings.Contains(rule, "-j ACCEPT") {
|
if strings.Contains(rule, "ACCEPT") {
|
||||||
t.Logf("Found ACCEPT rule at index %d: %s", i, rule)
|
t.Logf("Found ACCEPT rule at index %d: %s", i, rule)
|
||||||
acceptRuleIndex = i
|
if strings.Contains(rule, "accept-http") && strings.Contains(rule, "80") {
|
||||||
|
acceptRuleIndex = i
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -193,6 +196,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// just check on the local interface
|
||||||
manager, err := Create(mock, iface.DefaultMTU)
|
manager, err := Create(mock, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, manager.Init(nil))
|
require.NoError(t, manager.Init(nil))
|
||||||
@@ -206,39 +210,27 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var rule2 fw.Rule
|
var rule2 []fw.Rule
|
||||||
t.Run("single source uses direct -s match (no ipset)", func(t *testing.T) {
|
t.Run("add second rule", func(t *testing.T) {
|
||||||
ip := netip.MustParseAddr("10.20.0.3")
|
ip := netip.MustParseAddr("10.20.0.3")
|
||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
Values: []uint16{443},
|
Values: []uint16{443},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", port, nil, fw.ActionAccept)
|
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
|
||||||
require.NoError(t, err, "failed to add rule")
|
for _, r := range rule2 {
|
||||||
require.NotNil(t, rule2)
|
require.NoError(t, err, "failed to add rule")
|
||||||
require.Contains(t, rule2.(*Rule).specs, "-s",
|
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||||
"single-source rule should use direct -s match, not an ipset")
|
require.Equal(t, r.(*Rule).ip, "10.20.0.3", "ipset IP must be set")
|
||||||
require.Empty(t, findSets(rule2.(*Rule).specs),
|
|
||||||
"single-source rule should not allocate a shared ipset")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("delete single-source rule", func(t *testing.T) {
|
|
||||||
require.NoError(t, manager.DeleteFilterRule(rule2), "failed to delete rule")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("multi-source uses shared ipset", func(t *testing.T) {
|
|
||||||
sources := []netip.Prefix{
|
|
||||||
netip.PrefixFrom(netip.MustParseAddr("10.20.0.3"), 32),
|
|
||||||
netip.PrefixFrom(netip.MustParseAddr("10.20.0.4"), 32),
|
|
||||||
netip.PrefixFrom(netip.MustParseAddr("10.20.0.5"), 32),
|
|
||||||
}
|
}
|
||||||
port := &fw.Port{Values: []uint16{8080}}
|
})
|
||||||
multi, err := manager.AddFilterRule(nil, sources, fw.Network{}, "tcp", nil, port, fw.ActionAccept)
|
|
||||||
require.NoError(t, err, "failed to add multi-source rule")
|
|
||||||
require.NotNil(t, multi, "multi-source rule must produce one iptables rule")
|
|
||||||
sets := findSets(multi.(*Rule).specs)
|
|
||||||
require.Len(t, sets, 1, "multi-source rule must reference exactly one ipset")
|
|
||||||
|
|
||||||
require.NoError(t, manager.DeleteFilterRule(multi))
|
t.Run("delete second rule", func(t *testing.T) {
|
||||||
|
for _, r := range rule2 {
|
||||||
|
err := manager.DeletePeerRule(r)
|
||||||
|
require.NoError(t, err, "failed to delete rule")
|
||||||
|
|
||||||
|
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("reset check", func(t *testing.T) {
|
t.Run("reset check", func(t *testing.T) {
|
||||||
@@ -289,7 +281,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
|
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
}
|
}
|
||||||
|
|||||||
1147
client/firewall/iptables/router_linux.go
Normal file
1147
client/firewall/iptables/router_linux.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
|||||||
//go:build integration && !android
|
//go:build !android
|
||||||
|
|
||||||
package iptables
|
package iptables
|
||||||
|
|
||||||
@@ -31,7 +31,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
require.NoError(t, err, "failed to init iptables client")
|
require.NoError(t, err, "failed to init iptables client")
|
||||||
|
|
||||||
manager, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
|
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||||
require.NoError(t, err, "should return a valid iptables manager")
|
require.NoError(t, err, "should return a valid iptables manager")
|
||||||
require.NoError(t, manager.init(nil))
|
require.NoError(t, manager.init(nil))
|
||||||
|
|
||||||
@@ -52,12 +52,12 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
// 11. MSS clamping rule for outbound traffic
|
// 11. MSS clamping rule for outbound traffic
|
||||||
require.Len(t, manager.rules, 11, "should have created rules map")
|
require.Len(t, manager.rules, 11, "should have created rules map")
|
||||||
|
|
||||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPostrouting, "-j", chainRTNAT)
|
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPostrouting)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||||
require.True(t, exists, "postrouting jump rule should exist")
|
require.True(t, exists, "postrouting jump rule should exist")
|
||||||
|
|
||||||
exists, err = manager.iptablesClient.Exists(tableMangle, chainPrerouting, "-j", chainRTPre)
|
exists, err = manager.iptablesClient.Exists(tableMangle, chainPREROUTING, "-j", chainRTPRE)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPrerouting)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPREROUTING)
|
||||||
require.True(t, exists, "prerouting jump rule should exist")
|
require.True(t, exists, "prerouting jump rule should exist")
|
||||||
|
|
||||||
pair := firewall.RouterPair{
|
pair := firewall.RouterPair{
|
||||||
@@ -84,7 +84,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
|||||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
require.NoError(t, err, "failed to init iptables client")
|
require.NoError(t, err, "failed to init iptables client")
|
||||||
|
|
||||||
manager, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
|
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
require.NoError(t, manager.init(nil))
|
require.NoError(t, manager.init(nil))
|
||||||
|
|
||||||
@@ -95,7 +95,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
|||||||
err = manager.AddNatRule(testCase.InputPair)
|
err = manager.AddNatRule(testCase.InputPair)
|
||||||
require.NoError(t, err, "marking rule should be inserted")
|
require.NoError(t, err, "marking rule should be inserted")
|
||||||
|
|
||||||
natRuleKey := testCase.InputPair.GenKey(firewall.NatFormat)
|
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||||
markingRule := []string{
|
markingRule := []string{
|
||||||
"-i", ifaceMock.Name(),
|
"-i", ifaceMock.Name(),
|
||||||
"-m", "conntrack",
|
"-m", "conntrack",
|
||||||
@@ -106,8 +106,8 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
|||||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
||||||
}
|
}
|
||||||
|
|
||||||
exists, err := iptablesClient.Exists(tableMangle, chainRTPre, markingRule...)
|
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||||
if testCase.InputPair.Masquerade {
|
if testCase.InputPair.Masquerade {
|
||||||
require.True(t, exists, "marking rule should be created")
|
require.True(t, exists, "marking rule should be created")
|
||||||
foundRule, found := manager.rules[natRuleKey]
|
foundRule, found := manager.rules[natRuleKey]
|
||||||
@@ -121,7 +121,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
|||||||
|
|
||||||
// Check inverse rule
|
// Check inverse rule
|
||||||
inversePair := firewall.GetInversePair(testCase.InputPair)
|
inversePair := firewall.GetInversePair(testCase.InputPair)
|
||||||
inverseRuleKey := inversePair.GenKey(firewall.NatFormat)
|
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
|
||||||
inverseMarkingRule := []string{
|
inverseMarkingRule := []string{
|
||||||
"!", "-i", ifaceMock.Name(),
|
"!", "-i", ifaceMock.Name(),
|
||||||
"-m", "conntrack",
|
"-m", "conntrack",
|
||||||
@@ -132,8 +132,8 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
|||||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||||
}
|
}
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(tableMangle, chainRTPre, inverseMarkingRule...)
|
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||||
if testCase.InputPair.Masquerade {
|
if testCase.InputPair.Masquerade {
|
||||||
require.True(t, exists, "inverse marking rule should be created")
|
require.True(t, exists, "inverse marking rule should be created")
|
||||||
foundRule, found := manager.rules[inverseRuleKey]
|
foundRule, found := manager.rules[inverseRuleKey]
|
||||||
@@ -157,7 +157,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
t.Run(testCase.Name, func(t *testing.T) {
|
t.Run(testCase.Name, func(t *testing.T) {
|
||||||
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
|
|
||||||
manager, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
|
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
require.NoError(t, manager.init(nil))
|
require.NoError(t, manager.init(nil))
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -170,7 +170,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
err = manager.RemoveNatRule(testCase.InputPair)
|
err = manager.RemoveNatRule(testCase.InputPair)
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
|
|
||||||
natRuleKey := testCase.InputPair.GenKey(firewall.NatFormat)
|
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
|
||||||
markingRule := []string{
|
markingRule := []string{
|
||||||
"-i", ifaceMock.Name(),
|
"-i", ifaceMock.Name(),
|
||||||
"-m", "conntrack",
|
"-m", "conntrack",
|
||||||
@@ -181,8 +181,8 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
||||||
}
|
}
|
||||||
|
|
||||||
exists, err := iptablesClient.Exists(tableMangle, chainRTPre, markingRule...)
|
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||||
require.False(t, exists, "marking rule should not exist")
|
require.False(t, exists, "marking rule should not exist")
|
||||||
|
|
||||||
_, found := manager.rules[natRuleKey]
|
_, found := manager.rules[natRuleKey]
|
||||||
@@ -190,7 +190,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
|
|
||||||
// Check inverse rule removal
|
// Check inverse rule removal
|
||||||
inversePair := firewall.GetInversePair(testCase.InputPair)
|
inversePair := firewall.GetInversePair(testCase.InputPair)
|
||||||
inverseRuleKey := inversePair.GenKey(firewall.NatFormat)
|
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
|
||||||
inverseMarkingRule := []string{
|
inverseMarkingRule := []string{
|
||||||
"!", "-i", ifaceMock.Name(),
|
"!", "-i", ifaceMock.Name(),
|
||||||
"-m", "conntrack",
|
"-m", "conntrack",
|
||||||
@@ -201,8 +201,8 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
||||||
}
|
}
|
||||||
|
|
||||||
exists, err = iptablesClient.Exists(tableMangle, chainRTPre, inverseMarkingRule...)
|
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPre)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
|
||||||
require.False(t, exists, "inverse marking rule should not exist")
|
require.False(t, exists, "inverse marking rule should not exist")
|
||||||
|
|
||||||
_, found = manager.rules[inverseRuleKey]
|
_, found = manager.rules[inverseRuleKey]
|
||||||
@@ -219,13 +219,13 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
require.NoError(t, err, "Failed to create iptables client")
|
require.NoError(t, err, "Failed to create iptables client")
|
||||||
|
|
||||||
r, err := newFamily(iptablesClient, ifaceMock, iface.DefaultMTU)
|
r, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||||
require.NoError(t, err, "Failed to create family manager")
|
require.NoError(t, err, "Failed to create router manager")
|
||||||
require.NoError(t, r.init(nil))
|
require.NoError(t, r.init(nil))
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := r.Reset()
|
err := r.Reset()
|
||||||
require.NoError(t, err, "Failed to reset family")
|
require.NoError(t, err, "Failed to reset router")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@@ -334,30 +334,62 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ruleKey, err := r.AddFilterRule(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||||
require.NoError(t, err, "AddFilterRule failed")
|
require.NoError(t, err, "AddRouteFiltering failed")
|
||||||
|
|
||||||
stored, ok := r.filters[ruleKey.ID()]
|
// Check if the rule is in the internal map
|
||||||
require.True(t, ok, "rule not stored in filters")
|
rule, ok := r.rules[ruleKey.ID()]
|
||||||
t.Logf("Internal rule: %v", stored.specs)
|
assert.True(t, ok, "Rule not found in internal map")
|
||||||
|
|
||||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFwdIn, stored.specs...)
|
// Log the internal rule
|
||||||
|
t.Logf("Internal rule: %v", rule)
|
||||||
|
|
||||||
|
// Check if the rule exists in iptables
|
||||||
|
exists, err := iptablesClient.Exists(tableFilter, chainRTFWDIN, rule...)
|
||||||
assert.NoError(t, err, "Failed to check rule existence")
|
assert.NoError(t, err, "Failed to check rule existence")
|
||||||
assert.True(t, exists, "Rule not found in iptables")
|
assert.True(t, exists, "Rule not found in iptables")
|
||||||
|
|
||||||
if tt.expectSet {
|
var source firewall.Network
|
||||||
setName := firewall.NewPrefixSet(tt.sources).HashedName()
|
if len(tt.sources) > 1 {
|
||||||
_, exists := r.ipsetCounter.Get(setName)
|
source.Set = firewall.NewPrefixSet(tt.sources)
|
||||||
assert.True(t, exists, "IPSet not created")
|
} else if len(tt.sources) > 0 {
|
||||||
assert.NotEmpty(t, findSets(stored.specs), "Rule should reference an ipset")
|
source.Prefix = tt.sources[0]
|
||||||
|
}
|
||||||
|
// Verify rule content
|
||||||
|
params := routeFilteringRuleParams{
|
||||||
|
Source: source,
|
||||||
|
Destination: firewall.Network{Prefix: tt.destination},
|
||||||
|
Proto: tt.proto,
|
||||||
|
SPort: tt.sPort,
|
||||||
|
DPort: tt.dPort,
|
||||||
|
Action: tt.action,
|
||||||
}
|
}
|
||||||
|
|
||||||
require.NoError(t, r.DeleteFilterRule(ruleKey), "Failed to delete rule")
|
expectedRule, err := r.genRouteRuleSpec(params, nil)
|
||||||
|
require.NoError(t, err, "Failed to generate expected rule spec")
|
||||||
|
|
||||||
|
if tt.expectSet {
|
||||||
|
setName := firewall.NewPrefixSet(tt.sources).HashedName()
|
||||||
|
expectedRule, err = r.genRouteRuleSpec(params, nil)
|
||||||
|
require.NoError(t, err, "Failed to generate expected rule spec with set")
|
||||||
|
|
||||||
|
// Check if the set was created
|
||||||
|
_, exists := r.ipsetCounter.Get(setName)
|
||||||
|
assert.True(t, exists, "IPSet not created")
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, expectedRule, rule, "Rule content mismatch")
|
||||||
|
|
||||||
|
// Clean up
|
||||||
|
err = r.DeleteRouteRule(ruleKey)
|
||||||
|
require.NoError(t, err, "Failed to delete rule")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFindSetNameInRule(t *testing.T) {
|
func TestFindSetNameInRule(t *testing.T) {
|
||||||
|
r := &router{}
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
rule []string
|
rule []string
|
||||||
@@ -398,7 +430,7 @@ func TestFindSetNameInRule(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
result := findSets(tc.rule)
|
result := r.findSets(tc.rule)
|
||||||
|
|
||||||
if len(result) != len(tc.expected) {
|
if len(result) != len(tc.expected) {
|
||||||
t.Errorf("Expected %d sets, got %d. Sets found: %v", len(tc.expected), len(result), result)
|
t.Errorf("Expected %d sets, got %d. Sets found: %v", len(tc.expected), len(result), result)
|
||||||
|
|||||||
@@ -1,263 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package iptables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (r *family) AddNatRule(pair firewall.RouterPair) error {
|
|
||||||
if r.legacyManagement {
|
|
||||||
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
|
||||||
if err := r.addLegacyRouteRule(pair); err != nil {
|
|
||||||
return fmt.Errorf("add legacy routing rule: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if pair.Masquerade {
|
|
||||||
if err := r.addNatRule(pair); err != nil {
|
|
||||||
return fmt.Errorf("add nat rule: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
|
|
||||||
return fmt.Errorf("add inverse nat rule: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
r.updateState()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
|
||||||
func (r *family) RemoveNatRule(pair firewall.RouterPair) error {
|
|
||||||
if pair.Masquerade {
|
|
||||||
if err := r.removeNatRule(pair); err != nil {
|
|
||||||
return fmt.Errorf("remove nat rule: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
|
||||||
return fmt.Errorf("remove inverse nat rule: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
|
||||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.updateState()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
|
||||||
func (r *family) addLegacyRouteRule(pair firewall.RouterPair) error {
|
|
||||||
ruleID := pair.GenKey(firewall.ForwardingFormat)
|
|
||||||
|
|
||||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", "ACCEPT"}
|
|
||||||
if err := r.iptablesClient.Append(tableFilter, chainRTFwdIn, rule...); err != nil {
|
|
||||||
return fmt.Errorf("add legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.rules[ruleID] = rule
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
|
||||||
ruleID := pair.GenKey(firewall.ForwardingFormat)
|
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleID]; exists {
|
|
||||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFwdIn, rule...); err != nil {
|
|
||||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
|
||||||
}
|
|
||||||
delete(r.rules, ruleID)
|
|
||||||
|
|
||||||
if err := r.decrementSetCounter(rule); err != nil {
|
|
||||||
return fmt.Errorf("decrement ipset counter: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLegacyManagement returns the current legacy management mode
|
|
||||||
func (r *family) GetLegacyManagement() bool {
|
|
||||||
return r.legacyManagement
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetLegacyManagement sets the route manager to use legacy management mode
|
|
||||||
func (r *family) SetLegacyManagement(isLegacy bool) {
|
|
||||||
r.legacyManagement = isLegacy
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
|
|
||||||
func (r *family) RemoveAllLegacyRouteRules() error {
|
|
||||||
var merr *multierror.Error
|
|
||||||
for k, rule := range r.rules {
|
|
||||||
if !strings.HasPrefix(string(k), firewall.ForwardingFormatPrefix) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFwdIn, rule...); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %w", err))
|
|
||||||
} else {
|
|
||||||
delete(r.rules, k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
r.updateState()
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) addPostroutingRules() error {
|
|
||||||
// First rule for outbound masquerade
|
|
||||||
rule1 := []string{
|
|
||||||
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
|
|
||||||
"!", "-o", "lo",
|
|
||||||
"-j", "MASQUERADE",
|
|
||||||
}
|
|
||||||
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule1...); err != nil {
|
|
||||||
return fmt.Errorf("add outbound masquerade rule: %w", err)
|
|
||||||
}
|
|
||||||
r.rules["static-nat-outbound"] = rule1
|
|
||||||
|
|
||||||
// Second rule for return traffic masquerade
|
|
||||||
rule2 := []string{
|
|
||||||
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
|
|
||||||
"-o", r.wgIface.Name(),
|
|
||||||
"-j", "MASQUERADE",
|
|
||||||
}
|
|
||||||
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule2...); err != nil {
|
|
||||||
return fmt.Errorf("add return masquerade rule: %w", err)
|
|
||||||
}
|
|
||||||
r.rules["static-nat-return"] = rule2
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
|
||||||
func (r *family) addMSSClampingRules() error {
|
|
||||||
overhead := uint16(ipv4TCPHeaderSize)
|
|
||||||
if r.v6 {
|
|
||||||
overhead = ipv6TCPHeaderSize
|
|
||||||
}
|
|
||||||
mss := r.mtu - overhead
|
|
||||||
|
|
||||||
// Add jump rule from FORWARD chain in mangle table to our custom chain
|
|
||||||
jumpRule := jumpRuleSpec(chainRTMSSClamp)
|
|
||||||
if err := r.iptablesClient.Insert(tableMangle, chainForward, 1, jumpRule...); err != nil {
|
|
||||||
return fmt.Errorf("add jump to MSS clamp chain: %w", err)
|
|
||||||
}
|
|
||||||
r.rules[jumpMSSClamp] = jumpRule
|
|
||||||
|
|
||||||
ruleOut := []string{
|
|
||||||
"-o", r.wgIface.Name(),
|
|
||||||
"-p", "tcp",
|
|
||||||
"--tcp-flags", "SYN,RST", "SYN",
|
|
||||||
"-j", "TCPMSS",
|
|
||||||
"--set-mss", fmt.Sprintf("%d", mss),
|
|
||||||
}
|
|
||||||
if err := r.iptablesClient.Append(tableMangle, chainRTMSSClamp, ruleOut...); err != nil {
|
|
||||||
return fmt.Errorf("add outbound MSS clamp rule: %w", err)
|
|
||||||
}
|
|
||||||
r.rules["mss-clamp-out"] = ruleOut
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) insertEstablishedRule(chain string) error {
|
|
||||||
establishedRule := getConntrackEstablished()
|
|
||||||
|
|
||||||
err := r.iptablesClient.Insert(tableFilter, chain, 1, establishedRule...)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("insert established rule: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ruleID := firewall.RuleID("established-" + chain)
|
|
||||||
r.rules[ruleID] = establishedRule
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) addNatRule(pair firewall.RouterPair) error {
|
|
||||||
ruleID := pair.GenKey(firewall.NatFormat)
|
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleID]; exists {
|
|
||||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPre, rule...); err != nil {
|
|
||||||
return fmt.Errorf("remove existing marking rule for %s: %w", pair.Destination, err)
|
|
||||||
}
|
|
||||||
delete(r.rules, ruleID)
|
|
||||||
}
|
|
||||||
|
|
||||||
markValue := nbnet.PreroutingFwmarkMasquerade
|
|
||||||
if pair.Inverse {
|
|
||||||
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
|
|
||||||
}
|
|
||||||
|
|
||||||
rule := []string{"-i", r.wgIface.Name()}
|
|
||||||
if pair.Inverse {
|
|
||||||
rule = []string{"!", "-i", r.wgIface.Name()}
|
|
||||||
}
|
|
||||||
|
|
||||||
rule = append(rule,
|
|
||||||
"-m", "conntrack",
|
|
||||||
"--ctstate", "NEW",
|
|
||||||
)
|
|
||||||
sourceExp, err := r.applyNetwork("-s", pair.Source, nil)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("apply network -s: %w", err)
|
|
||||||
}
|
|
||||||
destExp, err := r.applyNetwork("-d", pair.Destination, nil)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("apply network -d: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rule = append(rule, sourceExp...)
|
|
||||||
rule = append(rule, destExp...)
|
|
||||||
rule = append(rule,
|
|
||||||
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
|
|
||||||
)
|
|
||||||
|
|
||||||
// Ensure nat rules come first, so the mark can be overwritten.
|
|
||||||
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
|
|
||||||
if err := r.iptablesClient.Insert(tableMangle, chainRTPre, 1, rule...); err != nil {
|
|
||||||
r.dropSourceMatch(rule)
|
|
||||||
return fmt.Errorf("add marking rule for %s: %w", pair.Destination, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.rules[ruleID] = rule
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) removeNatRule(pair firewall.RouterPair) error {
|
|
||||||
ruleID := pair.GenKey(firewall.NatFormat)
|
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleID]; exists {
|
|
||||||
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPre, rule...); err != nil {
|
|
||||||
return fmt.Errorf("remove marking rule for %s: %w", pair.Destination, err)
|
|
||||||
}
|
|
||||||
delete(r.rules, ruleID)
|
|
||||||
|
|
||||||
if err := r.decrementSetCounter(rule); err != nil {
|
|
||||||
return fmt.Errorf("decrement ipset counter: %w", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.Debugf("marking rule %s not found", ruleID)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,20 +1,18 @@
|
|||||||
package iptables
|
package iptables
|
||||||
|
|
||||||
import "github.com/netbirdio/netbird/client/firewall/manager"
|
// Rule to handle management of rules
|
||||||
|
|
||||||
// Rule to handle management of rules. Source set membership (when the
|
|
||||||
// rule was built against a shared hash:net ipset) is encoded in specs;
|
|
||||||
// DeleteFilterRule recovers it via findSets so the refcounter can drop
|
|
||||||
// the right reference.
|
|
||||||
type Rule struct {
|
type Rule struct {
|
||||||
id manager.RuleID
|
ruleID string
|
||||||
|
ipsetName string
|
||||||
|
|
||||||
specs []string
|
specs []string
|
||||||
mangleSpecs []string
|
mangleSpecs []string
|
||||||
|
ip string
|
||||||
chain string
|
chain string
|
||||||
v6 bool
|
v6 bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ID returns the rule id
|
// GetRuleID returns the rule id
|
||||||
func (r *Rule) ID() manager.RuleID {
|
func (r *Rule) ID() string {
|
||||||
return r.id
|
return r.ruleID
|
||||||
}
|
}
|
||||||
|
|||||||
103
client/firewall/iptables/rulestore_linux.go
Normal file
103
client/firewall/iptables/rulestore_linux.go
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
package iptables
|
||||||
|
|
||||||
|
import "encoding/json"
|
||||||
|
|
||||||
|
type ipList struct {
|
||||||
|
ips map[string]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newIpList(ip string) *ipList {
|
||||||
|
ips := make(map[string]struct{})
|
||||||
|
ips[ip] = struct{}{}
|
||||||
|
|
||||||
|
return &ipList{
|
||||||
|
ips: ips,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipList) addIP(ip string) {
|
||||||
|
s.ips[ip] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements json.Marshaler
|
||||||
|
func (s *ipList) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
IPs map[string]struct{} `json:"ips"`
|
||||||
|
}{
|
||||||
|
IPs: s.ips,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements json.Unmarshaler
|
||||||
|
func (s *ipList) UnmarshalJSON(data []byte) error {
|
||||||
|
temp := struct {
|
||||||
|
IPs map[string]struct{} `json:"ips"`
|
||||||
|
}{}
|
||||||
|
if err := json.Unmarshal(data, &temp); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.ips = temp.IPs
|
||||||
|
|
||||||
|
if temp.IPs == nil {
|
||||||
|
temp.IPs = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type ipsetStore struct {
|
||||||
|
ipsets map[string]*ipList
|
||||||
|
}
|
||||||
|
|
||||||
|
func newIpsetStore() *ipsetStore {
|
||||||
|
return &ipsetStore{
|
||||||
|
ipsets: make(map[string]*ipList),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
|
||||||
|
r, ok := s.ipsets[ipsetName]
|
||||||
|
return r, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipsetStore) addIpList(ipsetName string, list *ipList) {
|
||||||
|
s.ipsets[ipsetName] = list
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipsetStore) deleteIpset(ipsetName string) {
|
||||||
|
delete(s.ipsets, ipsetName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipsetStore) ipsetNames() []string {
|
||||||
|
names := make([]string, 0, len(s.ipsets))
|
||||||
|
for name := range s.ipsets {
|
||||||
|
names = append(names, name)
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements json.Marshaler
|
||||||
|
func (s *ipsetStore) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(struct {
|
||||||
|
IPSets map[string]*ipList `json:"ipsets"`
|
||||||
|
}{
|
||||||
|
IPSets: s.ipsets,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements json.Unmarshaler
|
||||||
|
func (s *ipsetStore) UnmarshalJSON(data []byte) error {
|
||||||
|
temp := struct {
|
||||||
|
IPSets map[string]*ipList `json:"ipsets"`
|
||||||
|
}{}
|
||||||
|
if err := json.Unmarshal(data, &temp); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.ipsets = temp.IPSets
|
||||||
|
|
||||||
|
if temp.IPSets == nil {
|
||||||
|
temp.IPSets = make(map[string]*ipList)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -29,13 +29,17 @@ type ShutdownState struct {
|
|||||||
|
|
||||||
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
|
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
|
||||||
|
|
||||||
RouteRules routeRules `json:"route_rules,omitempty"`
|
RouteRules routeRules `json:"route_rules,omitempty"`
|
||||||
RouteRules6 routeRules `json:"route_rules_v6,omitempty"`
|
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
|
||||||
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
|
|
||||||
RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"`
|
|
||||||
|
|
||||||
ACLEntries aclEntries `json:"acl_entries,omitempty"`
|
ACLEntries aclEntries `json:"acl_entries,omitempty"`
|
||||||
ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"`
|
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
|
||||||
|
|
||||||
|
// IPv6 counterparts
|
||||||
|
RouteRules6 routeRules `json:"route_rules_v6,omitempty"`
|
||||||
|
RouteIPsetCounter6 *ipsetCounter `json:"route_ipset_counter_v6,omitempty"`
|
||||||
|
ACLEntries6 aclEntries `json:"acl_entries_v6,omitempty"`
|
||||||
|
ACLIPsetStore6 *ipsetStore `json:"acl_ipset_store_v6,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ShutdownState) Name() string {
|
func (s *ShutdownState) Name() string {
|
||||||
@@ -53,14 +57,17 @@ func (s *ShutdownState) Cleanup() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if s.RouteRules != nil {
|
if s.RouteRules != nil {
|
||||||
ipt.family4.rules = s.RouteRules
|
ipt.router.rules = s.RouteRules
|
||||||
}
|
}
|
||||||
if s.RouteIPsetCounter != nil {
|
if s.RouteIPsetCounter != nil {
|
||||||
ipt.family4.ipsetCounter.LoadData(s.RouteIPsetCounter)
|
ipt.router.ipsetCounter.LoadData(s.RouteIPsetCounter)
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.ACLEntries != nil {
|
if s.ACLEntries != nil {
|
||||||
ipt.family4.entries = s.ACLEntries
|
ipt.aclMgr.entries = s.ACLEntries
|
||||||
|
}
|
||||||
|
if s.ACLIPsetStore != nil {
|
||||||
|
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clean up v6 state even if the current run has no IPv6.
|
// Clean up v6 state even if the current run has no IPv6.
|
||||||
@@ -72,13 +79,16 @@ func (s *ShutdownState) Cleanup() error {
|
|||||||
}
|
}
|
||||||
if ipt.hasIPv6() {
|
if ipt.hasIPv6() {
|
||||||
if s.RouteRules6 != nil {
|
if s.RouteRules6 != nil {
|
||||||
ipt.family6.rules = s.RouteRules6
|
ipt.router6.rules = s.RouteRules6
|
||||||
}
|
}
|
||||||
if s.RouteIPsetCounter6 != nil {
|
if s.RouteIPsetCounter6 != nil {
|
||||||
ipt.family6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
|
ipt.router6.ipsetCounter.LoadData(s.RouteIPsetCounter6)
|
||||||
}
|
}
|
||||||
if s.ACLEntries6 != nil {
|
if s.ACLEntries6 != nil {
|
||||||
ipt.family6.entries = s.ACLEntries6
|
ipt.aclMgr6.entries = s.ACLEntries6
|
||||||
|
}
|
||||||
|
if s.ACLIPsetStore6 != nil {
|
||||||
|
ipt.aclMgr6.ipsetStore = s.ACLIPsetStore6
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,27 +0,0 @@
|
|||||||
//go:build integration && !android
|
|
||||||
|
|
||||||
package iptables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
)
|
|
||||||
|
|
||||||
func pfx(ip net.IP) []netip.Prefix {
|
|
||||||
if ip == nil {
|
|
||||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
|
||||||
}
|
|
||||||
if ip.IsUnspecified() {
|
|
||||||
if ip.To4() != nil {
|
|
||||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
|
||||||
}
|
|
||||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
|
||||||
}
|
|
||||||
a, ok := netip.AddrFromSlice(ip)
|
|
||||||
if !ok {
|
|
||||||
panic(fmt.Sprintf("invalid IP length: %d", len(ip)))
|
|
||||||
}
|
|
||||||
a = a.Unmap()
|
|
||||||
return []netip.Prefix{netip.PrefixFrom(a, a.BitLen())}
|
|
||||||
}
|
|
||||||
@@ -3,6 +3,7 @@ package manager
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
@@ -15,12 +16,6 @@ import (
|
|||||||
// method but the IPv6 firewall components were not initialized.
|
// method but the IPv6 firewall components were not initialized.
|
||||||
var ErrIPv6NotInitialized = errors.New("IPv6 firewall not initialized")
|
var ErrIPv6NotInitialized = errors.New("IPv6 firewall not initialized")
|
||||||
|
|
||||||
// ErrNoSources is returned when AddFilterRule is called with an empty
|
|
||||||
// source list. "Match any source" must be expressed explicitly with a
|
|
||||||
// /0 prefix; an empty list is a caller error and is rejected rather
|
|
||||||
// than silently widening the rule to every source.
|
|
||||||
var ErrNoSources = errors.New("rule has no sources")
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ForwardingFormatPrefix = "netbird-fwd-"
|
ForwardingFormatPrefix = "netbird-fwd-"
|
||||||
ForwardingFormat = "netbird-fwd-%s-%t"
|
ForwardingFormat = "netbird-fwd-%s-%t"
|
||||||
@@ -28,18 +23,13 @@ const (
|
|||||||
NatFormat = "netbird-nat-%s-%t"
|
NatFormat = "netbird-nat-%s-%t"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RuleID identifies a firewall rule. It is a typed string so the
|
|
||||||
// compiler catches accidental mixing with arbitrary string keys. It is
|
|
||||||
// only an identifier and does not implement Rule.
|
|
||||||
type RuleID string
|
|
||||||
|
|
||||||
// Rule abstraction should be implemented by each firewall manager
|
// Rule abstraction should be implemented by each firewall manager
|
||||||
//
|
//
|
||||||
// Each firewall type for different OS can use different type
|
// Each firewall type for different OS can use different type
|
||||||
// of the properties to hold data of the created rule
|
// of the properties to hold data of the created rule
|
||||||
type Rule interface {
|
type Rule interface {
|
||||||
// ID returns the rule id
|
// ID returns the rule id
|
||||||
ID() RuleID
|
ID() string
|
||||||
}
|
}
|
||||||
|
|
||||||
// RuleDirection is the traffic direction which a rule is applied
|
// RuleDirection is the traffic direction which a rule is applied
|
||||||
@@ -101,13 +91,6 @@ func (d Network) IsPrefix() bool {
|
|||||||
return d.Prefix.IsValid()
|
return d.Prefix.IsValid()
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsZero returns true if the network designates no destination, i.e. it
|
|
||||||
// is the zero value. A zero Network is the peer-rule sentinel; a non-zero
|
|
||||||
// one carries a prefix or set destination.
|
|
||||||
func (d Network) IsZero() bool {
|
|
||||||
return !d.IsPrefix() && !d.IsSet()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Manager is the high level abstraction of a firewall manager
|
// Manager is the high level abstraction of a firewall manager
|
||||||
//
|
//
|
||||||
// It declares methods which handle actions required by the
|
// It declares methods which handle actions required by the
|
||||||
@@ -115,42 +98,46 @@ func (d Network) IsZero() bool {
|
|||||||
type Manager interface {
|
type Manager interface {
|
||||||
Init(stateManager *statemanager.Manager) error
|
Init(stateManager *statemanager.Manager) error
|
||||||
|
|
||||||
// AddFilterRule adds a packet-filtering rule to the firewall.
|
// AllowNetbird allows netbird interface traffic
|
||||||
|
AllowNetbird() error
|
||||||
|
|
||||||
|
// AddPeerFiltering adds a rule to the firewall
|
||||||
//
|
//
|
||||||
// If destination is the zero Network, the rule applies to traffic
|
// If comment argument is empty firewall manager should set
|
||||||
// inbound to this node, i.e. peer ACL semantics, installed in
|
// rule ID as comment for the rule
|
||||||
// the kernel's input chain. If destination is set (prefix or
|
|
||||||
// set), the rule applies to forwarded traffic with that
|
|
||||||
// destination, route ACL semantics, installed in the forward
|
|
||||||
// chain.
|
|
||||||
//
|
//
|
||||||
// sources must be a single address family; the caller splits mixed
|
// Note: Callers should call Flush() after adding rules to ensure
|
||||||
// families and calls once per family. "Match any source" must be
|
// they are applied to the kernel and rule handles are refreshed.
|
||||||
// expressed with an explicit /0 prefix; an empty sources list is
|
AddPeerFiltering(
|
||||||
// rejected with ErrNoSources so a zeroed list can never widen a
|
|
||||||
// rule to every source.
|
|
||||||
//
|
|
||||||
// Note: callers should call Flush() after adding rules.
|
|
||||||
AddFilterRule(
|
|
||||||
id []byte,
|
id []byte,
|
||||||
sources []netip.Prefix,
|
ip net.IP,
|
||||||
destination Network,
|
|
||||||
proto Protocol,
|
proto Protocol,
|
||||||
sPort *Port,
|
sPort *Port,
|
||||||
dPort *Port,
|
dPort *Port,
|
||||||
action Action,
|
action Action,
|
||||||
) (Rule, error)
|
ipsetName string,
|
||||||
|
) ([]Rule, error)
|
||||||
|
|
||||||
// DeleteFilterRule removes a filtering rule previously added via
|
// DeletePeerRule from the firewall by rule definition
|
||||||
// AddFilterRule. The rule's own type identifies whether it lives
|
DeletePeerRule(rule Rule) error
|
||||||
// in the peer (input) or route (forward) path.
|
|
||||||
DeleteFilterRule(rule Rule) error
|
|
||||||
|
|
||||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||||
IsServerRouteSupported() bool
|
IsServerRouteSupported() bool
|
||||||
|
|
||||||
IsStateful() bool
|
IsStateful() bool
|
||||||
|
|
||||||
|
AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
|
sources []netip.Prefix,
|
||||||
|
destination Network,
|
||||||
|
proto Protocol,
|
||||||
|
sPort, dPort *Port,
|
||||||
|
action Action,
|
||||||
|
) (Rule, error)
|
||||||
|
|
||||||
|
// DeleteRouteRule deletes a routing rule
|
||||||
|
DeleteRouteRule(rule Rule) error
|
||||||
|
|
||||||
// AddNatRule inserts a routing NAT rule
|
// AddNatRule inserts a routing NAT rule
|
||||||
AddNatRule(pair RouterPair) error
|
AddNatRule(pair RouterPair) error
|
||||||
|
|
||||||
@@ -198,9 +185,8 @@ type Manager interface {
|
|||||||
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
|
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenKey builds the rule id for this pair from the given format.
|
func GenKey(format string, pair RouterPair) string {
|
||||||
func (p RouterPair) GenKey(format string) RuleID {
|
return fmt.Sprintf(format, pair.ID, pair.Inverse)
|
||||||
return RuleID(fmt.Sprintf(format, p.ID, p.Inverse))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// LegacyManager defines the interface for legacy management operations
|
// LegacyManager defines the interface for legacy management operations
|
||||||
@@ -256,20 +242,6 @@ func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
|
|||||||
return merged
|
return merged
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmapPrefix normalizes a v4-mapped v6 prefix (::ffff:a.b.c.d) to its
|
|
||||||
// plain v4 form, shifting the prefix length out of the 96-bit mapped
|
|
||||||
// range. Other prefixes are returned unchanged. Keeping prefixes
|
|
||||||
// unmapped ensures v4 rules match consistently and the match builders
|
|
||||||
// read the correct address length.
|
|
||||||
func UnmapPrefix(p netip.Prefix) netip.Prefix {
|
|
||||||
addr := p.Addr()
|
|
||||||
if !addr.Is4In6() {
|
|
||||||
return p
|
|
||||||
}
|
|
||||||
bits := max(p.Bits()-96, 0)
|
|
||||||
return netip.PrefixFrom(addr.Unmap(), bits)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SortPrefixes sorts the given slice of netip.Prefix in place.
|
// SortPrefixes sorts the given slice of netip.Prefix in place.
|
||||||
// It sorts first by IP address, then by prefix length (most specific to least specific).
|
// It sorts first by IP address, then by prefix length (most specific to least specific).
|
||||||
func SortPrefixes(prefixes []netip.Prefix) {
|
func SortPrefixes(prefixes []netip.Prefix) {
|
||||||
|
|||||||
@@ -13,13 +13,13 @@ type ForwardRule struct {
|
|||||||
TranslatedPort Port
|
TranslatedPort Port
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r ForwardRule) ID() RuleID {
|
func (r ForwardRule) ID() string {
|
||||||
id := fmt.Sprintf("%s;%s;%s;%s",
|
id := fmt.Sprintf("%s;%s;%s;%s",
|
||||||
r.Protocol,
|
r.Protocol,
|
||||||
r.DestinationPort.String(),
|
r.DestinationPort.String(),
|
||||||
r.TranslatedAddress.String(),
|
r.TranslatedAddress.String(),
|
||||||
r.TranslatedPort.String())
|
r.TranslatedPort.String())
|
||||||
return RuleID(id)
|
return id
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r ForwardRule) String() string {
|
func (r ForwardRule) String() string {
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ func (h Set) Comment() string {
|
|||||||
|
|
||||||
// NewPrefixSet generates a unique name for an ipset based on the given prefixes.
|
// NewPrefixSet generates a unique name for an ipset based on the given prefixes.
|
||||||
func NewPrefixSet(prefixes []netip.Prefix) Set {
|
func NewPrefixSet(prefixes []netip.Prefix) Set {
|
||||||
prefixes = slices.Clone(prefixes)
|
// sort for consistent naming
|
||||||
SortPrefixes(prefixes)
|
SortPrefixes(prefixes)
|
||||||
|
|
||||||
hash := sha256.New()
|
hash := sha256.New()
|
||||||
|
|||||||
713
client/firewall/nftables/acl_linux.go
Normal file
713
client/firewall/nftables/acl_linux.go
Normal file
@@ -0,0 +1,713 @@
|
|||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/nftables"
|
||||||
|
"github.com/google/nftables/binaryutil"
|
||||||
|
"github.com/google/nftables/expr"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
|
||||||
|
// rules chains contains the effective ACL rules
|
||||||
|
chainNameInputRules = "netbird-acl-input-rules"
|
||||||
|
|
||||||
|
// filter chains contains the rules that jump to the rules chains
|
||||||
|
chainNameInputFilter = "netbird-acl-input-filter"
|
||||||
|
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||||
|
chainNameManglePrerouting = "netbird-mangle-prerouting"
|
||||||
|
chainNameManglePostrouting = "netbird-mangle-postrouting"
|
||||||
|
)
|
||||||
|
|
||||||
|
const flushError = "flush: %w"
|
||||||
|
|
||||||
|
type AclManager struct {
|
||||||
|
rConn *nftables.Conn
|
||||||
|
sConn *nftables.Conn
|
||||||
|
wgIface iFaceMapper
|
||||||
|
routingFwChainName string
|
||||||
|
af addrFamily
|
||||||
|
|
||||||
|
workTable *nftables.Table
|
||||||
|
chainInputRules *nftables.Chain
|
||||||
|
chainPrerouting *nftables.Chain
|
||||||
|
|
||||||
|
ipsetStore *ipsetStore
|
||||||
|
rules map[string]*Rule
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) {
|
||||||
|
// sConn is used for creating sets and adding/removing elements from them
|
||||||
|
// it's differ then rConn (which does create new conn for each flush operation)
|
||||||
|
// and is permanent. Using same connection for both type of operations
|
||||||
|
// overloads netlink with high amount of rules ( > 10000)
|
||||||
|
sConn, err := nftables.New(nftables.AsLasting())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create nf conn: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &AclManager{
|
||||||
|
rConn: &nftables.Conn{},
|
||||||
|
sConn: sConn,
|
||||||
|
wgIface: wgIface,
|
||||||
|
workTable: table,
|
||||||
|
routingFwChainName: routingFwChainName,
|
||||||
|
af: familyForAddr(table.Family == nftables.TableFamilyIPv4),
|
||||||
|
|
||||||
|
ipsetStore: newIpsetStore(),
|
||||||
|
rules: make(map[string]*Rule),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *AclManager) init(workTable *nftables.Table) error {
|
||||||
|
m.workTable = workTable
|
||||||
|
return m.createDefaultChains()
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddPeerFiltering rule to the firewall
|
||||||
|
//
|
||||||
|
// If comment argument is empty firewall manager should set
|
||||||
|
// rule ID as comment for the rule
|
||||||
|
func (m *AclManager) AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
|
ip net.IP,
|
||||||
|
proto firewall.Protocol,
|
||||||
|
sPort *firewall.Port,
|
||||||
|
dPort *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
ipsetName string,
|
||||||
|
) ([]firewall.Rule, error) {
|
||||||
|
var ipset *nftables.Set
|
||||||
|
if ipsetName != "" {
|
||||||
|
var err error
|
||||||
|
ipset, err = m.addIpToSet(ipsetName, ip)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
newRules := make([]firewall.Rule, 0, 2)
|
||||||
|
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
newRules = append(newRules, ioRule)
|
||||||
|
return newRules, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePeerRule from the firewall by rule definition
|
||||||
|
func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||||
|
r, ok := rule.(*Rule)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("invalid rule type")
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.nftSet == nil {
|
||||||
|
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
||||||
|
log.Errorf("failed to delete rule: %v", err)
|
||||||
|
}
|
||||||
|
if r.mangleRule != nil {
|
||||||
|
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
||||||
|
log.Errorf("failed to delete mangle rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete(m.rules, r.ID())
|
||||||
|
return m.rConn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
ips, ok := m.ipsetStore.ips(r.nftSet.Name)
|
||||||
|
if !ok {
|
||||||
|
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
||||||
|
log.Errorf("failed to delete rule: %v", err)
|
||||||
|
}
|
||||||
|
if r.mangleRule != nil {
|
||||||
|
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
||||||
|
log.Errorf("failed to delete mangle rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete(m.rules, r.ID())
|
||||||
|
return m.rConn.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := ips[r.ip.String()]; ok {
|
||||||
|
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: ipToBytes(r.ip, m.af)}})
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("delete elements for set %q: %v", r.nftSet.Name, err)
|
||||||
|
}
|
||||||
|
if err := m.sConn.Flush(); err != nil {
|
||||||
|
log.Debugf("flush error of set delete element, %s", r.nftSet.Name)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.ipsetStore.DeleteIpFromSet(r.nftSet.Name, r.ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// if after delete, set still contains other IPs,
|
||||||
|
// no need to delete firewall rule and we should exit here
|
||||||
|
if len(ips) > 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
||||||
|
log.Errorf("failed to delete rule: %v", err)
|
||||||
|
}
|
||||||
|
if r.mangleRule != nil {
|
||||||
|
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
||||||
|
log.Errorf("failed to delete mangle rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(m.rules, r.ID())
|
||||||
|
m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name)
|
||||||
|
|
||||||
|
if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// we delete last IP from the set, that means we need to delete
|
||||||
|
// set itself and associated firewall rule too
|
||||||
|
m.rConn.FlushSet(r.nftSet)
|
||||||
|
m.rConn.DelSet(r.nftSet)
|
||||||
|
m.ipsetStore.deleteIpset(r.nftSet.Name)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// createDefaultAllowRules creates default allow rules for the input and output chains
|
||||||
|
func (m *AclManager) createDefaultAllowRules() error {
|
||||||
|
expIn := []expr.Any{
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = m.rConn.InsertRule(&nftables.Rule{
|
||||||
|
Table: m.workTable,
|
||||||
|
Chain: m.chainInputRules,
|
||||||
|
Position: 0,
|
||||||
|
Exprs: expIn,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf(flushError, err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush rule/chain/set operations from the buffer
|
||||||
|
//
|
||||||
|
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||||
|
func (m *AclManager) Flush() error {
|
||||||
|
if err := m.flushWithBackoff(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.refreshRuleHandles(m.chainInputRules, false); err != nil {
|
||||||
|
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
||||||
|
}
|
||||||
|
if err := m.refreshRuleHandles(m.chainPrerouting, true); err != nil {
|
||||||
|
log.Errorf("failed to refresh rule handles prerouting chain: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *AclManager) addIOFiltering(
|
||||||
|
ip net.IP,
|
||||||
|
proto firewall.Protocol,
|
||||||
|
sPort *firewall.Port,
|
||||||
|
dPort *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
ipset *nftables.Set,
|
||||||
|
) (*Rule, error) {
|
||||||
|
ruleId := generatePeerRuleId(ip, proto, sPort, dPort, action, ipset)
|
||||||
|
if r, ok := m.rules[ruleId]; ok {
|
||||||
|
return &Rule{
|
||||||
|
nftRule: r.nftRule,
|
||||||
|
mangleRule: r.mangleRule,
|
||||||
|
nftSet: r.nftSet,
|
||||||
|
ruleID: r.ruleID,
|
||||||
|
ip: ip,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var expressions []expr.Any
|
||||||
|
|
||||||
|
if proto != firewall.ProtocolALL {
|
||||||
|
expressions = append(expressions, &expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: m.af.protoOffset,
|
||||||
|
Len: uint32(1),
|
||||||
|
})
|
||||||
|
|
||||||
|
protoData, err := m.af.protoNum(proto)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("convert protocol to number: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expressions = append(expressions, &expr.Cmp{
|
||||||
|
Register: 1,
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Data: []byte{protoData},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
rawIP := ipToBytes(ip, m.af)
|
||||||
|
// check if rawIP contains zeroed IPv4 0.0.0.0 value
|
||||||
|
// in that case not add IP match expression into the rule definition
|
||||||
|
if slices.ContainsFunc(rawIP, func(v byte) bool { return v != 0 }) {
|
||||||
|
expressions = append(expressions,
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: m.af.srcAddrOffset,
|
||||||
|
Len: m.af.addrLen,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
// add individual IP for match if no ipset defined
|
||||||
|
if ipset == nil {
|
||||||
|
expressions = append(expressions,
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: rawIP,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
expressions = append(expressions,
|
||||||
|
&expr.Lookup{
|
||||||
|
SourceRegister: 1,
|
||||||
|
SetName: ipset.Name,
|
||||||
|
SetID: ipset.ID,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expressions = append(expressions, applyPort(sPort, true)...)
|
||||||
|
expressions = append(expressions, applyPort(dPort, false)...)
|
||||||
|
|
||||||
|
mainExpressions := slices.Clone(expressions)
|
||||||
|
|
||||||
|
switch action {
|
||||||
|
case firewall.ActionAccept:
|
||||||
|
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictAccept})
|
||||||
|
case firewall.ActionDrop:
|
||||||
|
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||||
|
}
|
||||||
|
|
||||||
|
userData := []byte(ruleId)
|
||||||
|
|
||||||
|
chain := m.chainInputRules
|
||||||
|
rule := &nftables.Rule{
|
||||||
|
Table: m.workTable,
|
||||||
|
Chain: chain,
|
||||||
|
Exprs: mainExpressions,
|
||||||
|
UserData: userData,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||||
|
var nftRule *nftables.Rule
|
||||||
|
if action == firewall.ActionDrop {
|
||||||
|
nftRule = m.rConn.InsertRule(rule)
|
||||||
|
} else {
|
||||||
|
nftRule = m.rConn.AddRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
return nil, fmt.Errorf("flush input rule %s: %v", ruleId, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleStruct := &Rule{
|
||||||
|
nftRule: nftRule,
|
||||||
|
// best effort mangle rule
|
||||||
|
mangleRule: m.createPreroutingRule(expressions, userData),
|
||||||
|
nftSet: ipset,
|
||||||
|
ruleID: ruleId,
|
||||||
|
ip: ip,
|
||||||
|
}
|
||||||
|
m.rules[ruleId] = ruleStruct
|
||||||
|
if ipset != nil {
|
||||||
|
m.ipsetStore.AddReferenceToIpset(ipset.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ruleStruct, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
|
||||||
|
if m.chainPrerouting == nil {
|
||||||
|
log.Warn("prerouting chain is not created")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
preroutingExprs := slices.Clone(expressions)
|
||||||
|
|
||||||
|
// interface
|
||||||
|
preroutingExprs = append([]expr.Any{
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyIIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(m.wgIface.Name()),
|
||||||
|
},
|
||||||
|
}, preroutingExprs...)
|
||||||
|
|
||||||
|
// local destination and mark
|
||||||
|
preroutingExprs = append(preroutingExprs,
|
||||||
|
&expr.Fib{
|
||||||
|
Register: 1,
|
||||||
|
ResultADDRTYPE: true,
|
||||||
|
FlagDADDR: true,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
|
||||||
|
},
|
||||||
|
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||||
|
},
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyMARK,
|
||||||
|
Register: 1,
|
||||||
|
SourceRegister: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
nfRule := m.rConn.AddRule(&nftables.Rule{
|
||||||
|
Table: m.workTable,
|
||||||
|
Chain: m.chainPrerouting,
|
||||||
|
Exprs: preroutingExprs,
|
||||||
|
UserData: userData,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
log.Errorf("failed to flush mangle rule %s: %v", string(userData), err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nfRule
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *AclManager) createDefaultChains() (err error) {
|
||||||
|
// chainNameInputRules
|
||||||
|
chain := m.createChain(chainNameInputRules)
|
||||||
|
err = m.rConn.Flush()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
|
||||||
|
return fmt.Errorf(flushError, err)
|
||||||
|
}
|
||||||
|
m.chainInputRules = chain
|
||||||
|
|
||||||
|
// netbird-acl-input-filter
|
||||||
|
// type filter hook input priority filter; policy accept;
|
||||||
|
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
|
||||||
|
m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
|
||||||
|
m.addDropExpressions(chain, expr.MetaKeyIIFNAME)
|
||||||
|
err = m.rConn.Flush()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// netbird-acl-forward-filter
|
||||||
|
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
|
||||||
|
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
|
||||||
|
m.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME)
|
||||||
|
|
||||||
|
err = m.rConn.Flush()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
|
||||||
|
return fmt.Errorf(flushError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.allowRedirectedTraffic(chainFwFilter); err != nil {
|
||||||
|
log.Errorf("failed to allow redirected traffic: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Makes redirected traffic originally destined for the host itself (now subject to the forward filter)
|
||||||
|
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
||||||
|
// netbird peer IP.
|
||||||
|
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
||||||
|
// Chain is created by route manager
|
||||||
|
// TODO: move creation to a common place
|
||||||
|
m.chainPrerouting = &nftables.Chain{
|
||||||
|
Name: chainNameManglePrerouting,
|
||||||
|
Table: m.workTable,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
|
Priority: nftables.ChainPriorityMangle,
|
||||||
|
}
|
||||||
|
|
||||||
|
m.addFwmarkToForward(chainFwFilter)
|
||||||
|
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf(flushError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
||||||
|
m.rConn.InsertRule(&nftables.Rule{
|
||||||
|
Table: m.workTable,
|
||||||
|
Chain: chainFwFilter,
|
||||||
|
Exprs: []expr.Any{
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyMARK,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||||
|
},
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
|
||||||
|
expressions := []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(m.wgIface.Name()),
|
||||||
|
},
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictJump,
|
||||||
|
Chain: m.routingFwChainName,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = m.rConn.AddRule(&nftables.Rule{
|
||||||
|
Table: m.workTable,
|
||||||
|
Chain: chainFwFilter,
|
||||||
|
Exprs: expressions,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *AclManager) createChain(name string) *nftables.Chain {
|
||||||
|
chain := &nftables.Chain{
|
||||||
|
Name: name,
|
||||||
|
Table: m.workTable,
|
||||||
|
}
|
||||||
|
|
||||||
|
chain = m.rConn.AddChain(chain)
|
||||||
|
|
||||||
|
insertReturnTrafficRule(m.rConn, m.workTable, chain)
|
||||||
|
|
||||||
|
return chain
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain {
|
||||||
|
polAccept := nftables.ChainPolicyAccept
|
||||||
|
chain := &nftables.Chain{
|
||||||
|
Name: name,
|
||||||
|
Table: m.workTable,
|
||||||
|
Hooknum: hookNum,
|
||||||
|
Priority: nftables.ChainPriorityFilter,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
Policy: &polAccept,
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.rConn.AddChain(chain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
|
||||||
|
expressions := []expr.Any{
|
||||||
|
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(m.wgIface.Name()),
|
||||||
|
},
|
||||||
|
&expr.Verdict{Kind: expr.VerdictDrop},
|
||||||
|
}
|
||||||
|
_ = m.rConn.AddRule(&nftables.Rule{
|
||||||
|
Table: m.workTable,
|
||||||
|
Chain: chain,
|
||||||
|
Exprs: expressions,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
|
||||||
|
expressions := []expr.Any{
|
||||||
|
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(m.wgIface.Name()),
|
||||||
|
},
|
||||||
|
&expr.Verdict{
|
||||||
|
Kind: expr.VerdictJump,
|
||||||
|
Chain: to,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = m.rConn.AddRule(&nftables.Rule{
|
||||||
|
Table: chain.Table,
|
||||||
|
Chain: chain,
|
||||||
|
Exprs: expressions,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *AclManager) addIpToSet(ipsetName string, ip net.IP) (*nftables.Set, error) {
|
||||||
|
ipset, err := m.rConn.GetSetByName(m.workTable, ipsetName)
|
||||||
|
rawIP := ipToBytes(ip, m.af)
|
||||||
|
if err != nil {
|
||||||
|
if ipset, err = m.createSet(m.workTable, ipsetName); err != nil {
|
||||||
|
return nil, fmt.Errorf("get set name: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.ipsetStore.newIpset(ipset.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.ipsetStore.IsIpInSet(ipset.Name, ip) {
|
||||||
|
return ipset, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.sConn.SetAddElements(ipset, []nftables.SetElement{{Key: rawIP}}); err != nil {
|
||||||
|
return nil, fmt.Errorf("add set element for the first time: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.ipsetStore.AddIpToSet(ipset.Name, ip)
|
||||||
|
|
||||||
|
if err := m.sConn.Flush(); err != nil {
|
||||||
|
return nil, fmt.Errorf("flush add elements: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ipset, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// createSet in given table by name
|
||||||
|
func (m *AclManager) createSet(table *nftables.Table, name string) (*nftables.Set, error) {
|
||||||
|
ipset := &nftables.Set{
|
||||||
|
Name: name,
|
||||||
|
Table: table,
|
||||||
|
Dynamic: true,
|
||||||
|
KeyType: m.af.setKeyType,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.rConn.AddSet(ipset, nil); err != nil {
|
||||||
|
return nil, fmt.Errorf("create set: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
return nil, fmt.Errorf("flush created set: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ipset, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *AclManager) flushWithBackoff() (err error) {
|
||||||
|
backoff := 4
|
||||||
|
backoffTime := 1000 * time.Millisecond
|
||||||
|
for i := 0; ; i++ {
|
||||||
|
err = m.rConn.Flush()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to flush nftables: %v", err)
|
||||||
|
if !strings.Contains(err.Error(), "busy") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Error("failed to flush nftables, retrying...")
|
||||||
|
if i == backoff-1 {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
time.Sleep(backoffTime)
|
||||||
|
backoffTime *= 2
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) error {
|
||||||
|
if m.workTable == nil || chain == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
list, err := m.rConn.GetRules(m.workTable, chain)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range list {
|
||||||
|
if len(rule.UserData) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
split := bytes.Split(rule.UserData, []byte(" "))
|
||||||
|
r, ok := m.rules[string(split[0])]
|
||||||
|
if ok {
|
||||||
|
if mangle {
|
||||||
|
*r.mangleRule = *rule
|
||||||
|
} else {
|
||||||
|
*r.nftRule = *rule
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generatePeerRuleId(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
|
||||||
|
rulesetID := ":" + string(proto) + ":"
|
||||||
|
if sPort != nil {
|
||||||
|
rulesetID += sPort.String()
|
||||||
|
}
|
||||||
|
rulesetID += ":"
|
||||||
|
if dPort != nil {
|
||||||
|
rulesetID += dPort.String()
|
||||||
|
}
|
||||||
|
rulesetID += ":"
|
||||||
|
rulesetID += strconv.Itoa(int(action))
|
||||||
|
if ipset == nil {
|
||||||
|
return "ip:" + ip.String() + rulesetID
|
||||||
|
}
|
||||||
|
return "set:" + ipset.Name + rulesetID
|
||||||
|
}
|
||||||
|
|
||||||
|
func ifname(n string) []byte {
|
||||||
|
b := make([]byte, 16)
|
||||||
|
copy(b, n+"\x00")
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ipToBytes converts net.IP to the correct byte length for the address family.
|
||||||
|
func ipToBytes(ip net.IP, af addrFamily) []byte {
|
||||||
|
if af.addrLen == 4 {
|
||||||
|
return ip.To4()
|
||||||
|
}
|
||||||
|
return ip.To16()
|
||||||
|
}
|
||||||
|
|
||||||
@@ -1,880 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"slices"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
|
||||||
"github.com/google/nftables"
|
|
||||||
"github.com/google/nftables/binaryutil"
|
|
||||||
"github.com/google/nftables/expr"
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (r *family) createContainers() error {
|
|
||||||
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
|
|
||||||
Name: chainNameRoutingFw,
|
|
||||||
Table: r.workTable,
|
|
||||||
})
|
|
||||||
|
|
||||||
prio := *nftables.ChainPriorityNATSource - 1
|
|
||||||
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
|
|
||||||
Name: chainNameRoutingNat,
|
|
||||||
Table: r.workTable,
|
|
||||||
Hooknum: nftables.ChainHookPostrouting,
|
|
||||||
Priority: &prio,
|
|
||||||
Type: nftables.ChainTypeNAT,
|
|
||||||
})
|
|
||||||
|
|
||||||
r.chains[chainNameRoutingRdr] = r.conn.AddChain(&nftables.Chain{
|
|
||||||
Name: chainNameRoutingRdr,
|
|
||||||
Table: r.workTable,
|
|
||||||
Hooknum: nftables.ChainHookPrerouting,
|
|
||||||
Priority: nftables.ChainPriorityNATDest,
|
|
||||||
Type: nftables.ChainTypeNAT,
|
|
||||||
})
|
|
||||||
|
|
||||||
r.chains[chainNameManglePostrouting] = r.conn.AddChain(&nftables.Chain{
|
|
||||||
Name: chainNameManglePostrouting,
|
|
||||||
Table: r.workTable,
|
|
||||||
Hooknum: nftables.ChainHookPostrouting,
|
|
||||||
Priority: nftables.ChainPriorityMangle,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
})
|
|
||||||
|
|
||||||
r.chains[chainNameManglePrerouting] = r.conn.AddChain(&nftables.Chain{
|
|
||||||
Name: chainNameManglePrerouting,
|
|
||||||
Table: r.workTable,
|
|
||||||
Hooknum: nftables.ChainHookPrerouting,
|
|
||||||
Priority: nftables.ChainPriorityMangle,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
})
|
|
||||||
|
|
||||||
r.chains[chainNameMangleForward] = r.conn.AddChain(&nftables.Chain{
|
|
||||||
Name: chainNameMangleForward,
|
|
||||||
Table: r.workTable,
|
|
||||||
Hooknum: nftables.ChainHookForward,
|
|
||||||
Priority: nftables.ChainPriorityMangle,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
})
|
|
||||||
|
|
||||||
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
|
|
||||||
|
|
||||||
r.addPostroutingRules()
|
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf("initialize tables: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.addMSSClampingRules(); err != nil {
|
|
||||||
log.Errorf("failed to add MSS clamping rules: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Kernel routing opens both INPUT and FORWARD.
|
|
||||||
if err := r.openInterface(true); err != nil {
|
|
||||||
log.Errorf("failed to open interface in foreign chains: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil {
|
|
||||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
|
||||||
log.Errorf("failed to refresh rules: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// setupDataPlaneMark configures the fwmark for the data plane
|
|
||||||
func (r *family) setupDataPlaneMark() error {
|
|
||||||
if r.chains[chainNameManglePrerouting] == nil || r.chains[chainNameManglePostrouting] == nil {
|
|
||||||
return errors.New("no mangle chains found")
|
|
||||||
}
|
|
||||||
|
|
||||||
ctNew := getCtNewExprs()
|
|
||||||
preExprs := []expr.Any{
|
|
||||||
&expr.Meta{
|
|
||||||
Key: expr.MetaKeyIIFNAME,
|
|
||||||
Register: 1,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(r.wgIface.Name()),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
preExprs = append(preExprs, ctNew...)
|
|
||||||
preExprs = append(preExprs,
|
|
||||||
&expr.Immediate{
|
|
||||||
Register: 1,
|
|
||||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkIn),
|
|
||||||
},
|
|
||||||
&expr.Ct{
|
|
||||||
Key: expr.CtKeyMARK,
|
|
||||||
Register: 1,
|
|
||||||
SourceRegister: true,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
preNftRule := &nftables.Rule{
|
|
||||||
Table: r.workTable,
|
|
||||||
Chain: r.chains[chainNameManglePrerouting],
|
|
||||||
Exprs: preExprs,
|
|
||||||
}
|
|
||||||
r.conn.AddRule(preNftRule)
|
|
||||||
|
|
||||||
postExprs := []expr.Any{
|
|
||||||
&expr.Meta{
|
|
||||||
Key: expr.MetaKeyOIFNAME,
|
|
||||||
Register: 1,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(r.wgIface.Name()),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
postExprs = append(postExprs, ctNew...)
|
|
||||||
postExprs = append(postExprs,
|
|
||||||
&expr.Immediate{
|
|
||||||
Register: 1,
|
|
||||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkOut),
|
|
||||||
},
|
|
||||||
&expr.Ct{
|
|
||||||
Key: expr.CtKeyMARK,
|
|
||||||
Register: 1,
|
|
||||||
SourceRegister: true,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
postNftRule := &nftables.Rule{
|
|
||||||
Table: r.workTable,
|
|
||||||
Chain: r.chains[chainNameManglePostrouting],
|
|
||||||
Exprs: postExprs,
|
|
||||||
}
|
|
||||||
r.conn.AddRule(postNftRule)
|
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf("flush: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// openInterface adds passthrough accept rules for the NetBird interface to the
|
|
||||||
// kernel's filter table and external chains so they don't drop our traffic.
|
|
||||||
// includeForward also opens the FORWARD chains (kernel routing); when false only
|
|
||||||
// INPUT is opened, which is all the userspace router needs since it never
|
|
||||||
// forwards in the kernel.
|
|
||||||
func (r *family) openInterface(includeForward bool) error {
|
|
||||||
var merr *multierror.Error
|
|
||||||
|
|
||||||
if err := r.acceptFilterTableRules(includeForward); err != nil {
|
|
||||||
merr = multierror.Append(merr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.acceptExternalChainsRules(includeForward); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add accept rules to external chains: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) acceptFilterTableRules(includeForward bool) error {
|
|
||||||
if r.filterTable == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
fw := "iptables"
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
log.Debugf("Used %s to add accept input/forward rules", fw)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Try iptables first and fallback to nftables if iptables is not available.
|
|
||||||
// Use the correct protocol (iptables vs ip6tables) for the address family.
|
|
||||||
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
|
||||||
|
|
||||||
fw = "nftables"
|
|
||||||
return r.acceptFilterRulesNftables(r.filterTable, includeForward)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.acceptFilterRulesIptables(ipt, includeForward); err != nil {
|
|
||||||
log.Warnf("iptables failed (table may be incompatible), falling back to nftables: %v", err)
|
|
||||||
fw = "nftables"
|
|
||||||
return r.acceptFilterRulesNftables(r.filterTable, includeForward)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) acceptFilterRulesIptables(ipt *iptables.IPTables, includeForward bool) error {
|
|
||||||
var merr *multierror.Error
|
|
||||||
|
|
||||||
if includeForward {
|
|
||||||
for _, rule := range r.getAcceptForwardRules() {
|
|
||||||
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add iptables forward rule: %v", err))
|
|
||||||
} else {
|
|
||||||
log.Debugf("added iptables forward rule: %v", rule)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inputRule := r.getAcceptInputRule()
|
|
||||||
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add iptables input rule: %v", err))
|
|
||||||
} else {
|
|
||||||
log.Debugf("added iptables input rule: %v", inputRule)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) getAcceptForwardRules() [][]string {
|
|
||||||
intf := r.wgIface.Name()
|
|
||||||
return [][]string{
|
|
||||||
{"-i", intf, "-j", "ACCEPT"},
|
|
||||||
{"-o", intf, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) getAcceptInputRule() []string {
|
|
||||||
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
|
|
||||||
}
|
|
||||||
|
|
||||||
// acceptFilterRulesNftables adds accept rules to the ip filter table using nftables.
|
|
||||||
// This is used when iptables is not available.
|
|
||||||
func (r *family) acceptFilterRulesNftables(table *nftables.Table, includeForward bool) error {
|
|
||||||
intf := ifname(r.wgIface.Name())
|
|
||||||
|
|
||||||
if includeForward {
|
|
||||||
forwardChain := &nftables.Chain{
|
|
||||||
Name: chainNameForward,
|
|
||||||
Table: table,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookForward,
|
|
||||||
Priority: nftables.ChainPriorityFilter,
|
|
||||||
}
|
|
||||||
r.insertForwardAcceptRules(forwardChain, intf)
|
|
||||||
}
|
|
||||||
|
|
||||||
inputChain := &nftables.Chain{
|
|
||||||
Name: chainNameInput,
|
|
||||||
Table: table,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookInput,
|
|
||||||
Priority: nftables.ChainPriorityFilter,
|
|
||||||
}
|
|
||||||
r.insertInputAcceptRule(inputChain, intf)
|
|
||||||
|
|
||||||
return r.conn.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
// acceptExternalChainsRules adds accept rules to external chains (non-netbird, non-iptables tables).
|
|
||||||
// It dynamically finds chains at call time to handle chains that may have been created after startup.
|
|
||||||
func (r *family) acceptExternalChainsRules(includeForward bool) error {
|
|
||||||
chains := r.findExternalChains()
|
|
||||||
if len(chains) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
intf := ifname(r.wgIface.Name())
|
|
||||||
for _, chain := range chains {
|
|
||||||
r.applyExternalChainAccept(chain, intf, includeForward)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf("flush external chain rules: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) applyExternalChainAccept(chain *nftables.Chain, intf []byte, includeForward bool) {
|
|
||||||
if chain.Hooknum == nil {
|
|
||||||
log.Debugf("skipping external chain %s/%s: hooknum is nil", chain.Table.Name, chain.Name)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("adding accept rules to external %s chain: %s %s/%s",
|
|
||||||
hookName(chain.Hooknum), familyName(chain.Table.Family), chain.Table.Name, chain.Name)
|
|
||||||
|
|
||||||
switch *chain.Hooknum {
|
|
||||||
case *nftables.ChainHookForward:
|
|
||||||
if includeForward {
|
|
||||||
r.insertForwardAcceptRules(chain, intf)
|
|
||||||
}
|
|
||||||
case *nftables.ChainHookInput:
|
|
||||||
r.insertInputAcceptRule(chain, intf)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) insertForwardAcceptRules(chain *nftables.Chain, intf []byte) {
|
|
||||||
existing, err := r.existingNetbirdRulesInChain(chain)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("skip forward accept rules in %s/%s: %v", chain.Table.Name, chain.Name, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
r.insertForwardIifRule(chain, intf, existing)
|
|
||||||
r.insertForwardOifEstablishedRule(chain, intf, existing)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) insertForwardIifRule(chain *nftables.Chain, intf []byte, existing map[string]bool) {
|
|
||||||
if existing[userDataAcceptForwardRuleIif] {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
r.conn.InsertRule(&nftables.Rule{
|
|
||||||
Table: chain.Table,
|
|
||||||
Chain: chain,
|
|
||||||
Exprs: []expr.Any{
|
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
|
||||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
|
|
||||||
&expr.Counter{},
|
|
||||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
|
||||||
},
|
|
||||||
UserData: []byte(userDataAcceptForwardRuleIif),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) insertForwardOifEstablishedRule(chain *nftables.Chain, intf []byte, existing map[string]bool) {
|
|
||||||
if existing[userDataAcceptForwardRuleOif] {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
exprs := []expr.Any{
|
|
||||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
|
||||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
|
|
||||||
}
|
|
||||||
r.conn.InsertRule(&nftables.Rule{
|
|
||||||
Table: chain.Table,
|
|
||||||
Chain: chain,
|
|
||||||
Exprs: append(exprs, getEstablishedExprs(2)...),
|
|
||||||
UserData: []byte(userDataAcceptForwardRuleOif),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) insertInputAcceptRule(chain *nftables.Chain, intf []byte) {
|
|
||||||
existing, err := r.existingNetbirdRulesInChain(chain)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("skip input accept rule in %s/%s: %v", chain.Table.Name, chain.Name, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if existing[userDataAcceptInputRule] {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
r.conn.InsertRule(&nftables.Rule{
|
|
||||||
Table: chain.Table,
|
|
||||||
Chain: chain,
|
|
||||||
Exprs: []expr.Any{
|
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
|
||||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: intf},
|
|
||||||
&expr.Counter{},
|
|
||||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
|
||||||
},
|
|
||||||
UserData: []byte(userDataAcceptInputRule),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// existingNetbirdRulesInChain returns the set of netbird-owned UserData tags present in a chain; callers must bail on error since InsertRule is additive.
|
|
||||||
func (r *family) existingNetbirdRulesInChain(chain *nftables.Chain) (map[string]bool, error) {
|
|
||||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("list rules: %w", err)
|
|
||||||
}
|
|
||||||
present := map[string]bool{}
|
|
||||||
for _, rule := range rules {
|
|
||||||
if !isNetbirdAcceptRuleTag(rule.UserData) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
present[string(rule.UserData)] = true
|
|
||||||
}
|
|
||||||
return present, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isNetbirdAcceptRuleTag(userData []byte) bool {
|
|
||||||
switch string(userData) {
|
|
||||||
case userDataAcceptForwardRuleIif,
|
|
||||||
userDataAcceptForwardRuleOif,
|
|
||||||
userDataAcceptInputRule:
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) removeAcceptFilterRules() error {
|
|
||||||
var merr *multierror.Error
|
|
||||||
|
|
||||||
if err := r.removeFilterTableRules(); err != nil {
|
|
||||||
merr = multierror.Append(merr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.removeExternalChainsRules(); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove external chain rules: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) removeFilterTableRules() error {
|
|
||||||
if r.filterTable == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ipt, err := iptables.NewWithProtocol(r.iptablesProto())
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("iptables not available, using nftables to remove filter rules: %v", err)
|
|
||||||
return r.removeAcceptRulesFromTable(r.filterTable)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.removeAcceptFilterRulesIptables(ipt); err != nil {
|
|
||||||
log.Debugf("iptables removal failed (table may be incompatible), falling back to nftables: %v", err)
|
|
||||||
return r.removeAcceptRulesFromTable(r.filterTable)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) removeAcceptRulesFromTable(table *nftables.Table) error {
|
|
||||||
chains, err := r.conn.ListChainsOfTableFamily(table.Family)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("list chains: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, chain := range chains {
|
|
||||||
if chain.Table.Name != table.Name {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if chain.Name != chainNameForward && chain.Name != chainNameInput {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.removeAcceptRulesFromChain(table, chain); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return r.conn.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) removeAcceptRulesFromChain(table *nftables.Table, chain *nftables.Chain) error {
|
|
||||||
rules, err := r.conn.GetRules(table, chain)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("get rules from %s/%s: %v", table.Name, chain.Name, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rule := range rules {
|
|
||||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
|
||||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
|
|
||||||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
|
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
|
||||||
return fmt.Errorf("delete rule from %s/%s: %v", table.Name, chain.Name, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// removeExternalChainsRules removes our accept rules from all external chains.
|
|
||||||
// This is deterministic - it scans for chains at removal time rather than relying on saved state,
|
|
||||||
// ensuring cleanup works even after a crash or if chains changed.
|
|
||||||
func (r *family) removeExternalChainsRules() error {
|
|
||||||
chains := r.findExternalChains()
|
|
||||||
if len(chains) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
for _, chain := range chains {
|
|
||||||
if err := r.removeAcceptRulesFromChain(chain.Table, chain); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove rules from external chain %s/%s: %w", chain.Table.Name, chain.Name, err))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("flush external chain %s/%s: %w", chain.Table.Name, chain.Name, err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// findExternalChains scans for chains from non-netbird tables that have FORWARD or INPUT hooks.
|
|
||||||
// This is used both at startup (to know where to add rules) and at cleanup (to ensure deterministic removal).
|
|
||||||
func (r *family) findExternalChains() []*nftables.Chain {
|
|
||||||
var chains []*nftables.Chain
|
|
||||||
|
|
||||||
families := []nftables.TableFamily{r.af.tableFamily, nftables.TableFamilyINet}
|
|
||||||
|
|
||||||
for _, family := range families {
|
|
||||||
allChains, err := r.conn.ListChainsOfTableFamily(family)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("list chains for family %d: %v", family, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, chain := range allChains {
|
|
||||||
if r.isExternalChain(chain) {
|
|
||||||
chains = append(chains, chain)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return chains
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) isExternalChain(chain *nftables.Chain) bool {
|
|
||||||
if r.workTable != nil && chain.Table.Name == r.workTable.Name {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip firewalld-owned chains. Firewalld creates its chains with the
|
|
||||||
// NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM.
|
|
||||||
// We delegate acceptance to firewalld by trusting the interface instead.
|
|
||||||
if chain.Table.Name == firewalldTableName {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip iptables/ip6tables-managed tables (adding nft-native rules breaks iptables-save compat)
|
|
||||||
if (chain.Table.Family == nftables.TableFamilyIPv4 || chain.Table.Family == nftables.TableFamilyIPv6) && isIptablesTable(chain.Table.Name) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if chain.Type != nftables.ChainTypeFilter {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if chain.Hooknum == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return *chain.Hooknum == *nftables.ChainHookForward || *chain.Hooknum == *nftables.ChainHookInput
|
|
||||||
}
|
|
||||||
|
|
||||||
func isIptablesTable(name string) bool {
|
|
||||||
switch name {
|
|
||||||
case tableNameFilter, tableNat, tableMangle, tableRaw, tableSecurity:
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
|
||||||
var merr *multierror.Error
|
|
||||||
|
|
||||||
for _, rule := range r.getAcceptForwardRules() {
|
|
||||||
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove iptables forward rule: %v", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inputRule := r.getAcceptInputRule()
|
|
||||||
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove iptables input rule: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Flush rule/chain/set operations from the buffer
|
|
||||||
//
|
|
||||||
// Method also get all rules after flush and refreshes handle values in the rulesets
|
|
||||||
func (r *family) Flush() error {
|
|
||||||
if err := r.flushWithBackoff(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.refreshRuleHandles(r.chainInputRules, false); err != nil {
|
|
||||||
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
|
||||||
}
|
|
||||||
if err := r.refreshRuleHandles(r.chainPrerouting, true); err != nil {
|
|
||||||
log.Errorf("failed to refresh rule handles prerouting chain: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// queuePreroutingRule builds the prerouting mangle rule that marks
|
|
||||||
// redirected traffic and queues it on the connection without flushing,
|
|
||||||
// so the caller can commit it in the same transaction as the rule it
|
|
||||||
// pairs with. Returns nil when the prerouting chain is absent, in which
|
|
||||||
// case nothing is queued.
|
|
||||||
func (r *family) queuePreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
|
|
||||||
if r.chainPrerouting == nil {
|
|
||||||
log.Warn("prerouting chain is not created")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
preroutingExprs := slices.Clone(expressions)
|
|
||||||
|
|
||||||
// interface
|
|
||||||
preroutingExprs = append([]expr.Any{
|
|
||||||
&expr.Meta{
|
|
||||||
Key: expr.MetaKeyIIFNAME,
|
|
||||||
Register: 1,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(r.wgIface.Name()),
|
|
||||||
},
|
|
||||||
}, preroutingExprs...)
|
|
||||||
|
|
||||||
// local destination and mark
|
|
||||||
preroutingExprs = append(preroutingExprs,
|
|
||||||
&expr.Fib{
|
|
||||||
Register: 1,
|
|
||||||
ResultADDRTYPE: true,
|
|
||||||
FlagDADDR: true,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
|
|
||||||
},
|
|
||||||
|
|
||||||
&expr.Immediate{
|
|
||||||
Register: 1,
|
|
||||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
|
||||||
},
|
|
||||||
&expr.Meta{
|
|
||||||
Key: expr.MetaKeyMARK,
|
|
||||||
Register: 1,
|
|
||||||
SourceRegister: true,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
return r.conn.AddRule(&nftables.Rule{
|
|
||||||
Table: r.workTable,
|
|
||||||
Chain: r.chainPrerouting,
|
|
||||||
Exprs: preroutingExprs,
|
|
||||||
UserData: userData,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) createDefaultChains() (err error) {
|
|
||||||
// chainNameInputRules
|
|
||||||
chain := r.createChain(chainNameInputRules)
|
|
||||||
err = r.conn.Flush()
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
|
|
||||||
return fmt.Errorf(flushError, err)
|
|
||||||
}
|
|
||||||
r.chainInputRules = chain
|
|
||||||
|
|
||||||
// netbird-acl-input-filter
|
|
||||||
// type filter hook input priority filter; policy accept;
|
|
||||||
chain = r.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
|
|
||||||
r.addJumpRule(chain, r.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
|
|
||||||
r.addDropExpressions(chain, expr.MetaKeyIIFNAME)
|
|
||||||
err = r.conn.Flush()
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// netbird-acl-forward-filter
|
|
||||||
chainFwFilter := r.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
|
|
||||||
r.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
|
|
||||||
r.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME)
|
|
||||||
|
|
||||||
err = r.conn.Flush()
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
|
|
||||||
return fmt.Errorf(flushError, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.allowRedirectedTraffic(chainFwFilter); err != nil {
|
|
||||||
log.Errorf("failed to allow redirected traffic: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Makes redirected traffic originally destined for the host itself (now subject to the forward filter)
|
|
||||||
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
|
||||||
// netbird peer IP.
|
|
||||||
func (r *family) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
|
||||||
r.chainPrerouting = r.chains[chainNameManglePrerouting]
|
|
||||||
|
|
||||||
r.addFwmarkToForward(chainFwFilter)
|
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf(flushError, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
|
||||||
r.conn.InsertRule(&nftables.Rule{
|
|
||||||
Table: r.workTable,
|
|
||||||
Chain: chainFwFilter,
|
|
||||||
Exprs: []expr.Any{
|
|
||||||
&expr.Meta{
|
|
||||||
Key: expr.MetaKeyMARK,
|
|
||||||
Register: 1,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
|
||||||
},
|
|
||||||
&expr.Verdict{
|
|
||||||
Kind: expr.VerdictAccept,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
|
|
||||||
expressions := []expr.Any{
|
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(r.wgIface.Name()),
|
|
||||||
},
|
|
||||||
&expr.Verdict{
|
|
||||||
Kind: expr.VerdictJump,
|
|
||||||
Chain: r.routingFwChainName,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = r.conn.AddRule(&nftables.Rule{
|
|
||||||
Table: r.workTable,
|
|
||||||
Chain: chainFwFilter,
|
|
||||||
Exprs: expressions,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) createChain(name string) *nftables.Chain {
|
|
||||||
chain := &nftables.Chain{
|
|
||||||
Name: name,
|
|
||||||
Table: r.workTable,
|
|
||||||
}
|
|
||||||
|
|
||||||
chain = r.conn.AddChain(chain)
|
|
||||||
|
|
||||||
insertReturnTrafficRule(r.conn, r.workTable, chain)
|
|
||||||
|
|
||||||
return chain
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain {
|
|
||||||
polAccept := nftables.ChainPolicyAccept
|
|
||||||
chain := &nftables.Chain{
|
|
||||||
Name: name,
|
|
||||||
Table: r.workTable,
|
|
||||||
Hooknum: hookNum,
|
|
||||||
Priority: nftables.ChainPriorityFilter,
|
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Policy: &polAccept,
|
|
||||||
}
|
|
||||||
|
|
||||||
return r.conn.AddChain(chain)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
|
|
||||||
expressions := []expr.Any{
|
|
||||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(r.wgIface.Name()),
|
|
||||||
},
|
|
||||||
&expr.Verdict{Kind: expr.VerdictDrop},
|
|
||||||
}
|
|
||||||
_ = r.conn.AddRule(&nftables.Rule{
|
|
||||||
Table: r.workTable,
|
|
||||||
Chain: chain,
|
|
||||||
Exprs: expressions,
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
|
|
||||||
expressions := []expr.Any{
|
|
||||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(r.wgIface.Name()),
|
|
||||||
},
|
|
||||||
&expr.Verdict{
|
|
||||||
Kind: expr.VerdictJump,
|
|
||||||
Chain: to,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = r.conn.AddRule(&nftables.Rule{
|
|
||||||
Table: chain.Table,
|
|
||||||
Chain: chain,
|
|
||||||
Exprs: expressions,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) flushWithBackoff() (err error) {
|
|
||||||
backoff := 4
|
|
||||||
backoffTime := 1000 * time.Millisecond
|
|
||||||
for i := 0; ; i++ {
|
|
||||||
err = r.conn.Flush()
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to flush nftables: %v", err)
|
|
||||||
if !strings.Contains(err.Error(), "busy") {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Error("failed to flush nftables, retrying...")
|
|
||||||
if i == backoff-1 {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
time.Sleep(backoffTime)
|
|
||||||
backoffTime *= 2
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) refreshRuleHandles(chain *nftables.Chain, mangle bool) error {
|
|
||||||
if r.workTable == nil || chain == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
list, err := r.conn.GetRules(r.workTable, chain)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rule := range list {
|
|
||||||
if len(rule.UserData) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
pr, ok := r.filters[firewall.RuleID(rule.UserData)]
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if mangle {
|
|
||||||
if pr.mangleRule != nil {
|
|
||||||
*pr.mangleRule = *rule
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
*pr.nftRule = *rule
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,565 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"github.com/google/nftables"
|
|
||||||
"github.com/google/nftables/binaryutil"
|
|
||||||
"github.com/google/nftables/expr"
|
|
||||||
"github.com/google/nftables/xt"
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (r *family) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
|
||||||
ruleID := rule.ID()
|
|
||||||
if _, exists := r.rules[ruleID+dnatSuffix]; exists {
|
|
||||||
return rule, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
protoNum, err := r.af.protoNum(rule.Protocol)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Request forwarding once the rule is about to be installed, releasing
|
|
||||||
// it if a later step fails so the refcount tracks the real rules.
|
|
||||||
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.addDnatRedirect(rule, protoNum, ruleID); err != nil {
|
|
||||||
r.releaseForwarding()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.addDnatMasq(rule, protoNum, ruleID); err != nil {
|
|
||||||
r.releaseForwarding()
|
|
||||||
delete(r.rules, ruleID+dnatSuffix)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unlike iptables, there's no point in adding "out" rules in the forward chain here as our policy is ACCEPT.
|
|
||||||
// To overcome DROP policies in other chains, we'd have to add rules to the chains there.
|
|
||||||
// We also cannot just add "oif <iface> accept" there and filter in our own table as we don't know what is supposed to be allowed.
|
|
||||||
// TODO: find chains with drop policies and add rules there
|
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
r.releaseForwarding()
|
|
||||||
return nil, fmt.Errorf("flush rules: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &rule, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, ruleID firewall.RuleID) error {
|
|
||||||
dnatExprs := []expr.Any{
|
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpNeq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(r.wgIface.Name()),
|
|
||||||
},
|
|
||||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: []byte{protoNum},
|
|
||||||
},
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseTransportHeader,
|
|
||||||
Offset: 2,
|
|
||||||
Len: 2,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
portExprs, err := r.applyPort(&rule.DestinationPort, false)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("apply destination port: %w", err)
|
|
||||||
}
|
|
||||||
dnatExprs = append(dnatExprs, portExprs...)
|
|
||||||
|
|
||||||
// shifted translated port is not supported in nftables, so we hand this over to xtables
|
|
||||||
if rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2 {
|
|
||||||
if rule.TranslatedPort.Values[0] != rule.DestinationPort.Values[0] ||
|
|
||||||
rule.TranslatedPort.Values[1] != rule.DestinationPort.Values[1] {
|
|
||||||
return r.addXTablesRedirect(dnatExprs, ruleID, rule)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
additionalExprs, regProtoMin, regProtoMax, err := r.handleTranslatedPort(rule)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
dnatExprs = append(dnatExprs, additionalExprs...)
|
|
||||||
|
|
||||||
dnatExprs = append(dnatExprs,
|
|
||||||
&expr.NAT{
|
|
||||||
Type: expr.NATTypeDestNAT,
|
|
||||||
Family: uint32(r.af.tableFamily),
|
|
||||||
RegAddrMin: 1,
|
|
||||||
RegProtoMin: regProtoMin,
|
|
||||||
RegProtoMax: regProtoMax,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
dnatRule := &nftables.Rule{
|
|
||||||
Table: r.workTable,
|
|
||||||
Chain: r.chains[chainNameRoutingRdr],
|
|
||||||
Exprs: dnatExprs,
|
|
||||||
UserData: []byte(ruleID + dnatSuffix),
|
|
||||||
}
|
|
||||||
r.conn.AddRule(dnatRule)
|
|
||||||
r.rules[ruleID+dnatSuffix] = dnatRule
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) handleTranslatedPort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
|
||||||
switch {
|
|
||||||
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
|
|
||||||
return r.handlePortRange(rule)
|
|
||||||
case len(rule.TranslatedPort.Values) == 0:
|
|
||||||
return r.handleAddressOnly(rule)
|
|
||||||
case len(rule.TranslatedPort.Values) == 1:
|
|
||||||
return r.handleSinglePort(rule)
|
|
||||||
default:
|
|
||||||
return nil, 0, 0, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) handlePortRange(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
|
||||||
exprs := []expr.Any{
|
|
||||||
&expr.Immediate{
|
|
||||||
Register: 1,
|
|
||||||
Data: rule.TranslatedAddress.AsSlice(),
|
|
||||||
},
|
|
||||||
&expr.Immediate{
|
|
||||||
Register: 2,
|
|
||||||
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
|
|
||||||
},
|
|
||||||
&expr.Immediate{
|
|
||||||
Register: 3,
|
|
||||||
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[1]),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return exprs, 2, 3, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) handleAddressOnly(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
|
||||||
exprs := []expr.Any{
|
|
||||||
&expr.Immediate{
|
|
||||||
Register: 1,
|
|
||||||
Data: rule.TranslatedAddress.AsSlice(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return exprs, 0, 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) handleSinglePort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
|
||||||
exprs := []expr.Any{
|
|
||||||
&expr.Immediate{
|
|
||||||
Register: 1,
|
|
||||||
Data: rule.TranslatedAddress.AsSlice(),
|
|
||||||
},
|
|
||||||
&expr.Immediate{
|
|
||||||
Register: 2,
|
|
||||||
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return exprs, 2, 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) addXTablesRedirect(dnatExprs []expr.Any, ruleID firewall.RuleID, rule firewall.ForwardRule) error {
|
|
||||||
dnatExprs = append(dnatExprs,
|
|
||||||
&expr.Counter{},
|
|
||||||
&expr.Target{
|
|
||||||
Name: "DNAT",
|
|
||||||
Rev: 2,
|
|
||||||
Info: &xt.NatRange2{
|
|
||||||
NatRange: xt.NatRange{
|
|
||||||
Flags: uint(xt.NatRangeMapIPs | xt.NatRangeProtoSpecified | xt.NatRangeProtoOffset),
|
|
||||||
MinIP: rule.TranslatedAddress.AsSlice(),
|
|
||||||
MaxIP: rule.TranslatedAddress.AsSlice(),
|
|
||||||
MinPort: rule.TranslatedPort.Values[0],
|
|
||||||
MaxPort: rule.TranslatedPort.Values[1],
|
|
||||||
},
|
|
||||||
BasePort: rule.DestinationPort.Values[0],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
natTable := &nftables.Table{
|
|
||||||
Name: tableNat,
|
|
||||||
Family: r.af.tableFamily,
|
|
||||||
}
|
|
||||||
dnatRule := &nftables.Rule{
|
|
||||||
Table: natTable,
|
|
||||||
Chain: &nftables.Chain{
|
|
||||||
Name: chainNameNatPrerouting,
|
|
||||||
Table: natTable,
|
|
||||||
Type: nftables.ChainTypeNAT,
|
|
||||||
Hooknum: nftables.ChainHookPrerouting,
|
|
||||||
Priority: nftables.ChainPriorityNATDest,
|
|
||||||
},
|
|
||||||
Exprs: dnatExprs,
|
|
||||||
UserData: []byte(ruleID + dnatSuffix),
|
|
||||||
}
|
|
||||||
r.conn.AddRule(dnatRule)
|
|
||||||
r.rules[ruleID+dnatSuffix] = dnatRule
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleID firewall.RuleID) error {
|
|
||||||
portExprs, err := r.applyPort(&rule.TranslatedPort, false)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("apply translated port: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
masqExprs := []expr.Any{
|
|
||||||
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(r.wgIface.Name()),
|
|
||||||
},
|
|
||||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: []byte{protoNum},
|
|
||||||
},
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: r.af.dstAddrOffset,
|
|
||||||
Len: r.af.addrLen,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: rule.TranslatedAddress.AsSlice(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
masqExprs = append(masqExprs, portExprs...)
|
|
||||||
masqExprs = append(masqExprs, &expr.Masq{})
|
|
||||||
|
|
||||||
masqRule := &nftables.Rule{
|
|
||||||
Table: r.workTable,
|
|
||||||
Chain: r.chains[chainNameRoutingNat],
|
|
||||||
Exprs: masqExprs,
|
|
||||||
UserData: []byte(ruleID + snatSuffix),
|
|
||||||
}
|
|
||||||
r.conn.AddRule(masqRule)
|
|
||||||
r.rules[ruleID+snatSuffix] = masqRule
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) DeleteDNATRule(rule firewall.Rule) error {
|
|
||||||
ruleID := rule.ID()
|
|
||||||
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
|
||||||
return fmt.Errorf(refreshRulesMapError, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
var needsFlush bool
|
|
||||||
var found bool
|
|
||||||
|
|
||||||
if dnatRule, exists := r.rules[ruleID+dnatSuffix]; exists {
|
|
||||||
found = true
|
|
||||||
if dnatRule.Handle == 0 {
|
|
||||||
log.Warnf("dnat rule %s has no handle, removing stale entry", ruleID+dnatSuffix)
|
|
||||||
delete(r.rules, ruleID+dnatSuffix)
|
|
||||||
} else if err := r.conn.DelRule(dnatRule); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
|
|
||||||
} else {
|
|
||||||
needsFlush = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if masqRule, exists := r.rules[ruleID+snatSuffix]; exists {
|
|
||||||
found = true
|
|
||||||
if masqRule.Handle == 0 {
|
|
||||||
log.Warnf("snat rule %s has no handle, removing stale entry", ruleID+snatSuffix)
|
|
||||||
delete(r.rules, ruleID+snatSuffix)
|
|
||||||
} else if err := r.conn.DelRule(masqRule); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
|
|
||||||
} else {
|
|
||||||
needsFlush = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if needsFlush {
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if merr != nil {
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(r.rules, ruleID+dnatSuffix)
|
|
||||||
delete(r.rules, ruleID+snatSuffix)
|
|
||||||
|
|
||||||
// Release once, only if the rule was present and removed.
|
|
||||||
if found {
|
|
||||||
r.releaseForwarding()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// releaseForwarding drops one IP forwarding reference, logging any error.
|
|
||||||
func (r *family) releaseForwarding() {
|
|
||||||
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
|
||||||
log.Errorf("release IP forwarding: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
|
||||||
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
|
||||||
|
|
||||||
if _, exists := r.rules[ruleID]; exists {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
protoNum, err := r.af.protoNum(protocol)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("convert protocol to number: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
exprs := []expr.Any{
|
|
||||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(r.wgIface.Name()),
|
|
||||||
},
|
|
||||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 2,
|
|
||||||
Data: []byte{protoNum},
|
|
||||||
},
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 3,
|
|
||||||
Base: expr.PayloadBaseTransportHeader,
|
|
||||||
Offset: 2,
|
|
||||||
Len: 2,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 3,
|
|
||||||
Data: binaryutil.BigEndian.PutUint16(originalPort),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
bits := 32
|
|
||||||
if localAddr.Is6() {
|
|
||||||
bits = 128
|
|
||||||
}
|
|
||||||
exprs = append(exprs, prefixMatchExprs(r.af, netip.PrefixFrom(localAddr, bits), false)...)
|
|
||||||
|
|
||||||
exprs = append(exprs,
|
|
||||||
&expr.Immediate{
|
|
||||||
Register: 1,
|
|
||||||
Data: localAddr.AsSlice(),
|
|
||||||
},
|
|
||||||
&expr.Immediate{
|
|
||||||
Register: 2,
|
|
||||||
Data: binaryutil.BigEndian.PutUint16(translatedPort),
|
|
||||||
},
|
|
||||||
&expr.NAT{
|
|
||||||
Type: expr.NATTypeDestNAT,
|
|
||||||
Family: uint32(r.af.tableFamily),
|
|
||||||
RegAddrMin: 1,
|
|
||||||
RegProtoMin: 2,
|
|
||||||
RegProtoMax: 0,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
dnatRule := &nftables.Rule{
|
|
||||||
Table: r.workTable,
|
|
||||||
Chain: r.chains[chainNameRoutingRdr],
|
|
||||||
Exprs: exprs,
|
|
||||||
UserData: []byte(ruleID),
|
|
||||||
}
|
|
||||||
r.conn.AddRule(dnatRule)
|
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf("add inbound DNAT rule: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.rules[ruleID] = dnatRule
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
|
||||||
func (r *family) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
|
||||||
return fmt.Errorf(refreshRulesMapError, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ruleID := firewall.RuleID(fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
|
||||||
|
|
||||||
rule, exists := r.rules[ruleID]
|
|
||||||
if !exists {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if rule.Handle == 0 {
|
|
||||||
log.Warnf("inbound DNAT rule %s has no handle, removing stale entry", ruleID)
|
|
||||||
delete(r.rules, ruleID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
|
||||||
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
|
|
||||||
}
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
|
|
||||||
}
|
|
||||||
delete(r.rules, ruleID)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ensureNATOutputChain lazily creates the OUTPUT NAT chain on first use.
|
|
||||||
func (r *family) ensureNATOutputChain() error {
|
|
||||||
if _, exists := r.chains[chainNameNATOutput]; exists {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
r.chains[chainNameNATOutput] = r.conn.AddChain(&nftables.Chain{
|
|
||||||
Name: chainNameNATOutput,
|
|
||||||
Table: r.workTable,
|
|
||||||
Hooknum: nftables.ChainHookOutput,
|
|
||||||
Priority: nftables.ChainPriorityNATDest,
|
|
||||||
Type: nftables.ChainTypeNAT,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
delete(r.chains, chainNameNATOutput)
|
|
||||||
return fmt.Errorf("create NAT output chain: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
|
||||||
func (r *family) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
|
||||||
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
|
||||||
|
|
||||||
if _, exists := r.rules[ruleID]; exists {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.ensureNATOutputChain(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
protoNum, err := r.af.protoNum(protocol)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("convert protocol to number: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
exprs := []expr.Any{
|
|
||||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: []byte{protoNum},
|
|
||||||
},
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 2,
|
|
||||||
Base: expr.PayloadBaseTransportHeader,
|
|
||||||
Offset: 2,
|
|
||||||
Len: 2,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 2,
|
|
||||||
Data: binaryutil.BigEndian.PutUint16(originalPort),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
bits := 32
|
|
||||||
if localAddr.Is6() {
|
|
||||||
bits = 128
|
|
||||||
}
|
|
||||||
exprs = append(exprs, prefixMatchExprs(r.af, netip.PrefixFrom(localAddr, bits), false)...)
|
|
||||||
|
|
||||||
exprs = append(exprs,
|
|
||||||
&expr.Immediate{
|
|
||||||
Register: 1,
|
|
||||||
Data: localAddr.AsSlice(),
|
|
||||||
},
|
|
||||||
&expr.Immediate{
|
|
||||||
Register: 2,
|
|
||||||
Data: binaryutil.BigEndian.PutUint16(translatedPort),
|
|
||||||
},
|
|
||||||
&expr.NAT{
|
|
||||||
Type: expr.NATTypeDestNAT,
|
|
||||||
Family: uint32(r.af.tableFamily),
|
|
||||||
RegAddrMin: 1,
|
|
||||||
RegProtoMin: 2,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
dnatRule := &nftables.Rule{
|
|
||||||
Table: r.workTable,
|
|
||||||
Chain: r.chains[chainNameNATOutput],
|
|
||||||
Exprs: exprs,
|
|
||||||
UserData: []byte(ruleID),
|
|
||||||
}
|
|
||||||
r.conn.AddRule(dnatRule)
|
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf("add output DNAT rule: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.rules[ruleID] = dnatRule
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
|
||||||
func (r *family) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
|
||||||
return fmt.Errorf(refreshRulesMapError, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ruleID := firewall.RuleID(fmt.Sprintf("output-dnat-%s-%s-%d-%d", localAddr.String(), protocol, originalPort, translatedPort))
|
|
||||||
|
|
||||||
rule, exists := r.rules[ruleID]
|
|
||||||
if !exists {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if rule.Handle == 0 {
|
|
||||||
log.Warnf("output DNAT rule %s has no handle, removing stale entry", ruleID)
|
|
||||||
delete(r.rules, ruleID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
|
||||||
return fmt.Errorf("delete output DNAT rule %s: %w", ruleID, err)
|
|
||||||
}
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf("flush delete output DNAT rule: %w", err)
|
|
||||||
}
|
|
||||||
delete(r.rules, ruleID)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,249 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
|
||||||
"github.com/google/nftables"
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
tableNat = "nat"
|
|
||||||
tableMangle = "mangle"
|
|
||||||
tableRaw = "raw"
|
|
||||||
tableSecurity = "security"
|
|
||||||
|
|
||||||
chainNameNatPrerouting = "PREROUTING"
|
|
||||||
chainNameRoutingFw = "netbird-rt-fwd"
|
|
||||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
|
||||||
chainNameRoutingRdr = "netbird-rt-redirect"
|
|
||||||
chainNameNATOutput = "netbird-nat-output"
|
|
||||||
chainNameForward = "FORWARD"
|
|
||||||
chainNameMangleForward = "netbird-mangle-forward"
|
|
||||||
|
|
||||||
// Peer ACL chain names.
|
|
||||||
chainNameInputRules = "netbird-acl-input-rules"
|
|
||||||
chainNameInputFilter = "netbird-acl-input-filter"
|
|
||||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
|
||||||
chainNameManglePrerouting = "netbird-mangle-prerouting"
|
|
||||||
chainNameManglePostrouting = "netbird-mangle-postrouting"
|
|
||||||
|
|
||||||
flushError = "flush: %w"
|
|
||||||
|
|
||||||
firewalldTableName = "firewalld"
|
|
||||||
|
|
||||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
|
||||||
userDataAcceptForwardRuleOif = "frwacceptoif"
|
|
||||||
userDataAcceptInputRule = "inputaccept"
|
|
||||||
|
|
||||||
dnatSuffix firewall.RuleID = "_dnat"
|
|
||||||
snatSuffix firewall.RuleID = "_snat"
|
|
||||||
|
|
||||||
// ipv4TCPHeaderSize is the minimum IPv4 (20) + TCP (20) header size for MSS calculation.
|
|
||||||
ipv4TCPHeaderSize = 40
|
|
||||||
// ipv6TCPHeaderSize is the minimum IPv6 (40) + TCP (20) header size for MSS calculation.
|
|
||||||
ipv6TCPHeaderSize = 60
|
|
||||||
|
|
||||||
// maxPrefixesSet 1638 prefixes start to fail, taking some margin
|
|
||||||
maxPrefixesSet = 1500
|
|
||||||
refreshRulesMapError = "refresh rules map: %w"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
errFilterTableNotFound = fmt.Errorf("'filter' table not found")
|
|
||||||
)
|
|
||||||
|
|
||||||
type setInput struct {
|
|
||||||
set firewall.Set
|
|
||||||
prefixes []netip.Prefix
|
|
||||||
}
|
|
||||||
|
|
||||||
// family holds the per-address-family nftables state. One instance
|
|
||||||
// handles route ACLs, peer ACLs, NAT, DNAT, and MSS clamping for a
|
|
||||||
// single family; the top-level Manager owns one for v4 and another
|
|
||||||
// for v6. The name predates the peer-ACL absorption; it's effectively
|
|
||||||
// the per-family backend now.
|
|
||||||
type family struct {
|
|
||||||
conn *nftables.Conn
|
|
||||||
workTable *nftables.Table
|
|
||||||
filterTable *nftables.Table
|
|
||||||
chains map[string]*nftables.Chain
|
|
||||||
|
|
||||||
// filters holds peer + route filter rules keyed by content hash.
|
|
||||||
// AddFilterRule writes here; DeleteFilterRule looks up by id.
|
|
||||||
filters map[firewall.RuleID]*Rule
|
|
||||||
|
|
||||||
// rules holds NAT, DNAT, and external accept rules (auxiliary
|
|
||||||
// plumbing that isn't a filter rule).
|
|
||||||
rules map[firewall.RuleID]*nftables.Rule
|
|
||||||
|
|
||||||
// Peer ACL chain handles.
|
|
||||||
chainInputRules *nftables.Chain
|
|
||||||
chainPrerouting *nftables.Chain
|
|
||||||
routingFwChainName string
|
|
||||||
|
|
||||||
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
|
|
||||||
|
|
||||||
af addrFamily
|
|
||||||
wgIface iFaceMapper
|
|
||||||
ipFwdState *ipfwdstate.IPForwardingState
|
|
||||||
legacyManagement bool
|
|
||||||
mtu uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
func newFamily(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) *family {
|
|
||||||
r := &family{
|
|
||||||
conn: &nftables.Conn{},
|
|
||||||
workTable: workTable,
|
|
||||||
chains: make(map[string]*nftables.Chain),
|
|
||||||
filters: make(map[firewall.RuleID]*Rule),
|
|
||||||
rules: make(map[firewall.RuleID]*nftables.Rule),
|
|
||||||
routingFwChainName: chainNameRoutingFw,
|
|
||||||
af: familyForAddr(workTable.Family == nftables.TableFamilyIPv4),
|
|
||||||
wgIface: wgIface,
|
|
||||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
|
||||||
mtu: mtu,
|
|
||||||
}
|
|
||||||
|
|
||||||
r.ipsetCounter = refcounter.New(
|
|
||||||
r.createIpSet,
|
|
||||||
r.deleteIpSet,
|
|
||||||
)
|
|
||||||
|
|
||||||
var err error
|
|
||||||
r.filterTable, err = r.loadFilterTable()
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("ip filter table not found: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) init(workTable *nftables.Table) error {
|
|
||||||
r.workTable = workTable
|
|
||||||
|
|
||||||
if err := r.removeAcceptFilterRules(); err != nil {
|
|
||||||
log.Errorf("failed to clean up rules from filter table: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.createContainers(); err != nil {
|
|
||||||
return fmt.Errorf("create containers: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.setupDataPlaneMark(); err != nil {
|
|
||||||
log.Errorf("failed to set up data plane mark: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.createDefaultChains(); err != nil {
|
|
||||||
return fmt.Errorf("create default acl chains: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset cleans existing nftables filter table rules from the system
|
|
||||||
func (r *family) Reset() error {
|
|
||||||
// clear without deleting the ipsets, the nf table will be deleted by the caller
|
|
||||||
r.ipsetCounter.Clear()
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
|
|
||||||
if err := r.removeAcceptFilterRules(); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil {
|
|
||||||
merr = multierror.Append(merr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.removeNatPreroutingRules(); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) loadFilterTable() (*nftables.Table, error) {
|
|
||||||
tables, err := r.conn.ListTablesOfFamily(r.af.tableFamily)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("list tables: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, table := range tables {
|
|
||||||
if table.Name == "filter" {
|
|
||||||
return table, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, errFilterTableNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
func hookName(hook *nftables.ChainHook) string {
|
|
||||||
if hook == nil {
|
|
||||||
return "unknown"
|
|
||||||
}
|
|
||||||
switch *hook {
|
|
||||||
case *nftables.ChainHookForward:
|
|
||||||
return chainNameForward
|
|
||||||
case *nftables.ChainHookInput:
|
|
||||||
return chainNameInput
|
|
||||||
default:
|
|
||||||
return fmt.Sprintf("hook(%d)", *hook)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func familyName(family nftables.TableFamily) string {
|
|
||||||
switch family {
|
|
||||||
case nftables.TableFamilyIPv4:
|
|
||||||
return "ip"
|
|
||||||
case nftables.TableFamilyIPv6:
|
|
||||||
return "ip6"
|
|
||||||
case nftables.TableFamilyINet:
|
|
||||||
return "inet"
|
|
||||||
default:
|
|
||||||
return fmt.Sprintf("family(%d)", family)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) iptablesProto() iptables.Protocol {
|
|
||||||
if r.af.tableFamily == nftables.TableFamilyIPv6 {
|
|
||||||
return iptables.ProtocolIPv6
|
|
||||||
}
|
|
||||||
return iptables.ProtocolIPv4
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) refreshRulesMap() error {
|
|
||||||
var merr *multierror.Error
|
|
||||||
newRules := make(map[firewall.RuleID]*nftables.Rule)
|
|
||||||
for _, chain := range r.chains {
|
|
||||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
|
||||||
if err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("list rules for chain %s: %w", chain.Name, err))
|
|
||||||
// preserve existing entries for this chain since we can't verify their state
|
|
||||||
for k, v := range r.rules {
|
|
||||||
if v.Chain != nil && v.Chain.Name == chain.Name {
|
|
||||||
newRules[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, rule := range rules {
|
|
||||||
if len(rule.UserData) > 0 {
|
|
||||||
newRules[firewall.RuleID(rule.UserData)] = rule
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
r.rules = newRules
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
@@ -1,512 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"slices"
|
|
||||||
|
|
||||||
"github.com/google/nftables"
|
|
||||||
"github.com/google/nftables/binaryutil"
|
|
||||||
"github.com/google/nftables/expr"
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
|
||||||
)
|
|
||||||
|
|
||||||
// AddFilterRule installs one nftables packet-filter rule. With
|
|
||||||
// destination empty the rule goes to the peer ACL input chain plus a
|
|
||||||
// paired prerouting mangle rule for the redirect mark. With
|
|
||||||
// destination set (prefix or named set) it goes to the route ACL
|
|
||||||
// forward chain. Multi-source rules collapse to one nftables rule
|
|
||||||
// backed by the shared refcounted hash:net set.
|
|
||||||
func (r *family) AddFilterRule(
|
|
||||||
id []byte,
|
|
||||||
sources []netip.Prefix,
|
|
||||||
destination firewall.Network,
|
|
||||||
proto firewall.Protocol,
|
|
||||||
sPort *firewall.Port,
|
|
||||||
dPort *firewall.Port,
|
|
||||||
action firewall.Action,
|
|
||||||
) (firewall.Rule, error) {
|
|
||||||
isRoute := !destination.IsZero()
|
|
||||||
|
|
||||||
ruleID := nbid.GenerateRuleID(sources, destination, proto, sPort, dPort, action)
|
|
||||||
if existing, ok := r.filters[ruleID]; ok {
|
|
||||||
return existing, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
srcExprs, err := r.applyNetwork(sourceNetwork(sources), sources, true)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("apply source: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var exprs []expr.Any
|
|
||||||
if isRoute {
|
|
||||||
exprs, err = r.buildRouteFilterExprs(srcExprs, destination, proto, sPort, dPort)
|
|
||||||
} else {
|
|
||||||
exprs, err = r.buildPeerFilterExprs(srcExprs, proto, sPort, dPort)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
r.dropNetworkMatch(srcExprs)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
mainExprs := slices.Clone(exprs)
|
|
||||||
verdict := expr.VerdictAccept
|
|
||||||
if action == firewall.ActionDrop {
|
|
||||||
verdict = expr.VerdictDrop
|
|
||||||
}
|
|
||||||
mainExprs = append(mainExprs, &expr.Verdict{Kind: verdict})
|
|
||||||
|
|
||||||
chain := r.chainInputRules
|
|
||||||
if isRoute {
|
|
||||||
chain = r.chains[chainNameRoutingFw]
|
|
||||||
}
|
|
||||||
|
|
||||||
userData := []byte(ruleID)
|
|
||||||
|
|
||||||
// Build the paired prerouting mangle rule before flushing so both
|
|
||||||
// rules commit in one transaction. An anonymous port set binds to
|
|
||||||
// exactly one rule, so the mangle rule needs its own expression list
|
|
||||||
// with fresh sets, not a clone of the main rule's. Guard on the
|
|
||||||
// prerouting chain first: building the expressions queues the port
|
|
||||||
// set, so skipping the build when there is no chain to bind it to
|
|
||||||
// keeps an unbound set out of the connection batch.
|
|
||||||
var mangleRule *nftables.Rule
|
|
||||||
if !isRoute && r.chainPrerouting != nil {
|
|
||||||
mangleExprs, err := r.buildPeerFilterExprs(srcExprs, proto, sPort, dPort)
|
|
||||||
if err != nil {
|
|
||||||
r.dropNetworkMatch(exprs)
|
|
||||||
return nil, fmt.Errorf("build mangle rule: %w", err)
|
|
||||||
}
|
|
||||||
mangleRule = r.queuePreroutingRule(mangleExprs, userData)
|
|
||||||
}
|
|
||||||
|
|
||||||
nftRule := &nftables.Rule{
|
|
||||||
Table: r.workTable,
|
|
||||||
Chain: chain,
|
|
||||||
Exprs: mainExprs,
|
|
||||||
UserData: userData,
|
|
||||||
}
|
|
||||||
if action == firewall.ActionDrop {
|
|
||||||
nftRule = r.conn.InsertRule(nftRule)
|
|
||||||
} else {
|
|
||||||
nftRule = r.conn.AddRule(nftRule)
|
|
||||||
}
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
r.dropNetworkMatch(exprs)
|
|
||||||
return nil, fmt.Errorf(flushError, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rule := &Rule{
|
|
||||||
nftRule: nftRule,
|
|
||||||
mangleRule: mangleRule,
|
|
||||||
sources: sources,
|
|
||||||
id: ruleID,
|
|
||||||
}
|
|
||||||
r.filters[ruleID] = rule
|
|
||||||
|
|
||||||
log.Debugf("added filter rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v",
|
|
||||||
sources, destination, proto, sPort, dPort, action)
|
|
||||||
return rule, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildPeerFilterExprs assembles the input-chain (peer ACL) match: the
|
|
||||||
// IP-header protocol byte read via Payload, then source, then ports
|
|
||||||
// (no counter), matching the historical peer shape so per-rule kernel
|
|
||||||
// state is identical to pre-unification.
|
|
||||||
func (r *family) buildPeerFilterExprs(
|
|
||||||
srcExprs []expr.Any,
|
|
||||||
proto firewall.Protocol,
|
|
||||||
sPort, dPort *firewall.Port,
|
|
||||||
) ([]expr.Any, error) {
|
|
||||||
var exprs []expr.Any
|
|
||||||
|
|
||||||
if proto != firewall.ProtocolALL {
|
|
||||||
protoNum, err := r.af.protoNum(proto)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
|
||||||
}
|
|
||||||
exprs = append(exprs,
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: r.af.protoOffset,
|
|
||||||
Len: 1,
|
|
||||||
},
|
|
||||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{protoNum}},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
exprs = append(exprs, srcExprs...)
|
|
||||||
|
|
||||||
portExprs, err := r.applyPorts(sPort, dPort)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
exprs = append(exprs, portExprs...)
|
|
||||||
return exprs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildRouteFilterExprs assembles the forward-chain (route ACL) match:
|
|
||||||
// source, then destination, then optional proto/ports, then a counter.
|
|
||||||
func (r *family) buildRouteFilterExprs(
|
|
||||||
srcExprs []expr.Any,
|
|
||||||
destination firewall.Network,
|
|
||||||
proto firewall.Protocol,
|
|
||||||
sPort, dPort *firewall.Port,
|
|
||||||
) ([]expr.Any, error) {
|
|
||||||
exprs := append([]expr.Any{}, srcExprs...)
|
|
||||||
|
|
||||||
destExprs, err := r.applyNetwork(destination, nil, false)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("apply destination: %w", err)
|
|
||||||
}
|
|
||||||
exprs = append(exprs, destExprs...)
|
|
||||||
|
|
||||||
if proto != firewall.ProtocolALL {
|
|
||||||
protoNum, err := r.af.protoNum(proto)
|
|
||||||
if err != nil {
|
|
||||||
r.dropNetworkMatch(destExprs)
|
|
||||||
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
|
||||||
}
|
|
||||||
exprs = append(exprs,
|
|
||||||
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
|
||||||
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{protoNum}},
|
|
||||||
)
|
|
||||||
|
|
||||||
portExprs, err := r.applyPorts(sPort, dPort)
|
|
||||||
if err != nil {
|
|
||||||
r.dropNetworkMatch(destExprs)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
exprs = append(exprs, portExprs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
exprs = append(exprs, &expr.Counter{})
|
|
||||||
return exprs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) hasRule(id firewall.RuleID) bool {
|
|
||||||
_, ok := r.filters[id]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) hasDNATRule(id firewall.RuleID) bool {
|
|
||||||
_, ok := r.rules[id+dnatSuffix]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteFilterRule removes a previously installed filter rule. Source
|
|
||||||
// set references are recovered from the stored rule's expressions via
|
|
||||||
// findSets and dropped from the shared refcounter.
|
|
||||||
func (r *family) DeleteFilterRule(rule firewall.Rule) error {
|
|
||||||
ruleID := rule.ID()
|
|
||||||
pr, ok := r.filters[ruleID]
|
|
||||||
if !ok {
|
|
||||||
log.Debugf("filter rule %s not found", ruleID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// A freshly added rule carries no handle until it is read back from
|
|
||||||
// the kernel, and Flush only refreshes the peer chains. Pull live
|
|
||||||
// handles for this rule's chain before deciding it is stale so route
|
|
||||||
// rules (which Flush never refreshes) can actually be deleted. A
|
|
||||||
// refresh failure aborts the delete without touching tracking state,
|
|
||||||
// so the caller can retry while the rule may still exist in the kernel.
|
|
||||||
if pr.nftRule.Handle == 0 {
|
|
||||||
if err := r.refreshRuleHandles(pr.nftRule.Chain, false); err != nil {
|
|
||||||
return fmt.Errorf("refresh handles for chain %s: %w", pr.nftRule.Chain.Name, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Refresh the mangle handle independently: the main rule's handle can
|
|
||||||
// be populated while the prerouting refresh during Flush failed, and
|
|
||||||
// gating the mangle refresh on the main handle would leak the mangle
|
|
||||||
// rule on delete.
|
|
||||||
if pr.mangleRule != nil && pr.mangleRule.Handle == 0 {
|
|
||||||
if err := r.refreshRuleHandles(r.chainPrerouting, true); err != nil {
|
|
||||||
return fmt.Errorf("refresh mangle handles: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if pr.nftRule.Handle == 0 {
|
|
||||||
log.Warnf("filter rule %s has no handle, removing stale entry", ruleID)
|
|
||||||
r.dropNetworkMatch(pr.nftRule.Exprs)
|
|
||||||
delete(r.filters, ruleID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.conn.DelRule(pr.nftRule); err != nil {
|
|
||||||
log.Errorf("queue rule delete: %v", err)
|
|
||||||
}
|
|
||||||
if pr.mangleRule != nil {
|
|
||||||
if err := r.conn.DelRule(pr.mangleRule); err != nil {
|
|
||||||
log.Errorf("queue mangle rule delete: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf("flush delete %s: %w", ruleID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.dropNetworkMatch(pr.nftRule.Exprs)
|
|
||||||
delete(r.filters, ruleID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) decrementSetCounter(rule *nftables.Rule) error {
|
|
||||||
if r.ipsetCounter == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
sets := findSets(rule)
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
for _, setName := range sets {
|
|
||||||
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("decrement set counter: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// dropNetworkMatch undoes whatever the source/destination match
|
|
||||||
// reserved. Safe to call when the spec is empty or holds only inline
|
|
||||||
// matchers.
|
|
||||||
func (r *family) dropNetworkMatch(exprs []expr.Any) {
|
|
||||||
if r.ipsetCounter == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for _, e := range exprs {
|
|
||||||
lookup, ok := e.(*expr.Lookup)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if _, err := r.ipsetCounter.Decrement(lookup.SetName); err != nil {
|
|
||||||
log.Errorf("rollback ipset decrement %s: %v", lookup.SetName, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) applyNetwork(
|
|
||||||
network firewall.Network,
|
|
||||||
setPrefixes []netip.Prefix,
|
|
||||||
isSource bool,
|
|
||||||
) ([]expr.Any, error) {
|
|
||||||
if network.IsSet() {
|
|
||||||
exprs, err := r.getIpSet(network.Set, setPrefixes, isSource)
|
|
||||||
if err != nil {
|
|
||||||
side := "destination"
|
|
||||||
if isSource {
|
|
||||||
side = "source"
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("%s set: %w", side, err)
|
|
||||||
}
|
|
||||||
return exprs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if network.IsPrefix() {
|
|
||||||
return prefixMatchExprs(r.af, network.Prefix, isSource), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// applyPort builds the transport-header port match. A single value
|
|
||||||
// compares directly, a range uses a range expression, and multiple
|
|
||||||
// values go through an anonymous constant set: consecutive cmp
|
|
||||||
// expressions AND together, so chained equality comparisons could
|
|
||||||
// never match more than one port. The set is queued on the
|
|
||||||
// connection and committed by the caller's flush together with the
|
|
||||||
// rule that binds it.
|
|
||||||
func (r *family) applyPort(port *firewall.Port, isSource bool) ([]expr.Any, error) {
|
|
||||||
if port == nil || len(port.Values) == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// dst port
|
|
||||||
offset := uint32(2)
|
|
||||||
if isSource {
|
|
||||||
// src port
|
|
||||||
offset = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
exprs := []expr.Any{
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseTransportHeader,
|
|
||||||
Offset: offset,
|
|
||||||
Len: 2,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case port.IsRange && len(port.Values) == 2:
|
|
||||||
exprs = append(exprs, &expr.Range{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
FromData: binaryutil.BigEndian.PutUint16(port.Values[0]),
|
|
||||||
ToData: binaryutil.BigEndian.PutUint16(port.Values[1]),
|
|
||||||
})
|
|
||||||
case len(port.Values) == 1:
|
|
||||||
exprs = append(exprs, &expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: binaryutil.BigEndian.PutUint16(port.Values[0]),
|
|
||||||
})
|
|
||||||
default:
|
|
||||||
lookup, err := r.anonymousPortSet(port.Values)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
exprs = append(exprs, lookup)
|
|
||||||
}
|
|
||||||
|
|
||||||
return exprs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// anonymousPortSet queues an anonymous constant set holding the given
|
|
||||||
// ports on the connection and returns a lookup against it. The set is
|
|
||||||
// committed by the caller's flush together with the rule that binds it.
|
|
||||||
func (r *family) anonymousPortSet(values []uint16) (*expr.Lookup, error) {
|
|
||||||
set := &nftables.Set{
|
|
||||||
Anonymous: true,
|
|
||||||
Constant: true,
|
|
||||||
Table: r.workTable,
|
|
||||||
KeyType: nftables.TypeInetService,
|
|
||||||
}
|
|
||||||
elements := make([]nftables.SetElement, 0, len(values))
|
|
||||||
for _, p := range values {
|
|
||||||
elements = append(elements, nftables.SetElement{Key: binaryutil.BigEndian.PutUint16(p)})
|
|
||||||
}
|
|
||||||
if err := r.conn.AddSet(set, elements); err != nil {
|
|
||||||
return nil, fmt.Errorf("add anonymous port set: %w", err)
|
|
||||||
}
|
|
||||||
return &expr.Lookup{
|
|
||||||
SourceRegister: 1,
|
|
||||||
SetID: set.ID,
|
|
||||||
SetName: set.Name,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// applyPorts builds the source then destination port matches.
|
|
||||||
func (r *family) applyPorts(sPort, dPort *firewall.Port) ([]expr.Any, error) {
|
|
||||||
sPortExprs, err := r.applyPort(sPort, true)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("apply source port: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
dPortExprs, err := r.applyPort(dPort, false)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("apply destination port: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return append(sPortExprs, dPortExprs...), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// prefixMatchExprs is the family-aware match sequence for a CIDR
|
|
||||||
// prefix. /0 returns nil; a host prefix (full bit length for the
|
|
||||||
// family) skips the bitwise step since the mask is all-ones. Shared
|
|
||||||
// between family and aclManager so both treat single prefixes
|
|
||||||
// identically.
|
|
||||||
func prefixMatchExprs(af addrFamily, prefix netip.Prefix, isSource bool) []expr.Any {
|
|
||||||
offset := af.dstAddrOffset
|
|
||||||
if isSource {
|
|
||||||
offset = af.srcAddrOffset
|
|
||||||
}
|
|
||||||
|
|
||||||
ones := prefix.Bits()
|
|
||||||
if ones == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
payload := &expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: offset,
|
|
||||||
Len: af.addrLen,
|
|
||||||
}
|
|
||||||
cmp := &expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: prefix.Masked().Addr().AsSlice(),
|
|
||||||
}
|
|
||||||
|
|
||||||
if ones == af.totalBits {
|
|
||||||
return []expr.Any{payload, cmp}
|
|
||||||
}
|
|
||||||
|
|
||||||
mask := net.CIDRMask(ones, af.totalBits)
|
|
||||||
xor := make([]byte, af.addrLen)
|
|
||||||
return []expr.Any{
|
|
||||||
payload,
|
|
||||||
&expr.Bitwise{
|
|
||||||
DestRegister: 1,
|
|
||||||
SourceRegister: 1,
|
|
||||||
Len: af.addrLen,
|
|
||||||
Mask: mask,
|
|
||||||
Xor: xor,
|
|
||||||
},
|
|
||||||
cmp,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getCtNewExprs() []expr.Any {
|
|
||||||
return []expr.Any{
|
|
||||||
&expr.Ct{
|
|
||||||
Key: expr.CtKeySTATE,
|
|
||||||
Register: 1,
|
|
||||||
},
|
|
||||||
&expr.Bitwise{
|
|
||||||
SourceRegister: 1,
|
|
||||||
DestRegister: 1,
|
|
||||||
Len: 4,
|
|
||||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
|
|
||||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpNeq,
|
|
||||||
Register: 1,
|
|
||||||
Data: []byte{0, 0, 0, 0},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// sourceNetwork classifies a source-prefix list into the firewall.Network
|
|
||||||
// shape the rest of the spec-builder consumes: empty for match-any, a
|
|
||||||
// single prefix inline, or an ipset for multiple sources.
|
|
||||||
func sourceNetwork(sources []netip.Prefix) firewall.Network {
|
|
||||||
switch {
|
|
||||||
case len(sources) == 0:
|
|
||||||
return firewall.Network{}
|
|
||||||
case len(sources) == 1 && sources[0].Bits() == 0:
|
|
||||||
return firewall.Network{}
|
|
||||||
case len(sources) == 1:
|
|
||||||
return firewall.Network{Prefix: sources[0]}
|
|
||||||
default:
|
|
||||||
return firewall.Network{Set: firewall.NewPrefixSet(sources)}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ifname(n string) []byte {
|
|
||||||
b := make([]byte, 16)
|
|
||||||
copy(b, n+"\x00")
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
// findSets scans an nftables rule's expressions for expr.Lookup and
|
|
||||||
// returns the named sets in occurrence order. Used at delete time to
|
|
||||||
// drop ipsetCounter references; peer and route ACLs go through it.
|
|
||||||
func findSets(rule *nftables.Rule) []string {
|
|
||||||
var sets []string
|
|
||||||
for _, e := range rule.Exprs {
|
|
||||||
if lookup, ok := e.(*expr.Lookup); ok {
|
|
||||||
sets = append(sets, lookup.SetName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return sets
|
|
||||||
}
|
|
||||||
@@ -1,90 +0,0 @@
|
|||||||
//go:build integration && !android
|
|
||||||
|
|
||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/google/nftables"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestInterfaceAllowerInputOnly verifies the userspace-mode allower opens the
|
|
||||||
// interface on the INPUT hook of foreign chains only (not FORWARD, since the
|
|
||||||
// userspace router never forwards in the kernel), creates no netbird work
|
|
||||||
// table, and removes its rules on Close.
|
|
||||||
func TestInterfaceAllowerInputOnly(t *testing.T) {
|
|
||||||
if os.Geteuid() != 0 {
|
|
||||||
t.Skip("root required")
|
|
||||||
}
|
|
||||||
|
|
||||||
require.False(t, ipTableExists(t, getTableName()), "precondition: no stale netbird table")
|
|
||||||
|
|
||||||
conn := &nftables.Conn{}
|
|
||||||
extTable := conn.AddTable(&nftables.Table{Name: "nbtest_extchains", Family: nftables.TableFamilyINet})
|
|
||||||
inputChain := conn.AddChain(&nftables.Chain{
|
|
||||||
Name: "ext_input", Table: extTable,
|
|
||||||
Hooknum: nftables.ChainHookInput, Priority: nftables.ChainPriorityFilter, Type: nftables.ChainTypeFilter,
|
|
||||||
})
|
|
||||||
forwardChain := conn.AddChain(&nftables.Chain{
|
|
||||||
Name: "ext_forward", Table: extTable,
|
|
||||||
Hooknum: nftables.ChainHookForward, Priority: nftables.ChainPriorityFilter, Type: nftables.ChainTypeFilter,
|
|
||||||
})
|
|
||||||
require.NoError(t, conn.Flush(), "create external table and chains")
|
|
||||||
t.Cleanup(func() {
|
|
||||||
c := &nftables.Conn{}
|
|
||||||
c.DelTable(extTable)
|
|
||||||
_ = c.Flush()
|
|
||||||
})
|
|
||||||
|
|
||||||
allower, err := NewInterfaceAllower(ifaceMock, iface.DefaultMTU)
|
|
||||||
require.NoError(t, err, "create allower")
|
|
||||||
require.NoError(t, allower.Apply(), "apply")
|
|
||||||
|
|
||||||
require.True(t, chainHasUserData(t, extTable, inputChain, userDataAcceptInputRule),
|
|
||||||
"external INPUT chain should get the accept rule")
|
|
||||||
require.Len(t, listRules(t, extTable, forwardChain), 0,
|
|
||||||
"external FORWARD chain must not be opened in userspace mode")
|
|
||||||
require.False(t, ipTableExists(t, getTableName()),
|
|
||||||
"allower must not create a netbird work table")
|
|
||||||
|
|
||||||
require.NoError(t, allower.Close(), "close")
|
|
||||||
require.False(t, chainHasUserData(t, extTable, inputChain, userDataAcceptInputRule),
|
|
||||||
"accept rule should be removed on close")
|
|
||||||
}
|
|
||||||
|
|
||||||
func listRules(t *testing.T, table *nftables.Table, chain *nftables.Chain) []*nftables.Rule {
|
|
||||||
t.Helper()
|
|
||||||
c := &nftables.Conn{}
|
|
||||||
rules, err := c.GetRules(table, chain)
|
|
||||||
require.NoError(t, err)
|
|
||||||
return rules
|
|
||||||
}
|
|
||||||
|
|
||||||
func chainHasUserData(t *testing.T, table *nftables.Table, chain *nftables.Chain, ud string) bool {
|
|
||||||
for _, r := range listRules(t, table, chain) {
|
|
||||||
if bytes.Equal(r.UserData, []byte(ud)) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func ipTableExists(t *testing.T, name string) bool {
|
|
||||||
t.Helper()
|
|
||||||
c := &nftables.Conn{}
|
|
||||||
for _, fam := range []nftables.TableFamily{nftables.TableFamilyIPv4, nftables.TableFamilyIPv6} {
|
|
||||||
tbls, err := c.ListTablesOfFamily(fam)
|
|
||||||
require.NoError(t, err)
|
|
||||||
for _, tb := range tbls {
|
|
||||||
if tb.Name == name {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
@@ -1,107 +0,0 @@
|
|||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/google/nftables"
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
// InterfaceAllower opens the NetBird interface in the kernel's filter table and
|
|
||||||
// external chains and keeps them reconciled via a netlink monitor, so the host
|
|
||||||
// firewall doesn't drop traffic the NetBird firewall handles. It is used by the
|
|
||||||
// userspace firewall, where routing happens in the forwarder, so only INPUT is
|
|
||||||
// opened (the userspace router never forwards in the kernel).
|
|
||||||
//
|
|
||||||
// It owns its own families/connection and never creates a netbird work table.
|
|
||||||
// firewalld trust is handled by the caller, not here. Its operations are serial
|
|
||||||
// (Apply before the monitor starts; reconciles run on the single monitor
|
|
||||||
// goroutine; Close stops the monitor before removing), so it needs no locking.
|
|
||||||
//
|
|
||||||
// TODO: this opens nftables and the iptables-nft filter table (detected via
|
|
||||||
// nft), but not a legacy-iptables ruleset running in parallel with nftables.
|
|
||||||
// Such a host would keep its legacy filter chains closed for the interface.
|
|
||||||
type InterfaceAllower struct {
|
|
||||||
family4 *family
|
|
||||||
family6 *family
|
|
||||||
extMonitor *externalChainMonitor
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewInterfaceAllower builds an allower for the given interface. It returns an
|
|
||||||
// error when nftables is unavailable (e.g. an iptables-legacy host), so the
|
|
||||||
// caller can fall back to firewalld trust.
|
|
||||||
func NewInterfaceAllower(wgIface iFaceMapper, mtu uint16) (*InterfaceAllower, error) {
|
|
||||||
tableName := getTableName()
|
|
||||||
|
|
||||||
family4 := newFamily(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}, wgIface, mtu)
|
|
||||||
|
|
||||||
// Probe nftables availability before committing to this backend.
|
|
||||||
if _, err := family4.conn.ListChainsOfTableFamily(nftables.TableFamilyINet); err != nil {
|
|
||||||
return nil, fmt.Errorf("nftables not available: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
a := &InterfaceAllower{family4: family4}
|
|
||||||
|
|
||||||
if wgIface.Address().HasIPv6() {
|
|
||||||
a.family6 = newFamily(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6}, wgIface, mtu)
|
|
||||||
}
|
|
||||||
|
|
||||||
a.extMonitor = newExternalChainMonitor(a)
|
|
||||||
return a, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply opens the interface (INPUT only) in the foreign filter chains and starts
|
|
||||||
// reconciling them on nftables changes.
|
|
||||||
func (a *InterfaceAllower) Apply() error {
|
|
||||||
var merr *multierror.Error
|
|
||||||
for _, f := range a.families() {
|
|
||||||
// Remove any stale accepts first so a prior unclean exit (e.g. SIGKILL,
|
|
||||||
// where Close never ran) is recovered deterministically rather than
|
|
||||||
// accumulating duplicate rules on the iptables filter table.
|
|
||||||
if err := f.removeAcceptFilterRules(); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("clean stale accept rules: %w", err))
|
|
||||||
}
|
|
||||||
if err := f.openInterface(false); err != nil {
|
|
||||||
merr = multierror.Append(merr, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
a.extMonitor.start()
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// families returns the configured address families (v4, and v6 when present).
|
|
||||||
func (a *InterfaceAllower) families() []*family {
|
|
||||||
families := []*family{a.family4}
|
|
||||||
if a.family6 != nil {
|
|
||||||
families = append(families, a.family6)
|
|
||||||
}
|
|
||||||
return families
|
|
||||||
}
|
|
||||||
|
|
||||||
// reconcileExternalChains re-applies the INPUT accepts to external chains. It
|
|
||||||
// implements externalChainReconciler for the monitor.
|
|
||||||
func (a *InterfaceAllower) reconcileExternalChains() error {
|
|
||||||
var merr *multierror.Error
|
|
||||||
for _, f := range a.families() {
|
|
||||||
if err := f.acceptExternalChainsRules(false); err != nil {
|
|
||||||
merr = multierror.Append(merr, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close stops the monitor and removes the accept rules.
|
|
||||||
func (a *InterfaceAllower) Close() error {
|
|
||||||
a.extMonitor.stop()
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
for _, f := range a.families() {
|
|
||||||
if err := f.removeAcceptFilterRules(); err != nil {
|
|
||||||
merr = multierror.Append(merr, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
@@ -1,210 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"github.com/google/nftables"
|
|
||||||
"github.com/google/nftables/expr"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (r *family) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bool) ([]expr.Any, error) {
|
|
||||||
ref, err := r.ipsetCounter.Increment(set.HashedName(), setInput{
|
|
||||||
set: set,
|
|
||||||
prefixes: prefixes,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return r.getIpSetExprs(ref, isSource)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) createIpSet(setName string, input setInput) (*nftables.Set, error) {
|
|
||||||
// overlapping prefixes will result in an error, so we need to merge them
|
|
||||||
prefixes := firewall.MergeIPRanges(input.prefixes)
|
|
||||||
|
|
||||||
nfset := &nftables.Set{
|
|
||||||
Name: setName,
|
|
||||||
Comment: input.set.Comment(),
|
|
||||||
Table: r.workTable,
|
|
||||||
// required for prefixes
|
|
||||||
Interval: true,
|
|
||||||
KeyType: r.af.setKeyType,
|
|
||||||
}
|
|
||||||
|
|
||||||
elements := r.convertPrefixesToSet(prefixes)
|
|
||||||
nElements := len(elements)
|
|
||||||
|
|
||||||
maxElements := maxPrefixesSet * 2
|
|
||||||
initialElements := elements[:min(maxElements, nElements)]
|
|
||||||
|
|
||||||
if err := r.conn.AddSet(nfset, initialElements); err != nil {
|
|
||||||
return nil, fmt.Errorf("error adding set %s: %w", setName, err)
|
|
||||||
}
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return nil, fmt.Errorf("flush error: %w", err)
|
|
||||||
}
|
|
||||||
log.Debugf("Created new ipset: %s with %d initial prefixes (total prefixes %d)", setName, len(initialElements)/2, len(prefixes))
|
|
||||||
|
|
||||||
// The set is committed now. If a later batch fails, destroy it: the
|
|
||||||
// refcounter records nothing on a create-callback error, so it would
|
|
||||||
// otherwise leak, and a partial source set fails-open for deny rules.
|
|
||||||
if err := r.addRemainingElements(nfset, elements, maxElements); err != nil {
|
|
||||||
if derr := r.deleteIpSet(setName, nfset); derr != nil {
|
|
||||||
log.Warnf("rollback ipset %s after add failure: %v", setName, derr)
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("Created new ipset: %s with %d prefixes", setName, len(prefixes))
|
|
||||||
return nfset, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// addRemainingElements adds element batches beyond the initial one in
|
|
||||||
// maxElements-sized chunks, flushing each. Called after the set has been
|
|
||||||
// created with its first batch.
|
|
||||||
func (r *family) addRemainingElements(nfset *nftables.Set, elements []nftables.SetElement, maxElements int) error {
|
|
||||||
nElements := len(elements)
|
|
||||||
for subStart := maxElements; subStart < nElements; subStart += maxElements {
|
|
||||||
subEnd := min(subStart+maxElements, nElements)
|
|
||||||
subElement := elements[subStart:subEnd]
|
|
||||||
nSubPrefixes := len(subElement) / 2
|
|
||||||
log.Tracef("Adding new prefixes (%d) in ipset: %s", nSubPrefixes, nfset.Name)
|
|
||||||
if err := r.conn.SetAddElements(nfset, subElement); err != nil {
|
|
||||||
return fmt.Errorf("error adding prefixes (%d) to set %s: %w", nSubPrefixes, nfset.Name, err)
|
|
||||||
}
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf("flush error: %w", err)
|
|
||||||
}
|
|
||||||
log.Debugf("Added new prefixes (%d) in ipset: %s", nSubPrefixes, nfset.Name)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
|
|
||||||
var elements []nftables.SetElement
|
|
||||||
for _, prefix := range prefixes {
|
|
||||||
// nftables needs half-open intervals [firstIP, lastIP) for prefixes
|
|
||||||
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
|
|
||||||
firstIP := prefix.Addr()
|
|
||||||
|
|
||||||
// For a /0 the last address is the broadcast and its Next() overflows
|
|
||||||
// to an invalid Addr with an empty key, so wrap to the zero address,
|
|
||||||
// which nftables reads as the open end of a full-range interval.
|
|
||||||
var lastKey []byte
|
|
||||||
if prefix.Bits() == 0 {
|
|
||||||
lastKey = make([]byte, r.af.addrLen)
|
|
||||||
} else {
|
|
||||||
lastKey = calculateLastIP(prefix).Next().AsSlice()
|
|
||||||
}
|
|
||||||
|
|
||||||
// the nft tool also adds a zero-address IntervalEnd element, see https://github.com/google/nftables/issues/247
|
|
||||||
// nftables.SetElement{Key: make([]byte, r.af.addrLen), IntervalEnd: true},
|
|
||||||
elements = append(elements,
|
|
||||||
nftables.SetElement{Key: firstIP.AsSlice()},
|
|
||||||
nftables.SetElement{Key: lastKey, IntervalEnd: true},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
return elements
|
|
||||||
}
|
|
||||||
|
|
||||||
// calculateLastIP determines the last IP in a given prefix.
|
|
||||||
func calculateLastIP(prefix netip.Prefix) netip.Addr {
|
|
||||||
masked := prefix.Masked()
|
|
||||||
if masked.Addr().Is4() {
|
|
||||||
hostMask := ^uint32(0) >> masked.Bits()
|
|
||||||
lastIP := uint32FromNetipAddr(masked.Addr()) | hostMask
|
|
||||||
return netip.AddrFrom4(uint32ToBytes(lastIP))
|
|
||||||
}
|
|
||||||
|
|
||||||
// IPv6: set host bits to all 1s
|
|
||||||
b := masked.Addr().As16()
|
|
||||||
bits := masked.Bits()
|
|
||||||
for i := bits; i < 128; i++ {
|
|
||||||
b[i/8] |= 1 << (7 - i%8)
|
|
||||||
}
|
|
||||||
return netip.AddrFrom16(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Utility function to convert netip.Addr to uint32.
|
|
||||||
func uint32FromNetipAddr(addr netip.Addr) uint32 {
|
|
||||||
b := addr.As4()
|
|
||||||
return binary.BigEndian.Uint32(b[:])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Utility function to convert uint32 to a netip-compatible byte slice.
|
|
||||||
func uint32ToBytes(ip uint32) [4]byte {
|
|
||||||
var b [4]byte
|
|
||||||
binary.BigEndian.PutUint32(b[:], ip)
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) deleteIpSet(setName string, nfset *nftables.Set) error {
|
|
||||||
r.conn.DelSet(nfset)
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf(flushError, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("Deleted unused ipset %s", setName)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
|
||||||
nfset, err := r.conn.GetSetByName(r.workTable, set.HashedName())
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("get set %s: %w", set.HashedName(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Overlapping prefixes (e.g. duplicate resolved addresses) make the
|
|
||||||
// interval set reject the batch, so merge them as createIpSet does.
|
|
||||||
prefixes = firewall.MergeIPRanges(prefixes)
|
|
||||||
elements := r.convertPrefixesToSet(prefixes)
|
|
||||||
|
|
||||||
// Add in batches sized like createIpSet so a large update does not
|
|
||||||
// exceed the netlink message size limit.
|
|
||||||
maxElements := maxPrefixesSet * 2
|
|
||||||
for start := 0; start < len(elements); start += maxElements {
|
|
||||||
end := min(start+maxElements, len(elements))
|
|
||||||
if err := r.conn.SetAddElements(nfset, elements[start:end]); err != nil {
|
|
||||||
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
|
|
||||||
}
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf(flushError, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("updated set %s with %d prefixes", set.HashedName(), len(prefixes))
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
|
|
||||||
// dst offset by default
|
|
||||||
offset := r.af.dstAddrOffset
|
|
||||||
if isSource {
|
|
||||||
// src offset
|
|
||||||
offset = r.af.srcAddrOffset
|
|
||||||
}
|
|
||||||
|
|
||||||
return []expr.Any{
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: offset,
|
|
||||||
Len: r.af.addrLen,
|
|
||||||
},
|
|
||||||
&expr.Lookup{
|
|
||||||
SourceRegister: 1,
|
|
||||||
SetName: ref.Out.Name,
|
|
||||||
SetID: ref.Out.ID,
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestConvertPrefixesToSetWildcard verifies that a /0 prefix produces a
|
|
||||||
// usable interval. The last address of a /0 is the broadcast, whose Next()
|
|
||||||
// overflows to an invalid Addr with an empty key; the IntervalEnd must wrap
|
|
||||||
// to the zero address instead so nftables sees a full-range interval.
|
|
||||||
func TestConvertPrefixesToSetWildcard(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
af addrFamily
|
|
||||||
prefix string
|
|
||||||
}{
|
|
||||||
{"IPv4 /0", afIPv4, "0.0.0.0/0"},
|
|
||||||
{"IPv6 /0", afIPv6, "::/0"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
r := &family{af: tt.af}
|
|
||||||
elements := r.convertPrefixesToSet([]netip.Prefix{netip.MustParsePrefix(tt.prefix)})
|
|
||||||
|
|
||||||
require.Len(t, elements, 2, "expected start and interval-end element")
|
|
||||||
assert.False(t, elements[0].IntervalEnd, "first element is the interval start")
|
|
||||||
assert.True(t, elements[1].IntervalEnd, "second element is the interval end")
|
|
||||||
assert.Len(t, elements[1].Key, int(tt.af.addrLen), "interval-end key must be a zero address, not empty")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
85
client/firewall/nftables/ipsetstore_linux.go
Normal file
85
client/firewall/nftables/ipsetstore_linux.go
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
package nftables
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ipsetStore struct {
|
||||||
|
ipsetReference map[string]int
|
||||||
|
ipsets map[string]map[string]struct{} // ipsetName -> list of ips
|
||||||
|
}
|
||||||
|
|
||||||
|
func newIpsetStore() *ipsetStore {
|
||||||
|
return &ipsetStore{
|
||||||
|
ipsetReference: make(map[string]int),
|
||||||
|
ipsets: make(map[string]map[string]struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipsetStore) ips(ipsetName string) (map[string]struct{}, bool) {
|
||||||
|
r, ok := s.ipsets[ipsetName]
|
||||||
|
return r, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipsetStore) newIpset(ipsetName string) map[string]struct{} {
|
||||||
|
s.ipsetReference[ipsetName] = 0
|
||||||
|
ipList := make(map[string]struct{})
|
||||||
|
s.ipsets[ipsetName] = ipList
|
||||||
|
return ipList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipsetStore) deleteIpset(ipsetName string) {
|
||||||
|
delete(s.ipsetReference, ipsetName)
|
||||||
|
delete(s.ipsets, ipsetName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipsetStore) DeleteIpFromSet(ipsetName string, ip net.IP) {
|
||||||
|
ipList, ok := s.ipsets[ipsetName]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
delete(ipList, ip.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipsetStore) AddIpToSet(ipsetName string, ip net.IP) {
|
||||||
|
ipList, ok := s.ipsets[ipsetName]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ipList[ip.String()] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipsetStore) IsIpInSet(ipsetName string, ip net.IP) bool {
|
||||||
|
ipList, ok := s.ipsets[ipsetName]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, ok = ipList[ip.String()]
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipsetStore) AddReferenceToIpset(ipsetName string) {
|
||||||
|
s.ipsetReference[ipsetName]++
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipsetStore) DeleteReferenceFromIpSet(ipsetName string) {
|
||||||
|
r, ok := s.ipsetReference[ipsetName]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.ipsetReference[ipsetName]--
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ipsetStore) HasReferenceToSet(ipsetName string) bool {
|
||||||
|
if _, ok := s.ipsetReference[ipsetName]; !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if s.ipsetReference[ipsetName] == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@ package nftables
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -15,6 +16,7 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@@ -43,17 +45,18 @@ type iFaceMapper interface {
|
|||||||
Address() wgaddr.Address
|
Address() wgaddr.Address
|
||||||
}
|
}
|
||||||
|
|
||||||
// Manager of nftables firewall. Per-family state (peer ACLs, route
|
// Manager of iptables firewall
|
||||||
// ACLs, NAT, DNAT, MSS clamping) lives on family; Manager dispatches
|
|
||||||
// by family and provides the public firewall.Manager surface.
|
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
rConn *nftables.Conn
|
rConn *nftables.Conn
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
|
|
||||||
family4 *family
|
router *router
|
||||||
// IPv6 counterpart, nil when no v6 overlay.
|
aclManager *AclManager
|
||||||
family6 *family
|
|
||||||
|
// IPv6 counterparts, nil when no v6 overlay
|
||||||
|
router6 *router
|
||||||
|
aclManager6 *AclManager
|
||||||
|
|
||||||
notrackOutputChain *nftables.Chain
|
notrackOutputChain *nftables.Chain
|
||||||
notrackPreroutingChain *nftables.Chain
|
notrackPreroutingChain *nftables.Chain
|
||||||
@@ -71,10 +74,21 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
|||||||
tableName := getTableName()
|
tableName := getTableName()
|
||||||
workTable := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}
|
workTable := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}
|
||||||
|
|
||||||
m.family4 = newFamily(workTable, wgIface, mtu)
|
var err error
|
||||||
|
m.router, err = newRouter(workTable, wgIface, mtu)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create router: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
if wgIface.Address().HasIPv6() {
|
if wgIface.Address().HasIPv6() {
|
||||||
m.createIPv6Components(tableName, wgIface, mtu)
|
if err := m.createIPv6Components(tableName, wgIface, mtu); err != nil {
|
||||||
|
return nil, fmt.Errorf("create IPv6 firewall: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
m.extMonitor = newExternalChainMonitor(m)
|
m.extMonitor = newExternalChainMonitor(m)
|
||||||
@@ -82,19 +96,30 @@ func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
|||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mtu uint16) {
|
func (m *Manager) createIPv6Components(tableName string, wgIface iFaceMapper, mtu uint16) error {
|
||||||
workTable6 := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6}
|
workTable6 := &nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv6}
|
||||||
|
|
||||||
m.family6 = newFamily(workTable6, wgIface, mtu)
|
var err error
|
||||||
|
m.router6, err = newRouter(workTable6, wgIface, mtu)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create v6 router: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Share the same IP forwarding state with the v4 router, since
|
// Share the same IP forwarding state with the v4 router, since
|
||||||
// EnableIPForwarding controls both v4 and v6 sysctls.
|
// EnableIPForwarding controls both v4 and v6 sysctls.
|
||||||
m.family6.ipFwdState = m.family4.ipFwdState
|
m.router6.ipFwdState = m.router.ipFwdState
|
||||||
|
|
||||||
|
m.aclManager6, err = newAclManager(workTable6, wgIface, chainNameRoutingFw)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create v6 acl manager: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// hasIPv6 reports whether the manager has IPv6 components initialized.
|
// hasIPv6 reports whether the manager has IPv6 components initialized.
|
||||||
func (m *Manager) hasIPv6() bool {
|
func (m *Manager) hasIPv6() bool {
|
||||||
return m.family6 != nil
|
return m.router6 != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) initIPv6() error {
|
func (m *Manager) initIPv6() error {
|
||||||
@@ -103,8 +128,12 @@ func (m *Manager) initIPv6() error {
|
|||||||
return fmt.Errorf("create v6 work table: %w", err)
|
return fmt.Errorf("create v6 work table: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.family6.init(workTable6); err != nil {
|
if err := m.router6.init(workTable6); err != nil {
|
||||||
return fmt.Errorf("v6 family init: %w", err)
|
return fmt.Errorf("v6 router init: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.aclManager6.init(workTable6); err != nil {
|
||||||
|
return fmt.Errorf("v6 acl manager init: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -127,20 +156,19 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
|
|
||||||
// reconcileExternalChains re-applies passthrough accept rules to external
|
// reconcileExternalChains re-applies passthrough accept rules to external
|
||||||
// filter chains for both IPv4 and IPv6 routers. Called by the monitor when
|
// filter chains for both IPv4 and IPv6 routers. Called by the monitor when
|
||||||
// tables or chains appear (e.g. after firewalld reloads). Kernel routing opens
|
// tables or chains appear (e.g. after firewalld reloads).
|
||||||
// both INPUT and FORWARD.
|
|
||||||
func (m *Manager) reconcileExternalChains() error {
|
func (m *Manager) reconcileExternalChains() error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
if m.family4 != nil {
|
if m.router != nil {
|
||||||
if err := m.family4.acceptExternalChainsRules(true); err != nil {
|
if err := m.router.acceptExternalChainsRules(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("v4: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("v4: %w", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if m.hasIPv6() {
|
if m.hasIPv6() {
|
||||||
if err := m.family6.acceptExternalChainsRules(true); err != nil {
|
if err := m.router6.acceptExternalChainsRules(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("v6: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("v6: %w", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -159,8 +187,12 @@ func (m *Manager) initFirewall() (err error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err := m.family4.init(workTable); err != nil {
|
if err := m.router.init(workTable); err != nil {
|
||||||
return fmt.Errorf("family init: %w", err)
|
return fmt.Errorf("router init: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.aclManager.init(workTable); err != nil {
|
||||||
|
return fmt.Errorf("acl manager init: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.hasIPv6() {
|
if m.hasIPv6() {
|
||||||
@@ -188,7 +220,7 @@ func (m *Manager) persistState(stateManager *statemanager.Manager) {
|
|||||||
InterfaceState: &InterfaceState{
|
InterfaceState: &InterfaceState{
|
||||||
NameStr: m.wgIface.Name(),
|
NameStr: m.wgIface.Name(),
|
||||||
WGAddress: m.wgIface.Address(),
|
WGAddress: m.wgIface.Address(),
|
||||||
MTU: m.family4.mtu,
|
MTU: m.router.mtu,
|
||||||
},
|
},
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
log.Errorf("failed to update state: %v", err)
|
log.Errorf("failed to update state: %v", err)
|
||||||
@@ -203,12 +235,12 @@ func (m *Manager) persistState(stateManager *statemanager.Manager) {
|
|||||||
|
|
||||||
// rollbackInit performs best-effort cleanup of already-initialized state when Init fails partway through.
|
// rollbackInit performs best-effort cleanup of already-initialized state when Init fails partway through.
|
||||||
func (m *Manager) rollbackInit() {
|
func (m *Manager) rollbackInit() {
|
||||||
if err := m.family4.Reset(); err != nil {
|
if err := m.router.Reset(); err != nil {
|
||||||
log.Warnf("rollback family: %v", err)
|
log.Warnf("rollback router: %v", err)
|
||||||
}
|
}
|
||||||
if m.hasIPv6() {
|
if m.hasIPv6() {
|
||||||
if err := m.family6.Reset(); err != nil {
|
if err := m.router6.Reset(); err != nil {
|
||||||
log.Warnf("rollback v6 family: %v", err)
|
log.Warnf("rollback v6 router: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := m.cleanupNetbirdTables(); err != nil {
|
if err := m.cleanupNetbirdTables(); err != nil {
|
||||||
@@ -219,82 +251,118 @@ func (m *Manager) rollbackInit() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddFilterRule installs a packet-filtering rule.
|
// AddPeerFiltering rule to the firewall
|
||||||
//
|
//
|
||||||
// Destination semantics: zero Network → input chain (peer ACL);
|
// If comment argument is empty firewall manager should set
|
||||||
// set Network → forward chain (route ACL).
|
// rule ID as comment for the rule
|
||||||
//
|
func (m *Manager) AddPeerFiltering(
|
||||||
// Sources are a single address family; the rule is dispatched to the
|
|
||||||
// matching per-family backend.
|
|
||||||
func (m *Manager) AddFilterRule(
|
|
||||||
id []byte,
|
id []byte,
|
||||||
sources []netip.Prefix,
|
ip net.IP,
|
||||||
destination firewall.Network,
|
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
ipsetName string,
|
||||||
if len(sources) == 0 {
|
) ([]firewall.Rule, error) {
|
||||||
return nil, firewall.ErrNoSources
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
fam := m.family4
|
if ip.To4() != nil {
|
||||||
if isIPv6Rule(sources, destination) {
|
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||||
if !m.hasIPv6() {
|
|
||||||
return nil, fmt.Errorf("add filtering: %w", firewall.ErrIPv6NotInitialized)
|
|
||||||
}
|
|
||||||
fam = m.family6
|
|
||||||
}
|
}
|
||||||
return fam.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
|
|
||||||
|
if !m.hasIPv6() {
|
||||||
|
return nil, fmt.Errorf("add peer filtering for %s: %w", ip, firewall.ErrIPv6NotInitialized)
|
||||||
|
}
|
||||||
|
return m.aclManager6.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteFilterRule removes a filtering rule. The owning family is found
|
func (m *Manager) AddRouteFiltering(
|
||||||
// by id in the in-memory filter maps, which are the only tracking for
|
id []byte,
|
||||||
// filter rules. family.DeleteFilterRule is idempotent when the id is
|
sources []netip.Prefix,
|
||||||
// absent.
|
destination firewall.Network,
|
||||||
func (m *Manager) DeleteFilterRule(rule firewall.Rule) error {
|
proto firewall.Protocol,
|
||||||
|
sPort, dPort *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
) (firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
fam, err := m.familyForRuleID(rule.ID(), (*family).hasRule, false)
|
if isIPv6RouteRule(sources, destination) {
|
||||||
|
if !m.hasIPv6() {
|
||||||
|
return nil, fmt.Errorf("add route filtering: %w", firewall.ErrIPv6NotInitialized)
|
||||||
|
}
|
||||||
|
return m.router6.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePeerRule from the firewall by rule definition
|
||||||
|
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if m.hasIPv6() && isIPv6Rule(rule) {
|
||||||
|
return m.aclManager6.DeletePeerRule(rule)
|
||||||
|
}
|
||||||
|
return m.aclManager.DeletePeerRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIPv6Rule(rule firewall.Rule) bool {
|
||||||
|
r, ok := rule.(*Rule)
|
||||||
|
return ok && r.nftRule != nil && r.nftRule.Table != nil && r.nftRule.Table.Family == nftables.TableFamilyIPv6
|
||||||
|
}
|
||||||
|
|
||||||
|
// isIPv6RouteRule determines whether a route rule belongs to the v6 table.
|
||||||
|
// For static routes, the destination prefix determines the family. For dynamic
|
||||||
|
// routes (DomainSet), the sources determine the family since management
|
||||||
|
// duplicates dynamic rules per family.
|
||||||
|
func isIPv6RouteRule(sources []netip.Prefix, destination firewall.Network) bool {
|
||||||
|
if destination.IsPrefix() {
|
||||||
|
return destination.Prefix.Addr().Is6()
|
||||||
|
}
|
||||||
|
return len(sources) > 0 && sources[0].Addr().Is6()
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRouteRule deletes a routing rule. Route rules live in exactly one
|
||||||
|
// router; the cached maps are normally authoritative, so the kernel is only
|
||||||
|
// consulted when neither map knows about the rule.
|
||||||
|
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
id := rule.ID()
|
||||||
|
r, err := m.routerForRuleID(id, (*router).hasRule)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return fam.DeleteFilterRule(rule)
|
return r.DeleteRouteRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
// familyForRuleID picks the family holding the rule with the given id, using
|
// routerForRuleID picks the router holding the rule with the given id, using
|
||||||
// the supplied lookup. With refresh set, a miss in both cached maps reloads
|
// the supplied lookup. If the cached maps disagree (or both miss), it refreshes
|
||||||
// the NAT/DNAT rule maps from the kernel once and re-checks before falling
|
// from the kernel once and re-checks before falling back to the v4 router.
|
||||||
// back to the v4 family. Filter rules are tracked only in memory and have no
|
func (m *Manager) routerForRuleID(id string, has func(*router, string) bool) (*router, error) {
|
||||||
// kernel-backed reload, so their callers pass refresh as false.
|
if has(m.router, id) {
|
||||||
func (m *Manager) familyForRuleID(id firewall.RuleID, has func(*family, firewall.RuleID) bool, refresh bool) (*family, error) {
|
return m.router, nil
|
||||||
if has(m.family4, id) {
|
}
|
||||||
return m.family4, nil
|
if m.hasIPv6() && has(m.router6, id) {
|
||||||
|
return m.router6, nil
|
||||||
}
|
}
|
||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return m.family4, nil
|
return m.router, nil
|
||||||
}
|
}
|
||||||
if has(m.family6, id) {
|
if err := m.router.refreshRulesMap(); err != nil {
|
||||||
return m.family6, nil
|
|
||||||
}
|
|
||||||
if !refresh {
|
|
||||||
return m.family4, nil
|
|
||||||
}
|
|
||||||
if err := m.family4.refreshRulesMap(); err != nil {
|
|
||||||
return nil, fmt.Errorf("refresh v4 rules: %w", err)
|
return nil, fmt.Errorf("refresh v4 rules: %w", err)
|
||||||
}
|
}
|
||||||
if err := m.family6.refreshRulesMap(); err != nil {
|
if err := m.router6.refreshRulesMap(); err != nil {
|
||||||
return nil, fmt.Errorf("refresh v6 rules: %w", err)
|
return nil, fmt.Errorf("refresh v6 rules: %w", err)
|
||||||
}
|
}
|
||||||
if has(m.family6, id) && !has(m.family4, id) {
|
if has(m.router6, id) && !has(m.router, id) {
|
||||||
return m.family6, nil
|
return m.router6, nil
|
||||||
}
|
}
|
||||||
return m.family4, nil
|
return m.router, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) IsServerRouteSupported() bool {
|
func (m *Manager) IsServerRouteSupported() bool {
|
||||||
@@ -313,10 +381,10 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
|
return fmt.Errorf("add NAT rule: %w", firewall.ErrIPv6NotInitialized)
|
||||||
}
|
}
|
||||||
return m.family6.AddNatRule(pair)
|
return m.router6.AddNatRule(pair)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.family4.AddNatRule(pair); err != nil {
|
if err := m.router.AddNatRule(pair); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -328,7 +396,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
// so the eventual cleanup still works.
|
// so the eventual cleanup still works.
|
||||||
if m.hasIPv6() && pair.Dynamic {
|
if m.hasIPv6() && pair.Dynamic {
|
||||||
v6Pair := firewall.ToV6NatPair(pair)
|
v6Pair := firewall.ToV6NatPair(pair)
|
||||||
if err := m.family6.AddNatRule(v6Pair); err != nil {
|
if err := m.router6.AddNatRule(v6Pair); err != nil {
|
||||||
return fmt.Errorf("add v6 NAT rule: %w", err)
|
return fmt.Errorf("add v6 NAT rule: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -344,18 +412,18 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return m.family6.RemoveNatRule(pair)
|
return m.router6.RemoveNatRule(pair)
|
||||||
}
|
}
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
|
||||||
if err := m.family4.RemoveNatRule(pair); err != nil {
|
if err := m.router.RemoveNatRule(pair); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove v4 NAT rule: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.hasIPv6() && pair.Dynamic {
|
if m.hasIPv6() && pair.Dynamic {
|
||||||
v6Pair := firewall.ToV6NatPair(pair)
|
v6Pair := firewall.ToV6NatPair(pair)
|
||||||
if err := m.family6.RemoveNatRule(v6Pair); err != nil {
|
if err := m.router6.RemoveNatRule(v6Pair); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove v6 NAT rule: %w", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -363,13 +431,46 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AllowNetbird allows netbird interface traffic.
|
||||||
|
// This is called when USPFilter wraps the native firewall, adding blanket accept
|
||||||
|
// rules so that packet filtering is handled in userspace instead of by netfilter.
|
||||||
|
//
|
||||||
|
// TODO: In USP mode this only adds ACCEPT to the netbird table's own chains,
|
||||||
|
// which doesn't override DROP rules in external tables (e.g. firewalld).
|
||||||
|
// Should add passthrough rules to external chains (like the native mode router's
|
||||||
|
// addExternalChainsRules does) for both the netbird table family and inet tables.
|
||||||
|
// The netbird table itself is fine (routing chains already exist there), but
|
||||||
|
// non-netbird tables with INPUT/FORWARD hooks can still DROP our WG traffic.
|
||||||
|
func (m *Manager) AllowNetbird() error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if err := m.aclManager.createDefaultAllowRules(); err != nil {
|
||||||
|
return fmt.Errorf("create default allow rules: %w", err)
|
||||||
|
}
|
||||||
|
if m.hasIPv6() {
|
||||||
|
if err := m.aclManager6.createDefaultAllowRules(); err != nil {
|
||||||
|
return fmt.Errorf("create v6 default allow rules: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||||
|
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// SetLegacyManagement sets the route manager to use legacy management
|
// SetLegacyManagement sets the route manager to use legacy management
|
||||||
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||||
if err := firewall.SetLegacyManagement(m.family4, isLegacy); err != nil {
|
if err := firewall.SetLegacyManagement(m.router, isLegacy); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if m.hasIPv6() {
|
if m.hasIPv6() {
|
||||||
return firewall.SetLegacyManagement(m.family6, isLegacy)
|
return firewall.SetLegacyManagement(m.router6, isLegacy)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -383,13 +484,13 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
|
||||||
if err := m.family4.Reset(); err != nil {
|
if err := m.router.Reset(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("reset family: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("reset router: %v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.hasIPv6() {
|
if m.hasIPv6() {
|
||||||
if err := m.family6.Reset(); err != nil {
|
if err := m.router6.Reset(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("reset v6 family: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("reset v6 router: %v", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -429,14 +530,14 @@ func (m *Manager) SetLogLevel(log.Level) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) EnableRouting() error {
|
func (m *Manager) EnableRouting() error {
|
||||||
if err := m.family4.ipFwdState.RequestForwarding(); err != nil {
|
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
||||||
return fmt.Errorf("enable IP forwarding: %w", err)
|
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) DisableRouting() error {
|
func (m *Manager) DisableRouting() error {
|
||||||
if err := m.family4.ipFwdState.ReleaseForwarding(); err != nil {
|
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
return fmt.Errorf("disable IP forwarding: %w", err)
|
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -450,13 +551,13 @@ func (m *Manager) Flush() error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if err := m.family4.Flush(); err != nil {
|
if err := m.aclManager.Flush(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.hasIPv6() {
|
if m.hasIPv6() {
|
||||||
if err := m.family6.Flush(); err != nil {
|
if err := m.aclManager6.Flush(); err != nil {
|
||||||
return fmt.Errorf("flush v6 family: %w", err)
|
return fmt.Errorf("flush v6 acl: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -476,9 +577,9 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error)
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
|
return nil, fmt.Errorf("add DNAT rule: %w", firewall.ErrIPv6NotInitialized)
|
||||||
}
|
}
|
||||||
return m.family6.AddDNATRule(rule)
|
return m.router6.AddDNATRule(rule)
|
||||||
}
|
}
|
||||||
return m.family4.AddDNATRule(rule)
|
return m.router.AddDNATRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteDNATRule deletes a DNAT rule
|
// DeleteDNATRule deletes a DNAT rule
|
||||||
@@ -486,7 +587,7 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
r, err := m.familyForRuleID(rule.ID(), (*family).hasDNATRule, true)
|
r, err := m.routerForRuleID(rule.ID(), (*router).hasDNATRule)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -507,12 +608,12 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.family4.UpdateSet(set, v4Prefixes); err != nil {
|
if err := m.router.UpdateSet(set, v4Prefixes); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.hasIPv6() && len(v6Prefixes) > 0 {
|
if m.hasIPv6() && len(v6Prefixes) > 0 {
|
||||||
if err := m.family6.UpdateSet(set, v6Prefixes); err != nil {
|
if err := m.router6.UpdateSet(set, v6Prefixes); err != nil {
|
||||||
return fmt.Errorf("update v6 set: %w", err)
|
return fmt.Errorf("update v6 set: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -529,9 +630,9 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
return fmt.Errorf("add inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||||
}
|
}
|
||||||
return m.family6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.router6.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
return m.family4.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.router.AddInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveInboundDNAT removes an inbound DNAT rule.
|
// RemoveInboundDNAT removes an inbound DNAT rule.
|
||||||
@@ -543,9 +644,9 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
return fmt.Errorf("remove inbound DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||||
}
|
}
|
||||||
return m.family6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.router6.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
return m.family4.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.router.RemoveInboundDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
// AddOutputDNAT adds an OUTPUT chain DNAT rule for locally-generated traffic.
|
||||||
@@ -557,9 +658,9 @@ func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
return fmt.Errorf("add output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||||
}
|
}
|
||||||
return m.family6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.router6.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
return m.family4.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.router.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
// RemoveOutputDNAT removes an OUTPUT chain DNAT rule.
|
||||||
@@ -571,9 +672,9 @@ func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Proto
|
|||||||
if !m.hasIPv6() {
|
if !m.hasIPv6() {
|
||||||
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
return fmt.Errorf("remove output DNAT: %w", firewall.ErrIPv6NotInitialized)
|
||||||
}
|
}
|
||||||
return m.family6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.router6.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
return m.family4.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
return m.router.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -802,14 +903,3 @@ func getEstablishedExprs(register uint32) []expr.Any {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// isIPv6Rule reports whether the rule belongs to the v6 table. For a
|
|
||||||
// prefix destination the destination family decides; otherwise the
|
|
||||||
// (single-family) sources do, since management duplicates rules per
|
|
||||||
// family.
|
|
||||||
func isIPv6Rule(sources []netip.Prefix, destination firewall.Network) bool {
|
|
||||||
if destination.IsPrefix() {
|
|
||||||
return destination.Prefix.Addr().Is6()
|
|
||||||
}
|
|
||||||
return len(sources) > 0 && sources[0].Addr().Is6()
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
//go:build integration && !android
|
|
||||||
|
|
||||||
package nftables
|
package nftables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -72,13 +70,13 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
|
|
||||||
testClient := &nftables.Conn{}
|
testClient := &nftables.Conn{}
|
||||||
|
|
||||||
rule, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop)
|
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Flush()
|
err = manager.Flush()
|
||||||
require.NoError(t, err, "failed to flush")
|
require.NoError(t, err, "failed to flush")
|
||||||
|
|
||||||
rules, err := testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
|
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
||||||
require.NoError(t, err, "failed to get rules")
|
require.NoError(t, err, "failed to get rules")
|
||||||
|
|
||||||
require.Len(t, rules, 2, "expected 2 rules")
|
require.Len(t, rules, 2, "expected 2 rules")
|
||||||
@@ -149,12 +147,15 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
// Compare connection tracking rule at position 1 (pushed down by DROP rule insertion)
|
// Compare connection tracking rule at position 1 (pushed down by DROP rule insertion)
|
||||||
compareExprsIgnoringCounters(t, rules[1].Exprs, expectedExprs1)
|
compareExprsIgnoringCounters(t, rules[1].Exprs, expectedExprs1)
|
||||||
|
|
||||||
require.NoError(t, manager.DeleteFilterRule(rule), "failed to delete rule")
|
for _, r := range rule {
|
||||||
|
err = manager.DeletePeerRule(r)
|
||||||
|
require.NoError(t, err, "failed to delete rule")
|
||||||
|
}
|
||||||
|
|
||||||
err = manager.Flush()
|
err = manager.Flush()
|
||||||
require.NoError(t, err, "failed to flush")
|
require.NoError(t, err, "failed to flush")
|
||||||
|
|
||||||
rules, err = testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
|
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
||||||
require.NoError(t, err, "failed to get rules")
|
require.NoError(t, err, "failed to get rules")
|
||||||
// established rule remains
|
// established rule remains
|
||||||
require.Len(t, rules, 1, "expected 1 rules after deletion")
|
require.Len(t, rules, 1, "expected 1 rules after deletion")
|
||||||
@@ -179,39 +180,47 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
|
|||||||
testClient := &nftables.Conn{}
|
testClient := &nftables.Conn{}
|
||||||
|
|
||||||
// Add accept rule first
|
// Add accept rule first
|
||||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "accept-http")
|
||||||
require.NoError(t, err, "failed to add accept rule")
|
require.NoError(t, err, "failed to add accept rule")
|
||||||
|
|
||||||
// Add deny rule second for the same traffic
|
// Add deny rule second for the same traffic
|
||||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop, "deny-http")
|
||||||
require.NoError(t, err, "failed to add deny rule")
|
require.NoError(t, err, "failed to add deny rule")
|
||||||
|
|
||||||
err = manager.Flush()
|
err = manager.Flush()
|
||||||
require.NoError(t, err, "failed to flush")
|
require.NoError(t, err, "failed to flush")
|
||||||
|
|
||||||
rules, err := testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
|
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
|
||||||
require.NoError(t, err, "failed to get rules")
|
require.NoError(t, err, "failed to get rules")
|
||||||
|
|
||||||
t.Logf("Found %d rules in nftables chain", len(rules))
|
t.Logf("Found %d rules in nftables chain", len(rules))
|
||||||
|
|
||||||
// Single-source rules emit a direct payload+cmp on the source IP
|
// Find the accept and deny rules and verify deny comes before accept
|
||||||
// (no set lookup). Match by source-IP + port + verdict instead of
|
|
||||||
// the legacy per-(action,port) set names ("deny-http"/"accept-http")
|
|
||||||
// that this test predates.
|
|
||||||
wantSrc := ip.AsSlice()
|
|
||||||
var acceptRuleIndex, denyRuleIndex = -1, -1
|
var acceptRuleIndex, denyRuleIndex = -1, -1
|
||||||
for i, rule := range rules {
|
for i, rule := range rules {
|
||||||
var hasSrc, hasPort80 bool
|
hasAcceptHTTPSet := false
|
||||||
|
hasDenyHTTPSet := false
|
||||||
|
hasPort80 := false
|
||||||
var action string
|
var action string
|
||||||
|
|
||||||
for _, e := range rule.Exprs {
|
for _, e := range rule.Exprs {
|
||||||
if cmp, ok := e.(*expr.Cmp); ok && cmp.Op == expr.CmpOpEq {
|
// Check for set lookup
|
||||||
if bytes.Equal(cmp.Data, wantSrc) {
|
if lookup, ok := e.(*expr.Lookup); ok {
|
||||||
hasSrc = true
|
switch lookup.SetName {
|
||||||
|
case "accept-http":
|
||||||
|
hasAcceptHTTPSet = true
|
||||||
|
case "deny-http":
|
||||||
|
hasDenyHTTPSet = true
|
||||||
}
|
}
|
||||||
if len(cmp.Data) == 2 && binary.BigEndian.Uint16(cmp.Data) == 80 {
|
|
||||||
|
}
|
||||||
|
// Check for port 80
|
||||||
|
if cmp, ok := e.(*expr.Cmp); ok {
|
||||||
|
if cmp.Op == expr.CmpOpEq && len(cmp.Data) == 2 && binary.BigEndian.Uint16(cmp.Data) == 80 {
|
||||||
hasPort80 = true
|
hasPort80 = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Check for verdict
|
||||||
if verdict, ok := e.(*expr.Verdict); ok {
|
if verdict, ok := e.(*expr.Verdict); ok {
|
||||||
switch verdict.Kind {
|
switch verdict.Kind {
|
||||||
case expr.VerdictAccept:
|
case expr.VerdictAccept:
|
||||||
@@ -222,15 +231,11 @@ func TestNftablesManagerRuleOrder(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hasSrc || !hasPort80 {
|
if hasAcceptHTTPSet && hasPort80 && action == "ACCEPT" {
|
||||||
continue
|
t.Logf("Rule [%d]: accept-http set + Port 80 + ACCEPT", i)
|
||||||
}
|
|
||||||
switch action {
|
|
||||||
case "ACCEPT":
|
|
||||||
t.Logf("Rule [%d]: src=%s port=80 ACCEPT", i, ip)
|
|
||||||
acceptRuleIndex = i
|
acceptRuleIndex = i
|
||||||
case "DROP":
|
} else if hasDenyHTTPSet && hasPort80 && action == "DROP" {
|
||||||
t.Logf("Rule [%d]: src=%s port=80 DROP", i, ip)
|
t.Logf("Rule [%d]: deny-http set + Port 80 + DROP", i)
|
||||||
denyRuleIndex = i
|
denyRuleIndex = i
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -274,7 +279,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
|
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
if i%100 == 0 {
|
if i%100 == 0 {
|
||||||
@@ -356,10 +361,10 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
ip := netip.MustParseAddr("100.96.0.1")
|
ip := netip.MustParseAddr("100.96.0.1")
|
||||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add peer filtering rule")
|
require.NoError(t, err, "failed to add peer filtering rule")
|
||||||
|
|
||||||
_, err = manager.AddFilterRule(
|
_, err = manager.AddRouteFiltering(
|
||||||
nil,
|
nil,
|
||||||
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
||||||
fw.Network{Prefix: netip.MustParsePrefix("10.1.0.0/24")},
|
fw.Network{Prefix: netip.MustParsePrefix("10.1.0.0/24")},
|
||||||
@@ -432,10 +437,10 @@ func TestNftablesManagerIPv6CompatibilityWithIp6tables(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
ip := netip.MustParseAddr("fd00::2")
|
ip := netip.MustParseAddr("fd00::2")
|
||||||
_, err = manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "add v6 peer filtering rule")
|
require.NoError(t, err, "add v6 peer filtering rule")
|
||||||
|
|
||||||
_, err = manager.AddFilterRule(
|
_, err = manager.AddRouteFiltering(
|
||||||
nil,
|
nil,
|
||||||
[]netip.Prefix{netip.MustParsePrefix("fd00:1::/64")},
|
[]netip.Prefix{netip.MustParsePrefix("fd00:1::/64")},
|
||||||
fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
|
fw.Network{Prefix: netip.MustParsePrefix("2001:db8::/48")},
|
||||||
@@ -545,7 +550,7 @@ func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
|
|||||||
prefixes = append(prefixes, netip.PrefixFrom(addr, 24))
|
prefixes = append(prefixes, netip.PrefixFrom(addr, 24))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_, err = manager.AddFilterRule(
|
_, err = manager.AddRouteFiltering(
|
||||||
nil,
|
nil,
|
||||||
prefixes,
|
prefixes,
|
||||||
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
|
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
|
||||||
@@ -560,7 +565,7 @@ func TestNftablesManagerCompatibilityWithIptablesFor6kPrefixes(t *testing.T) {
|
|||||||
verifyIptablesOutput(t, stdout, stderr)
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNftablesManagerCompatibilityWithIptablesForWildcardSource(t *testing.T) {
|
func TestNftablesManagerCompatibilityWithIptablesForEmptyPrefixes(t *testing.T) {
|
||||||
if check() != NFTABLES {
|
if check() != NFTABLES {
|
||||||
t.Skip("nftables not supported on this system")
|
t.Skip("nftables not supported on this system")
|
||||||
}
|
}
|
||||||
@@ -586,9 +591,9 @@ func TestNftablesManagerCompatibilityWithIptablesForWildcardSource(t *testing.T)
|
|||||||
verifyIptablesOutput(t, stdout, stderr)
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
})
|
})
|
||||||
|
|
||||||
_, err = manager.AddFilterRule(
|
_, err = manager.AddRouteFiltering(
|
||||||
nil,
|
nil,
|
||||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
[]netip.Prefix{},
|
||||||
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
|
fw.Network{Prefix: netip.MustParsePrefix("10.2.0.0/24")},
|
||||||
fw.ProtocolTCP,
|
fw.ProtocolTCP,
|
||||||
nil,
|
nil,
|
||||||
@@ -601,73 +606,6 @@ func TestNftablesManagerCompatibilityWithIptablesForWildcardSource(t *testing.T)
|
|||||||
verifyIptablesOutput(t, stdout, stderr)
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNftablesManagerMultiPortFilter(t *testing.T) {
|
|
||||||
if check() != NFTABLES {
|
|
||||||
t.Skip("nftables not supported on this system")
|
|
||||||
}
|
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.NoError(t, manager.Init(nil))
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
require.NoError(t, manager.Close(nil), "failed to reset manager state")
|
|
||||||
})
|
|
||||||
|
|
||||||
ip := netip.MustParseAddr("100.96.0.1")
|
|
||||||
|
|
||||||
rule, err := manager.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80, 443}}, fw.ActionAccept)
|
|
||||||
require.NoError(t, err, "failed to add multi-port rule")
|
|
||||||
|
|
||||||
testClient := &nftables.Conn{}
|
|
||||||
rules, err := testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
|
|
||||||
require.NoError(t, err, "failed to get rules")
|
|
||||||
|
|
||||||
var lookup *expr.Lookup
|
|
||||||
for _, kernelRule := range rules {
|
|
||||||
if string(kernelRule.UserData) != string(rule.ID()) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, e := range kernelRule.Exprs {
|
|
||||||
if l, ok := e.(*expr.Lookup); ok {
|
|
||||||
lookup = l
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
require.NotNil(t, lookup, "multi-port rule must match ports via a set lookup")
|
|
||||||
|
|
||||||
sets, err := testClient.GetSets(manager.family4.workTable)
|
|
||||||
require.NoError(t, err, "failed to get sets")
|
|
||||||
|
|
||||||
var portSet *nftables.Set
|
|
||||||
for _, s := range sets {
|
|
||||||
if s.Name == lookup.SetName {
|
|
||||||
portSet = s
|
|
||||||
}
|
|
||||||
}
|
|
||||||
require.NotNil(t, portSet, "anonymous port set not found in kernel")
|
|
||||||
|
|
||||||
portSet.Table = manager.family4.workTable
|
|
||||||
elements, err := testClient.GetSetElements(portSet)
|
|
||||||
require.NoError(t, err, "failed to get set elements")
|
|
||||||
|
|
||||||
ports := make(map[uint16]bool)
|
|
||||||
for _, e := range elements {
|
|
||||||
require.Len(t, e.Key, 2, "port set element key should be 2 bytes")
|
|
||||||
ports[binary.BigEndian.Uint16(e.Key)] = true
|
|
||||||
}
|
|
||||||
require.True(t, ports[80], "port set should contain port 80")
|
|
||||||
require.True(t, ports[443], "port set should contain port 443")
|
|
||||||
|
|
||||||
require.NoError(t, manager.DeleteFilterRule(rule), "failed to delete rule")
|
|
||||||
|
|
||||||
rules, err = testClient.GetRules(manager.family4.workTable, manager.family4.chainInputRules)
|
|
||||||
require.NoError(t, err, "failed to get rules after delete")
|
|
||||||
for _, kernelRule := range rules {
|
|
||||||
require.NotEqual(t, string(rule.ID()), string(kernelRule.UserData), "rule should be removed from kernel")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) {
|
func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
require.Equal(t, len(got), len(want), "expression count mismatch")
|
require.Equal(t, len(got), len(want), "expression count mismatch")
|
||||||
|
|||||||
2244
client/firewall/nftables/router_linux.go
Normal file
2244
client/firewall/nftables/router_linux.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
|||||||
//go:build integration && !android
|
//go:build !android
|
||||||
|
|
||||||
package nftables
|
package nftables
|
||||||
|
|
||||||
@@ -37,7 +37,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
|||||||
|
|
||||||
for _, testCase := range test.InsertRuleTestCases {
|
for _, testCase := range test.InsertRuleTestCases {
|
||||||
t.Run(testCase.Name, func(t *testing.T) {
|
t.Run(testCase.Name, func(t *testing.T) {
|
||||||
// need fw manager to init both acl mgr and family for all chains to be present
|
// need fw manager to init both acl mgr and router for all chains to be present
|
||||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -47,7 +47,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
|||||||
|
|
||||||
nftablesTestingClient := &nftables.Conn{}
|
nftablesTestingClient := &nftables.Conn{}
|
||||||
|
|
||||||
rtr := manager.family4
|
rtr := manager.router
|
||||||
err = rtr.AddNatRule(testCase.InputPair)
|
err = rtr.AddNatRule(testCase.InputPair)
|
||||||
require.NoError(t, err, "pair should be inserted")
|
require.NoError(t, err, "pair should be inserted")
|
||||||
|
|
||||||
@@ -90,9 +90,9 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build CIDR matching expressions
|
// Build CIDR matching expressions
|
||||||
testRouter := &family{af: afIPv4}
|
testRouter := &router{af: afIPv4}
|
||||||
sourceExp := prefixMatchExprs(testRouter.af, testCase.InputPair.Source.Prefix, true)
|
sourceExp := testRouter.applyPrefix(testCase.InputPair.Source.Prefix, true)
|
||||||
destExp := prefixMatchExprs(testRouter.af, testCase.InputPair.Destination.Prefix, false)
|
destExp := testRouter.applyPrefix(testCase.InputPair.Destination.Prefix, false)
|
||||||
|
|
||||||
// Combine all expressions in the correct order
|
// Combine all expressions in the correct order
|
||||||
// nolint:gocritic
|
// nolint:gocritic
|
||||||
@@ -100,14 +100,14 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
|||||||
testingExpression = append(testingExpression, sourceExp...)
|
testingExpression = append(testingExpression, sourceExp...)
|
||||||
testingExpression = append(testingExpression, destExp...)
|
testingExpression = append(testingExpression, destExp...)
|
||||||
|
|
||||||
natRuleKey := testCase.InputPair.GenKey(firewall.PreroutingFormat)
|
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||||
found := 0
|
found := 0
|
||||||
for _, chain := range rtr.chains {
|
for _, chain := range rtr.chains {
|
||||||
if chain.Name == chainNameManglePrerouting {
|
if chain.Name == chainNameManglePrerouting {
|
||||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
|
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||||
// Compare expressions up to the mark setting expressions
|
// Compare expressions up to the mark setting expressions
|
||||||
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "prerouting nat rule elements should match")
|
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "prerouting nat rule elements should match")
|
||||||
found = 1
|
found = 1
|
||||||
@@ -135,19 +135,19 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, manager.Init(nil))
|
require.NoError(t, manager.Init(nil))
|
||||||
|
|
||||||
rtr := manager.family4
|
rtr := manager.router
|
||||||
|
|
||||||
// First add the NAT rule using the family's method
|
// First add the NAT rule using the router's method
|
||||||
err = rtr.AddNatRule(testCase.InputPair)
|
err = rtr.AddNatRule(testCase.InputPair)
|
||||||
require.NoError(t, err, "should add NAT rule")
|
require.NoError(t, err, "should add NAT rule")
|
||||||
|
|
||||||
// Verify the rule was added
|
// Verify the rule was added
|
||||||
natRuleKey := testCase.InputPair.GenKey(firewall.PreroutingFormat)
|
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||||
found := false
|
found := false
|
||||||
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||||
require.NoError(t, err, "should list rules")
|
require.NoError(t, err, "should list rules")
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
|
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||||
found = true
|
found = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -163,7 +163,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||||
require.NoError(t, err, "should list rules after removal")
|
require.NoError(t, err, "should list rules after removal")
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
|
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||||
found = true
|
found = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -200,10 +200,11 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
defer deleteWorkTable()
|
defer deleteWorkTable()
|
||||||
|
|
||||||
r := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err, "Failed to create router")
|
||||||
require.NoError(t, r.init(workTable))
|
require.NoError(t, r.init(workTable))
|
||||||
|
|
||||||
defer func(r *family) {
|
defer func(r *router) {
|
||||||
require.NoError(t, r.Reset(), "Failed to reset rules")
|
require.NoError(t, r.Reset(), "Failed to reset rules")
|
||||||
}(r)
|
}(r)
|
||||||
|
|
||||||
@@ -313,16 +314,16 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ruleKey, err := r.AddFilterRule(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||||
require.NoError(t, err, "AddFilterRule failed")
|
require.NoError(t, err, "AddRouteFiltering failed")
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, r.DeleteFilterRule(ruleKey), "Failed to delete rule")
|
require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule")
|
||||||
})
|
})
|
||||||
|
|
||||||
stored, ok := r.filters[id.RuleID(ruleKey.ID())]
|
// Check if the rule is in the internal map
|
||||||
require.True(t, ok, "Rule not found in filters map")
|
rule, ok := r.rules[ruleKey.ID()]
|
||||||
rule := stored.nftRule
|
assert.True(t, ok, "Rule not found in internal map")
|
||||||
|
|
||||||
t.Log("Internal rule expressions:")
|
t.Log("Internal rule expressions:")
|
||||||
for i, expr := range rule.Exprs {
|
for i, expr := range rule.Exprs {
|
||||||
@@ -338,7 +339,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
var nftRule *nftables.Rule
|
var nftRule *nftables.Rule
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if firewall.RuleID(rule.UserData) == ruleKey.ID() {
|
if string(rule.UserData) == ruleKey.ID() {
|
||||||
nftRule = rule
|
nftRule = rule
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -366,11 +367,12 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
|||||||
|
|
||||||
defer deleteWorkTable()
|
defer deleteWorkTable()
|
||||||
|
|
||||||
r := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err, "Failed to create router")
|
||||||
require.NoError(t, r.init(workTable))
|
require.NoError(t, r.init(workTable))
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, r.Reset(), "Failed to reset family")
|
require.NoError(t, r.Reset(), "Failed to reset router")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@@ -507,41 +509,6 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestNftablesUpdateSetMergesOverlapping verifies that UpdateSet merges
|
|
||||||
// overlapping prefixes before adding them. An interval set rejects
|
|
||||||
// overlapping elements, so without the merge a batch holding a /32 already
|
|
||||||
// covered by a /24, or a duplicate address as DNS resolution can produce,
|
|
||||||
// would fail.
|
|
||||||
func TestNftablesUpdateSetMergesOverlapping(t *testing.T) {
|
|
||||||
if check() != NFTABLES {
|
|
||||||
t.Skip("nftables not supported on this system")
|
|
||||||
}
|
|
||||||
|
|
||||||
workTable, err := createWorkTable()
|
|
||||||
require.NoError(t, err, "create work table")
|
|
||||||
defer deleteWorkTable()
|
|
||||||
|
|
||||||
r := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
|
||||||
require.NoError(t, r.init(workTable))
|
|
||||||
defer func() {
|
|
||||||
require.NoError(t, r.Reset(), "reset family")
|
|
||||||
}()
|
|
||||||
|
|
||||||
initial := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}
|
|
||||||
set := firewall.NewPrefixSet(initial)
|
|
||||||
|
|
||||||
created, err := r.createIpSet(set.HashedName(), setInput{prefixes: initial})
|
|
||||||
require.NoError(t, err, "create ip set")
|
|
||||||
require.NotNil(t, created)
|
|
||||||
|
|
||||||
overlapping := []netip.Prefix{
|
|
||||||
netip.MustParsePrefix("192.168.1.0/24"),
|
|
||||||
netip.MustParsePrefix("192.168.1.1/32"),
|
|
||||||
netip.MustParsePrefix("192.168.1.1/32"),
|
|
||||||
}
|
|
||||||
require.NoError(t, r.UpdateSet(set, overlapping), "UpdateSet must merge overlapping prefixes")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNftablesCreateIpSet_IPv6(t *testing.T) {
|
func TestNftablesCreateIpSet_IPv6(t *testing.T) {
|
||||||
if check() != NFTABLES {
|
if check() != NFTABLES {
|
||||||
t.Skip("nftables not supported on this system")
|
t.Skip("nftables not supported on this system")
|
||||||
@@ -551,10 +518,11 @@ func TestNftablesCreateIpSet_IPv6(t *testing.T) {
|
|||||||
require.NoError(t, err, "Failed to create v6 work table")
|
require.NoError(t, err, "Failed to create v6 work table")
|
||||||
defer deleteWorkTableIPv6()
|
defer deleteWorkTableIPv6()
|
||||||
|
|
||||||
r := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err, "Failed to create router")
|
||||||
require.NoError(t, r.init(workTable))
|
require.NoError(t, r.init(workTable))
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, r.Reset(), "Failed to reset family")
|
require.NoError(t, r.Reset(), "Failed to reset router")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@@ -780,14 +748,6 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case *expr.Lookup:
|
|
||||||
// Multiple discrete ports compile to an anonymous set lookup
|
|
||||||
// rather than a chain of comparisons. The set's id and name are
|
|
||||||
// assigned dynamically, so matching the lookup is enough here;
|
|
||||||
// the set elements are verified separately.
|
|
||||||
if !port.IsRange && len(port.Values) > 1 {
|
|
||||||
portMatchFound = true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if payloadFound && portMatchFound {
|
if payloadFound && portMatchFound {
|
||||||
return true
|
return true
|
||||||
@@ -901,12 +861,13 @@ func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer deleteWorkTable()
|
defer deleteWorkTable()
|
||||||
|
|
||||||
r := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
require.NoError(t, r.init(workTable))
|
require.NoError(t, r.init(workTable))
|
||||||
defer func() { require.NoError(t, r.Reset()) }()
|
defer func() { require.NoError(t, r.Reset()) }()
|
||||||
|
|
||||||
// Add a real rule to the kernel
|
// Add a real rule to the kernel
|
||||||
ruleKey, err := r.AddFilterRule(
|
ruleKey, err := r.AddRouteFiltering(
|
||||||
nil,
|
nil,
|
||||||
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
||||||
@@ -917,11 +878,11 @@ func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
|
|||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, r.DeleteFilterRule(ruleKey))
|
require.NoError(t, r.DeleteRouteRule(ruleKey))
|
||||||
})
|
})
|
||||||
|
|
||||||
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
|
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
|
||||||
staleKey := firewall.RuleID("stale-rule-that-does-not-exist")
|
staleKey := "stale-rule-that-does-not-exist"
|
||||||
r.rules[staleKey] = &nftables.Rule{
|
r.rules[staleKey] = &nftables.Rule{
|
||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
Chain: r.chains[chainNameRoutingFw],
|
Chain: r.chains[chainNameRoutingFw],
|
||||||
@@ -941,54 +902,6 @@ func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
|
|||||||
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
|
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestRouter_DeleteRouteRule_RemovesKernelRule verifies a route filter
|
|
||||||
// rule is actually removed from the kernel on delete. The route chain is
|
|
||||||
// not refreshed by Flush, so the stored rule carries a zero handle;
|
|
||||||
// DeleteFilterRule must pull live handles itself before issuing the
|
|
||||||
// delete or the kernel rule leaks. Regression test for that path.
|
|
||||||
func TestRouter_DeleteRouteRule_RemovesKernelRule(t *testing.T) {
|
|
||||||
if check() != NFTABLES {
|
|
||||||
t.Skip("nftables not supported on this system")
|
|
||||||
}
|
|
||||||
|
|
||||||
workTable, err := createWorkTable()
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer deleteWorkTable()
|
|
||||||
|
|
||||||
r := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
|
||||||
require.NoError(t, r.init(workTable))
|
|
||||||
defer func() { require.NoError(t, r.Reset()) }()
|
|
||||||
|
|
||||||
ruleKey, err := r.AddFilterRule(
|
|
||||||
nil,
|
|
||||||
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
|
||||||
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
|
||||||
firewall.ProtocolTCP,
|
|
||||||
nil,
|
|
||||||
&firewall.Port{Values: []uint16{80}},
|
|
||||||
firewall.ActionAccept,
|
|
||||||
)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
countKernelRules := func() int {
|
|
||||||
list, err := r.conn.GetRules(r.workTable, r.chains[chainNameRoutingFw])
|
|
||||||
require.NoError(t, err)
|
|
||||||
n := 0
|
|
||||||
for _, rule := range list {
|
|
||||||
if string(rule.UserData) == string(ruleKey.ID()) {
|
|
||||||
n++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return n
|
|
||||||
}
|
|
||||||
|
|
||||||
require.Equal(t, 1, countKernelRules(), "rule should be present in the kernel after add")
|
|
||||||
|
|
||||||
require.NoError(t, r.DeleteFilterRule(ruleKey))
|
|
||||||
assert.Equal(t, 0, countKernelRules(), "rule must be removed from the kernel after delete")
|
|
||||||
assert.NotContains(t, r.filters, ruleKey.ID(), "filters map entry should be cleared")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
|
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
|
||||||
if check() != NFTABLES {
|
if check() != NFTABLES {
|
||||||
t.Skip("nftables not supported on this system")
|
t.Skip("nftables not supported on this system")
|
||||||
@@ -998,27 +911,24 @@ func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer deleteWorkTable()
|
defer deleteWorkTable()
|
||||||
|
|
||||||
r := newFamily(workTable, ifaceMock, iface.DefaultMTU)
|
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
require.NoError(t, r.init(workTable))
|
require.NoError(t, r.init(workTable))
|
||||||
defer func() { require.NoError(t, r.Reset()) }()
|
defer func() { require.NoError(t, r.Reset()) }()
|
||||||
|
|
||||||
// Inject a stale entry with Handle=0
|
// Inject a stale entry with Handle=0
|
||||||
staleKey := id.RuleID("stale-route-rule")
|
staleKey := "stale-route-rule"
|
||||||
staleRule := &Rule{
|
r.rules[staleKey] = &nftables.Rule{
|
||||||
nftRule: &nftables.Rule{
|
Table: r.workTable,
|
||||||
Table: r.workTable,
|
Chain: r.chains[chainNameRoutingFw],
|
||||||
Chain: r.chains[chainNameRoutingFw],
|
Handle: 0,
|
||||||
Handle: 0,
|
UserData: []byte(staleKey),
|
||||||
UserData: []byte(staleKey),
|
|
||||||
},
|
|
||||||
id: staleKey,
|
|
||||||
}
|
}
|
||||||
r.filters[staleKey] = staleRule
|
|
||||||
|
|
||||||
// DeleteFilterRule should not return an error for stale handles
|
// DeleteRouteRule should not return an error for stale handles
|
||||||
err = r.DeleteFilterRule(staleRule)
|
err = r.DeleteRouteRule(id.RuleID(staleKey))
|
||||||
assert.NoError(t, err, "deleting a stale rule should not error")
|
assert.NoError(t, err, "deleting a stale rule should not error")
|
||||||
assert.NotContains(t, r.filters, staleKey, "stale entry should be cleaned up")
|
assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
||||||
@@ -1040,7 +950,7 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
|||||||
Masquerade: true,
|
Masquerade: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
rtr := manager.family4
|
rtr := manager.router
|
||||||
|
|
||||||
// First add succeeds
|
// First add succeeds
|
||||||
err = rtr.AddNatRule(pair)
|
err = rtr.AddNatRule(pair)
|
||||||
@@ -1050,11 +960,11 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Corrupt the handle to simulate stale state
|
// Corrupt the handle to simulate stale state
|
||||||
natRuleKey := pair.GenKey(firewall.PreroutingFormat)
|
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||||
if rule, exists := rtr.rules[natRuleKey]; exists {
|
if rule, exists := rtr.rules[natRuleKey]; exists {
|
||||||
rule.Handle = 0
|
rule.Handle = 0
|
||||||
}
|
}
|
||||||
inverseKey := firewall.GetInversePair(pair).GenKey(firewall.PreroutingFormat)
|
inverseKey := firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair))
|
||||||
if rule, exists := rtr.rules[inverseKey]; exists {
|
if rule, exists := rtr.rules[inverseKey]; exists {
|
||||||
rule.Handle = 0
|
rule.Handle = 0
|
||||||
}
|
}
|
||||||
@@ -1069,7 +979,7 @@ func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
|||||||
|
|
||||||
found := 0
|
found := 0
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 && firewall.RuleID(rule.UserData) == natRuleKey {
|
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||||
found++
|
found++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1100,7 +1010,7 @@ func TestCalculateLastIP(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestConvertPrefixesToSet_IPv6(t *testing.T) {
|
func TestConvertPrefixesToSet_IPv6(t *testing.T) {
|
||||||
r := &family{af: afIPv6}
|
r := &router{af: afIPv6}
|
||||||
prefixes := []netip.Prefix{
|
prefixes := []netip.Prefix{
|
||||||
netip.MustParsePrefix("fd00::/64"),
|
netip.MustParsePrefix("fd00::/64"),
|
||||||
netip.MustParsePrefix("2001:db8::1/128"),
|
netip.MustParsePrefix("2001:db8::1/128"),
|
||||||
|
|||||||
@@ -1,500 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/google/nftables"
|
|
||||||
"github.com/google/nftables/binaryutil"
|
|
||||||
"github.com/google/nftables/expr"
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (r *family) AddNatRule(pair firewall.RouterPair) error {
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
|
||||||
return fmt.Errorf(refreshRulesMapError, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.legacyManagement {
|
|
||||||
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
|
||||||
if err := r.addLegacyRouteRule(pair); err != nil {
|
|
||||||
r.rollbackRules(pair)
|
|
||||||
return fmt.Errorf("add legacy routing rule: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if pair.Masquerade {
|
|
||||||
if err := r.addNatRule(pair); err != nil {
|
|
||||||
r.rollbackRules(pair)
|
|
||||||
return fmt.Errorf("add nat rule: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
|
|
||||||
r.rollbackRules(pair)
|
|
||||||
return fmt.Errorf("add inverse nat rule: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
r.rollbackRules(pair)
|
|
||||||
return fmt.Errorf("insert rules for %s: %w", pair.Destination, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// rollbackRules cleans up unflushed rules and their set counters after a flush failure.
|
|
||||||
func (r *family) rollbackRules(pair firewall.RouterPair) {
|
|
||||||
keys := []firewall.RuleID{
|
|
||||||
pair.GenKey(firewall.ForwardingFormat),
|
|
||||||
pair.GenKey(firewall.PreroutingFormat),
|
|
||||||
firewall.GetInversePair(pair).GenKey(firewall.PreroutingFormat),
|
|
||||||
}
|
|
||||||
for _, key := range keys {
|
|
||||||
rule, ok := r.rules[key]
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := r.decrementSetCounter(rule); err != nil {
|
|
||||||
log.Warnf("rollback set counter for %s: %v", key, err)
|
|
||||||
}
|
|
||||||
delete(r.rules, key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// addNatRule inserts a nftables rule to the conn client flush queue
|
|
||||||
func (r *family) addNatRule(pair firewall.RouterPair) error {
|
|
||||||
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("apply source: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
destExp, err := r.applyNetwork(pair.Destination, nil, false)
|
|
||||||
if err != nil {
|
|
||||||
r.dropNetworkMatch(sourceExp)
|
|
||||||
return fmt.Errorf("apply destination: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
op := expr.CmpOpEq
|
|
||||||
if pair.Inverse {
|
|
||||||
op = expr.CmpOpNeq
|
|
||||||
}
|
|
||||||
|
|
||||||
exprs := []expr.Any{
|
|
||||||
&expr.Meta{
|
|
||||||
Key: expr.MetaKeyIIFNAME,
|
|
||||||
Register: 1,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: op,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(r.wgIface.Name()),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
|
|
||||||
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
|
|
||||||
exprs = append(exprs, getCtNewExprs()...)
|
|
||||||
|
|
||||||
exprs = append(exprs, sourceExp...)
|
|
||||||
exprs = append(exprs, destExp...)
|
|
||||||
|
|
||||||
var markValue uint32 = nbnet.PreroutingFwmarkMasquerade
|
|
||||||
if pair.Inverse {
|
|
||||||
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
|
|
||||||
}
|
|
||||||
|
|
||||||
exprs = append(exprs,
|
|
||||||
&expr.Immediate{
|
|
||||||
Register: 1,
|
|
||||||
Data: binaryutil.NativeEndian.PutUint32(markValue),
|
|
||||||
},
|
|
||||||
&expr.Meta{
|
|
||||||
Key: expr.MetaKeyMARK,
|
|
||||||
SourceRegister: true,
|
|
||||||
Register: 1,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
ruleID := pair.GenKey(firewall.PreroutingFormat)
|
|
||||||
|
|
||||||
if _, exists := r.rules[ruleID]; exists {
|
|
||||||
if err := r.removeNatRule(pair); err != nil {
|
|
||||||
r.dropNetworkMatch(sourceExp)
|
|
||||||
r.dropNetworkMatch(destExp)
|
|
||||||
return fmt.Errorf("remove prerouting rule: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure nat rules come first, so the mark can be overwritten.
|
|
||||||
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
|
|
||||||
r.rules[ruleID] = r.conn.InsertRule(&nftables.Rule{
|
|
||||||
Table: r.workTable,
|
|
||||||
Chain: r.chains[chainNameManglePrerouting],
|
|
||||||
Exprs: exprs,
|
|
||||||
UserData: []byte(ruleID),
|
|
||||||
})
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) addPostroutingRules() {
|
|
||||||
// First masquerade rule for traffic coming in from WireGuard interface
|
|
||||||
exprs := []expr.Any{
|
|
||||||
// Match on the first fwmark
|
|
||||||
&expr.Meta{
|
|
||||||
Key: expr.MetaKeyMARK,
|
|
||||||
Register: 1,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasquerade),
|
|
||||||
},
|
|
||||||
|
|
||||||
// We need to exclude the loopback interface as this changes the ebpf proxy port
|
|
||||||
&expr.Meta{
|
|
||||||
Key: expr.MetaKeyOIFNAME,
|
|
||||||
Register: 1,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpNeq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname("lo"),
|
|
||||||
},
|
|
||||||
&expr.Counter{},
|
|
||||||
&expr.Masq{},
|
|
||||||
}
|
|
||||||
|
|
||||||
r.conn.AddRule(&nftables.Rule{
|
|
||||||
Table: r.workTable,
|
|
||||||
Chain: r.chains[chainNameRoutingNat],
|
|
||||||
Exprs: exprs,
|
|
||||||
})
|
|
||||||
|
|
||||||
// Second masquerade rule for traffic going out through WireGuard interface
|
|
||||||
exprs2 := []expr.Any{
|
|
||||||
// Match on the second fwmark
|
|
||||||
&expr.Meta{
|
|
||||||
Key: expr.MetaKeyMARK,
|
|
||||||
Register: 1,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasqueradeReturn),
|
|
||||||
},
|
|
||||||
|
|
||||||
// Match WireGuard interface
|
|
||||||
&expr.Meta{
|
|
||||||
Key: expr.MetaKeyOIFNAME,
|
|
||||||
Register: 1,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(r.wgIface.Name()),
|
|
||||||
},
|
|
||||||
&expr.Counter{},
|
|
||||||
&expr.Masq{},
|
|
||||||
}
|
|
||||||
|
|
||||||
r.conn.AddRule(&nftables.Rule{
|
|
||||||
Table: r.workTable,
|
|
||||||
Chain: r.chains[chainNameRoutingNat],
|
|
||||||
Exprs: exprs2,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
|
||||||
func (r *family) addMSSClampingRules() error {
|
|
||||||
overhead := uint16(ipv4TCPHeaderSize)
|
|
||||||
if r.af.tableFamily == nftables.TableFamilyIPv6 {
|
|
||||||
overhead = ipv6TCPHeaderSize
|
|
||||||
}
|
|
||||||
if r.mtu <= overhead {
|
|
||||||
log.Debugf("MTU %d too small for MSS clamping (overhead %d), skipping", r.mtu, overhead)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
mss := r.mtu - overhead
|
|
||||||
|
|
||||||
exprsOut := []expr.Any{
|
|
||||||
&expr.Meta{
|
|
||||||
Key: expr.MetaKeyOIFNAME,
|
|
||||||
Register: 1,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(r.wgIface.Name()),
|
|
||||||
},
|
|
||||||
&expr.Meta{
|
|
||||||
Key: expr.MetaKeyL4PROTO,
|
|
||||||
Register: 1,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: []byte{unix.IPPROTO_TCP},
|
|
||||||
},
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseTransportHeader,
|
|
||||||
Offset: 13,
|
|
||||||
Len: 1,
|
|
||||||
},
|
|
||||||
&expr.Bitwise{
|
|
||||||
DestRegister: 1,
|
|
||||||
SourceRegister: 1,
|
|
||||||
Len: 1,
|
|
||||||
Mask: []byte{0x02},
|
|
||||||
Xor: []byte{0x00},
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpNeq,
|
|
||||||
Register: 1,
|
|
||||||
Data: []byte{0x00},
|
|
||||||
},
|
|
||||||
&expr.Counter{},
|
|
||||||
&expr.Exthdr{
|
|
||||||
DestRegister: 1,
|
|
||||||
Type: 2,
|
|
||||||
Offset: 2,
|
|
||||||
Len: 2,
|
|
||||||
Op: expr.ExthdrOpTcpopt,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpGt,
|
|
||||||
Register: 1,
|
|
||||||
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
|
|
||||||
},
|
|
||||||
&expr.Immediate{
|
|
||||||
Register: 1,
|
|
||||||
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
|
|
||||||
},
|
|
||||||
&expr.Exthdr{
|
|
||||||
SourceRegister: 1,
|
|
||||||
Type: 2,
|
|
||||||
Offset: 2,
|
|
||||||
Len: 2,
|
|
||||||
Op: expr.ExthdrOpTcpopt,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
r.conn.AddRule(&nftables.Rule{
|
|
||||||
Table: r.workTable,
|
|
||||||
Chain: r.chains[chainNameMangleForward],
|
|
||||||
Exprs: exprsOut,
|
|
||||||
})
|
|
||||||
|
|
||||||
return r.conn.Flush()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) addLegacyRouteRule(pair firewall.RouterPair) error {
|
|
||||||
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("apply source: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
destExp, err := r.applyNetwork(pair.Destination, nil, false)
|
|
||||||
if err != nil {
|
|
||||||
r.dropNetworkMatch(sourceExp)
|
|
||||||
return fmt.Errorf("apply destination: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var exprs []expr.Any
|
|
||||||
exprs = append(exprs, sourceExp...)
|
|
||||||
exprs = append(exprs, destExp...)
|
|
||||||
exprs = append(exprs,
|
|
||||||
&expr.Counter{},
|
|
||||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
|
||||||
)
|
|
||||||
|
|
||||||
ruleID := pair.GenKey(firewall.ForwardingFormat)
|
|
||||||
|
|
||||||
if _, exists := r.rules[ruleID]; exists {
|
|
||||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
|
||||||
r.dropNetworkMatch(sourceExp)
|
|
||||||
r.dropNetworkMatch(destExp)
|
|
||||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
r.rules[ruleID] = r.conn.AddRule(&nftables.Rule{
|
|
||||||
Table: r.workTable,
|
|
||||||
Chain: r.chains[chainNameRoutingFw],
|
|
||||||
Exprs: exprs,
|
|
||||||
UserData: []byte(ruleID),
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls
|
|
||||||
func (r *family) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
|
||||||
ruleID := pair.GenKey(firewall.ForwardingFormat)
|
|
||||||
|
|
||||||
rule, exists := r.rules[ruleID]
|
|
||||||
if !exists {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return r.deleteLegacyRuleEntry(ruleID, rule)
|
|
||||||
}
|
|
||||||
|
|
||||||
// deleteLegacyRuleEntry removes one legacy forwarding rule and drops its
|
|
||||||
// ipset references. It also clears stale entries that never got a handle.
|
|
||||||
func (r *family) deleteLegacyRuleEntry(ruleID firewall.RuleID, rule *nftables.Rule) error {
|
|
||||||
if rule.Handle == 0 {
|
|
||||||
log.Warnf("legacy forwarding rule %s has no handle, removing stale entry", ruleID)
|
|
||||||
if err := r.decrementSetCounter(rule); err != nil {
|
|
||||||
log.Warnf("decrement set counter for stale rule %s: %v", ruleID, err)
|
|
||||||
}
|
|
||||||
delete(r.rules, ruleID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
|
||||||
return fmt.Errorf("remove legacy forwarding rule %s: %w", ruleID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(r.rules, ruleID)
|
|
||||||
|
|
||||||
if err := r.decrementSetCounter(rule); err != nil {
|
|
||||||
return fmt.Errorf("decrement set counter: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetLegacyManagement returns the route manager's legacy management mode
|
|
||||||
func (r *family) GetLegacyManagement() bool {
|
|
||||||
return r.legacyManagement
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetLegacyManagement sets the route manager to use legacy management mode
|
|
||||||
func (r *family) SetLegacyManagement(isLegacy bool) {
|
|
||||||
r.legacyManagement = isLegacy
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
|
|
||||||
func (r *family) RemoveAllLegacyRouteRules() error {
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
|
||||||
return fmt.Errorf(refreshRulesMapError, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
for k, rule := range r.rules {
|
|
||||||
if !strings.HasPrefix(string(k), firewall.ForwardingFormatPrefix) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := r.deleteLegacyRuleEntry(k, rule); err != nil {
|
|
||||||
merr = multierror.Append(merr, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) removeNatPreroutingRules() error {
|
|
||||||
table := &nftables.Table{
|
|
||||||
Name: tableNat,
|
|
||||||
Family: r.af.tableFamily,
|
|
||||||
}
|
|
||||||
chain := &nftables.Chain{
|
|
||||||
Name: chainNameNatPrerouting,
|
|
||||||
Table: table,
|
|
||||||
Hooknum: nftables.ChainHookPrerouting,
|
|
||||||
Priority: nftables.ChainPriorityNATDest,
|
|
||||||
Type: nftables.ChainTypeNAT,
|
|
||||||
}
|
|
||||||
rules, err := r.conn.GetRules(table, chain)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("get rules from nat table: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
|
|
||||||
// Delete rules that have our UserData suffix
|
|
||||||
for _, rule := range rules {
|
|
||||||
if len(rule.UserData) == 0 || !strings.HasSuffix(string(rule.UserData), string(dnatSuffix)) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete rule %s: %w", rule.UserData, err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
|
||||||
}
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) RemoveNatRule(pair firewall.RouterPair) error {
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
|
||||||
return fmt.Errorf(refreshRulesMapError, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
|
|
||||||
if pair.Masquerade {
|
|
||||||
if err := r.removeNatRule(pair); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove prerouting rule: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove inverse prerouting rule: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy routing rule: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set counters are decremented in the sub-methods above before flush. If flush fails,
|
|
||||||
// counters will be off until the next successful removal or refresh cycle.
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("flush remove nat rules %s: %w", pair.Destination, err))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *family) removeNatRule(pair firewall.RouterPair) error {
|
|
||||||
ruleID := pair.GenKey(firewall.PreroutingFormat)
|
|
||||||
|
|
||||||
rule, exists := r.rules[ruleID]
|
|
||||||
if !exists {
|
|
||||||
log.Debugf("prerouting rule %s not found", ruleID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if rule.Handle == 0 {
|
|
||||||
log.Warnf("prerouting rule %s has no handle, removing stale entry", ruleID)
|
|
||||||
if err := r.decrementSetCounter(rule); err != nil {
|
|
||||||
log.Warnf("decrement set counter for stale rule %s: %v", ruleID, err)
|
|
||||||
}
|
|
||||||
delete(r.rules, ruleID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
|
||||||
return fmt.Errorf("remove prerouting rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
|
||||||
|
|
||||||
delete(r.rules, ruleID)
|
|
||||||
|
|
||||||
if err := r.decrementSetCounter(rule); err != nil {
|
|
||||||
return fmt.Errorf("decrement set counter: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -1,26 +1,21 @@
|
|||||||
package nftables
|
package nftables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net"
|
||||||
|
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Rule wraps an installed filter rule (peer or route). Source set
|
// Rule to handle management of rules
|
||||||
// membership is encoded in the rule's expressions; DeleteFilterRule
|
|
||||||
// recovers the set name via findSets so the refcounter can drop the
|
|
||||||
// right reference. mangleRule is set only for peer rules.
|
|
||||||
type Rule struct {
|
type Rule struct {
|
||||||
nftRule *nftables.Rule
|
nftRule *nftables.Rule
|
||||||
mangleRule *nftables.Rule
|
mangleRule *nftables.Rule
|
||||||
// sources is the canonical source list this rule was created for.
|
nftSet *nftables.Set
|
||||||
sources []netip.Prefix
|
ruleID string
|
||||||
id manager.RuleID
|
ip net.IP
|
||||||
}
|
}
|
||||||
|
|
||||||
// ID returns the rule id
|
// GetRuleID returns the rule id
|
||||||
func (r *Rule) ID() manager.RuleID {
|
func (r *Rule) ID() string {
|
||||||
return r.id
|
return r.ruleID
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,27 +0,0 @@
|
|||||||
//go:build integration && !android
|
|
||||||
|
|
||||||
package nftables
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
)
|
|
||||||
|
|
||||||
func pfx(ip net.IP) []netip.Prefix {
|
|
||||||
if ip == nil {
|
|
||||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
|
||||||
}
|
|
||||||
if ip.IsUnspecified() {
|
|
||||||
if ip.To4() != nil {
|
|
||||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
|
||||||
}
|
|
||||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
|
||||||
}
|
|
||||||
a, ok := netip.AddrFromSlice(ip)
|
|
||||||
if !ok {
|
|
||||||
panic(fmt.Sprintf("invalid IP length: %d", len(ip)))
|
|
||||||
}
|
|
||||||
a = a.Unmap()
|
|
||||||
return []netip.Prefix{netip.PrefixFrom(a, a.BitLen())}
|
|
||||||
}
|
|
||||||
37
client/firewall/uspfilter/allow_netbird.go
Normal file
37
client/firewall/uspfilter/allow_netbird.go
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Close cleans up the firewall manager by removing all rules and closing trackers
|
||||||
|
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
m.resetState()
|
||||||
|
|
||||||
|
if m.nativeFirewall != nil {
|
||||||
|
return m.nativeFirewall.Close(stateManager)
|
||||||
|
}
|
||||||
|
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
|
||||||
|
log.Warnf("failed to untrust interface in firewalld: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowNetbird allows netbird interface traffic
|
||||||
|
func (m *Manager) AllowNetbird() error {
|
||||||
|
if m.nativeFirewall != nil {
|
||||||
|
return m.nativeFirewall.AllowNetbird()
|
||||||
|
}
|
||||||
|
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
|
||||||
|
log.Warnf("failed to trust interface in firewalld: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
type action string
|
type action string
|
||||||
@@ -19,20 +20,35 @@ const (
|
|||||||
firewallRuleName = "Netbird"
|
firewallRuleName = "Netbird"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WindowsInterfaceAllower opens the NetBird interface in the Windows firewall
|
// Close cleans up the firewall manager by removing all rules and closing trackers
|
||||||
// via netsh advfirewall rules. It implements InterfaceAllower for the userspace
|
func (m *Manager) Close(*statemanager.Manager) error {
|
||||||
// firewall on Windows.
|
m.mutex.Lock()
|
||||||
type WindowsInterfaceAllower struct {
|
defer m.mutex.Unlock()
|
||||||
iface Iface
|
|
||||||
|
m.resetState()
|
||||||
|
|
||||||
|
if !isWindowsFirewallReachable() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
if isFirewallRuleActive(firewallRuleName) {
|
||||||
|
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove windows firewall rule: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if isFirewallRuleActive(firewallRuleName + "-v6") {
|
||||||
|
if err := manageFirewallRule(firewallRuleName+"-v6", deleteRule); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove windows v6 firewall rule: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWindowsInterfaceAllower builds the Windows netsh-based interface allower.
|
// AllowNetbird allows netbird interface traffic
|
||||||
func NewWindowsInterfaceAllower(iface Iface) *WindowsInterfaceAllower {
|
func (m *Manager) AllowNetbird() error {
|
||||||
return &WindowsInterfaceAllower{iface: iface}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply adds inbound-allow netsh rules for the interface's addresses.
|
|
||||||
func (a *WindowsInterfaceAllower) Apply() error {
|
|
||||||
if !isWindowsFirewallReachable() {
|
if !isWindowsFirewallReachable() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -44,13 +60,13 @@ func (a *WindowsInterfaceAllower) Apply() error {
|
|||||||
"enable=yes",
|
"enable=yes",
|
||||||
"action=allow",
|
"action=allow",
|
||||||
"profile=any",
|
"profile=any",
|
||||||
"localip="+a.iface.Address().IP.String(),
|
"localip="+m.wgIface.Address().IP.String(),
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if v6 := a.iface.Address().IPv6; v6.IsValid() && !isFirewallRuleActive(firewallRuleName+"-v6") {
|
if v6 := m.wgIface.Address().IPv6; v6.IsValid() && !isFirewallRuleActive(firewallRuleName+"-v6") {
|
||||||
if err := manageFirewallRule(firewallRuleName+"-v6",
|
if err := manageFirewallRule(firewallRuleName+"-v6",
|
||||||
addRule,
|
addRule,
|
||||||
"dir=in",
|
"dir=in",
|
||||||
@@ -66,27 +82,8 @@ func (a *WindowsInterfaceAllower) Apply() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close removes the netsh rules added by Apply.
|
|
||||||
func (a *WindowsInterfaceAllower) Close() error {
|
|
||||||
if !isWindowsFirewallReachable() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
if isFirewallRuleActive(firewallRuleName) {
|
|
||||||
if err := manageFirewallRule(firewallRuleName, deleteRule); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove windows firewall rule: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if isFirewallRuleActive(firewallRuleName + "-v6") {
|
|
||||||
if err := manageFirewallRule(firewallRuleName+"-v6", deleteRule); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove windows v6 firewall rule: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func manageFirewallRule(ruleName string, action action, extraArgs ...string) error {
|
func manageFirewallRule(ruleName string, action action, extraArgs ...string) error {
|
||||||
|
|
||||||
args := []string{"advfirewall", "firewall", string(action), "rule", "name=" + ruleName}
|
args := []string{"advfirewall", "firewall", string(action), "rule", "name=" + ruleName}
|
||||||
if action == addRule {
|
if action == addRule {
|
||||||
args = append(args, extraArgs...)
|
args = append(args, extraArgs...)
|
||||||
17
client/firewall/uspfilter/common/iface.go
Normal file
17
client/firewall/uspfilter/common/iface.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IFaceMapper defines subset methods of interface required for manager
|
||||||
|
type IFaceMapper interface {
|
||||||
|
Name() string
|
||||||
|
SetFilter(device.PacketFilter) error
|
||||||
|
Address() wgaddr.Address
|
||||||
|
GetWGDevice() *wgdevice.Device
|
||||||
|
GetDevice() *device.FilteredDevice
|
||||||
|
}
|
||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
@@ -19,18 +20,14 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
|
||||||
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@@ -61,10 +58,7 @@ const (
|
|||||||
// EnvDisableMSSClamping disables TCP MSS clamping for forwarded traffic.
|
// EnvDisableMSSClamping disables TCP MSS clamping for forwarded traffic.
|
||||||
EnvDisableMSSClamping = "NB_DISABLE_MSS_CLAMPING"
|
EnvDisableMSSClamping = "NB_DISABLE_MSS_CLAMPING"
|
||||||
|
|
||||||
// EnvForceUserspaceRouter is a deprecated alias for
|
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
|
||||||
// NB_FORCE_USERSPACE_FIREWALL: the userspace firewall always routes in
|
|
||||||
// userspace, so forcing one forces the other. Kept for backward
|
|
||||||
// compatibility.
|
|
||||||
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
|
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
|
||||||
|
|
||||||
// EnvEnableLocalForwarding enables forwarding of local traffic to the native stack for internal (non-NetBird) interfaces.
|
// EnvEnableLocalForwarding enables forwarding of local traffic to the native stack for internal (non-NetBird) interfaces.
|
||||||
@@ -76,20 +70,14 @@ const (
|
|||||||
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
|
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
|
||||||
)
|
)
|
||||||
|
|
||||||
// errNotSupported is returned by firewall operations that only make sense with
|
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
|
||||||
// a kernel firewall (kernel NAT/DNAT, eBPF) and are not implemented in
|
|
||||||
// userspace mode, where they should not be called.
|
|
||||||
var errNotSupported = errors.New("not supported with userspace firewall")
|
|
||||||
|
|
||||||
// peerRules is the canonical list-based storage for peer ACL rules.
|
// RuleSet is a set of rules grouped by a string key
|
||||||
// Drop and accept rules live in separate slices; drop-before-accept
|
type RuleSet map[string]PeerRule
|
||||||
// ordering comes from consulting the deny slice (and its index) before
|
|
||||||
// the accept one.
|
|
||||||
type peerRules []*PeerRule
|
|
||||||
|
|
||||||
type routeRules []*RouteRule
|
type RouteRules []*RouteRule
|
||||||
|
|
||||||
func (r routeRules) Sort() {
|
func (r RouteRules) Sort() {
|
||||||
slices.SortStableFunc(r, func(a, b *RouteRule) int {
|
slices.SortStableFunc(r, func(a, b *RouteRule) int {
|
||||||
// Deny rules come first
|
// Deny rules come first
|
||||||
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
|
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
|
||||||
@@ -98,74 +86,22 @@ func (r routeRules) Sort() {
|
|||||||
if a.action != firewall.ActionDrop && b.action == firewall.ActionDrop {
|
if a.action != firewall.ActionDrop && b.action == firewall.ActionDrop {
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
return strings.Compare(string(a.id), string(b.id))
|
return strings.Compare(a.id, b.id)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// peerRuleSpec carries the parameters that define a peer filter rule,
|
|
||||||
// threaded together through the build path so the builders take a single
|
|
||||||
// argument instead of a long parameter list.
|
|
||||||
type peerRuleSpec struct {
|
|
||||||
mgmtID []byte
|
|
||||||
sources []netip.Prefix
|
|
||||||
ipLayer gopacket.LayerType
|
|
||||||
proto firewall.Protocol
|
|
||||||
sPort *firewall.Port
|
|
||||||
dPort *firewall.Port
|
|
||||||
action firewall.Action
|
|
||||||
}
|
|
||||||
|
|
||||||
// Iface is the network interface the userspace firewall attaches to: the
|
|
||||||
// methods of the WireGuard device it actually uses.
|
|
||||||
type Iface interface {
|
|
||||||
Name() string
|
|
||||||
Address() wgaddr.Address
|
|
||||||
SetFilter(device.PacketFilter) error
|
|
||||||
GetWGDevice() *wgdevice.Device
|
|
||||||
}
|
|
||||||
|
|
||||||
// InterfaceAllower opens the NetBird interface in the host firewall so it
|
|
||||||
// doesn't drop traffic the userspace firewall handles, without taking over
|
|
||||||
// packet filtering. Implementations (nftables, iptables, firewalld, the windows
|
|
||||||
// netsh rule) are selected per platform and injected into Create; Apply runs at
|
|
||||||
// creation and Close on teardown.
|
|
||||||
type InterfaceAllower interface {
|
|
||||||
Apply() error
|
|
||||||
Close() error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Config holds the dependencies and options for the userspace firewall.
|
|
||||||
type Config struct {
|
|
||||||
// IFace is the overlay interface the filter attaches to.
|
|
||||||
IFace Iface
|
|
||||||
// InterfaceAllower opens the NetBird interface in foreign kernel filter
|
|
||||||
// chains so the kernel doesn't drop traffic the userspace firewall handles.
|
|
||||||
// Nil in netstack mode, on non-Linux platforms without a backend, or when
|
|
||||||
// neither nftables nor iptables is available. firewalld trust is applied by
|
|
||||||
// the manager regardless, since firewalld owns its own chains and we cannot
|
|
||||||
// insert into them.
|
|
||||||
InterfaceAllower InterfaceAllower
|
|
||||||
// DisableServerRoutes indicates whether server routes are disabled.
|
|
||||||
DisableServerRoutes bool
|
|
||||||
FlowLogger nftypes.FlowLogger
|
|
||||||
MTU uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
// Manager userspace firewall manager
|
// Manager userspace firewall manager
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
decoders sync.Pool
|
outgoingRules map[netip.Addr]RuleSet
|
||||||
wgIface Iface
|
incomingDenyRules map[netip.Addr]RuleSet
|
||||||
ifaceAllower InterfaceAllower
|
incomingRules map[netip.Addr]RuleSet
|
||||||
mutex sync.RWMutex
|
routeRules RouteRules
|
||||||
|
routeRulesMap map[nbid.RuleID]*RouteRule
|
||||||
|
decoders sync.Pool
|
||||||
|
wgIface common.IFaceMapper
|
||||||
|
nativeFirewall firewall.Manager
|
||||||
|
|
||||||
incomingDenyRules peerRules
|
mutex sync.RWMutex
|
||||||
incomingAcceptRules peerRules
|
|
||||||
incomingDenyIndex peerRuleIndex
|
|
||||||
incomingAcceptIndex peerRuleIndex
|
|
||||||
peerRulesMap map[nbid.RuleID]*PeerRule
|
|
||||||
|
|
||||||
routeRules routeRules
|
|
||||||
routeRulesMap map[nbid.RuleID]*RouteRule
|
|
||||||
|
|
||||||
// indicates whether server routes are disabled
|
// indicates whether server routes are disabled
|
||||||
disableServerRoutes bool
|
disableServerRoutes bool
|
||||||
@@ -247,6 +183,24 @@ func (d *decoder) decodePacket(data []byte) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create userspace firewall manager constructor
|
||||||
|
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||||
|
return create(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||||
|
if nativeFirewall == nil {
|
||||||
|
return nil, errors.New("native firewall is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger, mtu)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return mgr, nil
|
||||||
|
}
|
||||||
|
|
||||||
func parseCreateEnv() (bool, bool, bool) {
|
func parseCreateEnv() (bool, bool, bool) {
|
||||||
var disableConntrack, enableLocalForwarding, disableMSSClamping bool
|
var disableConntrack, enableLocalForwarding, disableMSSClamping bool
|
||||||
var err error
|
var err error
|
||||||
@@ -277,7 +231,7 @@ func parseCreateEnv() (bool, bool, bool) {
|
|||||||
return disableConntrack, enableLocalForwarding, disableMSSClamping
|
return disableConntrack, enableLocalForwarding, disableMSSClamping
|
||||||
}
|
}
|
||||||
|
|
||||||
func Create(cfg Config) (_ *Manager, err error) {
|
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||||
disableConntrack, enableLocalForwarding, disableMSSClamping := parseCreateEnv()
|
disableConntrack, enableLocalForwarding, disableMSSClamping := parseCreateEnv()
|
||||||
|
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
@@ -300,131 +254,62 @@ func Create(cfg Config) (_ *Manager, err error) {
|
|||||||
return d
|
return d
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
wgIface: cfg.IFace,
|
nativeFirewall: nativeFirewall,
|
||||||
ifaceAllower: cfg.InterfaceAllower,
|
outgoingRules: make(map[netip.Addr]RuleSet),
|
||||||
|
incomingDenyRules: make(map[netip.Addr]RuleSet),
|
||||||
|
incomingRules: make(map[netip.Addr]RuleSet),
|
||||||
|
wgIface: iface,
|
||||||
localipmanager: newLocalIPManager(),
|
localipmanager: newLocalIPManager(),
|
||||||
disableServerRoutes: cfg.DisableServerRoutes,
|
disableServerRoutes: disableServerRoutes,
|
||||||
stateful: !disableConntrack,
|
stateful: !disableConntrack,
|
||||||
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
||||||
flowLogger: cfg.FlowLogger,
|
flowLogger: flowLogger,
|
||||||
netstack: netstack.IsEnabled(),
|
netstack: netstack.IsEnabled(),
|
||||||
localForwarding: enableLocalForwarding,
|
localForwarding: enableLocalForwarding,
|
||||||
peerRulesMap: make(map[nbid.RuleID]*PeerRule),
|
|
||||||
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
|
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
|
||||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||||
portDNATRules: []portDNATRule{},
|
portDNATRules: []portDNATRule{},
|
||||||
netstackServices: make(map[serviceKey]struct{}),
|
netstackServices: make(map[serviceKey]struct{}),
|
||||||
mtu: cfg.MTU,
|
mtu: mtu,
|
||||||
}
|
}
|
||||||
m.routingEnabled.Store(false)
|
m.routingEnabled.Store(false)
|
||||||
|
|
||||||
// Release the allower (and its monitor) if setup fails after it was wired in.
|
|
||||||
defer func() {
|
|
||||||
if err != nil {
|
|
||||||
m.closeAllowerOnError()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if !disableMSSClamping {
|
if !disableMSSClamping {
|
||||||
m.enableMSSClamping(cfg.MTU)
|
m.mssClampEnabled = true
|
||||||
|
if mtu > ipv4TCPHeaderMinSize {
|
||||||
|
m.mssClampValueIPv4 = mtu - ipv4TCPHeaderMinSize
|
||||||
|
}
|
||||||
|
if mtu > ipv6TCPHeaderMinSize {
|
||||||
|
m.mssClampValueIPv6 = mtu - ipv6TCPHeaderMinSize
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if err := m.localipmanager.UpdateLocalIPs(cfg.IFace); err != nil {
|
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
||||||
return nil, fmt.Errorf("update local IPs: %w", err)
|
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||||
}
|
}
|
||||||
m.setupConntrack(disableConntrack)
|
if disableConntrack {
|
||||||
|
log.Info("conntrack is disabled")
|
||||||
|
} else {
|
||||||
|
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, flowLogger)
|
||||||
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger)
|
||||||
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger)
|
||||||
|
}
|
||||||
if m.netstack && m.localForwarding {
|
if m.netstack && m.localForwarding {
|
||||||
if err := m.initForwarder(); err != nil {
|
if err := m.initForwarder(); err != nil {
|
||||||
log.Errorf("failed to initialize forwarder: %v", err)
|
log.Errorf("failed to initialize forwarder: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := cfg.IFace.SetFilter(m); err != nil {
|
if err := iface.SetFilter(m); err != nil {
|
||||||
return nil, fmt.Errorf("set filter: %w", err)
|
return nil, fmt.Errorf("set filter: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.openHostFirewall(cfg.IFace.Name())
|
|
||||||
|
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// closeAllowerOnError releases the allower (and its monitor) when Create fails
|
|
||||||
// after the allower was wired in.
|
|
||||||
func (m *Manager) closeAllowerOnError() {
|
|
||||||
if m.ifaceAllower == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := m.ifaceAllower.Close(); err != nil {
|
|
||||||
log.Warnf("close interface allower after failed firewall setup: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// enableMSSClamping enables MSS clamping and computes the per-family clamp values.
|
|
||||||
func (m *Manager) enableMSSClamping(mtu uint16) {
|
|
||||||
m.mssClampEnabled = true
|
|
||||||
if mtu > ipv4TCPHeaderMinSize {
|
|
||||||
m.mssClampValueIPv4 = mtu - ipv4TCPHeaderMinSize
|
|
||||||
}
|
|
||||||
if mtu > ipv6TCPHeaderMinSize {
|
|
||||||
m.mssClampValueIPv6 = mtu - ipv6TCPHeaderMinSize
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// setupConntrack initializes the stateful trackers unless conntrack is disabled.
|
|
||||||
func (m *Manager) setupConntrack(disabled bool) {
|
|
||||||
if disabled {
|
|
||||||
log.Info("conntrack is disabled")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
|
|
||||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
|
|
||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
|
|
||||||
}
|
|
||||||
|
|
||||||
// openHostFirewall opens the NetBird interface in the kernel firewall so it
|
|
||||||
// doesn't drop traffic the userspace firewall handles. Best-effort: failures
|
|
||||||
// here shouldn't prevent the firewall from coming up.
|
|
||||||
func (m *Manager) openHostFirewall(ifaceName string) {
|
|
||||||
if m.ifaceAllower != nil {
|
|
||||||
if err := m.ifaceAllower.Apply(); err != nil {
|
|
||||||
log.Errorf("failed to allow netbird interface traffic: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// firewalld owns its own chains we can't insert into, so trust the interface
|
|
||||||
// there in addition to the allower. Netstack has no kernel interface.
|
|
||||||
if !m.netstack {
|
|
||||||
if err := firewalld.TrustInterface(ifaceName); err != nil {
|
|
||||||
log.Warnf("failed to trust interface in firewalld: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close cleans up the firewall manager: removes rules, closes trackers, and
|
|
||||||
// closes the interface allower.
|
|
||||||
func (m *Manager) Close(*statemanager.Manager) error {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
m.resetState()
|
|
||||||
|
|
||||||
var merr *multierror.Error
|
|
||||||
if m.ifaceAllower != nil {
|
|
||||||
if err := m.ifaceAllower.Close(); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("close interface allower: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !m.netstack {
|
|
||||||
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("untrust interface in firewalld: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// blockInvalidRouted installs drop rules for traffic to the wg overlay that
|
// blockInvalidRouted installs drop rules for traffic to the wg overlay that
|
||||||
// arrives via the routing path. v4 and v6 are independent: a v6 install
|
// arrives via the routing path. v4 and v6 are independent: a v6 install
|
||||||
// failure leaves v4 protection in place (and vice versa) so the returned
|
// failure leaves v4 protection in place (and vice versa) so the returned
|
||||||
// slice always contains whatever was successfully installed, even on error.
|
// slice always contains whatever was successfully installed, even on error.
|
||||||
// Callers must persist the slice so DisableRouting can clean partial state.
|
// Callers must persist the slice so DisableRouting can clean partial state.
|
||||||
func (m *Manager) blockInvalidRouted(iface Iface) ([]firewall.Rule, error) {
|
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) ([]firewall.Rule, error) {
|
||||||
wgPrefix := iface.Address().Network
|
wgPrefix := iface.Address().Network
|
||||||
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
||||||
|
|
||||||
@@ -435,7 +320,7 @@ func (m *Manager) blockInvalidRouted(iface Iface) ([]firewall.Rule, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var rules []firewall.Rule
|
var rules []firewall.Rule
|
||||||
v4Rule, err := m.addRouteRule(
|
v4Rule, err := m.addRouteFiltering(
|
||||||
nil,
|
nil,
|
||||||
sources,
|
sources,
|
||||||
firewall.Network{Prefix: wgPrefix},
|
firewall.Network{Prefix: wgPrefix},
|
||||||
@@ -451,7 +336,7 @@ func (m *Manager) blockInvalidRouted(iface Iface) ([]firewall.Rule, error) {
|
|||||||
|
|
||||||
if v6Net.IsValid() {
|
if v6Net.IsValid() {
|
||||||
log.Debugf("blocking invalid routed traffic for %s", v6Net)
|
log.Debugf("blocking invalid routed traffic for %s", v6Net)
|
||||||
v6Rule, err := m.addRouteRule(
|
v6Rule, err := m.addRouteFiltering(
|
||||||
nil,
|
nil,
|
||||||
sources,
|
sources,
|
||||||
firewall.Network{Prefix: v6Net},
|
firewall.Network{Prefix: v6Net},
|
||||||
@@ -472,14 +357,20 @@ func (m *Manager) blockInvalidRouted(iface Iface) ([]firewall.Rule, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) determineRouting() error {
|
func (m *Manager) determineRouting() error {
|
||||||
var disableUspRouting bool
|
var disableUspRouting, forceUserspaceRouter bool
|
||||||
|
var err error
|
||||||
if val := os.Getenv(EnvDisableUserspaceRouting); val != "" {
|
if val := os.Getenv(EnvDisableUserspaceRouting); val != "" {
|
||||||
var err error
|
|
||||||
disableUspRouting, err = strconv.ParseBool(val)
|
disableUspRouting, err = strconv.ParseBool(val)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to parse %s: %v", EnvDisableUserspaceRouting, err)
|
log.Warnf("failed to parse %s: %v", EnvDisableUserspaceRouting, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if val := os.Getenv(EnvForceUserspaceRouter); val != "" {
|
||||||
|
forceUserspaceRouter, err = strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", EnvForceUserspaceRouter, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case disableUspRouting:
|
case disableUspRouting:
|
||||||
@@ -494,11 +385,26 @@ func (m *Manager) determineRouting() error {
|
|||||||
|
|
||||||
log.Info("server routes are disabled")
|
log.Info("server routes are disabled")
|
||||||
|
|
||||||
|
case forceUserspaceRouter:
|
||||||
|
m.routingEnabled.Store(true)
|
||||||
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
|
log.Info("userspace routing is forced")
|
||||||
|
|
||||||
|
case !m.netstack && m.nativeFirewall != nil:
|
||||||
|
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
||||||
|
// netstack mode won't support native routing as there is no interface
|
||||||
|
|
||||||
|
m.routingEnabled.Store(true)
|
||||||
|
m.nativeRouter.Store(true)
|
||||||
|
|
||||||
|
log.Info("native routing is enabled")
|
||||||
|
|
||||||
default:
|
default:
|
||||||
m.routingEnabled.Store(true)
|
m.routingEnabled.Store(true)
|
||||||
m.nativeRouter.Store(false)
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
log.Info("userspace routing enabled")
|
log.Info("userspace routing enabled by default")
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.routingEnabled.Load() && !m.nativeRouter.Load() {
|
if m.routingEnabled.Load() && !m.nativeRouter.Load() {
|
||||||
@@ -564,118 +470,82 @@ func (m *Manager) IsStateful() bool {
|
|||||||
return m.stateful
|
return m.stateful
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddNatRule(firewall.RouterPair) error {
|
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||||
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
|
return m.nativeFirewall.AddNatRule(pair)
|
||||||
|
}
|
||||||
|
|
||||||
// userspace routed packets are always SNATed to the inbound direction
|
// userspace routed packets are always SNATed to the inbound direction
|
||||||
// TODO: implement outbound SNAT
|
// TODO: implement outbound SNAT
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveNatRule removes a routing firewall rule
|
// RemoveNatRule removes a routing firewall rule
|
||||||
func (m *Manager) RemoveNatRule(firewall.RouterPair) error {
|
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
|
return m.nativeFirewall.RemoveNatRule(pair)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// addPeerRule installs an input-chain rule that matches packets
|
// AddPeerFiltering rule to the firewall
|
||||||
// by source only. Called from AddFilterRule when the caller doesn't
|
//
|
||||||
// specify a destination. Sources are expected to share one address
|
// If comment argument is empty firewall manager should set
|
||||||
// family; the family selects the ipLayer so the ICMP variant matches
|
// rule ID as comment for the rule
|
||||||
// what the decoder produces.
|
func (m *Manager) AddPeerFiltering(
|
||||||
func (m *Manager) addPeerRule(
|
|
||||||
id []byte,
|
id []byte,
|
||||||
sources []netip.Prefix,
|
ip net.IP,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
_ string,
|
||||||
|
) ([]firewall.Rule, error) {
|
||||||
|
// TODO: fix in upper layers
|
||||||
|
i, ok := netip.AddrFromSlice(ip)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid IP: %s", ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
i = i.Unmap()
|
||||||
|
r := PeerRule{
|
||||||
|
id: uuid.New().String(),
|
||||||
|
mgmtId: id,
|
||||||
|
ip: i,
|
||||||
|
ipLayer: layers.LayerTypeIPv6,
|
||||||
|
matchByIP: true,
|
||||||
|
drop: action == firewall.ActionDrop,
|
||||||
|
}
|
||||||
|
if i.Is4() {
|
||||||
|
r.ipLayer = layers.LayerTypeIPv4
|
||||||
|
}
|
||||||
|
|
||||||
|
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
|
||||||
|
r.matchByIP = false
|
||||||
|
}
|
||||||
|
|
||||||
|
r.sPort = sPort
|
||||||
|
r.dPort = dPort
|
||||||
|
|
||||||
|
r.protoLayer = protoToLayer(proto, r.ipLayer)
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
var targetMap map[netip.Addr]RuleSet
|
||||||
|
if r.drop {
|
||||||
// Sources are a single family; normalize v4-mapped prefixes to plain
|
targetMap = m.incomingDenyRules
|
||||||
// v4 and pick the matching IP layer. A /0 source matches any address
|
|
||||||
// of its own family only, mirroring the kernel backends.
|
|
||||||
normalized := make([]netip.Prefix, len(sources))
|
|
||||||
ipLayer := layers.LayerTypeIPv4
|
|
||||||
for i, p := range sources {
|
|
||||||
normalized[i] = firewall.UnmapPrefix(p)
|
|
||||||
if normalized[i].Addr().Is6() {
|
|
||||||
ipLayer = layers.LayerTypeIPv6
|
|
||||||
}
|
|
||||||
}
|
|
||||||
spec := peerRuleSpec{
|
|
||||||
mgmtID: id,
|
|
||||||
sources: normalized,
|
|
||||||
ipLayer: ipLayer,
|
|
||||||
proto: proto,
|
|
||||||
sPort: sPort,
|
|
||||||
dPort: dPort,
|
|
||||||
action: action,
|
|
||||||
}
|
|
||||||
return m.addOnePeerRule(spec), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// addOnePeerRule builds and registers a single-family peer rule, or
|
|
||||||
// returns the existing rule when one with the same content key is
|
|
||||||
// already installed. The caller must hold m.mutex. The content key is
|
|
||||||
// the shared GenerateRuleID with an empty destination, so peer rules
|
|
||||||
// dedup the same way route rules and the kernel backends do; it is
|
|
||||||
// order-independent, so callers passing the same sources in any order
|
|
||||||
// dedup to one rule.
|
|
||||||
//
|
|
||||||
// There is no refcount: a content key is installed once and deleted on
|
|
||||||
// the first DeleteFilterRule for that key. The caller must therefore
|
|
||||||
// key its own tracking by the returned rule id so add and delete stay
|
|
||||||
// balanced per content key; the acl manager does this via
|
|
||||||
// peerRulesPairs.
|
|
||||||
func (m *Manager) addOnePeerRule(spec peerRuleSpec) *PeerRule {
|
|
||||||
ruleID := nbid.GenerateRuleID(spec.sources, firewall.Network{}, spec.proto, spec.sPort, spec.dPort, spec.action)
|
|
||||||
if existing, ok := m.peerRulesMap[ruleID]; ok {
|
|
||||||
return existing
|
|
||||||
}
|
|
||||||
|
|
||||||
rule := m.buildPeerRule(ruleID, spec)
|
|
||||||
m.registerPeerRule(rule)
|
|
||||||
return rule
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) buildPeerRule(ruleID nbid.RuleID, spec peerRuleSpec) *PeerRule {
|
|
||||||
r := &PeerRule{
|
|
||||||
id: ruleID,
|
|
||||||
mgmtId: spec.mgmtID,
|
|
||||||
sources: spec.sources,
|
|
||||||
action: spec.action,
|
|
||||||
srcPort: spec.sPort,
|
|
||||||
dstPort: spec.dPort,
|
|
||||||
}
|
|
||||||
r.sourceAddrs = make(map[netip.Addr]struct{}, len(spec.sources))
|
|
||||||
for _, p := range spec.sources {
|
|
||||||
if p.Bits() == p.Addr().BitLen() {
|
|
||||||
r.sourceAddrs[p.Addr()] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
r.protoLayer = protoToLayer(spec.proto, spec.ipLayer)
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// registerPeerRule records a freshly built peer rule in the matching
|
|
||||||
// slice, index, and dedup map. The caller must hold m.mutex.
|
|
||||||
func (m *Manager) registerPeerRule(r *PeerRule) {
|
|
||||||
if r.action == firewall.ActionDrop {
|
|
||||||
m.incomingDenyRules = append(m.incomingDenyRules, r)
|
|
||||||
m.incomingDenyIndex.add(r)
|
|
||||||
} else {
|
} else {
|
||||||
m.incomingAcceptRules = append(m.incomingAcceptRules, r)
|
targetMap = m.incomingRules
|
||||||
m.incomingAcceptIndex.add(r)
|
|
||||||
}
|
}
|
||||||
m.peerRulesMap[r.id] = r
|
|
||||||
|
if _, ok := targetMap[r.ip]; !ok {
|
||||||
|
targetMap[r.ip] = make(RuleSet)
|
||||||
|
}
|
||||||
|
targetMap[r.ip][r.id] = r
|
||||||
|
m.mutex.Unlock()
|
||||||
|
return []firewall.Rule{&r}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddFilterRule is the unified entry point for both peer (input chain)
|
func (m *Manager) AddRouteFiltering(
|
||||||
// and route (forward chain) filtering rules. The destination
|
|
||||||
// distinguishes the two semantics: a zero Network installs an
|
|
||||||
// input-side rule that matches by source only; a set Network installs
|
|
||||||
// a forward-side rule that also matches the destination.
|
|
||||||
func (m *Manager) AddFilterRule(
|
|
||||||
id []byte,
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination firewall.Network,
|
destination firewall.Network,
|
||||||
@@ -683,34 +553,13 @@ func (m *Manager) AddFilterRule(
|
|||||||
sPort, dPort *firewall.Port,
|
sPort, dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
if len(sources) == 0 {
|
|
||||||
return nil, firewall.ErrNoSources
|
|
||||||
}
|
|
||||||
|
|
||||||
if destination.IsZero() {
|
|
||||||
return m.addPeerRule(id, sources, proto, sPort, dPort, action)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
return m.addRouteRule(id, sources, destination, proto, sPort, dPort, action)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteFilterRule deletes a filtering rule. The rule's underlying type
|
|
||||||
// is used to route to the correct internal path.
|
|
||||||
func (m *Manager) DeleteFilterRule(rule firewall.Rule) error {
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if r, ok := rule.(*PeerRule); ok {
|
return m.addRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
return m.deletePeerRuleLocked(r)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Anything else is a route rule (matched on the forward path).
|
|
||||||
return m.deleteRouteRule(rule)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) addRouteRule(
|
func (m *Manager) addRouteFiltering(
|
||||||
id []byte,
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination firewall.Network,
|
destination firewall.Network,
|
||||||
@@ -718,14 +567,19 @@ func (m *Manager) addRouteRule(
|
|||||||
sPort, dPort *firewall.Port,
|
sPort, dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
ruleID := nbid.GenerateRuleID(sources, destination, proto, sPort, dPort, action)
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
|
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
|
}
|
||||||
|
|
||||||
if existingRule, ok := m.routeRulesMap[ruleID]; ok {
|
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||||
|
|
||||||
|
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
|
||||||
return existingRule, nil
|
return existingRule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rule := RouteRule{
|
rule := RouteRule{
|
||||||
id: ruleID,
|
// TODO: consolidate these IDs
|
||||||
|
id: string(ruleKey),
|
||||||
mgmtId: id,
|
mgmtId: id,
|
||||||
sources: sources,
|
sources: sources,
|
||||||
dstSet: destination.Set,
|
dstSet: destination.Set,
|
||||||
@@ -740,58 +594,78 @@ func (m *Manager) addRouteRule(
|
|||||||
|
|
||||||
m.routeRules = append(m.routeRules, &rule)
|
m.routeRules = append(m.routeRules, &rule)
|
||||||
m.routeRules.Sort()
|
m.routeRules.Sort()
|
||||||
m.routeRulesMap[ruleID] = &rule
|
m.routeRulesMap[ruleKey] = &rule
|
||||||
|
|
||||||
return &rule, nil
|
return &rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.deleteRouteRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
|
func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
|
||||||
ruleID := rule.ID()
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
trimmed, _, ok := removeRuleByID(m.routeRules, ruleID)
|
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("route rule not found: %s", ruleID)
|
|
||||||
}
|
}
|
||||||
m.routeRules = trimmed
|
|
||||||
delete(m.routeRulesMap, ruleID)
|
ruleKey := nbid.RuleID(rule.ID())
|
||||||
|
if _, ok := m.routeRulesMap[ruleKey]; !ok {
|
||||||
|
return fmt.Errorf("route rule not found: %s", ruleKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
|
||||||
|
return r.id == string(ruleKey)
|
||||||
|
})
|
||||||
|
if idx < 0 {
|
||||||
|
return fmt.Errorf("route rule not found in slice: %s", ruleKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
|
||||||
|
delete(m.routeRulesMap, ruleKey)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// deletePeerRuleLocked removes a peer rule from the matching slice,
|
// DeletePeerRule from the firewall by rule definition
|
||||||
// index, and dedup map. The caller must hold m.mutex.
|
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
||||||
func (m *Manager) deletePeerRuleLocked(r *PeerRule) error {
|
m.mutex.Lock()
|
||||||
target, index := &m.incomingAcceptRules, &m.incomingAcceptIndex
|
defer m.mutex.Unlock()
|
||||||
if r.action == firewall.ActionDrop {
|
|
||||||
target, index = &m.incomingDenyRules, &m.incomingDenyIndex
|
r, ok := rule.(*PeerRule)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
trimmed, stored, ok := removeRuleByID(*target, r.id)
|
var sourceMap map[netip.Addr]RuleSet
|
||||||
if !ok {
|
if r.drop {
|
||||||
|
sourceMap = m.incomingDenyRules
|
||||||
|
} else {
|
||||||
|
sourceMap = m.incomingRules
|
||||||
|
}
|
||||||
|
|
||||||
|
if ruleset, ok := sourceMap[r.ip]; ok {
|
||||||
|
if _, exists := ruleset[r.id]; !exists {
|
||||||
|
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||||
|
}
|
||||||
|
delete(ruleset, r.id)
|
||||||
|
if len(ruleset) == 0 {
|
||||||
|
delete(sourceMap, r.ip)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||||
}
|
}
|
||||||
*target = trimmed
|
|
||||||
index.remove(stored)
|
|
||||||
delete(m.peerRulesMap, r.id)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// removeRuleByID removes the first rule whose id matches ruleID from
|
// SetLegacyManagement doesn't need to be implemented for this manager
|
||||||
// rules, preserving order. It returns the trimmed slice, the removed
|
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
||||||
// rule, and whether a match was found.
|
if m.nativeFirewall == nil {
|
||||||
func removeRuleByID[S ~[]T, T firewall.Rule](rules S, ruleID firewall.RuleID) (S, T, bool) {
|
return nil
|
||||||
idx := slices.IndexFunc(rules, func(r T) bool { return r.ID() == ruleID })
|
|
||||||
var removed T
|
|
||||||
if idx < 0 {
|
|
||||||
return rules, removed, false
|
|
||||||
}
|
}
|
||||||
removed = rules[idx]
|
return m.nativeFirewall.SetLegacyManagement(isLegacy)
|
||||||
return slices.Delete(rules, idx, idx+1), removed, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetLegacyManagement is a no-op for the userspace firewall: it only matters
|
|
||||||
// when an old management server can't send route firewall rules, which the
|
|
||||||
// userspace router doesn't rely on.
|
|
||||||
func (m *Manager) SetLegacyManagement(bool) error {
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
@@ -800,14 +674,11 @@ func (m *Manager) Flush() error { return nil }
|
|||||||
// resetState clears all firewall rules and closes connection trackers.
|
// resetState clears all firewall rules and closes connection trackers.
|
||||||
// Must be called with m.mutex held.
|
// Must be called with m.mutex held.
|
||||||
func (m *Manager) resetState() {
|
func (m *Manager) resetState() {
|
||||||
m.incomingDenyRules = m.incomingDenyRules[:0]
|
clear(m.outgoingRules)
|
||||||
m.incomingAcceptRules = m.incomingAcceptRules[:0]
|
clear(m.incomingDenyRules)
|
||||||
m.incomingDenyIndex.reset()
|
clear(m.incomingRules)
|
||||||
m.incomingAcceptIndex.reset()
|
|
||||||
clear(m.peerRulesMap)
|
|
||||||
clear(m.routeRulesMap)
|
clear(m.routeRulesMap)
|
||||||
m.routeRules = m.routeRules[:0]
|
m.routeRules = m.routeRules[:0]
|
||||||
m.blockRules = nil
|
|
||||||
m.udpHookOut.Store(nil)
|
m.udpHookOut.Store(nil)
|
||||||
m.tcpHookOut.Store(nil)
|
m.tcpHookOut.Store(nil)
|
||||||
|
|
||||||
@@ -837,15 +708,21 @@ func (m *Manager) resetState() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetupEBPFProxyNoTrack is not supported by the userspace firewall: eBPF isn't
|
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||||
// used in userspace mode, so this should never be called.
|
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||||
func (m *Manager) SetupEBPFProxyNoTrack(uint16, uint16) error {
|
if m.nativeFirewall == nil {
|
||||||
return errNotSupported
|
return nil
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.SetupEBPFProxyNoTrack(proxyPort, wgPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateSet updates the rule destinations associated with the given set
|
// UpdateSet updates the rule destinations associated with the given set
|
||||||
// by merging the existing prefixes with the new ones, then deduplicating.
|
// by merging the existing prefixes with the new ones, then deduplicating.
|
||||||
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
|
return m.nativeFirewall.UpdateSet(set, prefixes)
|
||||||
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
@@ -943,11 +820,11 @@ func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP netip.Addr) {
|
|||||||
case layers.LayerTypeIPv4:
|
case layers.LayerTypeIPv4:
|
||||||
src, _ := netip.AddrFromSlice(d.ip4.SrcIP)
|
src, _ := netip.AddrFromSlice(d.ip4.SrcIP)
|
||||||
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
|
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
|
||||||
return src.Unmap(), dst.Unmap()
|
return src, dst
|
||||||
case layers.LayerTypeIPv6:
|
case layers.LayerTypeIPv6:
|
||||||
src, _ := netip.AddrFromSlice(d.ip6.SrcIP)
|
src, _ := netip.AddrFromSlice(d.ip6.SrcIP)
|
||||||
dst, _ := netip.AddrFromSlice(d.ip6.DstIP)
|
dst, _ := netip.AddrFromSlice(d.ip6.DstIP)
|
||||||
return src.Unmap(), dst.Unmap()
|
return src, dst
|
||||||
default:
|
default:
|
||||||
return netip.Addr{}, netip.Addr{}
|
return netip.Addr{}, netip.Addr{}
|
||||||
}
|
}
|
||||||
@@ -1527,12 +1404,20 @@ func (m *Manager) peerACLsBlock(srcIP netip.Addr, d *decoder, packetData []byte)
|
|||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if mgmtId, filter, ok := m.incomingDenyIndex.match(srcIP, d); ok {
|
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingDenyRules[srcIP], d); ok {
|
||||||
return mgmtId, filter
|
return mgmtId, filter
|
||||||
}
|
}
|
||||||
if mgmtId, filter, ok := m.incomingAcceptIndex.match(srcIP, d); ok {
|
|
||||||
|
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[srcIP], d); ok {
|
||||||
return mgmtId, filter
|
return mgmtId, filter
|
||||||
}
|
}
|
||||||
|
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[netip.IPv4Unspecified()], d); ok {
|
||||||
|
return mgmtId, filter
|
||||||
|
}
|
||||||
|
if mgmtId, filter, ok := validateRule(srcIP, packetData, m.incomingRules[netip.IPv6Unspecified()], d); ok {
|
||||||
|
return mgmtId, filter
|
||||||
|
}
|
||||||
|
|
||||||
return nil, true
|
return nil, true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1553,6 +1438,39 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
|
||||||
|
payloadLayer := d.decoded[1]
|
||||||
|
|
||||||
|
for _, rule := range rules {
|
||||||
|
if rule.matchByIP && ip.Compare(rule.ip) != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if rule.protoLayer == layerTypeAll {
|
||||||
|
return rule.mgmtId, rule.drop, true
|
||||||
|
}
|
||||||
|
|
||||||
|
if !protoLayerMatches(rule.protoLayer, payloadLayer) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch payloadLayer {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) {
|
||||||
|
return rule.mgmtId, rule.drop, true
|
||||||
|
}
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||||
|
return rule.mgmtId, rule.drop, true
|
||||||
|
}
|
||||||
|
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||||
|
return rule.mgmtId, rule.drop, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false, false
|
||||||
|
}
|
||||||
|
|
||||||
// routeACLsPass returns true if the packet is allowed by the route ACLs
|
// routeACLsPass returns true if the packet is allowed by the route ACLs
|
||||||
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) ([]byte, bool) {
|
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) ([]byte, bool) {
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
@@ -1629,13 +1547,10 @@ func (m *Manager) EnableRouting() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
rules, err := m.blockInvalidRouted(m.wgIface)
|
rules, err := m.blockInvalidRouted(m.wgIface)
|
||||||
|
// Persist whatever was installed even on partial failure, so DisableRouting
|
||||||
|
// can clean it up later.
|
||||||
m.blockRules = rules
|
m.blockRules = rules
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Roll back so forwarding can't stay active without the full set of
|
|
||||||
// block rules.
|
|
||||||
if derr := m.disableRouting(); derr != nil {
|
|
||||||
log.Warnf("roll back routing after block rule failure: %v", derr)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("block invalid routed: %w", err)
|
return fmt.Errorf("block invalid routed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1646,10 +1561,6 @@ func (m *Manager) DisableRouting() error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
return m.disableRouting()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) disableRouting() error {
|
|
||||||
fwder := m.forwarder.Load()
|
fwder := m.forwarder.Load()
|
||||||
if fwder == nil {
|
if fwder == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
stateful: false,
|
stateful: false,
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
// Single rule allowing all traffic
|
// Single rule allowing all traffic
|
||||||
_, err := m.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolALL, nil, nil, fw.ActionAccept)
|
_, err := m.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
},
|
},
|
||||||
desc: "Baseline: Single 'allow all' rule without connection tracking",
|
desc: "Baseline: Single 'allow all' rule without connection tracking",
|
||||||
@@ -114,13 +114,15 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
// Add explicit rules matching return traffic pattern
|
// Add explicit rules matching return traffic pattern
|
||||||
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
|
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
|
||||||
ip := generateRandomIPs(1)[0]
|
ip := generateRandomIPs(1)[0]
|
||||||
_, err := m.AddFilterRule(
|
_, err := m.AddPeerFiltering(
|
||||||
nil,
|
nil,
|
||||||
pfx(ip), fw.Network{},
|
ip,
|
||||||
fw.ProtocolTCP,
|
fw.ProtocolTCP,
|
||||||
&fw.Port{Values: []uint16{uint16(1024 + i)}},
|
&fw.Port{Values: []uint16{uint16(1024 + i)}},
|
||||||
&fw.Port{Values: []uint16{80}},
|
&fw.Port{Values: []uint16{80}},
|
||||||
fw.ActionAccept)
|
fw.ActionAccept,
|
||||||
|
"",
|
||||||
|
)
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -131,13 +133,15 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
stateful: true,
|
stateful: true,
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
// Add some basic rules but rely on state for established connections
|
// Add some basic rules but rely on state for established connections
|
||||||
_, err := m.AddFilterRule(
|
_, err := m.AddPeerFiltering(
|
||||||
nil,
|
nil,
|
||||||
pfx(net.ParseIP("0.0.0.0")), fw.Network{},
|
net.ParseIP("0.0.0.0"),
|
||||||
fw.ProtocolTCP,
|
fw.ProtocolTCP,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
fw.ActionDrop)
|
fw.ActionDrop,
|
||||||
|
"",
|
||||||
|
)
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
},
|
},
|
||||||
desc: "Connection tracking with established connections",
|
desc: "Connection tracking with established connections",
|
||||||
@@ -164,12 +168,9 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create manager and basic setup
|
// Create manager and basic setup
|
||||||
manager, err := Create(Config{
|
manager, _ := Create(&IFaceMock{
|
||||||
IFace: &IFaceMock{
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
|
||||||
require.NoError(b, err)
|
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -207,12 +208,9 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
|
|
||||||
for _, count := range connCounts {
|
for _, count := range connCounts {
|
||||||
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
||||||
manager, err := Create(Config{
|
manager, _ := Create(&IFaceMock{
|
||||||
IFace: &IFaceMock{
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
|
||||||
require.NoError(b, err)
|
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -253,12 +251,9 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
|||||||
|
|
||||||
for _, sc := range scenarios {
|
for _, sc := range scenarios {
|
||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, err := Create(Config{
|
manager, _ := Create(&IFaceMock{
|
||||||
IFace: &IFaceMock{
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
|
||||||
require.NoError(b, err)
|
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -414,12 +409,9 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
|
|
||||||
for _, sc := range scenarios {
|
for _, sc := range scenarios {
|
||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, err := Create(Config{
|
manager, _ := Create(&IFaceMock{
|
||||||
IFace: &IFaceMock{
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
|
||||||
require.NoError(b, err)
|
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -544,12 +536,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(Config{
|
manager, _ := Create(&IFaceMock{
|
||||||
IFace: &IFaceMock{
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
|
||||||
require.NoError(b, err)
|
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -557,7 +546,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
// Single rule to allow all return traffic from port 80
|
// Single rule to allow all return traffic from port 80
|
||||||
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -630,12 +619,9 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(Config{
|
manager, _ := Create(&IFaceMock{
|
||||||
IFace: &IFaceMock{
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
|
||||||
require.NoError(b, err)
|
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -643,7 +629,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
// Single rule to allow all return traffic from port 80
|
// Single rule to allow all return traffic from port 80
|
||||||
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -744,19 +730,16 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(Config{
|
manager, _ := Create(&IFaceMock{
|
||||||
IFace: &IFaceMock{
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
|
||||||
require.NoError(b, err)
|
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -827,18 +810,15 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(Config{
|
manager, _ := Create(&IFaceMock{
|
||||||
IFace: &IFaceMock{
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
|
||||||
require.NoError(b, err)
|
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
_, err := manager.AddFilterRule(nil, pfx(net.ParseIP("0.0.0.0")), fw.Network{}, fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept)
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -951,7 +931,7 @@ func BenchmarkRouteACLs(b *testing.B) {
|
|||||||
|
|
||||||
for _, r := range rules {
|
for _, r := range rules {
|
||||||
dst := fw.Network{Prefix: r.dest}
|
dst := fw.Network{Prefix: r.dest}
|
||||||
_, err := manager.AddFilterRule(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
|
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -1034,11 +1014,9 @@ func BenchmarkMSSClamping(b *testing.B) {
|
|||||||
|
|
||||||
for _, sc := range scenarios {
|
for _, sc := range scenarios {
|
||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, err := Create(Config{
|
manager, err := Create(&IFaceMock{
|
||||||
IFace: &IFaceMock{
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
@@ -1101,11 +1079,9 @@ func BenchmarkMSSClampingOverhead(b *testing.B) {
|
|||||||
|
|
||||||
for _, sc := range scenarios {
|
for _, sc := range scenarios {
|
||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, err := Create(Config{
|
manager, err := Create(&IFaceMock{
|
||||||
IFace: &IFaceMock{
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
@@ -1158,11 +1134,9 @@ func BenchmarkMSSClampingMemory(b *testing.B) {
|
|||||||
|
|
||||||
for _, sc := range scenarios {
|
for _, sc := range scenarios {
|
||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, err := Create(Config{
|
manager, err := Create(&IFaceMock{
|
||||||
IFace: &IFaceMock{
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, manager)
|
require.NotNil(t, manager)
|
||||||
|
|
||||||
@@ -496,32 +496,40 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
if tc.ruleAction == fw.ActionDrop {
|
if tc.ruleAction == fw.ActionDrop {
|
||||||
// add general accept rule for the same IP to test drop rule precedence
|
// add general accept rule for the same IP to test drop rule precedence
|
||||||
rules, err := manager.AddFilterRule(
|
rules, err := manager.AddPeerFiltering(
|
||||||
nil,
|
nil,
|
||||||
pfx(net.ParseIP(tc.ruleIP)), fw.Network{},
|
net.ParseIP(tc.ruleIP),
|
||||||
fw.ProtocolALL,
|
fw.ProtocolALL,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
fw.ActionAccept)
|
fw.ActionAccept,
|
||||||
|
"",
|
||||||
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, rules)
|
require.NotEmpty(t, rules)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.DeleteFilterRule(rules))
|
for _, rule := range rules {
|
||||||
|
require.NoError(t, manager.DeletePeerRule(rule))
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
rules, err := manager.AddFilterRule(
|
rules, err := manager.AddPeerFiltering(
|
||||||
nil,
|
nil,
|
||||||
pfx(net.ParseIP(tc.ruleIP)), fw.Network{},
|
net.ParseIP(tc.ruleIP),
|
||||||
tc.ruleProto,
|
tc.ruleProto,
|
||||||
tc.ruleSrcPort,
|
tc.ruleSrcPort,
|
||||||
tc.ruleDstPort,
|
tc.ruleDstPort,
|
||||||
tc.ruleAction)
|
tc.ruleAction,
|
||||||
|
"",
|
||||||
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, rules)
|
require.NotEmpty(t, rules)
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.DeleteFilterRule(rules))
|
for _, rule := range rules {
|
||||||
|
require.NoError(t, manager.DeletePeerRule(rule))
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||||
@@ -549,7 +557,7 @@ func TestPeerACLFilteringIPv6(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||||
|
|
||||||
@@ -644,24 +652,14 @@ func TestPeerACLFilteringIPv6(t *testing.T) {
|
|||||||
shouldBeBlocked: false,
|
shouldBeBlocked: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "IPv6: v4 wildcard ICMP rule does not match ICMPv6",
|
name: "IPv6: v4 wildcard ICMP rule matches ICMPv6 via protoLayerMatches",
|
||||||
srcIP: "fd00::1",
|
srcIP: "fd00::1",
|
||||||
dstIP: "fd00::100",
|
dstIP: "fd00::100",
|
||||||
proto: fw.ProtocolICMP,
|
proto: fw.ProtocolICMP,
|
||||||
ruleIP: "0.0.0.0",
|
ruleIP: "0.0.0.0",
|
||||||
ruleProto: fw.ProtocolICMP,
|
ruleProto: fw.ProtocolICMP,
|
||||||
ruleAction: fw.ActionAccept,
|
ruleAction: fw.ActionAccept,
|
||||||
shouldBeBlocked: true,
|
shouldBeBlocked: false,
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "IPv4: v6 wildcard ICMP rule does not match ICMPv4",
|
|
||||||
srcIP: "100.10.0.1",
|
|
||||||
dstIP: "100.10.0.100",
|
|
||||||
proto: fw.ProtocolICMP,
|
|
||||||
ruleIP: "::",
|
|
||||||
ruleProto: fw.ProtocolICMP,
|
|
||||||
ruleAction: fw.ActionAccept,
|
|
||||||
shouldBeBlocked: true,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -674,18 +672,22 @@ func TestPeerACLFilteringIPv6(t *testing.T) {
|
|||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
if tc.ruleAction == fw.ActionDrop {
|
if tc.ruleAction == fw.ActionDrop {
|
||||||
rules, err := manager.AddFilterRule(nil, pfx(net.ParseIP(tc.ruleIP)), fw.Network{}, fw.ProtocolALL, nil, nil, fw.ActionAccept)
|
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.DeleteFilterRule(rules))
|
for _, rule := range rules {
|
||||||
|
require.NoError(t, manager.DeletePeerRule(rule))
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
rules, err := manager.AddFilterRule(nil, pfx(net.ParseIP(tc.ruleIP)), fw.Network{}, tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction)
|
rules, err := manager.AddPeerFiltering(nil, net.ParseIP(tc.ruleIP), tc.ruleProto, nil, tc.ruleDstPort, tc.ruleAction, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, rules)
|
require.NotEmpty(t, rules)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.DeleteFilterRule(rules))
|
for _, rule := range rules {
|
||||||
|
require.NoError(t, manager.DeletePeerRule(rule))
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||||
@@ -798,7 +800,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
require.NoError(tb, err)
|
require.NoError(tb, err)
|
||||||
require.NoError(tb, manager.EnableRouting())
|
require.NoError(tb, manager.EnableRouting())
|
||||||
require.NotNil(tb, manager)
|
require.NotNil(tb, manager)
|
||||||
@@ -1403,7 +1405,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
if tc.rule.action == fw.ActionDrop {
|
if tc.rule.action == fw.ActionDrop {
|
||||||
// add general accept rule to test drop rule
|
// add general accept rule to test drop rule
|
||||||
rule, err := manager.AddFilterRule(
|
rule, err := manager.AddRouteFiltering(
|
||||||
nil,
|
nil,
|
||||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
@@ -1413,13 +1415,13 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
fw.ActionAccept,
|
fw.ActionAccept,
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotEmpty(t, rule)
|
require.NotNil(t, rule)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.DeleteFilterRule(rule))
|
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, err := manager.AddFilterRule(
|
rule, err := manager.AddRouteFiltering(
|
||||||
nil,
|
nil,
|
||||||
tc.rule.sources,
|
tc.rule.sources,
|
||||||
tc.rule.dest,
|
tc.rule.dest,
|
||||||
@@ -1429,10 +1431,10 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
tc.rule.action,
|
tc.rule.action,
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotEmpty(t, rule)
|
require.NotNil(t, rule)
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.DeleteFilterRule(rule))
|
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||||
})
|
})
|
||||||
|
|
||||||
srcIP := netip.MustParseAddr(tc.srcIP)
|
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||||
@@ -1600,9 +1602,9 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
var addedRules []fw.Rule
|
var rules []fw.Rule
|
||||||
for _, r := range tc.rules {
|
for _, r := range tc.rules {
|
||||||
rule, err := manager.AddFilterRule(
|
rule, err := manager.AddRouteFiltering(
|
||||||
nil,
|
nil,
|
||||||
r.sources,
|
r.sources,
|
||||||
r.dest,
|
r.dest,
|
||||||
@@ -1613,12 +1615,12 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, rule)
|
require.NotNil(t, rule)
|
||||||
addedRules = append(addedRules, rule)
|
rules = append(rules, rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
for _, rule := range addedRules {
|
for _, rule := range rules {
|
||||||
require.NoError(t, manager.DeleteFilterRule(rule))
|
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -1644,7 +1646,7 @@ func TestRouteACLSet(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -1653,7 +1655,7 @@ func TestRouteACLSet(t *testing.T) {
|
|||||||
set := fw.NewDomainSet(domain.List{"example.org"})
|
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||||
|
|
||||||
// Add rule that uses the set (initially empty)
|
// Add rule that uses the set (initially empty)
|
||||||
rule, err := manager.AddFilterRule(
|
rule, err := manager.AddRouteFiltering(
|
||||||
nil,
|
nil,
|
||||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
fw.Network{Set: set},
|
fw.Network{Set: set},
|
||||||
@@ -1687,7 +1689,7 @@ func TestRouteACLFilteringIPv6(t *testing.T) {
|
|||||||
manager := setupRoutedManager(t, "10.10.0.100/16")
|
manager := setupRoutedManager(t, "10.10.0.100/16")
|
||||||
|
|
||||||
v6Dst := netip.MustParsePrefix("fd00:dead:beef::/48")
|
v6Dst := netip.MustParsePrefix("fd00:dead:beef::/48")
|
||||||
_, err := manager.AddFilterRule(
|
_, err := manager.AddRouteFiltering(
|
||||||
nil,
|
nil,
|
||||||
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
|
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
|
||||||
fw.Network{Prefix: v6Dst},
|
fw.Network{Prefix: v6Dst},
|
||||||
@@ -1698,7 +1700,7 @@ func TestRouteACLFilteringIPv6(t *testing.T) {
|
|||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = manager.AddFilterRule(
|
_, err = manager.AddRouteFiltering(
|
||||||
nil,
|
nil,
|
||||||
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
|
[]netip.Prefix{netip.MustParsePrefix("fd00::/16")},
|
||||||
fw.Network{Prefix: netip.MustParsePrefix("fd00:dead:beef:1::/64")},
|
fw.Network{Prefix: netip.MustParsePrefix("fd00:dead:beef:1::/64")},
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
|
|||||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||||
|
|
||||||
// Add rule first time
|
// Add rule first time
|
||||||
rule1, err := manager.AddFilterRule(
|
rule1, err := manager.AddRouteFiltering(
|
||||||
[]byte("policy-1"),
|
[]byte("policy-1"),
|
||||||
sources,
|
sources,
|
||||||
destination,
|
destination,
|
||||||
@@ -42,7 +42,7 @@ func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
|
|||||||
require.NotNil(t, rule1)
|
require.NotNil(t, rule1)
|
||||||
|
|
||||||
// Add the same rule again
|
// Add the same rule again
|
||||||
rule2, err := manager.AddFilterRule(
|
rule2, err := manager.AddRouteFiltering(
|
||||||
[]byte("policy-1"),
|
[]byte("policy-1"),
|
||||||
sources,
|
sources,
|
||||||
destination,
|
destination,
|
||||||
@@ -74,7 +74,7 @@ func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
|
|||||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||||
|
|
||||||
// Add first rule
|
// Add first rule
|
||||||
rule1, err := manager.AddFilterRule(
|
rule1, err := manager.AddRouteFiltering(
|
||||||
[]byte("policy-1"),
|
[]byte("policy-1"),
|
||||||
sources,
|
sources,
|
||||||
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
@@ -86,7 +86,7 @@ func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Add different rule (different destination)
|
// Add different rule (different destination)
|
||||||
rule2, err := manager.AddFilterRule(
|
rule2, err := manager.AddRouteFiltering(
|
||||||
[]byte("policy-2"),
|
[]byte("policy-2"),
|
||||||
sources,
|
sources,
|
||||||
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
|
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
|
||||||
@@ -115,7 +115,7 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
|||||||
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||||
|
|
||||||
rule1, err := manager.AddFilterRule(
|
rule1, err := manager.AddRouteFiltering(
|
||||||
[]byte("policy-1"),
|
[]byte("policy-1"),
|
||||||
sources,
|
sources,
|
||||||
destination,
|
destination,
|
||||||
@@ -132,7 +132,7 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
|||||||
require.True(t, pass, "Traffic should pass with rule in place")
|
require.True(t, pass, "Traffic should pass with rule in place")
|
||||||
|
|
||||||
// Re-add same rule (simulates network map update)
|
// Re-add same rule (simulates network map update)
|
||||||
rule2, err := manager.AddFilterRule(
|
rule2, err := manager.AddRouteFiltering(
|
||||||
[]byte("policy-1"),
|
[]byte("policy-1"),
|
||||||
sources,
|
sources,
|
||||||
destination,
|
destination,
|
||||||
@@ -147,7 +147,7 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
|||||||
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
|
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
|
||||||
// would remove the only matching rule and cause a traffic gap.
|
// would remove the only matching rule and cause a traffic gap.
|
||||||
if rule1.ID() != rule2.ID() {
|
if rule1.ID() != rule2.ID() {
|
||||||
err = manager.DeleteFilterRule(rule1)
|
err = manager.DeleteRouteRule(rule1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -156,59 +156,6 @@ func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
|||||||
"Traffic should still pass after rule update - no gap should occur")
|
"Traffic should still pass after rule update - no gap should occur")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestBlockInvalidRoutedDualStack verifies that when the interface has an
|
|
||||||
// IPv6 overlay address, blockInvalidRouted installs a block rule for both
|
|
||||||
// the v4 and v6 WG prefixes and that routed traffic to the v6 prefix is
|
|
||||||
// denied. The v4-only soft-skip path is covered by
|
|
||||||
// TestBlockInvalidRoutedIdempotent, whose mock leaves IPv6Net invalid.
|
|
||||||
func TestBlockInvalidRoutedDualStack(t *testing.T) {
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
dev := mocks.NewMockDevice(ctrl)
|
|
||||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
|
||||||
|
|
||||||
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
|
||||||
wgNet6 := netip.MustParsePrefix("fd00:1234::1/64")
|
|
||||||
|
|
||||||
ifaceMock := &IFaceMock{
|
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
|
||||||
AddressFunc: func() wgaddr.Address {
|
|
||||||
return wgaddr.Address{
|
|
||||||
IP: wgNet.Addr(),
|
|
||||||
Network: wgNet,
|
|
||||||
IPv6: wgNet6.Addr(),
|
|
||||||
IPv6Net: wgNet6,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
GetDeviceFunc: func() *device.FilteredDevice {
|
|
||||||
return &device.FilteredDevice{Device: dev}
|
|
||||||
},
|
|
||||||
GetWGDeviceFunc: func() *wgdevice.Device {
|
|
||||||
return &wgdevice.Device{}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
require.NoError(t, manager.Close(nil))
|
|
||||||
})
|
|
||||||
|
|
||||||
rules, err := manager.blockInvalidRouted(ifaceMock)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, rules, 2, "dual-stack interface must produce a v4 and a v6 block rule")
|
|
||||||
|
|
||||||
manager.mutex.RLock()
|
|
||||||
ruleCount := len(manager.routeRules)
|
|
||||||
manager.mutex.RUnlock()
|
|
||||||
assert.Equal(t, 2, ruleCount, "should have one block rule per family")
|
|
||||||
|
|
||||||
// v6 routed traffic to the WG prefix must be denied.
|
|
||||||
srcIP := netip.MustParseAddr("2001:db8::1")
|
|
||||||
dstIP := netip.MustParseAddr("fd00:1234::50")
|
|
||||||
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 80)
|
|
||||||
assert.False(t, pass, "block rule should deny routed traffic to the v6 WG prefix")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates
|
// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates
|
||||||
// exactly one drop rule for the WireGuard network prefix, and calling it again
|
// exactly one drop rule for the WireGuard network prefix, and calling it again
|
||||||
// returns the same rule without duplicating.
|
// returns the same rule without duplicating.
|
||||||
@@ -235,7 +182,7 @@ func TestBlockInvalidRoutedIdempotent(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -298,7 +245,7 @@ func TestBlockRuleNotAccumulatedOnRepeatedEnableRouting(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -327,7 +274,7 @@ func TestRouteRuleCountStableAcrossUpdates(t *testing.T) {
|
|||||||
|
|
||||||
// Simulate 5 network map updates with the same route rule
|
// Simulate 5 network map updates with the same route rule
|
||||||
for i := 0; i < 5; i++ {
|
for i := 0; i < 5; i++ {
|
||||||
rule, err := manager.AddFilterRule(
|
rule, err := manager.AddRouteFiltering(
|
||||||
[]byte("policy-1"),
|
[]byte("policy-1"),
|
||||||
sources,
|
sources,
|
||||||
destination,
|
destination,
|
||||||
@@ -357,7 +304,7 @@ func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
|
|||||||
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||||
|
|
||||||
// Add same rule twice
|
// Add same rule twice
|
||||||
rule1, err := manager.AddFilterRule(
|
rule1, err := manager.AddRouteFiltering(
|
||||||
[]byte("policy-1"),
|
[]byte("policy-1"),
|
||||||
sources,
|
sources,
|
||||||
destination,
|
destination,
|
||||||
@@ -368,7 +315,7 @@ func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
|
|||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
rule2, err := manager.AddFilterRule(
|
rule2, err := manager.AddRouteFiltering(
|
||||||
[]byte("policy-1"),
|
[]byte("policy-1"),
|
||||||
sources,
|
sources,
|
||||||
destination,
|
destination,
|
||||||
@@ -382,7 +329,7 @@ func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
|
|||||||
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
|
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
|
||||||
|
|
||||||
// Delete using first reference
|
// Delete using first reference
|
||||||
err = manager.DeleteFilterRule(rule1)
|
err = manager.DeleteRouteRule(rule1)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify traffic no longer passes
|
// Verify traffic no longer passes
|
||||||
@@ -417,7 +364,7 @@ func setupTestManager(t *testing.T) *Manager {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, manager.EnableRouting())
|
require.NoError(t, manager.EnableRouting())
|
||||||
|
|
||||||
|
|||||||
@@ -78,19 +78,18 @@ func TestManagerCreate(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
t.Cleanup(func() { require.NoError(t, m.Close(nil)) })
|
|
||||||
|
|
||||||
if m == nil {
|
if m == nil {
|
||||||
t.Error("Manager is nil")
|
t.Error("Manager is nil")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManagerAddFilterRule(t *testing.T) {
|
func TestManagerAddPeerFiltering(t *testing.T) {
|
||||||
isSetFilterCalled := false
|
isSetFilterCalled := false
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error {
|
SetFilterFunc: func(device.PacketFilter) error {
|
||||||
@@ -99,19 +98,18 @@ func TestManagerAddFilterRule(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
t.Cleanup(func() { require.NoError(t, m.Close(nil)) })
|
|
||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
ip := net.ParseIP("192.168.1.1")
|
||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
|
|
||||||
rule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
rule, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -133,47 +131,74 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
t.Cleanup(func() { require.NoError(t, m.Close(nil)) })
|
|
||||||
|
|
||||||
ip := netip.MustParseAddr("192.168.1.1")
|
ip := netip.MustParseAddr("192.168.1.1")
|
||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
|
|
||||||
rule2, err := m.AddFilterRule(nil, pfx(ip.AsSlice()), fw.Network{}, proto, nil, port, action)
|
rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
peerRule, ok := rule2.(*PeerRule)
|
// Check rules exist in appropriate maps
|
||||||
require.True(t, ok, "rule should be a peer rule")
|
for _, r := range rule2 {
|
||||||
|
peerRule, ok := r.(*PeerRule)
|
||||||
inMap := func() bool {
|
if !ok {
|
||||||
if peerRule.action == fw.ActionDrop {
|
t.Errorf("rule should be a PeerRule")
|
||||||
return findRuleByID(m.incomingDenyRules, ip, rule2.ID())
|
continue
|
||||||
|
}
|
||||||
|
// Check if rule exists in deny or allow maps based on action
|
||||||
|
var found bool
|
||||||
|
if peerRule.drop {
|
||||||
|
_, found = m.incomingDenyRules[ip][r.ID()]
|
||||||
|
} else {
|
||||||
|
_, found = m.incomingRules[ip][r.ID()]
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Errorf("rule2 is not in the expected rules map")
|
||||||
}
|
}
|
||||||
return findRuleByID(m.incomingAcceptRules, ip, rule2.ID())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
require.True(t, inMap(), "rule2 should be in the expected rules list")
|
for _, r := range rule2 {
|
||||||
|
err = m.DeletePeerRule(r)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to delete rule: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
require.NoError(t, m.DeleteFilterRule(rule2), "failed to delete rule")
|
// Check rules are removed from appropriate maps
|
||||||
|
for _, r := range rule2 {
|
||||||
require.False(t, inMap(), "rule2 should be removed from the rules list")
|
peerRule, ok := r.(*PeerRule)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("rule should be a PeerRule")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Check if rule is removed from deny or allow maps based on action
|
||||||
|
var found bool
|
||||||
|
if peerRule.drop {
|
||||||
|
_, found = m.incomingDenyRules[ip][r.ID()]
|
||||||
|
} else {
|
||||||
|
_, found = m.incomingRules[ip][r.ID()]
|
||||||
|
}
|
||||||
|
if found {
|
||||||
|
t.Errorf("rule2 should be removed from the rules map")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSetUDPPacketHook(t *testing.T) {
|
func TestSetUDPPacketHook(t *testing.T) {
|
||||||
manager, err := Create(Config{
|
manager, err := Create(&IFaceMock{
|
||||||
IFace: &IFaceMock{
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
}, false, flowLogger, nbiface.DefaultMTU)
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||||
|
|
||||||
@@ -195,11 +220,9 @@ func TestSetUDPPacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSetTCPPacketHook(t *testing.T) {
|
func TestSetTCPPacketHook(t *testing.T) {
|
||||||
manager, err := Create(Config{
|
manager, err := Create(&IFaceMock{
|
||||||
IFace: &IFaceMock{
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
}, false, flowLogger, nbiface.DefaultMTU)
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
|
||||||
|
|
||||||
@@ -227,7 +250,7 @@ func TestPeerRuleLifecycleDenyRules(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, m.Close(nil))
|
require.NoError(t, m.Close(nil))
|
||||||
@@ -237,34 +260,36 @@ func TestPeerRuleLifecycleDenyRules(t *testing.T) {
|
|||||||
addr := netip.MustParseAddr("192.168.1.1")
|
addr := netip.MustParseAddr("192.168.1.1")
|
||||||
|
|
||||||
// Add multiple deny rules for different ports
|
// Add multiple deny rules for different ports
|
||||||
rule1, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{22}}, fw.ActionDrop)
|
rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
rule2, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{80}}, fw.ActionDrop, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
denyCount := countRulesForAddr(m.incomingDenyRules, addr)
|
denyCount := len(m.incomingDenyRules[addr])
|
||||||
m.mutex.RUnlock()
|
m.mutex.RUnlock()
|
||||||
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
|
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
|
||||||
|
|
||||||
// Delete the first deny rule
|
// Delete the first deny rule
|
||||||
err = m.DeleteFilterRule(rule1)
|
err = m.DeletePeerRule(rule1[0])
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
denyCount = countRulesForAddr(m.incomingDenyRules, addr)
|
denyCount = len(m.incomingDenyRules[addr])
|
||||||
m.mutex.RUnlock()
|
m.mutex.RUnlock()
|
||||||
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
|
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
|
||||||
|
|
||||||
// Delete the second deny rule
|
// Delete the second deny rule
|
||||||
err = m.DeleteFilterRule(rule2)
|
err = m.DeletePeerRule(rule2[0])
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
exists := countRulesForAddr(m.incomingDenyRules, addr) > 0
|
_, exists := m.incomingDenyRules[addr]
|
||||||
m.mutex.RUnlock()
|
m.mutex.RUnlock()
|
||||||
require.False(t, exists, "Deny rules should be cleaned up when empty")
|
require.False(t, exists, "Deny rules IP entry should be cleaned up when empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
|
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
|
||||||
@@ -274,7 +299,7 @@ func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, m.Close(nil))
|
require.NoError(t, m.Close(nil))
|
||||||
@@ -286,21 +311,27 @@ func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
|
|||||||
// Simulate 10 network map updates: add rule, delete old, add new
|
// Simulate 10 network map updates: add rule, delete old, add new
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
// Add a deny rule
|
// Add a deny rule
|
||||||
rules, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{22}}, fw.ActionDrop)
|
rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Add an allow rule
|
// Add an allow rule
|
||||||
allowRules, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Delete them (simulating ACL manager cleanup)
|
// Delete them (simulating ACL manager cleanup)
|
||||||
require.NoError(t, m.DeleteFilterRule(rules))
|
for _, r := range rules {
|
||||||
require.NoError(t, m.DeleteFilterRule(allowRules))
|
require.NoError(t, m.DeletePeerRule(r))
|
||||||
|
}
|
||||||
|
for _, r := range allowRules {
|
||||||
|
require.NoError(t, m.DeletePeerRule(r))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
denyCount := countRulesForAddr(m.incomingDenyRules, addr)
|
denyCount := len(m.incomingDenyRules[addr])
|
||||||
allowCount := countRulesForAddr(m.incomingAcceptRules, addr)
|
allowCount := len(m.incomingRules[addr])
|
||||||
m.mutex.RUnlock()
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
|
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
|
||||||
@@ -314,7 +345,7 @@ func TestMixedAllowDenyRulesSameIP(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, m.Close(nil))
|
require.NoError(t, m.Close(nil))
|
||||||
@@ -323,39 +354,41 @@ func TestMixedAllowDenyRulesSameIP(t *testing.T) {
|
|||||||
ip := net.ParseIP("192.168.1.1")
|
ip := net.ParseIP("192.168.1.1")
|
||||||
|
|
||||||
// Add allow rule for port 80
|
// Add allow rule for port 80
|
||||||
allowRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Add deny rule for port 22
|
// Add deny rule for port 22
|
||||||
denyRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{22}}, fw.ActionDrop)
|
denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
addr := netip.MustParseAddr("192.168.1.1")
|
addr := netip.MustParseAddr("192.168.1.1")
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
allowCount := countRulesForAddr(m.incomingAcceptRules, addr)
|
allowCount := len(m.incomingRules[addr])
|
||||||
denyCount := countRulesForAddr(m.incomingDenyRules, addr)
|
denyCount := len(m.incomingDenyRules[addr])
|
||||||
m.mutex.RUnlock()
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
|
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
|
||||||
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
|
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
|
||||||
|
|
||||||
// Delete allow rule should not affect deny rule
|
// Delete allow rule should not affect deny rule
|
||||||
err = m.DeleteFilterRule(allowRule)
|
err = m.DeletePeerRule(allowRule[0])
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
denyCountAfter := countRulesForAddr(m.incomingDenyRules, addr)
|
denyCountAfter := len(m.incomingDenyRules[addr])
|
||||||
m.mutex.RUnlock()
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
|
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
|
||||||
|
|
||||||
// Delete deny rule
|
// Delete deny rule
|
||||||
err = m.DeleteFilterRule(denyRule)
|
err = m.DeletePeerRule(denyRule[0])
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
denyExists := countRulesForAddr(m.incomingDenyRules, addr) > 0
|
_, denyExists := m.incomingDenyRules[addr]
|
||||||
allowExists := countRulesForAddr(m.incomingAcceptRules, addr) > 0
|
_, allowExists := m.incomingRules[addr]
|
||||||
m.mutex.RUnlock()
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
require.False(t, denyExists, "Deny rules should be empty")
|
require.False(t, denyExists, "Deny rules should be empty")
|
||||||
@@ -367,7 +400,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -378,7 +411,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
|
|
||||||
_, err = m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
_, err = m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -390,7 +423,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(m.incomingAcceptRules) != 0 || len(m.incomingDenyRules) != 0 {
|
if len(m.outgoingRules) != 0 || len(m.incomingRules) != 0 || len(m.incomingDenyRules) != 0 {
|
||||||
t.Errorf("rules are not empty")
|
t.Errorf("rules are not empty")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -406,7 +439,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -416,7 +449,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
proto := fw.ProtocolUDP
|
proto := fw.ProtocolUDP
|
||||||
action := fw.ActionAccept
|
action := fw.ActionAccept
|
||||||
|
|
||||||
_, err = m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, nil, action)
|
_, err = m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -469,7 +502,7 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// creating manager instance
|
// creating manager instance
|
||||||
manager, err := Create(Config{IFace: iface, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
manager, err := Create(iface, false, flowLogger, nbiface.DefaultMTU)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create Manager: %s", err)
|
t.Fatalf("Failed to create Manager: %s", err)
|
||||||
}
|
}
|
||||||
@@ -486,11 +519,9 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProcessOutgoingHooks(t *testing.T) {
|
func TestProcessOutgoingHooks(t *testing.T) {
|
||||||
manager, err := Create(Config{
|
manager, err := Create(&IFaceMock{
|
||||||
IFace: &IFaceMock{
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
}, false, flowLogger, nbiface.DefaultMTU)
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.udpTracker.Close()
|
manager.udpTracker.Close()
|
||||||
@@ -575,7 +606,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
@@ -590,7 +621,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddFilterRule(nil, pfx(ip), fw.Network{}, "tcp", nil, port, fw.ActionAccept)
|
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
}
|
}
|
||||||
@@ -600,11 +631,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||||
manager, err := Create(Config{
|
manager, err := Create(&IFaceMock{
|
||||||
IFace: &IFaceMock{
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
}, false, flowLogger, nbiface.DefaultMTU)
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.udpTracker.Close() // Close the existing tracker
|
manager.udpTracker.Close() // Close the existing tracker
|
||||||
@@ -816,7 +845,7 @@ func TestUpdateSetMerge(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -829,7 +858,7 @@ func TestUpdateSetMerge(t *testing.T) {
|
|||||||
netip.MustParsePrefix("192.168.1.0/24"),
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, err := manager.AddFilterRule(
|
rule, err := manager.AddRouteFiltering(
|
||||||
nil,
|
nil,
|
||||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
fw.Network{Set: set},
|
fw.Network{Set: set},
|
||||||
@@ -902,7 +931,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -910,7 +939,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
|
|||||||
|
|
||||||
set := fw.NewDomainSet(domain.List{"example.org"})
|
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||||
|
|
||||||
rule, err := manager.AddFilterRule(
|
rule, err := manager.AddRouteFiltering(
|
||||||
nil,
|
nil,
|
||||||
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
fw.Network{Set: set},
|
fw.Network{Set: set},
|
||||||
@@ -1022,7 +1051,7 @@ func TestMSSClamping(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: 1280})
|
manager, err := Create(ifaceMock, false, flowLogger, 1280)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -1214,7 +1243,7 @@ func TestShouldForward(t *testing.T) {
|
|||||||
return wgaddr.Address{IP: wgIP, Network: netip.PrefixFrom(wgIP, 24)}
|
return wgaddr.Address{IP: wgIP, Network: netip.PrefixFrom(wgIP, 24)}
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -1329,7 +1358,7 @@ func TestShouldForward(t *testing.T) {
|
|||||||
|
|
||||||
// Re-create manager to pick up the new address with IPv6
|
// Re-create manager to pick up the new address with IPv6
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
manager, err = Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
manager, err = Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
v6Cases := []struct {
|
v6Cases := []struct {
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
|
||||||
"gvisor.dev/gvisor/pkg/buffer"
|
"gvisor.dev/gvisor/pkg/buffer"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
@@ -21,9 +20,9 @@ import (
|
|||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -34,12 +33,6 @@ const (
|
|||||||
iosMaxInFlight = 256
|
iosMaxInFlight = 256
|
||||||
)
|
)
|
||||||
|
|
||||||
// IFace provides the WireGuard device and overlay addresses the forwarder needs.
|
|
||||||
type IFace interface {
|
|
||||||
GetWGDevice() *wgdevice.Device
|
|
||||||
Address() wgaddr.Address
|
|
||||||
}
|
|
||||||
|
|
||||||
type Forwarder struct {
|
type Forwarder struct {
|
||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
flowLogger nftypes.FlowLogger
|
flowLogger nftypes.FlowLogger
|
||||||
@@ -58,7 +51,7 @@ type Forwarder struct {
|
|||||||
pingSemaphore chan struct{}
|
pingSemaphore chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(iface IFace, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
|
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
|
||||||
s := stack.New(stack.Options{
|
s := stack.New(stack.Options{
|
||||||
NetworkProtocols: []stack.NetworkProtocolFactory{
|
NetworkProtocols: []stack.NetworkProtocolFactory{
|
||||||
ipv4.NewProtocol,
|
ipv4.NewProtocol,
|
||||||
|
|||||||
@@ -362,10 +362,6 @@ func (f *Forwarder) injectICMPv6Reply(id stack.TransportEndpointID, icmpPayload
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
if pc := f.endpoint.capture.Load(); pc != nil {
|
|
||||||
(*pc).Offer(fullPacket, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(fullPacket)
|
return len(fullPacket)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
// localIPSnapshot is an immutable snapshot of local IP addresses, swapped
|
// localIPSnapshot is an immutable snapshot of local IP addresses, swapped
|
||||||
@@ -58,7 +60,7 @@ func processInterface(iface net.Interface, ips map[netip.Addr]struct{}, addresse
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateLocalIPs rebuilds the local IP snapshot and swaps it in atomically.
|
// UpdateLocalIPs rebuilds the local IP snapshot and swaps it in atomically.
|
||||||
func (m *localIPManager) UpdateLocalIPs(iface Iface) (err error) {
|
func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
err = fmt.Errorf("panic: %v", r)
|
err = fmt.Errorf("panic: %v", r)
|
||||||
|
|||||||
@@ -487,13 +487,19 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddDNATRule adds outbound DNAT rule for forwarding external traffic to NetBird network.
|
// AddDNATRule adds outbound DNAT rule for forwarding external traffic to NetBird network.
|
||||||
func (m *Manager) AddDNATRule(firewall.ForwardRule) (firewall.Rule, error) {
|
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
return nil, errNotSupported
|
if m.nativeFirewall == nil {
|
||||||
|
return nil, errNatNotSupported
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.AddDNATRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteDNATRule deletes outbound DNAT rule.
|
// DeleteDNATRule deletes outbound DNAT rule.
|
||||||
func (m *Manager) DeleteDNATRule(firewall.Rule) error {
|
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||||
return errNotSupported
|
if m.nativeFirewall == nil {
|
||||||
|
return errNatNotSupported
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.DeleteDNATRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
// addPortRedirection adds a port redirection rule.
|
// addPortRedirection adds a port redirection rule.
|
||||||
@@ -515,6 +521,7 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
|
||||||
|
// TODO: also delegate to nativeFirewall when available for kernel WG mode
|
||||||
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||||
var layerType gopacket.LayerType
|
var layerType gopacket.LayerType
|
||||||
switch protocol {
|
switch protocol {
|
||||||
@@ -560,16 +567,20 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
|
|||||||
return m.removePortRedirection(localAddr, layerType, originalPort, translatedPort)
|
return m.removePortRedirection(localAddr, layerType, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddOutputDNAT is not supported by the userspace firewall: it backs kernel DNS
|
// AddOutputDNAT delegates to the native firewall if available.
|
||||||
// redirection, but userspace DNS is served in-process on the gVisor netstack, so
|
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||||
// this should never be called.
|
if m.nativeFirewall == nil {
|
||||||
func (m *Manager) AddOutputDNAT(netip.Addr, firewall.Protocol, uint16, uint16) error {
|
return fmt.Errorf("output DNAT not supported without native firewall")
|
||||||
return errNotSupported
|
}
|
||||||
|
return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveOutputDNAT is a no-op for the userspace firewall (see AddOutputDNAT).
|
// RemoveOutputDNAT delegates to the native firewall if available.
|
||||||
func (m *Manager) RemoveOutputDNAT(netip.Addr, firewall.Protocol, uint16, uint16) error {
|
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, originalPort, translatedPort uint16) error {
|
||||||
return nil
|
if m.nativeFirewall == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, originalPort, translatedPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
|
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
|
||||||
|
|||||||
@@ -64,11 +64,9 @@ func BenchmarkDNATTranslation(b *testing.B) {
|
|||||||
|
|
||||||
for _, sc := range scenarios {
|
for _, sc := range scenarios {
|
||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, err := Create(Config{
|
manager, err := Create(&IFaceMock{
|
||||||
IFace: &IFaceMock{
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
@@ -126,11 +124,9 @@ func BenchmarkDNATTranslation(b *testing.B) {
|
|||||||
|
|
||||||
// BenchmarkDNATConcurrency tests DNAT performance under concurrent load
|
// BenchmarkDNATConcurrency tests DNAT performance under concurrent load
|
||||||
func BenchmarkDNATConcurrency(b *testing.B) {
|
func BenchmarkDNATConcurrency(b *testing.B) {
|
||||||
manager, err := Create(Config{
|
manager, err := Create(&IFaceMock{
|
||||||
IFace: &IFaceMock{
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
@@ -200,11 +196,9 @@ func BenchmarkDNATScaling(b *testing.B) {
|
|||||||
|
|
||||||
for _, count := range mappingCounts {
|
for _, count := range mappingCounts {
|
||||||
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
|
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
|
||||||
manager, err := Create(Config{
|
manager, err := Create(&IFaceMock{
|
||||||
IFace: &IFaceMock{
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
@@ -314,11 +308,9 @@ func BenchmarkChecksumUpdate(b *testing.B) {
|
|||||||
|
|
||||||
// BenchmarkDNATMemoryAllocations checks for memory allocations in DNAT operations
|
// BenchmarkDNATMemoryAllocations checks for memory allocations in DNAT operations
|
||||||
func BenchmarkDNATMemoryAllocations(b *testing.B) {
|
func BenchmarkDNATMemoryAllocations(b *testing.B) {
|
||||||
manager, err := Create(Config{
|
manager, err := Create(&IFaceMock{
|
||||||
IFace: &IFaceMock{
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
@@ -489,11 +481,9 @@ func BenchmarkPortDNAT(b *testing.B) {
|
|||||||
|
|
||||||
for _, sc := range scenarios {
|
for _, sc := range scenarios {
|
||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, err := Create(Config{
|
manager, err := Create(&IFaceMock{
|
||||||
IFace: &IFaceMock{
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
|
|||||||
@@ -13,11 +13,9 @@ import (
|
|||||||
|
|
||||||
// TestPortDNATBasic tests basic port DNAT functionality
|
// TestPortDNATBasic tests basic port DNAT functionality
|
||||||
func TestPortDNATBasic(t *testing.T) {
|
func TestPortDNATBasic(t *testing.T) {
|
||||||
manager, err := Create(Config{
|
manager, err := Create(&IFaceMock{
|
||||||
IFace: &IFaceMock{
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -51,11 +49,9 @@ func TestPortDNATBasic(t *testing.T) {
|
|||||||
|
|
||||||
// TestPortDNATMultipleRules tests multiple port DNAT rules
|
// TestPortDNATMultipleRules tests multiple port DNAT rules
|
||||||
func TestPortDNATMultipleRules(t *testing.T) {
|
func TestPortDNATMultipleRules(t *testing.T) {
|
||||||
manager, err := Create(Config{
|
manager, err := Create(&IFaceMock{
|
||||||
IFace: &IFaceMock{
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
|
|||||||
@@ -15,11 +15,9 @@ import (
|
|||||||
|
|
||||||
// TestDNATTranslationCorrectness verifies DNAT translation works correctly
|
// TestDNATTranslationCorrectness verifies DNAT translation works correctly
|
||||||
func TestDNATTranslationCorrectness(t *testing.T) {
|
func TestDNATTranslationCorrectness(t *testing.T) {
|
||||||
manager, err := Create(Config{
|
manager, err := Create(&IFaceMock{
|
||||||
IFace: &IFaceMock{
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -106,11 +104,9 @@ func parsePacket(t testing.TB, packetData []byte) *decoder {
|
|||||||
|
|
||||||
// TestDNATMappingManagement tests adding/removing DNAT mappings
|
// TestDNATMappingManagement tests adding/removing DNAT mappings
|
||||||
func TestDNATMappingManagement(t *testing.T) {
|
func TestDNATMappingManagement(t *testing.T) {
|
||||||
manager, err := Create(Config{
|
manager, err := Create(&IFaceMock{
|
||||||
IFace: &IFaceMock{
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -156,11 +152,9 @@ func TestDNATMappingManagement(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestInboundPortDNAT(t *testing.T) {
|
func TestInboundPortDNAT(t *testing.T) {
|
||||||
manager, err := Create(Config{
|
manager, err := Create(&IFaceMock{
|
||||||
IFace: &IFaceMock{
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -208,11 +202,9 @@ func TestInboundPortDNAT(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestInboundPortDNATNegative(t *testing.T) {
|
func TestInboundPortDNATNegative(t *testing.T) {
|
||||||
manager, err := Create(Config{
|
manager, err := Create(&IFaceMock{
|
||||||
IFace: &IFaceMock{
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
|
|||||||
@@ -1,333 +0,0 @@
|
|||||||
//go:build uspbench
|
|
||||||
|
|
||||||
package uspfilter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"math/rand"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"runtime"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
|
||||||
)
|
|
||||||
|
|
||||||
// BenchmarkPeerACLMatch measures the per-packet cost of the peer ACL
|
|
||||||
// matcher (peerACLsBlock) across realistic shapes: M distinct policy
|
|
||||||
// rules, each with K source peers in its set.
|
|
||||||
//
|
|
||||||
// With the reverse-source index, miss cost is independent of M and
|
|
||||||
// hit cost grows only with the number of rules touching a single
|
|
||||||
// srcIP, not with total rule count.
|
|
||||||
func BenchmarkPeerACLMatch(b *testing.B) {
|
|
||||||
shapes := []struct{ M, K int }{
|
|
||||||
{1, 100}, {10, 100}, {50, 100}, {100, 100}, {100, 1000},
|
|
||||||
}
|
|
||||||
families := []struct {
|
|
||||||
name string
|
|
||||||
v6 bool
|
|
||||||
}{{"v4", false}, {"v6", true}}
|
|
||||||
|
|
||||||
for _, fam := range families {
|
|
||||||
for _, s := range shapes {
|
|
||||||
b.Run(fmt.Sprintf("%s/M=%d/K=%d/hit", fam.name, s.M, s.K), func(b *testing.B) {
|
|
||||||
runPeerACLBench(b, s.M, s.K, true, fam.v6)
|
|
||||||
})
|
|
||||||
b.Run(fmt.Sprintf("%s/M=%d/K=%d/miss", fam.name, s.M, s.K), func(b *testing.B) {
|
|
||||||
runPeerACLBench(b, s.M, s.K, false, fam.v6)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func runPeerACLBench(b *testing.B, m, k int, hit, v6 bool) {
|
|
||||||
log.SetOutput(io.Discard) // keep manager logs out of the benchmark output
|
|
||||||
|
|
||||||
// Miss packets are dropped, so they always traverse the full peer
|
|
||||||
// ACL matcher (every bucket) without short-circuiting and without
|
|
||||||
// touching conntrack. Disable conntrack for the miss case so it
|
|
||||||
// measures the matcher, not established-state lookups. The hit case
|
|
||||||
// keeps conntrack on: an accepted packet reaches trackInbound, which
|
|
||||||
// needs the trackers conntrack creates.
|
|
||||||
if !hit {
|
|
||||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
|
||||||
}
|
|
||||||
|
|
||||||
bits := 32
|
|
||||||
genPkt := generatePacket
|
|
||||||
addrs := uniqueAddrs
|
|
||||||
if v6 {
|
|
||||||
bits = 128
|
|
||||||
genPkt = generatePacket6
|
|
||||||
addrs = uniqueAddrs6
|
|
||||||
}
|
|
||||||
|
|
||||||
// dstIP must be a local IP so filterInbound takes the local-traffic
|
|
||||||
// path (handleLocalTraffic → peerACLsBlock) we are measuring; an
|
|
||||||
// address the manager doesn't own would be treated as routed and
|
|
||||||
// short-circuit before the peer matcher.
|
|
||||||
dstIP := addrs(1, 2)[0]
|
|
||||||
mockAddr := wgaddr.Address{IP: dstIP, Network: netip.PrefixFrom(dstIP, bits)}
|
|
||||||
if v6 {
|
|
||||||
// The local-IP manager needs a valid v4 address too; expose the v6
|
|
||||||
// dst as the interface's IPv6 so IsLocalIP recognizes it.
|
|
||||||
mockAddr = wgaddr.Address{
|
|
||||||
IP: netip.MustParseAddr("100.64.0.1"),
|
|
||||||
Network: netip.MustParsePrefix("100.64.0.0/16"),
|
|
||||||
IPv6: dstIP,
|
|
||||||
IPv6Net: netip.PrefixFrom(dstIP, bits),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
manager, err := Create(Config{
|
|
||||||
IFace: &IFaceMock{
|
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
|
||||||
AddressFunc: func() wgaddr.Address { return mockAddr },
|
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
|
||||||
require.NoError(b, err)
|
|
||||||
b.Cleanup(func() { require.NoError(b, manager.Close(nil)) })
|
|
||||||
|
|
||||||
// Generate M policies × K source peers, all distinct.
|
|
||||||
all := addrs(m*k, 1)
|
|
||||||
for i := 0; i < m; i++ {
|
|
||||||
sources := make([]netip.Prefix, k)
|
|
||||||
for j, a := range all[i*k : (i+1)*k] {
|
|
||||||
sources[j] = netip.PrefixFrom(a, bits)
|
|
||||||
}
|
|
||||||
_, err := manager.AddFilterRule(
|
|
||||||
nil, sources, fw.Network{}, fw.ProtocolTCP, nil,
|
|
||||||
&fw.Port{Values: []uint16{uint16(80 + i)}},
|
|
||||||
fw.ActionAccept)
|
|
||||||
require.NoError(b, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hit: cycle through real sources, picking the matching policy's port.
|
|
||||||
// Miss: a source from a disjoint range, port 80 (matches no policy).
|
|
||||||
var pktFn func(i int) []byte
|
|
||||||
if hit {
|
|
||||||
pktFn = func(i int) []byte {
|
|
||||||
policy := i % m
|
|
||||||
src := all[policy*k+(i%k)]
|
|
||||||
return genPkt(b, src.AsSlice(), dstIP.AsSlice(),
|
|
||||||
uint16(1024+i%60000), uint16(80+policy), layers.IPProtocolTCP)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
miss := addrs(4096, 99)
|
|
||||||
pktFn = func(i int) []byte {
|
|
||||||
return genPkt(b, miss[i%len(miss)].AsSlice(), dstIP.AsSlice(),
|
|
||||||
uint16(1024+i%60000), 80, layers.IPProtocolTCP)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pre-build a pool to avoid allocations dominating the measurement.
|
|
||||||
pool := make([][]byte, 1024)
|
|
||||||
for i := range pool {
|
|
||||||
pool[i] = pktFn(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Confirm the matcher is actually exercised: a hit packet must be
|
|
||||||
// allowed and a miss packet dropped. Without this the benchmark
|
|
||||||
// could silently time the routed early-return instead.
|
|
||||||
require.Equal(b, !hit, manager.filterInbound(pool[0], 0),
|
|
||||||
"benchmark must reach the peer ACL matcher")
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
manager.filterInbound(pool[i%len(pool)], 0)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BenchmarkPeerACLIndexMemory reports the resident memory cost of
|
|
||||||
// the source-keyed index across realistic deployment shapes. Two
|
|
||||||
// dimensions matter: (M, K), the number of policies × peers-per-policy,
|
|
||||||
// and overlap, the fraction of peers shared between policies.
|
|
||||||
//
|
|
||||||
// The output uses ReportMetric("bytes/rule") so the cost can be
|
|
||||||
// compared across shapes directly. Total bytes = bytes/rule * M.
|
|
||||||
func BenchmarkPeerACLIndexMemory(b *testing.B) {
|
|
||||||
cases := []struct {
|
|
||||||
name string
|
|
||||||
M, K int
|
|
||||||
overlapFrac float64 // 0 = disjoint per-policy sources, 1 = all share the same pool
|
|
||||||
}{
|
|
||||||
{"M=10/K=100/disjoint", 10, 100, 0},
|
|
||||||
{"M=100/K=100/disjoint", 100, 100, 0},
|
|
||||||
{"M=100/K=1000/disjoint", 100, 1000, 0},
|
|
||||||
{"M=100/K=1000/overlap=0.5", 100, 1000, 0.5},
|
|
||||||
{"M=100/K=1000/overlap=1.0", 100, 1000, 1.0},
|
|
||||||
{"M=1000/K=100/overlap=1.0", 1000, 100, 1.0},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, c := range cases {
|
|
||||||
b.Run(c.name, func(b *testing.B) {
|
|
||||||
b.ReportAllocs()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
mgr, err := Create(Config{
|
|
||||||
IFace: &IFaceMock{
|
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
|
||||||
},
|
|
||||||
FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
|
||||||
require.NoError(b, err)
|
|
||||||
|
|
||||||
populateIndexedRules(b, mgr, c.M, c.K, c.overlapFrac)
|
|
||||||
|
|
||||||
runtime.GC()
|
|
||||||
var ms runtime.MemStats
|
|
||||||
runtime.ReadMemStats(&ms)
|
|
||||||
before := ms.HeapAlloc
|
|
||||||
|
|
||||||
// Drop the manager's external roots so we can isolate
|
|
||||||
// the index cost. We hold the manager itself live; the
|
|
||||||
// index is what we measure on the second pass.
|
|
||||||
mgr.incomingAcceptIndex.reset()
|
|
||||||
mgr.incomingDenyIndex.reset()
|
|
||||||
mgr.incomingAcceptRules = mgr.incomingAcceptRules[:0]
|
|
||||||
mgr.incomingDenyRules = mgr.incomingDenyRules[:0]
|
|
||||||
runtime.GC()
|
|
||||||
runtime.ReadMemStats(&ms)
|
|
||||||
after := ms.HeapAlloc
|
|
||||||
|
|
||||||
delta := int64(before) - int64(after)
|
|
||||||
if delta < 0 {
|
|
||||||
delta = 0
|
|
||||||
}
|
|
||||||
b.ReportMetric(float64(delta)/float64(c.M), "bytes/rule")
|
|
||||||
b.ReportMetric(float64(delta), "bytes/total")
|
|
||||||
|
|
||||||
require.NoError(b, mgr.Close(nil))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func populateIndexedRules(b *testing.B, mgr *Manager, m, k int, overlapFrac float64) {
|
|
||||||
b.Helper()
|
|
||||||
pool := uniqueAddrs(k+m*k, 1) // big enough universe
|
|
||||||
sharedLen := int(float64(k) * overlapFrac)
|
|
||||||
if sharedLen > k {
|
|
||||||
sharedLen = k
|
|
||||||
}
|
|
||||||
shared := pool[:sharedLen]
|
|
||||||
uniquePool := pool[sharedLen:]
|
|
||||||
|
|
||||||
for i := 0; i < m; i++ {
|
|
||||||
sources := make([]netip.Prefix, 0, k)
|
|
||||||
for _, a := range shared {
|
|
||||||
sources = append(sources, netip.PrefixFrom(a, 32))
|
|
||||||
}
|
|
||||||
// each policy gets (k-sharedLen) addresses unique to it from the unique pool
|
|
||||||
unique := uniquePool[i*(k-sharedLen) : (i+1)*(k-sharedLen)]
|
|
||||||
for _, a := range unique {
|
|
||||||
sources = append(sources, netip.PrefixFrom(a, 32))
|
|
||||||
}
|
|
||||||
_, err := mgr.AddFilterRule(
|
|
||||||
nil, sources, fw.Network{}, fw.ProtocolTCP, nil,
|
|
||||||
&fw.Port{Values: []uint16{uint16(80 + i)}},
|
|
||||||
fw.ActionAccept)
|
|
||||||
require.NoError(b, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// uniqueAddrs returns n distinct addrs. Seeds 1, 2 are used for
|
|
||||||
// policy sources / dst; seed 99 puts misses in 10/8.
|
|
||||||
func uniqueAddrs(n int, seed int64) []netip.Addr {
|
|
||||||
out := make([]netip.Addr, 0, n)
|
|
||||||
seen := make(map[netip.Addr]struct{}, n)
|
|
||||||
r := rand.New(rand.NewSource(seed))
|
|
||||||
miss := seed == 99
|
|
||||||
for len(out) < n {
|
|
||||||
var b [4]byte
|
|
||||||
if miss {
|
|
||||||
b[0] = 10
|
|
||||||
b[1] = byte(r.Intn(256))
|
|
||||||
} else {
|
|
||||||
b[0] = 100
|
|
||||||
b[1] = byte(64 + r.Intn(63))
|
|
||||||
}
|
|
||||||
b[2] = byte(r.Intn(256))
|
|
||||||
b[3] = byte(1 + r.Intn(254))
|
|
||||||
a := netip.AddrFrom4(b)
|
|
||||||
if _, ok := seen[a]; ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
seen[a] = struct{}{}
|
|
||||||
out = append(out, a)
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
// uniqueAddrs6 mirrors uniqueAddrs for IPv6: sources come from the ULA
|
|
||||||
// range fd00::/8, the miss set (seed 99) from 2001:db8::/32 so it is
|
|
||||||
// disjoint from any source.
|
|
||||||
func uniqueAddrs6(n int, seed int64) []netip.Addr {
|
|
||||||
out := make([]netip.Addr, 0, n)
|
|
||||||
seen := make(map[netip.Addr]struct{}, n)
|
|
||||||
r := rand.New(rand.NewSource(seed))
|
|
||||||
miss := seed == 99
|
|
||||||
for len(out) < n {
|
|
||||||
var b [16]byte
|
|
||||||
if miss {
|
|
||||||
b[0], b[1], b[2], b[3] = 0x20, 0x01, 0x0d, 0xb8
|
|
||||||
} else {
|
|
||||||
b[0] = 0xfd
|
|
||||||
}
|
|
||||||
for x := 8; x < 16; x++ {
|
|
||||||
b[x] = byte(r.Intn(256))
|
|
||||||
}
|
|
||||||
a := netip.AddrFrom16(b)
|
|
||||||
if _, ok := seen[a]; ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
seen[a] = struct{}{}
|
|
||||||
out = append(out, a)
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
// generatePacket6 builds an IPv6 TCP/UDP packet, mirroring
|
|
||||||
// generatePacket for the v4 case.
|
|
||||||
func generatePacket6(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16, protocol layers.IPProtocol) []byte {
|
|
||||||
b.Helper()
|
|
||||||
|
|
||||||
ipv6 := &layers.IPv6{
|
|
||||||
Version: 6,
|
|
||||||
HopLimit: 64,
|
|
||||||
NextHeader: protocol,
|
|
||||||
SrcIP: srcIP,
|
|
||||||
DstIP: dstIP,
|
|
||||||
}
|
|
||||||
|
|
||||||
var transportLayer gopacket.SerializableLayer
|
|
||||||
switch protocol {
|
|
||||||
case layers.IPProtocolTCP:
|
|
||||||
tcp := &layers.TCP{
|
|
||||||
SrcPort: layers.TCPPort(srcPort),
|
|
||||||
DstPort: layers.TCPPort(dstPort),
|
|
||||||
SYN: true,
|
|
||||||
}
|
|
||||||
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv6))
|
|
||||||
transportLayer = tcp
|
|
||||||
case layers.IPProtocolUDP:
|
|
||||||
udp := &layers.UDP{
|
|
||||||
SrcPort: layers.UDPPort(srcPort),
|
|
||||||
DstPort: layers.UDPPort(dstPort),
|
|
||||||
}
|
|
||||||
require.NoError(b, udp.SetNetworkLayerForChecksum(ipv6))
|
|
||||||
transportLayer = udp
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := gopacket.NewSerializeBuffer()
|
|
||||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
|
||||||
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv6, transportLayer, gopacket.Payload("test")))
|
|
||||||
return buf.Bytes()
|
|
||||||
}
|
|
||||||
@@ -1,150 +0,0 @@
|
|||||||
package uspfilter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
nbiface "github.com/netbirdio/netbird/client/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
|
||||||
)
|
|
||||||
|
|
||||||
func newTestManager(t *testing.T) *Manager {
|
|
||||||
t.Helper()
|
|
||||||
ifaceMock := &IFaceMock{
|
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
|
||||||
}
|
|
||||||
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
|
||||||
require.NoError(t, err, "create manager")
|
|
||||||
t.Cleanup(func() { require.NoError(t, m.Close(nil)) })
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestAddPeerFiltering_DeduplicatesIdenticalRules verifies that adding
|
|
||||||
// the same peer rule twice does not create two backing rules. The acl
|
|
||||||
// manager keys its own cache, but the firewall backend must be
|
|
||||||
// idempotent on its own so a double-apply cannot leak rules, matching
|
|
||||||
// the route path and the kernel backends.
|
|
||||||
func TestAddPeerFiltering_DeduplicatesIdenticalRules(t *testing.T) {
|
|
||||||
m := newTestManager(t)
|
|
||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
|
||||||
proto := fw.ProtocolTCP
|
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
|
||||||
action := fw.ActionDrop
|
|
||||||
|
|
||||||
first, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
|
||||||
require.NoError(t, err, "first add")
|
|
||||||
|
|
||||||
second, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
|
||||||
require.NoError(t, err, "second add")
|
|
||||||
|
|
||||||
assert.Equal(t, first.ID(), second.ID(), "duplicate add should return the same rule id")
|
|
||||||
assert.Len(t, m.incomingDenyRules, 1, "duplicate add must not create a second backing rule")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestDeletePeerFiltering_NoRefcountSingleDeleteRemoves locks the
|
|
||||||
// backend's owner accounting for the same-owner case: a content key
|
|
||||||
// installed twice by the same owner registers one owner claim, so the
|
|
||||||
// first DeleteFilterRule removes the rule. Owner counting only kicks
|
|
||||||
// in for distinct management rule IDs (see the peer owner tests); the
|
|
||||||
// acl manager keys its tracking per (policy, content) and deletes once
|
|
||||||
// per key, so adds and deletes stay balanced.
|
|
||||||
func TestDeletePeerFiltering_NoRefcountSingleDeleteRemoves(t *testing.T) {
|
|
||||||
m := newTestManager(t)
|
|
||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
|
||||||
proto := fw.ProtocolTCP
|
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
|
||||||
action := fw.ActionDrop
|
|
||||||
|
|
||||||
first, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
|
||||||
require.NoError(t, err, "first add")
|
|
||||||
|
|
||||||
second, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
|
||||||
require.NoError(t, err, "second add")
|
|
||||||
require.Equal(t, first.ID(), second.ID(), "dedup to one rule")
|
|
||||||
require.Len(t, m.incomingDenyRules, 1, "still one backing rule after duplicate add")
|
|
||||||
|
|
||||||
require.NoError(t, m.DeleteFilterRule(first), "delete once")
|
|
||||||
assert.Empty(t, m.incomingDenyRules, "single delete removes the backing rule (no refcount)")
|
|
||||||
assert.NotContains(t, m.peerRulesMap, first.ID(), "dedup map entry cleared")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestAddPeerFiltering_DeterministicID verifies the peer rule id is a
|
|
||||||
// content hash, not a random UUID: identical inputs produce the same id
|
|
||||||
// across independent managers. A random id breaks caller-side dedup.
|
|
||||||
func TestAddPeerFiltering_DeterministicID(t *testing.T) {
|
|
||||||
ip := net.ParseIP("10.0.0.5")
|
|
||||||
proto := fw.ProtocolUDP
|
|
||||||
port := &fw.Port{Values: []uint16{53}}
|
|
||||||
action := fw.ActionAccept
|
|
||||||
|
|
||||||
m1 := newTestManager(t)
|
|
||||||
r1, err := m1.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
m2 := newTestManager(t)
|
|
||||||
r2, err := m2.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, r1.ID(), r2.ID(), "same inputs must produce the same rule id")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestAddPeerFiltering_DistinctRulesNotDeduped verifies that rules
|
|
||||||
// differing only by port are kept separate.
|
|
||||||
func TestAddPeerFiltering_DistinctRulesNotDeduped(t *testing.T) {
|
|
||||||
m := newTestManager(t)
|
|
||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
|
||||||
proto := fw.ProtocolTCP
|
|
||||||
action := fw.ActionAccept
|
|
||||||
|
|
||||||
r80, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, &fw.Port{Values: []uint16{80}}, action)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
r443, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, &fw.Port{Values: []uint16{443}}, action)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.NotEqual(t, r80.ID(), r443.ID(), "different ports must produce different rule ids")
|
|
||||||
assert.Len(t, m.incomingAcceptRules, 2, "distinct rules must both be stored")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestAddPeerFiltering_SourceVsDestPortNotDeduped verifies that a rule
|
|
||||||
// matching on source port and one matching on destination port for the
|
|
||||||
// same selector do not collide: the port lands in a different slot, so
|
|
||||||
// the content key must differ.
|
|
||||||
func TestAddPeerFiltering_SourceVsDestPortNotDeduped(t *testing.T) {
|
|
||||||
m := newTestManager(t)
|
|
||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
|
||||||
proto := fw.ProtocolTCP
|
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
|
||||||
action := fw.ActionAccept
|
|
||||||
|
|
||||||
dPortRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
sPortRule, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, port, nil, action)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.NotEqual(t, dPortRule.ID(), sPortRule.ID(), "source-port and dest-port matches must produce different rule ids")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestAddFilterRule_EmptySourcesRejected verifies that an empty source
|
|
||||||
// list is rejected rather than treated as "match any". "Match any" must
|
|
||||||
// be an explicit /0, so a zeroed list can never silently widen a rule to
|
|
||||||
// every source.
|
|
||||||
func TestAddFilterRule_EmptySourcesRejected(t *testing.T) {
|
|
||||||
m := newTestManager(t)
|
|
||||||
|
|
||||||
proto := fw.ProtocolTCP
|
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
|
||||||
|
|
||||||
_, err := m.AddFilterRule(nil, nil, fw.Network{}, proto, nil, port, fw.ActionAccept)
|
|
||||||
require.ErrorIs(t, err, fw.ErrNoSources, "empty sources must be rejected")
|
|
||||||
assert.Empty(t, m.incomingAcceptRules, "no rule should be stored for empty sources")
|
|
||||||
}
|
|
||||||
@@ -1,105 +0,0 @@
|
|||||||
package uspfilter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
nbiface "github.com/netbirdio/netbird/client/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
|
||||||
)
|
|
||||||
|
|
||||||
func newV6TestManager(t *testing.T, localV6 string) *Manager {
|
|
||||||
t.Helper()
|
|
||||||
ifaceMock := &IFaceMock{
|
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
|
||||||
AddressFunc: func() wgaddr.Address {
|
|
||||||
return wgaddr.Address{
|
|
||||||
IP: netip.MustParseAddr("100.10.0.100"),
|
|
||||||
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
|
||||||
IPv6: netip.MustParseAddr(localV6),
|
|
||||||
IPv6Net: netip.MustParsePrefix("fd00::/64"),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: nbiface.DefaultMTU})
|
|
||||||
require.NoError(t, err, "create manager")
|
|
||||||
t.Cleanup(func() { require.NoError(t, m.Close(nil)) })
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
func v6UDPPacket(t *testing.T, src, dst string, dstPort uint16) []byte {
|
|
||||||
t.Helper()
|
|
||||||
ip6 := &layers.IPv6{
|
|
||||||
Version: 6,
|
|
||||||
HopLimit: 64,
|
|
||||||
NextHeader: layers.IPProtocolUDP,
|
|
||||||
SrcIP: net.ParseIP(src),
|
|
||||||
DstIP: net.ParseIP(dst),
|
|
||||||
}
|
|
||||||
udp := &layers.UDP{SrcPort: 51334, DstPort: layers.UDPPort(dstPort)}
|
|
||||||
require.NoError(t, udp.SetNetworkLayerForChecksum(ip6))
|
|
||||||
|
|
||||||
buf := gopacket.NewSerializeBuffer()
|
|
||||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
|
||||||
require.NoError(t, gopacket.SerializeLayers(buf, opts, ip6, udp, gopacket.Payload("test")))
|
|
||||||
return buf.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestPeerACL_IPv6HostRule verifies the source index resolves /128 v6
|
|
||||||
// rules: a matching v6 source is accepted, a non-matching one is
|
|
||||||
// denied by the default. This is the end-to-end proof that the index
|
|
||||||
// is not v4-only.
|
|
||||||
func TestPeerACL_IPv6HostRule(t *testing.T) {
|
|
||||||
m := newV6TestManager(t, "fd00::100")
|
|
||||||
|
|
||||||
src := net.ParseIP("fd00::1")
|
|
||||||
_, err := m.AddFilterRule(nil, pfx(src), fw.Network{}, fw.ProtocolUDP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionAccept)
|
|
||||||
require.NoError(t, err, "add v6 accept rule")
|
|
||||||
|
|
||||||
require.False(t, m.filterInbound(v6UDPPacket(t, "fd00::1", "fd00::100", 53), 0),
|
|
||||||
"v6 packet from the allowed /128 source must be accepted")
|
|
||||||
require.True(t, m.filterInbound(v6UDPPacket(t, "fd00::2", "fd00::100", 53), 0),
|
|
||||||
"v6 packet from an unlisted source must be denied by default")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestPeerACL_IPv6IndexBuckets verifies that v6 sources land in the
|
|
||||||
// right index bucket: a /128 in bySource keyed by its address, and
|
|
||||||
// coarser prefixes (including ::/0) in the nonHost slice.
|
|
||||||
func TestPeerACL_IPv6IndexBuckets(t *testing.T) {
|
|
||||||
m := newV6TestManager(t, "fd00::100")
|
|
||||||
port := &fw.Port{Values: []uint16{53}}
|
|
||||||
|
|
||||||
host := netip.MustParseAddr("fd00::1")
|
|
||||||
_, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(host, 128)}, fw.Network{}, fw.ProtocolUDP, nil, port, fw.ActionAccept)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Contains(t, m.incomingAcceptIndex.bySource, host, "/128 v6 source must be indexed by address")
|
|
||||||
|
|
||||||
_, err = m.AddFilterRule(nil, []netip.Prefix{netip.MustParsePrefix("fd00:dead::/64")}, fw.Network{}, fw.ProtocolUDP, nil, port, fw.ActionAccept)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, m.incomingAcceptIndex.nonHost, 1, "coarser v6 prefix must land in nonHost")
|
|
||||||
|
|
||||||
_, err = m.AddFilterRule(nil, []netip.Prefix{netip.MustParsePrefix("::/0")}, fw.Network{}, fw.ProtocolUDP, nil, port, fw.ActionAccept)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, m.incomingAcceptIndex.nonHost, 2, "::/0 source must also land in nonHost")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestPeerACL_IPv4MappedSourceNormalized verifies a v4-mapped v6
|
|
||||||
// source prefix is normalized to v4 so a plain v4 packet matches it.
|
|
||||||
func TestPeerACL_IPv4MappedSourceNormalized(t *testing.T) {
|
|
||||||
m := newTestManager(t)
|
|
||||||
|
|
||||||
mapped := netip.MustParseAddr("::ffff:192.168.1.1")
|
|
||||||
_, err := m.AddFilterRule(nil, []netip.Prefix{netip.PrefixFrom(mapped, mapped.BitLen())}, fw.Network{}, fw.ProtocolUDP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionAccept)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
v4 := netip.MustParseAddr("192.168.1.1")
|
|
||||||
assert.Contains(t, m.incomingAcceptIndex.bySource, v4, "v4-mapped v6 source must be indexed as plain v4")
|
|
||||||
}
|
|
||||||
@@ -1,106 +0,0 @@
|
|||||||
package uspfilter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
)
|
|
||||||
|
|
||||||
// peerACLCheck decodes the packet and runs it through the peer ACLs,
|
|
||||||
// returning the attributed management rule id and the drop verdict.
|
|
||||||
func peerACLCheck(t *testing.T, m *Manager, packet []byte) ([]byte, bool) {
|
|
||||||
t.Helper()
|
|
||||||
d := m.decoders.Get().(*decoder)
|
|
||||||
defer m.decoders.Put(d)
|
|
||||||
require.NoError(t, d.decodePacket(packet))
|
|
||||||
src, _ := m.extractIPs(d)
|
|
||||||
return m.peerACLsBlock(src, d, packet)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestPeerACL_MultiValuePortMatchesEachListedPort guards the multi-value
|
|
||||||
// port path: a rule listing several discrete destination ports must
|
|
||||||
// match a packet to each listed port and drop one that is not listed.
|
|
||||||
// Management currently splits a multi-port policy into one rule per port
|
|
||||||
// (and the wire format carries a single port), so this list shape is not
|
|
||||||
// emitted today; the test locks correct matching in case that changes.
|
|
||||||
func TestPeerACL_MultiValuePortMatchesEachListedPort(t *testing.T) {
|
|
||||||
m := newTestManager(t)
|
|
||||||
|
|
||||||
src := net.ParseIP("192.168.1.1")
|
|
||||||
ports := &fw.Port{Values: []uint16{80, 443}}
|
|
||||||
_, err := m.AddFilterRule(nil, pfx(src), fw.Network{}, fw.ProtocolTCP, nil, ports, fw.ActionAccept)
|
|
||||||
require.NoError(t, err, "add multi-value port rule")
|
|
||||||
|
|
||||||
for _, p := range []uint16{80, 443} {
|
|
||||||
_, blocked := peerACLCheck(t, m, createTestPacket(t, "192.168.1.1", "10.0.0.2", fw.ProtocolTCP, 12345, p))
|
|
||||||
assert.False(t, blocked, "packet to listed port %d must match the rule", p)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, blocked := peerACLCheck(t, m, createTestPacket(t, "192.168.1.1", "10.0.0.2", fw.ProtocolTCP, 12345, 8080))
|
|
||||||
assert.True(t, blocked, "packet to a port not in the list must not match the rule")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestPeerACL_MatchAnyIsFamilyScoped verifies that a /0 source matches
|
|
||||||
// only packets of its own family: 0.0.0.0/0 must not match IPv6 packets
|
|
||||||
// and ::/0 must not match IPv4 packets, matching kernel backend
|
|
||||||
// semantics.
|
|
||||||
func TestPeerACL_MatchAnyIsFamilyScoped(t *testing.T) {
|
|
||||||
m := newTestManager(t)
|
|
||||||
|
|
||||||
v4Packet := createTestPacket(t, "10.0.0.1", "10.0.0.2", fw.ProtocolUDP, 12345, 53)
|
|
||||||
v6Packet := v6UDPPacket(t, "fd00::1", "fd00::100", 53)
|
|
||||||
|
|
||||||
v4Any := []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
|
||||||
rule, err := m.AddFilterRule(nil, v4Any, fw.Network{}, fw.ProtocolALL, nil, nil, fw.ActionAccept)
|
|
||||||
require.NoError(t, err, "add v4 /0 rule")
|
|
||||||
|
|
||||||
_, blocked := peerACLCheck(t, m, v4Packet)
|
|
||||||
assert.False(t, blocked, "0.0.0.0/0 must match IPv4 packets")
|
|
||||||
_, blocked = peerACLCheck(t, m, v6Packet)
|
|
||||||
assert.True(t, blocked, "0.0.0.0/0 must not match IPv6 packets")
|
|
||||||
|
|
||||||
require.NoError(t, m.DeleteFilterRule(rule))
|
|
||||||
|
|
||||||
v6Any := []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
|
||||||
_, err = m.AddFilterRule(nil, v6Any, fw.Network{}, fw.ProtocolALL, nil, nil, fw.ActionAccept)
|
|
||||||
require.NoError(t, err, "add v6 /0 rule")
|
|
||||||
|
|
||||||
_, blocked = peerACLCheck(t, m, v6Packet)
|
|
||||||
assert.False(t, blocked, "::/0 must match IPv6 packets")
|
|
||||||
_, blocked = peerACLCheck(t, m, v4Packet)
|
|
||||||
assert.True(t, blocked, "::/0 must not match IPv4 packets")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestRouteACL_MixedFamilyZeroSourcesStayFamilySafe verifies the route
|
|
||||||
// path keeps per-prefix family matching when a single rule carries both
|
|
||||||
// 0.0.0.0/0 and ::/0 sources, as blockInvalidRouted does.
|
|
||||||
func TestRouteACL_MixedFamilyZeroSourcesStayFamilySafe(t *testing.T) {
|
|
||||||
m := newTestManager(t)
|
|
||||||
|
|
||||||
sources := []netip.Prefix{
|
|
||||||
netip.PrefixFrom(netip.IPv4Unspecified(), 0),
|
|
||||||
netip.PrefixFrom(netip.IPv6Unspecified(), 0),
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := m.AddFilterRule(nil, sources, fw.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
|
||||||
fw.ProtocolALL, nil, nil, fw.ActionAccept)
|
|
||||||
require.NoError(t, err)
|
|
||||||
_, err = m.AddFilterRule(nil, sources, fw.Network{Prefix: netip.MustParsePrefix("fd00:1::/64")},
|
|
||||||
fw.ProtocolALL, nil, nil, fw.ActionAccept)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
v4Src := netip.MustParseAddr("192.168.1.1")
|
|
||||||
v6Src := netip.MustParseAddr("fd00::1")
|
|
||||||
|
|
||||||
_, pass := m.routeACLsPass(v4Src, netip.MustParseAddr("10.0.0.5"), 255, 0, 0)
|
|
||||||
assert.True(t, pass, "v4 source must match the v4 destination rule via 0.0.0.0/0")
|
|
||||||
_, pass = m.routeACLsPass(v6Src, netip.MustParseAddr("fd00:1::5"), 255, 0, 0)
|
|
||||||
assert.True(t, pass, "v6 source must match the v6 destination rule via ::/0")
|
|
||||||
_, pass = m.routeACLsPass(v6Src, netip.MustParseAddr("10.0.0.5"), 255, 0, 0)
|
|
||||||
assert.True(t, pass, "v6 source still passes the v4 destination rule via ::/0 in the same source list")
|
|
||||||
}
|
|
||||||
@@ -1,140 +0,0 @@
|
|||||||
package uspfilter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"slices"
|
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
)
|
|
||||||
|
|
||||||
// peerRuleIndex is the source-side dispatcher consulted on the packet
|
|
||||||
// hot path. It splits rules into two buckets by the shape of their
|
|
||||||
// source list:
|
|
||||||
//
|
|
||||||
// - bySource: every source is a host prefix (/32 for v4, /128 for
|
|
||||||
// v6). Keyed by the concrete source address, so a hit guarantees
|
|
||||||
// the source filter passes and the matcher goes straight to
|
|
||||||
// proto/port checks. This is the common case for peer ACLs.
|
|
||||||
// - nonHost: any source list with a prefix coarser than a host,
|
|
||||||
// including a /0 "match any". Walked linearly with a per-rule
|
|
||||||
// Contains() check. Expected small or empty for typical peer ACLs.
|
|
||||||
//
|
|
||||||
// Maintained incrementally by add/remove, never rebuilt.
|
|
||||||
type peerRuleIndex struct {
|
|
||||||
bySource map[netip.Addr][]*PeerRule
|
|
||||||
nonHost []*PeerRule
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *peerRuleIndex) add(r *PeerRule) {
|
|
||||||
if hasNonHostSource(r) {
|
|
||||||
i.nonHost = append(i.nonHost, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if i.bySource == nil {
|
|
||||||
i.bySource = make(map[netip.Addr][]*PeerRule)
|
|
||||||
}
|
|
||||||
for a := range r.sourceAddrs {
|
|
||||||
i.bySource[a] = append(i.bySource[a], r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *peerRuleIndex) remove(r *PeerRule) {
|
|
||||||
if hasNonHostSource(r) {
|
|
||||||
i.nonHost = slices.DeleteFunc(i.nonHost, eqRule(r))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if i.bySource == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for a := range r.sourceAddrs {
|
|
||||||
entries := slices.DeleteFunc(i.bySource[a], eqRule(r))
|
|
||||||
if len(entries) == 0 {
|
|
||||||
delete(i.bySource, a)
|
|
||||||
} else {
|
|
||||||
i.bySource[a] = entries
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *peerRuleIndex) reset() {
|
|
||||||
i.bySource = nil
|
|
||||||
i.nonHost = i.nonHost[:0]
|
|
||||||
}
|
|
||||||
|
|
||||||
// match returns the first rule matching src and the decoded packet.
|
|
||||||
// Host rules are found by direct map lookup; nonHost rules run a
|
|
||||||
// per-rule source Contains() check. Containment is family-scoped, so
|
|
||||||
// a /0 source matches every address of its own family only (0.0.0.0/0
|
|
||||||
// never matches v6 sources and ::/0 never matches v4). Within either
|
|
||||||
// bucket the matcher runs the proto/port filter.
|
|
||||||
func (i *peerRuleIndex) match(src netip.Addr, d *decoder) ([]byte, bool, bool) {
|
|
||||||
payloadLayer := d.decoded[1]
|
|
||||||
|
|
||||||
for _, rule := range i.bySource[src] {
|
|
||||||
if id, drop, ok := matchProto(rule, d, payloadLayer); ok {
|
|
||||||
return id, drop, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, rule := range i.nonHost {
|
|
||||||
if !prefixesContain(rule.sources, src) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if id, drop, ok := matchProto(rule, d, payloadLayer); ok {
|
|
||||||
return id, drop, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, false, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func eqRule(target *PeerRule) func(*PeerRule) bool {
|
|
||||||
return func(p *PeerRule) bool { return p == target }
|
|
||||||
}
|
|
||||||
|
|
||||||
// hasNonHostSource reports whether the rule has any source prefix
|
|
||||||
// that is not a single host address. Called only at add/remove time,
|
|
||||||
// not on the packet path.
|
|
||||||
func hasNonHostSource(r *PeerRule) bool {
|
|
||||||
for _, p := range r.sources {
|
|
||||||
if p.Bits() != p.Addr().BitLen() {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// matchProto applies the proto/port half of a rule against the
|
|
||||||
// decoded packet. Source matching is the caller's responsibility.
|
|
||||||
func matchProto(rule *PeerRule, d *decoder, payloadLayer gopacket.LayerType) ([]byte, bool, bool) {
|
|
||||||
drop := rule.action == firewall.ActionDrop
|
|
||||||
if rule.protoLayer == layerTypeAll {
|
|
||||||
return rule.mgmtId, drop, true
|
|
||||||
}
|
|
||||||
if !protoLayerMatches(rule.protoLayer, payloadLayer) {
|
|
||||||
return nil, false, false
|
|
||||||
}
|
|
||||||
switch payloadLayer {
|
|
||||||
case layers.LayerTypeTCP:
|
|
||||||
if portsMatch(rule.srcPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dstPort, uint16(d.tcp.DstPort)) {
|
|
||||||
return rule.mgmtId, drop, true
|
|
||||||
}
|
|
||||||
case layers.LayerTypeUDP:
|
|
||||||
if portsMatch(rule.srcPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dstPort, uint16(d.udp.DstPort)) {
|
|
||||||
return rule.mgmtId, drop, true
|
|
||||||
}
|
|
||||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
|
||||||
return rule.mgmtId, drop, true
|
|
||||||
}
|
|
||||||
return nil, false, false
|
|
||||||
}
|
|
||||||
|
|
||||||
func prefixesContain(sources []netip.Prefix, src netip.Addr) bool {
|
|
||||||
for _, p := range sources {
|
|
||||||
if p.Contains(src) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
@@ -10,43 +10,24 @@ import (
|
|||||||
|
|
||||||
// PeerRule to handle management of rules
|
// PeerRule to handle management of rules
|
||||||
type PeerRule struct {
|
type PeerRule struct {
|
||||||
id firewall.RuleID
|
id string
|
||||||
mgmtId []byte
|
mgmtId []byte
|
||||||
// sources is the canonical list of source prefixes this rule
|
ip netip.Addr
|
||||||
// matches against.
|
ipLayer gopacket.LayerType
|
||||||
sources []netip.Prefix
|
matchByIP bool
|
||||||
// sourceAddrs is a fast-path membership set for host-prefix
|
protoLayer gopacket.LayerType
|
||||||
// sources (/32 v4, /128 v6). Populated alongside sources;
|
sPort *firewall.Port
|
||||||
// consulted before falling back to prefix scan.
|
dPort *firewall.Port
|
||||||
sourceAddrs map[netip.Addr]struct{}
|
drop bool
|
||||||
protoLayer gopacket.LayerType
|
|
||||||
srcPort *firewall.Port
|
|
||||||
dstPort *firewall.Port
|
|
||||||
action firewall.Action
|
|
||||||
}
|
|
||||||
|
|
||||||
// matchesSource reports whether the given source address is covered
|
|
||||||
// by this rule's source list. Prefix containment is family-scoped, so
|
|
||||||
// a /0 source matches every address of its own family only.
|
|
||||||
func (r *PeerRule) matchesSource(src netip.Addr) bool {
|
|
||||||
if _, ok := r.sourceAddrs[src]; ok {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
for _, p := range r.sources {
|
|
||||||
if p.Contains(src) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ID returns the rule id
|
// ID returns the rule id
|
||||||
func (r *PeerRule) ID() firewall.RuleID {
|
func (r *PeerRule) ID() string {
|
||||||
return r.id
|
return r.id
|
||||||
}
|
}
|
||||||
|
|
||||||
type RouteRule struct {
|
type RouteRule struct {
|
||||||
id firewall.RuleID
|
id string
|
||||||
mgmtId []byte
|
mgmtId []byte
|
||||||
sources []netip.Prefix
|
sources []netip.Prefix
|
||||||
dstSet firewall.Set
|
dstSet firewall.Set
|
||||||
@@ -58,6 +39,6 @@ type RouteRule struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ID returns the rule id
|
// ID returns the rule id
|
||||||
func (r *RouteRule) ID() firewall.RuleID {
|
func (r *RouteRule) ID() string {
|
||||||
return r.id
|
return r.id
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,50 +0,0 @@
|
|||||||
package uspfilter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
)
|
|
||||||
|
|
||||||
// countRulesForAddr reports how many rules in the given slice match
|
|
||||||
// the supplied source address.
|
|
||||||
func countRulesForAddr(rules peerRules, src netip.Addr) int {
|
|
||||||
n := 0
|
|
||||||
for _, r := range rules {
|
|
||||||
if r.matchesSource(src) {
|
|
||||||
n++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return n
|
|
||||||
}
|
|
||||||
|
|
||||||
// findRuleByID returns true if the rules slice contains a rule with
|
|
||||||
// the given id whose source set covers src.
|
|
||||||
func findRuleByID(rules peerRules, src netip.Addr, id firewall.RuleID) bool {
|
|
||||||
for _, r := range rules {
|
|
||||||
if r.id == id && r.matchesSource(src) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// pfx converts a single net.IP into the []netip.Prefix form
|
|
||||||
// AddFilterRule expects. A nil or unspecified address becomes a /0
|
|
||||||
// ("match any") prefix in the matching family; any other address
|
|
||||||
// becomes its /32 (or /128) host prefix.
|
|
||||||
func pfx(ip net.IP) []netip.Prefix {
|
|
||||||
if ip == nil {
|
|
||||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
|
||||||
}
|
|
||||||
if ip.IsUnspecified() {
|
|
||||||
if ip.To4() != nil {
|
|
||||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
|
||||||
}
|
|
||||||
return []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
|
||||||
}
|
|
||||||
a, _ := netip.AddrFromSlice(ip)
|
|
||||||
a = a.Unmap()
|
|
||||||
return []netip.Prefix{netip.PrefixFrom(a, a.BitLen())}
|
|
||||||
}
|
|
||||||
@@ -285,14 +285,6 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
|||||||
trace.SourceIP = srcIP
|
trace.SourceIP = srcIP
|
||||||
trace.DestinationIP = dstIP
|
trace.DestinationIP = dstIP
|
||||||
|
|
||||||
// A fragment or otherwise truncated packet has no transport layer.
|
|
||||||
// The inbound datapath drops these via isValidPacket; the tracer must
|
|
||||||
// guard explicitly since every downstream stage reads d.decoded[1].
|
|
||||||
if len(d.decoded) < 2 {
|
|
||||||
trace.AddResult(StageReceived, "Packet has no transport layer (fragment or unsupported protocol)", false)
|
|
||||||
return trace
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine protocol and ports
|
// Determine protocol and ports
|
||||||
switch d.decoded[1] {
|
switch d.decoded[1] {
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ func TestTracePacket(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(Config{IFace: ifaceMock, FlowLogger: flowLogger, MTU: iface.DefaultMTU})
|
m, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
if !statefulMode {
|
if !statefulMode {
|
||||||
@@ -97,7 +97,7 @@ func TestTracePacket(t *testing.T) {
|
|||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionAccept
|
action := fw.ActionAccept
|
||||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
@@ -121,7 +121,7 @@ func TestTracePacket(t *testing.T) {
|
|||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
@@ -150,7 +150,7 @@ func TestTracePacket(t *testing.T) {
|
|||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionAccept
|
action := fw.ActionAccept
|
||||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
@@ -178,7 +178,7 @@ func TestTracePacket(t *testing.T) {
|
|||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionAccept
|
action := fw.ActionAccept
|
||||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
@@ -205,7 +205,7 @@ func TestTracePacket(t *testing.T) {
|
|||||||
|
|
||||||
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
|
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
|
||||||
_, err := m.AddFilterRule(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
@@ -231,7 +231,7 @@ func TestTracePacket(t *testing.T) {
|
|||||||
|
|
||||||
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
|
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
|
||||||
_, err := m.AddFilterRule(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
@@ -332,7 +332,7 @@ func TestTracePacket(t *testing.T) {
|
|||||||
ip := net.ParseIP("1.1.1.1")
|
ip := net.ParseIP("1.1.1.1")
|
||||||
proto := fw.ProtocolICMP
|
proto := fw.ProtocolICMP
|
||||||
action := fw.ActionAccept
|
action := fw.ActionAccept
|
||||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, nil, action)
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
@@ -355,7 +355,7 @@ func TestTracePacket(t *testing.T) {
|
|||||||
ip := net.ParseIP("1.1.1.1")
|
ip := net.ParseIP("1.1.1.1")
|
||||||
proto := fw.ProtocolICMP
|
proto := fw.ProtocolICMP
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, nil, action)
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
@@ -379,7 +379,7 @@ func TestTracePacket(t *testing.T) {
|
|||||||
proto := fw.ProtocolUDP
|
proto := fw.ProtocolUDP
|
||||||
port := &fw.Port{Values: []uint16{53}}
|
port := &fw.Port{Values: []uint16{53}}
|
||||||
action := fw.ActionAccept
|
action := fw.ActionAccept
|
||||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
@@ -423,7 +423,7 @@ func TestTracePacket(t *testing.T) {
|
|||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
_, err := m.AddFilterRule(nil, pfx(ip), fw.Network{}, proto, nil, port, action)
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
|||||||
@@ -1,190 +0,0 @@
|
|||||||
package acl
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
|
||||||
fwmgr "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestNetworkZeroPrefixIsRoute guards the route-vs-peer dispatch
|
|
||||||
// invariant: the backends classify a rule as a peer rule purely by the
|
|
||||||
// absence of a destination (neither prefix nor set). A default route
|
|
||||||
// (0.0.0.0/0 or ::/0) is a valid prefix and must therefore classify as
|
|
||||||
// a route, not collapse into the peer path.
|
|
||||||
func TestNetworkZeroPrefixIsRoute(t *testing.T) {
|
|
||||||
for _, p := range []string{"0.0.0.0/0", "::/0", "10.0.0.0/8"} {
|
|
||||||
n := fwmgr.Network{Prefix: netip.MustParsePrefix(p)}
|
|
||||||
assert.True(t, n.IsPrefix(), "%s must report IsPrefix", p)
|
|
||||||
assert.True(t, n.IsPrefix() || n.IsSet(), "%s must classify as a route", p)
|
|
||||||
}
|
|
||||||
|
|
||||||
// A zero-value Network is the only peer-rule shape.
|
|
||||||
var empty fwmgr.Network
|
|
||||||
assert.False(t, empty.IsPrefix(), "zero Network must not be a prefix")
|
|
||||||
assert.False(t, empty.IsSet(), "zero Network must not be a set")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestDetermineDestinationAlwaysRoute verifies determineDestination
|
|
||||||
// never yields an empty Network for a valid route rule: every branch
|
|
||||||
// (static prefix, default route, dynamic with/without domains, with and
|
|
||||||
// without a local resolver) produces a destination that classifies as a
|
|
||||||
// route. If this regresses, a route rule would be dispatched down the
|
|
||||||
// peer path, which matches on source only.
|
|
||||||
func TestDetermineDestinationAlwaysRoute(t *testing.T) {
|
|
||||||
v4 := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}
|
|
||||||
v6 := []netip.Prefix{netip.MustParsePrefix("2001:db8::/48")}
|
|
||||||
|
|
||||||
cases := []struct {
|
|
||||||
name string
|
|
||||||
rule *mgmProto.RouteFirewallRule
|
|
||||||
resolver bool
|
|
||||||
sources []netip.Prefix
|
|
||||||
}{
|
|
||||||
{"static prefix", &mgmProto.RouteFirewallRule{Destination: "192.168.0.0/16"}, false, v4},
|
|
||||||
{"static default route", &mgmProto.RouteFirewallRule{Destination: "0.0.0.0/0"}, false, v4},
|
|
||||||
{"dynamic with domains + resolver", &mgmProto.RouteFirewallRule{IsDynamic: true, Domains: []string{"example.com"}}, true, v4},
|
|
||||||
{"dynamic no domains + resolver (v4)", &mgmProto.RouteFirewallRule{IsDynamic: true}, true, v4},
|
|
||||||
{"dynamic no domains + resolver (v6)", &mgmProto.RouteFirewallRule{IsDynamic: true}, true, v6},
|
|
||||||
{"dynamic + no local resolver (v4)", &mgmProto.RouteFirewallRule{IsDynamic: true}, false, v4},
|
|
||||||
{"dynamic + no local resolver (v6)", &mgmProto.RouteFirewallRule{IsDynamic: true}, false, v6},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range cases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
dest, err := determineDestination(tc.rule, tc.resolver, tc.sources)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.True(t, dest.IsPrefix() || dest.IsSet(),
|
|
||||||
"destination must classify as a route, got empty Network")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// countingFirewall wraps a real firewall.Manager and counts filter-rule
|
|
||||||
// add/delete calls so a test can assert how many backing rules the acl
|
|
||||||
// manager actually creates and tears down.
|
|
||||||
type countingFirewall struct {
|
|
||||||
fwmgr.Manager
|
|
||||||
mu sync.Mutex
|
|
||||||
addCalls int
|
|
||||||
dels int
|
|
||||||
ruleIDs map[fwmgr.RuleID]struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// distinctRules returns the number of distinct backing rules the
|
|
||||||
// backend produced. Because the backend dedups identical content,
|
|
||||||
// repeated AddFilterRule calls for the same rule resolve to one id.
|
|
||||||
func (f *countingFirewall) distinctRules() int {
|
|
||||||
f.mu.Lock()
|
|
||||||
defer f.mu.Unlock()
|
|
||||||
return len(f.ruleIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *countingFirewall) AddFilterRule(id []byte, sources []netip.Prefix, destination fwmgr.Network, proto fwmgr.Protocol, sPort, dPort *fwmgr.Port, action fwmgr.Action) (fwmgr.Rule, error) {
|
|
||||||
rule, err := f.Manager.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
|
|
||||||
if err == nil {
|
|
||||||
f.mu.Lock()
|
|
||||||
f.addCalls++
|
|
||||||
if f.ruleIDs == nil {
|
|
||||||
f.ruleIDs = make(map[fwmgr.RuleID]struct{})
|
|
||||||
}
|
|
||||||
if rule != nil {
|
|
||||||
f.ruleIDs[rule.ID()] = struct{}{}
|
|
||||||
}
|
|
||||||
f.mu.Unlock()
|
|
||||||
}
|
|
||||||
return rule, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *countingFirewall) DeleteFilterRule(r fwmgr.Rule) error {
|
|
||||||
err := f.Manager.DeleteFilterRule(r)
|
|
||||||
if err == nil {
|
|
||||||
f.mu.Lock()
|
|
||||||
f.dels++
|
|
||||||
delete(f.ruleIDs, r.ID())
|
|
||||||
f.mu.Unlock()
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func newCountingACL(t *testing.T) (*DefaultManager, *countingFirewall, func()) {
|
|
||||||
t.Helper()
|
|
||||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
|
||||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
|
||||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
|
||||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
|
||||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
|
||||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
|
||||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{IP: network.Addr(), Network: network}).AnyTimes()
|
|
||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
|
||||||
|
|
||||||
realFW, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
fw := &countingFirewall{Manager: realFW}
|
|
||||||
cleanup := func() {
|
|
||||||
require.NoError(t, realFW.Close(nil))
|
|
||||||
ctrl.Finish()
|
|
||||||
}
|
|
||||||
return NewDefaultManager(fw), fw, cleanup
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestDuplicateContentPoliciesShareOneRule verifies the dedup contract
|
|
||||||
// the backends rely on: two policies that authorize an identical flow
|
|
||||||
// (same selector and sources) collapse to a single backing firewall
|
|
||||||
// rule, and that rule survives until BOTH policies are gone. This is
|
|
||||||
// why the backend can dedup on add without refcounting on delete: the
|
|
||||||
// acl manager's pair key matches the backend's content key, so add and
|
|
||||||
// delete stay balanced per content key across full-state reapplies.
|
|
||||||
func TestDuplicateContentPoliciesShareOneRule(t *testing.T) {
|
|
||||||
acl, fw, cleanup := newCountingACL(t)
|
|
||||||
defer cleanup()
|
|
||||||
|
|
||||||
ruleA := &mgmProto.FirewallRule{
|
|
||||||
PolicyID: []byte("policy-A"),
|
|
||||||
PeerIP: "10.0.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
}
|
|
||||||
ruleB := &mgmProto.FirewallRule{
|
|
||||||
PolicyID: []byte("policy-B"),
|
|
||||||
PeerIP: "10.0.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
}
|
|
||||||
|
|
||||||
// Both policies present: identical content collapses to one rule.
|
|
||||||
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: []*mgmProto.FirewallRule{ruleA, ruleB}, FirewallRulesIsEmpty: false}, false)
|
|
||||||
assert.Equal(t, 1, fw.distinctRules(), "identical-content policies must produce one backing rule")
|
|
||||||
assert.Equal(t, 1, len(acl.peerRulesPairs), "one content key, one pair")
|
|
||||||
|
|
||||||
// Drop policy A only: the shared rule is still authorized by B, so
|
|
||||||
// nothing is deleted.
|
|
||||||
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: []*mgmProto.FirewallRule{ruleB}, FirewallRulesIsEmpty: false}, false)
|
|
||||||
assert.Equal(t, 1, fw.distinctRules(), "no new backing rule on reapply")
|
|
||||||
assert.Equal(t, 0, fw.dels, "rule must survive while any policy still authorizes it")
|
|
||||||
assert.Equal(t, 1, len(acl.peerRulesPairs))
|
|
||||||
|
|
||||||
// Drop policy B too: now the content key has no authorizer and the
|
|
||||||
// single backing rule is removed exactly once.
|
|
||||||
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: nil, FirewallRulesIsEmpty: true}, false)
|
|
||||||
assert.Equal(t, 1, fw.dels, "rule removed once when last policy is gone")
|
|
||||||
assert.Equal(t, 0, len(acl.peerRulesPairs))
|
|
||||||
}
|
|
||||||
@@ -1,318 +0,0 @@
|
|||||||
package acl
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
|
||||||
fwmgr "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
"github.com/netbirdio/netbird/shared/netiputil"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TestGroupPeerRulesPolicyIDSeparates verifies that two FirewallRules
|
|
||||||
// with identical selectors but different PolicyIDs do NOT get merged
|
|
||||||
// into one group, so each policy's sources merge under its own
|
|
||||||
// attribution id. (Identical-content groups may still dedup to one
|
|
||||||
// backing rule at the backend; see TestDuplicateContentPoliciesShareOneRule.)
|
|
||||||
func TestGroupPeerRulesPolicyIDSeparates(t *testing.T) {
|
|
||||||
rules := []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PolicyID: []byte("policy-A"),
|
|
||||||
PeerIP: "10.0.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PolicyID: []byte("policy-B"),
|
|
||||||
PeerIP: "10.0.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
groups, denyErr, err := groupPeerRules(rules)
|
|
||||||
require.NoError(t, denyErr)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, groups, 2, "rules with different PolicyIDs must produce separate groups")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestGroupPeerRulesFamilySeparates verifies that v4 and v6 rules
|
|
||||||
// belonging to the same policy don't merge.
|
|
||||||
func TestGroupPeerRulesFamilySeparates(t *testing.T) {
|
|
||||||
rules := []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PolicyID: []byte("policy-A"),
|
|
||||||
PeerIP: "10.0.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PolicyID: []byte("policy-A"),
|
|
||||||
PeerIP: "2001:db8::1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
groups, denyErr, err := groupPeerRules(rules)
|
|
||||||
require.NoError(t, denyErr)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, groups, 2, "rules of different families must produce separate groups")
|
|
||||||
|
|
||||||
var sawV4, sawV6 bool
|
|
||||||
for _, g := range groups {
|
|
||||||
require.Len(t, g.sources, 1)
|
|
||||||
if g.sources[0].Addr().Is4() {
|
|
||||||
sawV4 = true
|
|
||||||
}
|
|
||||||
if g.sources[0].Addr().Is6() {
|
|
||||||
sawV6 = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
assert.True(t, sawV4 && sawV6)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestGroupPeerRulesSplitsMixedFamilySingleRule verifies that a single
|
|
||||||
// FirewallRule carrying both v4 and v6 source prefixes is split into one
|
|
||||||
// group per family. Each backend keys a rule to a single family, so a
|
|
||||||
// group whose sources span families would mismatch the other family's
|
|
||||||
// sources. mgmt normally emits one rule per family; this guards against
|
|
||||||
// a mixed-family rule slipping through.
|
|
||||||
func TestGroupPeerRulesSplitsMixedFamilySingleRule(t *testing.T) {
|
|
||||||
srcs := [][]byte{
|
|
||||||
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.1")),
|
|
||||||
netiputil.EncodeAddr(netip.MustParseAddr("2001:db8::1")),
|
|
||||||
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.2")),
|
|
||||||
netiputil.EncodeAddr(netip.MustParseAddr("2001:db8::2")),
|
|
||||||
}
|
|
||||||
rules := []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PolicyID: []byte("policy-A"),
|
|
||||||
SourcePrefixes: srcs,
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
groups, denyErr, err := groupPeerRules(rules)
|
|
||||||
require.NoError(t, denyErr)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, groups, 2, "mixed-family sources in one rule must split into two groups")
|
|
||||||
|
|
||||||
for _, g := range groups {
|
|
||||||
require.Len(t, g.sources, 2)
|
|
||||||
v6 := prefixIsV6(g.sources[0])
|
|
||||||
for _, s := range g.sources {
|
|
||||||
assert.Equal(t, v6, prefixIsV6(s), "every source in a group must share one family")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestGroupPeerRulesMergesSameSelector verifies that rules sharing
|
|
||||||
// every distinguishing field (policy, family, direction, action,
|
|
||||||
// proto, port) collapse into a single multi-source group.
|
|
||||||
func TestGroupPeerRulesMergesSameSelector(t *testing.T) {
|
|
||||||
mk := func(peerIP string) *mgmProto.FirewallRule {
|
|
||||||
return &mgmProto.FirewallRule{
|
|
||||||
PolicyID: []byte("policy-A"),
|
|
||||||
PeerIP: peerIP,
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rules := []*mgmProto.FirewallRule{mk("10.0.0.1"), mk("10.0.0.2"), mk("10.0.0.3")}
|
|
||||||
|
|
||||||
groups, denyErr, err := groupPeerRules(rules)
|
|
||||||
require.NoError(t, denyErr)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, groups, 1)
|
|
||||||
require.Len(t, groups[0].sources, 3)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestGroupPeerRulesPortSeparates verifies that PortInfo is part of the
|
|
||||||
// selector key: rules differing only in port must not merge, and a
|
|
||||||
// single port must not merge with a range. A regression dropping the
|
|
||||||
// port from the key would collapse rules for different ports into one.
|
|
||||||
func TestGroupPeerRulesPortSeparates(t *testing.T) {
|
|
||||||
mkPort := func(peerIP string, port uint32) *mgmProto.FirewallRule {
|
|
||||||
return &mgmProto.FirewallRule{
|
|
||||||
PolicyID: []byte("policy-A"),
|
|
||||||
PeerIP: peerIP,
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{PortSelection: &mgmProto.PortInfo_Port{Port: port}},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
groups, denyErr, err := groupPeerRules([]*mgmProto.FirewallRule{
|
|
||||||
mkPort("10.0.0.1", 80), mkPort("10.0.0.2", 80), mkPort("10.0.0.3", 443),
|
|
||||||
})
|
|
||||||
require.NoError(t, denyErr)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, groups, 2, "rules on different ports must not merge")
|
|
||||||
|
|
||||||
rangeRule := &mgmProto.FirewallRule{
|
|
||||||
PolicyID: []byte("policy-A"),
|
|
||||||
PeerIP: "10.0.0.4",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
PortInfo: &mgmProto.PortInfo{PortSelection: &mgmProto.PortInfo_Range_{Range: &mgmProto.PortInfo_Range{Start: 80, End: 90}}},
|
|
||||||
}
|
|
||||||
groups, denyErr, err = groupPeerRules([]*mgmProto.FirewallRule{mkPort("10.0.0.1", 80), rangeRule})
|
|
||||||
require.NoError(t, denyErr)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, groups, 2, "a single port and a range must not merge")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestGroupPeerRulesUsesSourcePrefixesWhenPresent verifies that the
|
|
||||||
// new sourcePrefixes wire field is consumed and produces a
|
|
||||||
// multi-source group in one shot (no client-side merging needed).
|
|
||||||
func TestGroupPeerRulesUsesSourcePrefixesWhenPresent(t *testing.T) {
|
|
||||||
srcs := [][]byte{
|
|
||||||
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.1")),
|
|
||||||
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.2")),
|
|
||||||
netiputil.EncodeAddr(netip.MustParseAddr("10.0.0.3")),
|
|
||||||
}
|
|
||||||
rules := []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PolicyID: []byte("policy-A"),
|
|
||||||
SourcePrefixes: srcs,
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
groups, denyErr, err := groupPeerRules(rules)
|
|
||||||
require.NoError(t, denyErr)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, groups, 1)
|
|
||||||
require.Len(t, groups[0].sources, 3)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestGroupPeerRulesActionSeparates verifies the obvious: accept
|
|
||||||
// and drop rules with the same selector don't merge.
|
|
||||||
func TestGroupPeerRulesActionSeparates(t *testing.T) {
|
|
||||||
rules := []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PolicyID: []byte("policy-A"),
|
|
||||||
PeerIP: "10.0.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PolicyID: []byte("policy-A"),
|
|
||||||
PeerIP: "10.0.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_DROP,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "443",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
groups, denyErr, err := groupPeerRules(rules)
|
|
||||||
require.NoError(t, denyErr)
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Len(t, groups, 2)
|
|
||||||
}
|
|
||||||
|
|
||||||
// failingDeleteFirewall wraps a real firewall.Manager and forces the
|
|
||||||
// next N DeleteFilterRule calls to fail. Used to verify that the acl
|
|
||||||
// manager retains rules whose deletion was rejected by the backend,
|
|
||||||
// so they get retried on the next ApplyFiltering pass instead of
|
|
||||||
// becoming orphans.
|
|
||||||
type failingDeleteFirewall struct {
|
|
||||||
fwmgr.Manager
|
|
||||||
failCount int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *failingDeleteFirewall) DeleteFilterRule(r fwmgr.Rule) error {
|
|
||||||
if f.failCount > 0 {
|
|
||||||
f.failCount--
|
|
||||||
return errors.New("simulated delete failure")
|
|
||||||
}
|
|
||||||
return f.Manager.DeleteFilterRule(r)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestApplyFilteringRetainsRulesOnDeleteFailure verifies that a
|
|
||||||
// transient DeleteFilterRule error doesn't make the acl manager
|
|
||||||
// forget about a rule. The rule must remain in peerRulesPairs so the
|
|
||||||
// next ApplyFiltering pass attempts the delete again.
|
|
||||||
func TestApplyFilteringRetainsRulesOnDeleteFailure(t *testing.T) {
|
|
||||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
|
||||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
defer ctrl.Finish()
|
|
||||||
|
|
||||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
|
||||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
|
||||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
|
||||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
|
||||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
|
||||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{IP: network.Addr(), Network: network}).AnyTimes()
|
|
||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
|
||||||
|
|
||||||
realFW, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() { require.NoError(t, realFW.Close(nil)) }()
|
|
||||||
|
|
||||||
fw := &failingDeleteFirewall{Manager: realFW}
|
|
||||||
acl := NewDefaultManager(fw)
|
|
||||||
|
|
||||||
// First pass: install a rule.
|
|
||||||
netmap1 := &mgmProto.NetworkMap{
|
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
|
||||||
{
|
|
||||||
PolicyID: []byte("policy-A"),
|
|
||||||
PeerIP: "10.0.0.1",
|
|
||||||
Direction: mgmProto.RuleDirection_IN,
|
|
||||||
Action: mgmProto.RuleAction_DROP,
|
|
||||||
Protocol: mgmProto.RuleProtocol_TCP,
|
|
||||||
Port: "22",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
FirewallRulesIsEmpty: false,
|
|
||||||
}
|
|
||||||
acl.ApplyFiltering(netmap1, false)
|
|
||||||
require.Equal(t, 1, len(acl.peerRulesPairs), "rule should be installed")
|
|
||||||
|
|
||||||
// Second pass: remove the rule from the map. The backend will
|
|
||||||
// fail the delete; the acl manager must retain the rule.
|
|
||||||
fw.failCount = 1
|
|
||||||
netmap2 := &mgmProto.NetworkMap{FirewallRules: nil, FirewallRulesIsEmpty: true}
|
|
||||||
acl.ApplyFiltering(netmap2, false)
|
|
||||||
require.Equal(t, 1, len(acl.peerRulesPairs),
|
|
||||||
"rule must be retained when DeleteFilterRule fails so it gets retried")
|
|
||||||
|
|
||||||
// Third pass: same map, backend no longer fails. The rule
|
|
||||||
// should now succeed in being removed.
|
|
||||||
acl.ApplyFiltering(netmap2, false)
|
|
||||||
require.Equal(t, 0, len(acl.peerRulesPairs), "retry should succeed")
|
|
||||||
}
|
|
||||||
@@ -5,18 +5,18 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RuleID aliases manager.RuleID so existing nbid.RuleID references
|
type RuleID string
|
||||||
// keep working while the canonical type lives in the firewall package.
|
|
||||||
type RuleID = manager.RuleID
|
|
||||||
|
|
||||||
// GenerateRuleID returns a deterministic content hash identifying a filter rule.
|
func (r RuleID) ID() string {
|
||||||
func GenerateRuleID(
|
return string(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateRouteRuleKey(
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination manager.Network,
|
destination manager.Network,
|
||||||
proto manager.Protocol,
|
proto manager.Protocol,
|
||||||
@@ -24,7 +24,6 @@ func GenerateRuleID(
|
|||||||
dPort *manager.Port,
|
dPort *manager.Port,
|
||||||
action manager.Action,
|
action manager.Action,
|
||||||
) RuleID {
|
) RuleID {
|
||||||
sources = slices.Clone(sources)
|
|
||||||
manager.SortPrefixes(sources)
|
manager.SortPrefixes(sources)
|
||||||
|
|
||||||
h := sha256.New()
|
h := sha256.New()
|
||||||
|
|||||||
@@ -1,75 +0,0 @@
|
|||||||
package acl
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
|
||||||
fwmgr "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
// sourcesRecordingFirewall wraps a real firewall.Manager and records
|
|
||||||
// the source prefixes of every AddFilterRule call.
|
|
||||||
type sourcesRecordingFirewall struct {
|
|
||||||
fwmgr.Manager
|
|
||||||
mu sync.Mutex
|
|
||||||
sources [][]netip.Prefix
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *sourcesRecordingFirewall) AddFilterRule(id []byte, sources []netip.Prefix, destination fwmgr.Network, proto fwmgr.Protocol, sPort, dPort *fwmgr.Port, action fwmgr.Action) (fwmgr.Rule, error) {
|
|
||||||
f.mu.Lock()
|
|
||||||
f.sources = append(f.sources, sources)
|
|
||||||
f.mu.Unlock()
|
|
||||||
return f.Manager.AddFilterRule(id, sources, destination, proto, sPort, dPort, action)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestLegacyManagementFallbackUsesMatchAnySources verifies the
|
|
||||||
// allow-all fallback for old management servers (empty FirewallRules
|
|
||||||
// without the FirewallRulesIsEmpty flag) reaches the firewall as /0
|
|
||||||
// match-any sources. The fallback rule carries PeerIP 0.0.0.0; if that
|
|
||||||
// were converted to a host prefix (0.0.0.0/32) it would match nothing
|
|
||||||
// and all peer traffic would be dropped.
|
|
||||||
func TestLegacyManagementFallbackUsesMatchAnySources(t *testing.T) {
|
|
||||||
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
|
||||||
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
|
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
defer ctrl.Finish()
|
|
||||||
|
|
||||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
|
||||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
|
||||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
|
||||||
network := netip.MustParsePrefix("172.0.0.1/32")
|
|
||||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
|
||||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{IP: network.Addr(), Network: network}).AnyTimes()
|
|
||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
|
||||||
|
|
||||||
realFW, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() { require.NoError(t, realFW.Close(nil)) }()
|
|
||||||
|
|
||||||
fw := &sourcesRecordingFirewall{Manager: realFW}
|
|
||||||
acl := NewDefaultManager(fw)
|
|
||||||
|
|
||||||
// Old management: no rules and no FirewallRulesIsEmpty flag.
|
|
||||||
acl.ApplyFiltering(&mgmProto.NetworkMap{FirewallRules: nil, FirewallRulesIsEmpty: false}, false)
|
|
||||||
|
|
||||||
fw.mu.Lock()
|
|
||||||
defer fw.mu.Unlock()
|
|
||||||
require.NotEmpty(t, fw.sources, "legacy fallback must install at least one allow-all rule")
|
|
||||||
for _, sources := range fw.sources {
|
|
||||||
require.NotEmpty(t, sources)
|
|
||||||
for _, p := range sources {
|
|
||||||
assert.Equal(t, 0, p.Bits(), "legacy fallback source %s must be a /0 match-any prefix", p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
package acl
|
package acl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/md5"
|
||||||
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -21,10 +23,6 @@ import (
|
|||||||
|
|
||||||
var ErrSourceRangesEmpty = errors.New("sources range is empty")
|
var ErrSourceRangesEmpty = errors.New("sources range is empty")
|
||||||
|
|
||||||
// ErrNoRuleReturned is returned when the firewall backend reports success
|
|
||||||
// from AddFilterRule but yields no rule to track.
|
|
||||||
var ErrNoRuleReturned = errors.New("backend returned no rule")
|
|
||||||
|
|
||||||
// Manager is a ACL rules manager
|
// Manager is a ACL rules manager
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
|
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
|
||||||
@@ -33,46 +31,17 @@ type Manager interface {
|
|||||||
// DefaultManager uses firewall manager to handle
|
// DefaultManager uses firewall manager to handle
|
||||||
type DefaultManager struct {
|
type DefaultManager struct {
|
||||||
firewall firewall.Manager
|
firewall firewall.Manager
|
||||||
|
ipsetCounter int
|
||||||
peerRulesPairs map[id.RuleID][]firewall.Rule
|
peerRulesPairs map[id.RuleID][]firewall.Rule
|
||||||
routeRules map[id.RuleID]firewall.Rule
|
routeRules map[id.RuleID]struct{}
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// peerRuleGroup collapses a set of single-source FirewallRules sharing
|
|
||||||
// the same selector into one multi-source rule to push to the backend.
|
|
||||||
type peerRuleGroup struct {
|
|
||||||
direction mgmProto.RuleDirection
|
|
||||||
action mgmProto.RuleAction
|
|
||||||
protocol mgmProto.RuleProtocol
|
|
||||||
port *mgmProto.PortInfo
|
|
||||||
// legacyPort is used only when PortInfo is empty (old management).
|
|
||||||
legacyPort string
|
|
||||||
policyID []byte
|
|
||||||
sources []netip.Prefix
|
|
||||||
}
|
|
||||||
|
|
||||||
// peerRuleKey is the comparable selector that decides which single-source
|
|
||||||
// rules merge into one group. Rules with an equal key collapse into one
|
|
||||||
// multi-source backend rule. PortInfo is flattened into its scalar fields
|
|
||||||
// so the key compares by value; policyID keeps policies separate so two
|
|
||||||
// policies authorizing different peers don't merge under one attribution.
|
|
||||||
type peerRuleKey struct {
|
|
||||||
v6 bool
|
|
||||||
policyID string
|
|
||||||
direction mgmProto.RuleDirection
|
|
||||||
action mgmProto.RuleAction
|
|
||||||
protocol mgmProto.RuleProtocol
|
|
||||||
legacyPort string
|
|
||||||
port uint16
|
|
||||||
rangeStart uint16
|
|
||||||
rangeEnd uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
|
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
|
||||||
return &DefaultManager{
|
return &DefaultManager{
|
||||||
firewall: fm,
|
firewall: fm,
|
||||||
peerRulesPairs: make(map[id.RuleID][]firewall.Rule),
|
peerRulesPairs: make(map[id.RuleID][]firewall.Rule),
|
||||||
routeRules: make(map[id.RuleID]firewall.Rule),
|
routeRules: make(map[id.RuleID]struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -99,12 +68,10 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
|||||||
time.Since(start), total)
|
time.Since(start), total)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err := d.applyPeerACLs(networkMap); err != nil {
|
d.applyPeerACLs(networkMap)
|
||||||
log.Errorf("apply peer ACLs: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
|
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
|
||||||
log.Errorf("apply route ACLs: %v", err)
|
log.Errorf("Failed to apply route ACLs: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := d.firewall.Flush(); err != nil {
|
if err := d.firewall.Flush(); err != nil {
|
||||||
@@ -112,7 +79,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) error {
|
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||||
rules := networkMap.FirewallRules
|
rules := networkMap.FirewallRules
|
||||||
|
|
||||||
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
|
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
|
||||||
@@ -135,167 +102,59 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) error {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Group incoming single-source rules from management by their
|
|
||||||
// (direction, action, proto, port) selector and merge sources.
|
|
||||||
// One call to the firewall backend per merged rule.
|
|
||||||
// A deny we cannot decode would leave its traffic unblocked, so skip
|
|
||||||
// the whole pass and keep existing rules until the next sync.
|
|
||||||
groups, denyErr, err := groupPeerRules(rules)
|
|
||||||
if denyErr != nil {
|
|
||||||
return fmt.Errorf("decode deny rule sources: %w", denyErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
newRulePairs := make(map[id.RuleID][]firewall.Rule)
|
newRulePairs := make(map[id.RuleID][]firewall.Rule)
|
||||||
|
ipsetByRuleSelectors := make(map[string]string)
|
||||||
|
|
||||||
|
// TODO: deny rules should be fatal: if a deny rule fails to apply, we must
|
||||||
|
// roll back all allow rules to avoid a fail-open where allowed traffic bypasses
|
||||||
|
// the missing deny. Currently we accumulate errors and continue.
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
if err != nil {
|
for _, r := range rules {
|
||||||
merr = multierror.Append(merr, err)
|
// if this rule is member of rule selection with more than DefaultIPsCountForSet
|
||||||
}
|
// it's IP address can be used in the ipset for firewall manager which supports it
|
||||||
|
selector := d.getRuleGroupingSelector(r)
|
||||||
// Apply denies first. A deny that fails to install is a security
|
ipsetName, ok := ipsetByRuleSelectors[selector]
|
||||||
// failure (fail-open), so if any deny errors we roll back the
|
if !ok {
|
||||||
// denies we already installed in this pass and bail out without
|
d.ipsetCounter++
|
||||||
// installing any accept. Pre-existing rules stay untouched until
|
ipsetName = fmt.Sprintf("nb%07d", d.ipsetCounter)
|
||||||
// the next successful pass clears them.
|
ipsetByRuleSelectors[selector] = ipsetName
|
||||||
denies, accepts := splitDenyAccept(groups)
|
|
||||||
if err := d.installPeerGroups(denies, newRulePairs, true); err != nil {
|
|
||||||
return fmt.Errorf("install deny rules: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := d.installPeerGroups(accepts, newRulePairs, false); err != nil {
|
|
||||||
merr = multierror.Append(merr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tear down rules that disappeared from the networkmap. Any rule
|
|
||||||
// the backend refuses to delete stays in our tracking so it gets
|
|
||||||
// retried on the next ApplyFiltering. Otherwise a transient
|
|
||||||
// delete failure would leak the rule in the firewall until the
|
|
||||||
// process exits.
|
|
||||||
for pairID, rules := range d.peerRulesPairs {
|
|
||||||
if _, ok := newRulePairs[pairID]; ok {
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
var remaining []firewall.Rule
|
pairID, rulePair, err := d.protoRuleToFirewallRule(r, ipsetName)
|
||||||
for _, rule := range rules {
|
|
||||||
if err := d.firewall.DeleteFilterRule(rule); err != nil {
|
|
||||||
log.Errorf("failed to delete peer firewall rule, will retry: %v", err)
|
|
||||||
remaining = append(remaining, rule)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(remaining) > 0 {
|
|
||||||
newRulePairs[pairID] = remaining
|
|
||||||
}
|
|
||||||
}
|
|
||||||
d.peerRulesPairs = newRulePairs
|
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// installPeerGroups applies each group and records the resulting rule
|
|
||||||
// pairs in newRulePairs. With atomic set (deny rules), a single failure
|
|
||||||
// rolls back every rule installed in this call and returns, leaving the
|
|
||||||
// firewall exactly as before: denies are fail-closed and must be applied
|
|
||||||
// all-or-nothing. With atomic unset (accept rules), failures are
|
|
||||||
// accumulated and the remaining groups still install, so one malformed
|
|
||||||
// allow cannot drop every other legitimate allow in the pass.
|
|
||||||
func (d *DefaultManager) installPeerGroups(groups []*peerRuleGroup, newRulePairs map[id.RuleID][]firewall.Rule, atomic bool) error {
|
|
||||||
var freshlyInstalled []id.RuleID
|
|
||||||
var merr *multierror.Error
|
|
||||||
for _, g := range groups {
|
|
||||||
pairID, rulePair, err := d.applyPeerGroup(g)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if atomic {
|
|
||||||
d.rollbackInstalled(freshlyInstalled)
|
|
||||||
return fmt.Errorf("apply firewall rule: %w", err)
|
|
||||||
}
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("apply firewall rule: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("apply firewall rule: %w", err))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if len(rulePair) == 0 {
|
if len(rulePair) > 0 {
|
||||||
continue
|
d.peerRulesPairs[pairID] = rulePair
|
||||||
|
newRulePairs[pairID] = rulePair
|
||||||
}
|
}
|
||||||
if _, existed := d.peerRulesPairs[pairID]; !existed {
|
|
||||||
freshlyInstalled = append(freshlyInstalled, pairID)
|
|
||||||
}
|
|
||||||
d.peerRulesPairs[pairID] = rulePair
|
|
||||||
newRulePairs[pairID] = rulePair
|
|
||||||
}
|
}
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DefaultManager) rollbackInstalled(pairIDs []id.RuleID) {
|
if merr != nil {
|
||||||
var merr *multierror.Error
|
log.Errorf("failed to apply %d peer ACL rule(s): %v", merr.Len(), nberrors.FormatErrorOrNil(merr))
|
||||||
for _, pairID := range pairIDs {
|
}
|
||||||
// Keep any rule the backend refuses to delete tracked so it is
|
|
||||||
// retried on the next ApplyFiltering instead of leaking in the
|
for pairID, rules := range d.peerRulesPairs {
|
||||||
// firewall with no tracking left to remove it.
|
if _, ok := newRulePairs[pairID]; !ok {
|
||||||
var remaining []firewall.Rule
|
for _, rule := range rules {
|
||||||
for _, rule := range d.peerRulesPairs[pairID] {
|
if err := d.firewall.DeletePeerRule(rule); err != nil {
|
||||||
if err := d.firewall.DeleteFilterRule(rule); err != nil {
|
log.Errorf("failed to delete peer firewall rule: %v", err)
|
||||||
merr = multierror.Append(merr, fmt.Errorf("rule %s: %w", pairID, err))
|
continue
|
||||||
remaining = append(remaining, rule)
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
if len(remaining) > 0 {
|
|
||||||
d.peerRulesPairs[pairID] = remaining
|
|
||||||
} else {
|
|
||||||
delete(d.peerRulesPairs, pairID)
|
delete(d.peerRulesPairs, pairID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := nberrors.FormatErrorOrNil(merr); err != nil {
|
d.peerRulesPairs = newRulePairs
|
||||||
log.Errorf("rollback peer rules: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *DefaultManager) applyPeerGroup(g *peerRuleGroup) (id.RuleID, []firewall.Rule, error) {
|
|
||||||
protocol, err := ConvertToFirewallProtocol(g.protocol)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, fmt.Errorf("skipping firewall rule: %w", err)
|
|
||||||
}
|
|
||||||
action, err := convertFirewallAction(g.action)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, fmt.Errorf("skipping firewall rule: %w", err)
|
|
||||||
}
|
|
||||||
port, err := resolveGroupPort(g)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var fwRule firewall.Rule
|
|
||||||
switch g.direction {
|
|
||||||
case mgmProto.RuleDirection_IN:
|
|
||||||
fwRule, err = d.firewall.AddFilterRule(g.policyID, g.sources, firewall.Network{}, protocol, nil, port, action)
|
|
||||||
case mgmProto.RuleDirection_OUT:
|
|
||||||
if d.firewall.IsStateful() {
|
|
||||||
return "", nil, nil
|
|
||||||
}
|
|
||||||
if shouldSkipInvertedRule(protocol, port) {
|
|
||||||
return "", nil, nil
|
|
||||||
}
|
|
||||||
fwRule, err = d.firewall.AddFilterRule(g.policyID, g.sources, firewall.Network{}, protocol, port, nil, action)
|
|
||||||
default:
|
|
||||||
return "", nil, errors.New("invalid direction")
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, fmt.Errorf("add firewall rule: %w", err)
|
|
||||||
}
|
|
||||||
if fwRule == nil {
|
|
||||||
return "", nil, fmt.Errorf("add firewall rule: %w", ErrNoRuleReturned)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Derive the pair id from the backend rule, like the route path:
|
|
||||||
// the backend dedups identical content, so two policies authorizing
|
|
||||||
// the same flow resolve to the same id and a single backing rule.
|
|
||||||
return fwRule.ID(), []firewall.Rule{fwRule}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dynamicResolver bool) error {
|
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dynamicResolver bool) error {
|
||||||
newRouteRules := make(map[id.RuleID]firewall.Rule, len(rules))
|
newRouteRules := make(map[id.RuleID]struct{}, len(rules))
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
|
||||||
// Apply new rules - firewall manager will return the existing rule if already present
|
// Apply new rules - firewall manager will return existing rule ID if already present
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
addedRule, err := d.applyRouteACL(rule, dynamicResolver)
|
id, err := d.applyRouteACL(rule, dynamicResolver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrSourceRangesEmpty) {
|
if errors.Is(err, ErrSourceRangesEmpty) {
|
||||||
log.Debugf("skipping empty sources rule with destination %s: %v", rule.Destination, err)
|
log.Debugf("skipping empty sources rule with destination %s: %v", rule.Destination, err)
|
||||||
@@ -304,18 +163,16 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dyn
|
|||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
newRouteRules[addedRule.ID()] = addedRule
|
newRouteRules[id] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tear down old route rules; retain ones the backend refused so a
|
// Clean up old firewall rules
|
||||||
// transient failure doesn't leave orphaned rules in the firewall.
|
for id := range d.routeRules {
|
||||||
for ruleID, rule := range d.routeRules {
|
if _, exists := newRouteRules[id]; !exists {
|
||||||
if _, exists := newRouteRules[ruleID]; exists {
|
if err := d.firewall.DeleteRouteRule(id); err != nil {
|
||||||
continue
|
merr = multierror.Append(merr, fmt.Errorf("delete route rule: %w", err))
|
||||||
}
|
}
|
||||||
if err := d.firewall.DeleteFilterRule(rule); err != nil {
|
// implicitly deleted from the map
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete route rule, will retry: %w", err))
|
|
||||||
newRouteRules[ruleID] = rule
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -323,202 +180,102 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dyn
|
|||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (firewall.Rule, error) {
|
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (id.RuleID, error) {
|
||||||
if len(rule.SourceRanges) == 0 {
|
if len(rule.SourceRanges) == 0 {
|
||||||
return nil, ErrSourceRangesEmpty
|
return "", ErrSourceRangesEmpty
|
||||||
}
|
}
|
||||||
|
|
||||||
var sources []netip.Prefix
|
var sources []netip.Prefix
|
||||||
for _, sourceRange := range rule.SourceRanges {
|
for _, sourceRange := range rule.SourceRanges {
|
||||||
source, err := netip.ParsePrefix(sourceRange)
|
source, err := netip.ParsePrefix(sourceRange)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("parse source range: %w", err)
|
return "", fmt.Errorf("parse source range: %w", err)
|
||||||
}
|
}
|
||||||
sources = append(sources, firewall.UnmapPrefix(source))
|
sources = append(sources, source)
|
||||||
}
|
}
|
||||||
|
|
||||||
destination, err := determineDestination(rule, dynamicResolver, sources)
|
destination, err := determineDestination(rule, dynamicResolver, sources)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("determine destination: %w", err)
|
return "", fmt.Errorf("determine destination: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
protocol, err := ConvertToFirewallProtocol(rule.Protocol)
|
protocol, err := convertToFirewallProtocol(rule.Protocol)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid protocol: %w", err)
|
return "", fmt.Errorf("invalid protocol: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
action, err := convertFirewallAction(rule.Action)
|
action, err := convertFirewallAction(rule.Action)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid action: %w", err)
|
return "", fmt.Errorf("invalid action: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dPorts := convertPortInfo(rule.PortInfo)
|
dPorts := convertPortInfo(rule.PortInfo)
|
||||||
|
|
||||||
addedRule, err := d.firewall.AddFilterRule(rule.PolicyID, sources, destination, protocol, nil, dPorts, action)
|
addedRule, err := d.firewall.AddRouteFiltering(rule.PolicyID, sources, destination, protocol, nil, dPorts, action)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("add route rule: %w", err)
|
return "", fmt.Errorf("add route rule: %w", err)
|
||||||
}
|
|
||||||
if addedRule == nil {
|
|
||||||
return nil, fmt.Errorf("add route rule: %w", ErrNoRuleReturned)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return addedRule, nil
|
return id.RuleID(addedRule.ID()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// splitDenyAccept partitions groups by action so denies can be
|
func (d *DefaultManager) protoRuleToFirewallRule(
|
||||||
// applied before accepts. Order within each bucket is preserved.
|
r *mgmProto.FirewallRule,
|
||||||
func splitDenyAccept(groups []*peerRuleGroup) (denies, accepts []*peerRuleGroup) {
|
ipsetName string,
|
||||||
for _, g := range groups {
|
) (id.RuleID, []firewall.Rule, error) {
|
||||||
if g.action == mgmProto.RuleAction_DROP {
|
ip, err := extractRuleIP(r)
|
||||||
denies = append(denies, g)
|
|
||||||
} else {
|
|
||||||
accepts = append(accepts, g)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return denies, accepts
|
|
||||||
}
|
|
||||||
|
|
||||||
// groupPeerRules merges single-source rules sharing a selector into
|
|
||||||
// multi-source groups. It splits source-decode failures by action:
|
|
||||||
// denyErr is non-nil when a deny rule could not be decoded, which is a
|
|
||||||
// fail-open risk the caller must treat as fatal for the pass; err
|
|
||||||
// carries the tolerable accept-rule failures the caller can log and
|
|
||||||
// continue past.
|
|
||||||
func groupPeerRules(rules []*mgmProto.FirewallRule) (groups []*peerRuleGroup, denyErr error, err error) {
|
|
||||||
var denyMerr, acceptMerr *multierror.Error
|
|
||||||
byKey := make(map[peerRuleKey]*peerRuleGroup)
|
|
||||||
order := make([]peerRuleKey, 0)
|
|
||||||
|
|
||||||
for _, r := range rules {
|
|
||||||
srcs, decErr := extractRuleSources(r)
|
|
||||||
if decErr != nil {
|
|
||||||
if r.Action == mgmProto.RuleAction_DROP {
|
|
||||||
denyMerr = multierror.Append(denyMerr, decErr)
|
|
||||||
} else {
|
|
||||||
acceptMerr = multierror.Append(acceptMerr, decErr)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// A single FirewallRule normally carries one address family, but
|
|
||||||
// split by family defensively: each backend keys a rule to one
|
|
||||||
// family and would mismatch sources of the other, so a group's
|
|
||||||
// sources must never span families.
|
|
||||||
v4, v6 := splitPrefixesByFamily(srcs)
|
|
||||||
for _, sub := range []struct {
|
|
||||||
isV6 bool
|
|
||||||
sources []netip.Prefix
|
|
||||||
}{{false, v4}, {true, v6}} {
|
|
||||||
if len(sub.sources) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
key := ruleGroupKey(r, sub.isV6)
|
|
||||||
g, ok := byKey[key]
|
|
||||||
if !ok {
|
|
||||||
g = &peerRuleGroup{
|
|
||||||
direction: r.Direction,
|
|
||||||
action: r.Action,
|
|
||||||
protocol: r.Protocol,
|
|
||||||
port: r.PortInfo,
|
|
||||||
legacyPort: r.Port,
|
|
||||||
policyID: r.PolicyID,
|
|
||||||
}
|
|
||||||
byKey[key] = g
|
|
||||||
order = append(order, key)
|
|
||||||
}
|
|
||||||
g.sources = append(g.sources, sub.sources...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
out := make([]*peerRuleGroup, 0, len(order))
|
|
||||||
for _, k := range order {
|
|
||||||
out = append(out, byKey[k])
|
|
||||||
}
|
|
||||||
return out, nberrors.FormatErrorOrNil(denyMerr), nberrors.FormatErrorOrNil(acceptMerr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func prefixIsV6(p netip.Prefix) bool {
|
|
||||||
return p.Addr().Is6() && !p.Addr().Is4In6()
|
|
||||||
}
|
|
||||||
|
|
||||||
// splitPrefixesByFamily partitions prefixes into IPv4 and IPv6 groups.
|
|
||||||
func splitPrefixesByFamily(prefixes []netip.Prefix) (v4, v6 []netip.Prefix) {
|
|
||||||
for _, p := range prefixes {
|
|
||||||
if prefixIsV6(p) {
|
|
||||||
v6 = append(v6, p)
|
|
||||||
} else {
|
|
||||||
v4 = append(v4, p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return v4, v6
|
|
||||||
}
|
|
||||||
|
|
||||||
// ruleGroupKey builds the selector key for a rule. v6 must reflect the
|
|
||||||
// rule's source family: mgmt emits one rule per family and mixing them
|
|
||||||
// would break ICMP-variant selection in uspfilter.
|
|
||||||
func ruleGroupKey(r *mgmProto.FirewallRule, v6 bool) peerRuleKey {
|
|
||||||
k := peerRuleKey{
|
|
||||||
v6: v6,
|
|
||||||
policyID: string(r.PolicyID),
|
|
||||||
direction: r.Direction,
|
|
||||||
action: r.Action,
|
|
||||||
protocol: r.Protocol,
|
|
||||||
legacyPort: r.Port,
|
|
||||||
}
|
|
||||||
if pi := r.PortInfo; pi != nil {
|
|
||||||
k.port = uint16(pi.GetPort())
|
|
||||||
if rng := pi.GetRange(); rng != nil {
|
|
||||||
k.rangeStart = uint16(rng.GetStart())
|
|
||||||
k.rangeEnd = uint16(rng.GetEnd())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return k
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractRuleSources returns all source prefixes the rule applies to.
|
|
||||||
// New management populates sourcePrefixes; older management sets PeerIP.
|
|
||||||
func extractRuleSources(r *mgmProto.FirewallRule) ([]netip.Prefix, error) {
|
|
||||||
if len(r.SourcePrefixes) > 0 {
|
|
||||||
out := make([]netip.Prefix, 0, len(r.SourcePrefixes))
|
|
||||||
for _, raw := range r.SourcePrefixes {
|
|
||||||
addr, err := netiputil.DecodeAddr(raw)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("decode source prefix: %w", err)
|
|
||||||
}
|
|
||||||
out = append(out, netip.PrefixFrom(addr.Unmap(), addr.Unmap().BitLen()))
|
|
||||||
}
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
peerIP := r.PeerIP //nolint:staticcheck // PeerIP is the legacy source field for old management servers
|
|
||||||
addr, err := netip.ParseAddr(peerIP)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("parse peer IP %q: %w", peerIP, err)
|
return "", nil, err
|
||||||
}
|
}
|
||||||
addr = addr.Unmap()
|
|
||||||
// An unspecified PeerIP means "any peer" (legacy management
|
|
||||||
// allow-all fallback); only a /0 prefix matches any source in the
|
|
||||||
// backends, a full-length prefix would match nothing.
|
|
||||||
if addr.IsUnspecified() {
|
|
||||||
return []netip.Prefix{netip.PrefixFrom(addr, 0)}, nil
|
|
||||||
}
|
|
||||||
return []netip.Prefix{netip.PrefixFrom(addr, addr.BitLen())}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func resolveGroupPort(g *peerRuleGroup) (*firewall.Port, error) {
|
protocol, err := convertToFirewallProtocol(r.Protocol)
|
||||||
if !portInfoEmpty(g.port) {
|
if err != nil {
|
||||||
return convertPortInfo(g.port), nil
|
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
|
||||||
}
|
}
|
||||||
if g.legacyPort != "" {
|
|
||||||
value, err := strconv.ParseUint(g.legacyPort, 10, 16)
|
action, err := convertFirewallAction(r.Action)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, fmt.Errorf("skipping firewall rule: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var port *firewall.Port
|
||||||
|
if !portInfoEmpty(r.PortInfo) {
|
||||||
|
port = convertPortInfo(r.PortInfo)
|
||||||
|
} else if r.Port != "" {
|
||||||
|
// old version of management, single port
|
||||||
|
value, err := strconv.Atoi(r.Port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid port: %w", err)
|
return "", nil, fmt.Errorf("invalid port: %w", err)
|
||||||
}
|
}
|
||||||
return &firewall.Port{
|
port = &firewall.Port{
|
||||||
Values: []uint16{uint16(value)},
|
Values: []uint16{uint16(value)},
|
||||||
}, nil
|
}
|
||||||
}
|
}
|
||||||
// nolint:nilnil // a nil port legitimately means "no port restriction"
|
|
||||||
return nil, nil
|
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action)
|
||||||
|
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
|
||||||
|
return ruleID, rulesPair, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var rules []firewall.Rule
|
||||||
|
switch r.Direction {
|
||||||
|
case mgmProto.RuleDirection_IN:
|
||||||
|
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||||
|
case mgmProto.RuleDirection_OUT:
|
||||||
|
if d.firewall.IsStateful() {
|
||||||
|
return "", nil, nil
|
||||||
|
}
|
||||||
|
// return traffic for outbound connections if firewall is stateless
|
||||||
|
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||||
|
default:
|
||||||
|
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ruleID, rules, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
|
func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
|
||||||
@@ -537,9 +294,85 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConvertToFirewallProtocol maps a management rule protocol to the
|
func (d *DefaultManager) addInRules(
|
||||||
// firewall protocol type.
|
id []byte,
|
||||||
func ConvertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
|
ip netip.Addr,
|
||||||
|
protocol firewall.Protocol,
|
||||||
|
port *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
ipsetName string,
|
||||||
|
) ([]firewall.Rule, error) {
|
||||||
|
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, nil, port, action, ipsetName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DefaultManager) addOutRules(
|
||||||
|
id []byte,
|
||||||
|
ip netip.Addr,
|
||||||
|
protocol firewall.Protocol,
|
||||||
|
port *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
ipsetName string,
|
||||||
|
) ([]firewall.Rule, error) {
|
||||||
|
if shouldSkipInvertedRule(protocol, port) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rule, err := d.firewall.AddPeerFiltering(id, ip.AsSlice(), protocol, port, nil, action, ipsetName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getPeerRuleID returns unique ID for the rule based on its parameters.
|
||||||
|
func (d *DefaultManager) getPeerRuleID(
|
||||||
|
ip netip.Addr,
|
||||||
|
proto firewall.Protocol,
|
||||||
|
direction int,
|
||||||
|
port *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
) id.RuleID {
|
||||||
|
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action))
|
||||||
|
if port != nil {
|
||||||
|
idStr += port.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRuleGroupingSelector takes all rule properties except IP address to build selector
|
||||||
|
func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) string {
|
||||||
|
return fmt.Sprintf("%v:%v:%v:%s:%v", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port, rule.PortInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// extractRuleIP extracts the peer IP from a firewall rule.
|
||||||
|
// If sourcePrefixes is populated (new management), decode the first entry and use its address.
|
||||||
|
// Otherwise fall back to the deprecated PeerIP string field (old management).
|
||||||
|
func extractRuleIP(r *mgmProto.FirewallRule) (netip.Addr, error) {
|
||||||
|
if len(r.SourcePrefixes) > 0 {
|
||||||
|
addr, err := netiputil.DecodeAddr(r.SourcePrefixes[0])
|
||||||
|
if err != nil {
|
||||||
|
return netip.Addr{}, fmt.Errorf("decode source prefix: %w", err)
|
||||||
|
}
|
||||||
|
return addr.Unmap(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:staticcheck // PeerIP used for backward compatibility with old management
|
||||||
|
addr, err := netip.ParseAddr(r.PeerIP)
|
||||||
|
if err != nil {
|
||||||
|
return netip.Addr{}, fmt.Errorf("invalid IP address, skipping firewall rule")
|
||||||
|
}
|
||||||
|
return addr.Unmap(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
|
||||||
switch protocol {
|
switch protocol {
|
||||||
case mgmProto.RuleProtocol_TCP:
|
case mgmProto.RuleProtocol_TCP:
|
||||||
return firewall.ProtocolTCP, nil
|
return firewall.ProtocolTCP, nil
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
fwmanager "github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||||
@@ -77,9 +76,9 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("add extra rules", func(t *testing.T) {
|
t.Run("add extra rules", func(t *testing.T) {
|
||||||
existedPairs := map[fwmanager.RuleID]struct{}{}
|
existedPairs := map[string]struct{}{}
|
||||||
for id := range acl.peerRulesPairs {
|
for id := range acl.peerRulesPairs {
|
||||||
existedPairs[id] = struct{}{}
|
existedPairs[id.ID()] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove first rule
|
// remove first rule
|
||||||
@@ -106,7 +105,7 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
// check that old rule was removed
|
// check that old rule was removed
|
||||||
previousCount := 0
|
previousCount := 0
|
||||||
for id := range acl.peerRulesPairs {
|
for id := range acl.peerRulesPairs {
|
||||||
if _, ok := existedPairs[id]; ok {
|
if _, ok := existedPairs[id.ID()]; ok {
|
||||||
previousCount++
|
previousCount++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -360,13 +360,7 @@ func isRedirectURLPortUsed(redirectURL string, excludedRanges []excludedPortRang
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// FreeBSD 15 disables connecting to INADDR_ANY (0.0.0.0) as a localhost
|
addr := fmt.Sprintf(":%s", port)
|
||||||
// alias by default, ensure explicit ip for localhost.
|
|
||||||
host := parsedURL.Hostname()
|
|
||||||
if host == "" {
|
|
||||||
host = "127.0.0.1"
|
|
||||||
}
|
|
||||||
addr := net.JoinHostPort(host, port)
|
|
||||||
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
|
conn, err := net.DialTimeout("tcp", addr, 3*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user