Compare commits

..

3 Commits

Author SHA1 Message Date
bcmmbaga
feb8e90ae1 Evaluate all applied posture checks on source peers only
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-12-27 23:28:34 +03:00
bcmmbaga
076d6d8a87 Evaluate all applied posture checks once
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-12-27 22:12:47 +03:00
bcmmbaga
c8c25221bd Apply policy posture checks on peer
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-12-27 21:49:28 +03:00
575 changed files with 11918 additions and 46209 deletions

View File

@@ -1,4 +1,4 @@
FROM golang:1.23-bullseye FROM golang:1.21-bullseye
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
&& apt-get -y install --no-install-recommends\ && apt-get -y install --no-install-recommends\

View File

@@ -7,7 +7,7 @@
"features": { "features": {
"ghcr.io/devcontainers/features/docker-in-docker:2": {}, "ghcr.io/devcontainers/features/docker-in-docker:2": {},
"ghcr.io/devcontainers/features/go:1": { "ghcr.io/devcontainers/features/go:1": {
"version": "1.23" "version": "1.21"
} }
}, },
"workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}", "workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}",

View File

@@ -31,22 +31,14 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
`netbird version` `netbird version`
**Is any other VPN software installed?** **NetBird status -dA output:**
If yes, which one? If applicable, add the `netbird status -dA' command output.
**Debug output** **Do you face any (non-mobile) client issues?**
To help us resolve the problem, please attach the following debug output Please provide the file created by `netbird debug for 1m -AS`.
We advise reviewing the anonymized files for any remaining PII.
netbird status -dA
As well as the file created by
netbird debug for 1m -AS
We advise reviewing the anonymized output for any remaining personal information.
**Screenshots** **Screenshots**
@@ -55,10 +47,3 @@ If applicable, add screenshots to help explain your problem.
**Additional context** **Additional context**
Add any other context about the problem here. Add any other context about the problem here.
**Have you tried these troubleshooting steps?**
- [ ] Checked for newer NetBird versions
- [ ] Searched for similar issues on GitHub (including closed ones)
- [ ] Restarted the NetBird client
- [ ] Disabled other VPN software
- [ ] Checked firewall settings

View File

@@ -1,4 +1,4 @@
name: "Darwin" name: Test Code Darwin
on: on:
push: push:
@@ -12,7 +12,9 @@ concurrency:
jobs: jobs:
test: test:
name: "Client / Unit" strategy:
matrix:
store: ['sqlite']
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- name: Install Go - name: Install Go
@@ -42,5 +44,4 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management) run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)

View File

@@ -1,4 +1,5 @@
name: "FreeBSD"
name: Test Code FreeBSD
on: on:
push: push:
@@ -12,7 +13,6 @@ concurrency:
jobs: jobs:
test: test:
name: "Client / Unit"
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
@@ -24,7 +24,7 @@ jobs:
copyback: false copyback: false
release: "14.1" release: "14.1"
prepare: | prepare: |
pkg install -y go pkgconf xorg pkg install -y go
# -x - to print all executed commands # -x - to print all executed commands
# -e - to faile on first error # -e - to faile on first error
@@ -33,7 +33,7 @@ jobs:
time go build -o netbird client/main.go time go build -o netbird client/main.go
# check all component except management, since we do not support management server on freebsd # check all component except management, since we do not support management server on freebsd
time go test -timeout 1m -failfast ./base62/... time go test -timeout 1m -failfast ./base62/...
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use` # NOTE: without -p1 `client/internal/dns` will fail becasue of `listen udp4 :33100: bind: address already in use`
time go test -timeout 8m -failfast -p 1 ./client/... time go test -timeout 8m -failfast -p 1 ./client/...
time go test -timeout 1m -failfast ./dns/... time go test -timeout 1m -failfast ./dns/...
time go test -timeout 1m -failfast ./encryption/... time go test -timeout 1m -failfast ./encryption/...

View File

@@ -1,4 +1,4 @@
name: Linux name: Test Code Linux
on: on:
push: push:
@@ -12,21 +12,11 @@ concurrency:
jobs: jobs:
build-cache: build-cache:
name: "Build Cache"
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
outputs: steps:
management: ${{ steps.filter.outputs.management }}
steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
- uses: dorny/paths-filter@v3
id: filter
with:
filters: |
management:
- 'management/**'
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
@@ -48,6 +38,7 @@ jobs:
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: | restore-keys: |
${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
- name: Install dependencies - name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true' if: steps.cache.outputs.cache-hit != 'true'
@@ -98,7 +89,6 @@ jobs:
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 . run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
test: test:
name: "Client / Unit"
needs: [build-cache] needs: [build-cache]
strategy: strategy:
fail-fast: false fail-fast: false
@@ -144,324 +134,14 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay) run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management)
test_relay:
name: "Relay / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test \
-exec 'sudo' \
-timeout 10m ./signal/...
test_signal:
name: "Signal / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test \
-exec 'sudo' \
-timeout 10m ./signal/...
test_management: test_management:
name: "Management / Unit"
needs: [ build-cache ] needs: [ build-cache ]
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
arch: [ 'amd64' ] arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Login to Docker hub
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
go test -tags=devcert \
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
-timeout 20m ./management/...
benchmark:
name: "Management / Benchmark"
needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Login to Docker hub
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
go test -tags devcert -run=^$ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./...
api_benchmark:
name: "Management / Benchmark (API)"
needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Login to Docker hub
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
go test -tags=benchmark \
-run=^$ \
-bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/...
api_integration_test:
name: "Management / Integration"
needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres'] store: [ 'sqlite', 'postgres']
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
@@ -503,15 +183,58 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management)
benchmark:
needs: [ build-cache ]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: | run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \ echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
go test -tags=integration \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ - name: Cache Go modules
-timeout 20m ./management/... uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
test_client_on_docker: test_client_on_docker:
name: "Client (Docker) / Unit"
needs: [ build-cache ] needs: [ build-cache ]
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
steps: steps:

View File

@@ -1,4 +1,4 @@
name: "Windows" name: Test Code Windows
on: on:
push: push:
@@ -14,7 +14,6 @@ concurrency:
jobs: jobs:
test: test:
name: "Client / Unit"
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- name: Checkout code - name: Checkout code
@@ -66,7 +65,7 @@ jobs:
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV - run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
- name: test - name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1" run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
- name: test output - name: test output
if: ${{ always() }} if: ${{ always() }}
run: Get-Content test-out.txt run: Get-Content test-out.txt

View File

@@ -1,4 +1,4 @@
name: Lint name: golangci-lint
on: [pull_request] on: [pull_request]
permissions: permissions:
@@ -19,7 +19,7 @@ jobs:
- name: codespell - name: codespell
uses: codespell-project/actions-codespell@v2 uses: codespell-project/actions-codespell@v2
with: with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin ignore_words_list: erro,clienta,hastable,iif,groupd
skip: go.mod,go.sum skip: go.mod,go.sum
only_warn: 1 only_warn: 1
golangci: golangci:
@@ -27,14 +27,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
os: [macos-latest, windows-latest, ubuntu-latest] os: [macos-latest, windows-latest, ubuntu-latest]
include: name: lint
- os: macos-latest
display_name: Darwin
- os: windows-latest
display_name: Windows
- os: ubuntu-latest
display_name: Linux
name: ${{ matrix.display_name }}
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
timeout-minutes: 15 timeout-minutes: 15
steps: steps:

View File

@@ -1,4 +1,4 @@
name: Mobile name: Mobile build validation
on: on:
push: push:
@@ -12,7 +12,6 @@ concurrency:
jobs: jobs:
android_build: android_build:
name: "Android / Build"
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
@@ -48,7 +47,6 @@ jobs:
CGO_ENABLED: 0 CGO_ENABLED: 0
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620 ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
ios_build: ios_build:
name: "iOS / Build"
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- name: Checkout repository - name: Checkout repository

View File

@@ -9,10 +9,10 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.0.18" SIGN_PIPE_VER: "v0.0.17"
GORELEASER_VER: "v2.3.2" GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH" COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
@@ -71,7 +71,7 @@ jobs:
- name: Install goversioninfo - name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64 - name: Generate windows syso amd64
run: goversioninfo -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
- name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
with: with:
@@ -150,7 +150,7 @@ jobs:
- name: Install goversioninfo - name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64 - name: Generate windows syso amd64
run: goversioninfo -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
- name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4

View File

@@ -20,7 +20,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
store: [ 'sqlite', 'postgres', 'mysql' ] store: [ 'sqlite', 'postgres' ]
services: services:
postgres: postgres:
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }} image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
@@ -34,19 +34,6 @@ jobs:
--health-timeout 5s --health-timeout 5s
ports: ports:
- 5432:5432 - 5432:5432
mysql:
image: ${{ (matrix.store == 'mysql') && 'mysql' || '' }}
env:
MYSQL_USER: netbird
MYSQL_PASSWORD: mysql
MYSQL_ROOT_PASSWORD: mysqlroot
MYSQL_DATABASE: netbird
options: >-
--health-cmd "mysqladmin ping --silent"
--health-interval 10s
--health-timeout 5s
ports:
- 3306:3306
steps: steps:
- name: Set Database Connection String - name: Set Database Connection String
run: | run: |
@@ -55,11 +42,6 @@ jobs:
else else
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN==" >> $GITHUB_ENV echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN==" >> $GITHUB_ENV
fi fi
if [ "${{ matrix.store }}" == "mysql" ]; then
echo "NETBIRD_STORE_ENGINE_MYSQL_DSN=netbird:mysql@tcp($(hostname -I | awk '{print $1}'):3306)/netbird" >> $GITHUB_ENV
else
echo "NETBIRD_STORE_ENGINE_MYSQL_DSN==" >> $GITHUB_ENV
fi
- name: Install jq - name: Install jq
run: sudo apt-get install -y jq run: sudo apt-get install -y jq
@@ -102,7 +84,6 @@ jobs:
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified" CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }} CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
NETBIRD_STORE_ENGINE_POSTGRES_DSN: ${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }} NETBIRD_STORE_ENGINE_POSTGRES_DSN: ${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}
NETBIRD_STORE_ENGINE_MYSQL_DSN: ${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
- name: check values - name: check values
@@ -131,7 +112,6 @@ jobs:
CI_NETBIRD_SIGNAL_PORT: 12345 CI_NETBIRD_SIGNAL_PORT: 12345
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }} CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$' NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4" CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
@@ -169,7 +149,6 @@ jobs:
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES" grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000" grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN" grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
# check relay values # check relay values
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml

1
.gitignore vendored
View File

@@ -29,4 +29,3 @@ infrastructure_files/setup.env
infrastructure_files/setup-*.env infrastructure_files/setup-*.env
.vscode .vscode
.DS_Store .DS_Store
vendor/

View File

@@ -103,7 +103,7 @@ linters:
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers - predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint. - revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed - sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
# - thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers. - thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
- wastedassign # wastedassign finds wasted assignment statements - wastedassign # wastedassign finds wasted assignment statements
issues: issues:
# Maximum count of issues with the same text. # Maximum count of issues with the same text.

View File

@@ -179,51 +179,6 @@ dockers:
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-amd64
ids:
- netbird
goarch: amd64
use: buildx
dockerfile: client/Dockerfile-rootless
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
ids:
- netbird
goarch: arm64
use: buildx
dockerfile: client/Dockerfile-rootless
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm
ids:
- netbird
goarch: arm
goarm: 6
use: buildx
dockerfile: client/Dockerfile-rootless
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=maintainer=dev@netbird.io"
- image_templates: - image_templates:
- netbirdio/relay:{{ .Version }}-amd64 - netbirdio/relay:{{ .Version }}-amd64
ids: ids:
@@ -422,18 +377,6 @@ docker_manifests:
- netbirdio/netbird:{{ .Version }}-arm - netbirdio/netbird:{{ .Version }}-arm
- netbirdio/netbird:{{ .Version }}-amd64 - netbirdio/netbird:{{ .Version }}-amd64
- name_template: netbirdio/netbird:{{ .Version }}-rootless
image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- netbirdio/netbird:{{ .Version }}-rootless-arm
- netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: netbirdio/netbird:rootless-latest
image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- netbirdio/netbird:{{ .Version }}-rootless-arm
- netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: netbirdio/relay:{{ .Version }} - name_template: netbirdio/relay:{{ .Version }}
image_templates: image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8 - netbirdio/relay:{{ .Version }}-arm64v8

View File

@@ -50,12 +50,10 @@ nfpms:
- netbird-ui - netbird-ui
formats: formats:
- deb - deb
scripts:
postinstall: "release_files/ui-post-install.sh"
contents: contents:
- src: client/ui/build/netbird.desktop - src: client/ui/netbird.desktop
dst: /usr/share/applications/netbird.desktop dst: /usr/share/applications/netbird.desktop
- src: client/ui/assets/netbird.png - src: client/ui/netbird-systemtray-connected.png
dst: /usr/share/pixmaps/netbird.png dst: /usr/share/pixmaps/netbird.png
dependencies: dependencies:
- netbird - netbird
@@ -69,12 +67,10 @@ nfpms:
- netbird-ui - netbird-ui
formats: formats:
- rpm - rpm
scripts:
postinstall: "release_files/ui-post-install.sh"
contents: contents:
- src: client/ui/build/netbird.desktop - src: client/ui/netbird.desktop
dst: /usr/share/applications/netbird.desktop dst: /usr/share/applications/netbird.desktop
- src: client/ui/assets/netbird.png - src: client/ui/netbird-systemtray-connected.png
dst: /usr/share/pixmaps/netbird.png dst: /usr/share/pixmaps/netbird.png
dependencies: dependencies:
- netbird - netbird

View File

@@ -1,3 +1,3 @@
Mikhail Bragin (https://github.com/braginini) Mikhail Bragin (https://github.com/braginini)
Maycon Santos (https://github.com/mlsmaycon) Maycon Santos (https://github.com/mlsmaycon)
NetBird GmbH Wiretrustee UG (haftungsbeschränkt)

View File

@@ -3,10 +3,10 @@
We are incredibly thankful for the contributions we receive from the community. We are incredibly thankful for the contributions we receive from the community.
We require our external contributors to sign a Contributor License Agreement ("CLA") in We require our external contributors to sign a Contributor License Agreement ("CLA") in
order to ensure that our projects remain licensed under Free and Open Source licenses such order to ensure that our projects remain licensed under Free and Open Source licenses such
as BSD-3 while allowing NetBird to build a sustainable business. as BSD-3 while allowing Wiretrustee to build a sustainable business.
NetBird is committed to having a true Open Source Software ("OSS") license for Wiretrustee is committed to having a true Open Source Software ("OSS") license for
our software. A CLA enables NetBird to safely commercialize our products our software. A CLA enables Wiretrustee to safely commercialize our products
while keeping a standard OSS license with all the rights that license grants to users: the while keeping a standard OSS license with all the rights that license grants to users: the
ability to use the project in their own projects or businesses, to republish modified ability to use the project in their own projects or businesses, to republish modified
source, or to completely fork the project. source, or to completely fork the project.
@@ -20,11 +20,11 @@ This is a human-readable summary of (and not a substitute for) the full agreemen
This highlights only some of key terms of the CLA. It has no legal value and you should This highlights only some of key terms of the CLA. It has no legal value and you should
carefully review all the terms of the actual CLA before agreeing. carefully review all the terms of the actual CLA before agreeing.
<li>Grant of copyright license. You give NetBird permission to use your copyrighted work <li>Grant of copyright license. You give Wiretrustee permission to use your copyrighted work
in commercial products. in commercial products.
</li> </li>
<li>Grant of patent license. If your contributed work uses a patent, you give NetBird a <li>Grant of patent license. If your contributed work uses a patent, you give Wiretrustee a
license to use that patent including within commercial products. You also agree that you license to use that patent including within commercial products. You also agree that you
have permission to grant this license. have permission to grant this license.
</li> </li>
@@ -45,7 +45,7 @@ more.
# Why require a CLA? # Why require a CLA?
Agreeing to a CLA explicitly states that you are entitled to provide a contribution, that you cannot withdraw permission Agreeing to a CLA explicitly states that you are entitled to provide a contribution, that you cannot withdraw permission
to use your contribution at a later date, and that NetBird has permission to use your contribution in our commercial to use your contribution at a later date, and that Wiretrustee has permission to use your contribution in our commercial
products. products.
This removes any ambiguities or uncertainties caused by not having a CLA and allows users and customers to confidently This removes any ambiguities or uncertainties caused by not having a CLA and allows users and customers to confidently
@@ -65,25 +65,25 @@ Follow the steps given by the bot to sign the CLA. This will require you to log
information from your account) and to fill in a few additional details such as your name and email address. We will only information from your account) and to fill in a few additional details such as your name and email address. We will only
use this information for CLA tracking; none of your submitted information will be used for marketing purposes. use this information for CLA tracking; none of your submitted information will be used for marketing purposes.
You only have to sign the CLA once. Once you've signed the CLA, future contributions to any NetBird project will not You only have to sign the CLA once. Once you've signed the CLA, future contributions to any Wiretrustee project will not
require you to sign again. require you to sign again.
# Legal Terms and Agreement # Legal Terms and Agreement
In order to clarify the intellectual property license granted with Contributions from any person or entity, NetBird In order to clarify the intellectual property license granted with Contributions from any person or entity, Wiretrustee
GmbH ("NetBird") must have a Contributor License Agreement ("CLA") on file that has been signed UG (haftungsbeschränkt) ("Wiretrustee") must have a Contributor License Agreement ("CLA") on file that has been signed
by each Contributor, indicating agreement to the license terms below. This license does not change your rights to use by each Contributor, indicating agreement to the license terms below. This license does not change your rights to use
your own Contributions for any other purpose. your own Contributions for any other purpose.
You accept and agree to the following terms and conditions for Your present and future Contributions submitted to You accept and agree to the following terms and conditions for Your present and future Contributions submitted to
NetBird. Except for the license granted herein to NetBird and recipients of software distributed by NetBird, Wiretrustee. Except for the license granted herein to Wiretrustee and recipients of software distributed by Wiretrustee,
You reserve all right, title, and interest in and to Your Contributions. You reserve all right, title, and interest in and to Your Contributions.
1. Definitions. 1. Definitions.
``` ```
"You" (or "Your") shall mean the copyright owner or legal entity authorized by the copyright owner "You" (or "Your") shall mean the copyright owner or legal entity authorized by the copyright owner
that is making this Agreement with NetBird. For legal entities, the entity making a Contribution and all other that is making this Agreement with Wiretrustee. For legal entities, the entity making a Contribution and all other
entities that control, are controlled by, or are under common control with that entity are considered entities that control, are controlled by, or are under common control with that entity are considered
to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect, to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect,
to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty
@@ -91,23 +91,23 @@ You reserve all right, title, and interest in and to Your Contributions.
``` ```
``` ```
"Contribution" shall mean any original work of authorship, including any modifications or additions to "Contribution" shall mean any original work of authorship, including any modifications or additions to
an existing work, that is or previously has been intentionally submitted by You to NetBird for inclusion in, an existing work, that is or previously has been intentionally submitted by You to Wiretrustee for inclusion in,
or documentation of, any of the products owned or managed by NetBird (the "Work"). or documentation of, any of the products owned or managed by Wiretrustee (the "Work").
For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication
sent to NetBird or its representatives, including but not limited to communication on electronic mailing lists, sent to Wiretrustee or its representatives, including but not limited to communication on electronic mailing lists,
source code control systems, and issue tracking systems that are managed by, or on behalf of, source code control systems, and issue tracking systems that are managed by, or on behalf of,
NetBird for the purpose of discussing and improving the Work, but excluding communication that is conspicuously Wiretrustee for the purpose of discussing and improving the Work, but excluding communication that is conspicuously
marked or otherwise designated in writing by You as "Not a Contribution." marked or otherwise designated in writing by You as "Not a Contribution."
``` ```
2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird 2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, You hereby grant to Wiretrustee
and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, and to recipients of software distributed by Wiretrustee a perpetual, worldwide, non-exclusive, no-charge,
royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly
perform, sublicense, and distribute Your Contributions and such derivative works. perform, sublicense, and distribute Your Contributions and such derivative works.
3. Grant of Patent License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird and 3. Grant of Patent License. Subject to the terms and conditions of this Agreement, You hereby grant to Wiretrustee and
to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free, to recipients of software distributed by Wiretrustee a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import,
and otherwise transfer the Work, where such license applies only to those patent claims licensable by You that are and otherwise transfer the Work, where such license applies only to those patent claims licensable by You that are
necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Work to which necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Work to which
@@ -121,8 +121,8 @@ You reserve all right, title, and interest in and to Your Contributions.
intellectual property that you create that includes your Contributions, you represent that you have received intellectual property that you create that includes your Contributions, you represent that you have received
permission to make Contributions on behalf of that employer, that you will have received permission from your current permission to make Contributions on behalf of that employer, that you will have received permission from your current
and future employers for all future Contributions, that your applicable employer has waived such rights for all of and future employers for all future Contributions, that your applicable employer has waived such rights for all of
your current and future Contributions to NetBird, or that your employer has executed a separate Corporate CLA your current and future Contributions to Wiretrustee, or that your employer has executed a separate Corporate CLA
with NetBird. with Wiretrustee.
5. You represent that each of Your Contributions is Your original creation (see section 7 for submissions on behalf of 5. You represent that each of Your Contributions is Your original creation (see section 7 for submissions on behalf of
@@ -138,11 +138,11 @@ You reserve all right, title, and interest in and to Your Contributions.
MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE.
7. Should You wish to submit work that is not Your original creation, You may submit it to NetBird separately from 7. Should You wish to submit work that is not Your original creation, You may submit it to Wiretrustee separately from
any Contribution, identifying the complete details of its source and of any license or other restriction (including, any Contribution, identifying the complete details of its source and of any license or other restriction (including,
but not limited to, related patents, trademarks, and license agreements) of which you are personally aware, and but not limited to, related patents, trademarks, and license agreements) of which you are personally aware, and
conspicuously marking the work as "Submitted on behalf of a third-party: [named here]". conspicuously marking the work as "Submitted on behalf of a third-party: [named here]".
8. You agree to notify NetBird of any facts or circumstances of which you become aware that would make these 8. You agree to notify Wiretrustee of any facts or circumstances of which you become aware that would make these
representations inaccurate in any respect. representations inaccurate in any respect.

View File

@@ -1,6 +1,6 @@
BSD 3-Clause License BSD 3-Clause License
Copyright (c) 2022 NetBird GmbH & AUTHORS Copyright (c) 2022 Wiretrustee UG (haftungsbeschränkt) & AUTHORS
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
@@ -10,4 +10,4 @@ Redistribution and use in source and binary forms, with or without modification,
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@@ -1,6 +1,11 @@
<div align="center"> <p align="center">
<strong>:hatching_chick: New Release! Device Posture Checks.</strong>
<a href="https://docs.netbird.io/how-to/manage-posture-checks">
Learn more
</a>
</p>
<br/> <br/>
<br/> <div align="center">
<p align="center"> <p align="center">
<img width="234" src="docs/media/logo-full.png"/> <img width="234" src="docs/media/logo-full.png"/>
</p> </p>
@@ -33,10 +38,6 @@
<br/> <br/>
</strong> </strong>
<br>
<a href="https://netbird.io/webinars/achieve-zero-trust-access-to-k8s?utm_source=github&utm_campaign=2502%20-%20webinar%20-%20How%20to%20Achieve%20Zero%20Trust%20Access%20to%20Kubernetes%20-%20Effortlessly&utm_medium=github">
Webinar: Securely Access Kubernetes without Port Forwarding and Jump Hosts
</a>
</p> </p>
<br> <br>

View File

@@ -1,4 +1,4 @@
FROM alpine:3.21.3 FROM alpine:3.21.0
RUN apk add --no-cache ca-certificates iptables ip6tables RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"] ENTRYPOINT [ "/usr/local/bin/netbird","up"]

View File

@@ -1,17 +0,0 @@
FROM alpine:3.21.0
COPY netbird /usr/local/bin/netbird
RUN apk add --no-cache ca-certificates \
&& adduser -D -h /var/lib/netbird netbird
WORKDIR /var/lib/netbird
USER netbird:netbird
ENV NB_FOREGROUND_MODE=true
ENV NB_USE_NETSTACK_MODE=true
ENV NB_ENABLE_NETSTACK_LOCAL_FORWARDING=true
ENV NB_CONFIG=config.json
ENV NB_DAEMON_ADDR=unix://netbird.sock
ENV NB_DISABLE_DNS=true
ENTRYPOINT [ "/usr/local/bin/netbird", "up" ]

View File

@@ -162,7 +162,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
// check if we need to generate JWT token // check if we need to generate JWT token
err := a.withBackOff(a.ctx, func() (err error) { err := a.withBackOff(a.ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config) needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey)
return return
}) })
if err != nil { if err != nil {

View File

@@ -13,7 +13,6 @@ import (
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server" "github.com/netbirdio/netbird/client/server"
nbstatus "github.com/netbirdio/netbird/client/status"
) )
const errCloseConnection = "Failed to close connection: %v" const errCloseConnection = "Failed to close connection: %v"
@@ -86,7 +85,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{ resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag, Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd, anonymizeFlag), Status: getStatusOutput(cmd),
SystemInfo: debugSystemInfoFlag, SystemInfo: debugSystemInfoFlag,
}) })
if err != nil { if err != nil {
@@ -197,7 +196,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
headerPostUp := fmt.Sprintf("----- Netbird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) headerPostUp := fmt.Sprintf("----- Netbird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag)) statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd))
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil { if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
return waitErr return waitErr
@@ -207,7 +206,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Creating debug bundle...") cmd.Println("Creating debug bundle...")
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration) headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag)) statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{ resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag, Anonymize: anonymizeFlag,
@@ -272,15 +271,13 @@ func setNetworkMapPersistence(cmd *cobra.Command, args []string) error {
return nil return nil
} }
func getStatusOutput(cmd *cobra.Command, anon bool) string { func getStatusOutput(cmd *cobra.Command) string {
var statusOutputString string var statusOutputString string
statusResp, err := getStatus(cmd.Context()) statusResp, err := getStatus(cmd.Context())
if err != nil { if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err) cmd.PrintErrf("Failed to get status: %v\n", err)
} else { } else {
statusOutputString = nbstatus.ParseToFullDetailSummary( statusOutputString = parseToFullDetailSummary(convertToStatusOutputOverview(statusResp))
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil),
)
} }
return statusOutputString return statusOutputString
} }

View File

@@ -1,98 +0,0 @@
package cmd
import (
"fmt"
"sort"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var forwardingRulesCmd = &cobra.Command{
Use: "forwarding",
Short: "List forwarding rules",
Long: `Commands to list forwarding rules.`,
}
var forwardingRulesListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List forwarding rules",
Example: " netbird forwarding list",
Long: "Commands to list forwarding rules.",
RunE: listForwardingRules,
}
func listForwardingRules(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.ForwardingRules(cmd.Context(), &proto.EmptyRequest{})
if err != nil {
return fmt.Errorf("failed to list network: %v", status.Convert(err).Message())
}
if len(resp.GetRules()) == 0 {
cmd.Println("No forwarding rules available.")
return nil
}
printForwardingRules(cmd, resp.GetRules())
return nil
}
func printForwardingRules(cmd *cobra.Command, rules []*proto.ForwardingRule) {
cmd.Println("Available forwarding rules:")
// Sort rules by translated address
sort.Slice(rules, func(i, j int) bool {
if rules[i].GetTranslatedAddress() != rules[j].GetTranslatedAddress() {
return rules[i].GetTranslatedAddress() < rules[j].GetTranslatedAddress()
}
if rules[i].GetProtocol() != rules[j].GetProtocol() {
return rules[i].GetProtocol() < rules[j].GetProtocol()
}
return getFirstPort(rules[i].GetDestinationPort()) < getFirstPort(rules[j].GetDestinationPort())
})
var lastIP string
for _, rule := range rules {
dPort := portToString(rule.GetDestinationPort())
tPort := portToString(rule.GetTranslatedPort())
if lastIP != rule.GetTranslatedAddress() {
lastIP = rule.GetTranslatedAddress()
cmd.Printf("\nTranslated peer: %s\n", rule.GetTranslatedHostname())
}
cmd.Printf(" Local %s/%s to %s:%s\n", rule.GetProtocol(), dPort, rule.GetTranslatedAddress(), tPort)
}
}
func getFirstPort(portInfo *proto.PortInfo) int {
switch v := portInfo.PortSelection.(type) {
case *proto.PortInfo_Port:
return int(v.Port)
case *proto.PortInfo_Range_:
return int(v.Range.GetStart())
default:
return 0
}
}
func portToString(translatedPort *proto.PortInfo) string {
switch v := translatedPort.PortSelection.(type) {
case *proto.PortInfo_Port:
return fmt.Sprintf("%d", v.Port)
case *proto.PortInfo_Range_:
return fmt.Sprintf("%d-%d", v.Range.GetStart(), v.Range.GetEnd())
default:
return "No port specified"
}
}

View File

@@ -85,17 +85,11 @@ var loginCmd = &cobra.Command{
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
var dnsLabelsReq []string
if dnsLabelsValidated != nil {
dnsLabelsReq = dnsLabelsValidated.ToSafeStringList()
}
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey, SetupKey: providedSetupKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
IsLinuxDesktopClient: isLinuxRunningDesktop(), IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName, Hostname: hostName,
DnsLabels: dnsLabelsReq,
} }
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {

View File

@@ -38,7 +38,6 @@ const (
extraIFaceBlackListFlag = "extra-iface-blacklist" extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval" dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info" systemInfoFlag = "system-info"
blockLANAccessFlag = "block-lan-access"
) )
var ( var (
@@ -74,7 +73,6 @@ var (
anonymizeFlag bool anonymizeFlag bool
debugSystemInfoFlag bool debugSystemInfoFlag bool
dnsRouteInterval time.Duration dnsRouteInterval time.Duration
blockLANAccess bool
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "netbird", Use: "netbird",
@@ -145,7 +143,6 @@ func init() {
rootCmd.AddCommand(versionCmd) rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd) rootCmd.AddCommand(sshCmd)
rootCmd.AddCommand(networksCMD) rootCmd.AddCommand(networksCMD)
rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(debugCmd) rootCmd.AddCommand(debugCmd)
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
@@ -154,8 +151,6 @@ func init() {
networksCMD.AddCommand(routesListCmd) networksCMD.AddCommand(routesListCmd)
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd) networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
forwardingRulesCmd.AddCommand(forwardingRulesListCmd)
debugCmd.AddCommand(debugBundleCmd) debugCmd.AddCommand(debugBundleCmd)
debugCmd.AddCommand(logCmd) debugCmd.AddCommand(logCmd)
logCmd.AddCommand(logLevelCmd) logCmd.AddCommand(logLevelCmd)

View File

@@ -9,6 +9,7 @@ import (
"strings" "strings"
"syscall" "syscall"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
@@ -72,7 +73,7 @@ var sshCmd = &cobra.Command{
go func() { go func() {
// blocking // blocking
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil { if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
cmd.Printf("Error: %v\n", err) log.Debug(err)
os.Exit(1) os.Exit(1)
} }
cancel() cancel()

View File

@@ -2,20 +2,107 @@ package cmd
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"os"
"runtime"
"sort"
"strings" "strings"
"time"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"gopkg.in/yaml.v3"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
) )
type peerStateDetailOutput struct {
FQDN string `json:"fqdn" yaml:"fqdn"`
IP string `json:"netbirdIp" yaml:"netbirdIp"`
PubKey string `json:"publicKey" yaml:"publicKey"`
Status string `json:"status" yaml:"status"`
LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
ConnType string `json:"connectionType" yaml:"connectionType"`
IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
IceCandidateEndpoint iceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"`
RelayAddress string `json:"relayAddress" yaml:"relayAddress"`
LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"`
TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"`
TransferSent int64 `json:"transferSent" yaml:"transferSent"`
Latency time.Duration `json:"latency" yaml:"latency"`
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
Routes []string `json:"routes" yaml:"routes"`
Networks []string `json:"networks" yaml:"networks"`
}
type peersStateOutput struct {
Total int `json:"total" yaml:"total"`
Connected int `json:"connected" yaml:"connected"`
Details []peerStateDetailOutput `json:"details" yaml:"details"`
}
type signalStateOutput struct {
URL string `json:"url" yaml:"url"`
Connected bool `json:"connected" yaml:"connected"`
Error string `json:"error" yaml:"error"`
}
type managementStateOutput struct {
URL string `json:"url" yaml:"url"`
Connected bool `json:"connected" yaml:"connected"`
Error string `json:"error" yaml:"error"`
}
type relayStateOutputDetail struct {
URI string `json:"uri" yaml:"uri"`
Available bool `json:"available" yaml:"available"`
Error string `json:"error" yaml:"error"`
}
type relayStateOutput struct {
Total int `json:"total" yaml:"total"`
Available int `json:"available" yaml:"available"`
Details []relayStateOutputDetail `json:"details" yaml:"details"`
}
type iceCandidateType struct {
Local string `json:"local" yaml:"local"`
Remote string `json:"remote" yaml:"remote"`
}
type nsServerGroupStateOutput struct {
Servers []string `json:"servers" yaml:"servers"`
Domains []string `json:"domains" yaml:"domains"`
Enabled bool `json:"enabled" yaml:"enabled"`
Error string `json:"error" yaml:"error"`
}
type statusOutputOverview struct {
Peers peersStateOutput `json:"peers" yaml:"peers"`
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"`
ManagementState managementStateOutput `json:"management" yaml:"management"`
SignalState signalStateOutput `json:"signal" yaml:"signal"`
Relays relayStateOutput `json:"relays" yaml:"relays"`
IP string `json:"netbirdIp" yaml:"netbirdIp"`
PubKey string `json:"publicKey" yaml:"publicKey"`
KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
FQDN string `json:"fqdn" yaml:"fqdn"`
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
Routes []string `json:"routes" yaml:"routes"`
Networks []string `json:"networks" yaml:"networks"`
NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
}
var ( var (
detailFlag bool detailFlag bool
ipv4Flag bool ipv4Flag bool
@@ -86,17 +173,18 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil return nil
} }
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap) outputInformationHolder := convertToStatusOutputOverview(resp)
var statusOutputString string var statusOutputString string
switch { switch {
case detailFlag: case detailFlag:
statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder) statusOutputString = parseToFullDetailSummary(outputInformationHolder)
case jsonFlag: case jsonFlag:
statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder) statusOutputString, err = parseToJSON(outputInformationHolder)
case yamlFlag: case yamlFlag:
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder) statusOutputString, err = parseToYAML(outputInformationHolder)
default: default:
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false) statusOutputString = parseGeneralSummary(outputInformationHolder, false, false, false)
} }
if err != nil { if err != nil {
@@ -126,6 +214,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
} }
func parseFilters() error { func parseFilters() error {
switch strings.ToLower(statusFilter) { switch strings.ToLower(statusFilter) {
case "", "disconnected", "connected": case "", "disconnected", "connected":
if strings.ToLower(statusFilter) != "" { if strings.ToLower(statusFilter) != "" {
@@ -162,6 +251,175 @@ func enableDetailFlagWhenFilterFlag() {
} }
} }
func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverview {
pbFullStatus := resp.GetFullStatus()
managementState := pbFullStatus.GetManagementState()
managementOverview := managementStateOutput{
URL: managementState.GetURL(),
Connected: managementState.GetConnected(),
Error: managementState.Error,
}
signalState := pbFullStatus.GetSignalState()
signalOverview := signalStateOutput{
URL: signalState.GetURL(),
Connected: signalState.GetConnected(),
Error: signalState.Error,
}
relayOverview := mapRelays(pbFullStatus.GetRelays())
peersOverview := mapPeers(resp.GetFullStatus().GetPeers())
overview := statusOutputOverview{
Peers: peersOverview,
CliVersion: version.NetbirdVersion(),
DaemonVersion: resp.GetDaemonVersion(),
ManagementState: managementOverview,
SignalState: signalOverview,
Relays: relayOverview,
IP: pbFullStatus.GetLocalPeerState().GetIP(),
PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(),
KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(),
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
Routes: pbFullStatus.GetLocalPeerState().GetNetworks(),
Networks: pbFullStatus.GetLocalPeerState().GetNetworks(),
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
}
if anonymizeFlag {
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
anonymizeOverview(anonymizer, &overview)
}
return overview
}
func mapRelays(relays []*proto.RelayState) relayStateOutput {
var relayStateDetail []relayStateOutputDetail
var relaysAvailable int
for _, relay := range relays {
available := relay.GetAvailable()
relayStateDetail = append(relayStateDetail,
relayStateOutputDetail{
URI: relay.URI,
Available: available,
Error: relay.GetError(),
},
)
if available {
relaysAvailable++
}
}
return relayStateOutput{
Total: len(relays),
Available: relaysAvailable,
Details: relayStateDetail,
}
}
func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput {
mappedNSGroups := make([]nsServerGroupStateOutput, 0, len(servers))
for _, pbNsGroupServer := range servers {
mappedNSGroups = append(mappedNSGroups, nsServerGroupStateOutput{
Servers: pbNsGroupServer.GetServers(),
Domains: pbNsGroupServer.GetDomains(),
Enabled: pbNsGroupServer.GetEnabled(),
Error: pbNsGroupServer.GetError(),
})
}
return mappedNSGroups
}
func mapPeers(peers []*proto.PeerState) peersStateOutput {
var peersStateDetail []peerStateDetailOutput
peersConnected := 0
for _, pbPeerState := range peers {
localICE := ""
remoteICE := ""
localICEEndpoint := ""
remoteICEEndpoint := ""
relayServerAddress := ""
connType := ""
lastHandshake := time.Time{}
transferReceived := int64(0)
transferSent := int64(0)
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
if skipDetailByFilters(pbPeerState, isPeerConnected) {
continue
}
if isPeerConnected {
peersConnected++
localICE = pbPeerState.GetLocalIceCandidateType()
remoteICE = pbPeerState.GetRemoteIceCandidateType()
localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint()
remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint()
connType = "P2P"
if pbPeerState.Relayed {
connType = "Relayed"
}
relayServerAddress = pbPeerState.GetRelayAddress()
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
transferReceived = pbPeerState.GetBytesRx()
transferSent = pbPeerState.GetBytesTx()
}
timeLocal := pbPeerState.GetConnStatusUpdate().AsTime().Local()
peerState := peerStateDetailOutput{
IP: pbPeerState.GetIP(),
PubKey: pbPeerState.GetPubKey(),
Status: pbPeerState.GetConnStatus(),
LastStatusUpdate: timeLocal,
ConnType: connType,
IceCandidateType: iceCandidateType{
Local: localICE,
Remote: remoteICE,
},
IceCandidateEndpoint: iceCandidateType{
Local: localICEEndpoint,
Remote: remoteICEEndpoint,
},
RelayAddress: relayServerAddress,
FQDN: pbPeerState.GetFqdn(),
LastWireguardHandshake: lastHandshake,
TransferReceived: transferReceived,
TransferSent: transferSent,
Latency: pbPeerState.GetLatency().AsDuration(),
RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
Routes: pbPeerState.GetNetworks(),
Networks: pbPeerState.GetNetworks(),
}
peersStateDetail = append(peersStateDetail, peerState)
}
sortPeersByIP(peersStateDetail)
peersOverview := peersStateOutput{
Total: len(peersStateDetail),
Connected: peersConnected,
Details: peersStateDetail,
}
return peersOverview
}
func sortPeersByIP(peersStateDetail []peerStateDetailOutput) {
if len(peersStateDetail) > 0 {
sort.SliceStable(peersStateDetail, func(i, j int) bool {
iAddr, _ := netip.ParseAddr(peersStateDetail[i].IP)
jAddr, _ := netip.ParseAddr(peersStateDetail[j].IP)
return iAddr.Compare(jAddr) == -1
})
}
}
func parseInterfaceIP(interfaceIP string) string { func parseInterfaceIP(interfaceIP string) string {
ip, _, err := net.ParseCIDR(interfaceIP) ip, _, err := net.ParseCIDR(interfaceIP)
if err != nil { if err != nil {
@@ -169,3 +427,452 @@ func parseInterfaceIP(interfaceIP string) string {
} }
return fmt.Sprintf("%s\n", ip) return fmt.Sprintf("%s\n", ip)
} }
func parseToJSON(overview statusOutputOverview) (string, error) {
jsonBytes, err := json.Marshal(overview)
if err != nil {
return "", fmt.Errorf("json marshal failed")
}
return string(jsonBytes), err
}
func parseToYAML(overview statusOutputOverview) (string, error) {
yamlBytes, err := yaml.Marshal(overview)
if err != nil {
return "", fmt.Errorf("yaml marshal failed")
}
return string(yamlBytes), nil
}
func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays bool, showNameServers bool) string {
var managementConnString string
if overview.ManagementState.Connected {
managementConnString = "Connected"
if showURL {
managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL)
}
} else {
managementConnString = "Disconnected"
if overview.ManagementState.Error != "" {
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error)
}
}
var signalConnString string
if overview.SignalState.Connected {
signalConnString = "Connected"
if showURL {
signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL)
}
} else {
signalConnString = "Disconnected"
if overview.SignalState.Error != "" {
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error)
}
}
interfaceTypeString := "Userspace"
interfaceIP := overview.IP
if overview.KernelInterface {
interfaceTypeString = "Kernel"
} else if overview.IP == "" {
interfaceTypeString = "N/A"
interfaceIP = "N/A"
}
var relaysString string
if showRelays {
for _, relay := range overview.Relays.Details {
available := "Available"
reason := ""
if !relay.Available {
available = "Unavailable"
reason = fmt.Sprintf(", reason: %s", relay.Error)
}
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
}
} else {
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
}
networks := "-"
if len(overview.Networks) > 0 {
sort.Strings(overview.Networks)
networks = strings.Join(overview.Networks, ", ")
}
var dnsServersString string
if showNameServers {
for _, nsServerGroup := range overview.NSServerGroups {
enabled := "Available"
if !nsServerGroup.Enabled {
enabled = "Unavailable"
}
errorString := ""
if nsServerGroup.Error != "" {
errorString = fmt.Sprintf(", reason: %s", nsServerGroup.Error)
errorString = strings.TrimSpace(errorString)
}
domainsString := strings.Join(nsServerGroup.Domains, ", ")
if domainsString == "" {
domainsString = "." // Show "." for the default zone
}
dnsServersString += fmt.Sprintf(
"\n [%s] for [%s] is %s%s",
strings.Join(nsServerGroup.Servers, ", "),
domainsString,
enabled,
errorString,
)
}
} else {
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups))
}
rosenpassEnabledStatus := "false"
if overview.RosenpassEnabled {
rosenpassEnabledStatus = "true"
if overview.RosenpassPermissive {
rosenpassEnabledStatus = "true (permissive)" //nolint:gosec
}
}
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
goos := runtime.GOOS
goarch := runtime.GOARCH
goarm := ""
if goarch == "arm" {
goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM"))
}
summary := fmt.Sprintf(
"OS: %s\n"+
"Daemon version: %s\n"+
"CLI version: %s\n"+
"Management: %s\n"+
"Signal: %s\n"+
"Relays: %s\n"+
"Nameservers: %s\n"+
"FQDN: %s\n"+
"NetBird IP: %s\n"+
"Interface type: %s\n"+
"Quantum resistance: %s\n"+
"Routes: %s\n"+
"Networks: %s\n"+
"Peers count: %s\n",
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
overview.DaemonVersion,
version.NetbirdVersion(),
managementConnString,
signalConnString,
relaysString,
dnsServersString,
overview.FQDN,
interfaceIP,
interfaceTypeString,
rosenpassEnabledStatus,
networks,
networks,
peersCountString,
)
return summary
}
func parseToFullDetailSummary(overview statusOutputOverview) string {
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
summary := parseGeneralSummary(overview, true, true, true)
return fmt.Sprintf(
"Peers detail:"+
"%s\n"+
"%s",
parsedPeersString,
summary,
)
}
func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bool) string {
var (
peersString = ""
)
for _, peerState := range peers.Details {
localICE := "-"
if peerState.IceCandidateType.Local != "" {
localICE = peerState.IceCandidateType.Local
}
remoteICE := "-"
if peerState.IceCandidateType.Remote != "" {
remoteICE = peerState.IceCandidateType.Remote
}
localICEEndpoint := "-"
if peerState.IceCandidateEndpoint.Local != "" {
localICEEndpoint = peerState.IceCandidateEndpoint.Local
}
remoteICEEndpoint := "-"
if peerState.IceCandidateEndpoint.Remote != "" {
remoteICEEndpoint = peerState.IceCandidateEndpoint.Remote
}
rosenpassEnabledStatus := "false"
if rosenpassEnabled {
if peerState.RosenpassEnabled {
rosenpassEnabledStatus = "true"
} else {
if rosenpassPermissive {
rosenpassEnabledStatus = "false (remote didn't enable quantum resistance)"
} else {
rosenpassEnabledStatus = "false (connection won't work without a permissive mode)"
}
}
} else {
if peerState.RosenpassEnabled {
rosenpassEnabledStatus = "false (connection might not work without a remote permissive mode)"
}
}
networks := "-"
if len(peerState.Networks) > 0 {
sort.Strings(peerState.Networks)
networks = strings.Join(peerState.Networks, ", ")
}
peerString := fmt.Sprintf(
"\n %s:\n"+
" NetBird IP: %s\n"+
" Public key: %s\n"+
" Status: %s\n"+
" -- detail --\n"+
" Connection type: %s\n"+
" ICE candidate (Local/Remote): %s/%s\n"+
" ICE candidate endpoints (Local/Remote): %s/%s\n"+
" Relay server address: %s\n"+
" Last connection update: %s\n"+
" Last WireGuard handshake: %s\n"+
" Transfer status (received/sent) %s/%s\n"+
" Quantum resistance: %s\n"+
" Routes: %s\n"+
" Networks: %s\n"+
" Latency: %s\n",
peerState.FQDN,
peerState.IP,
peerState.PubKey,
peerState.Status,
peerState.ConnType,
localICE,
remoteICE,
localICEEndpoint,
remoteICEEndpoint,
peerState.RelayAddress,
timeAgo(peerState.LastStatusUpdate),
timeAgo(peerState.LastWireguardHandshake),
toIEC(peerState.TransferReceived),
toIEC(peerState.TransferSent),
rosenpassEnabledStatus,
networks,
networks,
peerState.Latency.String(),
)
peersString += peerString
}
return peersString
}
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
statusEval := false
ipEval := false
nameEval := true
if statusFilter != "" {
lowerStatusFilter := strings.ToLower(statusFilter)
if lowerStatusFilter == "disconnected" && isConnected {
statusEval = true
} else if lowerStatusFilter == "connected" && !isConnected {
statusEval = true
}
}
if len(ipsFilter) > 0 {
_, ok := ipsFilterMap[peerState.IP]
if !ok {
ipEval = true
}
}
if len(prefixNamesFilter) > 0 {
for prefixNameFilter := range prefixNamesFilterMap {
if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
nameEval = false
break
}
}
} else {
nameEval = false
}
return statusEval || ipEval || nameEval
}
func toIEC(b int64) string {
const unit = 1024
if b < unit {
return fmt.Sprintf("%d B", b)
}
div, exp := int64(unit), 0
for n := b / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %ciB",
float64(b)/float64(div), "KMGTPE"[exp])
}
func countEnabled(dnsServers []nsServerGroupStateOutput) int {
count := 0
for _, server := range dnsServers {
if server.Enabled {
count++
}
}
return count
}
// timeAgo returns a string representing the duration since the provided time in a human-readable format.
func timeAgo(t time.Time) string {
if t.IsZero() || t.Equal(time.Unix(0, 0)) {
return "-"
}
duration := time.Since(t)
switch {
case duration < time.Second:
return "Now"
case duration < time.Minute:
seconds := int(duration.Seconds())
if seconds == 1 {
return "1 second ago"
}
return fmt.Sprintf("%d seconds ago", seconds)
case duration < time.Hour:
minutes := int(duration.Minutes())
seconds := int(duration.Seconds()) % 60
if minutes == 1 {
if seconds == 1 {
return "1 minute, 1 second ago"
} else if seconds > 0 {
return fmt.Sprintf("1 minute, %d seconds ago", seconds)
}
return "1 minute ago"
}
if seconds > 0 {
return fmt.Sprintf("%d minutes, %d seconds ago", minutes, seconds)
}
return fmt.Sprintf("%d minutes ago", minutes)
case duration < 24*time.Hour:
hours := int(duration.Hours())
minutes := int(duration.Minutes()) % 60
if hours == 1 {
if minutes == 1 {
return "1 hour, 1 minute ago"
} else if minutes > 0 {
return fmt.Sprintf("1 hour, %d minutes ago", minutes)
}
return "1 hour ago"
}
if minutes > 0 {
return fmt.Sprintf("%d hours, %d minutes ago", hours, minutes)
}
return fmt.Sprintf("%d hours ago", hours)
}
days := int(duration.Hours()) / 24
hours := int(duration.Hours()) % 24
if days == 1 {
if hours == 1 {
return "1 day, 1 hour ago"
} else if hours > 0 {
return fmt.Sprintf("1 day, %d hours ago", hours)
}
return "1 day ago"
}
if hours > 0 {
return fmt.Sprintf("%d days, %d hours ago", days, hours)
}
return fmt.Sprintf("%d days ago", days)
}
func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
peer.FQDN = a.AnonymizeDomain(peer.FQDN)
if localIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Local); err == nil {
peer.IceCandidateEndpoint.Local = fmt.Sprintf("%s:%s", a.AnonymizeIPString(localIP), port)
}
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
}
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
for i, route := range peer.Networks {
peer.Networks[i] = a.AnonymizeIPString(route)
}
for i, route := range peer.Networks {
peer.Networks[i] = a.AnonymizeRoute(route)
}
for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeIPString(route)
}
for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeRoute(route)
}
}
func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) {
for i, peer := range overview.Peers.Details {
peer := peer
anonymizePeerDetail(a, &peer)
overview.Peers.Details[i] = peer
}
overview.ManagementState.URL = a.AnonymizeURI(overview.ManagementState.URL)
overview.ManagementState.Error = a.AnonymizeString(overview.ManagementState.Error)
overview.SignalState.URL = a.AnonymizeURI(overview.SignalState.URL)
overview.SignalState.Error = a.AnonymizeString(overview.SignalState.Error)
overview.IP = a.AnonymizeIPString(overview.IP)
for i, detail := range overview.Relays.Details {
detail.URI = a.AnonymizeURI(detail.URI)
detail.Error = a.AnonymizeString(detail.Error)
overview.Relays.Details[i] = detail
}
for i, nsGroup := range overview.NSServerGroups {
for j, domain := range nsGroup.Domains {
overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
}
for j, ns := range nsGroup.Servers {
host, port, err := net.SplitHostPort(ns)
if err == nil {
overview.NSServerGroups[i].Servers[j] = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
}
}
}
for i, route := range overview.Networks {
overview.Networks[i] = a.AnonymizeRoute(route)
}
for i, route := range overview.Routes {
overview.Routes[i] = a.AnonymizeRoute(route)
}
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
}

View File

@@ -1,11 +1,597 @@
package cmd package cmd
import ( import (
"bytes"
"encoding/json"
"fmt"
"runtime"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/version"
) )
func init() {
loc, err := time.LoadLocation("UTC")
if err != nil {
panic(err)
}
time.Local = loc
}
var resp = &proto.StatusResponse{
Status: "Connected",
FullStatus: &proto.FullStatus{
Peers: []*proto.PeerState{
{
IP: "192.168.178.101",
PubKey: "Pubkey1",
Fqdn: "peer-1.awesome-domain.com",
ConnStatus: "Connected",
ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
Relayed: false,
LocalIceCandidateType: "",
RemoteIceCandidateType: "",
LocalIceCandidateEndpoint: "",
RemoteIceCandidateEndpoint: "",
LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)),
BytesRx: 200,
BytesTx: 100,
Networks: []string{
"10.1.0.0/24",
},
Latency: durationpb.New(time.Duration(10000000)),
},
{
IP: "192.168.178.102",
PubKey: "Pubkey2",
Fqdn: "peer-2.awesome-domain.com",
ConnStatus: "Connected",
ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
Relayed: true,
LocalIceCandidateType: "relay",
RemoteIceCandidateType: "prflx",
LocalIceCandidateEndpoint: "10.0.0.1:10001",
RemoteIceCandidateEndpoint: "10.0.10.1:10002",
LastWireguardHandshake: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 3, 0, time.UTC)),
BytesRx: 2000,
BytesTx: 1000,
Latency: durationpb.New(time.Duration(10000000)),
},
},
ManagementState: &proto.ManagementState{
URL: "my-awesome-management.com:443",
Connected: true,
Error: "",
},
SignalState: &proto.SignalState{
URL: "my-awesome-signal.com:443",
Connected: true,
Error: "",
},
Relays: []*proto.RelayState{
{
URI: "stun:my-awesome-stun.com:3478",
Available: true,
Error: "",
},
{
URI: "turns:my-awesome-turn.com:443?transport=tcp",
Available: false,
Error: "context: deadline exceeded",
},
},
LocalPeerState: &proto.LocalPeerState{
IP: "192.168.178.100/16",
PubKey: "Some-Pub-Key",
KernelInterface: true,
Fqdn: "some-localhost.awesome-domain.com",
Networks: []string{
"10.10.0.0/24",
},
},
DnsServers: []*proto.NSGroupState{
{
Servers: []string{
"8.8.8.8:53",
},
Domains: nil,
Enabled: true,
Error: "",
},
{
Servers: []string{
"1.1.1.1:53",
"2.2.2.2:53",
},
Domains: []string{
"example.com",
"example.net",
},
Enabled: false,
Error: "timeout",
},
},
},
DaemonVersion: "0.14.1",
}
var overview = statusOutputOverview{
Peers: peersStateOutput{
Total: 2,
Connected: 2,
Details: []peerStateDetailOutput{
{
IP: "192.168.178.101",
PubKey: "Pubkey1",
FQDN: "peer-1.awesome-domain.com",
Status: "Connected",
LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC),
ConnType: "P2P",
IceCandidateType: iceCandidateType{
Local: "",
Remote: "",
},
IceCandidateEndpoint: iceCandidateType{
Local: "",
Remote: "",
},
LastWireguardHandshake: time.Date(2001, 1, 1, 1, 1, 2, 0, time.UTC),
TransferReceived: 200,
TransferSent: 100,
Routes: []string{
"10.1.0.0/24",
},
Networks: []string{
"10.1.0.0/24",
},
Latency: time.Duration(10000000),
},
{
IP: "192.168.178.102",
PubKey: "Pubkey2",
FQDN: "peer-2.awesome-domain.com",
Status: "Connected",
LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC),
ConnType: "Relayed",
IceCandidateType: iceCandidateType{
Local: "relay",
Remote: "prflx",
},
IceCandidateEndpoint: iceCandidateType{
Local: "10.0.0.1:10001",
Remote: "10.0.10.1:10002",
},
LastWireguardHandshake: time.Date(2002, 2, 2, 2, 2, 3, 0, time.UTC),
TransferReceived: 2000,
TransferSent: 1000,
Latency: time.Duration(10000000),
},
},
},
CliVersion: version.NetbirdVersion(),
DaemonVersion: "0.14.1",
ManagementState: managementStateOutput{
URL: "my-awesome-management.com:443",
Connected: true,
Error: "",
},
SignalState: signalStateOutput{
URL: "my-awesome-signal.com:443",
Connected: true,
Error: "",
},
Relays: relayStateOutput{
Total: 2,
Available: 1,
Details: []relayStateOutputDetail{
{
URI: "stun:my-awesome-stun.com:3478",
Available: true,
Error: "",
},
{
URI: "turns:my-awesome-turn.com:443?transport=tcp",
Available: false,
Error: "context: deadline exceeded",
},
},
},
IP: "192.168.178.100/16",
PubKey: "Some-Pub-Key",
KernelInterface: true,
FQDN: "some-localhost.awesome-domain.com",
NSServerGroups: []nsServerGroupStateOutput{
{
Servers: []string{
"8.8.8.8:53",
},
Domains: nil,
Enabled: true,
Error: "",
},
{
Servers: []string{
"1.1.1.1:53",
"2.2.2.2:53",
},
Domains: []string{
"example.com",
"example.net",
},
Enabled: false,
Error: "timeout",
},
},
Routes: []string{
"10.10.0.0/24",
},
Networks: []string{
"10.10.0.0/24",
},
}
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
convertedResult := convertToStatusOutputOverview(resp)
assert.Equal(t, overview, convertedResult)
}
func TestSortingOfPeers(t *testing.T) {
peers := []peerStateDetailOutput{
{
IP: "192.168.178.104",
},
{
IP: "192.168.178.102",
},
{
IP: "192.168.178.101",
},
{
IP: "192.168.178.105",
},
{
IP: "192.168.178.103",
},
}
sortPeersByIP(peers)
assert.Equal(t, peers[3].IP, "192.168.178.104")
}
func TestParsingToJSON(t *testing.T) {
jsonString, _ := parseToJSON(overview)
//@formatter:off
expectedJSONString := `
{
"peers": {
"total": 2,
"connected": 2,
"details": [
{
"fqdn": "peer-1.awesome-domain.com",
"netbirdIp": "192.168.178.101",
"publicKey": "Pubkey1",
"status": "Connected",
"lastStatusUpdate": "2001-01-01T01:01:01Z",
"connectionType": "P2P",
"iceCandidateType": {
"local": "",
"remote": ""
},
"iceCandidateEndpoint": {
"local": "",
"remote": ""
},
"relayAddress": "",
"lastWireguardHandshake": "2001-01-01T01:01:02Z",
"transferReceived": 200,
"transferSent": 100,
"latency": 10000000,
"quantumResistance": false,
"routes": [
"10.1.0.0/24"
],
"networks": [
"10.1.0.0/24"
]
},
{
"fqdn": "peer-2.awesome-domain.com",
"netbirdIp": "192.168.178.102",
"publicKey": "Pubkey2",
"status": "Connected",
"lastStatusUpdate": "2002-02-02T02:02:02Z",
"connectionType": "Relayed",
"iceCandidateType": {
"local": "relay",
"remote": "prflx"
},
"iceCandidateEndpoint": {
"local": "10.0.0.1:10001",
"remote": "10.0.10.1:10002"
},
"relayAddress": "",
"lastWireguardHandshake": "2002-02-02T02:02:03Z",
"transferReceived": 2000,
"transferSent": 1000,
"latency": 10000000,
"quantumResistance": false,
"routes": null,
"networks": null
}
]
},
"cliVersion": "development",
"daemonVersion": "0.14.1",
"management": {
"url": "my-awesome-management.com:443",
"connected": true,
"error": ""
},
"signal": {
"url": "my-awesome-signal.com:443",
"connected": true,
"error": ""
},
"relays": {
"total": 2,
"available": 1,
"details": [
{
"uri": "stun:my-awesome-stun.com:3478",
"available": true,
"error": ""
},
{
"uri": "turns:my-awesome-turn.com:443?transport=tcp",
"available": false,
"error": "context: deadline exceeded"
}
]
},
"netbirdIp": "192.168.178.100/16",
"publicKey": "Some-Pub-Key",
"usesKernelInterface": true,
"fqdn": "some-localhost.awesome-domain.com",
"quantumResistance": false,
"quantumResistancePermissive": false,
"routes": [
"10.10.0.0/24"
],
"networks": [
"10.10.0.0/24"
],
"dnsServers": [
{
"servers": [
"8.8.8.8:53"
],
"domains": null,
"enabled": true,
"error": ""
},
{
"servers": [
"1.1.1.1:53",
"2.2.2.2:53"
],
"domains": [
"example.com",
"example.net"
],
"enabled": false,
"error": "timeout"
}
]
}`
// @formatter:on
var expectedJSON bytes.Buffer
require.NoError(t, json.Compact(&expectedJSON, []byte(expectedJSONString)))
assert.Equal(t, expectedJSON.String(), jsonString)
}
func TestParsingToYAML(t *testing.T) {
yaml, _ := parseToYAML(overview)
expectedYAML :=
`peers:
total: 2
connected: 2
details:
- fqdn: peer-1.awesome-domain.com
netbirdIp: 192.168.178.101
publicKey: Pubkey1
status: Connected
lastStatusUpdate: 2001-01-01T01:01:01Z
connectionType: P2P
iceCandidateType:
local: ""
remote: ""
iceCandidateEndpoint:
local: ""
remote: ""
relayAddress: ""
lastWireguardHandshake: 2001-01-01T01:01:02Z
transferReceived: 200
transferSent: 100
latency: 10ms
quantumResistance: false
routes:
- 10.1.0.0/24
networks:
- 10.1.0.0/24
- fqdn: peer-2.awesome-domain.com
netbirdIp: 192.168.178.102
publicKey: Pubkey2
status: Connected
lastStatusUpdate: 2002-02-02T02:02:02Z
connectionType: Relayed
iceCandidateType:
local: relay
remote: prflx
iceCandidateEndpoint:
local: 10.0.0.1:10001
remote: 10.0.10.1:10002
relayAddress: ""
lastWireguardHandshake: 2002-02-02T02:02:03Z
transferReceived: 2000
transferSent: 1000
latency: 10ms
quantumResistance: false
routes: []
networks: []
cliVersion: development
daemonVersion: 0.14.1
management:
url: my-awesome-management.com:443
connected: true
error: ""
signal:
url: my-awesome-signal.com:443
connected: true
error: ""
relays:
total: 2
available: 1
details:
- uri: stun:my-awesome-stun.com:3478
available: true
error: ""
- uri: turns:my-awesome-turn.com:443?transport=tcp
available: false
error: 'context: deadline exceeded'
netbirdIp: 192.168.178.100/16
publicKey: Some-Pub-Key
usesKernelInterface: true
fqdn: some-localhost.awesome-domain.com
quantumResistance: false
quantumResistancePermissive: false
routes:
- 10.10.0.0/24
networks:
- 10.10.0.0/24
dnsServers:
- servers:
- 8.8.8.8:53
domains: []
enabled: true
error: ""
- servers:
- 1.1.1.1:53
- 2.2.2.2:53
domains:
- example.com
- example.net
enabled: false
error: timeout
`
assert.Equal(t, expectedYAML, yaml)
}
func TestParsingToDetail(t *testing.T) {
// Calculate time ago based on the fixture dates
lastConnectionUpdate1 := timeAgo(overview.Peers.Details[0].LastStatusUpdate)
lastHandshake1 := timeAgo(overview.Peers.Details[0].LastWireguardHandshake)
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
detail := parseToFullDetailSummary(overview)
expectedDetail := fmt.Sprintf(
`Peers detail:
peer-1.awesome-domain.com:
NetBird IP: 192.168.178.101
Public key: Pubkey1
Status: Connected
-- detail --
Connection type: P2P
ICE candidate (Local/Remote): -/-
ICE candidate endpoints (Local/Remote): -/-
Relay server address:
Last connection update: %s
Last WireGuard handshake: %s
Transfer status (received/sent) 200 B/100 B
Quantum resistance: false
Routes: 10.1.0.0/24
Networks: 10.1.0.0/24
Latency: 10ms
peer-2.awesome-domain.com:
NetBird IP: 192.168.178.102
Public key: Pubkey2
Status: Connected
-- detail --
Connection type: Relayed
ICE candidate (Local/Remote): relay/prflx
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
Relay server address:
Last connection update: %s
Last WireGuard handshake: %s
Transfer status (received/sent) 2.0 KiB/1000 B
Quantum resistance: false
Routes: -
Networks: -
Latency: 10ms
OS: %s/%s
Daemon version: 0.14.1
CLI version: %s
Management: Connected to my-awesome-management.com:443
Signal: Connected to my-awesome-signal.com:443
Relays:
[stun:my-awesome-stun.com:3478] is Available
[turns:my-awesome-turn.com:443?transport=tcp] is Unavailable, reason: context: deadline exceeded
Nameservers:
[8.8.8.8:53] for [.] is Available
[1.1.1.1:53, 2.2.2.2:53] for [example.com, example.net] is Unavailable, reason: timeout
FQDN: some-localhost.awesome-domain.com
NetBird IP: 192.168.178.100/16
Interface type: Kernel
Quantum resistance: false
Routes: 10.10.0.0/24
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
assert.Equal(t, expectedDetail, detail)
}
func TestParsingToShortVersion(t *testing.T) {
shortVersion := parseGeneralSummary(overview, false, false, false)
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
Daemon version: 0.14.1
CLI version: development
Management: Connected
Signal: Connected
Relays: 1/2 Available
Nameservers: 1/2 Available
FQDN: some-localhost.awesome-domain.com
NetBird IP: 192.168.178.100/16
Interface type: Kernel
Quantum resistance: false
Routes: 10.10.0.0/24
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`
assert.Equal(t, expectedString, shortVersion)
}
func TestParsingOfIP(t *testing.T) { func TestParsingOfIP(t *testing.T) {
InterfaceIP := "192.168.178.123/16" InterfaceIP := "192.168.178.123/16"
@@ -13,3 +599,31 @@ func TestParsingOfIP(t *testing.T) {
assert.Equal(t, "192.168.178.123\n", parsedIP) assert.Equal(t, "192.168.178.123\n", parsedIP)
} }
func TestTimeAgo(t *testing.T) {
now := time.Now()
cases := []struct {
name string
input time.Time
expected string
}{
{"Now", now, "Now"},
{"Seconds ago", now.Add(-10 * time.Second), "10 seconds ago"},
{"One minute ago", now.Add(-1 * time.Minute), "1 minute ago"},
{"Minutes and seconds ago", now.Add(-(1*time.Minute + 30*time.Second)), "1 minute, 30 seconds ago"},
{"One hour ago", now.Add(-1 * time.Hour), "1 hour ago"},
{"Hours and minutes ago", now.Add(-(2*time.Hour + 15*time.Minute)), "2 hours, 15 minutes ago"},
{"One day ago", now.Add(-24 * time.Hour), "1 day ago"},
{"Multiple days ago", now.Add(-(72*time.Hour + 20*time.Minute)), "3 days ago"},
{"Zero time", time.Time{}, "-"},
{"Unix zero time", time.Unix(0, 0), "-"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
result := timeAgo(tc.input)
assert.Equal(t, tc.expected, result, "Failed %s", tc.name)
})
}
}

View File

@@ -1,31 +0,0 @@
package cmd
// Flag constants for system configuration
const (
disableClientRoutesFlag = "disable-client-routes"
disableServerRoutesFlag = "disable-server-routes"
disableDNSFlag = "disable-dns"
disableFirewallFlag = "disable-firewall"
)
var (
disableClientRoutes bool
disableServerRoutes bool
disableDNS bool
disableFirewall bool
)
func init() {
// Add system flags to upCmd
upCmd.PersistentFlags().BoolVar(&disableClientRoutes, disableClientRoutesFlag, false,
"Disable client routes. If enabled, the client won't process client routes received from the management service.")
upCmd.PersistentFlags().BoolVar(&disableServerRoutes, disableServerRoutesFlag, false,
"Disable server routes. If enabled, the client won't act as a router for server routes received from the management service.")
upCmd.PersistentFlags().BoolVar(&disableDNS, disableDNSFlag, false,
"Disable DNS. If enabled, the client won't configure DNS settings.")
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
}

View File

@@ -10,7 +10,6 @@ import (
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
@@ -90,13 +89,13 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err) require.NoError(t, err)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settings.NewManagerMock()) accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManagerMock(), peersUpdateManager, secretsManager, nil, nil, nil) mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -1,137 +0,0 @@
package cmd
import (
"fmt"
"math/rand"
"strings"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var traceCmd = &cobra.Command{
Use: "trace <direction> <source-ip> <dest-ip>",
Short: "Trace a packet through the firewall",
Example: `
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
Args: cobra.ExactArgs(3),
RunE: tracePacket,
}
func init() {
debugCmd.AddCommand(traceCmd)
traceCmd.Flags().StringP("protocol", "p", "tcp", "Protocol (tcp/udp/icmp)")
traceCmd.Flags().Uint16("sport", 0, "Source port")
traceCmd.Flags().Uint16("dport", 0, "Destination port")
traceCmd.Flags().Uint8("icmp-type", 0, "ICMP type")
traceCmd.Flags().Uint8("icmp-code", 0, "ICMP code")
traceCmd.Flags().Bool("syn", false, "TCP SYN flag")
traceCmd.Flags().Bool("ack", false, "TCP ACK flag")
traceCmd.Flags().Bool("fin", false, "TCP FIN flag")
traceCmd.Flags().Bool("rst", false, "TCP RST flag")
traceCmd.Flags().Bool("psh", false, "TCP PSH flag")
traceCmd.Flags().Bool("urg", false, "TCP URG flag")
}
func tracePacket(cmd *cobra.Command, args []string) error {
direction := strings.ToLower(args[0])
if direction != "in" && direction != "out" {
return fmt.Errorf("invalid direction: use 'in' or 'out'")
}
protocol := cmd.Flag("protocol").Value.String()
if protocol != "tcp" && protocol != "udp" && protocol != "icmp" {
return fmt.Errorf("invalid protocol: use tcp/udp/icmp")
}
sport, err := cmd.Flags().GetUint16("sport")
if err != nil {
return fmt.Errorf("invalid source port: %v", err)
}
dport, err := cmd.Flags().GetUint16("dport")
if err != nil {
return fmt.Errorf("invalid destination port: %v", err)
}
// For TCP/UDP, generate random ephemeral port (49152-65535) if not specified
if protocol != "icmp" {
if sport == 0 {
sport = uint16(rand.Intn(16383) + 49152)
}
if dport == 0 {
dport = uint16(rand.Intn(16383) + 49152)
}
}
var tcpFlags *proto.TCPFlags
if protocol == "tcp" {
syn, _ := cmd.Flags().GetBool("syn")
ack, _ := cmd.Flags().GetBool("ack")
fin, _ := cmd.Flags().GetBool("fin")
rst, _ := cmd.Flags().GetBool("rst")
psh, _ := cmd.Flags().GetBool("psh")
urg, _ := cmd.Flags().GetBool("urg")
tcpFlags = &proto.TCPFlags{
Syn: syn,
Ack: ack,
Fin: fin,
Rst: rst,
Psh: psh,
Urg: urg,
}
}
icmpType, _ := cmd.Flags().GetUint32("icmp-type")
icmpCode, _ := cmd.Flags().GetUint32("icmp-code")
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.TracePacket(cmd.Context(), &proto.TracePacketRequest{
SourceIp: args[1],
DestinationIp: args[2],
Protocol: protocol,
SourcePort: uint32(sport),
DestinationPort: uint32(dport),
Direction: direction,
TcpFlags: tcpFlags,
IcmpType: &icmpType,
IcmpCode: &icmpCode,
})
if err != nil {
return fmt.Errorf("trace failed: %v", status.Convert(err).Message())
}
printTrace(cmd, args[1], args[2], protocol, sport, dport, resp)
return nil
}
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
for _, stage := range resp.Stages {
if stage.ForwardingDetails != nil {
cmd.Printf("%s: %s [%s]\n", stage.Name, stage.Message, *stage.ForwardingDetails)
} else {
cmd.Printf("%s: %s\n", stage.Name, stage.Message)
}
}
disposition := map[bool]string{
true: "\033[32mALLOWED\033[0m", // Green
false: "\033[31mDENIED\033[0m", // Red
}[resp.FinalDisposition]
cmd.Printf("\nFinal disposition: %s\n", disposition)
}

View File

@@ -20,7 +20,6 @@ import (
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -30,16 +29,9 @@ const (
interfaceInputType interfaceInputType
) )
const (
dnsLabelsFlag = "extra-dns-labels"
)
var ( var (
foregroundMode bool foregroundMode bool
dnsLabels []string upCmd = &cobra.Command{
dnsLabelsValidated domain.List
upCmd = &cobra.Command{
Use: "up", Use: "up",
Short: "install, login and start Netbird client", Short: "install, login and start Netbird client",
RunE: upFunc, RunE: upFunc,
@@ -56,15 +48,6 @@ func init() {
) )
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening") upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval") upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
`Sets DNS labels`+
`You can specify a comma-separated list of up to 32 labels. `+
`An empty string "" clears the previous configuration. `+
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
`or --extra-dns-labels ""`,
)
} }
func upFunc(cmd *cobra.Command, args []string) error { func upFunc(cmd *cobra.Command, args []string) error {
@@ -83,11 +66,6 @@ func upFunc(cmd *cobra.Command, args []string) error {
return err return err
} }
dnsLabelsValidated, err = validateDnsLabels(dnsLabels)
if err != nil {
return err
}
ctx := internal.CtxInitState(cmd.Context()) ctx := internal.CtxInitState(cmd.Context())
if hostName != "" { if hostName != "" {
@@ -119,7 +97,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
NATExternalIPs: natExternalIPs, NATExternalIPs: natExternalIPs,
CustomDNSAddress: customDNSAddressConverted, CustomDNSAddress: customDNSAddressConverted,
ExtraIFaceBlackList: extraIFaceBlackList, ExtraIFaceBlackList: extraIFaceBlackList,
DNSLabels: dnsLabelsValidated,
} }
if cmd.Flag(enableRosenpassFlag).Changed { if cmd.Flag(enableRosenpassFlag).Changed {
@@ -170,23 +147,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
ic.DNSRouteInterval = &dnsRouteInterval ic.DNSRouteInterval = &dnsRouteInterval
} }
if cmd.Flag(disableClientRoutesFlag).Changed {
ic.DisableClientRoutes = &disableClientRoutes
}
if cmd.Flag(disableServerRoutesFlag).Changed {
ic.DisableServerRoutes = &disableServerRoutes
}
if cmd.Flag(disableDNSFlag).Changed {
ic.DisableDNS = &disableDNS
}
if cmd.Flag(disableFirewallFlag).Changed {
ic.DisableFirewall = &disableFirewall
}
if cmd.Flag(blockLANAccessFlag).Changed {
ic.BlockLANAccess = &blockLANAccess
}
providedSetupKey, err := getSetupKey() providedSetupKey, err := getSetupKey()
if err != nil { if err != nil {
return err return err
@@ -212,7 +172,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
r.GetFullStatus() r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r) connectClient := internal.NewConnectClient(ctx, config, r)
return connectClient.Run(nil) return connectClient.Run()
} }
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
@@ -262,8 +222,6 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
IsLinuxDesktopClient: isLinuxRunningDesktop(), IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName, Hostname: hostName,
ExtraIFaceBlacklist: extraIFaceBlackList, ExtraIFaceBlacklist: extraIFaceBlackList,
DnsLabels: dnsLabels,
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
} }
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
@@ -306,23 +264,6 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
loginRequest.DnsRouteInterval = durationpb.New(dnsRouteInterval) loginRequest.DnsRouteInterval = durationpb.New(dnsRouteInterval)
} }
if cmd.Flag(disableClientRoutesFlag).Changed {
loginRequest.DisableClientRoutes = &disableClientRoutes
}
if cmd.Flag(disableServerRoutesFlag).Changed {
loginRequest.DisableServerRoutes = &disableServerRoutes
}
if cmd.Flag(disableDNSFlag).Changed {
loginRequest.DisableDns = &disableDNS
}
if cmd.Flag(disableFirewallFlag).Changed {
loginRequest.DisableFirewall = &disableFirewall
}
if cmd.Flag(blockLANAccessFlag).Changed {
loginRequest.BlockLanAccess = &blockLANAccess
}
var loginErr error var loginErr error
var loginResp *proto.LoginResponse var loginResp *proto.LoginResponse
@@ -454,24 +395,6 @@ func parseCustomDNSAddress(modified bool) ([]byte, error) {
return parsed, nil return parsed, nil
} }
func validateDnsLabels(labels []string) (domain.List, error) {
var (
domains domain.List
err error
)
if len(labels) == 0 {
return domains, nil
}
domains, err = domain.ValidateDomains(labels)
if err != nil {
return nil, fmt.Errorf("failed to validate dns labels: %v", err)
}
return domains, nil
}
func isValidAddrPort(input string) bool { func isValidAddrPort(input string) bool {
if input == "" { if input == "" {
return true return true

View File

@@ -1,24 +0,0 @@
package configs
import (
"os"
"path/filepath"
"runtime"
)
var StateDir string
func init() {
StateDir = os.Getenv("NB_STATE_DIR")
if StateDir != "" {
return
}
switch runtime.GOOS {
case "windows":
StateDir = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird")
case "darwin", "linux":
StateDir = "/var/lib/netbird"
case "freebsd", "openbsd", "netbsd", "dragonfly":
StateDir = "/var/db/netbird"
}
}

View File

@@ -1,167 +0,0 @@
// Package embed provides a way to embed the NetBird client directly
// into Go programs without requiring a separate NetBird client installation.
package embed
// Basic Usage:
//
// client, err := embed.New(embed.Options{
// DeviceName: "my-service",
// SetupKey: os.Getenv("NB_SETUP_KEY"),
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
// })
// if err != nil {
// log.Fatal(err)
// }
//
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// defer cancel()
// if err := client.Start(ctx); err != nil {
// log.Fatal(err)
// }
//
// Complete HTTP Server Example:
//
// package main
//
// import (
// "context"
// "fmt"
// "log"
// "net/http"
// "os"
// "os/signal"
// "syscall"
// "time"
//
// netbird "github.com/netbirdio/netbird/client/embed"
// )
//
// func main() {
// // Create client with setup key and device name
// client, err := netbird.New(netbird.Options{
// DeviceName: "http-server",
// SetupKey: os.Getenv("NB_SETUP_KEY"),
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
// LogOutput: io.Discard,
// })
// if err != nil {
// log.Fatal(err)
// }
//
// // Start with timeout
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// defer cancel()
// if err := client.Start(ctx); err != nil {
// log.Fatal(err)
// }
//
// // Create HTTP server
// mux := http.NewServeMux()
// mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
// fmt.Printf("Request from %s: %s %s\n", r.RemoteAddr, r.Method, r.URL.Path)
// fmt.Fprintf(w, "Hello from netbird!")
// })
//
// // Listen on netbird network
// l, err := client.ListenTCP(":8080")
// if err != nil {
// log.Fatal(err)
// }
//
// server := &http.Server{Handler: mux}
// go func() {
// if err := server.Serve(l); !errors.Is(err, http.ErrServerClosed) {
// log.Printf("HTTP server error: %v", err)
// }
// }()
//
// log.Printf("HTTP server listening on netbird network port 8080")
//
// // Handle shutdown
// stop := make(chan os.Signal, 1)
// signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
// <-stop
//
// shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
// defer cancel()
//
// if err := server.Shutdown(shutdownCtx); err != nil {
// log.Printf("HTTP shutdown error: %v", err)
// }
// if err := client.Stop(shutdownCtx); err != nil {
// log.Printf("Netbird shutdown error: %v", err)
// }
// }
//
// Complete HTTP Client Example:
//
// package main
//
// import (
// "context"
// "fmt"
// "io"
// "log"
// "os"
// "time"
//
// netbird "github.com/netbirdio/netbird/client/embed"
// )
//
// func main() {
// // Create client with setup key and device name
// client, err := netbird.New(netbird.Options{
// DeviceName: "http-client",
// SetupKey: os.Getenv("NB_SETUP_KEY"),
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
// LogOutput: io.Discard,
// })
// if err != nil {
// log.Fatal(err)
// }
//
// // Start with timeout
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// defer cancel()
//
// if err := client.Start(ctx); err != nil {
// log.Fatal(err)
// }
//
// // Create HTTP client that uses netbird network
// httpClient := client.NewHTTPClient()
// httpClient.Timeout = 10 * time.Second
//
// // Make request to server in netbird network
// target := os.Getenv("NB_TARGET")
// resp, err := httpClient.Get(target)
// if err != nil {
// log.Fatal(err)
// }
// defer resp.Body.Close()
//
// // Read and print response
// body, err := io.ReadAll(resp.Body)
// if err != nil {
// log.Fatal(err)
// }
//
// fmt.Printf("Response from server: %s\n", string(body))
//
// // Clean shutdown
// shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
// defer cancel()
//
// if err := client.Stop(shutdownCtx); err != nil {
// log.Printf("Netbird shutdown error: %v", err)
// }
// }
//
// The package provides several methods for network operations:
// - Dial: Creates outbound connections
// - ListenTCP: Creates TCP listeners
// - ListenUDP: Creates UDP listeners
//
// By default, the embed package uses userspace networking mode, which doesn't
// require root/admin privileges. For production deployments, consider setting
// appropriate config and state paths for persistence.

View File

@@ -1,293 +0,0 @@
package embed
import (
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"os"
"sync"
"github.com/sirupsen/logrus"
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/system"
)
var ErrClientAlreadyStarted = errors.New("client already started")
var ErrClientNotStarted = errors.New("client not started")
// Client manages a netbird embedded client instance
type Client struct {
deviceName string
config *internal.Config
mu sync.Mutex
cancel context.CancelFunc
setupKey string
connect *internal.ConnectClient
}
// Options configures a new Client
type Options struct {
// DeviceName is this peer's name in the network
DeviceName string
// SetupKey is used for authentication
SetupKey string
// ManagementURL overrides the default management server URL
ManagementURL string
// PreSharedKey is the pre-shared key for the WireGuard interface
PreSharedKey string
// LogOutput is the output destination for logs (defaults to os.Stderr if nil)
LogOutput io.Writer
// LogLevel sets the logging level (defaults to info if empty)
LogLevel string
// NoUserspace disables the userspace networking mode. Needs admin/root privileges
NoUserspace bool
// ConfigPath is the path to the netbird config file. If empty, the config will be stored in memory and not persisted.
ConfigPath string
// StatePath is the path to the netbird state file
StatePath string
// DisableClientRoutes disables the client routes
DisableClientRoutes bool
}
// New creates a new netbird embedded client
func New(opts Options) (*Client, error) {
if opts.LogOutput != nil {
logrus.SetOutput(opts.LogOutput)
}
if opts.LogLevel != "" {
level, err := logrus.ParseLevel(opts.LogLevel)
if err != nil {
return nil, fmt.Errorf("parse log level: %w", err)
}
logrus.SetLevel(level)
}
if !opts.NoUserspace {
if err := os.Setenv(netstack.EnvUseNetstackMode, "true"); err != nil {
return nil, fmt.Errorf("setenv: %w", err)
}
if err := os.Setenv(netstack.EnvSkipProxy, "true"); err != nil {
return nil, fmt.Errorf("setenv: %w", err)
}
}
if opts.StatePath != "" {
// TODO: Disable state if path not provided
if err := os.Setenv("NB_DNS_STATE_FILE", opts.StatePath); err != nil {
return nil, fmt.Errorf("setenv: %w", err)
}
}
t := true
var config *internal.Config
var err error
input := internal.ConfigInput{
ConfigPath: opts.ConfigPath,
ManagementURL: opts.ManagementURL,
PreSharedKey: &opts.PreSharedKey,
DisableServerRoutes: &t,
DisableClientRoutes: &opts.DisableClientRoutes,
}
if opts.ConfigPath != "" {
config, err = internal.UpdateOrCreateConfig(input)
} else {
config, err = internal.CreateInMemoryConfig(input)
}
if err != nil {
return nil, fmt.Errorf("create config: %w", err)
}
return &Client{
deviceName: opts.DeviceName,
setupKey: opts.SetupKey,
config: config,
}, nil
}
// Start begins client operation and blocks until the engine has been started successfully or a startup error occurs.
// Pass a context with a deadline to limit the time spent waiting for the engine to start.
func (c *Client) Start(startCtx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.cancel != nil {
return ErrClientAlreadyStarted
}
ctx := internal.CtxInitState(context.Background())
// nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil {
return fmt.Errorf("login: %w", err)
}
recorder := peer.NewRecorder(c.config.ManagementURL.String())
client := internal.NewConnectClient(ctx, c.config, recorder)
// either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available
run := make(chan struct{}, 1)
clientErr := make(chan error, 1)
go func() {
if err := client.Run(run); err != nil {
clientErr <- err
}
}()
select {
case <-startCtx.Done():
if stopErr := client.Stop(); stopErr != nil {
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
}
return startCtx.Err()
case err := <-clientErr:
return fmt.Errorf("startup: %w", err)
case <-run:
}
c.connect = client
return nil
}
// Stop gracefully stops the client.
// Pass a context with a deadline to limit the time spent waiting for the engine to stop.
func (c *Client) Stop(ctx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
if c.connect == nil {
return ErrClientNotStarted
}
done := make(chan error, 1)
go func() {
done <- c.connect.Stop()
}()
select {
case <-ctx.Done():
c.cancel = nil
return ctx.Err()
case err := <-done:
c.cancel = nil
if err != nil {
return fmt.Errorf("stop: %w", err)
}
return nil
}
}
// Dial dials a network address in the netbird network.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
c.mu.Lock()
connect := c.connect
if connect == nil {
c.mu.Unlock()
return nil, ErrClientNotStarted
}
c.mu.Unlock()
engine := connect.Engine()
if engine == nil {
return nil, errors.New("engine not started")
}
nsnet, err := engine.GetNet()
if err != nil {
return nil, fmt.Errorf("get net: %w", err)
}
return nsnet.DialContext(ctx, network, address)
}
// ListenTCP listens on the given address in the netbird network
// Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenTCP(address string) (net.Listener, error) {
nsnet, addr, err := c.getNet()
if err != nil {
return nil, err
}
_, port, err := net.SplitHostPort(address)
if err != nil {
return nil, fmt.Errorf("split host port: %w", err)
}
listenAddr := fmt.Sprintf("%s:%s", addr, port)
tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr)
if err != nil {
return nil, fmt.Errorf("resolve: %w", err)
}
return nsnet.ListenTCP(tcpAddr)
}
// ListenUDP listens on the given address in the netbird network
// Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
nsnet, addr, err := c.getNet()
if err != nil {
return nil, err
}
_, port, err := net.SplitHostPort(address)
if err != nil {
return nil, fmt.Errorf("split host port: %w", err)
}
listenAddr := fmt.Sprintf("%s:%s", addr, port)
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
if err != nil {
return nil, fmt.Errorf("resolve: %w", err)
}
return nsnet.ListenUDP(udpAddr)
}
// NewHTTPClient returns a configured http.Client that uses the netbird network for requests.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) NewHTTPClient() *http.Client {
transport := &http.Transport{
DialContext: c.Dial,
}
return &http.Client{
Transport: transport,
}
}
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
c.mu.Lock()
connect := c.connect
if connect == nil {
c.mu.Unlock()
return nil, netip.Addr{}, errors.New("client not started")
}
c.mu.Unlock()
engine := connect.Engine()
if engine == nil {
return nil, netip.Addr{}, errors.New("engine not started")
}
addr, err := engine.Address()
if err != nil {
return nil, netip.Addr{}, fmt.Errorf("engine address: %w", err)
}
nsnet, err := engine.GetNet()
if err != nil {
return nil, netip.Addr{}, fmt.Errorf("get net: %w", err)
}
return nsnet, addr, nil
}

View File

@@ -10,18 +10,17 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
// NewFirewall creates a firewall manager instance // NewFirewall creates a firewall manager instance
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) { func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) {
if !iface.IsUserspaceBind() { if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
} }
// use userspace packet filtering firewall // use userspace packet filtering firewall
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger) fm, err := uspfilter.Create(iface)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -15,7 +15,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables" nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -34,12 +33,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type // FWType is the type for the firewall type
type FWType int type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) { func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables // on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic // in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers // so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall // for the userspace packet filtering firewall
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes) fm, err := createNativeFirewall(iface, stateManager)
if !iface.IsUserspaceBind() { if !iface.IsUserspaceBind() {
return fm, err return fm, err
@@ -48,10 +47,10 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogg
if err != nil { if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err) log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
} }
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger) return createUserspaceFirewall(iface, fm)
} }
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) { func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
fm, err := createFW(iface) fm, err := createFW(iface)
if err != nil { if err != nil {
return nil, fmt.Errorf("create firewall: %s", err) return nil, fmt.Errorf("create firewall: %s", err)
@@ -78,12 +77,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
} }
} }
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) { func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) {
var errUsp error var errUsp error
if fm != nil { if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger) fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
} else { } else {
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger) fm, errUsp = uspfilter.Create(iface)
} }
if errUsp != nil { if errUsp != nil {

View File

@@ -1,18 +1,13 @@
package firewall package firewall
import ( import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// IFaceMapper defines subset methods of interface required for manager // IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface { type IFaceMapper interface {
Name() string Name() string
Address() wgaddr.Address Address() device.WGAddress
IsUserspaceBind() bool IsUserspaceBind() bool
SetFilter(device.PacketFilter) error SetFilter(device.PacketFilter) error
GetDevice() *device.FilteredDevice
GetWGDevice() *wgdevice.Device
} }

View File

@@ -3,7 +3,7 @@ package iptables
import ( import (
"fmt" "fmt"
"net" "net"
"slices" "strconv"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/google/uuid" "github.com/google/uuid"
@@ -19,7 +19,8 @@ const (
tableName = "filter" tableName = "filter"
// rules chains contains the effective ACL rules // rules chains contains the effective ACL rules
chainNameInputRules = "NETBIRD-ACL-INPUT" chainNameInputRules = "NETBIRD-ACL-INPUT"
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
) )
type aclEntries map[string][][]string type aclEntries map[string][][]string
@@ -30,8 +31,10 @@ type entry struct {
} }
type aclManager struct { type aclManager struct {
iptablesClient *iptables.IPTables iptablesClient *iptables.IPTables
wgIface iFaceMapper wgIface iFaceMapper
routingFwChainName string
entries aclEntries entries aclEntries
optionalEntries map[string][]entry optionalEntries map[string][]entry
ipsetStore *ipsetStore ipsetStore *ipsetStore
@@ -39,10 +42,12 @@ type aclManager struct {
stateManager *statemanager.Manager stateManager *statemanager.Manager
} }
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) { func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
m := &aclManager{ m := &aclManager{
iptablesClient: iptablesClient, iptablesClient: iptablesClient,
wgIface: wgIface, wgIface: wgIface,
routingFwChainName: routingFwChainName,
entries: make(map[string][][]string), entries: make(map[string][][]string),
optionalEntries: make(map[string][]entry), optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(), ipsetStore: newIpsetStore(),
@@ -75,27 +80,32 @@ func (m *aclManager) init(stateManager *statemanager.Manager) error {
} }
func (m *aclManager) AddPeerFiltering( func (m *aclManager) AddPeerFiltering(
id []byte,
ip net.IP, ip net.IP,
protocol firewall.Protocol, protocol firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
chain := chainNameInputRules var dPortVal, sPortVal string
if dPort != nil && dPort.Values != nil {
// TODO: we support only one port per rule in current implementation of ACLs
dPortVal = strconv.Itoa(dPort.Values[0])
}
if sPort != nil && sPort.Values != nil {
sPortVal = strconv.Itoa(sPort.Values[0])
}
ipsetName = transformIPsetName(ipsetName, sPort, dPort) var chain string
specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName) if direction == firewall.RuleDirectionOUT {
chain = chainNameOutputRules
} else {
chain = chainNameInputRules
}
mangleSpecs := slices.Clone(specs) ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal)
mangleSpecs = append(mangleSpecs, specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName)
"-i", m.wgIface.Name(),
"-m", "addrtype", "--dst-type", "LOCAL",
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
)
specs = append(specs, "-j", actionToStr(action))
if ipsetName != "" { if ipsetName != "" {
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists { if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
if err := ipset.Add(ipsetName, ip.String()); err != nil { if err := ipset.Add(ipsetName, ip.String()); err != nil {
@@ -127,7 +137,7 @@ func (m *aclManager) AddPeerFiltering(
m.ipsetStore.addIpList(ipsetName, ipList) m.ipsetStore.addIpList(ipsetName, ipList)
} }
ok, err := m.iptablesClient.Exists(tableFilter, chain, specs...) ok, err := m.iptablesClient.Exists("filter", chain, specs...)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to check rule: %w", err) return nil, fmt.Errorf("failed to check rule: %w", err)
} }
@@ -135,22 +145,16 @@ func (m *aclManager) AddPeerFiltering(
return nil, fmt.Errorf("rule already exists") return nil, fmt.Errorf("rule already exists")
} }
if err := m.iptablesClient.Append(tableFilter, chain, specs...); err != nil { if err := m.iptablesClient.Append("filter", chain, specs...); err != nil {
return nil, err return nil, err
} }
if err := m.iptablesClient.Append(tableMangle, chainRTPRE, mangleSpecs...); err != nil {
log.Errorf("failed to add mangle rule: %v", err)
mangleSpecs = nil
}
rule := &Rule{ rule := &Rule{
ruleID: uuid.New().String(), ruleID: uuid.New().String(),
specs: specs, specs: specs,
mangleSpecs: mangleSpecs, ipsetName: ipsetName,
ipsetName: ipsetName, ip: ip.String(),
ip: ip.String(), chain: chain,
chain: chain,
} }
m.updateState() m.updateState()
@@ -193,12 +197,6 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err) return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
} }
if r.mangleSpecs != nil {
if err := m.iptablesClient.Delete(tableMangle, chainRTPRE, r.mangleSpecs...); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
m.updateState() m.updateState()
return nil return nil
@@ -216,7 +214,28 @@ func (m *aclManager) Reset() error {
// todo write less destructive cleanup mechanism // todo write less destructive cleanup mechanism
func (m *aclManager) cleanChains() error { func (m *aclManager) cleanChains() error {
ok, err := m.iptablesClient.ChainExists(tableName, chainNameInputRules) ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules)
if err != nil {
log.Debugf("failed to list chains: %s", err)
return err
}
if ok {
rules := m.entries["OUTPUT"]
for _, rule := range rules {
err := m.iptablesClient.DeleteIfExists(tableName, "OUTPUT", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameOutputRules)
if err != nil {
log.Debugf("failed to clear and delete %s chain: %s", chainNameOutputRules, err)
return err
}
}
ok, err = m.iptablesClient.ChainExists(tableName, chainNameInputRules)
if err != nil { if err != nil {
log.Debugf("failed to list chains: %s", err) log.Debugf("failed to list chains: %s", err)
return err return err
@@ -276,6 +295,12 @@ func (m *aclManager) createDefaultChains() error {
return err return err
} }
// chain netbird-acl-output-rules
if err := m.iptablesClient.NewChain(tableName, chainNameOutputRules); err != nil {
log.Debugf("failed to create '%s' chain: %s", chainNameOutputRules, err)
return err
}
for chainName, rules := range m.entries { for chainName, rules := range m.entries {
for _, rule := range rules { for _, rule := range rules {
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil { if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
@@ -304,6 +329,8 @@ func (m *aclManager) createDefaultChains() error {
// The existing FORWARD rules/policies decide outbound traffic towards our interface. // The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. // In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
func (m *aclManager) seedInitialEntries() { func (m *aclManager) seedInitialEntries() {
established := getConntrackEstablished() established := getConntrackEstablished()
@@ -311,21 +338,25 @@ func (m *aclManager) seedInitialEntries() {
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules}) m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...)) m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
// Inbound is handled by our ACLs, the rest is dropped.
// For outbound we respect the FORWARD policy. However, we need to allow established/related traffic for inbound rules.
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT}) m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
} }
func (m *aclManager) seedInitialOptionalEntries() { func (m *aclManager) seedInitialOptionalEntries() {
m.optionalEntries["FORWARD"] = []entry{ m.optionalEntries["FORWARD"] = []entry{
{ {
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"}, spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", chainNameInputRules},
position: 2, position: 2,
}, },
} }
m.optionalEntries["PREROUTING"] = []entry{
{
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected)},
position: 1,
},
}
} }
func (m *aclManager) appendToEntries(chainName string, spec []string) { func (m *aclManager) appendToEntries(chainName string, spec []string) {
@@ -359,26 +390,42 @@ func (m *aclManager) updateState() {
} }
// filterRuleSpecs returns the specs of a filtering rule // filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) { func filterRuleSpecs(
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
) (specs []string) {
matchByIP := true matchByIP := true
// don't use IP matching if IP is ip 0.0.0.0 // don't use IP matching if IP is ip 0.0.0.0
if ip.String() == "0.0.0.0" { if ip.String() == "0.0.0.0" {
matchByIP = false matchByIP = false
} }
switch direction {
if matchByIP { case firewall.RuleDirectionIN:
if ipsetName != "" { if matchByIP {
specs = append(specs, "-m", "set", "--set", ipsetName, "src") if ipsetName != "" {
} else { specs = append(specs, "-m", "set", "--set", ipsetName, "src")
specs = append(specs, "-s", ip.String()) } else {
specs = append(specs, "-s", ip.String())
}
}
case firewall.RuleDirectionOUT:
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
} else {
specs = append(specs, "-d", ip.String())
}
} }
} }
if protocol != "all" { if protocol != "all" {
specs = append(specs, "-p", protocol) specs = append(specs, "-p", protocol)
} }
specs = append(specs, applyPort("--sport", sPort)...) if sPort != "" {
specs = append(specs, applyPort("--dport", dPort)...) specs = append(specs, "--sport", sPort)
return specs }
if dPort != "" {
specs = append(specs, "--dport", dPort)
}
return append(specs, "-j", actionToStr(action))
} }
func actionToStr(action firewall.Action) string { func actionToStr(action firewall.Action) string {
@@ -388,15 +435,15 @@ func actionToStr(action firewall.Action) string {
return "DROP" return "DROP"
} }
func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port) string { func transformIPsetName(ipsetName string, sPort, dPort string) string {
switch { switch {
case ipsetName == "": case ipsetName == "":
return "" return ""
case sPort != nil && dPort != nil: case sPort != "" && dPort != "":
return ipsetName + "-sport-dport" return ipsetName + "-sport-dport"
case sPort != nil: case sPort != "":
return ipsetName + "-sport" return ipsetName + "-sport"
case dPort != nil: case dPort != "":
return ipsetName + "-dport" return ipsetName + "-dport"
default: default:
return ipsetName return ipsetName

View File

@@ -13,7 +13,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -31,7 +31,7 @@ type Manager struct {
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface { type iFaceMapper interface {
Name() string Name() string
Address() wgaddr.Address Address() iface.WGAddress
IsUserspaceBind() bool IsUserspaceBind() bool
} }
@@ -52,7 +52,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
return nil, fmt.Errorf("create router: %w", err) return nil, fmt.Errorf("create router: %w", err)
} }
m.aclMgr, err = newAclManager(iptablesClient, wgIface) m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
if err != nil { if err != nil {
return nil, fmt.Errorf("create acl manager: %w", err) return nil, fmt.Errorf("create acl manager: %w", err)
} }
@@ -96,22 +96,22 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
// //
// Comment will be ignored because some system this feature is not supported // Comment will be ignored because some system this feature is not supported
func (m *Manager) AddPeerFiltering( func (m *Manager) AddPeerFiltering(
id []byte,
ip net.IP, ip net.IP,
proto firewall.Protocol, protocol firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
} }
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination netip.Prefix,
proto firewall.Protocol, proto firewall.Protocol,
@@ -126,7 +126,7 @@ func (m *Manager) AddRouteFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String()) return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
} }
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
} }
// DeletePeerRule from the firewall by rule definition // DeletePeerRule from the firewall by rule definition
@@ -167,7 +167,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
} }
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Close(stateManager *statemanager.Manager) error { func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -197,13 +197,14 @@ func (m *Manager) AllowNetbird() error {
} }
_, err := m.AddPeerFiltering( _, err := m.AddPeerFiltering(
nil, net.ParseIP("0.0.0.0"),
net.IP{0, 0, 0, 0},
"all", "all",
nil, nil,
nil, nil,
firewall.RuleDirectionIN,
firewall.ActionAccept, firewall.ActionAccept,
"", "",
"",
) )
if err != nil { if err != nil {
return fmt.Errorf("allow netbird interface traffic: %w", err) return fmt.Errorf("allow netbird interface traffic: %w", err)
@@ -214,35 +215,6 @@ func (m *Manager) AllowNetbird() error {
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil } func (m *Manager) Flush() error { return nil }
// SetLogLevel sets the log level for the firewall manager
func (m *Manager) SetLogLevel(log.Level) {
// not supported
}
func (m *Manager) EnableRouting() error {
return nil
}
func (m *Manager) DisableRouting() error {
return nil
}
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.DeleteDNATRule(rule)
}
func getConntrackEstablished() []string { func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
} }

View File

@@ -10,15 +10,15 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface"
) )
var ifaceMock = &iFaceMock{ var ifaceMock = &iFaceMock{
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() iface.WGAddress {
return wgaddr.Address{ return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"), IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"), IP: net.ParseIP("10.20.0.0"),
@@ -31,7 +31,7 @@ var ifaceMock = &iFaceMock{
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct { type iFaceMock struct {
NameFunc func() string NameFunc func() string
AddressFunc func() wgaddr.Address AddressFunc func() iface.WGAddress
} }
func (i *iFaceMock) Name() string { func (i *iFaceMock) Name() string {
@@ -41,7 +41,7 @@ func (i *iFaceMock) Name() string {
panic("NameFunc is not set") panic("NameFunc is not set")
} }
func (i *iFaceMock) Address() wgaddr.Address { func (i *iFaceMock) Address() iface.WGAddress {
if i.AddressFunc != nil { if i.AddressFunc != nil {
return i.AddressFunc() return i.AddressFunc()
} }
@@ -62,20 +62,33 @@ func TestIptablesManager(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
err := manager.Close(nil) err := manager.Reset(nil)
require.NoError(t, err, "clear the manager state") require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
var rule1 []fw.Rule
t.Run("add first rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
for _, r := range rule1 {
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
}
})
var rule2 []fw.Rule var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) { t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{ port := &fw.Port{
IsRange: true, Values: []int{8043: 8046},
Values: []uint16{8043, 8046},
} }
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "") rule2, err = manager.AddPeerFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
for _, r := range rule2 { for _, r := range rule2 {
@@ -84,6 +97,15 @@ func TestIptablesManager(t *testing.T) {
} }
}) })
t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 {
err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
}
})
t.Run("delete second rule", func(t *testing.T) { t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 { for _, r := range rule2 {
err := manager.DeletePeerRule(r) err := manager.DeletePeerRule(r)
@@ -96,29 +118,32 @@ func TestIptablesManager(t *testing.T) {
t.Run("reset check", func(t *testing.T) { t.Run("reset check", func(t *testing.T) {
// add second rule // add second rule
ip := net.ParseIP("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []uint16{5353}} port := &fw.Port{Values: []int{5353}}
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Close(nil) err = manager.Reset(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules) ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
require.NoError(t, err, "failed check chain exists") require.NoError(t, err, "failed check chain exists")
if ok { if ok {
require.NoErrorf(t, err, "chain '%v' still exists after Close", chainNameInputRules) require.NoErrorf(t, err, "chain '%v' still exists after Reset", chainNameInputRules)
} }
}) })
} }
func TestIptablesManagerIPSet(t *testing.T) { func TestIptablesManagerIPSet(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)
mock := &iFaceMock{ mock := &iFaceMock{
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() iface.WGAddress {
return wgaddr.Address{ return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"), IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"), IP: net.ParseIP("10.20.0.0"),
@@ -136,19 +161,39 @@ func TestIptablesManagerIPSet(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
err := manager.Close(nil) err := manager.Reset(nil)
require.NoError(t, err, "clear the manager state") require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
var rule1 []fw.Rule
t.Run("add first rule with set", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddPeerFiltering(
ip, "tcp", nil, port, fw.RuleDirectionOUT,
fw.ActionAccept, "default", "accept HTTP traffic",
)
require.NoError(t, err, "failed to add rule")
for _, r := range rule1 {
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
require.Equal(t, r.(*Rule).ipsetName, "default-dport", "ipset name must be set")
require.Equal(t, r.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
}
})
var rule2 []fw.Rule var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) { t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{ port := &fw.Port{
Values: []uint16{443}, Values: []int{443},
} }
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default") rule2, err = manager.AddPeerFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
"default", "accept HTTPS traffic from ports range",
)
for _, r := range rule2 { for _, r := range rule2 {
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set") require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
@@ -156,6 +201,15 @@ func TestIptablesManagerIPSet(t *testing.T) {
} }
}) })
t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 {
err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
}
})
t.Run("delete second rule", func(t *testing.T) { t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 { for _, r := range rule2 {
err := manager.DeletePeerRule(r) err := manager.DeletePeerRule(r)
@@ -166,7 +220,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
}) })
t.Run("reset check", func(t *testing.T) { t.Run("reset check", func(t *testing.T) {
err = manager.Close(nil) err = manager.Reset(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
}) })
} }
@@ -184,8 +238,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() iface.WGAddress {
return wgaddr.Address{ return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"), IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"), IP: net.ParseIP("10.20.0.0"),
@@ -204,7 +258,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
err := manager.Close(nil) err := manager.Reset(nil)
require.NoError(t, err, "clear the manager state") require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -215,8 +269,12 @@ func TestIptablesCreatePerformance(t *testing.T) {
ip := net.ParseIP("10.20.0.100") ip := net.ParseIP("10.20.0.100")
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}} port := &fw.Port{Values: []int{1000 + i}}
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "") if i%2 == 0 {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
} }

View File

@@ -15,8 +15,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbid "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
@@ -24,36 +23,22 @@ import (
// constants needed to manage and create iptable rules // constants needed to manage and create iptable rules
const ( const (
tableFilter = "filter" tableFilter = "filter"
tableNat = "nat" tableNat = "nat"
tableMangle = "mangle" tableMangle = "mangle"
chainPOSTROUTING = "POSTROUTING" chainPOSTROUTING = "POSTROUTING"
chainPREROUTING = "PREROUTING" chainPREROUTING = "PREROUTING"
chainRTNAT = "NETBIRD-RT-NAT" chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWDIN = "NETBIRD-RT-FWD-IN" chainRTFWD = "NETBIRD-RT-FWD"
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
chainRTPRE = "NETBIRD-RT-PRE" chainRTPRE = "NETBIRD-RT-PRE"
chainRTRDR = "NETBIRD-RT-RDR"
routingFinalForwardJump = "ACCEPT" routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE" routingFinalNatJump = "MASQUERADE"
jumpManglePre = "jump-mangle-pre" jumpPre = "jump-pre"
jumpNatPre = "jump-nat-pre" jumpNat = "jump-nat"
jumpNatPost = "jump-nat-post" matchSet = "--match-set"
matchSet = "--match-set"
dnatSuffix = "_dnat"
snatSuffix = "_snat"
fwdSuffix = "_fwd"
) )
type ruleInfo struct {
chain string
table string
rule []string
}
type routeFilteringRuleParams struct { type routeFilteringRuleParams struct {
Sources []netip.Prefix Sources []netip.Prefix
Destination netip.Prefix Destination netip.Prefix
@@ -77,7 +62,6 @@ type router struct {
legacyManagement bool legacyManagement bool
stateManager *statemanager.Manager stateManager *statemanager.Manager
ipFwdState *ipfwdstate.IPForwardingState
} }
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
@@ -85,7 +69,6 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router,
iptablesClient: iptablesClient, iptablesClient: iptablesClient,
rules: make(map[string][]string), rules: make(map[string][]string),
wgIface: wgIface, wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(),
} }
r.ipsetCounter = refcounter.New( r.ipsetCounter = refcounter.New(
@@ -121,7 +104,6 @@ func (r *router) init(stateManager *statemanager.Manager) error {
} }
func (r *router) AddRouteFiltering( func (r *router) AddRouteFiltering(
id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination netip.Prefix,
proto firewall.Protocol, proto firewall.Protocol,
@@ -129,7 +111,7 @@ func (r *router) AddRouteFiltering(
dPort *firewall.Port, dPort *firewall.Port,
action firewall.Action, action firewall.Action,
) (firewall.Rule, error) { ) (firewall.Rule, error) {
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action) ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok { if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil return ruleKey, nil
} }
@@ -153,16 +135,7 @@ func (r *router) AddRouteFiltering(
} }
rule := genRouteFilteringRuleSpec(params) rule := genRouteFilteringRuleSpec(params)
// Insert DROP rules at the beginning, append ACCEPT rules at the end if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
var err error
if action == firewall.ActionDrop {
// after the established rule
err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...)
} else {
err = r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...)
}
if err != nil {
return nil, fmt.Errorf("add route rule: %v", err) return nil, fmt.Errorf("add route rule: %v", err)
} }
@@ -174,12 +147,12 @@ func (r *router) AddRouteFiltering(
} }
func (r *router) DeleteRouteRule(rule firewall.Rule) error { func (r *router) DeleteRouteRule(rule firewall.Rule) error {
ruleKey := rule.ID() ruleKey := rule.GetRuleID()
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
setName := r.findSetNameInRule(rule) setName := r.findSetNameInRule(rule)
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil { if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil {
return fmt.Errorf("delete route rule: %v", err) return fmt.Errorf("delete route rule: %v", err)
} }
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
@@ -230,10 +203,6 @@ func (r *router) deleteIpSet(setName string) error {
// AddNatRule inserts an iptables rule pair into the nat chain // AddNatRule inserts an iptables rule pair into the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error { func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if r.legacyManagement { if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination) log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil { if err := r.addLegacyRouteRule(pair); err != nil {
@@ -260,10 +229,6 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains // RemoveNatRule removes an iptables rule pair from forwarding and nat chains
func (r *router) RemoveNatRule(pair firewall.RouterPair) error { func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if err := r.removeNatRule(pair); err != nil { if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err) return fmt.Errorf("remove nat rule: %w", err)
} }
@@ -290,7 +255,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
} }
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump} rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
if err := r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...); err != nil { if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
} }
@@ -303,7 +268,7 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil { if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
} }
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
@@ -331,7 +296,7 @@ func (r *router) RemoveAllLegacyRouteRules() error {
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) { if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
continue continue
} }
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil { if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
} else { } else {
delete(r.rules, k) delete(r.rules, k)
@@ -369,11 +334,9 @@ func (r *router) cleanUpDefaultForwardRules() error {
chain string chain string
table string table string
}{ }{
{chainRTFWDIN, tableFilter}, {chainRTFWD, tableFilter},
{chainRTFWDOUT, tableFilter},
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat}, {chainRTNAT, tableNat},
{chainRTRDR, tableNat}, {chainRTPRE, tableMangle},
} { } {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain) ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
if err != nil { if err != nil {
@@ -393,22 +356,16 @@ func (r *router) createContainers() error {
chain string chain string
table string table string
}{ }{
{chainRTFWDIN, tableFilter}, {chainRTFWD, tableFilter},
{chainRTFWDOUT, tableFilter},
{chainRTPRE, tableMangle}, {chainRTPRE, tableMangle},
{chainRTNAT, tableNat}, {chainRTNAT, tableNat},
{chainRTRDR, tableNat},
} { } {
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil { if err := r.createAndSetupChain(chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err) return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
} }
} }
if err := r.insertEstablishedRule(chainRTFWDIN); err != nil { if err := r.insertEstablishedRule(chainRTFWD); err != nil {
return fmt.Errorf("insert established rule: %w", err)
}
if err := r.insertEstablishedRule(chainRTFWDOUT); err != nil {
return fmt.Errorf("insert established rule: %w", err) return fmt.Errorf("insert established rule: %w", err)
} }
@@ -449,6 +406,27 @@ func (r *router) addPostroutingRules() error {
return nil return nil
} }
func (r *router) createAndSetupChain(chain string) error {
table := r.getTableForChain(chain)
if err := r.iptablesClient.NewChain(table, chain); err != nil {
return fmt.Errorf("failed creating chain %s, error: %v", chain, err)
}
return nil
}
func (r *router) getTableForChain(chain string) string {
switch chain {
case chainRTNAT:
return tableNat
case chainRTPRE:
return tableMangle
default:
return tableFilter
}
}
func (r *router) insertEstablishedRule(chain string) error { func (r *router) insertEstablishedRule(chain string) error {
establishedRule := getConntrackEstablished() establishedRule := getConntrackEstablished()
@@ -467,43 +445,28 @@ func (r *router) addJumpRules() error {
// Jump to NAT chain // Jump to NAT chain
natRule := []string{"-j", chainRTNAT} natRule := []string{"-j", chainRTNAT}
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil { if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
return fmt.Errorf("add nat postrouting jump rule: %v", err) return fmt.Errorf("add nat jump rule: %v", err)
} }
r.rules[jumpNatPost] = natRule r.rules[jumpNat] = natRule
// Jump to mangle prerouting chain // Jump to prerouting chain
preRule := []string{"-j", chainRTPRE} preRule := []string{"-j", chainRTPRE}
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil { if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
return fmt.Errorf("add mangle prerouting jump rule: %v", err) return fmt.Errorf("add prerouting jump rule: %v", err)
} }
r.rules[jumpManglePre] = preRule r.rules[jumpPre] = preRule
// Jump to nat prerouting chain
rdrRule := []string{"-j", chainRTRDR}
if err := r.iptablesClient.Insert(tableNat, chainPREROUTING, 1, rdrRule...); err != nil {
return fmt.Errorf("add nat prerouting jump rule: %v", err)
}
r.rules[jumpNatPre] = rdrRule
return nil return nil
} }
func (r *router) cleanJumpRules() error { func (r *router) cleanJumpRules() error {
for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre} { for _, ruleKey := range []string{jumpNat, jumpPre} {
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
var table, chain string table := tableNat
switch ruleKey { chain := chainPOSTROUTING
case jumpNatPost: if ruleKey == jumpPre {
table = tableNat
chain = chainPOSTROUTING
case jumpManglePre:
table = tableMangle table = tableMangle
chain = chainPREROUTING chain = chainPREROUTING
case jumpNatPre:
table = tableNat
chain = chainPREROUTING
default:
return fmt.Errorf("unknown jump rule: %s", ruleKey)
} }
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil { if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
@@ -548,8 +511,6 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
} }
r.rules[ruleKey] = rule r.rules[ruleKey] = rule
r.updateState()
return nil return nil
} }
@@ -565,7 +526,6 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
log.Debugf("marking rule %s not found", ruleKey) log.Debugf("marking rule %s not found", ruleKey)
} }
r.updateState()
return nil return nil
} }
@@ -595,137 +555,6 @@ func (r *router) updateState() {
} }
} }
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return nil, err
}
ruleKey := rule.ID()
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
return rule, nil
}
toDestination := rule.TranslatedAddress.String()
switch {
case len(rule.TranslatedPort.Values) == 0:
// no translated port, use original port
case len(rule.TranslatedPort.Values) == 1:
toDestination += fmt.Sprintf(":%d", rule.TranslatedPort.Values[0])
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
// need the "/originalport" suffix to avoid dnat port randomization
toDestination += fmt.Sprintf(":%d-%d/%d", rule.TranslatedPort.Values[0], rule.TranslatedPort.Values[1], rule.DestinationPort.Values[0])
default:
return nil, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
}
proto := strings.ToLower(string(rule.Protocol))
rules := make(map[string]ruleInfo, 3)
// DNAT rule
dnatRule := []string{
"!", "-i", r.wgIface.Name(),
"-p", proto,
"-j", "DNAT",
"--to-destination", toDestination,
}
dnatRule = append(dnatRule, applyPort("--dport", &rule.DestinationPort)...)
rules[ruleKey+dnatSuffix] = ruleInfo{
table: tableNat,
chain: chainRTRDR,
rule: dnatRule,
}
// SNAT rule
snatRule := []string{
"-o", r.wgIface.Name(),
"-p", proto,
"-d", rule.TranslatedAddress.String(),
"-j", "MASQUERADE",
}
snatRule = append(snatRule, applyPort("--dport", &rule.TranslatedPort)...)
rules[ruleKey+snatSuffix] = ruleInfo{
table: tableNat,
chain: chainRTNAT,
rule: snatRule,
}
// Forward filtering rule, if fwd policy is DROP
forwardRule := []string{
"-o", r.wgIface.Name(),
"-p", proto,
"-d", rule.TranslatedAddress.String(),
"-j", "ACCEPT",
}
forwardRule = append(forwardRule, applyPort("--dport", &rule.TranslatedPort)...)
rules[ruleKey+fwdSuffix] = ruleInfo{
table: tableFilter,
chain: chainRTFWDOUT,
rule: forwardRule,
}
for key, ruleInfo := range rules {
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
if rollbackErr := r.rollbackRules(rules); rollbackErr != nil {
log.Errorf("rollback failed: %v", rollbackErr)
}
return nil, fmt.Errorf("add rule %s: %w", key, err)
}
r.rules[key] = ruleInfo.rule
}
r.updateState()
return rule, nil
}
func (r *router) rollbackRules(rules map[string]ruleInfo) error {
var merr *multierror.Error
for key, ruleInfo := range rules {
if err := r.iptablesClient.DeleteIfExists(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("rollback rule %s: %w", key, err))
// On rollback error, add to rules map for next cleanup
r.rules[key] = ruleInfo.rule
}
}
if merr != nil {
r.updateState()
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
ruleKey := rule.ID()
var merr *multierror.Error
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete DNAT rule: %w", err))
}
delete(r.rules, ruleKey+dnatSuffix)
}
if snatRule, exists := r.rules[ruleKey+snatSuffix]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTNAT, snatRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete SNAT rule: %w", err))
}
delete(r.rules, ruleKey+snatSuffix)
}
if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists {
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, fwdRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
}
delete(r.rules, ruleKey+fwdSuffix)
}
r.updateState()
return nberrors.FormatErrorOrNil(merr)
}
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
var rule []string var rule []string
@@ -761,10 +590,10 @@ func applyPort(flag string, port *firewall.Port) []string {
if len(port.Values) > 1 { if len(port.Values) > 1 {
portList := make([]string, len(port.Values)) portList := make([]string, len(port.Values))
for i, p := range port.Values { for i, p := range port.Values {
portList[i] = strconv.Itoa(int(p)) portList[i] = strconv.Itoa(p)
} }
return []string{"-m", "multiport", flag, strings.Join(portList, ",")} return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
} }
return []string{flag, strconv.Itoa(int(port.Values[0]))} return []string{flag, strconv.Itoa(port.Values[0])}
} }

View File

@@ -39,14 +39,12 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
}() }()
// Now 5 rules: // Now 5 rules:
// 1. established rule forward in // 1. established rule in forward chain
// 2. estbalished rule forward out // 2. jump rule to NAT chain
// 3. jump rule to POST nat chain // 3. jump rule to PRE chain
// 4. jump rule to PRE mangle chain // 4. static outbound masquerade rule
// 5. jump rule to PRE nat chain // 5. static return masquerade rule
// 6. static outbound masquerade rule require.Len(t, manager.rules, 5, "should have created rules map")
// 7. static return masquerade rule
require.Len(t, manager.rules, 7, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT) exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
@@ -241,7 +239,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
destination: netip.MustParsePrefix("10.0.0.0/24"), destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolTCP, proto: firewall.ProtocolTCP,
sPort: nil, sPort: nil,
dPort: &firewall.Port{Values: []uint16{80}}, dPort: &firewall.Port{Values: []int{80}},
direction: firewall.RuleDirectionIN, direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept, action: firewall.ActionAccept,
expectSet: false, expectSet: false,
@@ -254,7 +252,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
}, },
destination: netip.MustParsePrefix("10.0.0.0/8"), destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolUDP, proto: firewall.ProtocolUDP,
sPort: &firewall.Port{Values: []uint16{1024, 2048}, IsRange: true}, sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
dPort: nil, dPort: nil,
direction: firewall.RuleDirectionOUT, direction: firewall.RuleDirectionOUT,
action: firewall.ActionDrop, action: firewall.ActionDrop,
@@ -287,7 +285,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")}, sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
destination: netip.MustParsePrefix("192.168.0.0/16"), destination: netip.MustParsePrefix("192.168.0.0/16"),
proto: firewall.ProtocolTCP, proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []uint16{80, 443, 8080}}, sPort: &firewall.Port{Values: []int{80, 443, 8080}},
dPort: nil, dPort: nil,
direction: firewall.RuleDirectionOUT, direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept, action: firewall.ActionAccept,
@@ -299,7 +297,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
destination: netip.MustParsePrefix("10.0.0.0/24"), destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolUDP, proto: firewall.ProtocolUDP,
sPort: nil, sPort: nil,
dPort: &firewall.Port{Values: []uint16{5000, 5100}, IsRange: true}, dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
direction: firewall.RuleDirectionIN, direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop, action: firewall.ActionDrop,
expectSet: false, expectSet: false,
@@ -309,8 +307,8 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}, sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
destination: netip.MustParsePrefix("172.16.0.0/16"), destination: netip.MustParsePrefix("172.16.0.0/16"),
proto: firewall.ProtocolTCP, proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []uint16{1024, 65535}, IsRange: true}, sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []uint16{22}}, dPort: &firewall.Port{Values: []int{22}},
direction: firewall.RuleDirectionOUT, direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept, action: firewall.ActionAccept,
expectSet: false, expectSet: false,
@@ -330,18 +328,18 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed") require.NoError(t, err, "AddRouteFiltering failed")
// Check if the rule is in the internal map // Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.ID()] rule, ok := r.rules[ruleKey.GetRuleID()]
assert.True(t, ok, "Rule not found in internal map") assert.True(t, ok, "Rule not found in internal map")
// Log the internal rule // Log the internal rule
t.Logf("Internal rule: %v", rule) t.Logf("Internal rule: %v", rule)
// Check if the rule exists in iptables // Check if the rule exists in iptables
exists, err := iptablesClient.Exists(tableFilter, chainRTFWDIN, rule...) exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, rule...)
assert.NoError(t, err, "Failed to check rule existence") assert.NoError(t, err, "Failed to check rule existence")
assert.True(t, exists, "Rule not found in iptables") assert.True(t, exists, "Rule not found in iptables")

View File

@@ -5,13 +5,12 @@ type Rule struct {
ruleID string ruleID string
ipsetName string ipsetName string
specs []string specs []string
mangleSpecs []string ip string
ip string chain string
chain string
} }
// GetRuleID returns the rule id // GetRuleID returns the rule id
func (r *Rule) ID() string { func (r *Rule) GetRuleID() string {
return r.ruleID return r.ruleID
} }

View File

@@ -4,20 +4,21 @@ import (
"fmt" "fmt"
"sync" "sync"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
) )
type InterfaceState struct { type InterfaceState struct {
NameStr string `json:"name"` NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"` WGAddress iface.WGAddress `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"` UserspaceBind bool `json:"userspace_bind"`
} }
func (i *InterfaceState) Name() string { func (i *InterfaceState) Name() string {
return i.NameStr return i.NameStr
} }
func (i *InterfaceState) Address() wgaddr.Address { func (i *InterfaceState) Address() device.WGAddress {
return i.WGAddress return i.WGAddress
} }
@@ -61,7 +62,7 @@ func (s *ShutdownState) Cleanup() error {
ipt.aclMgr.ipsetStore = s.ACLIPsetStore ipt.aclMgr.ipsetStore = s.ACLIPsetStore
} }
if err := ipt.Close(nil); err != nil { if err := ipt.Reset(nil); err != nil {
return fmt.Errorf("reset iptables manager: %w", err) return fmt.Errorf("reset iptables manager: %w", err)
} }

View File

@@ -26,8 +26,8 @@ const (
// Each firewall type for different OS can use different type // Each firewall type for different OS can use different type
// of the properties to hold data of the created rule // of the properties to hold data of the created rule
type Rule interface { type Rule interface {
// ID returns the rule id // GetRuleID returns the rule id
ID() string GetRuleID() string
} }
// RuleDirection is the traffic direction which a rule is applied // RuleDirection is the traffic direction which a rule is applied
@@ -65,13 +65,14 @@ type Manager interface {
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
AddPeerFiltering( AddPeerFiltering(
id []byte,
ip net.IP, ip net.IP,
proto Protocol, proto Protocol,
sPort *Port, sPort *Port,
dPort *Port, dPort *Port,
direction RuleDirection,
action Action, action Action,
ipsetName string, ipsetName string,
comment string,
) ([]Rule, error) ) ([]Rule, error)
// DeletePeerRule from the firewall by rule definition // DeletePeerRule from the firewall by rule definition
@@ -80,15 +81,7 @@ type Manager interface {
// IsServerRouteSupported returns true if the firewall supports server side routing operations // IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool IsServerRouteSupported() bool
AddRouteFiltering( AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error)
id []byte,
sources []netip.Prefix,
destination netip.Prefix,
proto Protocol,
sPort *Port,
dPort *Port,
action Action,
) (Rule, error)
// DeleteRouteRule deletes a routing rule // DeleteRouteRule deletes a routing rule
DeleteRouteRule(rule Rule) error DeleteRouteRule(rule Rule) error
@@ -102,23 +95,11 @@ type Manager interface {
// SetLegacyManagement sets the legacy management mode // SetLegacyManagement sets the legacy management mode
SetLegacyManagement(legacy bool) error SetLegacyManagement(legacy bool) error
// Close closes the firewall manager // Reset firewall to the default state
Close(stateManager *statemanager.Manager) error Reset(stateManager *statemanager.Manager) error
// Flush the changes to firewall controller // Flush the changes to firewall controller
Flush() error Flush() error
SetLogLevel(log.Level)
EnableRouting() error
DisableRouting() error
// AddDNATRule adds a DNAT rule
AddDNATRule(ForwardRule) (Rule, error)
// DeleteDNATRule deletes a DNAT rule
DeleteDNATRule(Rule) error
} }
func GenKey(format string, pair RouterPair) string { func GenKey(format string, pair RouterPair) string {

View File

@@ -1,27 +0,0 @@
package manager
import (
"fmt"
"net/netip"
)
// ForwardRule todo figure out better place to this to avoid circular imports
type ForwardRule struct {
Protocol Protocol
DestinationPort Port
TranslatedAddress netip.Addr
TranslatedPort Port
}
func (r ForwardRule) ID() string {
id := fmt.Sprintf("%s;%s;%s;%s",
r.Protocol,
r.DestinationPort.String(),
r.TranslatedAddress.String(),
r.TranslatedPort.String())
return id
}
func (r ForwardRule) String() string {
return fmt.Sprintf("protocol: %s, destinationPort: %s, translatedAddress: %s, translatedPort: %s", r.Protocol, r.DestinationPort.String(), r.TranslatedAddress.String(), r.TranslatedPort.String())
}

View File

@@ -1,37 +1,36 @@
package manager package manager
import ( import (
"fmt"
"strconv" "strconv"
) )
// Protocol is the protocol of the port
type Protocol string
const (
// ProtocolTCP is the TCP protocol
ProtocolTCP Protocol = "tcp"
// ProtocolUDP is the UDP protocol
ProtocolUDP Protocol = "udp"
// ProtocolICMP is the ICMP protocol
ProtocolICMP Protocol = "icmp"
// ProtocolALL cover all supported protocols
ProtocolALL Protocol = "all"
// ProtocolUnknown unknown protocol
ProtocolUnknown Protocol = "unknown"
)
// Port of the address for firewall rule // Port of the address for firewall rule
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
type Port struct { type Port struct {
// IsRange is true Values contains two values, the first is the start port, the second is the end port // IsRange is true Values contains two values, the first is the start port, the second is the end port
IsRange bool IsRange bool
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports // Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
Values []uint16 Values []int
}
func NewPort(ports ...int) (*Port, error) {
if len(ports) == 0 {
return nil, fmt.Errorf("no port provided")
}
ports16 := make([]uint16, len(ports))
for i, port := range ports {
if port < 1 || port > 65535 {
return nil, fmt.Errorf("invalid port number: %d (must be between 1-65535)", port)
}
ports16[i] = uint16(port)
}
return &Port{
IsRange: len(ports) > 1,
Values: ports16,
}, nil
} }
// String interface implementation // String interface implementation
@@ -41,11 +40,7 @@ func (p *Port) String() string {
if ports != "" { if ports != "" {
ports += "," ports += ","
} }
ports += strconv.Itoa(int(port)) ports += strconv.Itoa(port)
} }
if p.IsRange {
ports = "range:" + ports
}
return ports return ports
} }

View File

@@ -1,19 +0,0 @@
package manager
// Protocol is the protocol of the port
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
type Protocol string
const (
// ProtocolTCP is the TCP protocol
ProtocolTCP Protocol = "tcp"
// ProtocolUDP is the UDP protocol
ProtocolUDP Protocol = "udp"
// ProtocolICMP is the ICMP protocol
ProtocolICMP Protocol = "icmp"
// ProtocolALL cover all supported protocols
ProtocolALL Protocol = "all"
)

View File

@@ -2,9 +2,9 @@ package nftables
import ( import (
"bytes" "bytes"
"encoding/binary"
"fmt" "fmt"
"net" "net"
"slices"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -22,7 +22,8 @@ import (
const ( const (
// rules chains contains the effective ACL rules // rules chains contains the effective ACL rules
chainNameInputRules = "netbird-acl-input-rules" chainNameInputRules = "netbird-acl-input-rules"
chainNameOutputRules = "netbird-acl-output-rules"
// filter chains contains the rules that jump to the rules chains // filter chains contains the rules that jump to the rules chains
chainNameInputFilter = "netbird-acl-input-filter" chainNameInputFilter = "netbird-acl-input-filter"
@@ -44,9 +45,9 @@ type AclManager struct {
wgIface iFaceMapper wgIface iFaceMapper
routingFwChainName string routingFwChainName string
workTable *nftables.Table workTable *nftables.Table
chainInputRules *nftables.Chain chainInputRules *nftables.Chain
chainPrerouting *nftables.Chain chainOutputRules *nftables.Chain
ipsetStore *ipsetStore ipsetStore *ipsetStore
rules map[string]*Rule rules map[string]*Rule
@@ -84,13 +85,14 @@ func (m *AclManager) init(workTable *nftables.Table) error {
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
func (m *AclManager) AddPeerFiltering( func (m *AclManager) AddPeerFiltering(
id []byte,
ip net.IP, ip net.IP,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
var ipset *nftables.Set var ipset *nftables.Set
if ipsetName != "" { if ipsetName != "" {
@@ -102,7 +104,7 @@ func (m *AclManager) AddPeerFiltering(
} }
newRules := make([]firewall.Rule, 0, 2) newRules := make([]firewall.Rule, 0, 2)
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset) ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, direction, action, ipset, comment)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -119,32 +121,23 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
} }
if r.nftSet == nil { if r.nftSet == nil {
if err := m.rConn.DelRule(r.nftRule); err != nil { err := m.rConn.DelRule(r.nftRule)
if err != nil {
log.Errorf("failed to delete rule: %v", err) log.Errorf("failed to delete rule: %v", err)
} }
if r.mangleRule != nil { delete(m.rules, r.GetRuleID())
if err := m.rConn.DelRule(r.mangleRule); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
delete(m.rules, r.ID())
return m.rConn.Flush() return m.rConn.Flush()
} }
ips, ok := m.ipsetStore.ips(r.nftSet.Name) ips, ok := m.ipsetStore.ips(r.nftSet.Name)
if !ok { if !ok {
if err := m.rConn.DelRule(r.nftRule); err != nil { err := m.rConn.DelRule(r.nftRule)
if err != nil {
log.Errorf("failed to delete rule: %v", err) log.Errorf("failed to delete rule: %v", err)
} }
if r.mangleRule != nil { delete(m.rules, r.GetRuleID())
if err := m.rConn.DelRule(r.mangleRule); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
delete(m.rules, r.ID())
return m.rConn.Flush() return m.rConn.Flush()
} }
if _, ok := ips[r.ip.String()]; ok { if _, ok := ips[r.ip.String()]; ok {
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}}) err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}})
if err != nil { if err != nil {
@@ -163,20 +156,16 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
return nil return nil
} }
if err := m.rConn.DelRule(r.nftRule); err != nil { err := m.rConn.DelRule(r.nftRule)
if err != nil {
log.Errorf("failed to delete rule: %v", err) log.Errorf("failed to delete rule: %v", err)
} }
if r.mangleRule != nil { err = m.rConn.Flush()
if err := m.rConn.DelRule(r.mangleRule); err != nil { if err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
if err := m.rConn.Flush(); err != nil {
return err return err
} }
delete(m.rules, r.ID()) delete(m.rules, r.GetRuleID())
m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name) m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name)
if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) { if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) {
@@ -225,6 +214,38 @@ func (m *AclManager) createDefaultAllowRules() error {
Exprs: expIn, Exprs: expIn,
}) })
expOut := []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
// mask
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: []byte{0, 0, 0, 0},
Xor: []byte{0, 0, 0, 0},
},
// net address
&expr.Cmp{
Register: 1,
Data: []byte{0, 0, 0, 0},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainOutputRules,
Position: 0,
Exprs: expOut,
})
if err := m.rConn.Flush(); err != nil { if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err) return fmt.Errorf(flushError, err)
} }
@@ -239,32 +260,25 @@ func (m *AclManager) Flush() error {
return err return err
} }
if err := m.refreshRuleHandles(m.chainInputRules, false); err != nil { if err := m.refreshRuleHandles(m.chainInputRules); err != nil {
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err) log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
} }
if err := m.refreshRuleHandles(m.chainPrerouting, true); err != nil {
log.Errorf("failed to refresh rule handles prerouting chain: %v", err) if err := m.refreshRuleHandles(m.chainOutputRules); err != nil {
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
} }
return nil return nil
} }
func (m *AclManager) addIOFiltering( func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) {
ip net.IP, ruleId := generatePeerRuleId(ip, sPort, dPort, direction, action, ipset)
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipset *nftables.Set,
) (*Rule, error) {
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
if r, ok := m.rules[ruleId]; ok { if r, ok := m.rules[ruleId]; ok {
return &Rule{ return &Rule{
nftRule: r.nftRule, r.nftRule,
mangleRule: r.mangleRule, r.nftSet,
nftSet: r.nftSet, r.ruleID,
ruleID: r.ruleID, ip,
ip: ip,
}, nil }, nil
} }
@@ -296,6 +310,9 @@ func (m *AclManager) addIOFiltering(
if !bytes.HasPrefix(anyIP, rawIP) { if !bytes.HasPrefix(anyIP, rawIP) {
// source address position // source address position
addrOffset := uint32(12) addrOffset := uint32(12)
if direction == firewall.RuleDirectionOUT {
addrOffset += 4 // is ipv4 address length
}
expressions = append(expressions, expressions = append(expressions,
&expr.Payload{ &expr.Payload{
@@ -325,100 +342,73 @@ func (m *AclManager) addIOFiltering(
} }
} }
expressions = append(expressions, applyPort(sPort, true)...) if sPort != nil && len(sPort.Values) != 0 {
expressions = append(expressions, applyPort(dPort, false)...) expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 0,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: encodePort(*sPort),
},
)
}
mainExpressions := slices.Clone(expressions) if dPort != nil && len(dPort.Values) != 0 {
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: encodePort(*dPort),
},
)
}
switch action { switch action {
case firewall.ActionAccept: case firewall.ActionAccept:
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictAccept}) expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept})
case firewall.ActionDrop: case firewall.ActionDrop:
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop}) expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
} }
userData := []byte(ruleId) userData := []byte(strings.Join([]string{ruleId, comment}, " "))
chain := m.chainInputRules var chain *nftables.Chain
if direction == firewall.RuleDirectionIN {
chain = m.chainInputRules
} else {
chain = m.chainOutputRules
}
nftRule := m.rConn.AddRule(&nftables.Rule{ nftRule := m.rConn.AddRule(&nftables.Rule{
Table: m.workTable, Table: m.workTable,
Chain: chain, Chain: chain,
Exprs: mainExpressions, Exprs: expressions,
UserData: userData, UserData: userData,
}) })
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf(flushError, err)
}
rule := &Rule{ rule := &Rule{
nftRule: nftRule, nftRule: nftRule,
mangleRule: m.createPreroutingRule(expressions, userData), nftSet: ipset,
nftSet: ipset, ruleID: ruleId,
ruleID: ruleId, ip: ip,
ip: ip,
} }
m.rules[ruleId] = rule m.rules[ruleId] = rule
if ipset != nil { if ipset != nil {
m.ipsetStore.AddReferenceToIpset(ipset.Name) m.ipsetStore.AddReferenceToIpset(ipset.Name)
} }
return rule, nil return rule, nil
} }
func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
if m.chainPrerouting == nil {
log.Warn("prerouting chain is not created")
return nil
}
preroutingExprs := slices.Clone(expressions)
// interface
preroutingExprs = append([]expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
}, preroutingExprs...)
// local destination and mark
preroutingExprs = append(preroutingExprs,
&expr.Fib{
Register: 1,
ResultADDRTYPE: true,
FlagDADDR: true,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
)
return m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainPrerouting,
Exprs: preroutingExprs,
UserData: userData,
})
}
func (m *AclManager) createDefaultChains() (err error) { func (m *AclManager) createDefaultChains() (err error) {
// chainNameInputRules // chainNameInputRules
chain := m.createChain(chainNameInputRules) chain := m.createChain(chainNameInputRules)
@@ -429,6 +419,15 @@ func (m *AclManager) createDefaultChains() (err error) {
} }
m.chainInputRules = chain m.chainInputRules = chain
// chainNameOutputRules
chain = m.createChain(chainNameOutputRules)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameOutputRules, err)
return err
}
m.chainOutputRules = chain
// netbird-acl-input-filter // netbird-acl-input-filter
// type filter hook input priority filter; policy accept; // type filter hook input priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput) chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
@@ -462,7 +461,7 @@ func (m *AclManager) createDefaultChains() (err error) {
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the // go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
// netbird peer IP. // netbird peer IP.
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error { func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
m.chainPrerouting = m.rConn.AddChain(&nftables.Chain{ preroutingChain := m.rConn.AddChain(&nftables.Chain{
Name: chainNamePrerouting, Name: chainNamePrerouting,
Table: m.workTable, Table: m.workTable,
Type: nftables.ChainTypeFilter, Type: nftables.ChainTypeFilter,
@@ -470,6 +469,8 @@ func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error
Priority: nftables.ChainPriorityMangle, Priority: nftables.ChainPriorityMangle,
}) })
m.addPreroutingRule(preroutingChain)
m.addFwmarkToForward(chainFwFilter) m.addFwmarkToForward(chainFwFilter)
if err := m.rConn.Flush(); err != nil { if err := m.rConn.Flush(); err != nil {
@@ -479,6 +480,43 @@ func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error
return nil return nil
} }
func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: preroutingChain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Fib{
Register: 1,
ResultADDRTYPE: true,
FlagDADDR: true,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
},
})
}
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) { func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
m.rConn.InsertRule(&nftables.Rule{ m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable, Table: m.workTable,
@@ -494,7 +532,8 @@ func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected), Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
}, },
&expr.Verdict{ &expr.Verdict{
Kind: expr.VerdictAccept, Kind: expr.VerdictJump,
Chain: m.chainInputRules.Name,
}, },
}, },
}) })
@@ -641,7 +680,6 @@ func (m *AclManager) flushWithBackoff() (err error) {
for i := 0; ; i++ { for i := 0; ; i++ {
err = m.rConn.Flush() err = m.rConn.Flush()
if err != nil { if err != nil {
log.Debugf("failed to flush nftables: %v", err)
if !strings.Contains(err.Error(), "busy") { if !strings.Contains(err.Error(), "busy") {
return return
} }
@@ -658,7 +696,7 @@ func (m *AclManager) flushWithBackoff() (err error) {
return return
} }
func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) error { func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
if m.workTable == nil || chain == nil { if m.workTable == nil || chain == nil {
return nil return nil
} }
@@ -675,19 +713,22 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) erro
split := bytes.Split(rule.UserData, []byte(" ")) split := bytes.Split(rule.UserData, []byte(" "))
r, ok := m.rules[string(split[0])] r, ok := m.rules[string(split[0])]
if ok { if ok {
if mangle { *r.nftRule = *rule
*r.mangleRule = *rule
} else {
*r.nftRule = *rule
}
} }
} }
return nil return nil
} }
func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string { func generatePeerRuleId(
rulesetID := ":" ip net.IP,
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
ipset *nftables.Set,
) string {
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
if sPort != nil { if sPort != nil {
rulesetID += sPort.String() rulesetID += sPort.String()
} }
@@ -703,6 +744,12 @@ func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, a
return "set:" + ipset.Name + rulesetID return "set:" + ipset.Name + rulesetID
} }
func encodePort(port firewall.Port) []byte {
bs := make([]byte, 2)
binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))
return bs
}
func ifname(n string) []byte { func ifname(n string) []byte {
b := make([]byte, 16) b := make([]byte, 16)
copy(b, n+"\x00") copy(b, n+"\x00")

View File

@@ -14,7 +14,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -29,7 +29,7 @@ const (
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface { type iFaceMapper interface {
Name() string Name() string
Address() wgaddr.Address Address() iface.WGAddress
IsUserspaceBind() bool IsUserspaceBind() bool
} }
@@ -87,7 +87,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
// We only need to record minimal interface state for potential recreation. // We only need to record minimal interface state for potential recreation.
// Unlike iptables, which requires tracking individual rules, nftables maintains // Unlike iptables, which requires tracking individual rules, nftables maintains
// a known state (our netbird table plus a few static rules). This allows for easy // a known state (our netbird table plus a few static rules). This allows for easy
// cleanup using Close() without needing to store specific rules. // cleanup using Reset() without needing to store specific rules.
if err := stateManager.UpdateState(&ShutdownState{ if err := stateManager.UpdateState(&ShutdownState{
InterfaceState: &InterfaceState{ InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(), NameStr: m.wgIface.Name(),
@@ -113,13 +113,14 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
func (m *Manager) AddPeerFiltering( func (m *Manager) AddPeerFiltering(
id []byte,
ip net.IP, ip net.IP,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -129,18 +130,10 @@ func (m *Manager) AddPeerFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", ip.String()) return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
} }
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName) return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
} }
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
id []byte,
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -148,7 +141,7 @@ func (m *Manager) AddRouteFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String()) return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
} }
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
} }
// DeletePeerRule from the firewall by rule definition // DeletePeerRule from the firewall by rule definition
@@ -243,7 +236,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
} }
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Close(stateManager *statemanager.Manager) error { func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -319,19 +312,6 @@ func (m *Manager) cleanupNetbirdTables() error {
return nil return nil
} }
// SetLogLevel sets the log level for the firewall manager
func (m *Manager) SetLogLevel(log.Level) {
// not supported
}
func (m *Manager) EnableRouting() error {
return nil
}
func (m *Manager) DisableRouting() error {
return nil
}
// Flush rule/chain/set operations from the buffer // Flush rule/chain/set operations from the buffer
// //
// Method also get all rules after flush and refreshes handle values in the rulesets // Method also get all rules after flush and refreshes handle values in the rulesets
@@ -343,22 +323,6 @@ func (m *Manager) Flush() error {
return m.aclManager.Flush() return m.aclManager.Flush()
} }
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.DeleteDNATRule(rule)
}
func (m *Manager) createWorkTable() (*nftables.Table, error) { func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil { if err != nil {

View File

@@ -16,15 +16,15 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface"
) )
var ifaceMock = &iFaceMock{ var ifaceMock = &iFaceMock{
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() iface.WGAddress {
return wgaddr.Address{ return iface.WGAddress{
IP: net.ParseIP("100.96.0.1"), IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"), IP: net.ParseIP("100.96.0.0"),
@@ -37,7 +37,7 @@ var ifaceMock = &iFaceMock{
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct { type iFaceMock struct {
NameFunc func() string NameFunc func() string
AddressFunc func() wgaddr.Address AddressFunc func() iface.WGAddress
} }
func (i *iFaceMock) Name() string { func (i *iFaceMock) Name() string {
@@ -47,7 +47,7 @@ func (i *iFaceMock) Name() string {
panic("NameFunc is not set") panic("NameFunc is not set")
} }
func (i *iFaceMock) Address() wgaddr.Address { func (i *iFaceMock) Address() iface.WGAddress {
if i.AddressFunc != nil { if i.AddressFunc != nil {
return i.AddressFunc() return i.AddressFunc()
} }
@@ -65,7 +65,7 @@ func TestNftablesManager(t *testing.T) {
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
defer func() { defer func() {
err = manager.Close(nil) err = manager.Reset(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
@@ -74,7 +74,16 @@ func TestNftablesManager(t *testing.T) {
testClient := &nftables.Conn{} testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "") rule, err := manager.AddPeerFiltering(
ip,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{53}},
fw.RuleDirectionIN,
fw.ActionDrop,
"",
"",
)
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Flush() err = manager.Flush()
@@ -107,7 +116,7 @@ func TestNftablesManager(t *testing.T) {
Kind: expr.VerdictAccept, Kind: expr.VerdictAccept,
}, },
} }
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1) require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions")
ipToAdd, _ := netip.AddrFromSlice(ip) ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap() add := ipToAdd.Unmap()
@@ -162,7 +171,7 @@ func TestNftablesManager(t *testing.T) {
// established rule remains // established rule remains
require.Len(t, rules, 1, "expected 1 rules after deletion") require.Len(t, rules, 1, "expected 1 rules after deletion")
err = manager.Close(nil) err = manager.Reset(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
} }
@@ -171,8 +180,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() iface.WGAddress {
return wgaddr.Address{ return iface.WGAddress{
IP: net.ParseIP("100.96.0.1"), IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"), IP: net.ParseIP("100.96.0.0"),
@@ -191,7 +200,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
defer func() { defer func() {
if err := manager.Close(nil); err != nil { if err := manager.Reset(nil); err != nil {
t.Errorf("clear the manager state: %v", err) t.Errorf("clear the manager state: %v", err)
} }
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -200,8 +209,12 @@ func TestNFtablesCreatePerformance(t *testing.T) {
ip := net.ParseIP("10.20.0.100") ip := net.ParseIP("10.20.0.100")
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}} port := &fw.Port{Values: []int{1000 + i}}
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "") if i%2 == 0 {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
if i%100 == 0 { if i%100 == 0 {
@@ -274,7 +287,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))
t.Cleanup(func() { t.Cleanup(func() {
err := manager.Close(nil) err := manager.Reset(nil)
require.NoError(t, err, "failed to reset manager state") require.NoError(t, err, "failed to reset manager state")
// Verify iptables output after reset // Verify iptables output after reset
@@ -283,16 +296,24 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
}) })
ip := net.ParseIP("100.96.0.1") ip := net.ParseIP("100.96.0.1")
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(
ip,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{80}},
fw.RuleDirectionIN,
fw.ActionAccept,
"",
"test rule",
)
require.NoError(t, err, "failed to add peer filtering rule") require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering( _, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")}, []netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
netip.MustParsePrefix("10.1.0.0/24"), netip.MustParsePrefix("10.1.0.0/24"),
fw.ProtocolTCP, fw.ProtocolTCP,
nil, nil,
&fw.Port{Values: []uint16{443}}, &fw.Port{Values: []int{443}},
fw.ActionAccept, fw.ActionAccept,
) )
require.NoError(t, err, "failed to add route filtering rule") require.NoError(t, err, "failed to add route filtering rule")
@@ -308,18 +329,3 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
stdout, stderr = runIptablesSave(t) stdout, stderr = runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr) verifyIptablesOutput(t, stdout, stderr)
} }
func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) {
t.Helper()
require.Equal(t, len(got), len(want), "expression count mismatch")
for i := range got {
if _, isCounter := got[i].(*expr.Counter); isCounter {
_, wantIsCounter := want[i].(*expr.Counter)
require.True(t, wantIsCounter, "expected Counter at index %d", i)
continue
}
require.Equal(t, got[i], want[i], "expression mismatch at index %d", i)
}
}

View File

@@ -14,31 +14,23 @@ import (
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil" "github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
"github.com/google/nftables/xt"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbid "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
const ( const (
tableNat = "nat" chainNameRoutingFw = "netbird-rt-fwd"
chainNameNatPrerouting = "PREROUTING" chainNameRoutingNat = "netbird-rt-postrouting"
chainNameRoutingFw = "netbird-rt-fwd" chainNameForward = "FORWARD"
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameRoutingRdr = "netbird-rt-redirect"
chainNameForward = "FORWARD"
userDataAcceptForwardRuleIif = "frwacceptiif" userDataAcceptForwardRuleIif = "frwacceptiif"
userDataAcceptForwardRuleOif = "frwacceptoif" userDataAcceptForwardRuleOif = "frwacceptoif"
dnatSuffix = "_dnat"
snatSuffix = "_snat"
) )
const refreshRulesMapError = "refresh rules map: %w" const refreshRulesMapError = "refresh rules map: %w"
@@ -57,18 +49,16 @@ type router struct {
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set] ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
wgIface iFaceMapper wgIface iFaceMapper
ipFwdState *ipfwdstate.IPForwardingState
legacyManagement bool legacyManagement bool
} }
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) { func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
r := &router{ r := &router{
conn: &nftables.Conn{}, conn: &nftables.Conn{},
workTable: workTable, workTable: workTable,
chains: make(map[string]*nftables.Chain), chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule), rules: make(map[string]*nftables.Rule),
wgIface: wgIface, wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(),
} }
r.ipsetCounter = refcounter.New( r.ipsetCounter = refcounter.New(
@@ -108,52 +98,7 @@ func (r *router) Reset() error {
// clear without deleting the ipsets, the nf table will be deleted by the caller // clear without deleting the ipsets, the nf table will be deleted by the caller
r.ipsetCounter.Clear() r.ipsetCounter.Clear()
var merr *multierror.Error return r.removeAcceptForwardRules()
if err := r.removeAcceptForwardRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove accept forward rules: %w", err))
}
if err := r.removeNatPreroutingRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) removeNatPreroutingRules() error {
table := &nftables.Table{
Name: tableNat,
Family: nftables.TableFamilyIPv4,
}
chain := &nftables.Chain{
Name: chainNameNatPrerouting,
Table: table,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
}
rules, err := r.conn.GetRules(table, chain)
if err != nil {
return fmt.Errorf("get rules from nat table: %w", err)
}
var merr *multierror.Error
// Delete rules that have our UserData suffix
for _, rule := range rules {
if len(rule.UserData) == 0 || !strings.HasSuffix(string(rule.UserData), dnatSuffix) {
continue
}
if err := r.conn.DelRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete rule %s: %w", rule.UserData, err))
}
}
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
return nberrors.FormatErrorOrNil(merr)
} }
func (r *router) loadFilterTable() (*nftables.Table, error) { func (r *router) loadFilterTable() (*nftables.Table, error) {
@@ -188,22 +133,14 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeNAT, Type: nftables.ChainTypeNAT,
}) })
r.chains[chainNameRoutingRdr] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingRdr,
Table: r.workTable,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
})
// Chain is created by acl manager // Chain is created by acl manager
// TODO: move creation to a common place // TODO: move creation to a common place
r.chains[chainNamePrerouting] = &nftables.Chain{ r.chains[chainNamePrerouting] = &nftables.Chain{
Name: chainNamePrerouting, Name: chainNamePrerouting,
Table: r.workTable, Table: r.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting, Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle, Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
} }
// Add the single NAT rule that matches on mark // Add the single NAT rule that matches on mark
@@ -228,7 +165,6 @@ func (r *router) createContainers() error {
// AddRouteFiltering appends a nftables rule to the routing chain // AddRouteFiltering appends a nftables rule to the routing chain
func (r *router) AddRouteFiltering( func (r *router) AddRouteFiltering(
id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination netip.Prefix,
proto firewall.Protocol, proto firewall.Protocol,
@@ -237,7 +173,7 @@ func (r *router) AddRouteFiltering(
action firewall.Action, action firewall.Action,
) (firewall.Rule, error) { ) (firewall.Rule, error) {
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action) ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok { if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil return ruleKey, nil
} }
@@ -297,13 +233,7 @@ func (r *router) AddRouteFiltering(
UserData: []byte(ruleKey), UserData: []byte(ruleKey),
} }
// Insert DROP rules at the beginning, append ACCEPT rules at the end rule = r.conn.AddRule(rule)
if action == firewall.ActionDrop {
// TODO: Insert after the established rule
rule = r.conn.InsertRule(rule)
} else {
rule = r.conn.AddRule(rule)
}
log.Tracef("Adding route rule %s", spew.Sdump(rule)) log.Tracef("Adding route rule %s", spew.Sdump(rule))
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
@@ -345,7 +275,7 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
ruleKey := rule.ID() ruleKey := rule.GetRuleID()
nftRule, exists := r.rules[ruleKey] nftRule, exists := r.rules[ruleKey]
if !exists { if !exists {
log.Debugf("route rule %s not found", ruleKey) log.Debugf("route rule %s not found", ruleKey)
@@ -474,10 +404,6 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
// AddNatRule appends a nftables rule pair to the nat chain // AddNatRule appends a nftables rule pair to the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error { func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
@@ -904,10 +830,6 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
// RemoveNatRule removes the prerouting mark rule // RemoveNatRule removes the prerouting mark rule
func (r *router) RemoveNatRule(pair firewall.RouterPair) error { func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
@@ -968,269 +890,6 @@ func (r *router) refreshRulesMap() error {
return nil return nil
} }
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return nil, err
}
ruleKey := rule.ID()
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
return rule, nil
}
protoNum, err := protoToInt(rule.Protocol)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err)
}
if err := r.addDnatRedirect(rule, protoNum, ruleKey); err != nil {
return nil, err
}
r.addDnatMasq(rule, protoNum, ruleKey)
// Unlike iptables, there's no point in adding "out" rules in the forward chain here as our policy is ACCEPT.
// To overcome DROP policies in other chains, we'd have to add rules to the chains there.
// We also cannot just add "oif <iface> accept" there and filter in our own table as we don't know what is supposed to be allowed.
// TODO: find chains with drop policies and add rules there
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush rules: %w", err)
}
return &rule, nil
}
func (r *router) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, ruleKey string) error {
dnatExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
}
dnatExprs = append(dnatExprs, applyPort(&rule.DestinationPort, false)...)
// shifted translated port is not supported in nftables, so we hand this over to xtables
if rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2 {
if rule.TranslatedPort.Values[0] != rule.DestinationPort.Values[0] ||
rule.TranslatedPort.Values[1] != rule.DestinationPort.Values[1] {
return r.addXTablesRedirect(dnatExprs, ruleKey, rule)
}
}
additionalExprs, regProtoMin, regProtoMax, err := r.handleTranslatedPort(rule)
if err != nil {
return err
}
dnatExprs = append(dnatExprs, additionalExprs...)
dnatExprs = append(dnatExprs,
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4),
RegAddrMin: 1,
RegProtoMin: regProtoMin,
RegProtoMax: regProtoMax,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingRdr],
Exprs: dnatExprs,
UserData: []byte(ruleKey + dnatSuffix),
}
r.conn.AddRule(dnatRule)
r.rules[ruleKey+dnatSuffix] = dnatRule
return nil
}
func (r *router) handleTranslatedPort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
switch {
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
return r.handlePortRange(rule)
case len(rule.TranslatedPort.Values) == 0:
return r.handleAddressOnly(rule)
case len(rule.TranslatedPort.Values) == 1:
return r.handleSinglePort(rule)
default:
return nil, 0, 0, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
}
}
func (r *router) handlePortRange(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
},
&expr.Immediate{
Register: 3,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[1]),
},
}
return exprs, 2, 3, nil
}
func (r *router) handleAddressOnly(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
}
return exprs, 0, 0, nil
}
func (r *router) handleSinglePort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
},
}
return exprs, 2, 0, nil
}
func (r *router) addXTablesRedirect(dnatExprs []expr.Any, ruleKey string, rule firewall.ForwardRule) error {
dnatExprs = append(dnatExprs,
&expr.Counter{},
&expr.Target{
Name: "DNAT",
Rev: 2,
Info: &xt.NatRange2{
NatRange: xt.NatRange{
Flags: uint(xt.NatRangeMapIPs | xt.NatRangeProtoSpecified | xt.NatRangeProtoOffset),
MinIP: rule.TranslatedAddress.AsSlice(),
MaxIP: rule.TranslatedAddress.AsSlice(),
MinPort: rule.TranslatedPort.Values[0],
MaxPort: rule.TranslatedPort.Values[1],
},
BasePort: rule.DestinationPort.Values[0],
},
},
)
dnatRule := &nftables.Rule{
Table: &nftables.Table{
Name: tableNat,
Family: nftables.TableFamilyIPv4,
},
Chain: &nftables.Chain{
Name: chainNameNatPrerouting,
Table: r.filterTable,
Type: nftables.ChainTypeNAT,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
},
Exprs: dnatExprs,
UserData: []byte(ruleKey + dnatSuffix),
}
r.conn.AddRule(dnatRule)
r.rules[ruleKey+dnatSuffix] = dnatRule
return nil
}
func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey string) {
masqExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
}
masqExprs = append(masqExprs, applyPort(&rule.TranslatedPort, false)...)
masqExprs = append(masqExprs, &expr.Masq{})
masqRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: masqExprs,
UserData: []byte(ruleKey + snatSuffix),
}
r.conn.AddRule(masqRule)
r.rules[ruleKey+snatSuffix] = masqRule
}
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
ruleKey := rule.ID()
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
if err := r.conn.DelRule(dnatRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
}
}
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
if err := r.conn.DelRule(masqRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
}
}
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
if merr == nil {
delete(r.rules, ruleKey+dnatSuffix)
delete(r.rules, ruleKey+snatSuffix)
}
return nberrors.FormatErrorOrNil(merr)
}
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR // generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any { func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
var offset uint32 var offset uint32
@@ -1294,11 +953,15 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
if port.IsRange && len(port.Values) == 2 { if port.IsRange && len(port.Values) == 2 {
// Handle port range // Handle port range
exprs = append(exprs, exprs = append(exprs,
&expr.Range{ &expr.Cmp{
Op: expr.CmpOpEq, Op: expr.CmpOpGte,
Register: 1, Register: 1,
FromData: binaryutil.BigEndian.PutUint16(port.Values[0]), Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[0])),
ToData: binaryutil.BigEndian.PutUint16(port.Values[1]), },
&expr.Cmp{
Op: expr.CmpOpLte,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[1])),
}, },
) )
} else { } else {
@@ -1317,7 +980,7 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
exprs = append(exprs, &expr.Cmp{ exprs = append(exprs, &expr.Cmp{
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
Register: 1, Register: 1,
Data: binaryutil.BigEndian.PutUint16(p), Data: binaryutil.BigEndian.PutUint16(uint16(p)),
}) })
} }
} }

View File

@@ -38,7 +38,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
// need fw manager to init both acl mgr and router for all chains to be present // need fw manager to init both acl mgr and router for all chains to be present
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock)
t.Cleanup(func() { t.Cleanup(func() {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Reset(nil))
}) })
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))
@@ -127,7 +127,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock)
t.Cleanup(func() { t.Cleanup(func() {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Reset(nil))
}) })
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))
@@ -222,7 +222,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
destination: netip.MustParsePrefix("10.0.0.0/24"), destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolTCP, proto: firewall.ProtocolTCP,
sPort: nil, sPort: nil,
dPort: &firewall.Port{Values: []uint16{80}}, dPort: &firewall.Port{Values: []int{80}},
direction: firewall.RuleDirectionIN, direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept, action: firewall.ActionAccept,
expectSet: false, expectSet: false,
@@ -235,7 +235,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
}, },
destination: netip.MustParsePrefix("10.0.0.0/8"), destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolUDP, proto: firewall.ProtocolUDP,
sPort: &firewall.Port{Values: []uint16{1024, 2048}, IsRange: true}, sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
dPort: nil, dPort: nil,
direction: firewall.RuleDirectionOUT, direction: firewall.RuleDirectionOUT,
action: firewall.ActionDrop, action: firewall.ActionDrop,
@@ -268,7 +268,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")}, sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
destination: netip.MustParsePrefix("192.168.0.0/16"), destination: netip.MustParsePrefix("192.168.0.0/16"),
proto: firewall.ProtocolTCP, proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []uint16{80, 443, 8080}}, sPort: &firewall.Port{Values: []int{80, 443, 8080}},
dPort: nil, dPort: nil,
direction: firewall.RuleDirectionOUT, direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept, action: firewall.ActionAccept,
@@ -280,7 +280,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
destination: netip.MustParsePrefix("10.0.0.0/24"), destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolUDP, proto: firewall.ProtocolUDP,
sPort: nil, sPort: nil,
dPort: &firewall.Port{Values: []uint16{5000, 5100}, IsRange: true}, dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
direction: firewall.RuleDirectionIN, direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop, action: firewall.ActionDrop,
expectSet: false, expectSet: false,
@@ -290,8 +290,8 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}, sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
destination: netip.MustParsePrefix("172.16.0.0/16"), destination: netip.MustParsePrefix("172.16.0.0/16"),
proto: firewall.ProtocolTCP, proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []uint16{1024, 65535}, IsRange: true}, sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []uint16{22}}, dPort: &firewall.Port{Values: []int{22}},
direction: firewall.RuleDirectionOUT, direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept, action: firewall.ActionAccept,
expectSet: false, expectSet: false,
@@ -311,7 +311,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed") require.NoError(t, err, "AddRouteFiltering failed")
t.Cleanup(func() { t.Cleanup(func() {
@@ -319,7 +319,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
}) })
// Check if the rule is in the internal map // Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.ID()] rule, ok := r.rules[ruleKey.GetRuleID()]
assert.True(t, ok, "Rule not found in internal map") assert.True(t, ok, "Rule not found in internal map")
t.Log("Internal rule expressions:") t.Log("Internal rule expressions:")
@@ -336,7 +336,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
var nftRule *nftables.Rule var nftRule *nftables.Rule
for _, rule := range rules { for _, rule := range rules {
if string(rule.UserData) == ruleKey.ID() { if string(rule.UserData) == ruleKey.GetRuleID() {
nftRule = rule nftRule = rule
break break
} }
@@ -595,20 +595,16 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 { if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 {
payloadFound = true payloadFound = true
} }
case *expr.Range: case *expr.Cmp:
if port.IsRange && len(port.Values) == 2 { if port.IsRange {
fromPort := binary.BigEndian.Uint16(ex.FromData) if ex.Op == expr.CmpOpGte || ex.Op == expr.CmpOpLte {
toPort := binary.BigEndian.Uint16(ex.ToData)
if fromPort == port.Values[0] && toPort == port.Values[1] {
portMatchFound = true portMatchFound = true
} }
} } else {
case *expr.Cmp:
if !port.IsRange {
if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 { if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 {
portValue := binary.BigEndian.Uint16(ex.Data) portValue := binary.BigEndian.Uint16(ex.Data)
for _, p := range port.Values { for _, p := range port.Values {
if p == portValue { if uint16(p) == portValue {
portMatchFound = true portMatchFound = true
break break
} }

View File

@@ -8,14 +8,13 @@ import (
// Rule to handle management of rules // Rule to handle management of rules
type Rule struct { type Rule struct {
nftRule *nftables.Rule nftRule *nftables.Rule
mangleRule *nftables.Rule nftSet *nftables.Set
nftSet *nftables.Set ruleID string
ruleID string ip net.IP
ip net.IP
} }
// GetRuleID returns the rule id // GetRuleID returns the rule id
func (r *Rule) ID() string { func (r *Rule) GetRuleID() string {
return r.ruleID return r.ruleID
} }

View File

@@ -3,20 +3,21 @@ package nftables
import ( import (
"fmt" "fmt"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
) )
type InterfaceState struct { type InterfaceState struct {
NameStr string `json:"name"` NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"` WGAddress iface.WGAddress `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"` UserspaceBind bool `json:"userspace_bind"`
} }
func (i *InterfaceState) Name() string { func (i *InterfaceState) Name() string {
return i.NameStr return i.NameStr
} }
func (i *InterfaceState) Address() wgaddr.Address { func (i *InterfaceState) Address() device.WGAddress {
return i.WGAddress return i.WGAddress
} }
@@ -38,7 +39,7 @@ func (s *ShutdownState) Cleanup() error {
return fmt.Errorf("create nftables manager: %w", err) return fmt.Errorf("create nftables manager: %w", err)
} }
if err := nft.Close(nil); err != nil { if err := nft.Reset(nil); err != nil {
return fmt.Errorf("reset nftables manager: %w", err) return fmt.Errorf("reset nftables manager: %w", err)
} }

View File

@@ -3,49 +3,35 @@
package uspfilter package uspfilter
import ( import (
"context" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"net/netip"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Close(stateManager *statemanager.Manager) error { func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.outgoingRules = make(map[netip.Addr]RuleSet) m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet) m.incomingRules = make(map[string]RuleSet)
if m.udpTracker != nil { if m.udpTracker != nil {
m.udpTracker.Close() m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
} }
if m.icmpTracker != nil { if m.icmpTracker != nil {
m.icmpTracker.Close() m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
} }
if m.tcpTracker != nil { if m.tcpTracker != nil {
m.tcpTracker.Close() m.tcpTracker.Close()
} m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
} }
if m.nativeFirewall != nil { if m.nativeFirewall != nil {
return m.nativeFirewall.Close(stateManager) return m.nativeFirewall.Reset(stateManager)
} }
return nil return nil
} }

View File

@@ -1,12 +1,9 @@
package uspfilter package uspfilter
import ( import (
"context"
"fmt" "fmt"
"net/netip"
"os/exec" "os/exec"
"syscall" "syscall"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -23,38 +20,26 @@ const (
) )
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Close(*statemanager.Manager) error { func (m *Manager) Reset(*statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.outgoingRules = make(map[netip.Addr]RuleSet) m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[netip.Addr]RuleSet) m.incomingRules = make(map[string]RuleSet)
if m.udpTracker != nil { if m.udpTracker != nil {
m.udpTracker.Close() m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger) m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
} }
if m.icmpTracker != nil { if m.icmpTracker != nil {
m.icmpTracker.Close() m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger) m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
} }
if m.tcpTracker != nil { if m.tcpTracker != nil {
m.tcpTracker.Close() m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.Stop()
}
if m.logger != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := m.logger.Stop(ctx); err != nil {
log.Errorf("failed to shutdown logger: %v", err)
}
} }
if !isWindowsFirewallReachable() { if !isWindowsFirewallReachable() {

View File

@@ -1,16 +0,0 @@
package common
import (
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
SetFilter(device.PacketFilter) error
Address() wgaddr.Address
GetWGDevice() *wgdevice.Device
GetDevice() *device.FilteredDevice
}

View File

@@ -1,27 +1,21 @@
// common.go
package conntrack package conntrack
import ( import (
"fmt" "net"
"net/netip" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/google/uuid"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
// BaseConnTrack provides common fields and locking for all connection types // BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct { type BaseConnTrack struct {
FlowId uuid.UUID SourceIP net.IP
Direction nftypes.Direction DestIP net.IP
SourceIP netip.Addr SourcePort uint16
DestIP netip.Addr DestPort uint16
lastSeen atomic.Int64 lastSeen atomic.Int64 // Unix nano for atomic access
PacketsTx atomic.Uint64 established atomic.Bool
PacketsRx atomic.Uint64
BytesTx atomic.Uint64
BytesRx atomic.Uint64
} }
// these small methods will be inlined by the compiler // these small methods will be inlined by the compiler
@@ -31,15 +25,14 @@ func (b *BaseConnTrack) UpdateLastSeen() {
b.lastSeen.Store(time.Now().UnixNano()) b.lastSeen.Store(time.Now().UnixNano())
} }
// UpdateCounters safely updates the packet and byte counters // IsEstablished safely checks if connection is established
func (b *BaseConnTrack) UpdateCounters(direction nftypes.Direction, bytes int) { func (b *BaseConnTrack) IsEstablished() bool {
if direction == nftypes.Egress { return b.established.Load()
b.PacketsTx.Add(1) }
b.BytesTx.Add(uint64(bytes))
} else { // SetEstablished safely sets the established state
b.PacketsRx.Add(1) func (b *BaseConnTrack) SetEstablished(state bool) {
b.BytesRx.Add(uint64(bytes)) b.established.Store(state)
}
} }
// GetLastSeen safely gets the last seen timestamp // GetLastSeen safely gets the last seen timestamp
@@ -53,14 +46,92 @@ func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
return time.Since(lastSeen) > timeout return time.Since(lastSeen) > timeout
} }
// IPAddr is a fixed-size IP address to avoid allocations
type IPAddr [16]byte
// MakeIPAddr creates an IPAddr from net.IP
func MakeIPAddr(ip net.IP) (addr IPAddr) {
// Optimization: check for v4 first as it's more common
if ip4 := ip.To4(); ip4 != nil {
copy(addr[12:], ip4)
} else {
copy(addr[:], ip.To16())
}
return addr
}
// ConnKey uniquely identifies a connection // ConnKey uniquely identifies a connection
type ConnKey struct { type ConnKey struct {
SrcIP netip.Addr SrcIP IPAddr
DstIP netip.Addr DstIP IPAddr
SrcPort uint16 SrcPort uint16
DstPort uint16 DstPort uint16
} }
func (c ConnKey) String() string { // makeConnKey creates a connection key
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort) func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
return ConnKey{
SrcIP: MakeIPAddr(srcIP),
DstIP: MakeIPAddr(dstIP),
SrcPort: srcPort,
DstPort: dstPort,
}
}
// ValidateIPs checks if IPs match without allocation
func ValidateIPs(connIP IPAddr, pktIP net.IP) bool {
if ip4 := pktIP.To4(); ip4 != nil {
// Compare IPv4 addresses (last 4 bytes)
for i := 0; i < 4; i++ {
if connIP[12+i] != ip4[i] {
return false
}
}
return true
}
// Compare full IPv6 addresses
ip6 := pktIP.To16()
for i := 0; i < 16; i++ {
if connIP[i] != ip6[i] {
return false
}
}
return true
}
// PreallocatedIPs is a pool of IP byte slices to reduce allocations
type PreallocatedIPs struct {
sync.Pool
}
// NewPreallocatedIPs creates a new IP pool
func NewPreallocatedIPs() *PreallocatedIPs {
return &PreallocatedIPs{
Pool: sync.Pool{
New: func() interface{} {
ip := make(net.IP, 16)
return &ip
},
},
}
}
// Get retrieves an IP from the pool
func (p *PreallocatedIPs) Get() net.IP {
return *p.Pool.Get().(*net.IP)
}
// Put returns an IP to the pool
func (p *PreallocatedIPs) Put(ip net.IP) {
p.Pool.Put(&ip)
}
// copyIP copies an IP address efficiently
func copyIP(dst, src net.IP) {
if len(src) == 16 {
copy(dst, src)
} else {
// Handle IPv4
copy(dst[12:], src.To4())
}
} }

View File

@@ -1,67 +1,114 @@
package conntrack package conntrack
import ( import (
"context" "net"
"net/netip"
"testing" "testing"
"github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/internal/netflow"
) )
var logger = log.NewFromLogrus(logrus.StandardLogger()) func BenchmarkIPOperations(b *testing.B) {
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger() b.Run("MakeIPAddr", func(b *testing.B) {
ip := net.ParseIP("192.168.1.1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = MakeIPAddr(ip)
}
})
b.Run("ValidateIPs", func(b *testing.B) {
ip1 := net.ParseIP("192.168.1.1")
ip2 := net.ParseIP("192.168.1.1")
addr := MakeIPAddr(ip1)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = ValidateIPs(addr, ip2)
}
})
b.Run("IPPool", func(b *testing.B) {
pool := NewPreallocatedIPs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
ip := pool.Get()
pool.Put(ip)
}
})
}
func BenchmarkAtomicOperations(b *testing.B) {
conn := &BaseConnTrack{}
b.Run("UpdateLastSeen", func(b *testing.B) {
for i := 0; i < b.N; i++ {
conn.UpdateLastSeen()
}
})
b.Run("IsEstablished", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = conn.IsEstablished()
}
})
b.Run("SetEstablished", func(b *testing.B) {
for i := 0; i < b.N; i++ {
conn.SetEstablished(i%2 == 0)
}
})
b.Run("GetLastSeen", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = conn.GetLastSeen()
}
})
}
// Memory pressure tests // Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) { func BenchmarkMemoryPressure(b *testing.B) {
b.Run("TCPHighLoad", func(b *testing.B) { b.Run("TCPHighLoad", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) tracker := NewTCPTracker(DefaultTCPTimeout)
defer tracker.Close() defer tracker.Close()
// Generate different IPs // Generate different IPs
srcIPs := make([]netip.Addr, 100) srcIPs := make([]net.IP, 100)
dstIPs := make([]netip.Addr, 100) dstIPs := make([]net.IP, 100)
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)}) srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)}) dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs) srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs) dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn, 0) tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn)
// Simulate some valid inbound packets // Simulate some valid inbound packets
if i%3 == 0 { if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck, 0) tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck)
} }
} }
}) })
b.Run("UDPHighLoad", func(b *testing.B) { b.Run("UDPHighLoad", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger) tracker := NewUDPTracker(DefaultUDPTimeout)
defer tracker.Close() defer tracker.Close()
// Generate different IPs // Generate different IPs
srcIPs := make([]netip.Addr, 100) srcIPs := make([]net.IP, 100)
dstIPs := make([]netip.Addr, 100) dstIPs := make([]net.IP, 100)
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)}) srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)}) dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs) srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs) dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, 0) tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80)
// Simulate some valid inbound packets // Simulate some valid inbound packets
if i%3 == 0 { if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), 0) tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535))
} }
} }
}) })

View File

@@ -1,17 +1,11 @@
package conntrack package conntrack
import ( import (
"context" "net"
"fmt"
"net/netip"
"sync" "sync"
"time" "time"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
const ( const (
@@ -23,223 +17,154 @@ const (
// ICMPConnKey uniquely identifies an ICMP connection // ICMPConnKey uniquely identifies an ICMP connection
type ICMPConnKey struct { type ICMPConnKey struct {
SrcIP netip.Addr // Supports both IPv4 and IPv6
DstIP netip.Addr SrcIP [16]byte
ID uint16 DstIP [16]byte
} Sequence uint16 // ICMP sequence number
ID uint16 // ICMP identifier
func (i ICMPConnKey) String() string {
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
} }
// ICMPConnTrack represents an ICMP connection state // ICMPConnTrack represents an ICMP connection state
type ICMPConnTrack struct { type ICMPConnTrack struct {
BaseConnTrack BaseConnTrack
ICMPType uint8 Sequence uint16
ICMPCode uint8 ID uint16
} }
// ICMPTracker manages ICMP connection states // ICMPTracker manages ICMP connection states
type ICMPTracker struct { type ICMPTracker struct {
logger *nblog.Logger
connections map[ICMPConnKey]*ICMPConnTrack connections map[ICMPConnKey]*ICMPConnTrack
timeout time.Duration timeout time.Duration
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex mutex sync.RWMutex
flowLogger nftypes.FlowLogger done chan struct{}
ipPool *PreallocatedIPs
} }
// NewICMPTracker creates a new ICMP connection tracker // NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker { func NewICMPTracker(timeout time.Duration) *ICMPTracker {
if timeout == 0 { if timeout == 0 {
timeout = DefaultICMPTimeout timeout = DefaultICMPTimeout
} }
ctx, cancel := context.WithCancel(context.Background())
tracker := &ICMPTracker{ tracker := &ICMPTracker{
logger: logger,
connections: make(map[ICMPConnKey]*ICMPConnTrack), connections: make(map[ICMPConnKey]*ICMPConnTrack),
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval), cleanupTicker: time.NewTicker(ICMPCleanupInterval),
tickerCancel: cancel, done: make(chan struct{}),
flowLogger: flowLogger, ipPool: NewPreallocatedIPs(),
} }
go tracker.cleanupRoutine(ctx) go tracker.cleanupRoutine()
return tracker return tracker
} }
func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, direction nftypes.Direction, size int) (ICMPConnKey, bool) { // TrackOutbound records an outbound ICMP Echo Request
key := ICMPConnKey{ func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
SrcIP: srcIP, key := makeICMPKey(srcIP, dstIP, id, seq)
DstIP: dstIP, now := time.Now().UnixNano()
ID: id,
}
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if exists {
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, true
}
return key, false
}
// TrackOutbound records an outbound ICMP connection
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
}
}
// TrackInbound records an inbound ICMP Echo Request
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
}
// track is the common implementation for tracking both inbound and outbound ICMP connections
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
if exists {
return
}
typ, code := typecode.Type(), typecode.Code()
// non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
return
}
conn := &ICMPConnTrack{
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
},
ICMPType: typ,
ICMPCode: code,
}
conn.UpdateLastSeen()
t.mutex.Lock() t.mutex.Lock()
t.connections[key] = conn conn, exists := t.connections[key]
if !exists {
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &ICMPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
},
ID: id,
Sequence: seq,
}
conn.lastSeen.Store(now)
conn.established.Store(true)
t.connections[key] = conn
}
t.mutex.Unlock() t.mutex.Unlock()
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code) conn.lastSeen.Store(now)
t.sendEvent(nftypes.TypeStart, conn, ruleId)
} }
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request // IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool { func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
if icmpType != uint8(layers.ICMPv4TypeEchoReply) { switch icmpType {
case uint8(layers.ICMPv4TypeDestinationUnreachable),
uint8(layers.ICMPv4TypeTimeExceeded):
return true
case uint8(layers.ICMPv4TypeEchoReply):
// continue processing
default:
return false return false
} }
key := ICMPConnKey{ key := makeICMPKey(dstIP, srcIP, id, seq)
SrcIP: dstIP,
DstIP: srcIP,
ID: id,
}
t.mutex.RLock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
t.mutex.RUnlock() t.mutex.RUnlock()
if !exists || conn.timeoutExceeded(t.timeout) { if !exists {
return false return false
} }
conn.UpdateLastSeen() if conn.timeoutExceeded(t.timeout) {
conn.UpdateCounters(nftypes.Ingress, size) return false
}
return true return conn.IsEstablished() &&
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
conn.ID == id &&
conn.Sequence == seq
} }
func (t *ICMPTracker) cleanupRoutine(ctx context.Context) { func (t *ICMPTracker) cleanupRoutine() {
defer t.tickerCancel()
for { for {
select { select {
case <-t.cleanupTicker.C: case <-t.cleanupTicker.C:
t.cleanup() t.cleanup()
case <-ctx.Done(): case <-t.done:
return return
} }
} }
} }
func (t *ICMPTracker) cleanup() { func (t *ICMPTracker) cleanup() {
t.mutex.Lock() t.mutex.Lock()
defer t.mutex.Unlock() defer t.mutex.Unlock()
for key, conn := range t.connections { for key, conn := range t.connections {
if conn.timeoutExceeded(t.timeout) { if conn.timeoutExceeded(t.timeout) {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key) delete(t.connections, key)
t.logger.Debug("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
} }
} }
// Close stops the cleanup routine and releases resources // Close stops the cleanup routine and releases resources
func (t *ICMPTracker) Close() { func (t *ICMPTracker) Close() {
t.tickerCancel() t.cleanupTicker.Stop()
close(t.done)
t.mutex.Lock() t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil t.connections = nil
t.mutex.Unlock() t.mutex.Unlock()
} }
func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []byte) { // makeICMPKey creates an ICMP connection key
t.flowLogger.StoreEvent(nftypes.EventFields{ func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
FlowID: conn.FlowId, return ICMPConnKey{
Type: typ, SrcIP: MakeIPAddr(srcIP),
RuleID: ruleID, DstIP: MakeIPAddr(dstIP),
Direction: conn.Direction, ID: id,
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6 Sequence: seq,
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
ICMPType: conn.ICMPType,
ICMPCode: conn.ICMPCode,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
})
}
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8, ruleID []byte, size int) {
fields := nftypes.EventFields{
FlowID: uuid.New(),
Type: nftypes.TypeStart,
RuleID: ruleID,
Direction: direction,
Protocol: nftypes.ICMP,
SourceIP: srcIP,
DestIP: dstIP,
ICMPType: typ,
ICMPCode: code,
} }
if direction == nftypes.Ingress {
fields.RxPackets = 1
fields.RxBytes = uint64(size)
} else {
fields.TxPackets = 1
fields.TxBytes = uint64(size)
}
t.flowLogger.StoreEvent(fields)
} }

View File

@@ -1,39 +1,39 @@
package conntrack package conntrack
import ( import (
"net/netip" "net"
"testing" "testing"
) )
func BenchmarkICMPTracker(b *testing.B) { func BenchmarkICMPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger) tracker := NewICMPTracker(DefaultICMPTimeout)
defer tracker.Close() defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2") dstIP := net.ParseIP("192.168.1.2")
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535))
} }
}) })
b.Run("IsValidInbound", func(b *testing.B) { b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger) tracker := NewICMPTracker(DefaultICMPTimeout)
defer tracker.Close() defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2") dstIP := net.ParseIP("192.168.1.2")
// Pre-populate some connections // Pre-populate some connections
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0) tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i))
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), 0, 0) tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0)
} }
}) })
} }

View File

@@ -3,16 +3,9 @@ package conntrack
// TODO: Send RST packets for invalid/timed-out connections // TODO: Send RST packets for invalid/timed-out connections
import ( import (
"context" "net"
"net/netip"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
const ( const (
@@ -43,35 +36,6 @@ const (
// TCPState represents the state of a TCP connection // TCPState represents the state of a TCP connection
type TCPState int type TCPState int
func (s TCPState) String() string {
switch s {
case TCPStateNew:
return "New"
case TCPStateSynSent:
return "SYN Sent"
case TCPStateSynReceived:
return "SYN Received"
case TCPStateEstablished:
return "Established"
case TCPStateFinWait1:
return "FIN Wait 1"
case TCPStateFinWait2:
return "FIN Wait 2"
case TCPStateClosing:
return "Closing"
case TCPStateTimeWait:
return "Time Wait"
case TCPStateCloseWait:
return "Close Wait"
case TCPStateLastAck:
return "Last ACK"
case TCPStateClosed:
return "Closed"
default:
return "Unknown"
}
}
const ( const (
TCPStateNew TCPState = iota TCPStateNew TCPState = iota
TCPStateSynSent TCPStateSynSent
@@ -86,147 +50,90 @@ const (
TCPStateClosed TCPStateClosed
) )
// TCPConnKey uniquely identifies a TCP connection
type TCPConnKey struct {
SrcIP [16]byte
DstIP [16]byte
SrcPort uint16
DstPort uint16
}
// TCPConnTrack represents a TCP connection state // TCPConnTrack represents a TCP connection state
type TCPConnTrack struct { type TCPConnTrack struct {
BaseConnTrack BaseConnTrack
SourcePort uint16 State TCPState
DestPort uint16
State TCPState
established atomic.Bool
tombstone atomic.Bool
sync.RWMutex sync.RWMutex
} }
// IsEstablished safely checks if connection is established
func (t *TCPConnTrack) IsEstablished() bool {
return t.established.Load()
}
// SetEstablished safely sets the established state
func (t *TCPConnTrack) SetEstablished(state bool) {
t.established.Store(state)
}
// IsTombstone safely checks if the connection is marked for deletion
func (t *TCPConnTrack) IsTombstone() bool {
return t.tombstone.Load()
}
// SetTombstone safely marks the connection for deletion
func (t *TCPConnTrack) SetTombstone() {
t.tombstone.Store(true)
}
// TCPTracker manages TCP connection states // TCPTracker manages TCP connection states
type TCPTracker struct { type TCPTracker struct {
logger *nblog.Logger
connections map[ConnKey]*TCPConnTrack connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex mutex sync.RWMutex
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
tickerCancel context.CancelFunc done chan struct{}
timeout time.Duration timeout time.Duration
flowLogger nftypes.FlowLogger ipPool *PreallocatedIPs
} }
// NewTCPTracker creates a new TCP connection tracker // NewTCPTracker creates a new TCP connection tracker
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker { func NewTCPTracker(timeout time.Duration) *TCPTracker {
if timeout == 0 {
timeout = DefaultTCPTimeout
}
ctx, cancel := context.WithCancel(context.Background())
tracker := &TCPTracker{ tracker := &TCPTracker{
logger: logger,
connections: make(map[ConnKey]*TCPConnTrack), connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval), cleanupTicker: time.NewTicker(TCPCleanupInterval),
tickerCancel: cancel, done: make(chan struct{}),
timeout: timeout, timeout: timeout,
flowLogger: flowLogger, ipPool: NewPreallocatedIPs(),
} }
go tracker.cleanupRoutine(ctx) go tracker.cleanupRoutine()
return tracker return tracker
} }
func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) { // TrackOutbound processes an outbound TCP packet and updates connection state
key := ConnKey{ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
SrcIP: srcIP, // Create key before lock
DstIP: dstIP, key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
SrcPort: srcPort, now := time.Now().UnixNano()
DstPort: dstPort,
}
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if exists {
conn.Lock()
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
conn.Unlock()
conn.UpdateCounters(direction, size)
return key, true
}
return key, false
}
// TrackOutbound records an outbound TCP connection
func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, 0, 0); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
}
}
// TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
if exists {
return
}
conn := &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
},
SourcePort: srcPort,
DestPort: dstPort,
}
conn.established.Store(false)
conn.tombstone.Store(false)
t.logger.Trace("New %s TCP connection: %s", direction, key)
t.updateState(key, conn, flags, direction == nftypes.Egress)
t.mutex.Lock() t.mutex.Lock()
t.connections[key] = conn conn, exists := t.connections[key]
if !exists {
// Use preallocated IPs
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
SourcePort: srcPort,
DestPort: dstPort,
},
State: TCPStateNew,
}
conn.lastSeen.Store(now)
conn.established.Store(false)
t.connections[key] = conn
}
t.mutex.Unlock() t.mutex.Unlock()
t.sendEvent(nftypes.TypeStart, conn, ruleID) // Lock individual connection for state update
conn.Lock()
t.updateState(conn, flags, true)
conn.Unlock()
conn.lastSeen.Store(now)
} }
// IsValidInbound checks if an inbound TCP packet matches a tracked connection // IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) bool { func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
key := ConnKey{ if !isValidFlagCombination(flags) {
SrcIP: dstIP, return false
DstIP: srcIP,
SrcPort: dstPort,
DstPort: srcPort,
} }
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.RLock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
t.mutex.RUnlock() t.mutex.RUnlock()
@@ -235,26 +142,22 @@ func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
return false return false
} }
// Handle RST flag specially - it always causes transition to closed // Handle RST packets
if flags&TCPRst != 0 { if flags&TCPRst != 0 {
if conn.IsTombstone() { conn.Lock()
if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
conn.State = TCPStateClosed
conn.SetEstablished(false)
conn.Unlock()
return true return true
} }
conn.Lock()
conn.SetTombstone()
conn.State = TCPStateClosed
conn.SetEstablished(false)
conn.Unlock() conn.Unlock()
conn.UpdateCounters(nftypes.Ingress, size) return false
t.logger.Trace("TCP connection reset: %s", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
return true
} }
conn.Lock() conn.Lock()
t.updateState(key, conn, flags, false) t.updateState(conn, flags, false)
conn.UpdateLastSeen()
isEstablished := conn.IsEstablished() isEstablished := conn.IsEstablished()
isValidState := t.isValidStateForFlags(conn.State, flags) isValidState := t.isValidStateForFlags(conn.State, flags)
conn.Unlock() conn.Unlock()
@@ -263,17 +166,15 @@ func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
} }
// updateState updates the TCP connection state based on flags // updateState updates the TCP connection state based on flags
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) { func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) {
conn.UpdateLastSeen() // Handle RST flag specially - it always causes transition to closed
if flags&TCPRst != 0 {
conn.State = TCPStateClosed
conn.SetEstablished(false)
return
}
state := conn.State switch conn.State {
defer func() {
if state != conn.State {
t.logger.Trace("TCP connection %s transitioned from %s to %s", key, state, conn.State)
}
}()
switch state {
case TCPStateNew: case TCPStateNew:
if flags&TCPSyn != 0 && flags&TCPAck == 0 { if flags&TCPSyn != 0 && flags&TCPAck == 0 {
conn.State = TCPStateSynSent conn.State = TCPStateSynSent
@@ -282,11 +183,11 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
case TCPStateSynSent: case TCPStateSynSent:
if flags&TCPSyn != 0 && flags&TCPAck != 0 { if flags&TCPSyn != 0 && flags&TCPAck != 0 {
if isOutbound { if isOutbound {
conn.State = TCPStateEstablished conn.State = TCPStateSynReceived
conn.SetEstablished(true)
} else { } else {
// Simultaneous open // Simultaneous open
conn.State = TCPStateSynReceived conn.State = TCPStateEstablished
conn.SetEstablished(true)
} }
} }
@@ -304,41 +205,28 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
conn.State = TCPStateCloseWait conn.State = TCPStateCloseWait
} }
conn.SetEstablished(false) conn.SetEstablished(false)
} else if flags&TCPRst != 0 {
conn.State = TCPStateClosed
conn.SetTombstone()
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateFinWait1: case TCPStateFinWait1:
switch { switch {
case flags&TCPFin != 0 && flags&TCPAck != 0: case flags&TCPFin != 0 && flags&TCPAck != 0:
// Simultaneous close - both sides sent FIN
conn.State = TCPStateClosing conn.State = TCPStateClosing
case flags&TCPFin != 0: case flags&TCPFin != 0:
conn.State = TCPStateFinWait2 conn.State = TCPStateFinWait2
case flags&TCPAck != 0: case flags&TCPAck != 0:
conn.State = TCPStateFinWait2 conn.State = TCPStateFinWait2
case flags&TCPRst != 0:
conn.State = TCPStateClosed
conn.SetTombstone()
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateFinWait2: case TCPStateFinWait2:
if flags&TCPFin != 0 { if flags&TCPFin != 0 {
conn.State = TCPStateTimeWait conn.State = TCPStateTimeWait
t.logger.Trace("TCP connection %s completed", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateClosing: case TCPStateClosing:
if flags&TCPAck != 0 { if flags&TCPAck != 0 {
conn.State = TCPStateTimeWait conn.State = TCPStateTimeWait
// Keep established = false from previous state // Keep established = false from previous state
t.logger.Trace("TCP connection %s closed (simultaneous)", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateCloseWait: case TCPStateCloseWait:
@@ -349,12 +237,11 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
case TCPStateLastAck: case TCPStateLastAck:
if flags&TCPAck != 0 { if flags&TCPAck != 0 {
conn.State = TCPStateClosed conn.State = TCPStateClosed
conn.SetTombstone()
// Send close event for gracefully closed connections
t.sendEvent(nftypes.TypeEnd, conn, nil)
t.logger.Trace("TCP connection %s closed gracefully", key)
} }
case TCPStateTimeWait:
// Stay in TIME-WAIT for 2MSL before transitioning to closed
// This is handled by the cleanup routine
} }
} }
@@ -399,14 +286,12 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
return false return false
} }
func (t *TCPTracker) cleanupRoutine(ctx context.Context) { func (t *TCPTracker) cleanupRoutine() {
defer t.cleanupTicker.Stop()
for { for {
select { select {
case <-t.cleanupTicker.C: case <-t.cleanupTicker.C:
t.cleanup() t.cleanup()
case <-ctx.Done(): case <-t.done:
return return
} }
} }
@@ -417,12 +302,6 @@ func (t *TCPTracker) cleanup() {
defer t.mutex.Unlock() defer t.mutex.Unlock()
for key, conn := range t.connections { for key, conn := range t.connections {
if conn.IsTombstone() {
// Clean up tombstoned connections without sending an event
delete(t.connections, key)
continue
}
var timeout time.Duration var timeout time.Duration
switch { switch {
case conn.State == TCPStateTimeWait: case conn.State == TCPStateTimeWait:
@@ -433,26 +312,27 @@ func (t *TCPTracker) cleanup() {
timeout = TCPHandshakeTimeout timeout = TCPHandshakeTimeout
} }
if conn.timeoutExceeded(timeout) { lastSeen := conn.GetLastSeen()
if time.Since(lastSeen) > timeout {
// Return IPs to pool // Return IPs to pool
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace("Cleaned up timed-out TCP connection %s", key)
// event already handled by state change
if conn.State != TCPStateTimeWait {
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
} }
} }
} }
// Close stops the cleanup routine and releases resources // Close stops the cleanup routine and releases resources
func (t *TCPTracker) Close() { func (t *TCPTracker) Close() {
t.tickerCancel() t.cleanupTicker.Stop()
close(t.done)
// Clean up all remaining IPs // Clean up all remaining IPs
t.mutex.Lock() t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil t.connections = nil
t.mutex.Unlock() t.mutex.Unlock()
} }
@@ -470,21 +350,3 @@ func isValidFlagCombination(flags uint8) bool {
return true return true
} }
func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack, ruleID []byte) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
RuleID: ruleID,
Direction: conn.Direction,
Protocol: nftypes.TCP,
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
SourcePort: conn.SourcePort,
DestPort: conn.DestPort,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
})
}

View File

@@ -1,7 +1,7 @@
package conntrack package conntrack
import ( import (
"net/netip" "net"
"testing" "testing"
"time" "time"
@@ -9,11 +9,11 @@ import (
) )
func TestTCPStateMachine(t *testing.T) { func TestTCPStateMachine(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) tracker := NewTCPTracker(DefaultTCPTimeout)
defer tracker.Close() defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1") srcIP := net.ParseIP("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2") dstIP := net.ParseIP("100.64.0.2")
srcPort := uint16(12345) srcPort := uint16(12345)
dstPort := uint16(80) dstPort := uint16(80)
@@ -58,7 +58,7 @@ func TestTCPStateMachine(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags, 0) isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags)
require.Equal(t, !tt.wantDrop, isValid, tt.desc) require.Equal(t, !tt.wantDrop, isValid, tt.desc)
}) })
} }
@@ -76,17 +76,17 @@ func TestTCPStateMachine(t *testing.T) {
t.Helper() t.Helper()
// Send initial SYN // Send initial SYN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
// Receive SYN-ACK // Receive SYN-ACK
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
require.True(t, valid, "SYN-ACK should be allowed") require.True(t, valid, "SYN-ACK should be allowed")
// Send ACK // Send ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
// Test data transfer // Test data transfer
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 0) valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck)
require.True(t, valid, "Data should be allowed after handshake") require.True(t, valid, "Data should be allowed after handshake")
}, },
}, },
@@ -99,18 +99,18 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Send FIN // Send FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
// Receive ACK for FIN // Receive ACK for FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
require.True(t, valid, "ACK for FIN should be allowed") require.True(t, valid, "ACK for FIN should be allowed")
// Receive FIN from other side // Receive FIN from other side
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
require.True(t, valid, "FIN should be allowed") require.True(t, valid, "FIN should be allowed")
// Send final ACK // Send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
}, },
}, },
{ {
@@ -122,7 +122,7 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Receive RST // Receive RST
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
require.True(t, valid, "RST should be allowed for established connection") require.True(t, valid, "RST should be allowed for established connection")
// Connection is logically dead but we don't enforce blocking subsequent packets // Connection is logically dead but we don't enforce blocking subsequent packets
@@ -138,13 +138,13 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Both sides send FIN+ACK // Both sides send FIN+ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
require.True(t, valid, "Simultaneous FIN should be allowed") require.True(t, valid, "Simultaneous FIN should be allowed")
// Both sides send final ACK // Both sides send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
require.True(t, valid, "Final ACKs should be allowed") require.True(t, valid, "Final ACKs should be allowed")
}, },
}, },
@@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Helper() t.Helper()
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) tracker = NewTCPTracker(DefaultTCPTimeout)
tt.test(t) tt.test(t)
}) })
} }
@@ -162,11 +162,11 @@ func TestTCPStateMachine(t *testing.T) {
} }
func TestRSTHandling(t *testing.T) { func TestRSTHandling(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) tracker := NewTCPTracker(DefaultTCPTimeout)
defer tracker.Close() defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1") srcIP := net.ParseIP("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2") dstIP := net.ParseIP("100.64.0.2")
srcPort := uint16(12345) srcPort := uint16(12345)
dstPort := uint16(80) dstPort := uint16(80)
@@ -181,12 +181,12 @@ func TestRSTHandling(t *testing.T) {
name: "RST in established", name: "RST in established",
setupState: func() { setupState: func() {
// Establish connection first // Establish connection first
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0) tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
}, },
sendRST: func() { sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0) tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
}, },
wantValid: true, wantValid: true,
desc: "Should accept RST for established connection", desc: "Should accept RST for established connection",
@@ -195,7 +195,7 @@ func TestRSTHandling(t *testing.T) {
name: "RST without connection", name: "RST without connection",
setupState: func() {}, setupState: func() {},
sendRST: func() { sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0) tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
}, },
wantValid: false, wantValid: false,
desc: "Should reject RST without connection", desc: "Should reject RST without connection",
@@ -208,12 +208,7 @@ func TestRSTHandling(t *testing.T) {
tt.sendRST() tt.sendRST()
// Verify connection state is as expected // Verify connection state is as expected
key := ConnKey{ key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key] conn := tracker.connections[key]
if tt.wantValid { if tt.wantValid {
require.NotNil(t, conn) require.NotNil(t, conn)
@@ -225,63 +220,63 @@ func TestRSTHandling(t *testing.T) {
} }
// Helper to establish a TCP connection // Helper to establish a TCP connection
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) { func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) {
t.Helper() t.Helper()
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
require.True(t, valid, "SYN-ACK should be allowed") require.True(t, valid, "SYN-ACK should be allowed")
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
} }
func BenchmarkTCPTracker(b *testing.B) { func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) tracker := NewTCPTracker(DefaultTCPTimeout)
defer tracker.Close() defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2") dstIP := net.ParseIP("192.168.1.2")
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
} }
}) })
b.Run("IsValidInbound", func(b *testing.B) { b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) tracker := NewTCPTracker(DefaultTCPTimeout)
defer tracker.Close() defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2") dstIP := net.ParseIP("192.168.1.2")
// Pre-populate some connections // Pre-populate some connections
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, 0) tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck)
} }
}) })
b.Run("ConcurrentAccess", func(b *testing.B) { b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) tracker := NewTCPTracker(DefaultTCPTimeout)
defer tracker.Close() defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2") dstIP := net.ParseIP("192.168.1.2")
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
i := 0 i := 0
for pb.Next() { for pb.Next() {
if i%2 == 0 { if i%2 == 0 {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
} else { } else {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, 0) tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck)
} }
i++ i++
} }
@@ -292,14 +287,14 @@ func BenchmarkTCPTracker(b *testing.B) {
// Benchmark connection cleanup // Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) { func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) { b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger) // Short timeout for testing tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing
defer tracker.Close() defer tracker.Close()
// Pre-populate with expired connections // Pre-populate with expired connections
srcIP := netip.MustParseAddr("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2") dstIP := net.ParseIP("192.168.1.2")
for i := 0; i < 10000; i++ { for i := 0; i < 10000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
} }
// Wait for connections to expire // Wait for connections to expire

View File

@@ -1,15 +1,9 @@
package conntrack package conntrack
import ( import (
"context" "net"
"net/netip"
"sync" "sync"
"time" "time"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
const ( const (
@@ -22,135 +16,96 @@ const (
// UDPConnTrack represents a UDP connection state // UDPConnTrack represents a UDP connection state
type UDPConnTrack struct { type UDPConnTrack struct {
BaseConnTrack BaseConnTrack
SourcePort uint16
DestPort uint16
} }
// UDPTracker manages UDP connection states // UDPTracker manages UDP connection states
type UDPTracker struct { type UDPTracker struct {
logger *nblog.Logger
connections map[ConnKey]*UDPConnTrack connections map[ConnKey]*UDPConnTrack
timeout time.Duration timeout time.Duration
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex mutex sync.RWMutex
flowLogger nftypes.FlowLogger done chan struct{}
ipPool *PreallocatedIPs
} }
// NewUDPTracker creates a new UDP connection tracker // NewUDPTracker creates a new UDP connection tracker
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *UDPTracker { func NewUDPTracker(timeout time.Duration) *UDPTracker {
if timeout == 0 { if timeout == 0 {
timeout = DefaultUDPTimeout timeout = DefaultUDPTimeout
} }
ctx, cancel := context.WithCancel(context.Background())
tracker := &UDPTracker{ tracker := &UDPTracker{
logger: logger,
connections: make(map[ConnKey]*UDPConnTrack), connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval), cleanupTicker: time.NewTicker(UDPCleanupInterval),
tickerCancel: cancel, done: make(chan struct{}),
flowLogger: flowLogger, ipPool: NewPreallocatedIPs(),
} }
go tracker.cleanupRoutine(ctx) go tracker.cleanupRoutine()
return tracker return tracker
} }
// TrackOutbound records an outbound UDP connection // TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) { func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists { key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
// if (inverted direction) conn is not tracked, track this direction now := time.Now().UnixNano()
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
}
}
// TrackInbound records an inbound UDP connection
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
}
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if exists {
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, true
}
return key, false
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
if exists {
return
}
conn := &UDPConnTrack{
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
},
SourcePort: srcPort,
DestPort: dstPort,
}
conn.UpdateLastSeen()
t.mutex.Lock() t.mutex.Lock()
t.connections[key] = conn conn, exists := t.connections[key]
if !exists {
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &UDPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
SourcePort: srcPort,
DestPort: dstPort,
},
}
conn.lastSeen.Store(now)
conn.established.Store(true)
t.connections[key] = conn
}
t.mutex.Unlock() t.mutex.Unlock()
t.logger.Trace("New %s UDP connection: %s", direction, key) conn.lastSeen.Store(now)
t.sendEvent(nftypes.TypeStart, conn, ruleID)
} }
// IsValidInbound checks if an inbound packet matches a tracked connection // IsValidInbound checks if an inbound packet matches a tracked connection
func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) bool { func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool {
key := ConnKey{ key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
SrcIP: dstIP,
DstIP: srcIP,
SrcPort: dstPort,
DstPort: srcPort,
}
t.mutex.RLock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
t.mutex.RUnlock() t.mutex.RUnlock()
if !exists || conn.timeoutExceeded(t.timeout) { if !exists {
return false return false
} }
conn.UpdateLastSeen() if conn.timeoutExceeded(t.timeout) {
conn.UpdateCounters(nftypes.Ingress, size) return false
}
return true return conn.IsEstablished() &&
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
conn.DestPort == srcPort &&
conn.SourcePort == dstPort
} }
// cleanupRoutine periodically removes stale connections // cleanupRoutine periodically removes stale connections
func (t *UDPTracker) cleanupRoutine(ctx context.Context) { func (t *UDPTracker) cleanupRoutine() {
defer t.cleanupTicker.Stop()
for { for {
select { select {
case <-t.cleanupTicker.C: case <-t.cleanupTicker.C:
t.cleanup() t.cleanup()
case <-ctx.Done(): case <-t.done:
return return
} }
} }
@@ -162,58 +117,42 @@ func (t *UDPTracker) cleanup() {
for key, conn := range t.connections { for key, conn := range t.connections {
if conn.timeoutExceeded(t.timeout) { if conn.timeoutExceeded(t.timeout) {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
} }
} }
// Close stops the cleanup routine and releases resources // Close stops the cleanup routine and releases resources
func (t *UDPTracker) Close() { func (t *UDPTracker) Close() {
t.tickerCancel() t.cleanupTicker.Stop()
close(t.done)
t.mutex.Lock() t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil t.connections = nil
t.mutex.Unlock() t.mutex.Unlock()
} }
// GetConnection safely retrieves a connection state // GetConnection safely retrieves a connection state
func (t *UDPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*UDPConnTrack, bool) { func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) {
t.mutex.RLock() t.mutex.RLock()
defer t.mutex.RUnlock() defer t.mutex.RUnlock()
key := ConnKey{ key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn, exists := t.connections[key] conn, exists := t.connections[key]
return conn, exists if !exists {
return nil, false
}
return conn, true
} }
// Timeout returns the configured timeout duration for the tracker // Timeout returns the configured timeout duration for the tracker
func (t *UDPTracker) Timeout() time.Duration { func (t *UDPTracker) Timeout() time.Duration {
return t.timeout return t.timeout
} }
func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack, ruleID []byte) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
RuleID: ruleID,
Direction: conn.Direction,
Protocol: nftypes.UDP,
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
SourcePort: conn.SourcePort,
DestPort: conn.DestPort,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
})
}

View File

@@ -1,8 +1,7 @@
package conntrack package conntrack
import ( import (
"context" "net"
"net/netip"
"testing" "testing"
"time" "time"
@@ -30,59 +29,55 @@ func TestNewUDPTracker(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tracker := NewUDPTracker(tt.timeout, logger, flowLogger) tracker := NewUDPTracker(tt.timeout)
assert.NotNil(t, tracker) assert.NotNil(t, tracker)
assert.Equal(t, tt.wantTimeout, tracker.timeout) assert.Equal(t, tt.wantTimeout, tracker.timeout)
assert.NotNil(t, tracker.connections) assert.NotNil(t, tracker.connections)
assert.NotNil(t, tracker.cleanupTicker) assert.NotNil(t, tracker.cleanupTicker)
assert.NotNil(t, tracker.tickerCancel) assert.NotNil(t, tracker.done)
}) })
} }
} }
func TestUDPTracker_TrackOutbound(t *testing.T) { func TestUDPTracker_TrackOutbound(t *testing.T) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger) tracker := NewUDPTracker(DefaultUDPTimeout)
defer tracker.Close() defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.2") srcIP := net.ParseIP("192.168.1.2")
dstIP := netip.MustParseAddr("192.168.1.3") dstIP := net.ParseIP("192.168.1.3")
srcPort := uint16(12345) srcPort := uint16(12345)
dstPort := uint16(53) dstPort := uint16(53)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
// Verify connection was tracked // Verify connection was tracked
key := ConnKey{ key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn, exists := tracker.connections[key] conn, exists := tracker.connections[key]
require.True(t, exists) require.True(t, exists)
assert.True(t, conn.SourceIP.Compare(srcIP) == 0) assert.True(t, conn.SourceIP.Equal(srcIP))
assert.True(t, conn.DestIP.Compare(dstIP) == 0) assert.True(t, conn.DestIP.Equal(dstIP))
assert.Equal(t, srcPort, conn.SourcePort) assert.Equal(t, srcPort, conn.SourcePort)
assert.Equal(t, dstPort, conn.DestPort) assert.Equal(t, dstPort, conn.DestPort)
assert.True(t, conn.IsEstablished())
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second) assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
} }
func TestUDPTracker_IsValidInbound(t *testing.T) { func TestUDPTracker_IsValidInbound(t *testing.T) {
tracker := NewUDPTracker(1*time.Second, logger, flowLogger) tracker := NewUDPTracker(1 * time.Second)
defer tracker.Close() defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.2") srcIP := net.ParseIP("192.168.1.2")
dstIP := netip.MustParseAddr("192.168.1.3") dstIP := net.ParseIP("192.168.1.3")
srcPort := uint16(12345) srcPort := uint16(12345)
dstPort := uint16(53) dstPort := uint16(53)
// Track outbound connection // Track outbound connection
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
tests := []struct { tests := []struct {
name string name string
srcIP netip.Addr srcIP net.IP
dstIP netip.Addr dstIP net.IP
srcPort uint16 srcPort uint16
dstPort uint16 dstPort uint16
sleep time.Duration sleep time.Duration
@@ -99,7 +94,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
}, },
{ {
name: "invalid source IP", name: "invalid source IP",
srcIP: netip.MustParseAddr("192.168.1.4"), srcIP: net.ParseIP("192.168.1.4"),
dstIP: srcIP, dstIP: srcIP,
srcPort: dstPort, srcPort: dstPort,
dstPort: srcPort, dstPort: srcPort,
@@ -109,7 +104,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
{ {
name: "invalid destination IP", name: "invalid destination IP",
srcIP: dstIP, srcIP: dstIP,
dstIP: netip.MustParseAddr("192.168.1.4"), dstIP: net.ParseIP("192.168.1.4"),
srcPort: dstPort, srcPort: dstPort,
dstPort: srcPort, dstPort: srcPort,
sleep: 0, sleep: 0,
@@ -149,7 +144,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
if tt.sleep > 0 { if tt.sleep > 0 {
time.Sleep(tt.sleep) time.Sleep(tt.sleep)
} }
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort, 0) got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort)
assert.Equal(t, tt.want, got) assert.Equal(t, tt.want, got)
}) })
} }
@@ -160,45 +155,41 @@ func TestUDPTracker_Cleanup(t *testing.T) {
timeout := 50 * time.Millisecond timeout := 50 * time.Millisecond
cleanupInterval := 25 * time.Millisecond cleanupInterval := 25 * time.Millisecond
ctx, tickerCancel := context.WithCancel(context.Background())
defer tickerCancel()
// Create tracker with custom cleanup interval // Create tracker with custom cleanup interval
tracker := &UDPTracker{ tracker := &UDPTracker{
connections: make(map[ConnKey]*UDPConnTrack), connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(cleanupInterval), cleanupTicker: time.NewTicker(cleanupInterval),
tickerCancel: tickerCancel, done: make(chan struct{}),
logger: logger, ipPool: NewPreallocatedIPs(),
flowLogger: flowLogger,
} }
// Start cleanup routine // Start cleanup routine
go tracker.cleanupRoutine(ctx) go tracker.cleanupRoutine()
// Add some connections // Add some connections
connections := []struct { connections := []struct {
srcIP netip.Addr srcIP net.IP
dstIP netip.Addr dstIP net.IP
srcPort uint16 srcPort uint16
dstPort uint16 dstPort uint16
}{ }{
{ {
srcIP: netip.MustParseAddr("192.168.1.2"), srcIP: net.ParseIP("192.168.1.2"),
dstIP: netip.MustParseAddr("192.168.1.3"), dstIP: net.ParseIP("192.168.1.3"),
srcPort: 12345, srcPort: 12345,
dstPort: 53, dstPort: 53,
}, },
{ {
srcIP: netip.MustParseAddr("192.168.1.4"), srcIP: net.ParseIP("192.168.1.4"),
dstIP: netip.MustParseAddr("192.168.1.5"), dstIP: net.ParseIP("192.168.1.5"),
srcPort: 12346, srcPort: 12346,
dstPort: 53, dstPort: 53,
}, },
} }
for _, conn := range connections { for _, conn := range connections {
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort, 0) tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort)
} }
// Verify initial connections // Verify initial connections
@@ -220,33 +211,33 @@ func TestUDPTracker_Cleanup(t *testing.T) {
func BenchmarkUDPTracker(b *testing.B) { func BenchmarkUDPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger) tracker := NewUDPTracker(DefaultUDPTimeout)
defer tracker.Close() defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2") dstIP := net.ParseIP("192.168.1.2")
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, 0) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80)
} }
}) })
b.Run("IsValidInbound", func(b *testing.B) { b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger) tracker := NewUDPTracker(DefaultUDPTimeout)
defer tracker.Close() defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1") srcIP := net.ParseIP("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2") dstIP := net.ParseIP("192.168.1.2")
// Pre-populate some connections // Pre-populate some connections
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, 0) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80)
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), 0) tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000))
} }
}) })
} }

View File

@@ -1,90 +0,0 @@
package forwarder
import (
"fmt"
wgdevice "golang.zx2c4.com/wireguard/device"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device
type endpoint struct {
logger *nblog.Logger
dispatcher stack.NetworkDispatcher
device *wgdevice.Device
mtu uint32
}
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
}
func (e *endpoint) IsAttached() bool {
return e.dispatcher != nil
}
func (e *endpoint) MTU() uint32 {
return e.mtu
}
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
return stack.CapabilityNone
}
func (e *endpoint) MaxHeaderLength() uint16 {
return 0
}
func (e *endpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
var written int
for _, pkt := range pkts.AsSlice() {
netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice())
data := stack.PayloadSince(pkt.NetworkHeader())
if data == nil {
continue
}
// Send the packet through WireGuard
address := netHeader.DestinationAddress()
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
if err != nil {
e.logger.Error("CreateOutboundPacket: %v", err)
continue
}
written++
}
return written, nil
}
func (e *endpoint) Wait() {
// not required
}
func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareNone
}
func (e *endpoint) AddHeader(*stack.PacketBuffer) {
// not required
}
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
return true
}
type epID stack.TransportEndpointID
func (i epID) String() string {
// src and remote is swapped
return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
}

View File

@@ -1,169 +0,0 @@
package forwarder
import (
"context"
"fmt"
"net"
"runtime"
log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
const (
defaultReceiveWindow = 32768
defaultMaxInFlight = 1024
iosReceiveWindow = 16384
iosMaxInFlight = 256
)
type Forwarder struct {
logger *nblog.Logger
flowLogger nftypes.FlowLogger
stack *stack.Stack
endpoint *endpoint
udpForwarder *udpForwarder
ctx context.Context
cancel context.CancelFunc
ip net.IP
netstack bool
}
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{
tcp.NewProtocol,
udp.NewProtocol,
icmp.NewProtocol4,
},
HandleLocal: false,
})
mtu, err := iface.GetDevice().MTU()
if err != nil {
return nil, fmt.Errorf("get MTU: %w", err)
}
nicID := tcpip.NICID(1)
endpoint := &endpoint{
logger: logger,
device: iface.GetWGDevice(),
mtu: uint32(mtu),
}
if err := s.CreateNIC(nicID, endpoint); err != nil {
return nil, fmt.Errorf("failed to create NIC: %v", err)
}
ones, _ := iface.Address().Network.Mask.Size()
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
PrefixLen: ones,
},
}
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
return nil, fmt.Errorf("failed to add protocol address: %s", err)
}
defaultSubnet, err := tcpip.NewSubnet(
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
)
if err != nil {
return nil, fmt.Errorf("creating default subnet: %w", err)
}
if err := s.SetPromiscuousMode(nicID, true); err != nil {
return nil, fmt.Errorf("set promiscuous mode: %s", err)
}
if err := s.SetSpoofing(nicID, true); err != nil {
return nil, fmt.Errorf("set spoofing: %s", err)
}
s.SetRouteTable([]tcpip.Route{
{
Destination: defaultSubnet,
NIC: nicID,
},
})
ctx, cancel := context.WithCancel(context.Background())
f := &Forwarder{
logger: logger,
flowLogger: flowLogger,
stack: s,
endpoint: endpoint,
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
ctx: ctx,
cancel: cancel,
netstack: netstack,
ip: iface.Address().IP,
}
receiveWindow := defaultReceiveWindow
maxInFlight := defaultMaxInFlight
if runtime.GOOS == "ios" {
receiveWindow = iosReceiveWindow
maxInFlight = iosMaxInFlight
}
tcpForwarder := tcp.NewForwarder(s, receiveWindow, maxInFlight, f.handleTCP)
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
udpForwarder := udp.NewForwarder(s, f.handleUDP)
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
log.Debugf("forwarder: Initialization complete with NIC %d", nicID)
return f, nil
}
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
if len(payload) < header.IPv4MinimumSize {
return fmt.Errorf("packet too small: %d bytes", len(payload))
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(payload),
})
defer pkt.DecRef()
if f.endpoint.dispatcher != nil {
f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
}
return nil
}
// Stop gracefully shuts down the forwarder
func (f *Forwarder) Stop() {
f.cancel()
if f.udpForwarder != nil {
f.udpForwarder.Stop()
}
f.stack.Close()
f.stack.Wait()
}
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
if f.netstack && f.ip.Equal(addr.AsSlice()) {
return net.IPv4(127, 0, 0, 1)
}
return addr.AsSlice()
}

View File

@@ -1,127 +0,0 @@
package forwarder
import (
"context"
"net"
"net/netip"
"time"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
// handleICMP handles ICMP packets from the network stack
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
icmpType := uint8(icmpHdr.Type())
icmpCode := uint8(icmpHdr.Code())
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
// dont process our own replies
return true
}
flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode)
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel()
lc := net.ListenConfig{}
// TODO: support non-root
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil {
f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err)
// This will make netstack reply on behalf of the original destination, that's ok for now
return false
}
defer func() {
if err := conn.Close(); err != nil {
f.logger.Debug("Failed to close ICMP socket: %v", err)
}
}()
dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP}
fullPacket := stack.PayloadSince(pkt.TransportHeader())
payload := fullPacket.AsSlice()
if _, err = conn.WriteTo(payload, dst); err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
return true
}
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
// For Echo Requests, send and handle response
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
f.handleEchoResponse(icmpHdr, conn, id)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
}
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
return true
}
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) {
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
return
}
response := make([]byte, f.endpoint.mtu)
n, _, err := conn.ReadFrom(response)
if err != nil {
if !isTimeout(err) {
f.logger.Error("Failed to read ICMP response: %v", err)
}
return
}
ipHdr := make([]byte, header.IPv4MinimumSize)
ip := header.IPv4(ipHdr)
ip.Encode(&header.IPv4Fields{
TotalLength: uint16(header.IPv4MinimumSize + n),
TTL: 64,
Protocol: uint8(header.ICMPv4ProtocolNumber),
SrcAddr: id.LocalAddress,
DstAddr: id.RemoteAddress,
})
ip.SetChecksum(^ip.CalculateChecksum())
fullPacket := make([]byte, 0, len(ipHdr)+n)
fullPacket = append(fullPacket, ipHdr...)
fullPacket = append(fullPacket, response[:n]...)
if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error("Failed to inject ICMP response: %v", err)
return
}
f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
}
// sendICMPEvent stores flow events for ICMP packets
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8) {
f.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.ICMP,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
ICMPType: icmpType,
ICMPCode: icmpCode,
// TODO: get packets/bytes
})
}

View File

@@ -1,132 +0,0 @@
package forwarder
import (
"context"
"fmt"
"io"
"net"
"net/netip"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
// handleTCP is called by the TCP forwarder for new connections.
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
id := r.ID()
flowID := uuid.New()
f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil)
var success bool
defer func() {
if !success {
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil)
}
}()
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if err != nil {
r.Complete(true)
f.logger.Trace("forwarder: dial error for %v: %v", epID(id), err)
return
}
// Create wait queue for blocking syscalls
wq := waiter.Queue{}
ep, epErr := r.CreateEndpoint(&wq)
if epErr != nil {
f.logger.Error("forwarder: failed to create TCP endpoint: %v", epErr)
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: outConn close error: %v", err)
}
r.Complete(true)
return
}
// Complete the handshake
r.Complete(false)
inConn := gonet.NewTCPConn(&wq, ep)
success = true
f.logger.Trace("forwarder: established TCP connection %v", epID(id))
go f.proxyTCP(id, inConn, outConn, ep, flowID)
}
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
defer func() {
if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: inConn close error: %v", err)
}
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: outConn close error: %v", err)
}
ep.Close()
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep)
}()
// Create context for managing the proxy goroutines
ctx, cancel := context.WithCancel(f.ctx)
defer cancel()
errChan := make(chan error, 2)
go func() {
_, err := io.Copy(outConn, inConn)
errChan <- err
}()
go func() {
_, err := io.Copy(inConn, outConn)
errChan <- err
}()
select {
case <-ctx.Done():
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", epID(id))
return
case err := <-errChan:
if err != nil && !isClosedError(err) {
f.logger.Error("proxyTCP: copy error: %v", err)
}
f.logger.Trace("forwarder: tearing down TCP connection %v", epID(id))
return
}
}
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
fields := nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.TCP,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
SourcePort: id.RemotePort,
DestPort: id.LocalPort,
}
if ep != nil {
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
// fields are flipped since this is the in conn
// TODO: get bytes
fields.RxPackets = tcpStats.SegmentsSent.Value()
fields.TxPackets = tcpStats.SegmentsReceived.Value()
}
}
f.flowLogger.StoreEvent(fields)
}

View File

@@ -1,332 +0,0 @@
package forwarder
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
const (
udpTimeout = 30 * time.Second
)
type udpPacketConn struct {
conn *gonet.UDPConn
outConn net.Conn
lastSeen atomic.Int64
cancel context.CancelFunc
ep tcpip.Endpoint
flowID uuid.UUID
}
type udpForwarder struct {
sync.RWMutex
logger *nblog.Logger
flowLogger nftypes.FlowLogger
conns map[stack.TransportEndpointID]*udpPacketConn
bufPool sync.Pool
ctx context.Context
cancel context.CancelFunc
}
type idleConn struct {
id stack.TransportEndpointID
conn *udpPacketConn
}
func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
ctx, cancel := context.WithCancel(context.Background())
f := &udpForwarder{
logger: logger,
flowLogger: flowLogger,
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
ctx: ctx,
cancel: cancel,
bufPool: sync.Pool{
New: func() any {
b := make([]byte, mtu)
return &b
},
},
}
go f.cleanup()
return f
}
// Stop stops the UDP forwarder and all active connections
func (f *udpForwarder) Stop() {
f.cancel()
f.Lock()
defer f.Unlock()
for id, conn := range f.conns {
conn.cancel()
if err := conn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(id), err)
}
if err := conn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
conn.ep.Close()
delete(f.conns, id)
}
}
// cleanup periodically removes idle UDP connections
func (f *udpForwarder) cleanup() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-f.ctx.Done():
return
case <-ticker.C:
var idleConns []idleConn
f.RLock()
for id, conn := range f.conns {
if conn.getIdleDuration() > udpTimeout {
idleConns = append(idleConns, idleConn{id, conn})
}
}
f.RUnlock()
for _, idle := range idleConns {
idle.conn.cancel()
if err := idle.conn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(idle.id), err)
}
if err := idle.conn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(idle.id), err)
}
idle.conn.ep.Close()
f.Lock()
delete(f.conns, idle.id)
f.Unlock()
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
}
}
}
}
// handleUDP is called by the UDP forwarder for new packets
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if f.ctx.Err() != nil {
f.logger.Trace("forwarder: context done, dropping UDP packet")
return
}
id := r.ID()
f.udpForwarder.RLock()
_, exists := f.udpForwarder.conns[id]
f.udpForwarder.RUnlock()
if exists {
f.logger.Trace("forwarder: existing UDP connection for %v", epID(id))
return
}
flowID := uuid.New()
f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil)
var success bool
defer func() {
if !success {
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil)
}
}()
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil {
f.logger.Debug("forwarder: UDP dial error for %v: %v", epID(id), err)
// TODO: Send ICMP error message
return
}
// Create wait queue for blocking syscalls
wq := waiter.Queue{}
ep, epErr := r.CreateEndpoint(&wq)
if epErr != nil {
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
return
}
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
connCtx, connCancel := context.WithCancel(f.ctx)
pConn := &udpPacketConn{
conn: inConn,
outConn: outConn,
cancel: connCancel,
ep: ep,
flowID: flowID,
}
pConn.updateLastSeen()
f.udpForwarder.Lock()
// Double-check no connection was created while we were setting up
if _, exists := f.udpForwarder.conns[id]; exists {
f.udpForwarder.Unlock()
pConn.cancel()
if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
}
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
return
}
f.udpForwarder.conns[id] = pConn
f.udpForwarder.Unlock()
success = true
f.logger.Trace("forwarder: established UDP connection %v", epID(id))
go f.proxyUDP(connCtx, pConn, id, ep)
}
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
defer func() {
pConn.cancel()
if err := pConn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
}
if err := pConn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
}
ep.Close()
f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id)
f.udpForwarder.Unlock()
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep)
}()
errChan := make(chan error, 2)
go func() {
errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
}()
go func() {
errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
}()
select {
case <-ctx.Done():
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", epID(id))
return
case err := <-errChan:
if err != nil && !isClosedError(err) {
f.logger.Error("proxyUDP: copy error: %v", err)
}
f.logger.Trace("forwarder: tearing down UDP connection %v", epID(id))
return
}
}
// sendUDPEvent stores flow events for UDP connections
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
fields := nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.UDP,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
SourcePort: id.RemotePort,
DestPort: id.LocalPort,
}
if ep != nil {
if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
// fields are flipped since this is the in conn
// TODO: get bytes
fields.RxPackets = tcpStats.PacketsSent.Value()
fields.TxPackets = tcpStats.PacketsReceived.Value()
}
}
f.flowLogger.StoreEvent(fields)
}
func (c *udpPacketConn) updateLastSeen() {
c.lastSeen.Store(time.Now().UnixNano())
}
func (c *udpPacketConn) getIdleDuration() time.Duration {
lastSeen := time.Unix(0, c.lastSeen.Load())
return time.Since(lastSeen)
}
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error {
bufp := bufPool.Get().(*[]byte)
defer bufPool.Put(bufp)
buffer := *bufp
for {
if ctx.Err() != nil {
return ctx.Err()
}
if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil {
return fmt.Errorf("set read deadline: %w", err)
}
n, err := src.Read(buffer)
if err != nil {
if isTimeout(err) {
continue
}
return fmt.Errorf("read from %s: %w", direction, err)
}
_, err = dst.Write(buffer[:n])
if err != nil {
return fmt.Errorf("write to %s: %w", direction, err)
}
c.updateLastSeen()
}
}
func isClosedError(err error) bool {
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled)
}
func isTimeout(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
return netErr.Timeout()
}
return false
}

View File

@@ -1,131 +0,0 @@
package uspfilter
import (
"fmt"
"net"
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
)
type localIPManager struct {
mu sync.RWMutex
// Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory)
ipv4Bitmap [1 << 16]uint32
}
func newLocalIPManager() *localIPManager {
return &localIPManager{}
}
func (m *localIPManager) setBitmapBit(ip net.IP) {
ipv4 := ip.To4()
if ipv4 == nil {
return
}
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
m.ipv4Bitmap[high] |= 1 << (low % 32)
}
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
high := (uint16(ip[0]) << 8) | uint16(ip[1])
low := (uint16(ip[2]) << 8) | uint16(ip[3])
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
}
func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
if ipv4 := ip.To4(); ipv4 != nil {
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
if int(high) >= len(*newIPv4Bitmap) {
return fmt.Errorf("invalid IPv4 address: %s", ip)
}
ipStr := ip.String()
if _, exists := ipv4Set[ipStr]; !exists {
ipv4Set[ipStr] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ipStr)
newIPv4Bitmap[high] |= 1 << (low % 32)
}
}
return nil
}
func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
addrs, err := iface.Addrs()
if err != nil {
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
return
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
default:
continue
}
if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil {
log.Debugf("process IP failed: %v", err)
}
}
}
func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic: %v", r)
}
}()
var newIPv4Bitmap [1 << 16]uint32
ipv4Set := make(map[string]struct{})
var ipv4Addresses []string
// 127.0.0.0/8
high := uint16(127) << 8
for i := uint16(0); i < 256; i++ {
newIPv4Bitmap[high|i] = 0xffffffff
}
if iface != nil {
if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil {
return err
}
}
interfaces, err := net.Interfaces()
if err != nil {
log.Warnf("failed to get interfaces: %v", err)
} else {
for _, intf := range interfaces {
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
}
}
m.mu.Lock()
m.ipv4Bitmap = newIPv4Bitmap
m.mu.Unlock()
log.Debugf("Local IPv4 addresses: %v", ipv4Addresses)
return nil
}
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
m.mu.RLock()
defer m.mu.RUnlock()
if ip.Is4() {
return m.checkBitmapBit(ip.AsSlice())
}
return false
}

View File

@@ -1,271 +0,0 @@
package uspfilter
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func TestLocalIPManager(t *testing.T) {
tests := []struct {
name string
setupAddr wgaddr.Address
testIP netip.Addr
expected bool
}{
{
name: "Localhost range",
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("127.0.0.2"),
expected: true,
},
{
name: "Localhost standard address",
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("127.0.0.1"),
expected: true,
},
{
name: "Localhost range edge",
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("127.255.255.255"),
expected: true,
},
{
name: "Local IP matches",
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("192.168.1.1"),
expected: true,
},
{
name: "Local IP doesn't match",
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("192.168.1.2"),
expected: false,
},
{
name: "IPv6 address",
setupAddr: wgaddr.Address{
IP: net.ParseIP("fe80::1"),
Network: &net.IPNet{
IP: net.ParseIP("fe80::"),
Mask: net.CIDRMask(64, 128),
},
},
testIP: netip.MustParseAddr("fe80::1"),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager := newLocalIPManager()
mock := &IFaceMock{
AddressFunc: func() wgaddr.Address {
return tt.setupAddr
},
}
err := manager.UpdateLocalIPs(mock)
require.NoError(t, err)
result := manager.IsLocalIP(tt.testIP)
require.Equal(t, tt.expected, result)
})
}
}
func TestLocalIPManager_AllInterfaces(t *testing.T) {
manager := newLocalIPManager()
mock := &IFaceMock{}
// Get actual local interfaces
interfaces, err := net.Interfaces()
require.NoError(t, err)
var tests []struct {
ip string
expected bool
}
// Add all local interface IPs to test cases
for _, iface := range interfaces {
addrs, err := iface.Addrs()
require.NoError(t, err)
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
default:
continue
}
if ip4 := ip.To4(); ip4 != nil {
tests = append(tests, struct {
ip string
expected bool
}{
ip: ip4.String(),
expected: true,
})
}
}
}
// Add some external IPs as negative test cases
externalIPs := []string{
"8.8.8.8",
"1.1.1.1",
"208.67.222.222",
}
for _, ip := range externalIPs {
tests = append(tests, struct {
ip string
expected bool
}{
ip: ip,
expected: false,
})
}
require.NotEmpty(t, tests, "No test cases generated")
err = manager.UpdateLocalIPs(mock)
require.NoError(t, err)
t.Logf("Testing %d IPs", len(tests))
for _, tt := range tests {
t.Run(tt.ip, func(t *testing.T) {
result := manager.IsLocalIP(netip.MustParseAddr(tt.ip))
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
})
}
}
// MapImplementation is a version using map[string]struct{}
type MapImplementation struct {
localIPs map[string]struct{}
}
func BenchmarkIPChecks(b *testing.B) {
interfaces := make([]net.IP, 16)
for i := range interfaces {
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
}
// Setup bitmap version
bitmapManager := &localIPManager{
ipv4Bitmap: [1 << 16]uint32{},
}
for _, ip := range interfaces[:8] { // Add half of IPs
bitmapManager.setBitmapBit(ip)
}
// Setup map version
mapManager := &MapImplementation{
localIPs: make(map[string]struct{}),
}
for _, ip := range interfaces[:8] {
mapManager.localIPs[ip.String()] = struct{}{}
}
b.Run("Bitmap_Hit", func(b *testing.B) {
ip := interfaces[4]
b.ResetTimer()
for i := 0; i < b.N; i++ {
bitmapManager.checkBitmapBit(ip)
}
})
b.Run("Bitmap_Miss", func(b *testing.B) {
ip := interfaces[12]
b.ResetTimer()
for i := 0; i < b.N; i++ {
bitmapManager.checkBitmapBit(ip)
}
})
b.Run("Map_Hit", func(b *testing.B) {
ip := interfaces[4]
b.ResetTimer()
for i := 0; i < b.N; i++ {
// nolint:gosimple
_, _ = mapManager.localIPs[ip.String()]
}
})
b.Run("Map_Miss", func(b *testing.B) {
ip := interfaces[12]
b.ResetTimer()
for i := 0; i < b.N; i++ {
// nolint:gosimple
_, _ = mapManager.localIPs[ip.String()]
}
})
}
func BenchmarkWGPosition(b *testing.B) {
wgIP := net.ParseIP("10.10.0.1")
// Create two managers - one checks WG IP first, other checks it last
b.Run("WG_First", func(b *testing.B) {
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
bm.setBitmapBit(wgIP)
b.ResetTimer()
for i := 0; i < b.N; i++ {
bm.checkBitmapBit(wgIP)
}
})
b.Run("WG_Last", func(b *testing.B) {
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
// Fill with other IPs first
for i := 0; i < 15; i++ {
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
}
bm.setBitmapBit(wgIP) // Add WG IP last
b.ResetTimer()
for i := 0; i < b.N; i++ {
bm.checkBitmapBit(wgIP)
}
})
}

View File

@@ -1,252 +0,0 @@
// Package log provides a high-performance, non-blocking logger for userspace networking
package log
import (
"context"
"fmt"
"io"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
)
const (
maxBatchSize = 1024 * 16
maxMessageSize = 1024 * 2
defaultFlushInterval = 2 * time.Second
logChannelSize = 1000
)
type Level uint32
const (
LevelPanic Level = iota
LevelFatal
LevelError
LevelWarn
LevelInfo
LevelDebug
LevelTrace
)
var levelStrings = map[Level]string{
LevelPanic: "PANC",
LevelFatal: "FATL",
LevelError: "ERRO",
LevelWarn: "WARN",
LevelInfo: "INFO",
LevelDebug: "DEBG",
LevelTrace: "TRAC",
}
type logMessage struct {
level Level
format string
args []any
}
// Logger is a high-performance, non-blocking logger
type Logger struct {
output io.Writer
level atomic.Uint32
msgChannel chan logMessage
shutdown chan struct{}
closeOnce sync.Once
wg sync.WaitGroup
bufPool sync.Pool
}
// NewFromLogrus creates a new Logger that writes to the same output as the given logrus logger
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
l := &Logger{
output: logrusLogger.Out,
msgChannel: make(chan logMessage, logChannelSize),
shutdown: make(chan struct{}),
bufPool: sync.Pool{
New: func() any {
b := make([]byte, 0, maxMessageSize)
return &b
},
},
}
logrusLevel := logrusLogger.GetLevel()
l.level.Store(uint32(logrusLevel))
level := levelStrings[Level(logrusLevel)]
log.Debugf("New uspfilter logger created with loglevel %v", level)
l.wg.Add(1)
go l.worker()
return l
}
// SetLevel sets the logging level
func (l *Logger) SetLevel(level Level) {
l.level.Store(uint32(level))
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
}
func (l *Logger) log(level Level, format string, args ...any) {
select {
case l.msgChannel <- logMessage{level: level, format: format, args: args}:
default:
}
}
// Error logs a message at error level
func (l *Logger) Error(format string, args ...any) {
if l.level.Load() >= uint32(LevelError) {
l.log(LevelError, format, args...)
}
}
// Warn logs a message at warning level
func (l *Logger) Warn(format string, args ...any) {
if l.level.Load() >= uint32(LevelWarn) {
l.log(LevelWarn, format, args...)
}
}
// Info logs a message at info level
func (l *Logger) Info(format string, args ...any) {
if l.level.Load() >= uint32(LevelInfo) {
l.log(LevelInfo, format, args...)
}
}
// Debug logs a message at debug level
func (l *Logger) Debug(format string, args ...any) {
if l.level.Load() >= uint32(LevelDebug) {
l.log(LevelDebug, format, args...)
}
}
// Trace logs a message at trace level
func (l *Logger) Trace(format string, args ...any) {
if l.level.Load() >= uint32(LevelTrace) {
l.log(LevelTrace, format, args...)
}
}
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...any) {
*buf = (*buf)[:0]
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
*buf = append(*buf, ' ')
*buf = append(*buf, levelStrings[level]...)
*buf = append(*buf, ' ')
var msg string
if len(args) > 0 {
msg = fmt.Sprintf(format, args...)
} else {
msg = format
}
*buf = append(*buf, msg...)
*buf = append(*buf, '\n')
if len(*buf) > maxMessageSize {
*buf = (*buf)[:maxMessageSize]
}
}
// processMessage handles a single log message and adds it to the buffer
func (l *Logger) processMessage(msg logMessage, buffer *[]byte) {
bufp := l.bufPool.Get().(*[]byte)
defer l.bufPool.Put(bufp)
l.formatMessage(bufp, msg.level, msg.format, msg.args...)
if len(*buffer)+len(*bufp) > maxBatchSize {
_, _ = l.output.Write(*buffer)
*buffer = (*buffer)[:0]
}
*buffer = append(*buffer, *bufp...)
}
// flushBuffer writes the accumulated buffer to output
func (l *Logger) flushBuffer(buffer *[]byte) {
if len(*buffer) > 0 {
_, _ = l.output.Write(*buffer)
*buffer = (*buffer)[:0]
}
}
// processBatch processes as many messages as possible without blocking
func (l *Logger) processBatch(buffer *[]byte) {
for len(*buffer) < maxBatchSize {
select {
case msg := <-l.msgChannel:
l.processMessage(msg, buffer)
default:
return
}
}
}
// handleShutdown manages the graceful shutdown sequence with timeout
func (l *Logger) handleShutdown(buffer *[]byte) {
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
for {
select {
case msg := <-l.msgChannel:
l.processMessage(msg, buffer)
case <-ctx.Done():
l.flushBuffer(buffer)
return
}
if len(l.msgChannel) == 0 {
l.flushBuffer(buffer)
return
}
}
}
// worker is the main goroutine that processes log messages
func (l *Logger) worker() {
defer l.wg.Done()
ticker := time.NewTicker(defaultFlushInterval)
defer ticker.Stop()
buffer := make([]byte, 0, maxBatchSize)
for {
select {
case <-l.shutdown:
l.handleShutdown(&buffer)
return
case <-ticker.C:
l.flushBuffer(&buffer)
case msg := <-l.msgChannel:
l.processMessage(msg, &buffer)
l.processBatch(&buffer)
}
}
}
// Stop gracefully shuts down the logger
func (l *Logger) Stop(ctx context.Context) error {
done := make(chan struct{})
l.closeOnce.Do(func() {
close(l.shutdown)
})
go func() {
l.wg.Wait()
close(done)
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-done:
return nil
}
}

View File

@@ -1,121 +0,0 @@
package log_test
import (
"context"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
type discard struct{}
func (d *discard) Write(p []byte) (n int, err error) {
return len(p), nil
}
func BenchmarkLogger(b *testing.B) {
simpleMessage := "Connection established"
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4 // TCPStateEstablished
complexMessage := "Packet inspection result: protocol=%s, direction=%s, flags=0x%x, sequence=%d, acknowledged=%d, payload_size=%d, fragmented=%v, connection_id=%s"
protocol := "TCP"
direction := "outbound"
flags := uint16(0x18) // ACK + PSH
sequence := uint32(123456789)
acknowledged := uint32(987654321)
payloadSize := 1460
fragmented := false
connID := "f7a12b3e-c456-7890-d123-456789abcdef"
b.Run("SimpleMessage", func(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(simpleMessage)
}
})
b.Run("ConntrackMessage", func(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
}
})
b.Run("ComplexMessage", func(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(complexMessage, protocol, direction, flags, sequence, acknowledged, payloadSize, fragmented, connID)
}
})
}
// BenchmarkLoggerParallel tests the logger under concurrent load
func BenchmarkLoggerParallel(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
}
})
}
// BenchmarkLoggerBurst tests how the logger handles bursts of messages
func BenchmarkLoggerBurst(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j := 0; j < 100; j++ {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
}
}
}
func createTestLogger() *log.Logger {
logrusLogger := logrus.New()
logrusLogger.SetOutput(&discard{})
logrusLogger.SetLevel(logrus.TraceLevel)
return log.NewFromLogrus(logrusLogger)
}
func cleanupLogger(logger *log.Logger) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_ = logger.Stop(ctx)
}

View File

@@ -1,45 +1,30 @@
package uspfilter package uspfilter
import ( import (
"net/netip" "net"
"github.com/google/gopacket" "github.com/google/gopacket"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
) )
// PeerRule to handle management of rules // Rule to handle management of rules
type PeerRule struct { type Rule struct {
id string id string
mgmtId []byte ip net.IP
ip netip.Addr
ipLayer gopacket.LayerType ipLayer gopacket.LayerType
matchByIP bool matchByIP bool
protoLayer gopacket.LayerType protoLayer gopacket.LayerType
sPort *firewall.Port direction firewall.RuleDirection
dPort *firewall.Port sPort uint16
dPort uint16
drop bool drop bool
comment string
udpHook func([]byte) bool udpHook func([]byte) bool
} }
// ID returns the rule id // GetRuleID returns the rule id
func (r *PeerRule) ID() string { func (r *Rule) GetRuleID() string {
return r.id
}
type RouteRule struct {
id string
mgmtId []byte
sources []netip.Prefix
destination netip.Prefix
proto firewall.Protocol
srcPort *firewall.Port
dstPort *firewall.Port
action firewall.Action
}
// ID returns the rule id
func (r *RouteRule) ID() string {
return r.id return r.id
} }

View File

@@ -1,411 +0,0 @@
package uspfilter
import (
"fmt"
"net/netip"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
)
type PacketStage int
const (
StageReceived PacketStage = iota
StageConntrack
StagePeerACL
StageRouting
StageRouteACL
StageForwarding
StageCompleted
)
const msgProcessingCompleted = "Processing completed"
func (s PacketStage) String() string {
return map[PacketStage]string{
StageReceived: "Received",
StageConntrack: "Connection Tracking",
StagePeerACL: "Peer ACL",
StageRouting: "Routing",
StageRouteACL: "Route ACL",
StageForwarding: "Forwarding",
StageCompleted: "Completed",
}[s]
}
type ForwarderAction struct {
Action string
RemoteAddr string
Error error
}
type TraceResult struct {
Timestamp time.Time
Stage PacketStage
Message string
Allowed bool
ForwarderAction *ForwarderAction
}
type PacketTrace struct {
SourceIP netip.Addr
DestinationIP netip.Addr
Protocol string
SourcePort uint16
DestinationPort uint16
Direction fw.RuleDirection
Results []TraceResult
}
type TCPState struct {
SYN bool
ACK bool
FIN bool
RST bool
PSH bool
URG bool
}
type PacketBuilder struct {
SrcIP netip.Addr
DstIP netip.Addr
Protocol fw.Protocol
SrcPort uint16
DstPort uint16
ICMPType uint8
ICMPCode uint8
Direction fw.RuleDirection
PayloadSize int
TCPState *TCPState
}
func (t *PacketTrace) AddResult(stage PacketStage, message string, allowed bool) {
t.Results = append(t.Results, TraceResult{
Timestamp: time.Now(),
Stage: stage,
Message: message,
Allowed: allowed,
})
}
func (t *PacketTrace) AddResultWithForwarder(stage PacketStage, message string, allowed bool, action *ForwarderAction) {
t.Results = append(t.Results, TraceResult{
Timestamp: time.Now(),
Stage: stage,
Message: message,
Allowed: allowed,
ForwarderAction: action,
})
}
func (p *PacketBuilder) Build() ([]byte, error) {
ip := p.buildIPLayer()
pktLayers := []gopacket.SerializableLayer{ip}
transportLayer, err := p.buildTransportLayer(ip)
if err != nil {
return nil, err
}
pktLayers = append(pktLayers, transportLayer...)
if p.PayloadSize > 0 {
payload := make([]byte, p.PayloadSize)
pktLayers = append(pktLayers, gopacket.Payload(payload))
}
return serializePacket(pktLayers)
}
func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
return &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
SrcIP: p.SrcIP.AsSlice(),
DstIP: p.DstIP.AsSlice(),
}
}
func (p *PacketBuilder) buildTransportLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
switch p.Protocol {
case "tcp":
return p.buildTCPLayer(ip)
case "udp":
return p.buildUDPLayer(ip)
case "icmp":
return p.buildICMPLayer()
default:
return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol)
}
}
func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
tcp := &layers.TCP{
SrcPort: layers.TCPPort(p.SrcPort),
DstPort: layers.TCPPort(p.DstPort),
Window: 65535,
SYN: p.TCPState != nil && p.TCPState.SYN,
ACK: p.TCPState != nil && p.TCPState.ACK,
FIN: p.TCPState != nil && p.TCPState.FIN,
RST: p.TCPState != nil && p.TCPState.RST,
PSH: p.TCPState != nil && p.TCPState.PSH,
URG: p.TCPState != nil && p.TCPState.URG,
}
if err := tcp.SetNetworkLayerForChecksum(ip); err != nil {
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err)
}
return []gopacket.SerializableLayer{tcp}, nil
}
func (p *PacketBuilder) buildUDPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
udp := &layers.UDP{
SrcPort: layers.UDPPort(p.SrcPort),
DstPort: layers.UDPPort(p.DstPort),
}
if err := udp.SetNetworkLayerForChecksum(ip); err != nil {
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err)
}
return []gopacket.SerializableLayer{udp}, nil
}
func (p *PacketBuilder) buildICMPLayer() ([]gopacket.SerializableLayer, error) {
icmp := &layers.ICMPv4{
TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode),
}
if p.ICMPType == layers.ICMPv4TypeEchoRequest || p.ICMPType == layers.ICMPv4TypeEchoReply {
icmp.Id = uint16(1)
icmp.Seq = uint16(1)
}
return []gopacket.SerializableLayer{icmp}, nil
}
func serializePacket(layers []gopacket.SerializableLayer) ([]byte, error) {
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
if err := gopacket.SerializeLayers(buf, opts, layers...); err != nil {
return nil, fmt.Errorf("serialize packet: %w", err)
}
return buf.Bytes(), nil
}
func getIPProtocolNumber(protocol fw.Protocol) int {
switch protocol {
case fw.ProtocolTCP:
return int(layers.IPProtocolTCP)
case fw.ProtocolUDP:
return int(layers.IPProtocolUDP)
case fw.ProtocolICMP:
return int(layers.IPProtocolICMPv4)
default:
return 0
}
}
func (m *Manager) TracePacketFromBuilder(builder *PacketBuilder) (*PacketTrace, error) {
packetData, err := builder.Build()
if err != nil {
return nil, fmt.Errorf("build packet: %w", err)
}
return m.TracePacket(packetData, builder.Direction), nil
}
func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *PacketTrace {
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
trace := &PacketTrace{Direction: direction}
// Initial packet decoding
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false)
return trace
}
// Extract base packet info
srcIP, dstIP := m.extractIPs(d)
trace.SourceIP = srcIP
trace.DestinationIP = dstIP
// Determine protocol and ports
switch d.decoded[1] {
case layers.LayerTypeTCP:
trace.Protocol = "TCP"
trace.SourcePort = uint16(d.tcp.SrcPort)
trace.DestinationPort = uint16(d.tcp.DstPort)
case layers.LayerTypeUDP:
trace.Protocol = "UDP"
trace.SourcePort = uint16(d.udp.SrcPort)
trace.DestinationPort = uint16(d.udp.DstPort)
case layers.LayerTypeICMPv4:
trace.Protocol = "ICMP"
}
trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d",
trace.Protocol, srcIP, trace.SourcePort, dstIP, trace.DestinationPort), true)
if direction == fw.RuleDirectionOUT {
return m.traceOutbound(packetData, trace)
}
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
}
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
return trace
}
if m.localipmanager.IsLocalIP(dstIP) {
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
return trace
}
}
if !m.handleRouting(trace) {
return trace
}
if m.nativeRouter.Load() {
return m.handleNativeRouter(trace)
}
return m.handleRouteACLs(trace, d, srcIP, dstIP)
}
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) bool {
allowed := m.isValidTrackedConnection(d, srcIP, dstIP, 0)
msg := "No existing connection found"
if allowed {
msg = m.buildConntrackStateMessage(d)
trace.AddResult(StageConntrack, msg, true)
trace.AddResult(StageCompleted, "Packet allowed by connection tracking", true)
return true
}
trace.AddResult(StageConntrack, msg, false)
return false
}
func (m *Manager) buildConntrackStateMessage(d *decoder) string {
msg := "Matched existing connection state"
switch d.decoded[1] {
case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp)
msg += fmt.Sprintf(" (TCP Flags: SYN=%v ACK=%v RST=%v FIN=%v)",
flags&conntrack.TCPSyn != 0,
flags&conntrack.TCPAck != 0,
flags&conntrack.TCPRst != 0,
flags&conntrack.TCPFin != 0)
case layers.LayerTypeICMPv4:
msg += fmt.Sprintf(" (ICMP ID=%d, Seq=%d)", d.icmp4.Id, d.icmp4.Seq)
}
return msg
}
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
strRuleId := "<no id>"
if ruleId != nil {
strRuleId = string(ruleId)
}
msg := fmt.Sprintf("Allowed by peer ACL rules (%s)", strRuleId)
if blocked {
msg = fmt.Sprintf("Blocked by peer ACL rules (%s)", strRuleId)
trace.AddResult(StagePeerACL, msg, false)
trace.AddResult(StageCompleted, "Packet dropped - ACL denied", false)
return true
}
trace.AddResult(StagePeerACL, msg, true)
// Handle netstack mode
if m.netstack {
switch {
case !m.localForwarding:
trace.AddResult(StageCompleted, "Packet sent to virtual stack", true)
case m.forwarder.Load() != nil:
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", true)
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
default:
trace.AddResult(StageCompleted, "Packet dropped - forwarder not initialized", false)
}
return true
}
// In normal mode, packets are allowed through for local delivery
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
return true
}
func (m *Manager) handleRouting(trace *PacketTrace) bool {
if !m.routingEnabled.Load() {
trace.AddResult(StageRouting, "Routing disabled", false)
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
return false
}
trace.AddResult(StageRouting, "Routing enabled, checking ACLs", true)
return true
}
func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
trace.AddResult(StageRouteACL, "Using native router, skipping ACL checks", true)
trace.AddResult(StageForwarding, "Forwarding via native router", true)
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
return trace
}
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
proto, _ := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
strId := string(id)
if id == nil {
strId = "<no id>"
}
msg := fmt.Sprintf("Allowed by route ACLs (%s)", strId)
if !allowed {
msg = fmt.Sprintf("Blocked by route ACLs (%s)", strId)
}
trace.AddResult(StageRouteACL, msg, allowed)
if allowed && m.forwarder.Load() != nil {
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
}
trace.AddResult(StageCompleted, msgProcessingCompleted, allowed)
return trace
}
func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr string, allowed bool) {
fwdAction := &ForwarderAction{
Action: action,
RemoteAddr: remoteAddr,
}
trace.AddResultWithForwarder(StageForwarding,
fmt.Sprintf("Forwarding to %s", fwdAction.Action), allowed, fwdAction)
}
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
// will create or update the connection state
dropped := m.processOutgoingHooks(packetData, 0)
if dropped {
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
} else {
trace.AddResult(StageCompleted, "Packet allowed (outgoing)", true)
}
return trace
}

View File

@@ -1,440 +0,0 @@
package uspfilter
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func verifyTraceStages(t *testing.T, trace *PacketTrace, expectedStages []PacketStage) {
t.Logf("Trace results: %v", trace.Results)
actualStages := make([]PacketStage, 0, len(trace.Results))
for _, result := range trace.Results {
actualStages = append(actualStages, result.Stage)
t.Logf("Stage: %s, Message: %s, Allowed: %v", result.Stage, result.Message, result.Allowed)
}
require.ElementsMatch(t, expectedStages, actualStages, "Trace stages don't match expected stages")
}
func verifyFinalDisposition(t *testing.T, trace *PacketTrace, expectedAllowed bool) {
require.NotEmpty(t, trace.Results, "Trace should have results")
lastResult := trace.Results[len(trace.Results)-1]
require.Equal(t, StageCompleted, lastResult.Stage, "Last stage should be 'Completed'")
require.Equal(t, expectedAllowed, lastResult.Allowed, "Final disposition incorrect")
}
func TestTracePacket(t *testing.T) {
setupTracerTest := func(statefulMode bool) *Manager {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
}
},
}
m, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
if !statefulMode {
m.stateful = false
}
return m
}
createPacketBuilder := func(srcIP, dstIP string, protocol fw.Protocol, srcPort, dstPort uint16, direction fw.RuleDirection) *PacketBuilder {
builder := &PacketBuilder{
SrcIP: netip.MustParseAddr(srcIP),
DstIP: netip.MustParseAddr(dstIP),
Protocol: protocol,
SrcPort: srcPort,
DstPort: dstPort,
Direction: direction,
}
if protocol == "tcp" {
builder.TCPState = &TCPState{SYN: true}
}
return builder
}
createICMPPacketBuilder := func(srcIP, dstIP string, icmpType, icmpCode uint8, direction fw.RuleDirection) *PacketBuilder {
return &PacketBuilder{
SrcIP: netip.MustParseAddr(srcIP),
DstIP: netip.MustParseAddr(dstIP),
Protocol: "icmp",
ICMPType: icmpType,
ICMPCode: icmpCode,
Direction: direction,
}
}
testCases := []struct {
name string
setup func(*Manager)
packetBuilder func() *PacketBuilder
expectedStages []PacketStage
expectedAllow bool
}{
{
name: "LocalTraffic_ACLAllowed",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "LocalTraffic_ACLDenied",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,
},
{
name: "LocalTraffic_WithForwarder",
setup: func(m *Manager) {
m.netstack = true
m.localForwarding = true
m.forwarder.Store(&forwarder.Forwarder{})
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageForwarding,
StageCompleted,
},
expectedAllow: true,
},
{
name: "LocalTraffic_WithoutForwarder",
setup: func(m *Manager) {
m.netstack = true
m.localForwarding = false
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "RoutedTraffic_ACLAllowed",
setup: func(m *Manager) {
m.routingEnabled.Store(true)
m.nativeRouter.Store(false)
m.forwarder.Store(&forwarder.Forwarder{})
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageRouteACL,
StageForwarding,
StageCompleted,
},
expectedAllow: true,
},
{
name: "RoutedTraffic_ACLDenied",
setup: func(m *Manager) {
m.routingEnabled.Store(true)
m.nativeRouter.Store(false)
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageRouteACL,
StageCompleted,
},
expectedAllow: false,
},
{
name: "RoutedTraffic_NativeRouter",
setup: func(m *Manager) {
m.routingEnabled.Store(true)
m.nativeRouter.Store(true)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageRouteACL,
StageForwarding,
StageCompleted,
},
expectedAllow: true,
},
{
name: "RoutedTraffic_RoutingDisabled",
setup: func(m *Manager) {
m.routingEnabled.Store(false)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageCompleted,
},
expectedAllow: false,
},
{
name: "ConnectionTracking_Hit",
setup: func(m *Manager) {
srcIP := netip.MustParseAddr("100.10.0.100")
dstIP := netip.MustParseAddr("1.1.1.1")
srcPort := uint16(12345)
dstPort := uint16(80)
m.tcpTracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, conntrack.TCPSyn, 0)
},
packetBuilder: func() *PacketBuilder {
pb := createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 80, 12345, fw.RuleDirectionIN)
pb.TCPState = &TCPState{SYN: true, ACK: true}
return pb
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageCompleted,
},
expectedAllow: true,
},
{
name: "OutboundTraffic",
setup: func(m *Manager) {
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("100.10.0.100", "1.1.1.1", "tcp", 12345, 80, fw.RuleDirectionOUT)
},
expectedStages: []PacketStage{
StageReceived,
StageCompleted,
},
expectedAllow: true,
},
{
name: "ICMPEchoRequest",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolICMP
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 8, 0, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "ICMPDestinationUnreachable",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolICMP
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 3, 0, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "UDPTraffic_WithoutHook",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolUDP
port := &fw.Port{Values: []uint16{53}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "UDPTraffic_WithHook",
setup: func(m *Manager) {
hookFunc := func([]byte) bool {
return true
}
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,
},
{
name: "StatefulDisabled_NoTracking",
setup: func(m *Manager) {
m.stateful = false
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
m := setupTracerTest(true)
tc.setup(m)
require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")),
"100.10.0.100 should be recognized as a local IP")
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("172.17.0.2")),
"172.17.0.2 should not be recognized as a local IP")
pb := tc.packetBuilder()
trace, err := m.TracePacketFromBuilder(pb)
require.NoError(t, err)
verifyTraceStages(t, trace, tc.expectedStages)
verifyFinalDisposition(t, trace, tc.expectedAllow)
})
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,12 +1,9 @@
//go:build uspbench
package uspfilter package uspfilter
import ( import (
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
"net/netip"
"os" "os"
"strings" "strings"
"testing" "testing"
@@ -93,7 +90,8 @@ func BenchmarkCoreFiltering(b *testing.B) {
stateful: false, stateful: false,
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
// Single rule allowing all traffic // Single rule allowing all traffic
_, err := m.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, fw.ActionAccept, "") _, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "allow all")
require.NoError(b, err) require.NoError(b, err)
}, },
desc: "Baseline: Single 'allow all' rule without connection tracking", desc: "Baseline: Single 'allow all' rule without connection tracking",
@@ -113,15 +111,10 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Add explicit rules matching return traffic pattern // Add explicit rules matching return traffic pattern
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
ip := generateRandomIPs(1)[0] ip := generateRandomIPs(1)[0]
_, err := m.AddPeerFiltering( _, err := m.AddPeerFiltering(ip, fw.ProtocolTCP,
nil, &fw.Port{Values: []int{1024 + i}},
ip, &fw.Port{Values: []int{80}},
fw.ProtocolTCP, fw.RuleDirectionIN, fw.ActionAccept, "", "explicit return")
&fw.Port{Values: []uint16{uint16(1024 + i)}},
&fw.Port{Values: []uint16{80}},
fw.ActionAccept,
"",
)
require.NoError(b, err) require.NoError(b, err)
} }
}, },
@@ -132,15 +125,8 @@ func BenchmarkCoreFiltering(b *testing.B) {
stateful: true, stateful: true,
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
// Add some basic rules but rely on state for established connections // Add some basic rules but rely on state for established connections
_, err := m.AddPeerFiltering( _, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil,
nil, fw.RuleDirectionIN, fw.ActionDrop, "", "default drop")
net.ParseIP("0.0.0.0"),
fw.ProtocolTCP,
nil,
nil,
fw.ActionDrop,
"",
)
require.NoError(b, err) require.NoError(b, err)
}, },
desc: "Connection tracking with established connections", desc: "Connection tracking with established connections",
@@ -169,9 +155,9 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Create manager and basic setup // Create manager and basic setup
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) })
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Reset(nil))
}) })
manager.wgNetwork = &net.IPNet{ manager.wgNetwork = &net.IPNet{
@@ -193,13 +179,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
// For stateful scenarios, establish the connection // For stateful scenarios, establish the connection
if sc.stateful { if sc.stateful {
manager.processOutgoingHooks(outbound, 0) manager.processOutgoingHooks(outbound)
} }
// Measure inbound packet processing // Measure inbound packet processing
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, 0) manager.dropFilter(inbound, manager.incomingRules)
} }
}) })
} }
@@ -214,9 +200,9 @@ func BenchmarkStateScaling(b *testing.B) {
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) { b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) })
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Reset(nil))
}) })
manager.wgNetwork = &net.IPNet{ manager.wgNetwork = &net.IPNet{
@@ -230,7 +216,7 @@ func BenchmarkStateScaling(b *testing.B) {
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
outbound := generatePacket(b, srcIPs[i], dstIPs[i], outbound := generatePacket(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, layers.IPProtocolTCP) uint16(1024+i), 80, layers.IPProtocolTCP)
manager.processOutgoingHooks(outbound, 0) manager.processOutgoingHooks(outbound)
} }
// Test packet // Test packet
@@ -238,11 +224,11 @@ func BenchmarkStateScaling(b *testing.B) {
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP) testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
// First establish our test connection // First establish our test connection
manager.processOutgoingHooks(testOut, 0) manager.processOutgoingHooks(testOut)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(testIn, 0) manager.dropFilter(testIn, manager.incomingRules)
} }
}) })
} }
@@ -262,9 +248,9 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
b.Run(sc.name, func(b *testing.B) { b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) })
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Reset(nil))
}) })
manager.wgNetwork = &net.IPNet{ manager.wgNetwork = &net.IPNet{
@@ -278,12 +264,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
if sc.established { if sc.established {
manager.processOutgoingHooks(outbound, 0) manager.processOutgoingHooks(outbound)
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, 0) manager.dropFilter(inbound, manager.incomingRules)
} }
}) })
} }
@@ -461,9 +447,9 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
b.Run(sc.name, func(b *testing.B) { b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) })
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Reset(nil))
}) })
// Setup scenario // Setup scenario
@@ -477,25 +463,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
// For stateful cases and established connections // For stateful cases and established connections
if !strings.Contains(sc.name, "allow_non_wg") || if !strings.Contains(sc.name, "allow_non_wg") ||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") { (strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
manager.processOutgoingHooks(outbound, 0) manager.processOutgoingHooks(outbound)
// For TCP post-handshake, simulate full handshake // For TCP post-handshake, simulate full handshake
if sc.state == "post_handshake" { if sc.state == "post_handshake" {
// SYN // SYN
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn)) syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn, 0) manager.processOutgoingHooks(syn)
// SYN-ACK // SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck)) synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, 0) manager.dropFilter(synack, manager.incomingRules)
// ACK // ACK
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)) ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack, 0) manager.processOutgoingHooks(ack)
} }
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, 0) manager.dropFilter(inbound, manager.incomingRules)
} }
}) })
} }
@@ -588,9 +574,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) })
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Reset(nil))
}) })
manager.SetNetwork(&net.IPNet{ manager.SetNetwork(&net.IPNet{
@@ -601,7 +587,10 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
// Single rule to allow all return traffic from port 80 // Single rule to allow all return traffic from port 80
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "") _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}},
nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }
@@ -624,17 +613,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Initial SYN // Initial SYN
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn)) uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn, 0) manager.processOutgoingHooks(syn)
// SYN-ACK // SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, 0) manager.dropFilter(synack, manager.incomingRules)
// ACK // ACK
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck)) uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack, 0) manager.processOutgoingHooks(ack)
} }
// Prepare test packets simulating bidirectional traffic // Prepare test packets simulating bidirectional traffic
@@ -655,9 +644,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Simulate bidirectional traffic // Simulate bidirectional traffic
// First outbound data // First outbound data
manager.processOutgoingHooks(outPackets[connIdx], 0) manager.processOutgoingHooks(outPackets[connIdx])
// Then inbound response - this is what we're actually measuring // Then inbound response - this is what we're actually measuring
manager.dropFilter(inPackets[connIdx], 0) manager.dropFilter(inPackets[connIdx], manager.incomingRules)
} }
}) })
} }
@@ -676,9 +665,9 @@ func BenchmarkShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) })
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Reset(nil))
}) })
manager.SetNetwork(&net.IPNet{ manager.SetNetwork(&net.IPNet{
@@ -689,7 +678,10 @@ func BenchmarkShortLivedConnections(b *testing.B) {
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
// Single rule to allow all return traffic from port 80 // Single rule to allow all return traffic from port 80
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "") _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}},
nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }
@@ -761,19 +753,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
p := patterns[connIdx] p := patterns[connIdx]
// Connection establishment // Connection establishment
manager.processOutgoingHooks(p.syn, 0) manager.processOutgoingHooks(p.syn)
manager.dropFilter(p.synAck, 0) manager.dropFilter(p.synAck, manager.incomingRules)
manager.processOutgoingHooks(p.ack, 0) manager.processOutgoingHooks(p.ack)
// Data transfer // Data transfer
manager.processOutgoingHooks(p.request, 0) manager.processOutgoingHooks(p.request)
manager.dropFilter(p.response, 0) manager.dropFilter(p.response, manager.incomingRules)
// Connection teardown // Connection teardown
manager.processOutgoingHooks(p.finClient, 0) manager.processOutgoingHooks(p.finClient)
manager.dropFilter(p.ackServer, 0) manager.dropFilter(p.ackServer, manager.incomingRules)
manager.dropFilter(p.finServer, 0) manager.dropFilter(p.finServer, manager.incomingRules)
manager.processOutgoingHooks(p.ackClient, 0) manager.processOutgoingHooks(p.ackClient)
} }
}) })
} }
@@ -792,9 +784,9 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) })
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Reset(nil))
}) })
manager.SetNetwork(&net.IPNet{ manager.SetNetwork(&net.IPNet{
@@ -804,7 +796,10 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "") _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}},
nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }
@@ -826,15 +821,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
for i := 0; i < sc.connCount; i++ { for i := 0; i < sc.connCount; i++ {
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn)) uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn, 0) manager.processOutgoingHooks(syn)
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, 0) manager.dropFilter(synack, manager.incomingRules)
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck)) uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack, 0) manager.processOutgoingHooks(ack)
} }
// Pre-generate test packets // Pre-generate test packets
@@ -856,8 +851,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
counter++ counter++
// Simulate bidirectional traffic // Simulate bidirectional traffic
manager.processOutgoingHooks(outPackets[connIdx], 0) manager.processOutgoingHooks(outPackets[connIdx])
manager.dropFilter(inPackets[connIdx], 0) manager.dropFilter(inPackets[connIdx], manager.incomingRules)
} }
}) })
}) })
@@ -877,9 +872,9 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) })
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Reset(nil))
}) })
manager.SetNetwork(&net.IPNet{ manager.SetNetwork(&net.IPNet{
@@ -888,7 +883,10 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
}) })
if sc.rules { if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "") _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []int{80}},
nil,
fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }
@@ -950,17 +948,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
p := patterns[connIdx] p := patterns[connIdx]
// Full connection lifecycle // Full connection lifecycle
manager.processOutgoingHooks(p.syn, 0) manager.processOutgoingHooks(p.syn)
manager.dropFilter(p.synAck, 0) manager.dropFilter(p.synAck, manager.incomingRules)
manager.processOutgoingHooks(p.ack, 0) manager.processOutgoingHooks(p.ack)
manager.processOutgoingHooks(p.request, 0) manager.processOutgoingHooks(p.request)
manager.dropFilter(p.response, 0) manager.dropFilter(p.response, manager.incomingRules)
manager.processOutgoingHooks(p.finClient, 0) manager.processOutgoingHooks(p.finClient)
manager.dropFilter(p.ackServer, 0) manager.dropFilter(p.ackServer, manager.incomingRules)
manager.dropFilter(p.finServer, 0) manager.dropFilter(p.finServer, manager.incomingRules)
manager.processOutgoingHooks(p.ackClient, 0) manager.processOutgoingHooks(p.ackClient)
} }
}) })
}) })
@@ -998,65 +996,3 @@ func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstP
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test"))) require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")))
return buf.Bytes() return buf.Bytes()
} }
func BenchmarkRouteACLs(b *testing.B) {
manager := setupRoutedManager(b, "10.10.0.100/16")
// Add several route rules to simulate real-world scenario
rules := []struct {
sources []netip.Prefix
dest netip.Prefix
proto fw.Protocol
port *fw.Port
}{
{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
port: &fw.Port{Values: []uint16{80, 443}},
},
{
sources: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/12"),
netip.MustParsePrefix("10.0.0.0/8"),
},
dest: netip.MustParsePrefix("0.0.0.0/0"),
proto: fw.ProtocolICMP,
},
{
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
dest: netip.MustParsePrefix("192.168.0.0/16"),
proto: fw.ProtocolUDP,
port: &fw.Port{Values: []uint16{53}},
},
}
for _, r := range rules {
_, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept)
if err != nil {
b.Fatal(err)
}
}
// Test cases that exercise different matching scenarios
cases := []struct {
srcIP string
dstIP string
proto fw.Protocol
dstPort uint16
}{
{"100.10.0.1", "192.168.1.100", fw.ProtocolTCP, 443}, // Match first rule
{"172.16.0.1", "8.8.8.8", fw.ProtocolICMP, 0}, // Match second rule
{"1.1.1.1", "192.168.1.53", fw.ProtocolUDP, 53}, // Match third rule
{"192.168.1.1", "10.0.0.1", fw.ProtocolTCP, 8080}, // No match
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, tc := range cases {
srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := netip.MustParseAddr(tc.dstIP)
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,50 +1,25 @@
package uspfilter package uspfilter
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
wgdevice "golang.zx2c4.com/wireguard/device"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow"
) )
var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
type IFaceMock struct { type IFaceMock struct {
SetFilterFunc func(device.PacketFilter) error SetFilterFunc func(device.PacketFilter) error
AddressFunc func() wgaddr.Address AddressFunc func() iface.WGAddress
GetWGDeviceFunc func() *wgdevice.Device
GetDeviceFunc func() *device.FilteredDevice
}
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
if i.GetWGDeviceFunc == nil {
return nil
}
return i.GetWGDeviceFunc()
}
func (i *IFaceMock) GetDevice() *device.FilteredDevice {
if i.GetDeviceFunc == nil {
return nil
}
return i.GetDeviceFunc()
} }
func (i *IFaceMock) SetFilter(iface device.PacketFilter) error { func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
@@ -54,9 +29,9 @@ func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
return i.SetFilterFunc(iface) return i.SetFilterFunc(iface)
} }
func (i *IFaceMock) Address() wgaddr.Address { func (i *IFaceMock) Address() iface.WGAddress {
if i.AddressFunc == nil { if i.AddressFunc == nil {
return wgaddr.Address{} return iface.WGAddress{}
} }
return i.AddressFunc() return i.AddressFunc()
} }
@@ -66,7 +41,7 @@ func TestManagerCreate(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock, false, flowLogger) m, err := Create(ifaceMock)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -86,7 +61,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
}, },
} }
m, err := Create(ifaceMock, false, flowLogger) m, err := Create(ifaceMock)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -94,10 +69,12 @@ func TestManagerAddPeerFiltering(t *testing.T) {
ip := net.ParseIP("192.168.1.1") ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}} port := &fw.Port{Values: []int{80}}
direction := fw.RuleDirectionOUT
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule"
rule, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "") rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -119,25 +96,48 @@ func TestManagerDeleteRule(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock, false, flowLogger) m, err := Create(ifaceMock)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
} }
ip := netip.MustParseAddr("192.168.1.1") ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}} port := &fw.Port{Values: []int{80}}
direction := fw.RuleDirectionOUT
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule"
rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "") rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
} }
ip = net.ParseIP("192.168.1.1")
proto = fw.ProtocolTCP
port = &fw.Port{Values: []int{80}}
direction = fw.RuleDirectionIN
action = fw.ActionDrop
comment = "Test rule 2"
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
}
for _, r := range rule {
err = m.DeletePeerRule(r)
if err != nil {
t.Errorf("failed to delete rule: %v", err)
return
}
}
for _, r := range rule2 { for _, r := range rule2 {
if _, ok := m.incomingRules[ip][r.ID()]; !ok { if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
t.Errorf("rule2 is not in the incomingRules") t.Errorf("rule2 is not in the incomingRules")
} }
} }
@@ -151,7 +151,7 @@ func TestManagerDeleteRule(t *testing.T) {
} }
for _, r := range rule2 { for _, r := range rule2 {
if _, ok := m.incomingRules[ip][r.ID()]; ok { if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; ok {
t.Errorf("rule2 is not in the incomingRules") t.Errorf("rule2 is not in the incomingRules")
} }
} }
@@ -162,7 +162,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name string name string
in bool in bool
expDir fw.RuleDirection expDir fw.RuleDirection
ip netip.Addr ip net.IP
dPort uint16 dPort uint16
hook func([]byte) bool hook func([]byte) bool
expectedID string expectedID string
@@ -171,7 +171,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name: "Test Outgoing UDP Packet Hook", name: "Test Outgoing UDP Packet Hook",
in: false, in: false,
expDir: fw.RuleDirectionOUT, expDir: fw.RuleDirectionOUT,
ip: netip.MustParseAddr("10.168.0.1"), ip: net.IPv4(10, 168, 0, 1),
dPort: 8000, dPort: 8000,
hook: func([]byte) bool { return true }, hook: func([]byte) bool { return true },
}, },
@@ -179,7 +179,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name: "Test Incoming UDP Packet Hook", name: "Test Incoming UDP Packet Hook",
in: true, in: true,
expDir: fw.RuleDirectionIN, expDir: fw.RuleDirectionIN,
ip: netip.MustParseAddr("::1"), ip: net.IPv6loopback,
dPort: 9000, dPort: 9000,
hook: func([]byte) bool { return false }, hook: func([]byte) bool { return false },
}, },
@@ -189,18 +189,18 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) })
require.NoError(t, err) require.NoError(t, err)
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
var addedRule PeerRule var addedRule Rule
if tt.in { if tt.in {
if len(manager.incomingRules[tt.ip]) != 1 { if len(manager.incomingRules[tt.ip.String()]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules)) t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
return return
} }
for _, rule := range manager.incomingRules[tt.ip] { for _, rule := range manager.incomingRules[tt.ip.String()] {
addedRule = rule addedRule = rule
} }
} else { } else {
@@ -208,23 +208,27 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules)) t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
return return
} }
for _, rule := range manager.outgoingRules[tt.ip] { for _, rule := range manager.outgoingRules[tt.ip.String()] {
addedRule = rule addedRule = rule
} }
} }
if tt.ip.Compare(addedRule.ip) != 0 { if !tt.ip.Equal(addedRule.ip) {
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip) t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
return return
} }
if tt.dPort != addedRule.dPort.Values[0] { if tt.dPort != addedRule.dPort {
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort.Values[0]) t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort)
return return
} }
if layers.LayerTypeUDP != addedRule.protoLayer { if layers.LayerTypeUDP != addedRule.protoLayer {
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer) t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
return return
} }
if tt.expDir != addedRule.direction {
t.Errorf("expected direction %d, got %d", tt.expDir, addedRule.direction)
return
}
if addedRule.udpHook == nil { if addedRule.udpHook == nil {
t.Errorf("expected udpHook to be set") t.Errorf("expected udpHook to be set")
return return
@@ -238,7 +242,7 @@ func TestManagerReset(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock, false, flowLogger) m, err := Create(ifaceMock)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -246,16 +250,18 @@ func TestManagerReset(t *testing.T) {
ip := net.ParseIP("192.168.1.1") ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}} port := &fw.Port{Values: []int{80}}
direction := fw.RuleDirectionOUT
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule"
_, err = m.AddPeerFiltering(nil, ip, proto, nil, port, action, "") _, err = m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
} }
err = m.Close(nil) err = m.Reset(nil)
if err != nil { if err != nil {
t.Errorf("failed to reset Manager: %v", err) t.Errorf("failed to reset Manager: %v", err)
return return
@@ -269,18 +275,9 @@ func TestManagerReset(t *testing.T) {
func TestNotMatchByIP(t *testing.T) { func TestNotMatchByIP(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
}
},
} }
m, err := Create(ifaceMock, false, flowLogger) m, err := Create(ifaceMock)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -292,9 +289,11 @@ func TestNotMatchByIP(t *testing.T) {
ip := net.ParseIP("0.0.0.0") ip := net.ParseIP("0.0.0.0")
proto := fw.ProtocolUDP proto := fw.ProtocolUDP
direction := fw.RuleDirectionOUT
action := fw.ActionAccept action := fw.ActionAccept
comment := "Test rule"
_, err = m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "") _, err = m.AddPeerFiltering(ip, proto, nil, nil, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -328,12 +327,12 @@ func TestNotMatchByIP(t *testing.T) {
return return
} }
if m.dropFilter(buf.Bytes(), 0) { if m.dropFilter(buf.Bytes(), m.outgoingRules) {
t.Errorf("expected packet to be accepted") t.Errorf("expected packet to be accepted")
return return
} }
if err = m.Close(nil); err != nil { if err = m.Reset(nil); err != nil {
t.Errorf("failed to reset Manager: %v", err) t.Errorf("failed to reset Manager: %v", err)
return return
} }
@@ -347,17 +346,17 @@ func TestRemovePacketHook(t *testing.T) {
} }
// creating manager instance // creating manager instance
manager, err := Create(iface, false, flowLogger) manager, err := Create(iface)
if err != nil { if err != nil {
t.Fatalf("Failed to create Manager: %s", err) t.Fatalf("Failed to create Manager: %s", err)
} }
defer func() { defer func() {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Reset(nil))
}() }()
// Add a UDP packet hook // Add a UDP packet hook
hookFunc := func(data []byte) bool { return true } hookFunc := func(data []byte) bool { return true }
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc) hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc)
// Assert the hook is added by finding it in the manager's outgoing rules // Assert the hook is added by finding it in the manager's outgoing rules
found := false found := false
@@ -393,7 +392,7 @@ func TestRemovePacketHook(t *testing.T) {
func TestProcessOutgoingHooks(t *testing.T) { func TestProcessOutgoingHooks(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) })
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{ manager.wgNetwork = &net.IPNet{
@@ -401,9 +400,9 @@ func TestProcessOutgoingHooks(t *testing.T) {
Mask: net.CIDRMask(16, 32), Mask: net.CIDRMask(16, 32),
} }
manager.udpTracker.Close() manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger) manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond)
defer func() { defer func() {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Reset(nil))
}() }()
manager.decoders = sync.Pool{ manager.decoders = sync.Pool{
@@ -423,7 +422,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
hookCalled := false hookCalled := false
hookID := manager.AddUDPPacketHook( hookID := manager.AddUDPPacketHook(
false, false,
netip.MustParseAddr("100.10.0.100"), net.ParseIP("100.10.0.100"),
53, 53,
func([]byte) bool { func([]byte) bool {
hookCalled = true hookCalled = true
@@ -458,7 +457,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Test hook gets called // Test hook gets called
result := manager.processOutgoingHooks(buf.Bytes(), 0) result := manager.processOutgoingHooks(buf.Bytes())
require.True(t, result) require.True(t, result)
require.True(t, hookCalled) require.True(t, hookCalled)
@@ -468,7 +467,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
err = gopacket.SerializeLayers(buf, opts, ipv4) err = gopacket.SerializeLayers(buf, opts, ipv4)
require.NoError(t, err) require.NoError(t, err)
result = manager.processOutgoingHooks(buf.Bytes(), 0) result = manager.processOutgoingHooks(buf.Bytes())
require.False(t, result) require.False(t, result)
} }
@@ -479,12 +478,12 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
manager, err := Create(ifaceMock, false, flowLogger) manager, err := Create(ifaceMock)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
if err := manager.Close(nil); err != nil { if err := manager.Reset(nil); err != nil {
t.Errorf("clear the manager state: %v", err) t.Errorf("clear the manager state: %v", err)
} }
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -493,8 +492,12 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
ip := net.ParseIP("10.20.0.100") ip := net.ParseIP("10.20.0.100")
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}} port := &fw.Port{Values: []int{1000 + i}}
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "") if i%2 == 0 {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
} }
@@ -506,7 +509,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
func TestStatefulFirewall_UDPTracking(t *testing.T) { func TestStatefulFirewall_UDPTracking(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) })
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{ manager.wgNetwork = &net.IPNet{
@@ -515,7 +518,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
} }
manager.udpTracker.Close() // Close the existing tracker manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger) manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond)
manager.decoders = sync.Pool{ manager.decoders = sync.Pool{
New: func() any { New: func() any {
d := &decoder{ d := &decoder{
@@ -530,12 +533,12 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}, },
} }
defer func() { defer func() {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Reset(nil))
}() }()
// Set up packet parameters // Set up packet parameters
srcIP := netip.MustParseAddr("100.10.0.1") srcIP := net.ParseIP("100.10.0.1")
dstIP := netip.MustParseAddr("100.10.0.100") dstIP := net.ParseIP("100.10.0.100")
srcPort := uint16(51334) srcPort := uint16(51334)
dstPort := uint16(53) dstPort := uint16(53)
@@ -543,8 +546,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
outboundIPv4 := &layers.IPv4{ outboundIPv4 := &layers.IPv4{
TTL: 64, TTL: 64,
Version: 4, Version: 4,
SrcIP: srcIP.AsSlice(), SrcIP: srcIP,
DstIP: dstIP.AsSlice(), DstIP: dstIP,
Protocol: layers.IPProtocolUDP, Protocol: layers.IPProtocolUDP,
} }
outboundUDP := &layers.UDP{ outboundUDP := &layers.UDP{
@@ -569,15 +572,15 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Process outbound packet and verify connection tracking // Process outbound packet and verify connection tracking
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0) drop := manager.DropOutgoing(outboundBuf.Bytes())
require.False(t, drop, "Initial outbound packet should not be dropped") require.False(t, drop, "Initial outbound packet should not be dropped")
// Verify connection was tracked // Verify connection was tracked
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort) conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
require.True(t, exists, "Connection should be tracked after outbound packet") require.True(t, exists, "Connection should be tracked after outbound packet")
require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match") require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match")
require.True(t, dstIP.Compare(conn.DestIP) == 0, "Destination IP should match") require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match")
require.Equal(t, srcPort, conn.SourcePort, "Source port should match") require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
require.Equal(t, dstPort, conn.DestPort, "Destination port should match") require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
@@ -585,8 +588,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
inboundIPv4 := &layers.IPv4{ inboundIPv4 := &layers.IPv4{
TTL: 64, TTL: 64,
Version: 4, Version: 4,
SrcIP: dstIP.AsSlice(), // Original destination is now source SrcIP: dstIP, // Original destination is now source
DstIP: srcIP.AsSlice(), // Original source is now destination DstIP: srcIP, // Original source is now destination
Protocol: layers.IPProtocolUDP, Protocol: layers.IPProtocolUDP,
} }
inboundUDP := &layers.UDP{ inboundUDP := &layers.UDP{
@@ -636,7 +639,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
for _, cp := range checkPoints { for _, cp := range checkPoints {
time.Sleep(cp.sleep) time.Sleep(cp.sleep)
drop = manager.dropFilter(inboundBuf.Bytes(), 0) drop = manager.dropFilter(inboundBuf.Bytes(), manager.incomingRules)
require.Equal(t, cp.shouldAllow, !drop, cp.description) require.Equal(t, cp.shouldAllow, !drop, cp.description)
// If the connection should still be valid, verify it exists // If the connection should still be valid, verify it exists
@@ -685,7 +688,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
} }
// Create a new outbound connection for invalid tests // Create a new outbound connection for invalid tests
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0) drop = manager.processOutgoingHooks(outboundBuf.Bytes())
require.False(t, drop, "Second outbound packet should not be dropped") require.False(t, drop, "Second outbound packet should not be dropped")
for _, tc := range invalidCases { for _, tc := range invalidCases {
@@ -707,7 +710,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Verify the invalid packet is dropped // Verify the invalid packet is dropped
drop = manager.dropFilter(testBuf.Bytes(), 0) drop = manager.dropFilter(testBuf.Bytes(), manager.incomingRules)
require.True(t, drop, tc.description) require.True(t, drop, tc.description)
}) })
} }

View File

@@ -5,6 +5,7 @@ import (
"net" "net"
"net/netip" "net/netip"
"runtime" "runtime"
"strings"
"sync" "sync"
"github.com/pion/stun/v2" "github.com/pion/stun/v2"
@@ -13,8 +14,6 @@ import (
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn" wgConn "golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type RecvMessage struct { type RecvMessage struct {
@@ -53,10 +52,9 @@ type ICEBind struct {
muUDPMux sync.Mutex muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault udpMux *UniversalUDPMuxDefault
address wgaddr.Address
} }
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind { func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{ ib := &ICEBind{
StdNetBind: b, StdNetBind: b,
@@ -66,7 +64,6 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Ad
endpoints: make(map[netip.Addr]net.Conn), endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}), closedChan: make(chan struct{}),
closed: true, closed: true,
address: address,
} }
rc := receiverCreator{ rc := receiverCreator{
@@ -111,17 +108,35 @@ func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
return s.udpMux, nil return s.udpMux, nil
} }
func (b *ICEBind) SetEndpoint(fakeIP netip.Addr, conn net.Conn) { func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) {
fakeUDPAddr, err := fakeAddress(peerAddress)
if err != nil {
return nil, err
}
// force IPv4
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
return nil, fmt.Errorf("failed to convert IP to netip.Addr")
}
b.endpointsMu.Lock() b.endpointsMu.Lock()
b.endpoints[fakeIP] = conn b.endpoints[fakeAddr] = conn
b.endpointsMu.Unlock() b.endpointsMu.Unlock()
return fakeUDPAddr, nil
} }
func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) { func (b *ICEBind) RemoveEndpoint(fakeUDPAddr *net.UDPAddr) {
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
log.Warnf("failed to convert IP to netip.Addr")
return
}
b.endpointsMu.Lock() b.endpointsMu.Lock()
defer b.endpointsMu.Unlock() defer b.endpointsMu.Unlock()
delete(b.endpoints, fakeAddr)
delete(b.endpoints, fakeIP)
} }
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
@@ -146,10 +161,9 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
s.udpMux = NewUniversalUDPMuxDefault( s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{ UniversalUDPMuxParams{
UDPConn: conn, UDPConn: conn,
Net: s.transportNet, Net: s.transportNet,
FilterFn: s.filterFn, FilterFn: s.filterFn,
WGAddress: s.address,
}, },
) )
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
@@ -261,6 +275,21 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
} }
} }
// fakeAddress returns a fake address that is used to as an identifier for the peer.
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
octets := strings.Split(peerAddress.IP.String(), ".")
if len(octets) != 4 {
return nil, fmt.Errorf("invalid IP format")
}
newAddr := &net.UDPAddr{
IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])),
Port: peerAddress.Port,
}
return newAddr, nil
}
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message { func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
return msgsPool.Get().(*[]ipv6.Message) return msgsPool.Get().(*[]ipv6.Message)
} }

View File

@@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"slices"
"strings" "strings"
"sync" "sync"
@@ -153,7 +152,46 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
} }
mux := &UDPMuxDefault{ var localAddrsForUnspecified []net.Addr
if addr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr())
} else if ok && addr.IP.IsUnspecified() {
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
// it will break the applications that are already using unspecified UDP connection
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType
switch {
case addr.IP.To16() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
case addr.IP.To4() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
default:
params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
}
if len(networks) > 0 {
if params.Net == nil {
var err error
if params.Net, err = stdnet.NewNet(); err != nil {
params.Logger.Errorf("failed to get create network: %v", err)
}
}
ips, err := localInterfaces(params.Net, params.InterfaceFilter, nil, networks, true)
if err == nil {
for _, ip := range ips {
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
}
} else {
params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
}
}
}
return &UDPMuxDefault{
addressMap: map[string][]*udpMuxedConn{}, addressMap: map[string][]*udpMuxedConn{},
params: params, params: params,
connsIPv4: make(map[string]*udpMuxedConn), connsIPv4: make(map[string]*udpMuxedConn),
@@ -165,55 +203,8 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
return newBufferHolder(receiveMTU + maxAddrSize) return newBufferHolder(receiveMTU + maxAddrSize)
}, },
}, },
localAddrsForUnspecified: localAddrsForUnspecified,
} }
mux.updateLocalAddresses()
return mux
}
func (m *UDPMuxDefault) updateLocalAddresses() {
var localAddrsForUnspecified []net.Addr
if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr())
} else if ok && addr.IP.IsUnspecified() {
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
// it will break the applications that are already using unspecified UDP connection
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
var networks []ice.NetworkType
switch {
case addr.IP.To16() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
case addr.IP.To4() != nil:
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
default:
m.params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", m.params.UDPConn.LocalAddr())
}
if len(networks) > 0 {
if m.params.Net == nil {
var err error
if m.params.Net, err = stdnet.NewNet(); err != nil {
m.params.Logger.Errorf("failed to get create network: %v", err)
}
}
ips, err := localInterfaces(m.params.Net, m.params.InterfaceFilter, nil, networks, true)
if err == nil {
for _, ip := range ips {
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
}
} else {
m.params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
}
}
}
m.mu.Lock()
m.localAddrsForUnspecified = localAddrsForUnspecified
m.mu.Unlock()
} }
// LocalAddr returns the listening address of this UDPMuxDefault // LocalAddr returns the listening address of this UDPMuxDefault
@@ -223,12 +214,8 @@ func (m *UDPMuxDefault) LocalAddr() net.Addr {
// GetListenAddresses returns the list of addresses that this mux is listening on // GetListenAddresses returns the list of addresses that this mux is listening on
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr { func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
m.updateLocalAddresses()
m.mu.Lock()
defer m.mu.Unlock()
if len(m.localAddrsForUnspecified) > 0 { if len(m.localAddrsForUnspecified) > 0 {
return slices.Clone(m.localAddrsForUnspecified) return m.localAddrsForUnspecified
} }
return []net.Addr{m.LocalAddr()} return []net.Addr{m.LocalAddr()}
@@ -238,10 +225,7 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
// creates the connection if an existing one can't be found // creates the connection if an existing one can't be found
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) { func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
// don't check addr for mux using unspecified address // don't check addr for mux using unspecified address
m.mu.Lock() if len(m.localAddrsForUnspecified) == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() {
lenLocalAddrs := len(m.localAddrsForUnspecified)
m.mu.Unlock()
if lenLocalAddrs == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() {
return nil, fmt.Errorf("invalid address %s", addr.String()) return nil, fmt.Errorf("invalid address %s", addr.String())
} }

View File

@@ -17,8 +17,6 @@ import (
"github.com/pion/logging" "github.com/pion/logging"
"github.com/pion/stun/v2" "github.com/pion/stun/v2"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// FilterFn is a function that filters out candidates based on the address. // FilterFn is a function that filters out candidates based on the address.
@@ -43,7 +41,6 @@ type UniversalUDPMuxParams struct {
XORMappedAddrCacheTTL time.Duration XORMappedAddrCacheTTL time.Duration
Net transport.Net Net transport.Net
FilterFn FilterFn FilterFn FilterFn
WGAddress wgaddr.Address
} }
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux // NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
@@ -67,7 +64,6 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
mux: m, mux: m,
logger: params.Logger, logger: params.Logger,
filterFn: params.FilterFn, filterFn: params.FilterFn,
address: params.WGAddress,
} }
// embed UDPMux // embed UDPMux
@@ -122,7 +118,6 @@ type udpConn struct {
filterFn FilterFn filterFn FilterFn
// TODO: reset cache on route changes // TODO: reset cache on route changes
addrCache sync.Map addrCache sync.Map
address wgaddr.Address
} }
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) { func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
@@ -164,11 +159,6 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
return nil return nil
} }
if u.address.Network.Contains(a.AsSlice()) {
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
}
if isRouted, prefix, err := u.filterFn(a); err != nil { if isRouted, prefix, err := u.filterFn(a); err != nil {
log.Errorf("Failed to check if address %s is routed: %v", addr, err) log.Errorf("Failed to check if address %s is routed: %v", addr, err)
} else { } else {

View File

@@ -43,7 +43,13 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error
return nil return nil
} }
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
// parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {
return err return err
@@ -52,7 +58,7 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, ke
PublicKey: peerKeyParsed, PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false, ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP // don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: allowedIps, AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive, PersistentKeepaliveInterval: &keepAlive,
Endpoint: endpoint, Endpoint: endpoint,
PresharedKey: preSharedKey, PresharedKey: preSharedKey,

View File

@@ -2,5 +2,5 @@
package configurer package configurer
// WgInterfaceDefault is a default interface name of Netbird // WgInterfaceDefault is a default interface name of Wiretrustee
const WgInterfaceDefault = "wt0" const WgInterfaceDefault = "wt0"

View File

@@ -2,5 +2,5 @@
package configurer package configurer
// WgInterfaceDefault is a default interface name of Netbird // WgInterfaceDefault is a default interface name of Wiretrustee
const WgInterfaceDefault = "utun100" const WgInterfaceDefault = "utun100"

View File

@@ -52,7 +52,13 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error
return c.device.IpcSet(toWgUserspaceString(config)) return c.device.IpcSet(toWgUserspaceString(config))
} }
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
// parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil {
return err
}
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {
return err return err
@@ -61,7 +67,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, kee
PublicKey: peerKeyParsed, PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false, ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP // don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: allowedIps, AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive, PersistentKeepaliveInterval: &keepAlive,
PresharedKey: preSharedKey, PresharedKey: preSharedKey,
Endpoint: endpoint, Endpoint: endpoint,
@@ -356,7 +362,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
} }
func getFwmark() int { func getFwmark() int {
if nbnet.AdvancedRouting() { if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() {
return nbnet.NetbirdFwmark return nbnet.NetbirdFwmark
} }
return 0 return 0

View File

@@ -3,23 +3,16 @@
package iface package iface
import ( import (
"golang.zx2c4.com/wireguard/tun/netstack"
wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type WGTunDevice interface { type WGTunDevice interface {
Create() (device.WGConfigurer, error) Create() (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error UpdateAddr(address WGAddress) error
WgAddress() wgaddr.Address WgAddress() WGAddress
DeviceName() string DeviceName() string
Close() error Close() error
FilteredDevice() *device.FilteredDevice FilteredDevice() *device.FilteredDevice
Device() *wgdevice.Device
GetNet() *netstack.Net
} }

View File

@@ -1,29 +1,29 @@
package wgaddr package device
import ( import (
"fmt" "fmt"
"net" "net"
) )
// Address WireGuard parsed address // WGAddress WireGuard parsed address
type Address struct { type WGAddress struct {
IP net.IP IP net.IP
Network *net.IPNet Network *net.IPNet
} }
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address // ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
func ParseWGAddress(address string) (Address, error) { func ParseWGAddress(address string) (WGAddress, error) {
ip, network, err := net.ParseCIDR(address) ip, network, err := net.ParseCIDR(address)
if err != nil { if err != nil {
return Address{}, err return WGAddress{}, err
} }
return Address{ return WGAddress{
IP: ip, IP: ip,
Network: network, Network: network,
}, nil }, nil
} }
func (addr Address) String() string { func (addr WGAddress) String() string {
maskSize, _ := addr.Network.Mask.Size() maskSize, _ := addr.Network.Mask.Size()
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize) return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
} }

View File

@@ -9,16 +9,14 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform // WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
type WGTunDevice struct { type WGTunDevice struct {
address wgaddr.Address address WGAddress
port int port int
key string key string
mtu int mtu int
@@ -32,7 +30,7 @@ type WGTunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice { func NewTunDevice(address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
return &WGTunDevice{ return &WGTunDevice{
address: address, address: address,
port: port, port: port,
@@ -65,7 +63,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
t.filteredDevice = newDeviceFilter(tunDevice) t.filteredDevice = newDeviceFilter(tunDevice)
log.Debugf("attaching to interface %v", name) log.Debugf("attaching to interface %v", name)
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] ")) t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
// without this property mobile devices can discover remote endpoints if the configured one was wrong. // without this property mobile devices can discover remote endpoints if the configured one was wrong.
// this helps with support for the older NetBird clients that had a hardcoded direct mode // this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics() // t.device.DisableSomeRoamingForBrokenMobileSemantics()
@@ -94,7 +92,7 @@ func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error { func (t *WGTunDevice) UpdateAddr(addr WGAddress) error {
// todo implement // todo implement
return nil return nil
} }
@@ -124,7 +122,7 @@ func (t *WGTunDevice) DeviceName() string {
return t.name return t.name
} }
func (t *WGTunDevice) WgAddress() wgaddr.Address { func (t *WGTunDevice) WgAddress() WGAddress {
return t.address return t.address
} }
@@ -132,10 +130,6 @@ func (t *WGTunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice return t.filteredDevice
} }
func (t *WGTunDevice) GetNet() *netstack.Net {
return nil
}
func routesToString(routes []string) string { func routesToString(routes []string) string {
return strings.Join(routes, ";") return strings.Join(routes, ";")
} }

View File

@@ -9,16 +9,14 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type TunDevice struct { type TunDevice struct {
name string name string
address wgaddr.Address address WGAddress
port int port int
key string key string
mtu int mtu int
@@ -30,7 +28,7 @@ type TunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice { func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{ return &TunDevice{
name: name, name: name,
address: address, address: address,
@@ -86,7 +84,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *TunDevice) UpdateAddr(address wgaddr.Address) error { func (t *TunDevice) UpdateAddr(address WGAddress) error {
t.address = address t.address = address
return t.assignAddr() return t.assignAddr()
} }
@@ -107,7 +105,7 @@ func (t *TunDevice) Close() error {
return nil return nil
} }
func (t *TunDevice) WgAddress() wgaddr.Address { func (t *TunDevice) WgAddress() WGAddress {
return t.address return t.address
} }
@@ -119,11 +117,6 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice return t.filteredDevice
} }
// Device returns the wireguard device
func (t *TunDevice) Device() *device.Device {
return t.device
}
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided // assignAddr Adds IP address to the tunnel interface and network route based on the range provided
func (t *TunDevice) assignAddr() error { func (t *TunDevice) assignAddr() error {
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String()) cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
@@ -145,7 +138,3 @@ func (t *TunDevice) assignAddr() error {
} }
return nil return nil
} }
func (t *TunDevice) GetNet() *netstack.Net {
return nil
}

View File

@@ -2,7 +2,6 @@ package device
import ( import (
"net" "net"
"net/netip"
"sync" "sync"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
@@ -11,16 +10,16 @@ import (
// PacketFilter interface for firewall abilities // PacketFilter interface for firewall abilities
type PacketFilter interface { type PacketFilter interface {
// DropOutgoing filter outgoing packets from host to external destinations // DropOutgoing filter outgoing packets from host to external destinations
DropOutgoing(packetData []byte, size int) bool DropOutgoing(packetData []byte) bool
// DropIncoming filter incoming packets from external sources to host // DropIncoming filter incoming packets from external sources to host
DropIncoming(packetData []byte, size int) bool DropIncoming(packetData []byte) bool
// AddUDPPacketHook calls hook when UDP packet from given direction matched // AddUDPPacketHook calls hook when UDP packet from given direction matched
// //
// Hook function returns flag which indicates should be the matched package dropped or not. // Hook function returns flag which indicates should be the matched package dropped or not.
// Hook function receives raw network packet data as argument. // Hook function receives raw network packet data as argument.
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string
// RemovePacketHook removes hook by ID // RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error RemovePacketHook(hookID string) error
@@ -58,7 +57,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
} }
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) { if filter.DropOutgoing(bufs[i][offset : offset+sizes[i]]) {
bufs = append(bufs[:i], bufs[i+1:]...) bufs = append(bufs[:i], bufs[i+1:]...)
sizes = append(sizes[:i], sizes[i+1:]...) sizes = append(sizes[:i], sizes[i+1:]...)
n-- n--
@@ -82,7 +81,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
filteredBufs := make([][]byte, 0, len(bufs)) filteredBufs := make([][]byte, 0, len(bufs))
dropped := 0 dropped := 0
for _, buf := range bufs { for _, buf := range bufs {
if !filter.DropIncoming(buf[offset:], len(buf)) { if !filter.DropIncoming(buf[offset:]) {
filteredBufs = append(filteredBufs, buf) filteredBufs = append(filteredBufs, buf)
dropped++ dropped++
} }

View File

@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
tun.EXPECT().Write(mockBufs, 0).Return(0, nil) tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
filter := mocks.NewMockPacketFilter(ctrl) filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true) filter.EXPECT().DropIncoming(gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun) wrapped := newDeviceFilter(tun)
wrapped.filter = filter wrapped.filter = filter
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
return 1, nil return 1, nil
}) })
filter := mocks.NewMockPacketFilter(ctrl) filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true) filter.EXPECT().DropOutgoing(gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun) wrapped := newDeviceFilter(tun)
wrapped.filter = filter wrapped.filter = filter

View File

@@ -10,16 +10,14 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type TunDevice struct { type TunDevice struct {
name string name string
address wgaddr.Address address WGAddress
port int port int
key string key string
iceBind *bind.ICEBind iceBind *bind.ICEBind
@@ -31,7 +29,7 @@ type TunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(name string, address wgaddr.Address, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice { func NewTunDevice(name string, address WGAddress, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
return &TunDevice{ return &TunDevice{
name: name, name: name,
address: address, address: address,
@@ -66,7 +64,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
t.filteredDevice = newDeviceFilter(tunDevice) t.filteredDevice = newDeviceFilter(tunDevice)
log.Debug("Attaching to interface") log.Debug("Attaching to interface")
t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] ")) t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
// without this property mobile devices can discover remote endpoints if the configured one was wrong. // without this property mobile devices can discover remote endpoints if the configured one was wrong.
// this helps with support for the older NetBird clients that had a hardcoded direct mode // this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics() // t.device.DisableSomeRoamingForBrokenMobileSemantics()
@@ -121,11 +119,11 @@ func (t *TunDevice) Close() error {
return nil return nil
} }
func (t *TunDevice) WgAddress() wgaddr.Address { func (t *TunDevice) WgAddress() WGAddress {
return t.address return t.address
} }
func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error { func (t *TunDevice) UpdateAddr(addr WGAddress) error {
// todo implement // todo implement
return nil return nil
} }
@@ -133,7 +131,3 @@ func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error {
func (t *TunDevice) FilteredDevice() *FilteredDevice { func (t *TunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice return t.filteredDevice
} }
func (t *TunDevice) GetNet() *netstack.Net {
return nil
}

View File

@@ -9,18 +9,15 @@ import (
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/sharedsock" "github.com/netbirdio/netbird/sharedsock"
) )
type TunKernelDevice struct { type TunKernelDevice struct {
name string name string
address wgaddr.Address address WGAddress
wgPort int wgPort int
key string key string
mtu int mtu int
@@ -35,7 +32,9 @@ type TunKernelDevice struct {
filterFn bind.FilterFn filterFn bind.FilterFn
} }
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice { func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
checkUser()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &TunKernelDevice{ return &TunKernelDevice{
ctx: ctx, ctx: ctx,
@@ -100,10 +99,9 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return nil, err return nil, err
} }
bindParams := bind.UniversalUDPMuxParams{ bindParams := bind.UniversalUDPMuxParams{
UDPConn: rawSock, UDPConn: rawSock,
Net: t.transportNet, Net: t.transportNet,
FilterFn: t.filterFn, FilterFn: t.filterFn,
WGAddress: t.address,
} }
mux := bind.NewUniversalUDPMuxDefault(bindParams) mux := bind.NewUniversalUDPMuxDefault(bindParams)
go mux.ReadFromConn(t.ctx) go mux.ReadFromConn(t.ctx)
@@ -114,7 +112,7 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return t.udpMux, nil return t.udpMux, nil
} }
func (t *TunKernelDevice) UpdateAddr(address wgaddr.Address) error { func (t *TunKernelDevice) UpdateAddr(address WGAddress) error {
t.address = address t.address = address
return t.assignAddr() return t.assignAddr()
} }
@@ -147,7 +145,7 @@ func (t *TunKernelDevice) Close() error {
return closErr return closErr
} }
func (t *TunKernelDevice) WgAddress() wgaddr.Address { func (t *TunKernelDevice) WgAddress() WGAddress {
return t.address return t.address
} }
@@ -155,11 +153,6 @@ func (t *TunKernelDevice) DeviceName() string {
return t.name return t.name
} }
// Device returns the wireguard device, not applicable for kernel devices
func (t *TunKernelDevice) Device() *device.Device {
return nil
}
func (t *TunKernelDevice) FilteredDevice() *FilteredDevice { func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
return nil return nil
} }
@@ -168,7 +161,3 @@ func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
func (t *TunKernelDevice) assignAddr() error { func (t *TunKernelDevice) assignAddr() error {
return t.link.assignAddr(t.address) return t.link.assignAddr(t.address)
} }
func (t *TunKernelDevice) GetNet() *netstack.Net {
return nil
}

View File

@@ -8,18 +8,15 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net"
) )
type TunNetstackDevice struct { type TunNetstackDevice struct {
name string name string
address wgaddr.Address address WGAddress
port int port int
key string key string
mtu int mtu int
@@ -28,14 +25,12 @@ type TunNetstackDevice struct {
device *device.Device device *device.Device
filteredDevice *FilteredDevice filteredDevice *FilteredDevice
nsTun *nbnetstack.NetStackTun nsTun *netstack.NetStackTun
udpMux *bind.UniversalUDPMuxDefault udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer configurer WGConfigurer
net *netstack.Net
} }
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice { func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
return &TunNetstackDevice{ return &TunNetstackDevice{
name: name, name: name,
address: address, address: address,
@@ -48,19 +43,13 @@ func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key stri
} }
func (t *TunNetstackDevice) Create() (WGConfigurer, error) { func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
log.Info("create nbnetstack tun interface") log.Info("create netstack tun interface")
t.nsTun = netstack.NewNetStackTun(t.listenAddress, t.address.IP.String(), t.mtu)
// TODO: get from service listener runtime IP tunIface, err := t.nsTun.Create()
dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
log.Debugf("netstack using address: %s", t.address.IP)
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
log.Debugf("netstack using dns address: %s", dnsAddr)
tunIface, net, err := t.nsTun.Create()
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating tun device: %s", err) return nil, fmt.Errorf("error creating tun device: %s", err)
} }
t.filteredDevice = newDeviceFilter(tunIface) t.filteredDevice = newDeviceFilter(tunIface)
t.net = net
t.device = device.NewDevice( t.device = device.NewDevice(
t.filteredDevice, t.filteredDevice,
@@ -98,7 +87,7 @@ func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *TunNetstackDevice) UpdateAddr(wgaddr.Address) error { func (t *TunNetstackDevice) UpdateAddr(WGAddress) error {
return nil return nil
} }
@@ -117,7 +106,7 @@ func (t *TunNetstackDevice) Close() error {
return nil return nil
} }
func (t *TunNetstackDevice) WgAddress() wgaddr.Address { func (t *TunNetstackDevice) WgAddress() WGAddress {
return t.address return t.address
} }
@@ -128,12 +117,3 @@ func (t *TunNetstackDevice) DeviceName() string {
func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice { func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice return t.filteredDevice
} }
// Device returns the wireguard device
func (t *TunNetstackDevice) Device() *device.Device {
return t.device
}
func (t *TunNetstackDevice) GetNet() *netstack.Net {
return t.net
}

Some files were not shown because too many files have changed in this diff Show More