Compare commits

..

1 Commits

Author SHA1 Message Date
bcmmbaga
5b344f9b3f test migration
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-11 22:04:00 +03:00
354 changed files with 9376 additions and 19716 deletions

View File

@@ -18,14 +18,14 @@ jobs:
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: macos-go-${{ hashFiles('**/go.sum') }} key: macos-go-${{ hashFiles('**/go.sum') }}
@@ -42,4 +42,4 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./... run: NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...

View File

@@ -38,7 +38,7 @@ jobs:
time go test -timeout 1m -failfast ./dns/... time go test -timeout 1m -failfast ./dns/...
time go test -timeout 1m -failfast ./encryption/... time go test -timeout 1m -failfast ./encryption/...
time go test -timeout 1m -failfast ./formatter/... time go test -timeout 1m -failfast ./formatter/...
time go test -timeout 1m -failfast ./client/iface/... time go test -timeout 1m -failfast ./iface/...
time go test -timeout 1m -failfast ./route/... time go test -timeout 1m -failfast ./route/...
time go test -timeout 1m -failfast ./sharedsock/... time go test -timeout 1m -failfast ./sharedsock/...
time go test -timeout 1m -failfast ./signal/... time go test -timeout 1m -failfast ./signal/...

View File

@@ -16,16 +16,16 @@ jobs:
matrix: matrix:
arch: [ '386','amd64' ] arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres'] store: [ 'sqlite', 'postgres']
runs-on: ubuntu-22.04 runs-on: ubuntu-latest
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -33,7 +33,7 @@ jobs:
${{ runner.os }}-go- ${{ runner.os }}-go-
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
@@ -49,18 +49,18 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./... run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
test_client_on_docker: test_client_on_docker:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -68,7 +68,7 @@ jobs:
${{ runner.os }}-go- ${{ runner.os }}-go-
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
@@ -79,6 +79,9 @@ jobs:
- name: check git status - name: check git status
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Generate Iface Test bin
run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./iface/
- name: Generate Shared Sock Test bin - name: Generate Shared Sock Test bin
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
@@ -95,7 +98,7 @@ jobs:
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
- name: Generate Peer Test bin - name: Generate Peer Test bin
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/ run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/...
- run: chmod +x *testing.bin - run: chmod +x *testing.bin
@@ -103,7 +106,7 @@ jobs:
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1 run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Iface tests in docker - name: Run Iface tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/... run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin -test.timeout 5m -test.parallel 1
- name: Run RouteManager tests in docker - name: Run RouteManager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1 run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
@@ -121,4 +124,4 @@ jobs:
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1 run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Peer tests in docker - name: Run Peer tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1 run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1

View File

@@ -17,13 +17,13 @@ jobs:
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
id: go id: go
with: with:
go-version: "1.23.x" go-version: "1.21.x"
- name: Download wintun - name: Download wintun
uses: carlosperate/download-file-action@v2 uses: carlosperate/download-file-action@v2

View File

@@ -15,11 +15,11 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: codespell - name: codespell
uses: codespell-project/actions-codespell@v2 uses: codespell-project/actions-codespell@v2
with: with:
ignore_words_list: erro,clienta,hastable,iif,groupd ignore_words_list: erro,clienta,hastable,
skip: go.mod,go.sum skip: go.mod,go.sum
only_warn: 1 only_warn: 1
golangci: golangci:
@@ -32,15 +32,15 @@ jobs:
timeout-minutes: 15 timeout-minutes: 15
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Check for duplicate constants - name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest' if: matrix.os == 'ubuntu-latest'
run: | run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep . ! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
cache: false cache: false
- name: Install dependencies - name: Install dependencies
if: matrix.os == 'ubuntu-latest' if: matrix.os == 'ubuntu-latest'
@@ -49,4 +49,4 @@ jobs:
uses: golangci/golangci-lint-action@v3 uses: golangci/golangci-lint-action@v3
with: with:
version: latest version: latest
args: --timeout=12m args: --timeout=12m

View File

@@ -13,7 +13,6 @@ concurrency:
jobs: jobs:
test-install-script: test-install-script:
strategy: strategy:
fail-fast: false
max-parallel: 2 max-parallel: 2
matrix: matrix:
os: [ubuntu-latest, macos-latest] os: [ubuntu-latest, macos-latest]
@@ -22,7 +21,7 @@ jobs:
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: run install script - name: run install script
env: env:

View File

@@ -15,23 +15,23 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
- name: Setup Android SDK - name: Setup Android SDK
uses: android-actions/setup-android@v3 uses: android-actions/setup-android@v3
with: with:
cmdline-tools-version: 8512546 cmdline-tools-version: 8512546
- name: Setup Java - name: Setup Java
uses: actions/setup-java@v4 uses: actions/setup-java@v3
with: with:
java-version: "11" java-version: "11"
distribution: "adopt" distribution: "adopt"
- name: NDK Cache - name: NDK Cache
id: ndk-cache id: ndk-cache
uses: actions/cache@v4 uses: actions/cache@v3
with: with:
path: /usr/local/lib/android/sdk/ndk path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620 key: ndk-cache-23.1.7779620
@@ -50,11 +50,11 @@ jobs:
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
- name: install gomobile - name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
- name: gomobile init - name: gomobile init
@@ -62,4 +62,4 @@ jobs:
- name: build iOS netbird lib - name: build iOS netbird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o ./NetBirdSDK.xcframework ./client/ios/NetBirdSDK run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o ./NetBirdSDK.xcframework ./client/ios/NetBirdSDK
env: env:
CGO_ENABLED: 0 CGO_ENABLED: 0

View File

@@ -3,14 +3,15 @@ name: Release
on: on:
push: push:
tags: tags:
- "v*" - 'v*'
branches: branches:
- main - main
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.0.16" SIGN_PIPE_VER: "v0.0.14"
GORELEASER_VER: "v2.3.2" GORELEASER_VER: "v1.14.1"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)" COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
@@ -20,7 +21,7 @@ concurrency:
jobs: jobs:
release: release:
runs-on: ubuntu-22.04 runs-on: ubuntu-latest
env: env:
flags: "" flags: ""
steps: steps:
@@ -33,17 +34,20 @@ jobs:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }} - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout -
uses: actions/checkout@v4 name: Checkout
uses: actions/checkout@v3
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go -
uses: actions/setup-go@v5 name: Set up Go
uses: actions/setup-go@v4
with: with:
go-version: "1.23" go-version: "1.21"
cache: false cache: false
- name: Cache Go modules -
uses: actions/cache@v4 name: Cache Go modules
uses: actions/cache@v3
with: with:
path: | path: |
~/go/pkg/mod ~/go/pkg/mod
@@ -51,19 +55,24 @@ jobs:
key: ${{ runner.os }}-go-releaser-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-releaser-${{ hashFiles('**/go.sum') }}
restore-keys: | restore-keys: |
${{ runner.os }}-go-releaser- ${{ runner.os }}-go-releaser-
- name: Install modules -
name: Install modules
run: go mod tidy run: go mod tidy
- name: check git status -
name: check git status
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Set up QEMU -
name: Set up QEMU
uses: docker/setup-qemu-action@v2 uses: docker/setup-qemu-action@v2
- name: Set up Docker Buildx -
name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2 uses: docker/setup-buildx-action@v2
- name: Login to Docker hub -
name: Login to Docker hub
if: github.event_name != 'pull_request' if: github.event_name != 'pull_request'
uses: docker/login-action@v1 uses: docker/login-action@v1
with: with:
username: ${{ secrets.DOCKER_USER }} username: netbirdio
password: ${{ secrets.DOCKER_TOKEN }} password: ${{ secrets.DOCKER_TOKEN }}
- name: Install OS build dependencies - name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
@@ -76,32 +85,36 @@ jobs:
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
args: release --clean ${{ env.flags }} args: release --rm-dist ${{ env.flags }}
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- name: upload non tags for debug purposes -
uses: actions/upload-artifact@v4 name: upload non tags for debug purposes
uses: actions/upload-artifact@v3
with: with:
name: release name: release
path: dist/ path: dist/
retention-days: 3 retention-days: 3
- name: upload linux packages -
uses: actions/upload-artifact@v4 name: upload linux packages
uses: actions/upload-artifact@v3
with: with:
name: linux-packages name: linux-packages
path: dist/netbird_linux** path: dist/netbird_linux**
retention-days: 3 retention-days: 3
- name: upload windows packages -
uses: actions/upload-artifact@v4 name: upload windows packages
uses: actions/upload-artifact@v3
with: with:
name: windows-packages name: windows-packages
path: dist/netbird_windows** path: dist/netbird_windows**
retention-days: 3 retention-days: 3
- name: upload macos packages -
uses: actions/upload-artifact@v4 name: upload macos packages
uses: actions/upload-artifact@v3
with: with:
name: macos-packages name: macos-packages
path: dist/netbird_darwin** path: dist/netbird_darwin**
@@ -120,19 +133,19 @@ jobs:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }} - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v3
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23" go-version: "1.21"
cache: false cache: false
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v3
with: with:
path: | path: |
~/go/pkg/mod ~/go/pkg/mod
~/.cache/go-build ~/.cache/go-build
key: ${{ runner.os }}-ui-go-releaser-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-ui-go-releaser-${{ hashFiles('**/go.sum') }}
@@ -156,14 +169,14 @@ jobs:
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }} args: release --config .goreleaser_ui.yaml --rm-dist ${{ env.flags }}
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- name: upload non tags for debug purposes - name: upload non tags for debug purposes
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v3
with: with:
name: release-ui name: release-ui
path: dist/ path: dist/
@@ -174,17 +187,20 @@ jobs:
steps: steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }} - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout -
uses: actions/checkout@v4 name: Checkout
uses: actions/checkout@v3
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go -
uses: actions/setup-go@v5 name: Set up Go
uses: actions/setup-go@v4
with: with:
go-version: "1.23" go-version: "1.21"
cache: false cache: false
- name: Cache Go modules -
uses: actions/cache@v4 name: Cache Go modules
uses: actions/cache@v3
with: with:
path: | path: |
~/go/pkg/mod ~/go/pkg/mod
@@ -192,20 +208,24 @@ jobs:
key: ${{ runner.os }}-ui-go-releaser-darwin-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-ui-go-releaser-darwin-${{ hashFiles('**/go.sum') }}
restore-keys: | restore-keys: |
${{ runner.os }}-ui-go-releaser-darwin- ${{ runner.os }}-ui-go-releaser-darwin-
- name: Install modules -
name: Install modules
run: go mod tidy run: go mod tidy
- name: check git status -
name: check git status
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Run GoReleaser -
name: Run GoReleaser
id: goreleaser id: goreleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }} args: release --config .goreleaser_ui_darwin.yaml --rm-dist ${{ env.flags }}
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: upload non tags for debug purposes -
uses: actions/upload-artifact@v4 name: upload non tags for debug purposes
uses: actions/upload-artifact@v3
with: with:
name: release-ui-darwin name: release-ui-darwin
path: dist/ path: dist/
@@ -213,7 +233,7 @@ jobs:
trigger_signer: trigger_signer:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: [release, release_ui, release_ui_darwin] needs: [release,release_ui,release_ui_darwin]
if: startsWith(github.ref, 'refs/tags/') if: startsWith(github.ref, 'refs/tags/')
steps: steps:
- name: Trigger binaries sign pipelines - name: Trigger binaries sign pipelines
@@ -223,4 +243,4 @@ jobs:
repo: netbirdio/sign-pipelines repo: netbirdio/sign-pipelines
ref: ${{ env.SIGN_PIPE_VER }} ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }} token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }' inputs: '{ "tag": "${{ github.ref }}" }'

View File

@@ -50,12 +50,12 @@ jobs:
run: sudo apt-get install -y curl run: sudo apt-get install -y curl
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -63,7 +63,7 @@ jobs:
${{ runner.os }}-go- ${{ runner.os }}-go-
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: cp setup.env - name: cp setup.env
run: cp infrastructure_files/tests/setup.env infrastructure_files/ run: cp infrastructure_files/tests/setup.env infrastructure_files/
@@ -219,7 +219,7 @@ jobs:
run: sudo apt-get install -y jq run: sudo apt-get install -y jq
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: run script with Zitadel PostgreSQL - name: run script with Zitadel PostgreSQL
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh

View File

@@ -1,5 +1,3 @@
version: 2
project_name: netbird project_name: netbird
builds: builds:
- id: netbird - id: netbird
@@ -24,7 +22,7 @@ builds:
goarch: 386 goarch: 386
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: '{{ .CommitTimestamp }}'
tags: tags:
- load_wgnt_from_rsrc - load_wgnt_from_rsrc
@@ -44,19 +42,19 @@ builds:
- softfloat - softfloat
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: '{{ .CommitTimestamp }}'
tags: tags:
- load_wgnt_from_rsrc - load_wgnt_from_rsrc
- id: netbird-mgmt - id: netbird-mgmt
dir: management dir: management
env: env:
- CGO_ENABLED=1 - CGO_ENABLED=1
- >- - >-
{{- if eq .Runtime.Goos "linux" }} {{- if eq .Runtime.Goos "linux" }}
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }} {{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }} {{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
{{- end }} {{- end }}
binary: netbird-mgmt binary: netbird-mgmt
goos: goos:
- linux - linux
@@ -66,7 +64,7 @@ builds:
- arm - arm
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: '{{ .CommitTimestamp }}'
- id: netbird-signal - id: netbird-signal
dir: signal dir: signal
@@ -80,7 +78,7 @@ builds:
- arm - arm
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: '{{ .CommitTimestamp }}'
- id: netbird-relay - id: netbird-relay
dir: relay dir: relay
@@ -94,10 +92,7 @@ builds:
- arm - arm
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: '{{ .CommitTimestamp }}'
universal_binaries:
- id: netbird
archives: archives:
- builds: - builds:
@@ -105,6 +100,7 @@ archives:
- netbird-static - netbird-static
nfpms: nfpms:
- maintainer: Netbird <dev@netbird.io> - maintainer: Netbird <dev@netbird.io>
description: Netbird client. description: Netbird client.
homepage: https://netbird.io/ homepage: https://netbird.io/
@@ -420,9 +416,10 @@ docker_manifests:
- netbirdio/management:{{ .Version }}-debug-amd64 - netbirdio/management:{{ .Version }}-debug-amd64
brews: brews:
- ids: -
ids:
- default - default
repository: tap:
owner: netbirdio owner: netbirdio
name: homebrew-tap name: homebrew-tap
token: "{{ .Env.HOMEBREW_TAP_GITHUB_TOKEN }}" token: "{{ .Env.HOMEBREW_TAP_GITHUB_TOKEN }}"
@@ -439,7 +436,7 @@ brews:
uploads: uploads:
- name: debian - name: debian
ids: ids:
- netbird-deb - netbird-deb
mode: archive mode: archive
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package= target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
username: dev@wiretrustee.com username: dev@wiretrustee.com

View File

@@ -1,5 +1,3 @@
version: 2
project_name: netbird-ui project_name: netbird-ui
builds: builds:
- id: netbird-ui - id: netbird-ui
@@ -13,7 +11,7 @@ builds:
- amd64 - amd64
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: '{{ .CommitTimestamp }}'
- id: netbird-ui-windows - id: netbird-ui-windows
dir: client/ui dir: client/ui
@@ -28,7 +26,7 @@ builds:
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
- -H windowsgui - -H windowsgui
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: '{{ .CommitTimestamp }}'
archives: archives:
- id: linux-arch - id: linux-arch
@@ -41,6 +39,7 @@ archives:
- netbird-ui-windows - netbird-ui-windows
nfpms: nfpms:
- maintainer: Netbird <dev@netbird.io> - maintainer: Netbird <dev@netbird.io>
description: Netbird client UI. description: Netbird client UI.
homepage: https://netbird.io/ homepage: https://netbird.io/
@@ -78,7 +77,7 @@ nfpms:
uploads: uploads:
- name: debian - name: debian
ids: ids:
- netbird-ui-deb - netbird-ui-deb
mode: archive mode: archive
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package= target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
username: dev@wiretrustee.com username: dev@wiretrustee.com

View File

@@ -1,5 +1,3 @@
version: 2
project_name: netbird-ui project_name: netbird-ui
builds: builds:
- id: netbird-ui-darwin - id: netbird-ui-darwin
@@ -19,13 +17,10 @@ builds:
- softfloat - softfloat
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: '{{ .CommitTimestamp }}'
tags: tags:
- load_wgnt_from_rsrc - load_wgnt_from_rsrc
universal_binaries:
- id: netbird-ui-darwin
archives: archives:
- builds: - builds:
- netbird-ui-darwin - netbird-ui-darwin
@@ -33,4 +28,4 @@ archives:
checksum: checksum:
name_template: "{{ .ProjectName }}_darwin_checksums.txt" name_template: "{{ .ProjectName }}_darwin_checksums.txt"
changelog: changelog:
disable: true skip: true

View File

@@ -96,7 +96,7 @@ They can be executed from the repository root before every push or PR:
**Goreleaser** **Goreleaser**
```shell ```shell
goreleaser build --snapshot --clean goreleaser --snapshot --rm-dist
``` ```
**golangci-lint** **golangci-lint**
```shell ```shell

View File

@@ -49,8 +49,6 @@
![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab) ![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab)
### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)
### Key features ### Key features
@@ -64,7 +62,6 @@
| | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> | | | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
| | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> | | | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> |
| | | | | <ul><li> - \[x] Docker </ul></li> | | | | | | <ul><li> - \[x] Docker </ul></li> |
### Quickstart with NetBird Cloud ### Quickstart with NetBird Cloud
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install) - Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)

View File

@@ -8,7 +8,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
@@ -16,6 +15,7 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/util/net"
) )
@@ -26,7 +26,7 @@ type ConnectionListener interface {
// TunAdapter export internal TunAdapter for mobile // TunAdapter export internal TunAdapter for mobile
type TunAdapter interface { type TunAdapter interface {
device.TunAdapter iface.TunAdapter
} }
// IFaceDiscover export internal IFaceDiscover for mobile // IFaceDiscover export internal IFaceDiscover for mobile
@@ -51,7 +51,7 @@ func init() {
// Client struct manage the life circle of background service // Client struct manage the life circle of background service
type Client struct { type Client struct {
cfgFile string cfgFile string
tunAdapter device.TunAdapter tunAdapter iface.TunAdapter
iFaceDiscover IFaceDiscover iFaceDiscover IFaceDiscover
recorder *peer.Status recorder *peer.Status
ctxCancel context.CancelFunc ctxCancel context.CancelFunc

View File

@@ -5,8 +5,8 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )

View File

@@ -7,7 +7,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/iface"
) )
func TestInitCommands(t *testing.T) { func TestInitCommands(t *testing.T) {

View File

@@ -805,9 +805,6 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil { if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port) peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
} }
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
for i, route := range peer.Routes { for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeIPString(route) peer.Routes[i] = a.AnonymizeIPString(route)
} }

View File

@@ -3,6 +3,7 @@ package cmd
import ( import (
"context" "context"
"net" "net"
"path/filepath"
"testing" "testing"
"time" "time"
@@ -33,12 +34,18 @@ func startTestingServices(t *testing.T) string {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
testDir := t.TempDir()
config.Datadir = testDir
err = util.CopyFileContents("../testdata/store.json", filepath.Join(testDir, "store.json"))
if err != nil {
t.Fatal(err)
}
_, signalLis := startSignal(t) _, signalLis := startSignal(t)
signalAddr := signalLis.Addr().String() signalAddr := signalLis.Addr().String()
config.Signal.URI = signalAddr config.Signal.URI = signalAddr
_, mgmLis := startManagement(t, config, "../testdata/store.sql") _, mgmLis := startManagement(t, config)
mgmAddr := mgmLis.Addr().String() mgmAddr := mgmLis.Addr().String()
return mgmAddr return mgmAddr
} }
@@ -50,7 +57,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err) t.Fatal(err)
} }
s := grpc.NewServer() s := grpc.NewServer()
srv, err := sig.NewServer(context.Background(), otel.Meter("")) srv, err := sig.NewServer(otel.Meter(""))
require.NoError(t, err) require.NoError(t, err)
sigProto.RegisterSignalExchangeServer(s, srv) sigProto.RegisterSignalExchangeServer(s, srv)
@@ -63,7 +70,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
return s, lis return s, lis
} }
func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) { func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) {
t.Helper() t.Helper()
lis, err := net.Listen("tcp", ":0") lis, err := net.Listen("tcp", ":0")
@@ -71,7 +78,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
t.Fatal(err) t.Fatal(err)
} }
s := grpc.NewServer() s := grpc.NewServer()
store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir()) store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -15,11 +15,11 @@ import (
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )

View File

@@ -8,8 +8,8 @@ import (
) )
func formatError(es []error) string { func formatError(es []error) string {
if len(es) == 1 { if len(es) == 0 {
return fmt.Sprintf("1 error occurred:\n\t* %s", es[0]) return fmt.Sprintf("0 error occurred:\n\t* %s", es[0])
} }
points := make([]string, len(es)) points := make([]string, len(es))

View File

@@ -3,6 +3,7 @@
package firewall package firewall
import ( import (
"context"
"fmt" "fmt"
"runtime" "runtime"
@@ -10,11 +11,10 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
// NewFirewall creates a firewall manager instance // NewFirewall creates a firewall manager instance
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) { func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
if !iface.IsUserspaceBind() { if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
} }

View File

@@ -3,7 +3,7 @@
package firewall package firewall
import ( import (
"errors" "context"
"fmt" "fmt"
"os" "os"
@@ -15,7 +15,6 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables" nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const ( const (
@@ -33,65 +32,54 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type // FWType is the type for the firewall type
type FWType int type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) { func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables // on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic // in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers // so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall // for the userspace packet filtering firewall
fm, err := createNativeFirewall(iface, stateManager) var fm firewall.Manager
var errFw error
if !iface.IsUserspaceBind() {
return fm, err
}
if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
}
return createUserspaceFirewall(iface, fm)
}
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
fm, err := createFW(iface)
if err != nil {
return nil, fmt.Errorf("create firewall: %s", err)
}
if err = fm.Init(stateManager); err != nil {
return nil, fmt.Errorf("init firewall: %s", err)
}
return fm, nil
}
func createFW(iface IFaceMapper) (firewall.Manager, error) {
switch check() { switch check() {
case IPTABLES: case IPTABLES:
log.Info("creating an iptables firewall manager") log.Info("creating an iptables firewall manager")
return nbiptables.Create(iface) fm, errFw = nbiptables.Create(context, iface)
if errFw != nil {
log.Errorf("failed to create iptables manager: %s", errFw)
}
case NFTABLES: case NFTABLES:
log.Info("creating an nftables firewall manager") log.Info("creating an nftables firewall manager")
return nbnftables.Create(iface) fm, errFw = nbnftables.Create(context, iface)
if errFw != nil {
log.Errorf("failed to create nftables manager: %s", errFw)
}
default: default:
errFw = fmt.Errorf("no firewall manager found")
log.Info("no firewall manager found, trying to use userspace packet filtering firewall") log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
return nil, errors.New("no firewall manager found")
}
}
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) {
var errUsp error
if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
} else {
fm, errUsp = uspfilter.Create(iface)
} }
if errUsp != nil { if iface.IsUserspaceBind() {
return nil, fmt.Errorf("create userspace firewall: %s", errUsp) var errUsp error
if errFw == nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
} else {
fm, errUsp = uspfilter.Create(iface)
}
if errUsp != nil {
log.Debugf("failed to create userspace filtering firewall: %s", errUsp)
return nil, errUsp
}
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
return fm, nil
} }
if err := fm.AllowNetbird(); err != nil { if errFw != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err) return nil, errFw
} }
return fm, nil return fm, nil
} }

View File

@@ -1,13 +1,11 @@
package firewall package firewall
import ( import "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
// IFaceMapper defines subset methods of interface required for manager // IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface { type IFaceMapper interface {
Name() string Name() string
Address() device.WGAddress Address() iface.WGAddress
IsUserspaceBind() bool IsUserspaceBind() bool
SetFilter(device.PacketFilter) error SetFilter(iface.PacketFilter) error
} }

View File

@@ -11,8 +11,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
) )
const ( const (
@@ -21,65 +19,49 @@ const (
// rules chains contains the effective ACL rules // rules chains contains the effective ACL rules
chainNameInputRules = "NETBIRD-ACL-INPUT" chainNameInputRules = "NETBIRD-ACL-INPUT"
chainNameOutputRules = "NETBIRD-ACL-OUTPUT" chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
postRoutingMark = "0x000007e4"
) )
type aclEntries map[string][][]string
type entry struct {
spec []string
position int
}
type aclManager struct { type aclManager struct {
iptablesClient *iptables.IPTables iptablesClient *iptables.IPTables
wgIface iFaceMapper wgIface iFaceMapper
routingFwChainName string routeingFwChainName string
entries aclEntries entries map[string][][]string
optionalEntries map[string][]entry ipsetStore *ipsetStore
ipsetStore *ipsetStore
stateManager *statemanager.Manager
} }
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) { func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routeingFwChainName string) (*aclManager, error) {
m := &aclManager{ m := &aclManager{
iptablesClient: iptablesClient, iptablesClient: iptablesClient,
wgIface: wgIface, wgIface: wgIface,
routingFwChainName: routingFwChainName, routeingFwChainName: routeingFwChainName,
entries: make(map[string][][]string), entries: make(map[string][][]string),
optionalEntries: make(map[string][]entry), ipsetStore: newIpsetStore(),
ipsetStore: newIpsetStore(),
} }
if err := ipset.Init(); err != nil { err := ipset.Init()
return nil, fmt.Errorf("init ipset: %w", err) if err != nil {
return nil, fmt.Errorf("failed to init ipset: %w", err)
} }
m.seedInitialEntries()
err = m.cleanChains()
if err != nil {
return nil, err
}
err = m.createDefaultChains()
if err != nil {
return nil, err
}
return m, nil return m, nil
} }
func (m *aclManager) init(stateManager *statemanager.Manager) error { func (m *aclManager) AddFiltering(
m.stateManager = stateManager
m.seedInitialEntries()
m.seedInitialOptionalEntries()
if err := m.cleanChains(); err != nil {
return fmt.Errorf("clean chains: %w", err)
}
if err := m.createDefaultChains(); err != nil {
return fmt.Errorf("create default chains: %w", err)
}
m.updateState()
return nil
}
func (m *aclManager) AddPeerFiltering(
ip net.IP, ip net.IP,
protocol firewall.Protocol, protocol firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
@@ -145,7 +127,7 @@ func (m *aclManager) AddPeerFiltering(
return nil, fmt.Errorf("rule already exists") return nil, fmt.Errorf("rule already exists")
} }
if err := m.iptablesClient.Append("filter", chain, specs...); err != nil { if err := m.iptablesClient.Insert("filter", chain, 1, specs...); err != nil {
return nil, err return nil, err
} }
@@ -157,18 +139,28 @@ func (m *aclManager) AddPeerFiltering(
chain: chain, chain: chain,
} }
m.updateState() if !shouldAddToPrerouting(protocol, dPort, direction) {
return []firewall.Rule{rule}, nil
}
return []firewall.Rule{rule}, nil rulePrerouting, err := m.addPreroutingFilter(ipsetName, string(protocol), dPortVal, ip)
if err != nil {
return []firewall.Rule{rule}, err
}
return []firewall.Rule{rule, rulePrerouting}, nil
} }
// DeletePeerRule from the firewall by rule definition // DeleteRule from the firewall by rule definition
func (m *aclManager) DeletePeerRule(rule firewall.Rule) error { func (m *aclManager) DeleteRule(rule firewall.Rule) error {
r, ok := rule.(*Rule) r, ok := rule.(*Rule)
if !ok { if !ok {
return fmt.Errorf("invalid rule type") return fmt.Errorf("invalid rule type")
} }
if r.chain == "PREROUTING" {
goto DELETERULE
}
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok { if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
// delete IP from ruleset IPs list and ipset // delete IP from ruleset IPs list and ipset
if _, ok := ipsetList.ips[r.ip]; ok { if _, ok := ipsetList.ips[r.ip]; ok {
@@ -193,23 +185,60 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
} }
} }
if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil { DELETERULE:
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err) var table string
if r.chain == "PREROUTING" {
table = "mangle"
} else {
table = "filter"
} }
err := m.iptablesClient.Delete(table, r.chain, r.specs...)
m.updateState() if err != nil {
log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err)
return nil }
return err
} }
func (m *aclManager) Reset() error { func (m *aclManager) Reset() error {
if err := m.cleanChains(); err != nil { return m.cleanChains()
return fmt.Errorf("clean chains: %w", err) }
func (m *aclManager) addPreroutingFilter(ipsetName string, protocol string, port string, ip net.IP) (*Rule, error) {
var src []string
if ipsetName != "" {
src = []string{"-m", "set", "--set", ipsetName, "src"}
} else {
src = []string{"-s", ip.String()}
}
specs := []string{
"-d", m.wgIface.Address().IP.String(),
"-p", protocol,
"--dport", port,
"-j", "MARK", "--set-mark", postRoutingMark,
} }
m.updateState() specs = append(src, specs...)
return nil ok, err := m.iptablesClient.Exists("mangle", "PREROUTING", specs...)
if err != nil {
return nil, fmt.Errorf("failed to check rule: %w", err)
}
if ok {
return nil, fmt.Errorf("rule already exists")
}
if err := m.iptablesClient.Insert("mangle", "PREROUTING", 1, specs...); err != nil {
return nil, err
}
rule := &Rule{
ruleID: uuid.New().String(),
specs: specs,
ipsetName: ipsetName,
ip: ip.String(),
chain: "PREROUTING",
}
return rule, nil
} }
// todo write less destructive cleanup mechanism // todo write less destructive cleanup mechanism
@@ -264,7 +293,8 @@ func (m *aclManager) cleanChains() error {
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING") ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
if err != nil { if err != nil {
return fmt.Errorf("list chains: %w", err) log.Debugf("failed to list chains: %s", err)
return err
} }
if ok { if ok {
for _, rule := range m.entries["PREROUTING"] { for _, rule := range m.entries["PREROUTING"] {
@@ -273,6 +303,11 @@ func (m *aclManager) cleanChains() error {
log.Errorf("failed to delete rule: %v, %s", rule, err) log.Errorf("failed to delete rule: %v, %s", rule, err)
} }
} }
err = m.iptablesClient.ClearChain("mangle", "PREROUTING")
if err != nil {
log.Debugf("failed to clear %s chain: %s", "PREROUTING", err)
return err
}
} }
for _, ipsetName := range m.ipsetStore.ipsetNames() { for _, ipsetName := range m.ipsetStore.ipsetNames() {
@@ -303,98 +338,64 @@ func (m *aclManager) createDefaultChains() error {
for chainName, rules := range m.entries { for chainName, rules := range m.entries {
for _, rule := range rules { for _, rule := range rules {
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil { if chainName == "FORWARD" {
log.Debugf("failed to create input chain jump rule: %s", err) // position 2 because we add it after router's, jump rule
return err if err := m.iptablesClient.InsertUnique(tableName, "FORWARD", 2, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err)
return err
}
} else {
if err := m.iptablesClient.AppendUnique(tableName, chainName, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err)
return err
}
} }
} }
} }
for chainName, entries := range m.optionalEntries {
for _, entry := range entries {
if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil {
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
continue
}
m.entries[chainName] = append(m.entries[chainName], entry.spec)
}
}
clear(m.optionalEntries)
return nil return nil
} }
// seedInitialEntries adds default rules to the entries map, rules are inserted on pos 1, hence the order is reversed.
// We want to make sure our traffic is not dropped by existing rules.
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
func (m *aclManager) seedInitialEntries() { func (m *aclManager) seedInitialEntries() {
m.appendToEntries("INPUT",
[]string{"-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
established := getConntrackEstablished() m.appendToEntries("INPUT",
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("INPUT",
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...)) m.appendToEntries("OUTPUT",
[]string{"-o", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("OUTPUT",
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("OUTPUT",
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameOutputRules})
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"}) m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", chainNameOutputRules})
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("OUTPUT", append([]string{"-o", m.wgIface.Name()}, established...))
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName}) m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...)) m.appendToEntries("FORWARD",
} []string{"-o", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
m.appendToEntries("FORWARD",
[]string{"-i", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", m.routeingFwChainName})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routeingFwChainName})
func (m *aclManager) seedInitialOptionalEntries() { m.appendToEntries("PREROUTING",
m.optionalEntries["FORWARD"] = []entry{ []string{"-t", "mangle", "-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().IP.String(), "-m", "mark", "--mark", postRoutingMark})
{
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark), "-j", chainNameInputRules},
position: 2,
},
}
m.optionalEntries["PREROUTING"] = []entry{
{
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark)},
position: 1,
},
}
} }
func (m *aclManager) appendToEntries(chainName string, spec []string) { func (m *aclManager) appendToEntries(chainName string, spec []string) {
m.entries[chainName] = append(m.entries[chainName], spec) m.entries[chainName] = append(m.entries[chainName], spec)
} }
func (m *aclManager) updateState() {
if m.stateManager == nil {
return
}
var currentState *ShutdownState
if existing := m.stateManager.GetState(currentState); existing != nil {
if existingState, ok := existing.(*ShutdownState); ok {
currentState = existingState
}
}
if currentState == nil {
currentState = &ShutdownState{}
}
currentState.Lock()
defer currentState.Unlock()
currentState.ACLEntries = m.entries
currentState.ACLIPsetStore = m.ipsetStore
if err := m.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err)
}
}
// filterRuleSpecs returns the specs of a filtering rule // filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs( func filterRuleSpecs(
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string, ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
@@ -455,3 +456,18 @@ func transformIPsetName(ipsetName string, sPort, dPort string) string {
return ipsetName return ipsetName
} }
} }
func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
if proto == "all" {
return false
}
if direction != firewall.RuleDirectionIN {
return false
}
if dPort == nil {
return false
}
return true
}

View File

@@ -4,17 +4,13 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
// Manager of iptables firewall // Manager of iptables firewall
@@ -25,7 +21,7 @@ type Manager struct {
ipv4Client *iptables.IPTables ipv4Client *iptables.IPTables
aclMgr *aclManager aclMgr *aclManager
router *router router *routerManager
} }
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
@@ -36,10 +32,10 @@ type iFaceMapper interface {
} }
// Create iptables firewall manager // Create iptables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) { func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil { if err != nil {
return nil, fmt.Errorf("init iptables: %w", err) return nil, fmt.Errorf("iptables is not installed in the system or not supported")
} }
m := &Manager{ m := &Manager{
@@ -47,53 +43,24 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
ipv4Client: iptablesClient, ipv4Client: iptablesClient,
} }
m.router, err = newRouter(iptablesClient, wgIface) m.router, err = newRouterManager(context, iptablesClient)
if err != nil { if err != nil {
return nil, fmt.Errorf("create router: %w", err) log.Debugf("failed to initialize route related chains: %s", err)
return nil, err
} }
m.aclMgr, err = newAclManager(iptablesClient, wgIface, m.router.RouteingFwChainName())
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
if err != nil { if err != nil {
return nil, fmt.Errorf("create acl manager: %w", err) log.Debugf("failed to initialize ACL manager: %s", err)
return nil, err
} }
return m, nil return m, nil
} }
func (m *Manager) Init(stateManager *statemanager.Manager) error { // AddFiltering rule to the firewall
state := &ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
},
}
stateManager.RegisterState(state)
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update state: %v", err)
}
if err := m.router.init(stateManager); err != nil {
return fmt.Errorf("router init: %w", err)
}
if err := m.aclMgr.init(stateManager); err != nil {
// TODO: cleanup router
return fmt.Errorf("acl manager init: %w", err)
}
// persist early to ensure cleanup of chains
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
return nil
}
// AddPeerFiltering adds a rule to the firewall
// //
// Comment will be ignored because some system this feature is not supported // Comment will be ignored because some system this feature is not supported
func (m *Manager) AddPeerFiltering( func (m *Manager) AddFiltering(
ip net.IP, ip net.IP,
protocol firewall.Protocol, protocol firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
@@ -106,86 +73,50 @@ func (m *Manager) AddPeerFiltering(
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName) return m.aclMgr.AddFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
} }
func (m *Manager) AddRouteFiltering( // DeleteRule from the firewall by rule definition
sources []netip.Prefix, func (m *Manager) DeleteRule(rule firewall.Rule) error {
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if !destination.Addr().Is4() { return m.aclMgr.DeleteRule(rule)
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
}
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclMgr.DeletePeerRule(rule)
}
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.DeleteRouteRule(rule)
} }
func (m *Manager) IsServerRouteSupported() bool { func (m *Manager) IsServerRouteSupported() bool {
return true return true
} }
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.AddNatRule(pair) return m.router.InsertRoutingRules(pair)
} }
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.RemoveNatRule(pair) return m.router.RemoveRoutingRules(pair)
}
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return firewall.SetLegacyManagement(m.router, isLegacy)
} }
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset(stateManager *statemanager.Manager) error { func (m *Manager) Reset() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
var merr *multierror.Error errAcl := m.aclMgr.Reset()
if errAcl != nil {
if err := m.aclMgr.Reset(); err != nil { log.Errorf("failed to clean up ACL rules from firewall: %s", errAcl)
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
} }
if err := m.router.Reset(); err != nil { errMgr := m.router.Reset()
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err)) if errMgr != nil {
log.Errorf("failed to clean up router rules from firewall: %s", errMgr)
return errMgr
} }
return errAcl
// attempt to delete state only if all other operations succeeded
if merr == nil {
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete state: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
} }
// AllowNetbird allows netbird interface traffic // AllowNetbird allows netbird interface traffic
@@ -194,7 +125,7 @@ func (m *Manager) AllowNetbird() error {
return nil return nil
} }
_, err := m.AddPeerFiltering( _, err := m.AddFiltering(
net.ParseIP("0.0.0.0"), net.ParseIP("0.0.0.0"),
"all", "all",
nil, nil,
@@ -207,7 +138,7 @@ func (m *Manager) AllowNetbird() error {
if err != nil { if err != nil {
return fmt.Errorf("failed to allow netbird interface traffic: %w", err) return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
} }
_, err = m.AddPeerFiltering( _, err = m.AddFiltering(
net.ParseIP("0.0.0.0"), net.ParseIP("0.0.0.0"),
"all", "all",
nil, nil,
@@ -222,7 +153,3 @@ func (m *Manager) AllowNetbird() error {
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil } func (m *Manager) Flush() error { return nil }
func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}

View File

@@ -1,6 +1,7 @@
package iptables package iptables
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"testing" "testing"
@@ -10,24 +11,9 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/iface"
) )
var ifaceMock = &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct { type iFaceMock struct {
NameFunc func() string NameFunc func() string
@@ -54,15 +40,29 @@ func TestIptablesManager(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err) require.NoError(t, err)
mock := &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
// just check on the local interface // just check on the local interface
manager, err := Create(ifaceMock) manager, err := Create(context.Background(), mock)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
err := manager.Reset(nil) err := manager.Reset()
require.NoError(t, err, "clear the manager state") require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -72,7 +72,7 @@ func TestIptablesManager(t *testing.T) {
t.Run("add first rule", func(t *testing.T) { t.Run("add first rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2") ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}} port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
for _, r := range rule1 { for _, r := range rule1 {
@@ -87,7 +87,7 @@ func TestIptablesManager(t *testing.T) {
port := &fw.Port{ port := &fw.Port{
Values: []int{8043: 8046}, Values: []int{8043: 8046},
} }
rule2, err = manager.AddPeerFiltering( rule2, err = manager.AddFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range") ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
@@ -99,7 +99,7 @@ func TestIptablesManager(t *testing.T) {
t.Run("delete first rule", func(t *testing.T) { t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 { for _, r := range rule1 {
err := manager.DeletePeerRule(r) err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule") require.NoError(t, err, "failed to delete rule")
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...) checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
@@ -108,7 +108,7 @@ func TestIptablesManager(t *testing.T) {
t.Run("delete second rule", func(t *testing.T) { t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 { for _, r := range rule2 {
err := manager.DeletePeerRule(r) err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule") require.NoError(t, err, "failed to delete rule")
} }
@@ -119,10 +119,10 @@ func TestIptablesManager(t *testing.T) {
// add second rule // add second rule
ip := net.ParseIP("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []int{5353}} port := &fw.Port{Values: []int{5353}}
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic") _, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Reset(nil) err = manager.Reset()
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules) ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
@@ -154,14 +154,13 @@ func TestIptablesManagerIPSet(t *testing.T) {
} }
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(context.Background(), mock)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
err := manager.Reset(nil) err := manager.Reset()
require.NoError(t, err, "clear the manager state") require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -171,7 +170,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
t.Run("add first rule with set", func(t *testing.T) { t.Run("add first rule with set", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2") ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}} port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddPeerFiltering( rule1, err = manager.AddFiltering(
ip, "tcp", nil, port, fw.RuleDirectionOUT, ip, "tcp", nil, port, fw.RuleDirectionOUT,
fw.ActionAccept, "default", "accept HTTP traffic", fw.ActionAccept, "default", "accept HTTP traffic",
) )
@@ -190,7 +189,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
port := &fw.Port{ port := &fw.Port{
Values: []int{443}, Values: []int{443},
} }
rule2, err = manager.AddPeerFiltering( rule2, err = manager.AddFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
"default", "accept HTTPS traffic from ports range", "default", "accept HTTPS traffic from ports range",
) )
@@ -203,7 +202,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
t.Run("delete first rule", func(t *testing.T) { t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 { for _, r := range rule1 {
err := manager.DeletePeerRule(r) err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule") require.NoError(t, err, "failed to delete rule")
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index") require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
@@ -212,7 +211,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
t.Run("delete second rule", func(t *testing.T) { t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 { for _, r := range rule2 {
err := manager.DeletePeerRule(r) err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule") require.NoError(t, err, "failed to delete rule")
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty") require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
@@ -220,7 +219,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
}) })
t.Run("reset check", func(t *testing.T) { t.Run("reset check", func(t *testing.T) {
err = manager.Reset(nil) err = manager.Reset()
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
}) })
} }
@@ -252,13 +251,12 @@ func TestIptablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(context.Background(), mock)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
err := manager.Reset(nil) err := manager.Reset()
require.NoError(t, err, "clear the manager state") require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -271,9 +269,9 @@ func TestIptablesCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 { if i%2 == 0 {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else { } else {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
} }
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")

View File

@@ -3,534 +3,370 @@
package iptables package iptables
import ( import (
"context"
"fmt" "fmt"
"net/netip"
"strconv"
"strings" "strings"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/hashicorp/go-multierror"
"github.com/nadoo/ipset"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const ( const (
ipv4Nat = "netbird-rt-nat" Ipv4Forwarding = "netbird-rt-forwarding"
ipv4Nat = "netbird-rt-nat"
) )
// constants needed to manage and create iptable rules // constants needed to manage and create iptable rules
const ( const (
tableFilter = "filter" tableFilter = "filter"
tableNat = "nat" tableNat = "nat"
chainFORWARD = "FORWARD"
chainPOSTROUTING = "POSTROUTING" chainPOSTROUTING = "POSTROUTING"
chainRTNAT = "NETBIRD-RT-NAT" chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWD = "NETBIRD-RT-FWD" chainRTFWD = "NETBIRD-RT-FWD"
routingFinalForwardJump = "ACCEPT" routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE" routingFinalNatJump = "MASQUERADE"
matchSet = "--match-set"
) )
type routeFilteringRuleParams struct { type routerManager struct {
Sources []netip.Prefix ctx context.Context
Destination netip.Prefix stop context.CancelFunc
Proto firewall.Protocol iptablesClient *iptables.IPTables
SPort *firewall.Port rules map[string][]string
DPort *firewall.Port
Direction firewall.RuleDirection
Action firewall.Action
SetName string
} }
type routeRules map[string][]string func newRouterManager(parentCtx context.Context, iptablesClient *iptables.IPTables) (*routerManager, error) {
ctx, cancel := context.WithCancel(parentCtx)
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}] m := &routerManager{
ctx: ctx,
type router struct { stop: cancel,
iptablesClient *iptables.IPTables
rules routeRules
ipsetCounter *ipsetCounter
wgIface iFaceMapper
legacyManagement bool
stateManager *statemanager.Manager
}
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
r := &router{
iptablesClient: iptablesClient, iptablesClient: iptablesClient,
rules: make(map[string][]string), rules: make(map[string][]string),
wgIface: wgIface,
} }
r.ipsetCounter = refcounter.New( err := m.cleanUpDefaultForwardRules()
func(name string, sources []netip.Prefix) (struct{}, error) { if err != nil {
return struct{}{}, r.createIpSet(name, sources) log.Errorf("failed to cleanup routing rules: %s", err)
}, return nil, err
func(name string, _ struct{}) error {
return r.deleteIpSet(name)
},
)
if err := ipset.Init(); err != nil {
return nil, fmt.Errorf("init ipset: %w", err)
} }
err = m.createContainers()
return r, nil if err != nil {
log.Errorf("failed to create containers for route: %s", err)
}
return m, err
} }
func (r *router) init(stateManager *statemanager.Manager) error { // InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain
r.stateManager = stateManager func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
err := i.insertRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, pair)
if err := r.cleanUpDefaultForwardRules(); err != nil { if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err) return err
} }
if err := r.createContainers(); err != nil { err = i.insertRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, firewall.GetInPair(pair))
return fmt.Errorf("create containers: %w", err) if err != nil {
} return err
r.updateState()
return nil
}
func (r *router) AddRouteFiltering(
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil
}
var setName string
if len(sources) > 1 {
setName = firewall.GenerateSetName(sources)
if _, err := r.ipsetCounter.Increment(setName, sources); err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err)
}
}
params := routeFilteringRuleParams{
Sources: sources,
Destination: destination,
Proto: proto,
SPort: sPort,
DPort: dPort,
Action: action,
SetName: setName,
}
rule := genRouteFilteringRuleSpec(params)
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
return nil, fmt.Errorf("add route rule: %v", err)
}
r.rules[string(ruleKey)] = rule
r.updateState()
return ruleKey, nil
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
ruleKey := rule.GetRuleID()
if rule, exists := r.rules[ruleKey]; exists {
setName := r.findSetNameInRule(rule)
if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil {
return fmt.Errorf("delete route rule: %v", err)
}
delete(r.rules, ruleKey)
if setName != "" {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
return fmt.Errorf("failed to remove ipset: %w", err)
}
}
} else {
log.Debugf("route rule %s not found", ruleKey)
}
r.updateState()
return nil
}
func (r *router) findSetNameInRule(rule []string) string {
for i, arg := range rule {
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
return rule[i+3]
}
}
return ""
}
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
return fmt.Errorf("create set %s: %w", setName, err)
}
for _, prefix := range sources {
if err := ipset.AddPrefix(setName, prefix); err != nil {
return fmt.Errorf("add element to set %s: %w", setName, err)
}
}
return nil
}
func (r *router) deleteIpSet(setName string) error {
if err := ipset.Destroy(setName); err != nil {
return fmt.Errorf("destroy set %s: %w", setName, err)
}
return nil
}
// AddNatRule inserts an iptables rule pair into the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error {
if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil {
return fmt.Errorf("add legacy routing rule: %w", err)
}
} }
if !pair.Masquerade { if !pair.Masquerade {
return nil return nil
} }
if err := r.addNatRule(pair); err != nil { err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
return fmt.Errorf("add nat rule: %w", err) if err != nil {
}
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("add inverse nat rule: %w", err)
}
r.updateState()
return nil
}
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse nat rule: %w", err)
}
if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err)
}
r.updateState()
return nil
}
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if err := r.removeLegacyRouteRule(pair); err != nil {
return err return err
} }
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump} err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil { if err != nil {
return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) return err
}
r.rules[ruleKey] = rule
return nil
}
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
delete(r.rules, ruleKey)
} else {
log.Debugf("legacy forwarding rule %s not found", ruleKey)
} }
return nil return nil
} }
// GetLegacyManagement returns the current legacy management mode // insertRoutingRule inserts an iptables rule
func (r *router) GetLegacyManagement() bool { func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
return r.legacyManagement var err error
}
// SetLegacyManagement sets the route manager to use legacy management mode ruleKey := firewall.GenKey(keyFormat, pair.ID)
func (r *router) SetLegacyManagement(isLegacy bool) { rule := genRuleSpec(jump, pair.Source, pair.Destination)
r.legacyManagement = isLegacy existingRule, found := i.rules[ruleKey]
} if found {
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls if err != nil {
func (r *router) RemoveAllLegacyRouteRules() error { return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
var merr *multierror.Error
for k, rule := range r.rules {
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
continue
} }
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil { delete(i.rules, ruleKey)
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) }
} else {
delete(r.rules, k) err = i.iptablesClient.Insert(table, chain, 1, rule...)
if err != nil {
return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
}
i.rules[ruleKey] = rule
return nil
}
// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains
func (i *routerManager) RemoveRoutingRules(pair firewall.RouterPair) error {
err := i.removeRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, pair)
if err != nil {
return err
}
err = i.removeRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, firewall.GetInPair(pair))
if err != nil {
return err
}
if !pair.Masquerade {
return nil
}
err = i.removeRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, pair)
if err != nil {
return err
}
err = i.removeRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, firewall.GetInPair(pair))
if err != nil {
return err
}
return nil
}
func (i *routerManager) removeRoutingRule(keyFormat, table, chain string, pair firewall.RouterPair) error {
var err error
ruleKey := firewall.GenKey(keyFormat, pair.ID)
existingRule, found := i.rules[ruleKey]
if found {
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
} }
} }
delete(i.rules, ruleKey)
r.updateState() return nil
return nberrors.FormatErrorOrNil(merr)
} }
func (r *router) Reset() error { func (i *routerManager) RouteingFwChainName() string {
var merr *multierror.Error return chainRTFWD
if err := r.cleanUpDefaultForwardRules(); err != nil {
merr = multierror.Append(merr, err)
}
r.rules = make(map[string][]string)
if err := r.ipsetCounter.Flush(); err != nil {
merr = multierror.Append(merr, err)
}
r.updateState()
return nberrors.FormatErrorOrNil(merr)
} }
func (r *router) cleanUpDefaultForwardRules() error { func (i *routerManager) Reset() error {
err := r.cleanJumpRules() err := i.cleanUpDefaultForwardRules()
if err != nil {
return err
}
i.rules = make(map[string][]string)
return nil
}
func (i *routerManager) cleanUpDefaultForwardRules() error {
err := i.cleanJumpRules()
if err != nil { if err != nil {
return err return err
} }
log.Debug("flushing routing related tables") log.Debug("flushing routing related tables")
for _, chain := range []string{chainRTFWD, chainRTNAT} { ok, err := i.iptablesClient.ChainExists(tableFilter, chainRTFWD)
table := r.getTableForChain(chain)
ok, err := r.iptablesClient.ChainExists(table, chain)
if err != nil {
log.Errorf("failed check chain %s, error: %v", chain, err)
return err
} else if ok {
err = r.iptablesClient.ClearAndDeleteChain(table, chain)
if err != nil {
log.Errorf("failed cleaning chain %s, error: %v", chain, err)
return err
}
}
}
return nil
}
func (r *router) createContainers() error {
for _, chain := range []string{chainRTFWD, chainRTNAT} {
if err := r.createAndSetupChain(chain); err != nil {
return fmt.Errorf("create chain %s: %w", chain, err)
}
}
if err := r.insertEstablishedRule(chainRTFWD); err != nil {
return fmt.Errorf("insert established rule: %w", err)
}
if err := r.addJumpRules(); err != nil {
return fmt.Errorf("add jump rules: %w", err)
}
return nil
}
func (r *router) createAndSetupChain(chain string) error {
table := r.getTableForChain(chain)
if err := r.iptablesClient.NewChain(table, chain); err != nil {
return fmt.Errorf("failed creating chain %s, error: %v", chain, err)
}
return nil
}
func (r *router) getTableForChain(chain string) string {
if chain == chainRTNAT {
return tableNat
}
return tableFilter
}
func (r *router) insertEstablishedRule(chain string) error {
establishedRule := getConntrackEstablished()
err := r.iptablesClient.Insert(tableFilter, chain, 1, establishedRule...)
if err != nil {
return fmt.Errorf("failed to insert established rule: %v", err)
}
ruleKey := "established-" + chain
r.rules[ruleKey] = establishedRule
return nil
}
func (r *router) addJumpRules() error {
rule := []string{"-j", chainRTNAT}
err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
if err != nil { if err != nil {
log.Errorf("failed check chain %s,error: %v", chainRTFWD, err)
return err return err
} } else if ok {
r.rules[ipv4Nat] = rule err = i.iptablesClient.ClearAndDeleteChain(tableFilter, chainRTFWD)
return nil
}
func (r *router) cleanJumpRules() error {
rule, found := r.rules[ipv4Nat]
if found {
err := r.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
if err != nil { if err != nil {
return fmt.Errorf("failed cleaning rule from chain %s, err: %v", chainPOSTROUTING, err) log.Errorf("failed cleaning chain %s,error: %v", chainRTFWD, err)
return err
} }
} }
ok, err = i.iptablesClient.ChainExists(tableNat, chainRTNAT)
if err != nil {
log.Errorf("failed check chain %s,error: %v", chainRTNAT, err)
return err
} else if ok {
err = i.iptablesClient.ClearAndDeleteChain(tableNat, chainRTNAT)
if err != nil {
log.Errorf("failed cleaning chain %s,error: %v", chainRTNAT, err)
return err
}
}
return nil return nil
} }
func (r *router) addNatRule(pair firewall.RouterPair) error { func (i *routerManager) createContainers() error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair) if i.rules[Ipv4Forwarding] != nil {
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
}
delete(r.rules, ruleKey)
}
rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, r.wgIface.Name(), pair.Inverse)
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule...); err != nil {
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
}
r.rules[ruleKey] = rule
return nil
}
func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
return fmt.Errorf("error while removing existing nat rule for %s: %v", pair.Destination, err)
}
delete(r.rules, ruleKey)
} else {
log.Debugf("nat rule %s not found", ruleKey)
}
return nil
}
func (r *router) updateState() {
if r.stateManager == nil {
return
}
var currentState *ShutdownState
if existing := r.stateManager.GetState(currentState); existing != nil {
if existingState, ok := existing.(*ShutdownState); ok {
currentState = existingState
}
}
if currentState == nil {
currentState = &ShutdownState{}
}
currentState.Lock()
defer currentState.Unlock()
currentState.RouteRules = r.rules
currentState.RouteIPsetCounter = r.ipsetCounter
if err := r.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err)
}
}
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
intdir := "-i"
lointdir := "-o"
if inverse {
intdir = "-o"
lointdir = "-i"
}
return []string{intdir, intf, "!", lointdir, "lo", "-s", source.String(), "-d", destination.String(), "-j", jump}
}
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
var rule []string
if params.SetName != "" {
rule = append(rule, "-m", "set", matchSet, params.SetName, "src")
} else if len(params.Sources) > 0 {
source := params.Sources[0]
rule = append(rule, "-s", source.String())
}
rule = append(rule, "-d", params.Destination.String())
if params.Proto != firewall.ProtocolALL {
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
rule = append(rule, applyPort("--sport", params.SPort)...)
rule = append(rule, applyPort("--dport", params.DPort)...)
}
rule = append(rule, "-j", actionToStr(params.Action))
return rule
}
func applyPort(flag string, port *firewall.Port) []string {
if port == nil {
return nil return nil
} }
if port.IsRange && len(port.Values) == 2 { errMSGFormat := "failed creating chain %s,error: %v"
return []string{flag, fmt.Sprintf("%d:%d", port.Values[0], port.Values[1])} err := i.createChain(tableFilter, chainRTFWD)
if err != nil {
return fmt.Errorf(errMSGFormat, chainRTFWD, err)
} }
if len(port.Values) > 1 { err = i.createChain(tableNat, chainRTNAT)
portList := make([]string, len(port.Values)) if err != nil {
for i, p := range port.Values { return fmt.Errorf(errMSGFormat, chainRTNAT, err)
portList[i] = strconv.Itoa(p)
}
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
} }
return []string{flag, strconv.Itoa(port.Values[0])} err = i.addJumpRules()
if err != nil {
return fmt.Errorf("error while creating jump rules: %v", err)
}
return nil
}
// addJumpRules create jump rules to send packets to NetBird chains
func (i *routerManager) addJumpRules() error {
rule := []string{"-j", chainRTFWD}
err := i.iptablesClient.Insert(tableFilter, chainFORWARD, 1, rule...)
if err != nil {
return err
}
i.rules[Ipv4Forwarding] = rule
rule = []string{"-j", chainRTNAT}
err = i.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
if err != nil {
return err
}
i.rules[ipv4Nat] = rule
return nil
}
// cleanJumpRules cleans jump rules that was sending packets to NetBird chains
func (i *routerManager) cleanJumpRules() error {
var err error
errMSGFormat := "failed cleaning rule from chain %s,err: %v"
rule, found := i.rules[Ipv4Forwarding]
if found {
err = i.iptablesClient.DeleteIfExists(tableFilter, chainFORWARD, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, chainFORWARD, err)
}
}
rule, found = i.rules[ipv4Nat]
if found {
err = i.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, chainPOSTROUTING, err)
}
}
rules, err := i.iptablesClient.List("nat", "POSTROUTING")
if err != nil {
return fmt.Errorf("failed to list rules: %s", err)
}
for _, ruleString := range rules {
if !strings.Contains(ruleString, "NETBIRD") {
continue
}
rule := strings.Fields(ruleString)
err := i.iptablesClient.DeleteIfExists("nat", "POSTROUTING", rule[2:]...)
if err != nil {
return fmt.Errorf("failed to delete postrouting jump rule: %s", err)
}
}
rules, err = i.iptablesClient.List(tableFilter, "FORWARD")
if err != nil {
return fmt.Errorf("failed to list rules in FORWARD chain: %s", err)
}
for _, ruleString := range rules {
if !strings.Contains(ruleString, "NETBIRD") {
continue
}
rule := strings.Fields(ruleString)
err := i.iptablesClient.DeleteIfExists(tableFilter, "FORWARD", rule[2:]...)
if err != nil {
return fmt.Errorf("failed to delete FORWARD jump rule: %s", err)
}
}
return nil
}
func (i *routerManager) createChain(table, newChain string) error {
chains, err := i.iptablesClient.ListChains(table)
if err != nil {
return fmt.Errorf("couldn't get %s table chains, error: %v", table, err)
}
shouldCreateChain := true
for _, chain := range chains {
if chain == newChain {
shouldCreateChain = false
}
}
if shouldCreateChain {
err = i.iptablesClient.NewChain(table, newChain)
if err != nil {
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
}
// Add the loopback return rule to the NAT chain
loopbackRule := []string{"-o", "lo", "-j", "RETURN"}
err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...)
if err != nil {
return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err)
}
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
if err != nil {
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
}
}
return nil
}
// addNATRule appends an iptables rule pair to the nat chain
func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(keyFormat, pair.ID)
rule := genRuleSpec(jump, pair.Source, pair.Destination)
existingRule, found := i.rules[ruleKey]
if found {
err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
}
delete(i.rules, ruleKey)
}
// inserting after loopback ignore rule
err := i.iptablesClient.Insert(table, chain, 2, rule...)
if err != nil {
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
}
i.rules[ruleKey] = rule
return nil
}
// genRuleSpec generates rule specification
func genRuleSpec(jump, source, destination string) []string {
return []string{"-s", source, "-d", destination, "-j", jump}
}
func getIptablesRuleType(table string) string {
ruleType := "forwarding"
if table == tableNat {
ruleType = "nat"
}
return ruleType
} }

View File

@@ -3,13 +3,12 @@
package iptables package iptables
import ( import (
"net/netip" "context"
"os/exec" "os/exec"
"testing" "testing"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -29,9 +28,8 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client") require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(iptablesClient, ifaceMock) manager, err := newRouterManager(context.TODO(), iptablesClient)
require.NoError(t, err, "should return a valid iptables manager") require.NoError(t, err, "should return a valid iptables manager")
require.NoError(t, manager.init(nil))
defer func() { defer func() {
_ = manager.Reset() _ = manager.Reset()
@@ -39,22 +37,26 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
require.Len(t, manager.rules, 2, "should have created rules map") require.Len(t, manager.rules, 2, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...) exists, err := manager.iptablesClient.Exists(tableFilter, chainFORWARD, manager.rules[Ipv4Forwarding]...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainFORWARD)
require.True(t, exists, "forwarding rule should exist")
exists, err = manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
require.True(t, exists, "postrouting rule should exist") require.True(t, exists, "postrouting rule should exist")
pair := firewall.RouterPair{ pair := firewall.RouterPair{
ID: "abc", ID: "abc",
Source: netip.MustParsePrefix("100.100.100.1/32"), Source: "100.100.100.1/32",
Destination: netip.MustParsePrefix("100.100.100.0/24"), Destination: "100.100.100.0/24",
Masquerade: true, Masquerade: true,
} }
forward4Rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump} forward4Rule := genRuleSpec(routingFinalForwardJump, pair.Source, pair.Destination)
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...) err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "inserting rule should not return error")
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, ifaceMock.Name(), false) nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination)
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...) err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "inserting rule should not return error")
@@ -63,7 +65,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
} }
func TestIptablesManager_AddNatRule(t *testing.T) { func TestIptablesManager_InsertRoutingRules(t *testing.T) {
if !isIptablesSupported() { if !isIptablesSupported() {
t.SkipNow() t.SkipNow()
@@ -74,9 +76,8 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client") require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(iptablesClient, ifaceMock) manager, err := newRouterManager(context.TODO(), iptablesClient)
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
defer func() { defer func() {
err := manager.Reset() err := manager.Reset()
@@ -85,13 +86,35 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
} }
}() }()
err = manager.AddNatRule(testCase.InputPair) err = manager.InsertRoutingRules(testCase.InputPair)
require.NoError(t, err, "forwarding pair should be inserted") require.NoError(t, err, "forwarding pair should be inserted")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false) forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination)
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...) exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.True(t, exists, "forwarding rule should exist")
foundRule, found := manager.rules[forwardRuleKey]
require.True(t, found, "forwarding rule should exist in the manager map")
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.True(t, exists, "income forwarding rule should exist")
foundRule, found = manager.rules[inForwardRuleKey]
require.True(t, found, "income forwarding rule should exist in the manager map")
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
if testCase.InputPair.Masquerade { if testCase.InputPair.Masquerade {
require.True(t, exists, "nat rule should be created") require.True(t, exists, "nat rule should be created")
@@ -104,8 +127,8 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
require.False(t, foundNat, "nat rule should not exist in the map") require.False(t, foundNat, "nat rule should not exist in the map")
} }
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true) inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...) exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
@@ -123,7 +146,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
} }
} }
func TestIptablesManager_RemoveNatRule(t *testing.T) { func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
if !isIptablesSupported() { if !isIptablesSupported() {
t.SkipNow() t.SkipNow()
@@ -133,23 +156,34 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
manager, err := newRouter(iptablesClient, ifaceMock) manager, err := newRouterManager(context.TODO(), iptablesClient)
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
defer func() { defer func() {
_ = manager.Reset() _ = manager.Reset()
}() }()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false) forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination)
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...)
require.NoError(t, err, "inserting rule should not return error")
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...)
require.NoError(t, err, "inserting rule should not return error")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...) err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "inserting rule should not return error")
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true) inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...) err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "inserting rule should not return error")
@@ -157,14 +191,28 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
err = manager.Reset() err = manager.Reset()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
err = manager.RemoveNatRule(testCase.InputPair) err = manager.RemoveRoutingRules(testCase.InputPair)
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...) exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.False(t, exists, "forwarding rule should not exist")
_, found := manager.rules[forwardRuleKey]
require.False(t, found, "forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.False(t, exists, "income forwarding rule should not exist")
_, found = manager.rules[inForwardRuleKey]
require.False(t, found, "income forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
require.False(t, exists, "nat rule should not exist") require.False(t, exists, "nat rule should not exist")
_, found := manager.rules[natRuleKey] _, found = manager.rules[natRuleKey]
require.False(t, found, "nat rule should exist in the manager map") require.False(t, found, "nat rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...) exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
@@ -173,176 +221,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
_, found = manager.rules[inNatRuleKey] _, found = manager.rules[inNatRuleKey]
require.False(t, found, "income nat rule should exist in the manager map") require.False(t, found, "income nat rule should exist in the manager map")
})
}
}
func TestRouter_AddRouteFiltering(t *testing.T) {
if !isIptablesSupported() {
t.Skip("iptables not supported on this system")
}
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "Failed to create iptables client")
r, err := newRouter(iptablesClient, ifaceMock)
require.NoError(t, err, "Failed to create router manager")
require.NoError(t, r.init(nil))
defer func() {
err := r.Reset()
require.NoError(t, err, "Failed to reset router")
}()
tests := []struct {
name string
sources []netip.Prefix
destination netip.Prefix
proto firewall.Protocol
sPort *firewall.Port
dPort *firewall.Port
direction firewall.RuleDirection
action firewall.Action
expectSet bool
}{
{
name: "Basic TCP rule with single source",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolTCP,
sPort: nil,
dPort: &firewall.Port{Values: []int{80}},
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with multiple sources",
sources: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("192.168.0.0/16"),
},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolUDP,
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionDrop,
expectSet: true,
},
{
name: "All protocols rule",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
destination: netip.MustParsePrefix("0.0.0.0/0"),
proto: firewall.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "ICMP rule",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolICMP,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "TCP rule with multiple source ports",
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
destination: netip.MustParsePrefix("192.168.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with single IP and port range",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolUDP,
sPort: nil,
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
},
{
name: "TCP rule with source and destination ports",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
destination: netip.MustParsePrefix("172.16.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []int{22}},
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "Drop all incoming traffic",
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
destination: netip.MustParsePrefix("192.168.0.0/24"),
proto: firewall.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.GetRuleID()]
assert.True(t, ok, "Rule not found in internal map")
// Log the internal rule
t.Logf("Internal rule: %v", rule)
// Check if the rule exists in iptables
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, rule...)
assert.NoError(t, err, "Failed to check rule existence")
assert.True(t, exists, "Rule not found in iptables")
// Verify rule content
params := routeFilteringRuleParams{
Sources: tt.sources,
Destination: tt.destination,
Proto: tt.proto,
SPort: tt.sPort,
DPort: tt.dPort,
Action: tt.action,
SetName: "",
}
expectedRule := genRouteFilteringRuleSpec(params)
if tt.expectSet {
setName := firewall.GenerateSetName(tt.sources)
params.SetName = setName
expectedRule = genRouteFilteringRuleSpec(params)
// Check if the set was created
_, exists := r.ipsetCounter.Get(setName)
assert.True(t, exists, "IPSet not created")
}
assert.Equal(t, expectedRule, rule, "Rule content mismatch")
// Clean up
err = r.DeleteRouteRule(ruleKey)
require.NoError(t, err, "Failed to delete rule")
}) })
} }
} }

View File

@@ -1,16 +1,14 @@
package iptables package iptables
import "encoding/json"
type ipList struct { type ipList struct {
ips map[string]struct{} ips map[string]struct{}
} }
func newIpList(ip string) *ipList { func newIpList(ip string) ipList {
ips := make(map[string]struct{}) ips := make(map[string]struct{})
ips[ip] = struct{}{} ips[ip] = struct{}{}
return &ipList{ return ipList{
ips: ips, ips: ips,
} }
} }
@@ -19,47 +17,27 @@ func (s *ipList) addIP(ip string) {
s.ips[ip] = struct{}{} s.ips[ip] = struct{}{}
} }
// MarshalJSON implements json.Marshaler
func (s *ipList) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
IPs map[string]struct{} `json:"ips"`
}{
IPs: s.ips,
})
}
// UnmarshalJSON implements json.Unmarshaler
func (s *ipList) UnmarshalJSON(data []byte) error {
temp := struct {
IPs map[string]struct{} `json:"ips"`
}{}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
s.ips = temp.IPs
return nil
}
type ipsetStore struct { type ipsetStore struct {
ipsets map[string]*ipList ipsets map[string]ipList // ipsetName -> ruleset
} }
func newIpsetStore() *ipsetStore { func newIpsetStore() *ipsetStore {
return &ipsetStore{ return &ipsetStore{
ipsets: make(map[string]*ipList), ipsets: make(map[string]ipList),
} }
} }
func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) { func (s *ipsetStore) ipset(ipsetName string) (ipList, bool) {
r, ok := s.ipsets[ipsetName] r, ok := s.ipsets[ipsetName]
return r, ok return r, ok
} }
func (s *ipsetStore) addIpList(ipsetName string, list *ipList) { func (s *ipsetStore) addIpList(ipsetName string, list ipList) {
s.ipsets[ipsetName] = list s.ipsets[ipsetName] = list
} }
func (s *ipsetStore) deleteIpset(ipsetName string) { func (s *ipsetStore) deleteIpset(ipsetName string) {
s.ipsets[ipsetName] = ipList{}
delete(s.ipsets, ipsetName) delete(s.ipsets, ipsetName)
} }
@@ -70,24 +48,3 @@ func (s *ipsetStore) ipsetNames() []string {
} }
return names return names
} }
// MarshalJSON implements json.Marshaler
func (s *ipsetStore) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
IPSets map[string]*ipList `json:"ipsets"`
}{
IPSets: s.ipsets,
})
}
// UnmarshalJSON implements json.Unmarshaler
func (s *ipsetStore) UnmarshalJSON(data []byte) error {
temp := struct {
IPSets map[string]*ipList `json:"ipsets"`
}{}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
s.ipsets = temp.IPSets
return nil
}

View File

@@ -1,70 +0,0 @@
package iptables
import (
"fmt"
"sync"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
}
func (i *InterfaceState) Name() string {
return i.NameStr
}
func (i *InterfaceState) Address() device.WGAddress {
return i.WGAddress
}
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct {
sync.Mutex
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
RouteRules routeRules `json:"route_rules,omitempty"`
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
ACLEntries aclEntries `json:"acl_entries,omitempty"`
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
}
func (s *ShutdownState) Name() string {
return "iptables_state"
}
func (s *ShutdownState) Cleanup() error {
ipt, err := Create(s.InterfaceState)
if err != nil {
return fmt.Errorf("create iptables manager: %w", err)
}
if s.RouteRules != nil {
ipt.router.rules = s.RouteRules
}
if s.RouteIPsetCounter != nil {
ipt.router.ipsetCounter.LoadData(s.RouteIPsetCounter)
}
if s.ACLEntries != nil {
ipt.aclMgr.entries = s.ACLEntries
}
if s.ACLIPsetStore != nil {
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
}
if err := ipt.Reset(nil); err != nil {
return fmt.Errorf("reset iptables manager: %w", err)
}
return nil
}

View File

@@ -1,23 +1,15 @@
package manager package manager
import ( import (
"crypto/sha256"
"encoding/hex"
"fmt" "fmt"
"net" "net"
"net/netip"
"sort"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const ( const (
ForwardingFormatPrefix = "netbird-fwd-" NatFormat = "netbird-nat-%s"
ForwardingFormat = "netbird-fwd-%s-%t" ForwardingFormat = "netbird-fwd-%s"
NatFormat = "netbird-nat-%s-%t" InNatFormat = "netbird-nat-in-%s"
InForwardingFormat = "netbird-fwd-in-%s"
) )
// Rule abstraction should be implemented by each firewall manager // Rule abstraction should be implemented by each firewall manager
@@ -54,16 +46,14 @@ const (
// It declares methods which handle actions required by the // It declares methods which handle actions required by the
// Netbird client for ACL and routing functionality // Netbird client for ACL and routing functionality
type Manager interface { type Manager interface {
Init(stateManager *statemanager.Manager) error
// AllowNetbird allows netbird interface traffic // AllowNetbird allows netbird interface traffic
AllowNetbird() error AllowNetbird() error
// AddPeerFiltering adds a rule to the firewall // AddFiltering rule to the firewall
// //
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
AddPeerFiltering( AddFiltering(
ip net.IP, ip net.IP,
proto Protocol, proto Protocol,
sPort *Port, sPort *Port,
@@ -74,116 +64,25 @@ type Manager interface {
comment string, comment string,
) ([]Rule, error) ) ([]Rule, error)
// DeletePeerRule from the firewall by rule definition // DeleteRule from the firewall by rule definition
DeletePeerRule(rule Rule) error DeleteRule(rule Rule) error
// IsServerRouteSupported returns true if the firewall supports server side routing operations // IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool IsServerRouteSupported() bool
AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error) // InsertRoutingRules inserts a routing firewall rule
InsertRoutingRules(pair RouterPair) error
// DeleteRouteRule deletes a routing rule // RemoveRoutingRules removes a routing firewall rule
DeleteRouteRule(rule Rule) error RemoveRoutingRules(pair RouterPair) error
// AddNatRule inserts a routing NAT rule
AddNatRule(pair RouterPair) error
// RemoveNatRule removes a routing NAT rule
RemoveNatRule(pair RouterPair) error
// SetLegacyManagement sets the legacy management mode
SetLegacyManagement(legacy bool) error
// Reset firewall to the default state // Reset firewall to the default state
Reset(stateManager *statemanager.Manager) error Reset() error
// Flush the changes to firewall controller // Flush the changes to firewall controller
Flush() error Flush() error
} }
func GenKey(format string, pair RouterPair) string { func GenKey(format string, input string) string {
return fmt.Sprintf(format, pair.ID, pair.Inverse) return fmt.Sprintf(format, input)
}
// LegacyManager defines the interface for legacy management operations
type LegacyManager interface {
RemoveAllLegacyRouteRules() error
GetLegacyManagement() bool
SetLegacyManagement(bool)
}
// SetLegacyManagement sets the route manager to use legacy management
func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
oldLegacy := router.GetLegacyManagement()
if oldLegacy != isLegacy {
router.SetLegacyManagement(isLegacy)
log.Debugf("Set legacy management to %v", isLegacy)
}
// client reconnected to a newer mgmt, we need to clean up the legacy rules
if !isLegacy && oldLegacy {
if err := router.RemoveAllLegacyRouteRules(); err != nil {
return fmt.Errorf("remove legacy routing rules: %v", err)
}
log.Debugf("Legacy routing rules removed")
}
return nil
}
// GenerateSetName generates a unique name for an ipset based on the given sources.
func GenerateSetName(sources []netip.Prefix) string {
// sort for consistent naming
SortPrefixes(sources)
var sourcesStr strings.Builder
for _, src := range sources {
sourcesStr.WriteString(src.String())
}
hash := sha256.Sum256([]byte(sourcesStr.String()))
shortHash := hex.EncodeToString(hash[:])[:8]
return fmt.Sprintf("nb-%s", shortHash)
}
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
if len(prefixes) == 0 {
return prefixes
}
merged := []netip.Prefix{prefixes[0]}
for _, prefix := range prefixes[1:] {
last := merged[len(merged)-1]
if last.Contains(prefix.Addr()) {
// If the current prefix is contained within the last merged prefix, skip it
continue
}
if prefix.Contains(last.Addr()) {
// If the current prefix contains the last merged prefix, replace it
merged[len(merged)-1] = prefix
} else {
// Otherwise, add the current prefix to the merged list
merged = append(merged, prefix)
}
}
return merged
}
// SortPrefixes sorts the given slice of netip.Prefix in place.
// It sorts first by IP address, then by prefix length (most specific to least specific).
func SortPrefixes(prefixes []netip.Prefix) {
sort.Slice(prefixes, func(i, j int) bool {
addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr())
if addrCmp != 0 {
return addrCmp < 0
}
// If IP addresses are the same, compare prefix lengths (longer prefixes first)
return prefixes[i].Bits() > prefixes[j].Bits()
})
} }

View File

@@ -1,192 +0,0 @@
package manager_test
import (
"net/netip"
"reflect"
"regexp"
"testing"
"github.com/netbirdio/netbird/client/firewall/manager"
)
func TestGenerateSetName(t *testing.T) {
t.Run("Different orders result in same hash", func(t *testing.T) {
prefixes1 := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
}
prefixes2 := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.1.0/24"),
}
result1 := manager.GenerateSetName(prefixes1)
result2 := manager.GenerateSetName(prefixes2)
if result1 != result2 {
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
}
})
t.Run("Result format is correct", func(t *testing.T) {
prefixes := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
}
result := manager.GenerateSetName(prefixes)
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result)
if err != nil {
t.Fatalf("Error matching regex: %v", err)
}
if !matched {
t.Errorf("Result format is incorrect: %s", result)
}
})
t.Run("Empty input produces consistent result", func(t *testing.T) {
result1 := manager.GenerateSetName([]netip.Prefix{})
result2 := manager.GenerateSetName([]netip.Prefix{})
if result1 != result2 {
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
}
})
t.Run("IPv4 and IPv6 mixing", func(t *testing.T) {
prefixes1 := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("2001:db8::/32"),
}
prefixes2 := []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
netip.MustParsePrefix("192.168.1.0/24"),
}
result1 := manager.GenerateSetName(prefixes1)
result2 := manager.GenerateSetName(prefixes2)
if result1 != result2 {
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)
}
})
}
func TestMergeIPRanges(t *testing.T) {
tests := []struct {
name string
input []netip.Prefix
expected []netip.Prefix
}{
{
name: "Empty input",
input: []netip.Prefix{},
expected: []netip.Prefix{},
},
{
name: "Single range",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Two non-overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
},
},
{
name: "One range containing another",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "One range containing another (different order)",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.0.0/16"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "Overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.1.128/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Overlapping ranges (different order)",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.128/25"),
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Multiple overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.2.0/24"),
netip.MustParsePrefix("192.168.1.128/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "Partially overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/23"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.2.0/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/23"),
netip.MustParsePrefix("192.168.2.0/25"),
},
},
{
name: "IPv6 ranges",
input: []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
netip.MustParsePrefix("2001:db8:1::/48"),
netip.MustParsePrefix("2001:db8:2::/48"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := manager.MergeIPRanges(tt.input)
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("MergeIPRanges() = %v, want %v", result, tt.expected)
}
})
}
}

View File

@@ -1,26 +1,18 @@
package manager package manager
import (
"net/netip"
"github.com/netbirdio/netbird/route"
)
type RouterPair struct { type RouterPair struct {
ID route.ID ID string
Source netip.Prefix Source string
Destination netip.Prefix Destination string
Masquerade bool Masquerade bool
Inverse bool
} }
func GetInversePair(pair RouterPair) RouterPair { func GetInPair(pair RouterPair) RouterPair {
return RouterPair{ return RouterPair{
ID: pair.ID, ID: pair.ID,
// invert Source/Destination // invert Source/Destination
Source: pair.Destination, Source: pair.Destination,
Destination: pair.Source, Destination: pair.Source,
Masquerade: pair.Masquerade, Masquerade: pair.Masquerade,
Inverse: true,
} }
} }

View File

@@ -11,13 +11,12 @@ import (
"time" "time"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/iface"
) )
const ( const (
@@ -30,63 +29,72 @@ const (
chainNameInputFilter = "netbird-acl-input-filter" chainNameInputFilter = "netbird-acl-input-filter"
chainNameOutputFilter = "netbird-acl-output-filter" chainNameOutputFilter = "netbird-acl-output-filter"
chainNameForwardFilter = "netbird-acl-forward-filter" chainNameForwardFilter = "netbird-acl-forward-filter"
chainNamePrerouting = "netbird-rt-prerouting"
allowNetbirdInputRuleID = "allow Netbird incoming traffic" allowNetbirdInputRuleID = "allow Netbird incoming traffic"
) )
const flushError = "flush: %w"
var ( var (
anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
postroutingMark = []byte{0xe4, 0x7, 0x0, 0x00}
) )
type AclManager struct { type AclManager struct {
rConn *nftables.Conn rConn *nftables.Conn
sConn *nftables.Conn sConn *nftables.Conn
wgIface iFaceMapper wgIface iFaceMapper
routingFwChainName string routeingFwChainName string
workTable *nftables.Table workTable *nftables.Table
chainInputRules *nftables.Chain chainInputRules *nftables.Chain
chainOutputRules *nftables.Chain chainOutputRules *nftables.Chain
chainFwFilter *nftables.Chain
chainPrerouting *nftables.Chain
ipsetStore *ipsetStore ipsetStore *ipsetStore
rules map[string]*Rule rules map[string]*Rule
} }
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) { // iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
}
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) {
// sConn is used for creating sets and adding/removing elements from them // sConn is used for creating sets and adding/removing elements from them
// it's differ then rConn (which does create new conn for each flush operation) // it's differ then rConn (which does create new conn for each flush operation)
// and is permanent. Using same connection for both type of operations // and is permanent. Using same connection for booth type of operations
// overloads netlink with high amount of rules ( > 10000) // overloads netlink with high amount of rules ( > 10000)
sConn, err := nftables.New(nftables.AsLasting()) sConn, err := nftables.New(nftables.AsLasting())
if err != nil { if err != nil {
return nil, fmt.Errorf("create nf conn: %w", err) return nil, err
} }
return &AclManager{ m := &AclManager{
rConn: &nftables.Conn{}, rConn: &nftables.Conn{},
sConn: sConn, sConn: sConn,
wgIface: wgIface, wgIface: wgIface,
workTable: table, workTable: table,
routingFwChainName: routingFwChainName, routeingFwChainName: routeingFwChainName,
ipsetStore: newIpsetStore(), ipsetStore: newIpsetStore(),
rules: make(map[string]*Rule), rules: make(map[string]*Rule),
}, nil }
err = m.createDefaultChains()
if err != nil {
return nil, err
}
return m, nil
} }
func (m *AclManager) init(workTable *nftables.Table) error { // AddFiltering rule to the firewall
m.workTable = workTable
return m.createDefaultChains()
}
// AddPeerFiltering rule to the firewall
// //
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
func (m *AclManager) AddPeerFiltering( func (m *AclManager) AddFiltering(
ip net.IP, ip net.IP,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
@@ -112,11 +120,20 @@ func (m *AclManager) AddPeerFiltering(
} }
newRules = append(newRules, ioRule) newRules = append(newRules, ioRule)
if !shouldAddToPrerouting(proto, dPort, direction) {
return newRules, nil
}
preroutingRule, err := m.addPreroutingFiltering(ipset, proto, dPort, ip)
if err != nil {
return newRules, err
}
newRules = append(newRules, preroutingRule)
return newRules, nil return newRules, nil
} }
// DeletePeerRule from the firewall by rule definition // DeleteRule from the firewall by rule definition
func (m *AclManager) DeletePeerRule(rule firewall.Rule) error { func (m *AclManager) DeleteRule(rule firewall.Rule) error {
r, ok := rule.(*Rule) r, ok := rule.(*Rule)
if !ok { if !ok {
return fmt.Errorf("invalid rule type") return fmt.Errorf("invalid rule type")
@@ -182,7 +199,8 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
return nil return nil
} }
// createDefaultAllowRules creates default allow rules for the input and output chains // createDefaultAllowRules In case if the USP firewall manager can use the native firewall manager we must to create allow rules for
// input and output chains
func (m *AclManager) createDefaultAllowRules() error { func (m *AclManager) createDefaultAllowRules() error {
expIn := []expr.Any{ expIn := []expr.Any{
&expr.Payload{ &expr.Payload{
@@ -196,13 +214,13 @@ func (m *AclManager) createDefaultAllowRules() error {
SourceRegister: 1, SourceRegister: 1,
DestRegister: 1, DestRegister: 1,
Len: 4, Len: 4,
Mask: []byte{0, 0, 0, 0}, Mask: []byte{0x00, 0x00, 0x00, 0x00},
Xor: []byte{0, 0, 0, 0}, Xor: zeroXor,
}, },
// net address // net address
&expr.Cmp{ &expr.Cmp{
Register: 1, Register: 1,
Data: []byte{0, 0, 0, 0}, Data: []byte{0x00, 0x00, 0x00, 0x00},
}, },
&expr.Verdict{ &expr.Verdict{
Kind: expr.VerdictAccept, Kind: expr.VerdictAccept,
@@ -228,13 +246,13 @@ func (m *AclManager) createDefaultAllowRules() error {
SourceRegister: 1, SourceRegister: 1,
DestRegister: 1, DestRegister: 1,
Len: 4, Len: 4,
Mask: []byte{0, 0, 0, 0}, Mask: []byte{0x00, 0x00, 0x00, 0x00},
Xor: []byte{0, 0, 0, 0}, Xor: zeroXor,
}, },
// net address // net address
&expr.Cmp{ &expr.Cmp{
Register: 1, Register: 1,
Data: []byte{0, 0, 0, 0}, Data: []byte{0x00, 0x00, 0x00, 0x00},
}, },
&expr.Verdict{ &expr.Verdict{
Kind: expr.VerdictAccept, Kind: expr.VerdictAccept,
@@ -248,8 +266,10 @@ func (m *AclManager) createDefaultAllowRules() error {
Exprs: expOut, Exprs: expOut,
}) })
if err := m.rConn.Flush(); err != nil { err := m.rConn.Flush()
return fmt.Errorf(flushError, err) if err != nil {
log.Debugf("failed to create default allow rules: %s", err)
return err
} }
return nil return nil
} }
@@ -270,11 +290,15 @@ func (m *AclManager) Flush() error {
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err) log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
} }
if err := m.refreshRuleHandles(m.chainPrerouting); err != nil {
log.Errorf("failed to refresh rule handles IPv4 prerouting chain: %v", err)
}
return nil return nil
} }
func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) { func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) {
ruleId := generatePeerRuleId(ip, sPort, dPort, direction, action, ipset) ruleId := generateRuleId(ip, sPort, dPort, direction, action, ipset)
if r, ok := m.rules[ruleId]; ok { if r, ok := m.rules[ruleId]; ok {
return &Rule{ return &Rule{
r.nftRule, r.nftRule,
@@ -284,7 +308,18 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
}, nil }, nil
} }
var expressions []expr.Any ifaceKey := expr.MetaKeyIIFNAME
if direction == firewall.RuleDirectionOUT {
ifaceKey = expr.MetaKeyOIFNAME
}
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
}
if proto != firewall.ProtocolALL { if proto != firewall.ProtocolALL {
expressions = append(expressions, &expr.Payload{ expressions = append(expressions, &expr.Payload{
@@ -294,15 +329,21 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
Len: uint32(1), Len: uint32(1),
}) })
protoData, err := protoToInt(proto) var protoData []byte
if err != nil { switch proto {
return nil, fmt.Errorf("convert protocol to number: %v", err) case firewall.ProtocolTCP:
protoData = []byte{unix.IPPROTO_TCP}
case firewall.ProtocolUDP:
protoData = []byte{unix.IPPROTO_UDP}
case firewall.ProtocolICMP:
protoData = []byte{unix.IPPROTO_ICMP}
default:
return nil, fmt.Errorf("unsupported protocol: %s", proto)
} }
expressions = append(expressions, &expr.Cmp{ expressions = append(expressions, &expr.Cmp{
Register: 1, Register: 1,
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
Data: []byte{protoData}, Data: protoData,
}) })
} }
@@ -391,9 +432,10 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
} else { } else {
chain = m.chainOutputRules chain = m.chainOutputRules
} }
nftRule := m.rConn.AddRule(&nftables.Rule{ nftRule := m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable, Table: m.workTable,
Chain: chain, Chain: chain,
Position: 0,
Exprs: expressions, Exprs: expressions,
UserData: userData, UserData: userData,
}) })
@@ -411,13 +453,139 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
return rule, nil return rule, nil
} }
func (m *AclManager) addPreroutingFiltering(ipset *nftables.Set, proto firewall.Protocol, port *firewall.Port, ip net.IP) (*Rule, error) {
var protoData []byte
switch proto {
case firewall.ProtocolTCP:
protoData = []byte{unix.IPPROTO_TCP}
case firewall.ProtocolUDP:
protoData = []byte{unix.IPPROTO_UDP}
case firewall.ProtocolICMP:
protoData = []byte{unix.IPPROTO_ICMP}
default:
return nil, fmt.Errorf("unsupported protocol: %s", proto)
}
ruleId := generateRuleIdForMangle(ipset, ip, proto, port)
if r, ok := m.rules[ruleId]; ok {
return &Rule{
r.nftRule,
r.nftSet,
r.ruleID,
ip,
}, nil
}
var ipExpression expr.Any
// add individual IP for match if no ipset defined
rawIP := ip.To4()
if ipset == nil {
ipExpression = &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: rawIP,
}
} else {
ipExpression = &expr.Lookup{
SourceRegister: 1,
SetName: ipset.Name,
SetID: ipset.ID,
}
}
expressions := []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
ipExpression,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: m.wgIface.Address().IP.To4(),
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: uint32(9),
Len: uint32(1),
},
&expr.Cmp{
Register: 1,
Op: expr.CmpOpEq,
Data: protoData,
},
}
if port != nil {
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: encodePort(*port),
},
)
}
expressions = append(expressions,
&expr.Immediate{
Register: 1,
Data: postroutingMark,
},
&expr.Meta{
Key: expr.MetaKeyMARK,
SourceRegister: true,
Register: 1,
},
)
nftRule := m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainPrerouting,
Position: 0,
Exprs: expressions,
UserData: []byte(ruleId),
})
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf("flush insert rule: %v", err)
}
rule := &Rule{
nftRule: nftRule,
nftSet: ipset,
ruleID: ruleId,
ip: ip,
}
m.rules[ruleId] = rule
if ipset != nil {
m.ipsetStore.AddReferenceToIpset(ipset.Name)
}
return rule, nil
}
func (m *AclManager) createDefaultChains() (err error) { func (m *AclManager) createDefaultChains() (err error) {
// chainNameInputRules // chainNameInputRules
chain := m.createChain(chainNameInputRules) chain := m.createChain(chainNameInputRules)
err = m.rConn.Flush() err = m.rConn.Flush()
if err != nil { if err != nil {
log.Debugf("failed to create chain (%s): %s", chain.Name, err) log.Debugf("failed to create chain (%s): %s", chain.Name, err)
return fmt.Errorf(flushError, err) return err
} }
m.chainInputRules = chain m.chainInputRules = chain
@@ -433,6 +601,9 @@ func (m *AclManager) createDefaultChains() (err error) {
// netbird-acl-input-filter // netbird-acl-input-filter
// type filter hook input priority filter; policy accept; // type filter hook input priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput) chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
//netbird-acl-input-filter iifname "wt0" ip saddr 100.72.0.0/16 ip daddr != 100.72.0.0/16 accept
m.addRouteAllowRule(chain, expr.MetaKeyIIFNAME)
m.addFwdAllow(chain, expr.MetaKeyIIFNAME)
m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
m.addDropExpressions(chain, expr.MetaKeyIIFNAME) m.addDropExpressions(chain, expr.MetaKeyIIFNAME)
err = m.rConn.Flush() err = m.rConn.Flush()
@@ -444,6 +615,7 @@ func (m *AclManager) createDefaultChains() (err error) {
// netbird-acl-output-filter // netbird-acl-output-filter
// type filter hook output priority filter; policy accept; // type filter hook output priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput) chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput)
m.addRouteAllowRule(chain, expr.MetaKeyOIFNAME)
m.addFwdAllow(chain, expr.MetaKeyOIFNAME) m.addFwdAllow(chain, expr.MetaKeyOIFNAME)
m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules
m.addDropExpressions(chain, expr.MetaKeyOIFNAME) m.addDropExpressions(chain, expr.MetaKeyOIFNAME)
@@ -454,106 +626,29 @@ func (m *AclManager) createDefaultChains() (err error) {
} }
// netbird-acl-forward-filter // netbird-acl-forward-filter
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward) m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd m.addJumpRulesToRtForward() // to
m.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME) m.addMarkAccept()
m.addJumpRuleToInputChain() // to netbird-acl-input-rules
m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME)
err = m.rConn.Flush() err = m.rConn.Flush()
if err != nil { if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err) log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
return fmt.Errorf(flushError, err) return err
} }
if err := m.allowRedirectedTraffic(chainFwFilter); err != nil { // netbird-acl-output-filter
log.Errorf("failed to allow redirected traffic: %s", err) // type filter hook output priority filter; policy accept;
m.chainPrerouting = m.createPreroutingMangle()
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", m.chainPrerouting.Name, err)
return err
} }
return nil return nil
} }
// Makes redirected traffic originally destined for the host itself (now subject to the forward filter) func (m *AclManager) addJumpRulesToRtForward() {
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
// netbird peer IP.
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
preroutingChain := m.rConn.AddChain(&nftables.Chain{
Name: chainNamePrerouting,
Table: m.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
})
m.addPreroutingRule(preroutingChain)
m.addFwmarkToForward(chainFwFilter)
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: preroutingChain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Fib{
Register: 1,
ResultADDRTYPE: true,
FlagDADDR: true,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
},
})
}
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
Chain: chainFwFilter,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: m.chainInputRules.Name,
},
},
})
}
func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
expressions := []expr.Any{ expressions := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{ &expr.Cmp{
@@ -563,15 +658,68 @@ func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
}, },
&expr.Verdict{ &expr.Verdict{
Kind: expr.VerdictJump, Kind: expr.VerdictJump,
Chain: m.routingFwChainName, Chain: m.routeingFwChainName,
}, },
} }
_ = m.rConn.AddRule(&nftables.Rule{ _ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable, Table: m.workTable,
Chain: chainFwFilter, Chain: m.chainFwFilter,
Exprs: expressions, Exprs: expressions,
}) })
expressions = []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: m.routeingFwChainName,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainFwFilter,
Exprs: expressions,
})
}
func (m *AclManager) addMarkAccept() {
// oifname "wt0" meta mark 0x000007e4 accept
// iifname "wt0" meta mark 0x000007e4 accept
ifaces := []expr.MetaKey{expr.MetaKeyIIFNAME, expr.MetaKeyOIFNAME}
for _, iface := range ifaces {
expressions := []expr.Any{
&expr.Meta{Key: iface, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: postroutingMark,
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainFwFilter,
Exprs: expressions,
})
}
} }
func (m *AclManager) createChain(name string) *nftables.Chain { func (m *AclManager) createChain(name string) *nftables.Chain {
@@ -581,13 +729,10 @@ func (m *AclManager) createChain(name string) *nftables.Chain {
} }
chain = m.rConn.AddChain(chain) chain = m.rConn.AddChain(chain)
insertReturnTrafficRule(m.rConn, m.workTable, chain)
return chain return chain
} }
func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain { func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.ChainHook) *nftables.Chain {
polAccept := nftables.ChainPolicyAccept polAccept := nftables.ChainPolicyAccept
chain := &nftables.Chain{ chain := &nftables.Chain{
Name: name, Name: name,
@@ -601,6 +746,74 @@ func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.Ch
return m.rConn.AddChain(chain) return m.rConn.AddChain(chain)
} }
func (m *AclManager) createPreroutingMangle() *nftables.Chain {
polAccept := nftables.ChainPolicyAccept
chain := &nftables.Chain{
Name: "netbird-acl-prerouting-filter",
Table: m.workTable,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
Policy: &polAccept,
}
chain = m.rConn.AddChain(chain)
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
expressions := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: m.wgIface.Address().IP.To4(),
},
&expr.Immediate{
Register: 1,
Data: postroutingMark,
},
&expr.Meta{
Key: expr.MetaKeyMARK,
SourceRegister: true,
Register: 1,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: chain,
Exprs: expressions,
})
return chain
}
func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any { func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
expressions := []expr.Any{ expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1}, &expr.Meta{Key: ifaceKey, Register: 1},
@@ -619,9 +832,101 @@ func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.Met
return nil return nil
} }
func (m *AclManager) addJumpRuleToInputChain() {
expressions := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: m.chainInputRules.Name,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainFwFilter,
Exprs: expressions,
})
}
func (m *AclManager) addRouteAllowRule(chain *nftables.Chain, netIfName expr.MetaKey) {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
var srcOp, dstOp expr.CmpOp
if netIfName == expr.MetaKeyIIFNAME {
srcOp = expr.CmpOpEq
dstOp = expr.CmpOpNeq
} else {
srcOp = expr.CmpOpNeq
dstOp = expr.CmpOpEq
}
expressions := []expr.Any{
&expr.Meta{Key: netIfName, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: srcOp,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: dstOp,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: expressions,
})
}
func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) { func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
dstOp := expr.CmpOpNeq var srcOp, dstOp expr.CmpOp
if iifname == expr.MetaKeyIIFNAME {
srcOp = expr.CmpOpNeq
dstOp = expr.CmpOpEq
} else {
srcOp = expr.CmpOpEq
dstOp = expr.CmpOpNeq
}
expressions := []expr.Any{ expressions := []expr.Any{
&expr.Meta{Key: iifname, Register: 1}, &expr.Meta{Key: iifname, Register: 1},
&expr.Cmp{ &expr.Cmp{
@@ -629,6 +934,24 @@ func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
Register: 1, Register: 1,
Data: ifname(m.wgIface.Name()), Data: ifname(m.wgIface.Name()),
}, },
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: srcOp,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Payload{ &expr.Payload{
DestRegister: 2, DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader, Base: expr.PayloadBaseNetworkHeader,
@@ -659,6 +982,7 @@ func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
} }
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) { func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
expressions := []expr.Any{ expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1}, &expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{ &expr.Cmp{
@@ -666,12 +990,47 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr
Register: 1, Register: 1,
Data: ifname(m.wgIface.Name()), Data: ifname(m.wgIface.Name()),
}, },
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Verdict{ &expr.Verdict{
Kind: expr.VerdictJump, Kind: expr.VerdictJump,
Chain: to, Chain: to,
}, },
} }
_ = m.rConn.AddRule(&nftables.Rule{ _ = m.rConn.AddRule(&nftables.Rule{
Table: chain.Table, Table: chain.Table,
Chain: chain, Chain: chain,
@@ -773,7 +1132,7 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
return nil return nil
} }
func generatePeerRuleId( func generateRuleId(
ip net.IP, ip net.IP,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
@@ -796,6 +1155,33 @@ func generatePeerRuleId(
} }
return "set:" + ipset.Name + rulesetID return "set:" + ipset.Name + rulesetID
} }
func generateRuleIdForMangle(ipset *nftables.Set, ip net.IP, proto firewall.Protocol, port *firewall.Port) string {
// case of icmp port is empty
var p string
if port != nil {
p = port.String()
}
if ipset != nil {
return fmt.Sprintf("p:set:%s:%s:%v", ipset.Name, proto, p)
} else {
return fmt.Sprintf("p:ip:%s:%s:%v", ip.String(), proto, p)
}
}
func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
if proto == "all" {
return false
}
if direction != firewall.RuleDirectionIN {
return false
}
if dPort == nil && proto != firewall.ProtocolICMP {
return false
}
return true
}
func encodePort(port firewall.Port) []byte { func encodePort(port firewall.Port) []byte {
bs := make([]byte, 2) bs := make([]byte, 2)
@@ -805,19 +1191,6 @@ func encodePort(port firewall.Port) []byte {
func ifname(n string) []byte { func ifname(n string) []byte {
b := make([]byte, 16) b := make([]byte, 16)
copy(b, n+"\x00") copy(b, []byte(n+"\x00"))
return b return b
} }
func protoToInt(protocol firewall.Protocol) (uint8, error) {
switch protocol {
case firewall.ProtocolTCP:
return unix.IPPROTO_TCP, nil
case firewall.ProtocolUDP:
return unix.IPPROTO_UDP, nil
case firewall.ProtocolICMP:
return unix.IPPROTO_ICMP, nil
}
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
}

View File

@@ -5,34 +5,20 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const ( const (
// tableNameNetbird is the name of the table that is used for filtering by the Netbird client // tableName is the name of the table that is used for filtering by the Netbird client
tableNameNetbird = "netbird" tableName = "netbird"
tableNameFilter = "filter"
chainNameInput = "INPUT"
) )
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
}
// Manager of iptables firewall // Manager of iptables firewall
type Manager struct { type Manager struct {
mutex sync.Mutex mutex sync.Mutex
@@ -44,73 +30,35 @@ type Manager struct {
} }
// Create nftables firewall manager // Create nftables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) { func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
m := &Manager{ m := &Manager{
rConn: &nftables.Conn{}, rConn: &nftables.Conn{},
wgIface: wgIface, wgIface: wgIface,
} }
workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4} workTable, err := m.createWorkTable()
var err error
m.router, err = newRouter(workTable, wgIface)
if err != nil { if err != nil {
return nil, fmt.Errorf("create router: %w", err) return nil, err
} }
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw) m.router, err = newRouter(context, workTable)
if err != nil { if err != nil {
return nil, fmt.Errorf("create acl manager: %w", err) return nil, err
}
m.aclManager, err = newAclManager(workTable, wgIface, m.router.RouteingFwChainName())
if err != nil {
return nil, err
} }
return m, nil return m, nil
} }
// Init nftables firewall manager // AddFiltering rule to the firewall
func (m *Manager) Init(stateManager *statemanager.Manager) error {
workTable, err := m.createWorkTable()
if err != nil {
return fmt.Errorf("create work table: %w", err)
}
if err := m.router.init(workTable); err != nil {
return fmt.Errorf("router init: %w", err)
}
if err := m.aclManager.init(workTable); err != nil {
// TODO: cleanup router
return fmt.Errorf("acl manager init: %w", err)
}
stateManager.RegisterState(&ShutdownState{})
// We only need to record minimal interface state for potential recreation.
// Unlike iptables, which requires tracking individual rules, nftables maintains
// a known state (our netbird table plus a few static rules). This allows for easy
// cleanup using Reset() without needing to store specific rules.
if err := stateManager.UpdateState(&ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
},
}); err != nil {
log.Errorf("failed to update state: %v", err)
}
// persist early
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
return nil
}
// AddPeerFiltering rule to the firewall
// //
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
func (m *Manager) AddPeerFiltering( func (m *Manager) AddFiltering(
ip net.IP, ip net.IP,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
@@ -128,52 +76,33 @@ func (m *Manager) AddPeerFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", ip.String()) return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
} }
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment) return m.aclManager.AddFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
} }
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) { // DeleteRule from the firewall by rule definition
func (m *Manager) DeleteRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if !destination.Addr().Is4() { return m.aclManager.DeleteRule(rule)
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
}
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclManager.DeletePeerRule(rule)
}
// DeleteRouteRule deletes a routing rule
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.DeleteRouteRule(rule)
} }
func (m *Manager) IsServerRouteSupported() bool { func (m *Manager) IsServerRouteSupported() bool {
return true return true
} }
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.AddNatRule(pair) return m.router.AddRoutingRules(pair)
} }
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.RemoveNatRule(pair) return m.router.RemoveRoutingRules(pair)
} }
// AllowNetbird allows netbird interface traffic // AllowNetbird allows netbird interface traffic
@@ -197,7 +126,7 @@ func (m *Manager) AllowNetbird() error {
var chain *nftables.Chain var chain *nftables.Chain
for _, c := range chains { for _, c := range chains {
if c.Table.Name == tableNameFilter && c.Name == chainNameForward { if c.Table.Name == "filter" && c.Name == "INPUT" {
chain = c chain = c
break break
} }
@@ -228,86 +157,47 @@ func (m *Manager) AllowNetbird() error {
return nil return nil
} }
// SetLegacyManagement sets the route manager to use legacy management
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return firewall.SetLegacyManagement(m.router, isLegacy)
}
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset(stateManager *statemanager.Manager) error { func (m *Manager) Reset() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if err := m.resetNetbirdInputRules(); err != nil {
return fmt.Errorf("reset netbird input rules: %v", err)
}
if err := m.router.Reset(); err != nil {
return fmt.Errorf("reset router: %v", err)
}
if err := m.cleanupNetbirdTables(); err != nil {
return fmt.Errorf("cleanup netbird tables: %v", err)
}
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
return fmt.Errorf("delete state: %v", err)
}
return nil
}
func (m *Manager) resetNetbirdInputRules() error {
chains, err := m.rConn.ListChains() chains, err := m.rConn.ListChains()
if err != nil { if err != nil {
return fmt.Errorf("list chains: %w", err) return fmt.Errorf("list of chains: %w", err)
} }
m.deleteNetbirdInputRules(chains)
return nil
}
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
for _, c := range chains { for _, c := range chains {
// delete Netbird allow input traffic rule if it exists
if c.Table.Name == "filter" && c.Name == "INPUT" { if c.Table.Name == "filter" && c.Name == "INPUT" {
rules, err := m.rConn.GetRules(c.Table, c) rules, err := m.rConn.GetRules(c.Table, c)
if err != nil { if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err) log.Errorf("get rules for chain %q: %v", c.Name, err)
continue continue
} }
for _, r := range rules {
m.deleteMatchingRules(rules) if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
} if err := m.rConn.DelRule(r); err != nil {
} log.Errorf("delete rule: %v", err)
} }
}
func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) {
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
if err := m.rConn.DelRule(r); err != nil {
log.Errorf("delete rule: %v", err)
} }
} }
} }
}
func (m *Manager) cleanupNetbirdTables() error { m.router.ResetForwardRules()
tables, err := m.rConn.ListTables() tables, err := m.rConn.ListTables()
if err != nil { if err != nil {
return fmt.Errorf("list tables: %w", err) return fmt.Errorf("list of tables: %w", err)
} }
for _, t := range tables { for _, t := range tables {
if t.Name == tableNameNetbird { if t.Name == tableName {
m.rConn.DelTable(t) m.rConn.DelTable(t)
} }
} }
return nil
return m.rConn.Flush()
} }
// Flush rule/chain/set operations from the buffer // Flush rule/chain/set operations from the buffer
@@ -328,12 +218,12 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
} }
for _, t := range tables { for _, t := range tables {
if t.Name == tableNameNetbird { if t.Name == tableName {
m.rConn.DelTable(t) m.rConn.DelTable(t)
} }
} }
table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}) table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
err = m.rConn.Flush() err = m.rConn.Flush()
return table, err return table, err
} }
@@ -349,7 +239,9 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
Register: 1, Register: 1,
Data: ifname(m.wgIface.Name()), Data: ifname(m.wgIface.Name()),
}, },
&expr.Verdict{}, &expr.Verdict{
Kind: expr.VerdictAccept,
},
}, },
UserData: []byte(allowNetbirdInputRuleID), UserData: []byte(allowNetbirdInputRuleID),
} }
@@ -359,7 +251,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule { func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
ifName := ifname(m.wgIface.Name()) ifName := ifname(m.wgIface.Name())
for _, rule := range existedRules { for _, rule := range existedRules {
if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput { if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" {
if len(rule.Exprs) < 4 { if len(rule.Exprs) < 4 {
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME { if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
continue continue
@@ -373,38 +265,3 @@ func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftable
} }
return nil return nil
} }
func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) {
rule := &nftables.Rule{
Table: table,
Chain: chain,
Exprs: getEstablishedExprs(1),
}
conn.InsertRule(rule)
}
func getEstablishedExprs(register uint32) []expr.Any {
return []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: register,
},
&expr.Bitwise{
SourceRegister: register,
DestRegister: register,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: register,
Data: []byte{0, 0, 0, 0},
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
}

View File

@@ -1,6 +1,7 @@
package nftables package nftables
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@@ -8,30 +9,14 @@ import (
"time" "time"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/iface"
) )
var ifaceMock = &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct { type iFaceMock struct {
NameFunc func() string NameFunc func() string
@@ -55,15 +40,28 @@ func (i *iFaceMock) Address() iface.WGAddress {
func (i *iFaceMock) IsUserspaceBind() bool { return false } func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestNftablesManager(t *testing.T) { func TestNftablesManager(t *testing.T) {
mock := &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
// just check on the local interface // just check on the local interface
manager, err := Create(ifaceMock) manager, err := Create(context.Background(), mock)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
defer func() { defer func() {
err = manager.Reset(nil) err = manager.Reset()
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
@@ -72,7 +70,7 @@ func TestNftablesManager(t *testing.T) {
testClient := &nftables.Conn{} testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering( rule, err := manager.AddFiltering(
ip, ip,
fw.ProtocolTCP, fw.ProtocolTCP,
nil, nil,
@@ -90,35 +88,17 @@ func TestNftablesManager(t *testing.T) {
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules) rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules") require.NoError(t, err, "failed to get rules")
require.Len(t, rules, 2, "expected 2 rules") require.Len(t, rules, 1, "expected 1 rules")
expectedExprs1 := []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions")
ipToAdd, _ := netip.AddrFromSlice(ip) ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap() add := ipToAdd.Unmap()
expectedExprs2 := []expr.Any{ expectedExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname("lo"),
},
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader, Base: expr.PayloadBaseNetworkHeader,
@@ -154,10 +134,10 @@ func TestNftablesManager(t *testing.T) {
}, },
&expr.Verdict{Kind: expr.VerdictDrop}, &expr.Verdict{Kind: expr.VerdictDrop},
} }
require.ElementsMatch(t, rules[1].Exprs, expectedExprs2, "expected the same expressions") require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions")
for _, r := range rule { for _, r := range rule {
err = manager.DeletePeerRule(r) err = manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule") require.NoError(t, err, "failed to delete rule")
} }
@@ -166,10 +146,9 @@ func TestNftablesManager(t *testing.T) {
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules) rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules") require.NoError(t, err, "failed to get rules")
// established rule remains require.Len(t, rules, 0, "expected 0 rules after deletion")
require.Len(t, rules, 1, "expected 1 rules after deletion")
err = manager.Reset(nil) err = manager.Reset()
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
} }
@@ -192,13 +171,12 @@ func TestNFtablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(context.Background(), mock)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
defer func() { defer func() {
if err := manager.Reset(nil); err != nil { if err := manager.Reset(); err != nil {
t.Errorf("clear the manager state: %v", err) t.Errorf("clear the manager state: %v", err)
} }
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -209,9 +187,9 @@ func TestNFtablesCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 { if i%2 == 0 {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else { } else {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
} }
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")

View File

@@ -0,0 +1,431 @@
package nftables
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"net/netip"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/manager"
)
const (
chainNameRouteingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-nat"
userDataAcceptForwardRuleSrc = "frwacceptsrc"
userDataAcceptForwardRuleDst = "frwacceptdst"
loopbackInterface = "lo\x00"
)
// some presets for building nftable rules
var (
zeroXor = binaryutil.NativeEndian.PutUint32(0)
exprCounterAccept = []expr.Any{
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
)
type router struct {
ctx context.Context
stop context.CancelFunc
conn *nftables.Conn
workTable *nftables.Table
filterTable *nftables.Table
chains map[string]*nftables.Chain
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
rules map[string]*nftables.Rule
isDefaultFwdRulesEnabled bool
}
func newRouter(parentCtx context.Context, workTable *nftables.Table) (*router, error) {
ctx, cancel := context.WithCancel(parentCtx)
r := &router{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
workTable: workTable,
chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
}
var err error
r.filterTable, err = r.loadFilterTable()
if err != nil {
if errors.Is(err, errFilterTableNotFound) {
log.Warnf("table 'filter' not found for forward rules")
} else {
return nil, err
}
}
err = r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
err = r.createContainers()
if err != nil {
log.Errorf("failed to create containers for route: %s", err)
}
return r, err
}
func (r *router) RouteingFwChainName() string {
return chainNameRouteingFw
}
// ResetForwardRules cleans existing nftables default forward rules from the system
func (r *router) ResetForwardRules() {
err := r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to reset forward rules: %s", err)
}
}
func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
}
for _, table := range tables {
if table.Name == "filter" {
return table, nil
}
}
return nil, errFilterTableNotFound
}
func (r *router) createContainers() error {
r.chains[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRouteingFw,
Table: r.workTable,
})
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityNATSource - 1,
Type: nftables.ChainTypeNAT,
})
// Add RETURN rule for loopback interface
loRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte(loopbackInterface),
},
&expr.Verdict{Kind: expr.VerdictReturn},
},
}
r.conn.InsertRule(loRule)
err := r.refreshRulesMap()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to initialize table: %v", err)
}
return nil
}
// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain
func (r *router) AddRoutingRules(pair manager.RouterPair) error {
err := r.refreshRulesMap()
if err != nil {
return err
}
err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
if err != nil {
return err
}
err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
if err != nil {
return err
}
if pair.Masquerade {
err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
if err != nil {
return err
}
err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
if err != nil {
return err
}
}
if r.filterTable != nil && !r.isDefaultFwdRulesEnabled {
log.Debugf("add default accept forward rule")
r.acceptForwardRule(pair.Source)
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.Destination, err)
}
return nil
}
// addRoutingRule inserts a nftable rule to the conn client flush queue
func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
var expression []expr.Any
if isNat {
expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) // nolint:gocritic
} else {
expression = append(sourceExp, append(destExp, exprCounterAccept...)...) // nolint:gocritic
}
ruleKey := manager.GenKey(format, pair.ID)
_, exists := r.rules[ruleKey]
if exists {
err := r.removeRoutingRule(format, pair)
if err != nil {
return err
}
}
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainName],
Exprs: expression,
UserData: []byte(ruleKey),
})
return nil
}
func (r *router) acceptForwardRule(sourceNetwork string) {
src := generateCIDRMatcherExpressions(true, sourceNetwork)
dst := generateCIDRMatcherExpressions(false, "0.0.0.0/0")
var exprs []expr.Any
exprs = append(src, append(dst, &expr.Verdict{ // nolint:gocritic
Kind: expr.VerdictAccept,
})...)
rule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleSrc),
}
r.conn.AddRule(rule)
src = generateCIDRMatcherExpressions(true, "0.0.0.0/0")
dst = generateCIDRMatcherExpressions(false, sourceNetwork)
exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic
Kind: expr.VerdictAccept,
})...)
rule = &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleDst),
}
r.conn.AddRule(rule)
r.isDefaultFwdRulesEnabled = true
}
// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
func (r *router) RemoveRoutingRules(pair manager.RouterPair) error {
err := r.refreshRulesMap()
if err != nil {
return err
}
err = r.removeRoutingRule(manager.ForwardingFormat, pair)
if err != nil {
return err
}
err = r.removeRoutingRule(manager.InForwardingFormat, manager.GetInPair(pair))
if err != nil {
return err
}
err = r.removeRoutingRule(manager.NatFormat, pair)
if err != nil {
return err
}
err = r.removeRoutingRule(manager.InNatFormat, manager.GetInPair(pair))
if err != nil {
return err
}
if len(r.rules) == 0 {
err := r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
}
log.Debugf("nftables: removed rules for %s", pair.Destination)
return nil
}
// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error {
ruleKey := manager.GenKey(format, pair.ID)
rule, found := r.rules[ruleKey]
if found {
ruleType := "forwarding"
if rule.Chain.Type == nftables.ChainTypeNAT {
ruleType = "nat"
}
err := r.conn.DelRule(rule)
if err != nil {
return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.Destination, err)
}
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.Destination)
delete(r.rules, ruleKey)
}
return nil
}
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
// duplicates and to get missing attributes that we don't have when adding new rules
func (r *router) refreshRulesMap() error {
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("nftables: unable to list rules: %v", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
r.rules[string(rule.UserData)] = rule
}
}
}
return nil
}
func (r *router) cleanUpDefaultForwardRules() error {
if r.filterTable == nil {
r.isDefaultFwdRulesEnabled = false
return nil
}
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return err
}
var rules []*nftables.Rule
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name {
continue
}
if chain.Name != "FORWARD" {
continue
}
rules, err = r.conn.GetRules(r.filterTable, chain)
if err != nil {
return err
}
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleSrc)) || bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleDst)) {
err := r.conn.DelRule(rule)
if err != nil {
return err
}
}
}
r.isDefaultFwdRulesEnabled = false
return r.conn.Flush()
}
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any {
ip, network, _ := net.ParseCIDR(cidr)
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
var offSet uint32
if source {
offSet = 12 // src offset
} else {
offSet = 16 // dst offset
}
return []expr.Any{
// fetch src add
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offSet,
Len: 4,
},
// net mask
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 4,
Mask: network.Mask,
Xor: zeroXor,
},
// net address
&expr.Cmp{
Register: 1,
Data: add.AsSlice(),
},
}
}

View File

@@ -1,885 +0,0 @@
package nftables
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"net"
"net/netip"
"strings"
"github.com/coreos/go-iptables/iptables"
"github.com/davecgh/go-spew/spew"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
)
const (
chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameForward = "FORWARD"
userDataAcceptForwardRuleIif = "frwacceptiif"
userDataAcceptForwardRuleOif = "frwacceptoif"
)
const refreshRulesMapError = "refresh rules map: %w"
var (
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
)
type router struct {
conn *nftables.Conn
workTable *nftables.Table
filterTable *nftables.Table
chains map[string]*nftables.Chain
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
rules map[string]*nftables.Rule
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
wgIface iFaceMapper
legacyManagement bool
}
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
r := &router{
conn: &nftables.Conn{},
workTable: workTable,
chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
wgIface: wgIface,
}
r.ipsetCounter = refcounter.New(
r.createIpSet,
r.deleteIpSet,
)
var err error
r.filterTable, err = r.loadFilterTable()
if err != nil {
if errors.Is(err, errFilterTableNotFound) {
log.Warnf("table 'filter' not found for forward rules")
} else {
return nil, fmt.Errorf("load filter table: %w", err)
}
}
return r, nil
}
func (r *router) init(workTable *nftables.Table) error {
r.workTable = workTable
if err := r.removeAcceptForwardRules(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
if err := r.createContainers(); err != nil {
return fmt.Errorf("create containers: %w", err)
}
return nil
}
// Reset cleans existing nftables default forward rules from the system
func (r *router) Reset() error {
// clear without deleting the ipsets, the nf table will be deleted by the caller
r.ipsetCounter.Clear()
return r.removeAcceptForwardRules()
}
func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
}
for _, table := range tables {
if table.Name == "filter" {
return table, nil
}
}
return nil, errFilterTableNotFound
}
func (r *router) createContainers() error {
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingFw,
Table: r.workTable,
})
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
prio := *nftables.ChainPriorityNATSource - 1
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting,
Priority: &prio,
Type: nftables.ChainTypeNAT,
})
if err := r.acceptForwardRules(); err != nil {
log.Errorf("failed to add accept rules for the forward chain: %s", err)
}
if err := r.refreshRulesMap(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: unable to initialize table: %v", err)
}
return nil
}
// AddRouteFiltering appends a nftables rule to the routing chain
func (r *router) AddRouteFiltering(
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil
}
chain := r.chains[chainNameRoutingFw]
var exprs []expr.Any
switch {
case len(sources) == 1 && sources[0].Bits() == 0:
// If it's 0.0.0.0/0, we don't need to add any source matching
case len(sources) == 1:
// If there's only one source, we can use it directly
exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...)
default:
// If there are multiple sources, create or get an ipset
var err error
exprs, err = r.getIpSetExprs(sources, exprs)
if err != nil {
return nil, fmt.Errorf("get ipset expressions: %w", err)
}
}
// Handle destination
exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...)
// Handle protocol
if proto != firewall.ProtocolALL {
protoNum, err := protoToInt(proto)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err)
}
exprs = append(exprs, &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1})
exprs = append(exprs, &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
})
exprs = append(exprs, applyPort(sPort, true)...)
exprs = append(exprs, applyPort(dPort, false)...)
}
exprs = append(exprs, &expr.Counter{})
var verdict expr.VerdictKind
if action == firewall.ActionAccept {
verdict = expr.VerdictAccept
} else {
verdict = expr.VerdictDrop
}
exprs = append(exprs, &expr.Verdict{Kind: verdict})
rule := &nftables.Rule{
Table: r.workTable,
Chain: chain,
Exprs: exprs,
UserData: []byte(ruleKey),
}
rule = r.conn.AddRule(rule)
log.Tracef("Adding route rule %s", spew.Sdump(rule))
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf(flushError, err)
}
r.rules[string(ruleKey)] = rule
log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
return ruleKey, nil
}
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
setName := firewall.GenerateSetName(sources)
ref, err := r.ipsetCounter.Increment(setName, sources)
if err != nil {
return nil, fmt.Errorf("create or get ipset for sources: %w", err)
}
exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: ref.Out.Name,
SetID: ref.Out.ID,
},
)
return exprs, nil
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleKey := rule.GetRuleID()
nftRule, exists := r.rules[ruleKey]
if !exists {
log.Debugf("route rule %s not found", ruleKey)
return nil
}
if nftRule.Handle == 0 {
return fmt.Errorf("route rule %s has no handle", ruleKey)
}
setName := r.findSetNameInRule(nftRule)
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
return fmt.Errorf("delete: %w", err)
}
if setName != "" {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
return fmt.Errorf("decrement ipset reference: %w", err)
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) {
// overlapping prefixes will result in an error, so we need to merge them
sources = firewall.MergeIPRanges(sources)
set := &nftables.Set{
Name: setName,
Table: r.workTable,
// required for prefixes
Interval: true,
KeyType: nftables.TypeIPAddr,
}
var elements []nftables.SetElement
for _, prefix := range sources {
// TODO: Implement IPv6 support
if prefix.Addr().Is6() {
log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue
}
// nftables needs half-open intervals [firstIP, lastIP) for prefixes
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
firstIP := prefix.Addr()
lastIP := calculateLastIP(prefix).Next()
elements = append(elements,
// the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247
// nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true},
nftables.SetElement{Key: firstIP.AsSlice()},
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
)
}
if err := r.conn.AddSet(set, elements); err != nil {
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
return set, nil
}
// calculateLastIP determines the last IP in a given prefix.
func calculateLastIP(prefix netip.Prefix) netip.Addr {
hostMask := ^uint32(0) >> prefix.Masked().Bits()
lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask
return netip.AddrFrom4(uint32ToBytes(lastIP))
}
// Utility function to convert netip.Addr to uint32.
func uint32FromNetipAddr(addr netip.Addr) uint32 {
b := addr.As4()
return binary.BigEndian.Uint32(b[:])
}
// Utility function to convert uint32 to a netip-compatible byte slice.
func uint32ToBytes(ip uint32) [4]byte {
var b [4]byte
binary.BigEndian.PutUint32(b[:], ip)
return b
}
func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
r.conn.DelSet(set)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
log.Debugf("Deleted unused ipset %s", setName)
return nil
}
func (r *router) findSetNameInRule(rule *nftables.Rule) string {
for _, e := range rule.Exprs {
if lookup, ok := e.(*expr.Lookup); ok {
return lookup.SetName
}
}
return ""
}
func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule %s: %w", ruleKey, err)
}
delete(r.rules, ruleKey)
log.Debugf("removed route rule %s", ruleKey)
return nil
}
// AddNatRule appends a nftables rule pair to the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil {
return fmt.Errorf("add legacy routing rule: %w", err)
}
}
if pair.Masquerade {
if err := r.addNatRule(pair); err != nil {
return fmt.Errorf("add nat rule: %w", err)
}
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("add inverse nat rule: %w", err)
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err)
}
return nil
}
// addNatRule inserts a nftables rule to the conn client flush queue
func (r *router) addNatRule(pair firewall.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
dir := expr.MetaKeyIIFNAME
notDir := expr.MetaKeyOIFNAME
if pair.Inverse {
dir = expr.MetaKeyOIFNAME
notDir = expr.MetaKeyIIFNAME
}
lo := ifname("lo")
intf := ifname(r.wgIface.Name())
exprs := []expr.Any{
&expr.Meta{
Key: dir,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
// We need to exclude the loopback interface as this changes the ebpf proxy port
&expr.Meta{
Key: notDir,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: lo,
},
}
exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...)
exprs = append(exprs,
&expr.Counter{}, &expr.Masq{},
)
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if _, exists := r.rules[ruleKey]; exists {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove routing rule: %w", err)
}
}
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs,
UserData: []byte(ruleKey),
})
return nil
}
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
exprs := []expr.Any{
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if _, exists := r.rules[ruleKey]; exists {
if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err)
}
}
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Exprs: expression,
UserData: []byte(ruleKey),
})
return nil
}
// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
} else {
log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey)
}
return nil
}
// GetLegacyManagement returns the route manager's legacy management mode
func (r *router) GetLegacyManagement() bool {
return r.legacyManagement
}
// SetLegacyManagement sets the route manager to use legacy management mode
func (r *router) SetLegacyManagement(isLegacy bool) {
r.legacyManagement = isLegacy
}
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
func (r *router) RemoveAllLegacyRouteRules() error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
for k, rule := range r.rules {
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
continue
}
if err := r.conn.DelRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
} else {
delete(r.rules, k)
}
}
return nberrors.FormatErrorOrNil(merr)
}
// acceptForwardRules adds iif/oif rules in the filter table/forward chain to make sure
// that our traffic is not dropped by existing rules there.
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
func (r *router) acceptForwardRules() error {
if r.filterTable == nil {
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
return nil
}
fw := "iptables"
defer func() {
log.Debugf("Used %s to add accept forward rules", fw)
}()
// Try iptables first and fallback to nftables if iptables is not available
ipt, err := iptables.New()
if err != nil {
// filter table exists but iptables is not
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
fw = "nftables"
return r.acceptForwardRulesNftables()
}
return r.acceptForwardRulesIptables(ipt)
}
func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err))
} else {
log.Debugf("added iptables rule: %v", rule)
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) getAcceptForwardRules() [][]string {
intf := r.wgIface.Name()
return [][]string{
{"-i", intf, "-j", "ACCEPT"},
{"-o", intf, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"},
}
}
func (r *router) acceptForwardRulesNftables() error {
intf := ifname(r.wgIface.Name())
// Rule for incoming interface (iif) with counter
iifRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: chainNameForward,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
},
UserData: []byte(userDataAcceptForwardRuleIif),
}
r.conn.InsertRule(iifRule)
oifExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
}
// Rule for outgoing interface (oif) with counter
oifRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: append(oifExprs, getEstablishedExprs(2)...),
UserData: []byte(userDataAcceptForwardRuleOif),
}
r.conn.InsertRule(oifRule)
return nil
}
func (r *router) removeAcceptForwardRules() error {
if r.filterTable == nil {
return nil
}
// Try iptables first and fallback to nftables if iptables is not available
ipt, err := iptables.New()
if err != nil {
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
return r.removeAcceptForwardRulesNftables()
}
return r.removeAcceptForwardRulesIptables(ipt)
}
func (r *router) removeAcceptForwardRulesNftables() error {
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list chains: %v", err)
}
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
continue
}
rules, err := r.conn.GetRules(r.filterTable, chain)
if err != nil {
return fmt.Errorf("get rules: %v", err)
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule: %v", err)
}
}
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
// RemoveNatRule removes a nftables rule pair from nat chains
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse nat rule: %w", err)
}
if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
}
log.Debugf("nftables: removed nat rules for %s", pair.Destination)
return nil
}
// removeNatRule adds a nftables rule to the removal queue and deletes it from the rules map
func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
err := r.conn.DelRule(rule)
if err != nil {
return fmt.Errorf("remove nat rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("nftables: removed nat rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
} else {
log.Debugf("nftables: nat rule %s not found", ruleKey)
}
return nil
}
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
// duplicates and to get missing attributes that we don't have when adding new rules
func (r *router) refreshRulesMap() error {
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("nftables: unable to list rules: %v", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
r.rules[string(rule.UserData)] = rule
}
}
}
return nil
}
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
var offset uint32
if source {
offset = 12 // src offset
} else {
offset = 16 // dst offset
}
ones := prefix.Bits()
// 0.0.0.0/0 doesn't need extra expressions
if ones == 0 {
return nil
}
mask := net.CIDRMask(ones, 32)
return []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offset,
Len: 4,
},
// netmask
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 4,
Mask: mask,
Xor: []byte{0, 0, 0, 0},
},
// net address
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: prefix.Masked().Addr().AsSlice(),
},
}
}
func applyPort(port *firewall.Port, isSource bool) []expr.Any {
if port == nil {
return nil
}
var exprs []expr.Any
offset := uint32(2) // Default offset for destination port
if isSource {
offset = 0 // Offset for source port
}
exprs = append(exprs, &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: offset,
Len: 2,
})
if port.IsRange && len(port.Values) == 2 {
// Handle port range
exprs = append(exprs,
&expr.Cmp{
Op: expr.CmpOpGte,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[0])),
},
&expr.Cmp{
Op: expr.CmpOpLte,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[1])),
},
)
} else {
// Handle single port or multiple ports
for i, p := range port.Values {
if i > 0 {
// Add a bitwise OR operation between port checks
exprs = append(exprs, &expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: []byte{0x00, 0x00, 0xff, 0xff},
Xor: []byte{0x00, 0x00, 0x00, 0x00},
})
}
exprs = append(exprs, &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(p)),
})
}
}
return exprs
}

View File

@@ -3,15 +3,12 @@
package nftables package nftables
import ( import (
"encoding/binary" "context"
"net/netip"
"os/exec"
"testing" "testing"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -27,57 +24,56 @@ const (
NFTABLES NFTABLES
) )
func TestNftablesManager_AddNatRule(t *testing.T) { func TestNftablesManager_InsertRoutingRules(t *testing.T) {
if check() != NFTABLES { if check() != NFTABLES {
t.Skip("nftables not supported on this OS") t.Skip("nftables not supported on this OS")
} }
table, err := createWorkTable() table, err := createWorkTable()
require.NoError(t, err, "Failed to create work table") if err != nil {
t.Fatal(err)
}
defer deleteWorkTable() defer deleteWorkTable()
for _, testCase := range test.InsertRuleTestCases { for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
manager, err := newRouter(table, ifaceMock) manager, err := newRouter(context.TODO(), table)
require.NoError(t, err, "failed to create router") require.NoError(t, err, "failed to create router")
require.NoError(t, manager.init(table))
nftablesTestingClient := &nftables.Conn{} nftablesTestingClient := &nftables.Conn{}
defer func(manager *router) { defer manager.ResetForwardRules()
require.NoError(t, manager.Reset(), "failed to reset rules")
}(manager)
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
err = manager.AddNatRule(testCase.InputPair) err = manager.AddRoutingRules(testCase.InputPair)
require.NoError(t, err, "pair should be inserted") defer func() {
_ = manager.RemoveRoutingRules(testCase.InputPair)
}()
require.NoError(t, err, "forwarding pair should be inserted")
defer func(manager *router, pair firewall.RouterPair) { sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
require.NoError(t, manager.RemoveNatRule(pair), "failed to remove rule") destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
}(manager, testCase.InputPair) testingExpression := append(sourceExp, destExp...) //nolint:gocritic
fwdRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
if testCase.InputPair.Masquerade { if testCase.InputPair.Masquerade {
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
testingExpression = append(testingExpression,
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(ifaceMock.Name()),
},
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
)
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
found := 0 found := 0
for _, chain := range manager.chains { for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain) rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
@@ -92,26 +88,27 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
require.Equal(t, 1, found, "should find at least 1 rule to test") require.Equal(t, 1, found, "should find at least 1 rule to test")
} }
if testCase.InputPair.Masquerade { sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) testingExpression = append(sourceExp, destExp...) //nolint:gocritic
testingExpression := append(sourceExp, destExp...) //nolint:gocritic inFwdRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
testingExpression = append(testingExpression,
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(ifaceMock.Name()),
},
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
)
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) found = 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
if testCase.InputPair.Masquerade {
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
found := 0 found := 0
for _, chain := range manager.chains { for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain) rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
@@ -125,38 +122,45 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
} }
require.Equal(t, 1, found, "should find at least 1 rule to test") require.Equal(t, 1, found, "should find at least 1 rule to test")
} }
}) })
} }
} }
func TestNftablesManager_RemoveNatRule(t *testing.T) { func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
if check() != NFTABLES { if check() != NFTABLES {
t.Skip("nftables not supported on this OS") t.Skip("nftables not supported on this OS")
} }
table, err := createWorkTable() table, err := createWorkTable()
require.NoError(t, err, "Failed to create work table") if err != nil {
t.Fatal(err)
}
defer deleteWorkTable() defer deleteWorkTable()
for _, testCase := range test.RemoveRuleTestCases { for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
manager, err := newRouter(table, ifaceMock) manager, err := newRouter(context.TODO(), table)
require.NoError(t, err, "failed to create router") require.NoError(t, err, "failed to create router")
require.NoError(t, manager.init(table))
nftablesTestingClient := &nftables.Conn{} nftablesTestingClient := &nftables.Conn{}
defer func(manager *router) { defer manager.ResetForwardRules()
require.NoError(t, manager.Reset(), "failed to reset rules")
}(manager)
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRouteingFw],
Exprs: forwardExp,
UserData: []byte(forwardRuleKey),
})
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{ insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable, Table: manager.workTable,
@@ -165,11 +169,20 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
UserData: []byte(natRuleKey), UserData: []byte(natRuleKey),
}) })
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInversePair(testCase.InputPair).Source) sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
destExp = generateCIDRMatcherExpressions(false, firewall.GetInversePair(testCase.InputPair).Destination) destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRouteingFw],
Exprs: forwardExp,
UserData: []byte(inForwardRuleKey),
})
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{ insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable, Table: manager.workTable,
@@ -181,10 +194,9 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
err = nftablesTestingClient.Flush() err = nftablesTestingClient.Flush()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
err = manager.Reset() manager.ResetForwardRules()
require.NoError(t, err, "shouldn't return error")
err = manager.RemoveNatRule(testCase.InputPair) err = manager.RemoveRoutingRules(testCase.InputPair)
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
for _, chain := range manager.chains { for _, chain := range manager.chains {
@@ -192,7 +204,9 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules { for _, rule := range rules {
if len(rule.UserData) > 0 { if len(rule.UserData) > 0 {
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist") require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist") require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
} }
} }
@@ -201,470 +215,6 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
} }
} }
func TestRouter_AddRouteFiltering(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock)
require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
defer func(r *router) {
require.NoError(t, r.Reset(), "Failed to reset rules")
}(r)
tests := []struct {
name string
sources []netip.Prefix
destination netip.Prefix
proto firewall.Protocol
sPort *firewall.Port
dPort *firewall.Port
direction firewall.RuleDirection
action firewall.Action
expectSet bool
}{
{
name: "Basic TCP rule with single source",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolTCP,
sPort: nil,
dPort: &firewall.Port{Values: []int{80}},
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with multiple sources",
sources: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("192.168.0.0/16"),
},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolUDP,
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionDrop,
expectSet: true,
},
{
name: "All protocols rule",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
destination: netip.MustParsePrefix("0.0.0.0/0"),
proto: firewall.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "ICMP rule",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolICMP,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "TCP rule with multiple source ports",
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
destination: netip.MustParsePrefix("192.168.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with single IP and port range",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolUDP,
sPort: nil,
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
},
{
name: "TCP rule with source and destination ports",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
destination: netip.MustParsePrefix("172.16.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []int{22}},
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "Drop all incoming traffic",
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
destination: netip.MustParsePrefix("192.168.0.0/24"),
proto: firewall.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
t.Cleanup(func() {
require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule")
})
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.GetRuleID()]
assert.True(t, ok, "Rule not found in internal map")
t.Log("Internal rule expressions:")
for i, expr := range rule.Exprs {
t.Logf(" [%d] %T: %+v", i, expr, expr)
}
// Verify internal rule content
verifyRule(t, rule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
// Check if the rule exists in nftables and verify its content
rules, err := r.conn.GetRules(r.workTable, r.chains[chainNameRoutingFw])
require.NoError(t, err, "Failed to get rules from nftables")
var nftRule *nftables.Rule
for _, rule := range rules {
if string(rule.UserData) == ruleKey.GetRuleID() {
nftRule = rule
break
}
}
require.NotNil(t, nftRule, "Rule not found in nftables")
t.Log("Actual nftables rule expressions:")
for i, expr := range nftRule.Exprs {
t.Logf(" [%d] %T: %+v", i, expr, expr)
}
// Verify actual nftables rule content
verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
})
}
}
func TestNftablesCreateIpSet(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock)
require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
defer func() {
require.NoError(t, r.Reset(), "Failed to reset router")
}()
tests := []struct {
name string
sources []netip.Prefix
expected []netip.Prefix
}{
{
name: "Single IP",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
},
{
name: "Multiple IPs",
sources: []netip.Prefix{
netip.MustParsePrefix("192.168.1.1/32"),
netip.MustParsePrefix("10.0.0.1/32"),
netip.MustParsePrefix("172.16.0.1/32"),
},
},
{
name: "Single Subnet",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
},
{
name: "Multiple Subnets with Various Prefix Lengths",
sources: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("203.0.113.0/26"),
},
},
{
name: "Mix of Single IPs and Subnets in Different Positions",
sources: []netip.Prefix{
netip.MustParsePrefix("192.168.1.1/32"),
netip.MustParsePrefix("10.0.0.0/16"),
netip.MustParsePrefix("172.16.0.1/32"),
netip.MustParsePrefix("203.0.113.0/24"),
},
},
{
name: "Overlapping IPs/Subnets",
sources: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("10.0.0.0/16"),
netip.MustParsePrefix("10.0.0.1/32"),
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.1.1/32"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.0.0/16"),
},
},
}
// Add this helper function inside TestNftablesCreateIpSet
printNftSets := func() {
cmd := exec.Command("nft", "list", "sets")
output, err := cmd.CombinedOutput()
if err != nil {
t.Logf("Failed to run 'nft list sets': %v", err)
} else {
t.Logf("Current nft sets:\n%s", output)
}
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setName := firewall.GenerateSetName(tt.sources)
set, err := r.createIpSet(setName, tt.sources)
if err != nil {
t.Logf("Failed to create IP set: %v", err)
printNftSets()
require.NoError(t, err, "Failed to create IP set")
}
require.NotNil(t, set, "Created set is nil")
// Verify set properties
assert.Equal(t, setName, set.Name, "Set name mismatch")
assert.Equal(t, r.workTable, set.Table, "Set table mismatch")
assert.True(t, set.Interval, "Set interval property should be true")
assert.Equal(t, nftables.TypeIPAddr, set.KeyType, "Set key type mismatch")
// Fetch the created set from nftables
fetchedSet, err := r.conn.GetSetByName(r.workTable, setName)
require.NoError(t, err, "Failed to fetch created set")
require.NotNil(t, fetchedSet, "Fetched set is nil")
// Verify set elements
elements, err := r.conn.GetSetElements(fetchedSet)
require.NoError(t, err, "Failed to get set elements")
// Count the number of unique prefixes (excluding interval end markers)
uniquePrefixes := make(map[string]bool)
for _, elem := range elements {
if !elem.IntervalEnd {
ip := netip.AddrFrom4(*(*[4]byte)(elem.Key))
uniquePrefixes[ip.String()] = true
}
}
// Check against expected merged prefixes
expectedCount := len(tt.expected)
if expectedCount == 0 {
expectedCount = len(tt.sources)
}
assert.Equal(t, expectedCount, len(uniquePrefixes), "Number of unique prefixes in set doesn't match expected")
// Verify each expected prefix is in the set
for _, expected := range tt.expected {
found := false
for _, elem := range elements {
if !elem.IntervalEnd {
ip := netip.AddrFrom4(*(*[4]byte)(elem.Key))
if expected.Contains(ip) {
found = true
break
}
}
}
assert.True(t, found, "Expected prefix %s not found in set", expected)
}
r.conn.DelSet(set)
if err := r.conn.Flush(); err != nil {
t.Logf("Failed to delete set: %v", err)
printNftSets()
}
require.NoError(t, err, "Failed to delete set")
})
}
}
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) {
t.Helper()
assert.NotNil(t, rule, "Rule should not be nil")
// Verify sources and destination
if expectSet {
assert.True(t, containsSetLookup(rule.Exprs), "Rule should contain set lookup for multiple sources")
} else if len(sources) == 1 && sources[0].Bits() != 0 {
if direction == firewall.RuleDirectionIN {
assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], true), "Rule should contain source CIDR matcher for %s", sources[0])
} else {
assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], false), "Rule should contain destination CIDR matcher for %s", sources[0])
}
}
if direction == firewall.RuleDirectionIN {
assert.True(t, containsCIDRMatcher(rule.Exprs, destination, false), "Rule should contain destination CIDR matcher for %s", destination)
} else {
assert.True(t, containsCIDRMatcher(rule.Exprs, destination, true), "Rule should contain source CIDR matcher for %s", destination)
}
// Verify protocol
if proto != firewall.ProtocolALL {
assert.True(t, containsProtocol(rule.Exprs, proto), "Rule should contain protocol matcher for %s", proto)
}
// Verify ports
if sPort != nil {
assert.True(t, containsPort(rule.Exprs, sPort, true), "Rule should contain source port matcher for %v", sPort)
}
if dPort != nil {
assert.True(t, containsPort(rule.Exprs, dPort, false), "Rule should contain destination port matcher for %v", dPort)
}
// Verify action
assert.True(t, containsAction(rule.Exprs, action), "Rule should contain correct action: %s", action)
}
func containsSetLookup(exprs []expr.Any) bool {
for _, e := range exprs {
if _, ok := e.(*expr.Lookup); ok {
return true
}
}
return false
}
func containsCIDRMatcher(exprs []expr.Any, prefix netip.Prefix, isSource bool) bool {
var offset uint32
if isSource {
offset = 12 // src offset
} else {
offset = 16 // dst offset
}
var payloadFound, bitwiseFound, cmpFound bool
for _, e := range exprs {
switch ex := e.(type) {
case *expr.Payload:
if ex.Base == expr.PayloadBaseNetworkHeader && ex.Offset == offset && ex.Len == 4 {
payloadFound = true
}
case *expr.Bitwise:
if ex.Len == 4 && len(ex.Mask) == 4 && len(ex.Xor) == 4 {
bitwiseFound = true
}
case *expr.Cmp:
if ex.Op == expr.CmpOpEq && len(ex.Data) == 4 {
cmpFound = true
}
}
}
return (payloadFound && bitwiseFound && cmpFound) || prefix.Bits() == 0
}
func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
var offset uint32 = 2 // Default offset for destination port
if isSource {
offset = 0 // Offset for source port
}
var payloadFound, portMatchFound bool
for _, e := range exprs {
switch ex := e.(type) {
case *expr.Payload:
if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 {
payloadFound = true
}
case *expr.Cmp:
if port.IsRange {
if ex.Op == expr.CmpOpGte || ex.Op == expr.CmpOpLte {
portMatchFound = true
}
} else {
if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 {
portValue := binary.BigEndian.Uint16(ex.Data)
for _, p := range port.Values {
if uint16(p) == portValue {
portMatchFound = true
break
}
}
}
}
}
if payloadFound && portMatchFound {
return true
}
}
return false
}
func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
var metaFound, cmpFound bool
expectedProto, _ := protoToInt(proto)
for _, e := range exprs {
switch ex := e.(type) {
case *expr.Meta:
if ex.Key == expr.MetaKeyL4PROTO {
metaFound = true
}
case *expr.Cmp:
if ex.Op == expr.CmpOpEq && len(ex.Data) == 1 && ex.Data[0] == expectedProto {
cmpFound = true
}
}
}
return metaFound && cmpFound
}
func containsAction(exprs []expr.Any, action firewall.Action) bool {
for _, e := range exprs {
if verdict, ok := e.(*expr.Verdict); ok {
switch action {
case firewall.ActionAccept:
return verdict.Kind == expr.VerdictAccept
case firewall.ActionDrop:
return verdict.Kind == expr.VerdictDrop
}
}
}
return false
}
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found. // check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func check() int { func check() int {
nf := nftables.Conn{} nf := nftables.Conn{}
@@ -700,12 +250,12 @@ func createWorkTable() (*nftables.Table, error) {
} }
for _, t := range tables { for _, t := range tables {
if t.Name == tableNameNetbird { if t.Name == tableName {
sConn.DelTable(t) sConn.DelTable(t)
} }
} }
table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}) table := sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
err = sConn.Flush() err = sConn.Flush()
return table, err return table, err
@@ -723,7 +273,7 @@ func deleteWorkTable() {
} }
for _, t := range tables { for _, t := range tables {
if t.Name == tableNameNetbird { if t.Name == tableName {
sConn.DelTable(t) sConn.DelTable(t)
} }
} }

View File

@@ -1 +0,0 @@
package nftables

View File

@@ -1,47 +0,0 @@
package nftables
import (
"fmt"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
}
func (i *InterfaceState) Name() string {
return i.NameStr
}
func (i *InterfaceState) Address() device.WGAddress {
return i.WGAddress
}
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct {
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
}
func (s *ShutdownState) Name() string {
return "nftables_state"
}
func (s *ShutdownState) Cleanup() error {
nft, err := Create(s.InterfaceState)
if err != nil {
return fmt.Errorf("create nftables manager: %w", err)
}
if err := nft.Reset(nil); err != nil {
return fmt.Errorf("reset nftables manager: %w", err)
}
return nil
}

View File

@@ -1,10 +1,8 @@
//go:build !android
package test package test
import ( import firewall "github.com/netbirdio/netbird/client/firewall/manager"
"net/netip"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
var ( var (
InsertRuleTestCases = []struct { InsertRuleTestCases = []struct {
@@ -15,8 +13,8 @@ var (
Name: "Insert Forwarding IPV4 Rule", Name: "Insert Forwarding IPV4 Rule",
InputPair: firewall.RouterPair{ InputPair: firewall.RouterPair{
ID: "zxa", ID: "zxa",
Source: netip.MustParsePrefix("100.100.100.1/32"), Source: "100.100.100.1/32",
Destination: netip.MustParsePrefix("100.100.200.0/24"), Destination: "100.100.200.0/24",
Masquerade: false, Masquerade: false,
}, },
}, },
@@ -24,8 +22,8 @@ var (
Name: "Insert Forwarding And Nat IPV4 Rules", Name: "Insert Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{ InputPair: firewall.RouterPair{
ID: "zxa", ID: "zxa",
Source: netip.MustParsePrefix("100.100.100.1/32"), Source: "100.100.100.1/32",
Destination: netip.MustParsePrefix("100.100.200.0/24"), Destination: "100.100.200.0/24",
Masquerade: true, Masquerade: true,
}, },
}, },
@@ -40,8 +38,8 @@ var (
Name: "Remove Forwarding And Nat IPV4 Rules", Name: "Remove Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{ InputPair: firewall.RouterPair{
ID: "zxa", ID: "zxa",
Source: netip.MustParsePrefix("100.100.100.1/32"), Source: "100.100.100.1/32",
Destination: netip.MustParsePrefix("100.100.200.0/24"), Destination: "100.100.200.0/24",
Masquerade: true, Masquerade: true,
}, },
}, },

View File

@@ -2,10 +2,8 @@
package uspfilter package uspfilter
import "github.com/netbirdio/netbird/client/internal/statemanager"
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset(stateManager *statemanager.Manager) error { func (m *Manager) Reset() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -13,7 +11,7 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.incomingRules = make(map[string]RuleSet) m.incomingRules = make(map[string]RuleSet)
if m.nativeFirewall != nil { if m.nativeFirewall != nil {
return m.nativeFirewall.Reset(stateManager) return m.nativeFirewall.Reset()
} }
return nil return nil
} }

View File

@@ -6,8 +6,6 @@ import (
"syscall" "syscall"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
type action string type action string
@@ -19,7 +17,7 @@ const (
) )
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset(*statemanager.Manager) error { func (m *Manager) Reset() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()

View File

@@ -3,7 +3,6 @@ package uspfilter
import ( import (
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
"github.com/google/gopacket" "github.com/google/gopacket"
@@ -12,9 +11,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const layerTypeAll = 0 const layerTypeAll = 0
@@ -25,7 +22,7 @@ var (
// IFaceMapper defines subset methods of interface required for manager // IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface { type IFaceMapper interface {
SetFilter(device.PacketFilter) error SetFilter(iface.PacketFilter) error
Address() iface.WGAddress Address() iface.WGAddress
} }
@@ -98,10 +95,6 @@ func create(iface IFaceMapper) (*Manager, error) {
return m, nil return m, nil
} }
func (m *Manager) Init(*statemanager.Manager) error {
return nil
}
func (m *Manager) IsServerRouteSupported() bool { func (m *Manager) IsServerRouteSupported() bool {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
return false return false
@@ -110,26 +103,26 @@ func (m *Manager) IsServerRouteSupported() bool {
} }
} }
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
return errRouteNotSupported return errRouteNotSupported
} }
return m.nativeFirewall.AddNatRule(pair) return m.nativeFirewall.InsertRoutingRules(pair)
} }
// RemoveNatRule removes a routing firewall rule // RemoveRoutingRules removes a routing firewall rule
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
return errRouteNotSupported return errRouteNotSupported
} }
return m.nativeFirewall.RemoveNatRule(pair) return m.nativeFirewall.RemoveRoutingRules(pair)
} }
// AddPeerFiltering rule to the firewall // AddFiltering rule to the firewall
// //
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
func (m *Manager) AddPeerFiltering( func (m *Manager) AddFiltering(
ip net.IP, ip net.IP,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
@@ -195,22 +188,8 @@ func (m *Manager) AddPeerFiltering(
return []firewall.Rule{&r}, nil return []firewall.Rule{&r}, nil
} }
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) { // DeleteRule from the firewall by rule definition
if m.nativeFirewall == nil { func (m *Manager) DeleteRule(rule firewall.Rule) error {
return nil, errRouteNotSupported
}
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
}
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
if m.nativeFirewall == nil {
return errRouteNotSupported
}
return m.nativeFirewall.DeleteRouteRule(rule)
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -236,14 +215,6 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
return nil return nil
} }
// SetLegacyManagement doesn't need to be implemented for this manager
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
if m.nativeFirewall == nil {
return errRouteNotSupported
}
return m.nativeFirewall.SetLegacyManagement(isLegacy)
}
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil } func (m *Manager) Flush() error { return nil }
@@ -424,7 +395,7 @@ func (m *Manager) RemovePacketHook(hookID string) error {
for _, r := range arr { for _, r := range arr {
if r.id == hookID { if r.id == hookID {
rule := r rule := r
return m.DeletePeerRule(&rule) return m.DeleteRule(&rule)
} }
} }
} }
@@ -432,7 +403,7 @@ func (m *Manager) RemovePacketHook(hookID string) error {
for _, r := range arr { for _, r := range arr {
if r.id == hookID { if r.id == hookID {
rule := r rule := r
return m.DeletePeerRule(&rule) return m.DeleteRule(&rule)
} }
} }
} }

View File

@@ -11,16 +11,15 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/client/iface/device"
) )
type IFaceMock struct { type IFaceMock struct {
SetFilterFunc func(device.PacketFilter) error SetFilterFunc func(iface.PacketFilter) error
AddressFunc func() iface.WGAddress AddressFunc func() iface.WGAddress
} }
func (i *IFaceMock) SetFilter(iface device.PacketFilter) error { func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
if i.SetFilterFunc == nil { if i.SetFilterFunc == nil {
return fmt.Errorf("not implemented") return fmt.Errorf("not implemented")
} }
@@ -36,7 +35,7 @@ func (i *IFaceMock) Address() iface.WGAddress {
func TestManagerCreate(t *testing.T) { func TestManagerCreate(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(iface.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock) m, err := Create(ifaceMock)
@@ -50,10 +49,10 @@ func TestManagerCreate(t *testing.T) {
} }
} }
func TestManagerAddPeerFiltering(t *testing.T) { func TestManagerAddFiltering(t *testing.T) {
isSetFilterCalled := false isSetFilterCalled := false
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { SetFilterFunc: func(iface.PacketFilter) error {
isSetFilterCalled = true isSetFilterCalled = true
return nil return nil
}, },
@@ -72,7 +71,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule"
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -91,7 +90,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
func TestManagerDeleteRule(t *testing.T) { func TestManagerDeleteRule(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(iface.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock) m, err := Create(ifaceMock)
@@ -107,7 +106,7 @@ func TestManagerDeleteRule(t *testing.T) {
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule"
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -120,14 +119,14 @@ func TestManagerDeleteRule(t *testing.T) {
action = fw.ActionDrop action = fw.ActionDrop
comment = "Test rule 2" comment = "Test rule 2"
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
} }
for _, r := range rule { for _, r := range rule {
err = m.DeletePeerRule(r) err = m.DeleteRule(r)
if err != nil { if err != nil {
t.Errorf("failed to delete rule: %v", err) t.Errorf("failed to delete rule: %v", err)
return return
@@ -141,7 +140,7 @@ func TestManagerDeleteRule(t *testing.T) {
} }
for _, r := range rule2 { for _, r := range rule2 {
err = m.DeletePeerRule(r) err = m.DeleteRule(r)
if err != nil { if err != nil {
t.Errorf("failed to delete rule: %v", err) t.Errorf("failed to delete rule: %v", err)
return return
@@ -237,7 +236,7 @@ func TestAddUDPPacketHook(t *testing.T) {
func TestManagerReset(t *testing.T) { func TestManagerReset(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(iface.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock) m, err := Create(ifaceMock)
@@ -253,13 +252,13 @@ func TestManagerReset(t *testing.T) {
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) _, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
} }
err = m.Reset(nil) err = m.Reset()
if err != nil { if err != nil {
t.Errorf("failed to reset Manager: %v", err) t.Errorf("failed to reset Manager: %v", err)
return return
@@ -272,7 +271,7 @@ func TestManagerReset(t *testing.T) {
func TestNotMatchByIP(t *testing.T) { func TestNotMatchByIP(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(iface.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock) m, err := Create(ifaceMock)
@@ -291,7 +290,7 @@ func TestNotMatchByIP(t *testing.T) {
action := fw.ActionAccept action := fw.ActionAccept
comment := "Test rule" comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, nil, direction, action, "", comment) _, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -330,7 +329,7 @@ func TestNotMatchByIP(t *testing.T) {
return return
} }
if err = m.Reset(nil); err != nil { if err = m.Reset(); err != nil {
t.Errorf("failed to reset Manager: %v", err) t.Errorf("failed to reset Manager: %v", err)
return return
} }
@@ -340,7 +339,7 @@ func TestNotMatchByIP(t *testing.T) {
func TestRemovePacketHook(t *testing.T) { func TestRemovePacketHook(t *testing.T) {
// creating mock iface // creating mock iface
iface := &IFaceMock{ iface := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(iface.PacketFilter) error { return nil },
} }
// creating manager instance // creating manager instance
@@ -389,14 +388,14 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface // just check on the local interface
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(iface.PacketFilter) error { return nil },
} }
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
if err := manager.Reset(nil); err != nil { if err := manager.Reset(); err != nil {
t.Errorf("clear the manager state: %v", err) t.Errorf("clear the manager state: %v", err)
} }
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -407,9 +406,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 { if i%2 == 0 {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else { } else {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
} }
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")

View File

@@ -1,5 +0,0 @@
package bind
import wgConn "golang.zx2c4.com/wireguard/conn"
type Endpoint = wgConn.StdNetEndpoint

View File

@@ -1,275 +0,0 @@
package bind
import (
"fmt"
"net"
"net/netip"
"runtime"
"strings"
"sync"
"github.com/pion/stun/v2"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4"
wgConn "golang.zx2c4.com/wireguard/conn"
)
type RecvMessage struct {
Endpoint *Endpoint
Buffer []byte
}
type receiverCreator struct {
iceBind *ICEBind
}
func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn)
}
// ICEBind is a bind implementation with two main features:
// 1. filter out STUN messages and handle them
// 2. forward the received packets to the WireGuard interface from the relayed connection
//
// ICEBind.endpoints var is a map that stores the connection for each relayed peer. Fake address is just an IP address
// without port, in the format of 127.1.x.x where x.x is the last two octets of the peer address. We try to avoid to
// use the port because in the Send function the wgConn.Endpoint the port info is not exported.
type ICEBind struct {
*wgConn.StdNetBind
RecvChan chan RecvMessage
transportNet transport.Net
filterFn FilterFn
endpoints map[netip.Addr]net.Conn
endpointsMu sync.Mutex
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
// new closed channel. With the closedChanMu we can safely close the channel and create a new one
closedChan chan struct{}
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
closed bool
muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault
}
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{
StdNetBind: b,
RecvChan: make(chan RecvMessage, 1),
transportNet: transportNet,
filterFn: filterFn,
endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}),
closed: true,
}
rc := receiverCreator{
ib,
}
ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc)
return ib
}
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
s.closed = false
s.closedChanMu.Lock()
s.closedChan = make(chan struct{})
s.closedChanMu.Unlock()
fns, port, err := s.StdNetBind.Open(uport)
if err != nil {
return nil, 0, err
}
fns = append(fns, s.receiveRelayed)
return fns, port, nil
}
func (s *ICEBind) Close() error {
if s.closed {
return nil
}
s.closed = true
close(s.closedChan)
return s.StdNetBind.Close()
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
if s.udpMux == nil {
return nil, fmt.Errorf("ICEBind has not been initialized yet")
}
return s.udpMux, nil
}
func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) {
fakeUDPAddr, err := fakeAddress(peerAddress)
if err != nil {
return nil, err
}
// force IPv4
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
return nil, fmt.Errorf("failed to convert IP to netip.Addr")
}
b.endpointsMu.Lock()
b.endpoints[fakeAddr] = conn
b.endpointsMu.Unlock()
return fakeUDPAddr, nil
}
func (b *ICEBind) RemoveEndpoint(fakeUDPAddr *net.UDPAddr) {
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
log.Warnf("failed to convert IP to netip.Addr")
return
}
b.endpointsMu.Lock()
defer b.endpointsMu.Unlock()
delete(b.endpoints, fakeAddr)
}
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
b.endpointsMu.Lock()
conn, ok := b.endpoints[ep.DstIP()]
b.endpointsMu.Unlock()
if !ok {
return b.StdNetBind.Send(bufs, ep)
}
for _, buf := range bufs {
if _, err := conn.Write(buf); err != nil {
return err
}
}
return nil
}
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: conn,
Net: s.transportNet,
FilterFn: s.filterFn,
},
)
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := ipv4MsgsPool.Get().(*[]ipv4.Message)
defer ipv4MsgsPool.Put(msgs)
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
}
var numMsgs int
if runtime.GOOS == "linux" {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
// todo: handle err
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
if ok {
sizes[i] = 0
} else {
sizes[i] = msg.N
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
}
}
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
for i := range buffers {
if !stun.IsMessage(buffers[i]) {
continue
}
msg, err := s.parseSTUNMessage(buffers[i][:n])
if err != nil {
buffers[i] = []byte{}
return true, err
}
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
if muxErr != nil {
log.Warnf("failed to handle STUN packet")
}
buffers[i] = []byte{}
return true, nil
}
return false, nil
}
func (s *ICEBind) parseSTUNMessage(raw []byte) (*stun.Message, error) {
msg := &stun.Message{
Raw: raw,
}
if err := msg.Decode(); err != nil {
return nil, err
}
return msg, nil
}
// receiveRelayed is a receive function that is used to receive packets from the relayed connection and forward to the
// WireGuard. Critical part is do not block if the Closed() has been called.
func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
c.closedChanMu.RLock()
defer c.closedChanMu.RUnlock()
select {
case <-c.closedChan:
return 0, net.ErrClosed
case msg, ok := <-c.RecvChan:
if !ok {
return 0, net.ErrClosed
}
copy(buffs[0], msg.Buffer)
sizes[0] = len(msg.Buffer)
eps[0] = wgConn.Endpoint(msg.Endpoint)
return 1, nil
}
}
// fakeAddress returns a fake address that is used to as an identifier for the peer.
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
octets := strings.Split(peerAddress.IP.String(), ".")
if len(octets) != 4 {
return nil, fmt.Errorf("invalid IP format")
}
newAddr := &net.UDPAddr{
IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])),
Port: peerAddress.Port,
}
return newAddr, nil
}

View File

@@ -1,5 +0,0 @@
package configurer
import "errors"
var ErrPeerNotFound = errors.New("peer not found")

View File

@@ -1,9 +0,0 @@
package configurer
import "time"
type WGStats struct {
LastHandshake time.Time
TxBytes int64
RxBytes int64
}

View File

@@ -1,18 +0,0 @@
//go:build !android
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
)
type WGTunDevice interface {
Create() (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address WGAddress) error
WgAddress() WGAddress
DeviceName() string
Close() error
FilteredDevice() *device.FilteredDevice
}

View File

@@ -1,20 +0,0 @@
package device
import (
"net"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/configurer"
)
type WGConfigurer interface {
ConfigureInterface(privateKey string, port int) error
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error
Close()
GetStats(peerKey string) (configurer.WGStats, error)
}

View File

@@ -1,4 +0,0 @@
package device
// CustomWindowsGUIDString is a custom GUID string for the interface
var CustomWindowsGUIDString string

View File

@@ -1,16 +0,0 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
)
type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address WGAddress) error
WgAddress() WGAddress
DeviceName() string
Close() error
FilteredDevice() *device.FilteredDevice
}

View File

@@ -1,24 +0,0 @@
package iface
import (
"fmt"
)
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
w.mu.Lock()
defer w.mu.Unlock()
cfgr, err := w.tun.Create(routes, dns, searchDomains)
if err != nil {
return err
}
w.configurer = cfgr
return nil
}
// Create this function make sense on mobile only
func (w *WGIface) Create() error {
return fmt.Errorf("this function has not implemented on this platform")
}

View File

@@ -1,10 +0,0 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/device"
)
// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only
func (w *WGIface) GetInterfaceGUIDString() (string, error) {
return w.tun.(*device.TunDevice).GetInterfaceGUIDString()
}

View File

@@ -1,24 +0,0 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace := &WGIface{
userspaceBind: true,
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil
}

View File

@@ -1,34 +0,0 @@
//go:build !ios
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
var tun WGTunDevice
if netstack.IsEnabled() {
tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
} else {
tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
}
wgIFace := &WGIface{
userspaceBind: true,
tun: tun,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil
}

View File

@@ -1,26 +0,0 @@
//go:build ios
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace := &WGIface{
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd),
userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil
}

View File

@@ -1,45 +0,0 @@
//go:build (linux && !android) || freebsd
package iface
import (
"fmt"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{}
if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
return wgIFace, nil
}
if device.WireGuardModuleIsLoaded() {
wgIFace.tun = device.NewKernelDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet)
wgIFace.wgProxyFactory = wgproxy.NewKernelFactory(opts.WGPort)
return wgIFace, nil
}
if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
return wgIFace, nil
}
return nil, fmt.Errorf("couldn't check or load tun module")
}

View File

@@ -1,32 +0,0 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
var tun WGTunDevice
if netstack.IsEnabled() {
tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
} else {
tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
}
wgIFace := &WGIface{
userspaceBind: true,
tun: tun,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil
}

View File

@@ -1,137 +0,0 @@
package bind
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind"
)
type ProxyBind struct {
Bind *bind.ICEBind
wgAddr *net.UDPAddr
wgEndpoint *bind.Endpoint
remoteConn net.Conn
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex
closed bool
pausedMu sync.Mutex
paused bool
isStarted bool
}
// AddTurnConn adds a new connection to the bind.
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
// WireGuard configuration.
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.Bind.SetEndpoint(nbAddr, remoteConn)
if err != nil {
return err
}
p.wgAddr = addr
p.wgEndpoint = addrToEndpoint(addr)
p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx)
return err
}
func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
return p.wgAddr
}
func (p *ProxyBind) Work() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = false
p.pausedMu.Unlock()
// Start the proxy only once
if !p.isStarted {
p.isStarted = true
go p.proxyToLocal(p.ctx)
}
}
func (p *ProxyBind) Pause() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = true
p.pausedMu.Unlock()
}
func (p *ProxyBind) CloseConn() error {
if p.cancel == nil {
return fmt.Errorf("proxy not started")
}
return p.close()
}
func (p *ProxyBind) close() error {
p.closeMu.Lock()
defer p.closeMu.Unlock()
if p.closed {
return nil
}
p.closed = true
p.cancel()
p.Bind.RemoveEndpoint(p.wgAddr)
return p.remoteConn.Close()
}
func (p *ProxyBind) proxyToLocal(ctx context.Context) {
defer func() {
if err := p.close(); err != nil {
log.Warnf("failed to close remote conn: %s", err)
}
}()
buf := make([]byte, 1500)
for {
n, err := p.remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return
}
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return
}
p.pausedMu.Lock()
if p.paused {
p.pausedMu.Unlock()
continue
}
msg := bind.RecvMessage{
Endpoint: p.wgEndpoint,
Buffer: buf[:n],
}
p.Bind.RecvChan <- msg
p.pausedMu.Unlock()
}
}
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
ip, _ := netip.AddrFromSlice(addr.IP.To4())
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}
}

View File

@@ -1,126 +0,0 @@
//go:build linux && !android
package ebpf
import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"
log "github.com/sirupsen/logrus"
)
// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
type ProxyWrapper struct {
WgeBPFProxy *WGEBPFProxy
remoteConn net.Conn
ctx context.Context
cancel context.CancelFunc
wgEndpointAddr *net.UDPAddr
pausedMu sync.Mutex
paused bool
isStarted bool
}
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
if err != nil {
return fmt.Errorf("add turn conn: %w", err)
}
p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx)
p.wgEndpointAddr = addr
return err
}
func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr {
return p.wgEndpointAddr
}
func (p *ProxyWrapper) Work() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = false
p.pausedMu.Unlock()
if !p.isStarted {
p.isStarted = true
go p.proxyToLocal(p.ctx)
}
}
func (p *ProxyWrapper) Pause() {
if p.remoteConn == nil {
return
}
log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr())
p.pausedMu.Lock()
p.paused = true
p.pausedMu.Unlock()
}
// CloseConn close the remoteConn and automatically remove the conn instance from the map
func (e *ProxyWrapper) CloseConn() error {
if e.cancel == nil {
return fmt.Errorf("proxy not started")
}
e.cancel()
if err := e.remoteConn.Close(); err != nil {
return fmt.Errorf("failed to close remote conn: %w", err)
}
return nil
}
func (p *ProxyWrapper) proxyToLocal(ctx context.Context) {
defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port))
buf := make([]byte, 1500)
for {
n, err := p.readFromRemote(ctx, buf)
if err != nil {
return
}
p.pausedMu.Lock()
if p.paused {
p.pausedMu.Unlock()
continue
}
err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port)
p.pausedMu.Unlock()
if err != nil {
if ctx.Err() != nil {
return
}
log.Errorf("failed to write out turn pkg to local conn: %v", err)
}
}
}
func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, error) {
n, err := p.remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return 0, ctx.Err()
}
if !errors.Is(err, io.EOF) {
log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err)
}
return 0, err
}
return n, nil
}

View File

@@ -1,49 +0,0 @@
//go:build linux && !android
package wgproxy
import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
type KernelFactory struct {
wgPort int
ebpfProxy *ebpf.WGEBPFProxy
}
func NewKernelFactory(wgPort int) *KernelFactory {
f := &KernelFactory{
wgPort: wgPort,
}
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort)
if err := ebpfProxy.Listen(); err != nil {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
return f
}
log.Infof("WireGuard Proxy Factory will produce eBPF proxy")
f.ebpfProxy = ebpfProxy
return f
}
func (w *KernelFactory) GetProxy() Proxy {
if w.ebpfProxy == nil {
return udpProxy.NewWGUDPProxy(w.wgPort)
}
return &ebpf.ProxyWrapper{
WgeBPFProxy: w.ebpfProxy,
}
}
func (w *KernelFactory) Free() error {
if w.ebpfProxy == nil {
return nil
}
return w.ebpfProxy.Free()
}

View File

@@ -1,29 +0,0 @@
package wgproxy
import (
log "github.com/sirupsen/logrus"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
// KernelFactory todo: check eBPF support on FreeBSD
type KernelFactory struct {
wgPort int
}
func NewKernelFactory(wgPort int) *KernelFactory {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
f := &KernelFactory{
wgPort: wgPort,
}
return f
}
func (w *KernelFactory) GetProxy() Proxy {
return udpProxy.NewWGUDPProxy(w.wgPort)
}
func (w *KernelFactory) Free() error {
return nil
}

View File

@@ -1,30 +0,0 @@
package wgproxy
import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind"
proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
)
type USPFactory struct {
bind *bind.ICEBind
}
func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
log.Infof("WireGuard Proxy Factory will produce bind proxy")
f := &USPFactory{
bind: iceBind,
}
return f
}
func (w *USPFactory) GetProxy() Proxy {
return &proxyBind.ProxyBind{
Bind: w.bind,
}
}
func (w *USPFactory) Free() error {
return nil
}

View File

@@ -1,15 +0,0 @@
package wgproxy
import (
"context"
"net"
)
// Proxy is a transfer layer between the relayed connection and the WireGuard
type Proxy interface {
AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error
EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint
Work() // Work start or resume the proxy
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
CloseConn() error
}

View File

@@ -1,56 +0,0 @@
//go:build linux && !android
package wgproxy
import (
"context"
"os"
"testing"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
)
func TestProxyCloseByRemoteConnEBPF(t *testing.T) {
if os.Getenv("GITHUB_ACTIONS") != "true" {
t.Skip("Skipping test as it requires root privileges")
}
ctx := context.Background()
ebpfProxy := ebpf.NewWGEBPFProxy(51831)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %s", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %s", err)
}
}()
tests := []struct {
name string
proxy Proxy
}{
{
name: "ebpf proxy",
proxy: &ebpf.ProxyWrapper{
WgeBPFProxy: ebpfProxy,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
relayedConn := newMockConn()
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
if err != nil {
t.Errorf("error: %v", err)
}
_ = relayedConn.Close()
if err := tt.proxy.CloseConn(); err != nil {
t.Errorf("error: %v", err)
}
})
}
}

View File

@@ -1,128 +0,0 @@
//go:build linux
package wgproxy
import (
"context"
"io"
"net"
"os"
"runtime"
"testing"
"time"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
"github.com/netbirdio/netbird/util"
)
func TestMain(m *testing.M) {
_ = util.InitLog("trace", "console")
code := m.Run()
os.Exit(code)
}
type mocConn struct {
closeChan chan struct{}
closed bool
}
func newMockConn() *mocConn {
return &mocConn{
closeChan: make(chan struct{}),
}
}
func (m *mocConn) Read(b []byte) (n int, err error) {
<-m.closeChan
return 0, io.EOF
}
func (m *mocConn) Write(b []byte) (n int, err error) {
<-m.closeChan
return 0, io.EOF
}
func (m *mocConn) Close() error {
if m.closed == true {
return nil
}
m.closed = true
close(m.closeChan)
return nil
}
func (m *mocConn) LocalAddr() net.Addr {
panic("implement me")
}
func (m *mocConn) RemoteAddr() net.Addr {
return &net.UDPAddr{
IP: net.ParseIP("172.16.254.1"),
}
}
func (m *mocConn) SetDeadline(t time.Time) error {
panic("implement me")
}
func (m *mocConn) SetReadDeadline(t time.Time) error {
panic("implement me")
}
func (m *mocConn) SetWriteDeadline(t time.Time) error {
panic("implement me")
}
func TestProxyCloseByRemoteConn(t *testing.T) {
ctx := context.Background()
tests := []struct {
name string
proxy Proxy
}{
{
name: "userspace proxy",
proxy: udpProxy.NewWGUDPProxy(51830),
},
}
if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" {
ebpfProxy := ebpf.NewWGEBPFProxy(51831)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %s", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %s", err)
}
}()
proxyWrapper := &ebpf.ProxyWrapper{
WgeBPFProxy: ebpfProxy,
}
tests = append(tests, struct {
name string
proxy Proxy
}{
name: "ebpf proxy",
proxy: proxyWrapper,
})
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
relayedConn := newMockConn()
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
if err != nil {
t.Errorf("error: %v", err)
}
_ = relayedConn.Close()
if err := tt.proxy.CloseConn(); err != nil {
t.Errorf("error: %v", err)
}
})
}
}

View File

@@ -1,207 +0,0 @@
package udp
import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
cerrors "github.com/netbirdio/netbird/client/errors"
)
// WGUDPProxy proxies
type WGUDPProxy struct {
localWGListenPort int
remoteConn net.Conn
localConn net.Conn
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex
closed bool
pausedMu sync.Mutex
paused bool
isStarted bool
}
// NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation
func NewWGUDPProxy(wgPort int) *WGUDPProxy {
log.Debugf("Initializing new user space proxy with port %d", wgPort)
p := &WGUDPProxy{
localWGListenPort: wgPort,
}
return p
}
// AddTurnConn
// The provided Context must be non-nil. If the context expires before
// the connection is complete, an error is returned. Once successfully
// connected, any expiration of the context will not affect the
// connection.
func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
dialer := net.Dialer{}
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil {
log.Errorf("failed dialing to local Wireguard port %s", err)
return err
}
p.ctx, p.cancel = context.WithCancel(ctx)
p.localConn = localConn
p.remoteConn = remoteConn
return err
}
func (p *WGUDPProxy) EndpointAddr() *net.UDPAddr {
if p.localConn == nil {
return nil
}
endpointUdpAddr, _ := net.ResolveUDPAddr(p.localConn.LocalAddr().Network(), p.localConn.LocalAddr().String())
return endpointUdpAddr
}
// Work starts the proxy or resumes it if it was paused
func (p *WGUDPProxy) Work() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = false
p.pausedMu.Unlock()
if !p.isStarted {
p.isStarted = true
go p.proxyToRemote(p.ctx)
go p.proxyToLocal(p.ctx)
}
}
// Pause pauses the proxy from receiving data from the remote peer
func (p *WGUDPProxy) Pause() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = true
p.pausedMu.Unlock()
}
// CloseConn close the localConn
func (p *WGUDPProxy) CloseConn() error {
if p.cancel == nil {
return fmt.Errorf("proxy not started")
}
return p.close()
}
func (p *WGUDPProxy) close() error {
p.closeMu.Lock()
defer p.closeMu.Unlock()
// prevent double close
if p.closed {
return nil
}
p.closed = true
p.cancel()
var result *multierror.Error
if err := p.remoteConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
}
if err := p.localConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
}
return cerrors.FormatErrorOrNil(result)
}
// proxyToRemote proxies from Wireguard to the RemoteKey
func (p *WGUDPProxy) proxyToRemote(ctx context.Context) {
defer func() {
if err := p.close(); err != nil {
log.Warnf("error in proxy to remote loop: %s", err)
}
}()
buf := make([]byte, 1500)
for ctx.Err() == nil {
n, err := p.localConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return
}
log.Debugf("failed to read from wg interface conn: %s", err)
return
}
_, err = p.remoteConn.Write(buf[:n])
if err != nil {
if ctx.Err() != nil {
return
}
log.Debugf("failed to write to remote conn: %s", err)
return
}
}
}
// proxyToLocal proxies from the Remote peer to local WireGuard
// if the proxy is paused it will drain the remote conn and drop the packets
func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
defer func() {
if err := p.close(); err != nil {
if !errors.Is(err, io.EOF) {
log.Warnf("error in proxy to local loop: %s", err)
}
}
}()
buf := make([]byte, 1500)
for {
n, err := p.remoteConnRead(ctx, buf)
if err != nil {
return
}
p.pausedMu.Lock()
if p.paused {
p.pausedMu.Unlock()
continue
}
_, err = p.localConn.Write(buf[:n])
p.pausedMu.Unlock()
if err != nil {
if ctx.Err() != nil {
return
}
log.Debugf("failed to write to wg interface conn: %s", err)
continue
}
}
}
func (p *WGUDPProxy) remoteConnRead(ctx context.Context, buf []byte) (n int, err error) {
n, err = p.remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return
}
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.LocalAddr(), err)
return
}
return
}

View File

@@ -1,64 +0,0 @@
package id
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net/netip"
"strconv"
"github.com/netbirdio/netbird/client/firewall/manager"
)
type RuleID string
func (r RuleID) GetRuleID() string {
return string(r)
}
func GenerateRouteRuleKey(
sources []netip.Prefix,
destination netip.Prefix,
proto manager.Protocol,
sPort *manager.Port,
dPort *manager.Port,
action manager.Action,
) RuleID {
manager.SortPrefixes(sources)
h := sha256.New()
// Write all fields to the hasher, with delimiters
h.Write([]byte("sources:"))
for _, src := range sources {
h.Write([]byte(src.String()))
h.Write([]byte(","))
}
h.Write([]byte("destination:"))
h.Write([]byte(destination.String()))
h.Write([]byte("proto:"))
h.Write([]byte(proto))
h.Write([]byte("sPort:"))
if sPort != nil {
h.Write([]byte(sPort.String()))
} else {
h.Write([]byte("<nil>"))
}
h.Write([]byte("dPort:"))
if dPort != nil {
h.Write([]byte(dPort.String()))
} else {
h.Write([]byte("<nil>"))
}
h.Write([]byte("action:"))
h.Write([]byte(strconv.Itoa(int(action))))
hash := hex.EncodeToString(h.Sum(nil))
// prepend destination prefix to be able to identify the rule
return RuleID(fmt.Sprintf("%s-%s", destination.String(), hash[:16]))
}

View File

@@ -3,26 +3,19 @@ package acl
import ( import (
"crypto/md5" "crypto/md5"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"net" "net"
"net/netip"
"strconv" "strconv"
"sync" "sync"
"time" "time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
) )
var ErrSourceRangesEmpty = errors.New("sources range is empty")
// Manager is a ACL rules manager // Manager is a ACL rules manager
type Manager interface { type Manager interface {
ApplyFiltering(networkMap *mgmProto.NetworkMap) ApplyFiltering(networkMap *mgmProto.NetworkMap)
@@ -30,18 +23,16 @@ type Manager interface {
// DefaultManager uses firewall manager to handle // DefaultManager uses firewall manager to handle
type DefaultManager struct { type DefaultManager struct {
firewall firewall.Manager firewall firewall.Manager
ipsetCounter int ipsetCounter int
peerRulesPairs map[id.RuleID][]firewall.Rule rulesPairs map[string][]firewall.Rule
routeRules map[id.RuleID]struct{} mutex sync.Mutex
mutex sync.Mutex
} }
func NewDefaultManager(fm firewall.Manager) *DefaultManager { func NewDefaultManager(fm firewall.Manager) *DefaultManager {
return &DefaultManager{ return &DefaultManager{
firewall: fm, firewall: fm,
peerRulesPairs: make(map[id.RuleID][]firewall.Rule), rulesPairs: make(map[string][]firewall.Rule),
routeRules: make(map[id.RuleID]struct{}),
} }
} }
@@ -55,7 +46,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
start := time.Now() start := time.Now()
defer func() { defer func() {
total := 0 total := 0
for _, pairs := range d.peerRulesPairs { for _, pairs := range d.rulesPairs {
total += len(pairs) total += len(pairs)
} }
log.Infof( log.Infof(
@@ -68,34 +59,21 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
return return
} }
d.applyPeerACLs(networkMap) defer func() {
if err := d.firewall.Flush(); err != nil {
log.Error("failed to flush firewall rules: ", err)
}
}()
// If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
// then the mgmt server is older than the client, and we need to allow all traffic for routes
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
if err := d.firewall.SetLegacyManagement(isLegacy); err != nil {
log.Errorf("failed to set legacy management flag: %v", err)
}
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules); err != nil {
log.Errorf("Failed to apply route ACLs: %v", err)
}
if err := d.firewall.Flush(); err != nil {
log.Error("failed to flush firewall rules: ", err)
}
}
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
rules, squashedProtocols := d.squashAcceptRules(networkMap) rules, squashedProtocols := d.squashAcceptRules(networkMap)
enableSSH := networkMap.PeerConfig != nil && enableSSH := (networkMap.PeerConfig != nil &&
networkMap.PeerConfig.SshConfig != nil && networkMap.PeerConfig.SshConfig != nil &&
networkMap.PeerConfig.SshConfig.SshEnabled networkMap.PeerConfig.SshConfig.SshEnabled)
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok { if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok {
enableSSH = enableSSH && !ok enableSSH = enableSSH && !ok
} }
if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok { if _, ok := squashedProtocols[mgmProto.FirewallRule_TCP]; ok {
enableSSH = enableSSH && !ok enableSSH = enableSSH && !ok
} }
@@ -105,9 +83,9 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
if enableSSH { if enableSSH {
rules = append(rules, &mgmProto.FirewallRule{ rules = append(rules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0", PeerIP: "0.0.0.0",
Direction: mgmProto.RuleDirection_IN, Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP, Protocol: mgmProto.FirewallRule_TCP,
Port: strconv.Itoa(ssh.DefaultSSHPort), Port: strconv.Itoa(ssh.DefaultSSHPort),
}) })
} }
@@ -119,20 +97,20 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
rules = append(rules, rules = append(rules,
&mgmProto.FirewallRule{ &mgmProto.FirewallRule{
PeerIP: "0.0.0.0", PeerIP: "0.0.0.0",
Direction: mgmProto.RuleDirection_IN, Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
&mgmProto.FirewallRule{ &mgmProto.FirewallRule{
PeerIP: "0.0.0.0", PeerIP: "0.0.0.0",
Direction: mgmProto.RuleDirection_OUT, Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
) )
} }
newRulePairs := make(map[id.RuleID][]firewall.Rule) newRulePairs := make(map[string][]firewall.Rule)
ipsetByRuleSelectors := make(map[string]string) ipsetByRuleSelectors := make(map[string]string)
for _, r := range rules { for _, r := range rules {
@@ -152,106 +130,29 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
break break
} }
if len(rules) > 0 { if len(rules) > 0 {
d.peerRulesPairs[pairID] = rulePair d.rulesPairs[pairID] = rulePair
newRulePairs[pairID] = rulePair newRulePairs[pairID] = rulePair
} }
} }
for pairID, rules := range d.peerRulesPairs { for pairID, rules := range d.rulesPairs {
if _, ok := newRulePairs[pairID]; !ok { if _, ok := newRulePairs[pairID]; !ok {
for _, rule := range rules { for _, rule := range rules {
if err := d.firewall.DeletePeerRule(rule); err != nil { if err := d.firewall.DeleteRule(rule); err != nil {
log.Errorf("failed to delete peer firewall rule: %v", err) log.Errorf("failed to delete firewall rule: %v", err)
continue continue
} }
} }
delete(d.peerRulesPairs, pairID) delete(d.rulesPairs, pairID)
} }
} }
d.peerRulesPairs = newRulePairs d.rulesPairs = newRulePairs
}
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error {
newRouteRules := make(map[id.RuleID]struct{}, len(rules))
var merr *multierror.Error
// Apply new rules - firewall manager will return existing rule ID if already present
for _, rule := range rules {
id, err := d.applyRouteACL(rule)
if err != nil {
if errors.Is(err, ErrSourceRangesEmpty) {
log.Debugf("skipping empty rule with destination %s: %v", rule.Destination, err)
} else {
merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err))
}
continue
}
newRouteRules[id] = struct{}{}
}
// Clean up old firewall rules
for id := range d.routeRules {
if _, exists := newRouteRules[id]; !exists {
if err := d.firewall.DeleteRouteRule(id); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete route rule: %w", err))
}
// implicitly deleted from the map
}
}
d.routeRules = newRouteRules
return nberrors.FormatErrorOrNil(merr)
}
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) {
if len(rule.SourceRanges) == 0 {
return "", ErrSourceRangesEmpty
}
var sources []netip.Prefix
for _, sourceRange := range rule.SourceRanges {
source, err := netip.ParsePrefix(sourceRange)
if err != nil {
return "", fmt.Errorf("parse source range: %w", err)
}
sources = append(sources, source)
}
var destination netip.Prefix
if rule.IsDynamic {
destination = getDefault(sources[0])
} else {
var err error
destination, err = netip.ParsePrefix(rule.Destination)
if err != nil {
return "", fmt.Errorf("parse destination: %w", err)
}
}
protocol, err := convertToFirewallProtocol(rule.Protocol)
if err != nil {
return "", fmt.Errorf("invalid protocol: %w", err)
}
action, err := convertFirewallAction(rule.Action)
if err != nil {
return "", fmt.Errorf("invalid action: %w", err)
}
dPorts := convertPortInfo(rule.PortInfo)
addedRule, err := d.firewall.AddRouteFiltering(sources, destination, protocol, nil, dPorts, action)
if err != nil {
return "", fmt.Errorf("add route rule: %w", err)
}
return id.RuleID(addedRule.GetRuleID()), nil
} }
func (d *DefaultManager) protoRuleToFirewallRule( func (d *DefaultManager) protoRuleToFirewallRule(
r *mgmProto.FirewallRule, r *mgmProto.FirewallRule,
ipsetName string, ipsetName string,
) (id.RuleID, []firewall.Rule, error) { ) (string, []firewall.Rule, error) {
ip := net.ParseIP(r.PeerIP) ip := net.ParseIP(r.PeerIP)
if ip == nil { if ip == nil {
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule") return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
@@ -278,16 +179,16 @@ func (d *DefaultManager) protoRuleToFirewallRule(
} }
} }
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action, "") ruleID := d.getRuleID(ip, protocol, int(r.Direction), port, action, "")
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok { if rulesPair, ok := d.rulesPairs[ruleID]; ok {
return ruleID, rulesPair, nil return ruleID, rulesPair, nil
} }
var rules []firewall.Rule var rules []firewall.Rule
switch r.Direction { switch r.Direction {
case mgmProto.RuleDirection_IN: case mgmProto.FirewallRule_IN:
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "") rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
case mgmProto.RuleDirection_OUT: case mgmProto.FirewallRule_OUT:
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "") rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
default: default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule") return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
@@ -309,7 +210,7 @@ func (d *DefaultManager) addInRules(
comment string, comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
var rules []firewall.Rule var rules []firewall.Rule
rule, err := d.firewall.AddPeerFiltering( rule, err := d.firewall.AddFiltering(
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment) ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("failed to add firewall rule: %v", err)
@@ -320,7 +221,7 @@ func (d *DefaultManager) addInRules(
return rules, nil return rules, nil
} }
rule, err = d.firewall.AddPeerFiltering( rule, err = d.firewall.AddFiltering(
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment) ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("failed to add firewall rule: %v", err)
@@ -338,7 +239,7 @@ func (d *DefaultManager) addOutRules(
comment string, comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
var rules []firewall.Rule var rules []firewall.Rule
rule, err := d.firewall.AddPeerFiltering( rule, err := d.firewall.AddFiltering(
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment) ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("failed to add firewall rule: %v", err)
@@ -349,7 +250,7 @@ func (d *DefaultManager) addOutRules(
return rules, nil return rules, nil
} }
rule, err = d.firewall.AddPeerFiltering( rule, err = d.firewall.AddFiltering(
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment) ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("failed to add firewall rule: %v", err)
@@ -358,21 +259,21 @@ func (d *DefaultManager) addOutRules(
return append(rules, rule...), nil return append(rules, rule...), nil
} }
// getPeerRuleID() returns unique ID for the rule based on its parameters. // getRuleID() returns unique ID for the rule based on its parameters.
func (d *DefaultManager) getPeerRuleID( func (d *DefaultManager) getRuleID(
ip net.IP, ip net.IP,
proto firewall.Protocol, proto firewall.Protocol,
direction int, direction int,
port *firewall.Port, port *firewall.Port,
action firewall.Action, action firewall.Action,
comment string, comment string,
) id.RuleID { ) string {
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
if port != nil { if port != nil {
idStr += port.String() idStr += port.String()
} }
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr)))) return hex.EncodeToString(md5.New().Sum([]byte(idStr)))
} }
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type // squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
@@ -382,7 +283,7 @@ func (d *DefaultManager) getPeerRuleID(
// but other has port definitions or has drop policy. // but other has port definitions or has drop policy.
func (d *DefaultManager) squashAcceptRules( func (d *DefaultManager) squashAcceptRules(
networkMap *mgmProto.NetworkMap, networkMap *mgmProto.NetworkMap,
) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) { ) ([]*mgmProto.FirewallRule, map[mgmProto.FirewallRuleProtocol]struct{}) {
totalIPs := 0 totalIPs := 0
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) { for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
for range p.AllowedIps { for range p.AllowedIps {
@@ -390,14 +291,14 @@ func (d *DefaultManager) squashAcceptRules(
} }
} }
type protoMatch map[mgmProto.RuleProtocol]map[string]int type protoMatch map[mgmProto.FirewallRuleProtocol]map[string]int
in := protoMatch{} in := protoMatch{}
out := protoMatch{} out := protoMatch{}
// trace which type of protocols was squashed // trace which type of protocols was squashed
squashedRules := []*mgmProto.FirewallRule{} squashedRules := []*mgmProto.FirewallRule{}
squashedProtocols := map[mgmProto.RuleProtocol]struct{}{} squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
// this function we use to do calculation, can we squash the rules by protocol or not. // this function we use to do calculation, can we squash the rules by protocol or not.
// We summ amount of Peers IP for given protocol we found in original rules list. // We summ amount of Peers IP for given protocol we found in original rules list.
@@ -407,7 +308,7 @@ func (d *DefaultManager) squashAcceptRules(
// //
// We zeroed this to notify squash function that this protocol can't be squashed. // We zeroed this to notify squash function that this protocol can't be squashed.
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols protoMatch) { addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols protoMatch) {
drop := r.Action == mgmProto.RuleAction_DROP || r.Port != "" drop := r.Action == mgmProto.FirewallRule_DROP || r.Port != ""
if drop { if drop {
protocols[r.Protocol] = map[string]int{} protocols[r.Protocol] = map[string]int{}
return return
@@ -435,7 +336,7 @@ func (d *DefaultManager) squashAcceptRules(
for i, r := range networkMap.FirewallRules { for i, r := range networkMap.FirewallRules {
// calculate squash for different directions // calculate squash for different directions
if r.Direction == mgmProto.RuleDirection_IN { if r.Direction == mgmProto.FirewallRule_IN {
addRuleToCalculationMap(i, r, in) addRuleToCalculationMap(i, r, in)
} else { } else {
addRuleToCalculationMap(i, r, out) addRuleToCalculationMap(i, r, out)
@@ -444,14 +345,14 @@ func (d *DefaultManager) squashAcceptRules(
// order of squashing by protocol is important // order of squashing by protocol is important
// only for their first element ALL, it must be done first // only for their first element ALL, it must be done first
protocolOrders := []mgmProto.RuleProtocol{ protocolOrders := []mgmProto.FirewallRuleProtocol{
mgmProto.RuleProtocol_ALL, mgmProto.FirewallRule_ALL,
mgmProto.RuleProtocol_ICMP, mgmProto.FirewallRule_ICMP,
mgmProto.RuleProtocol_TCP, mgmProto.FirewallRule_TCP,
mgmProto.RuleProtocol_UDP, mgmProto.FirewallRule_UDP,
} }
squash := func(matches protoMatch, direction mgmProto.RuleDirection) { squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) {
for _, protocol := range protocolOrders { for _, protocol := range protocolOrders {
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 { if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
// don't squash if : // don't squash if :
@@ -464,12 +365,12 @@ func (d *DefaultManager) squashAcceptRules(
squashedRules = append(squashedRules, &mgmProto.FirewallRule{ squashedRules = append(squashedRules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0", PeerIP: "0.0.0.0",
Direction: direction, Direction: direction,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: protocol, Protocol: protocol,
}) })
squashedProtocols[protocol] = struct{}{} squashedProtocols[protocol] = struct{}{}
if protocol == mgmProto.RuleProtocol_ALL { if protocol == mgmProto.FirewallRule_ALL {
// if we have ALL traffic type squashed rule // if we have ALL traffic type squashed rule
// it allows all other type of traffic, so we can stop processing // it allows all other type of traffic, so we can stop processing
break break
@@ -477,11 +378,11 @@ func (d *DefaultManager) squashAcceptRules(
} }
} }
squash(in, mgmProto.RuleDirection_IN) squash(in, mgmProto.FirewallRule_IN)
squash(out, mgmProto.RuleDirection_OUT) squash(out, mgmProto.FirewallRule_OUT)
// if all protocol was squashed everything is allow and we can ignore all other rules // if all protocol was squashed everything is allow and we can ignore all other rules
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok { if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok {
return squashedRules, squashedProtocols return squashedRules, squashedProtocols
} }
@@ -511,26 +412,26 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port) return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
} }
func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) { func (d *DefaultManager) rollBack(newRulePairs map[string][]firewall.Rule) {
log.Debugf("rollback ACL to previous state") log.Debugf("rollback ACL to previous state")
for _, rules := range newRulePairs { for _, rules := range newRulePairs {
for _, rule := range rules { for _, rule := range rules {
if err := d.firewall.DeletePeerRule(rule); err != nil { if err := d.firewall.DeleteRule(rule); err != nil {
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err) log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
} }
} }
} }
} }
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) { func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) (firewall.Protocol, error) {
switch protocol { switch protocol {
case mgmProto.RuleProtocol_TCP: case mgmProto.FirewallRule_TCP:
return firewall.ProtocolTCP, nil return firewall.ProtocolTCP, nil
case mgmProto.RuleProtocol_UDP: case mgmProto.FirewallRule_UDP:
return firewall.ProtocolUDP, nil return firewall.ProtocolUDP, nil
case mgmProto.RuleProtocol_ICMP: case mgmProto.FirewallRule_ICMP:
return firewall.ProtocolICMP, nil return firewall.ProtocolICMP, nil
case mgmProto.RuleProtocol_ALL: case mgmProto.FirewallRule_ALL:
return firewall.ProtocolALL, nil return firewall.ProtocolALL, nil
default: default:
return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String()) return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String())
@@ -541,41 +442,13 @@ func shouldSkipInvertedRule(protocol firewall.Protocol, port *firewall.Port) boo
return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil
} }
func convertFirewallAction(action mgmProto.RuleAction) (firewall.Action, error) { func convertFirewallAction(action mgmProto.FirewallRuleAction) (firewall.Action, error) {
switch action { switch action {
case mgmProto.RuleAction_ACCEPT: case mgmProto.FirewallRule_ACCEPT:
return firewall.ActionAccept, nil return firewall.ActionAccept, nil
case mgmProto.RuleAction_DROP: case mgmProto.FirewallRule_DROP:
return firewall.ActionDrop, nil return firewall.ActionDrop, nil
default: default:
return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action) return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action)
} }
} }
func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port {
if portInfo == nil {
return nil
}
if portInfo.GetPort() != 0 {
return &firewall.Port{
Values: []int{int(portInfo.GetPort())},
}
}
if portInfo.GetRange() != nil {
return &firewall.Port{
IsRange: true,
Values: []int{int(portInfo.GetRange().Start), int(portInfo.GetRange().End)},
}
}
return nil
}
func getDefault(prefix netip.Prefix) netip.Prefix {
if prefix.Addr().Is6() {
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
}
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
}

View File

@@ -1,6 +1,7 @@
package acl package acl
import ( import (
"context"
"net" "net"
"testing" "testing"
@@ -8,8 +9,8 @@ import (
"github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/acl/mocks" "github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/iface"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
) )
@@ -18,16 +19,16 @@ func TestDefaultManager(t *testing.T) {
FirewallRules: []*mgmProto.FirewallRule{ FirewallRules: []*mgmProto.FirewallRule{
{ {
PeerIP: "10.93.0.1", PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_OUT, Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP, Protocol: mgmProto.FirewallRule_TCP,
Port: "80", Port: "80",
}, },
{ {
PeerIP: "10.93.0.2", PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_OUT, Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.RuleAction_DROP, Action: mgmProto.FirewallRule_DROP,
Protocol: mgmProto.RuleProtocol_UDP, Protocol: mgmProto.FirewallRule_UDP,
Port: "53", Port: "53",
}, },
}, },
@@ -51,29 +52,29 @@ func TestDefaultManager(t *testing.T) {
}).AnyTimes() }).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it // we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil) fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
if err != nil { if err != nil {
t.Errorf("create firewall: %v", err) t.Errorf("create firewall: %v", err)
return return
} }
defer func(fw manager.Manager) { defer func(fw manager.Manager) {
_ = fw.Reset(nil) _ = fw.Reset()
}(fw) }(fw)
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)
t.Run("apply firewall rules", func(t *testing.T) { t.Run("apply firewall rules", func(t *testing.T) {
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap)
if len(acl.peerRulesPairs) != 2 { if len(acl.rulesPairs) != 2 {
t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs) t.Errorf("firewall rules not applied: %v", acl.rulesPairs)
return return
} }
}) })
t.Run("add extra rules", func(t *testing.T) { t.Run("add extra rules", func(t *testing.T) {
existedPairs := map[string]struct{}{} existedPairs := map[string]struct{}{}
for id := range acl.peerRulesPairs { for id := range acl.rulesPairs {
existedPairs[id.GetRuleID()] = struct{}{} existedPairs[id] = struct{}{}
} }
// remove first rule // remove first rule
@@ -82,24 +83,24 @@ func TestDefaultManager(t *testing.T) {
networkMap.FirewallRules, networkMap.FirewallRules,
&mgmProto.FirewallRule{ &mgmProto.FirewallRule{
PeerIP: "10.93.0.3", PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN, Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.RuleAction_DROP, Action: mgmProto.FirewallRule_DROP,
Protocol: mgmProto.RuleProtocol_ICMP, Protocol: mgmProto.FirewallRule_ICMP,
}, },
) )
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap)
// we should have one old and one new rule in the existed rules // we should have one old and one new rule in the existed rules
if len(acl.peerRulesPairs) != 2 { if len(acl.rulesPairs) != 2 {
t.Errorf("firewall rules not applied") t.Errorf("firewall rules not applied")
return return
} }
// check that old rule was removed // check that old rule was removed
previousCount := 0 previousCount := 0
for id := range acl.peerRulesPairs { for id := range acl.rulesPairs {
if _, ok := existedPairs[id.GetRuleID()]; ok { if _, ok := existedPairs[id]; ok {
previousCount++ previousCount++
} }
} }
@@ -112,15 +113,15 @@ func TestDefaultManager(t *testing.T) {
networkMap.FirewallRules = networkMap.FirewallRules[:0] networkMap.FirewallRules = networkMap.FirewallRules[:0]
networkMap.FirewallRulesIsEmpty = true networkMap.FirewallRulesIsEmpty = true
if acl.ApplyFiltering(networkMap); len(acl.peerRulesPairs) != 0 { if acl.ApplyFiltering(networkMap); len(acl.rulesPairs) != 0 {
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs)) t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.rulesPairs))
return return
} }
networkMap.FirewallRulesIsEmpty = false networkMap.FirewallRulesIsEmpty = false
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap)
if len(acl.peerRulesPairs) != 2 { if len(acl.rulesPairs) != 2 {
t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs)) t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.rulesPairs))
return return
} }
}) })
@@ -137,51 +138,51 @@ func TestDefaultManagerSquashRules(t *testing.T) {
FirewallRules: []*mgmProto.FirewallRule{ FirewallRules: []*mgmProto.FirewallRule{
{ {
PeerIP: "10.93.0.1", PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN, Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
{ {
PeerIP: "10.93.0.2", PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN, Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
{ {
PeerIP: "10.93.0.3", PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN, Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
{ {
PeerIP: "10.93.0.4", PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN, Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
{ {
PeerIP: "10.93.0.1", PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_OUT, Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
{ {
PeerIP: "10.93.0.2", PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_OUT, Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
{ {
PeerIP: "10.93.0.3", PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_OUT, Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
{ {
PeerIP: "10.93.0.4", PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_OUT, Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
}, },
} }
@@ -198,13 +199,13 @@ func TestDefaultManagerSquashRules(t *testing.T) {
case r.PeerIP != "0.0.0.0": case r.PeerIP != "0.0.0.0":
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP) t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
return return
case r.Direction != mgmProto.RuleDirection_IN: case r.Direction != mgmProto.FirewallRule_IN:
t.Errorf("direction should be IN, got: %v", r.Direction) t.Errorf("direction should be IN, got: %v", r.Direction)
return return
case r.Protocol != mgmProto.RuleProtocol_ALL: case r.Protocol != mgmProto.FirewallRule_ALL:
t.Errorf("protocol should be ALL, got: %v", r.Protocol) t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return return
case r.Action != mgmProto.RuleAction_ACCEPT: case r.Action != mgmProto.FirewallRule_ACCEPT:
t.Errorf("action should be ACCEPT, got: %v", r.Action) t.Errorf("action should be ACCEPT, got: %v", r.Action)
return return
} }
@@ -214,13 +215,13 @@ func TestDefaultManagerSquashRules(t *testing.T) {
case r.PeerIP != "0.0.0.0": case r.PeerIP != "0.0.0.0":
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP) t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
return return
case r.Direction != mgmProto.RuleDirection_OUT: case r.Direction != mgmProto.FirewallRule_OUT:
t.Errorf("direction should be OUT, got: %v", r.Direction) t.Errorf("direction should be OUT, got: %v", r.Direction)
return return
case r.Protocol != mgmProto.RuleProtocol_ALL: case r.Protocol != mgmProto.FirewallRule_ALL:
t.Errorf("protocol should be ALL, got: %v", r.Protocol) t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return return
case r.Action != mgmProto.RuleAction_ACCEPT: case r.Action != mgmProto.FirewallRule_ACCEPT:
t.Errorf("action should be ACCEPT, got: %v", r.Action) t.Errorf("action should be ACCEPT, got: %v", r.Action)
return return
} }
@@ -237,51 +238,51 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
FirewallRules: []*mgmProto.FirewallRule{ FirewallRules: []*mgmProto.FirewallRule{
{ {
PeerIP: "10.93.0.1", PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN, Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
{ {
PeerIP: "10.93.0.2", PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN, Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
{ {
PeerIP: "10.93.0.3", PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN, Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
{ {
PeerIP: "10.93.0.4", PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN, Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP, Protocol: mgmProto.FirewallRule_TCP,
}, },
{ {
PeerIP: "10.93.0.1", PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_OUT, Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
{ {
PeerIP: "10.93.0.2", PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_OUT, Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
{ {
PeerIP: "10.93.0.3", PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_OUT, Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
{ {
PeerIP: "10.93.0.4", PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_OUT, Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP, Protocol: mgmProto.FirewallRule_UDP,
}, },
}, },
} }
@@ -307,21 +308,21 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
FirewallRules: []*mgmProto.FirewallRule{ FirewallRules: []*mgmProto.FirewallRule{
{ {
PeerIP: "10.93.0.1", PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN, Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP, Protocol: mgmProto.FirewallRule_TCP,
}, },
{ {
PeerIP: "10.93.0.2", PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN, Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP, Protocol: mgmProto.FirewallRule_TCP,
}, },
{ {
PeerIP: "10.93.0.3", PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_OUT, Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP, Protocol: mgmProto.FirewallRule_UDP,
}, },
}, },
} }
@@ -344,20 +345,20 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
}).AnyTimes() }).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it // we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil) fw, err := firewall.NewFirewall(context.Background(), ifaceMock)
if err != nil { if err != nil {
t.Errorf("create firewall: %v", err) t.Errorf("create firewall: %v", err)
return return
} }
defer func(fw manager.Manager) { defer func(fw manager.Manager) {
_ = fw.Reset(nil) _ = fw.Reset()
}(fw) }(fw)
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap)
if len(acl.peerRulesPairs) != 4 { if len(acl.rulesPairs) != 4 {
t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.peerRulesPairs)) t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.rulesPairs))
return return
} }
} }

View File

@@ -8,8 +8,7 @@ import (
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
iface "github.com/netbirdio/netbird/client/iface" iface "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/client/iface/device"
) )
// MockIFaceMapper is a mock of IFaceMapper interface. // MockIFaceMapper is a mock of IFaceMapper interface.
@@ -78,7 +77,7 @@ func (mr *MockIFaceMapperMockRecorder) Name() *gomock.Call {
} }
// SetFilter mocks base method. // SetFilter mocks base method.
func (m *MockIFaceMapper) SetFilter(arg0 device.PacketFilter) error { func (m *MockIFaceMapper) SetFilter(arg0 iface.PacketFilter) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetFilter", arg0) ret := m.ctrl.Call(m, "SetFilter", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)

View File

@@ -16,9 +16,9 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client" mgm "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -117,11 +117,6 @@ type Config struct {
// ReadConfig read config file and return with Config. If it is not exists create a new with default values // ReadConfig read config file and return with Config. If it is not exists create a new with default values
func ReadConfig(configPath string) (*Config, error) { func ReadConfig(configPath string) (*Config, error) {
if configFileIsExists(configPath) { if configFileIsExists(configPath) {
err := util.EnforcePermission(configPath)
if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err)
}
config := &Config{} config := &Config{}
if _, err := util.ReadJson(configPath, config); err != nil { if _, err := util.ReadJson(configPath, config); err != nil {
return nil, err return nil, err
@@ -164,17 +159,13 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg) err = WriteOutConfig(input.ConfigPath, cfg)
return cfg, err return cfg, err
} }
if isPreSharedKeyHidden(input.PreSharedKey) { if isPreSharedKeyHidden(input.PreSharedKey) {
input.PreSharedKey = nil input.PreSharedKey = nil
} }
err := util.EnforcePermission(input.ConfigPath)
if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err)
}
return update(input) return update(input)
} }

View File

@@ -17,14 +17,13 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client" mgm "github.com/netbirdio/netbird/management/client"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/relay/auth/hmac" "github.com/netbirdio/netbird/relay/auth/hmac"
@@ -62,13 +61,16 @@ func (c *ConnectClient) Run() error {
} }
// RunWithProbes runs the client's main logic with probes attached // RunWithProbes runs the client's main logic with probes attached
func (c *ConnectClient) RunWithProbes(probes *ProbeHolder, runningChan chan error) error { func (c *ConnectClient) RunWithProbes(
probes *ProbeHolder,
runningChan chan error,
) error {
return c.run(MobileDependency{}, probes, runningChan) return c.run(MobileDependency{}, probes, runningChan)
} }
// RunOnAndroid with main logic on mobile system // RunOnAndroid with main logic on mobile system
func (c *ConnectClient) RunOnAndroid( func (c *ConnectClient) RunOnAndroid(
tunAdapter device.TunAdapter, tunAdapter iface.TunAdapter,
iFaceDiscover stdnet.ExternalIFaceDiscover, iFaceDiscover stdnet.ExternalIFaceDiscover,
networkChangeListener listener.NetworkChangeListener, networkChangeListener listener.NetworkChangeListener,
dnsAddresses []string, dnsAddresses []string,
@@ -101,7 +103,11 @@ func (c *ConnectClient) RunOniOS(
return c.run(mobileDependency, nil, nil) return c.run(mobileDependency, nil, nil)
} }
func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHolder, runningChan chan error) error { func (c *ConnectClient) run(
mobileDependency MobileDependency,
probes *ProbeHolder,
runningChan chan error,
) error {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack())) log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
@@ -110,6 +116,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH) log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
// Check if client was not shut down in a clean way and restore DNS config if required.
// Otherwise, we might not be able to connect to the management server to retrieve new config.
if err := dns.CheckUncleanShutdown(c.config.WgIface); err != nil {
log.Errorf("checking unclean shutdown error: %s", err)
}
backOff := &backoff.ExponentialBackOff{ backOff := &backoff.ExponentialBackOff{
InitialInterval: time.Second, InitialInterval: time.Second,
RandomizationFactor: 1, RandomizationFactor: 1,
@@ -146,7 +158,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
} }
defer c.statusRecorder.ClientStop() defer c.statusRecorder.ClientStop()
runningChanOpen := true
operation := func() error { operation := func() error {
// if context cancelled we not start new backoff cycle // if context cancelled we not start new backoff cycle
if c.isContextCancelled() { if c.isContextCancelled() {
@@ -193,7 +204,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
localPeerState := peer.LocalPeerState{ localPeerState := peer.LocalPeerState{
IP: loginResp.GetPeerConfig().GetAddress(), IP: loginResp.GetPeerConfig().GetAddress(),
PubKey: myPrivateKey.PublicKey().String(), PubKey: myPrivateKey.PublicKey().String(),
KernelInterface: device.WireGuardModuleIsLoaded(), KernelInterface: iface.WireGuardModuleIsLoaded(),
FQDN: loginResp.GetPeerConfig().GetFqdn(), FQDN: loginResp.GetPeerConfig().GetFqdn(),
} }
c.statusRecorder.UpdateLocalPeerState(localPeerState) c.statusRecorder.UpdateLocalPeerState(localPeerState)
@@ -268,22 +279,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress()) log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected) state.Set(StatusConnected)
if runningChan != nil && runningChanOpen { if runningChan != nil {
runningChan <- nil runningChan <- nil
close(runningChan) close(runningChan)
runningChanOpen = false
} }
<-engineCtx.Done() <-engineCtx.Done()
c.engineMutex.Lock()
if c.engine != nil && c.engine.wgInterface != nil {
log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name())
if err := c.engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
c.engine = nil
}
c.engineMutex.Unlock()
c.statusRecorder.ClientTeardown() c.statusRecorder.ClientTeardown()
backOff.Reset() backOff.Reset()
@@ -345,11 +346,7 @@ func (c *ConnectClient) Stop() error {
if c.engine == nil { if c.engine == nil {
return nil return nil
} }
if err := c.engine.Stop(); err != nil { return c.engine.Stop()
return fmt.Errorf("stop engine: %w", err)
}
return nil
} }
func (c *ConnectClient) isContextCancelled() bool { func (c *ConnectClient) isContextCancelled() bool {

View File

@@ -1,5 +1,6 @@
package dns package dns
const ( const (
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf" fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/db/netbird/manager"
) )

View File

@@ -3,5 +3,6 @@
package dns package dns
const ( const (
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf" fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
) )

View File

@@ -9,8 +9,6 @@ import (
"github.com/fsnotify/fsnotify" "github.com/fsnotify/fsnotify"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
var ( var (
@@ -22,7 +20,7 @@ var (
} }
) )
type repairConfFn func([]string, string, *resolvConf, *statemanager.Manager) error type repairConfFn func([]string, string, *resolvConf) error
type repair struct { type repair struct {
operationFile string operationFile string
@@ -42,7 +40,7 @@ func newRepair(operationFile string, updateFn repairConfFn) *repair {
} }
} }
func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string, stateManager *statemanager.Manager) { func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string) {
if f.inotify != nil { if f.inotify != nil {
return return
} }
@@ -83,7 +81,7 @@ func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP strin
log.Errorf("failed to rm inotify watch for resolv.conf: %s", err) log.Errorf("failed to rm inotify watch for resolv.conf: %s", err)
} }
err = f.updateFn(nbSearchDomains, nbNameserverIP, rConf, stateManager) err = f.updateFn(nbSearchDomains, nbNameserverIP, rConf)
if err != nil { if err != nil {
log.Errorf("failed to repair resolv.conf: %v", err) log.Errorf("failed to repair resolv.conf: %v", err)
} }

View File

@@ -9,7 +9,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -105,14 +104,14 @@ nameserver 8.8.8.8`,
var changed bool var changed bool
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error { updateFn := func([]string, string, *resolvConf) error {
changed = true changed = true
cancel() cancel()
return nil return nil
} }
r := newRepair(operationFile, updateFn) r := newRepair(operationFile, updateFn)
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil) r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1")
err = os.WriteFile(operationFile, []byte(tt.touchedConfContent), 0755) err = os.WriteFile(operationFile, []byte(tt.touchedConfContent), 0755)
if err != nil { if err != nil {
@@ -152,14 +151,14 @@ searchdomain netbird.cloud something`
var changed bool var changed bool
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error { updateFn := func([]string, string, *resolvConf) error {
changed = true changed = true
cancel() cancel()
return nil return nil
} }
r := newRepair(tmpLink, updateFn) r := newRepair(tmpLink, updateFn)
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil) r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1")
err = os.WriteFile(tmpLink, []byte(modifyContent), 0755) err = os.WriteFile(tmpLink, []byte(modifyContent), 0755)
if err != nil { if err != nil {

View File

@@ -11,8 +11,6 @@ import (
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const ( const (
@@ -38,7 +36,7 @@ type fileConfigurator struct {
nbNameserverIP string nbNameserverIP string
} }
func newFileConfigurator() (*fileConfigurator, error) { func newFileConfigurator() (hostManager, error) {
fc := &fileConfigurator{} fc := &fileConfigurator{}
fc.repair = newRepair(defaultResolvConfPath, fc.updateConfig) fc.repair = newRepair(defaultResolvConfPath, fc.updateConfig)
return fc, nil return fc, nil
@@ -48,7 +46,7 @@ func (f *fileConfigurator) supportCustomPort() bool {
return false return false
} }
func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
backupFileExist := f.isBackupFileExist() backupFileExist := f.isBackupFileExist()
if !config.RouteAll { if !config.RouteAll {
if backupFileExist { if backupFileExist {
@@ -78,15 +76,15 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *st
f.repair.stopWatchFileChanges() f.repair.stopWatchFileChanges()
err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf, stateManager) err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf)
if err != nil { if err != nil {
return err return err
} }
f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP, stateManager) f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP)
return nil return nil
} }
func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf, stateManager *statemanager.Manager) error { func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf) error {
searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains) searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains)
nameServers := generateNsList(nbNameserverIP, cfg) nameServers := generateNsList(nbNameserverIP, cfg)
@@ -109,7 +107,7 @@ func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP
log.Infof("created a NetBird managed %s file with the DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList) log.Infof("created a NetBird managed %s file with the DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList)
// create another backup for unclean shutdown detection right after overwriting the original resolv.conf // create another backup for unclean shutdown detection right after overwriting the original resolv.conf
if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, nbNameserverIP, stateManager); err != nil { if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, fileManager, nbNameserverIP); err != nil {
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err) log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
} }
@@ -147,6 +145,10 @@ func (f *fileConfigurator) restore() error {
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err) return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
} }
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
}
return os.RemoveAll(fileDefaultResolvConfBackupLocation) return os.RemoveAll(fileDefaultResolvConfBackupLocation)
} }
@@ -174,7 +176,7 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add
return restoreResolvConfFile() return restoreResolvConfFile()
} }
log.Infof("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: %s (current) vs %s (stored): not restoring", currentDNSAddress, storedDNSAddress) log.Info("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: not restoring")
return nil return nil
} }
@@ -190,6 +192,10 @@ func restoreResolvConfFile() error {
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation, err) return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation, err)
} }
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown resolv.conf file: %s", err)
}
return nil return nil
} }

View File

@@ -5,14 +5,14 @@ import (
"net/netip" "net/netip"
"strings" "strings"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
) )
type hostManager interface { type hostManager interface {
applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error applyDNSConfig(config HostDNSConfig) error
restoreHostDNS() error restoreHostDNS() error
supportCustomPort() bool supportCustomPort() bool
restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error
} }
type SystemDNSSettings struct { type SystemDNSSettings struct {
@@ -35,15 +35,15 @@ type DomainConfig struct {
} }
type mockHostConfigurator struct { type mockHostConfigurator struct {
applyDNSConfigFunc func(config HostDNSConfig, stateManager *statemanager.Manager) error applyDNSConfigFunc func(config HostDNSConfig) error
restoreHostDNSFunc func() error restoreHostDNSFunc func() error
supportCustomPortFunc func() bool supportCustomPortFunc func() bool
restoreUncleanShutdownDNSFunc func(*netip.Addr) error restoreUncleanShutdownDNSFunc func(*netip.Addr) error
} }
func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig) error {
if m.applyDNSConfigFunc != nil { if m.applyDNSConfigFunc != nil {
return m.applyDNSConfigFunc(config, stateManager) return m.applyDNSConfigFunc(config)
} }
return fmt.Errorf("method applyDNSSettings is not implemented") return fmt.Errorf("method applyDNSSettings is not implemented")
} }
@@ -62,9 +62,16 @@ func (m *mockHostConfigurator) supportCustomPort() bool {
return false return false
} }
func (m *mockHostConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error {
if m.restoreUncleanShutdownDNSFunc != nil {
return m.restoreUncleanShutdownDNSFunc(storedDNSAddress)
}
return fmt.Errorf("method restoreUncleanShutdownDNS is not implemented")
}
func newNoopHostMocker() hostManager { func newNoopHostMocker() hostManager {
return &mockHostConfigurator{ return &mockHostConfigurator{
applyDNSConfigFunc: func(config HostDNSConfig, stateManager *statemanager.Manager) error { return nil }, applyDNSConfigFunc: func(config HostDNSConfig) error { return nil },
restoreHostDNSFunc: func() error { return nil }, restoreHostDNSFunc: func() error { return nil },
supportCustomPortFunc: func() bool { return true }, supportCustomPortFunc: func() bool { return true },
restoreUncleanShutdownDNSFunc: func(*netip.Addr) error { return nil }, restoreUncleanShutdownDNSFunc: func(*netip.Addr) error { return nil },

View File

@@ -1,17 +1,15 @@
package dns package dns
import ( import "net/netip"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
type androidHostManager struct { type androidHostManager struct {
} }
func newHostManager() (*androidHostManager, error) { func newHostManager() (hostManager, error) {
return &androidHostManager{}, nil return &androidHostManager{}, nil
} }
func (a androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error { func (a androidHostManager) applyDNSConfig(config HostDNSConfig) error {
return nil return nil
} }
@@ -22,3 +20,7 @@ func (a androidHostManager) restoreHostDNS() error {
func (a androidHostManager) supportCustomPort() bool { func (a androidHostManager) supportCustomPort() bool {
return false return false
} }
func (a androidHostManager) restoreUncleanShutdownDNS(*netip.Addr) error {
return nil
}

View File

@@ -8,13 +8,12 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"net/netip"
"os/exec" "os/exec"
"strconv" "strconv"
"strings" "strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const ( const (
@@ -38,7 +37,7 @@ type systemConfigurator struct {
systemDNSSettings SystemDNSSettings systemDNSSettings SystemDNSSettings
} }
func newHostManager() (*systemConfigurator, error) { func newHostManager() (hostManager, error) {
return &systemConfigurator{ return &systemConfigurator{
createdKeys: make(map[string]struct{}), createdKeys: make(map[string]struct{}),
}, nil }, nil
@@ -48,11 +47,12 @@ func (s *systemConfigurator) supportCustomPort() bool {
return true return true
} }
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error {
var err error var err error
if err := stateManager.UpdateState(&ShutdownState{}); err != nil { // create a file for unclean shutdown detection
log.Errorf("failed to update shutdown state: %s", err) if err := createUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to create unclean shutdown file: %s", err)
} }
var ( var (
@@ -123,6 +123,10 @@ func (s *systemConfigurator) restoreHostDNS() error {
} }
} }
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown file: %s", err)
}
return nil return nil
} }
@@ -316,7 +320,7 @@ func (s *systemConfigurator) getPrimaryService() (string, string, error) {
return primaryService, router, nil return primaryService, router, nil
} }
func (s *systemConfigurator) restoreUncleanShutdownDNS() error { func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
if err := s.restoreHostDNS(); err != nil { if err := s.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via scutil: %w", err) return fmt.Errorf("restoring dns via scutil: %w", err)
} }

View File

@@ -3,10 +3,9 @@ package dns
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/netip"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
type iosHostManager struct { type iosHostManager struct {
@@ -14,13 +13,13 @@ type iosHostManager struct {
config HostDNSConfig config HostDNSConfig
} }
func newHostManager(dnsManager IosDnsManager) (*iosHostManager, error) { func newHostManager(dnsManager IosDnsManager) (hostManager, error) {
return &iosHostManager{ return &iosHostManager{
dnsManager: dnsManager, dnsManager: dnsManager,
}, nil }, nil
} }
func (a iosHostManager) applyDNSConfig(config HostDNSConfig, _ *statemanager.Manager) error { func (a iosHostManager) applyDNSConfig(config HostDNSConfig) error {
jsonData, err := json.Marshal(config) jsonData, err := json.Marshal(config)
if err != nil { if err != nil {
return fmt.Errorf("marshal: %w", err) return fmt.Errorf("marshal: %w", err)
@@ -38,3 +37,7 @@ func (a iosHostManager) restoreHostDNS() error {
func (a iosHostManager) supportCustomPort() bool { func (a iosHostManager) supportCustomPort() bool {
return false return false
} }
func (a iosHostManager) restoreUncleanShutdownDNS(*netip.Addr) error {
return nil
}

View File

@@ -4,9 +4,9 @@ package dns
import ( import (
"bufio" "bufio"
"errors"
"fmt" "fmt"
"io" "io"
"net/netip"
"os" "os"
"strings" "strings"
@@ -21,8 +21,27 @@ const (
resolvConfManager resolvConfManager
) )
var ErrUnknownOsManagerType = errors.New("unknown os manager type")
type osManagerType int type osManagerType int
func newOsManagerType(osManager string) (osManagerType, error) {
switch osManager {
case "netbird":
return fileManager, nil
case "file":
return netbirdManager, nil
case "networkManager":
return networkManager, nil
case "systemd":
return systemdManager, nil
case "resolvconf":
return resolvConfManager, nil
default:
return 0, ErrUnknownOsManagerType
}
}
func (t osManagerType) String() string { func (t osManagerType) String() string {
switch t { switch t {
case netbirdManager: case netbirdManager:
@@ -40,11 +59,6 @@ func (t osManagerType) String() string {
} }
} }
type restoreHostManager interface {
hostManager
restoreUncleanShutdownDNS(*netip.Addr) error
}
func newHostManager(wgInterface string) (hostManager, error) { func newHostManager(wgInterface string) (hostManager, error) {
osManager, err := getOSDNSManagerType() osManager, err := getOSDNSManagerType()
if err != nil { if err != nil {
@@ -55,7 +69,7 @@ func newHostManager(wgInterface string) (hostManager, error) {
return newHostManagerFromType(wgInterface, osManager) return newHostManagerFromType(wgInterface, osManager)
} }
func newHostManagerFromType(wgInterface string, osManager osManagerType) (restoreHostManager, error) { func newHostManagerFromType(wgInterface string, osManager osManagerType) (hostManager, error) {
switch osManager { switch osManager {
case networkManager: case networkManager:
return newNetworkManagerDbusConfigurator(wgInterface) return newNetworkManagerDbusConfigurator(wgInterface)

View File

@@ -3,12 +3,11 @@ package dns
import ( import (
"fmt" "fmt"
"io" "io"
"net/netip"
"strings" "strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows/registry" "golang.org/x/sys/windows/registry"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const ( const (
@@ -32,7 +31,7 @@ type registryConfigurator struct {
routingAll bool routingAll bool
} }
func newHostManager(wgInterface WGIface) (*registryConfigurator, error) { func newHostManager(wgInterface WGIface) (hostManager, error) {
guid, err := wgInterface.GetInterfaceGUIDString() guid, err := wgInterface.GetInterfaceGUIDString()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -40,7 +39,7 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
return newHostManagerWithGuid(guid) return newHostManagerWithGuid(guid)
} }
func newHostManagerWithGuid(guid string) (*registryConfigurator, error) { func newHostManagerWithGuid(guid string) (hostManager, error) {
return &registryConfigurator{ return &registryConfigurator{
guid: guid, guid: guid,
}, nil }, nil
@@ -50,7 +49,7 @@ func (r *registryConfigurator) supportCustomPort() bool {
return false return false
} }
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error {
var err error var err error
if config.RouteAll { if config.RouteAll {
err = r.addDNSSetupForAll(config.ServerIP) err = r.addDNSSetupForAll(config.ServerIP)
@@ -66,8 +65,9 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
} }
if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid}); err != nil { // create a file for unclean shutdown detection
log.Errorf("failed to update shutdown state: %s", err) if err := createUncleanShutdownIndicator(r.guid); err != nil {
log.Errorf("failed to create unclean shutdown file: %s", err)
} }
var ( var (
@@ -160,6 +160,10 @@ func (r *registryConfigurator) restoreHostDNS() error {
return fmt.Errorf("remove interface registry key: %w", err) return fmt.Errorf("remove interface registry key: %w", err)
} }
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown file: %s", err)
}
return nil return nil
} }
@@ -217,7 +221,7 @@ func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) {
return regKey, nil return regKey, nil
} }
func (r *registryConfigurator) restoreUncleanShutdownDNS() error { func (r *registryConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error {
if err := r.restoreHostDNS(); err != nil { if err := r.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via registry: %w", err) return fmt.Errorf("restoring dns via registry: %w", err)
} }

View File

@@ -16,7 +16,6 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbversion "github.com/netbirdio/netbird/version" nbversion "github.com/netbirdio/netbird/version"
) )
@@ -54,7 +53,6 @@ var supportedNetworkManagerVersionConstraints = []string{
type networkManagerDbusConfigurator struct { type networkManagerDbusConfigurator struct {
dbusLinkObject dbus.ObjectPath dbusLinkObject dbus.ObjectPath
routingAll bool routingAll bool
ifaceName string
} }
// the types below are based on dbus specification, each field is mapped to a dbus type // the types below are based on dbus specification, each field is mapped to a dbus type
@@ -79,7 +77,7 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() {
} }
} }
func newNetworkManagerDbusConfigurator(wgInterface string) (*networkManagerDbusConfigurator, error) { func newNetworkManagerDbusConfigurator(wgInterface string) (hostManager, error) {
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode) obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
if err != nil { if err != nil {
return nil, fmt.Errorf("get nm dbus: %w", err) return nil, fmt.Errorf("get nm dbus: %w", err)
@@ -95,7 +93,6 @@ func newNetworkManagerDbusConfigurator(wgInterface string) (*networkManagerDbusC
return &networkManagerDbusConfigurator{ return &networkManagerDbusConfigurator{
dbusLinkObject: dbus.ObjectPath(s), dbusLinkObject: dbus.ObjectPath(s),
ifaceName: wgInterface,
}, nil }, nil
} }
@@ -103,7 +100,7 @@ func (n *networkManagerDbusConfigurator) supportCustomPort() bool {
return false return false
} }
func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
connSettings, configVersion, err := n.getAppliedConnectionSettings() connSettings, configVersion, err := n.getAppliedConnectionSettings()
if err != nil { if err != nil {
return fmt.Errorf("retrieving the applied connection settings, error: %w", err) return fmt.Errorf("retrieving the applied connection settings, error: %w", err)
@@ -154,12 +151,10 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority) connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority)
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList) connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList)
state := &ShutdownState{ // create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file.
ManagerType: networkManager, // The file content itself is not important for network-manager restoration
WgIface: n.ifaceName, if err := createUncleanShutdownIndicator(defaultResolvConfPath, networkManager, dnsIP.String()); err != nil {
} log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
} }
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains) log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
@@ -176,6 +171,10 @@ func (n *networkManagerDbusConfigurator) restoreHostDNS() error {
return fmt.Errorf("delete connection settings: %w", err) return fmt.Errorf("delete connection settings: %w", err)
} }
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
}
return nil return nil
} }

View File

@@ -9,8 +9,6 @@ import (
"os/exec" "os/exec"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const resolvconfCommand = "resolvconf" const resolvconfCommand = "resolvconf"
@@ -24,7 +22,7 @@ type resolvconf struct {
} }
// supported "openresolv" only // supported "openresolv" only
func newResolvConfConfigurator(wgInterface string) (*resolvconf, error) { func newResolvConfConfigurator(wgInterface string) (hostManager, error) {
resolvConfEntries, err := parseDefaultResolvConf() resolvConfEntries, err := parseDefaultResolvConf()
if err != nil { if err != nil {
log.Errorf("could not read original search domains from %s: %s", defaultResolvConfPath, err) log.Errorf("could not read original search domains from %s: %s", defaultResolvConfPath, err)
@@ -42,7 +40,7 @@ func (r *resolvconf) supportCustomPort() bool {
return false return false
} }
func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
var err error var err error
if !config.RouteAll { if !config.RouteAll {
err = r.restoreHostDNS() err = r.restoreHostDNS()
@@ -62,12 +60,9 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *stateman
append([]string{config.ServerIP}, r.originalNameServers...), append([]string{config.ServerIP}, r.originalNameServers...),
options) options)
state := &ShutdownState{ // create a backup for unclean shutdown detection before the resolv.conf is changed
ManagerType: resolvConfManager, if err := createUncleanShutdownIndicator(defaultResolvConfPath, resolvConfManager, config.ServerIP); err != nil {
WgIface: r.ifaceName, log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
}
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
} }
err = r.applyConfig(buf) err = r.applyConfig(buf)
@@ -84,7 +79,11 @@ func (r *resolvconf) restoreHostDNS() error {
cmd := exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName) cmd := exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
_, err := cmd.Output() _, err := cmd.Output()
if err != nil { if err != nil {
return fmt.Errorf("removing resolvconf configuration for %s interface: %w", r.ifaceName, err) return fmt.Errorf("removing resolvconf configuration for %s interface, error: %w", r.ifaceName, err)
}
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
} }
return nil return nil
@@ -96,7 +95,7 @@ func (r *resolvconf) applyConfig(content bytes.Buffer) error {
cmd.Stdin = &content cmd.Stdin = &content
_, err := cmd.Output() _, err := cmd.Output()
if err != nil { if err != nil {
return fmt.Errorf("applying resolvconf configuration for %s interface: %w", r.ifaceName, err) return fmt.Errorf("applying resolvconf configuration for %s interface, error: %w", r.ifaceName, err)
} }
return nil return nil
} }

View File

@@ -9,7 +9,7 @@ import (
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/iface/mocks"
) )
func TestResponseWriterLocalAddr(t *testing.T) { func TestResponseWriterLocalAddr(t *testing.T) {

View File

@@ -7,7 +7,6 @@ import (
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
"time"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/mitchellh/hashstructure/v2" "github.com/mitchellh/hashstructure/v2"
@@ -15,7 +14,6 @@ import (
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
) )
@@ -65,7 +63,6 @@ type DefaultServer struct {
iosDnsManager IosDnsManager iosDnsManager IosDnsManager
statusRecorder *peer.Status statusRecorder *peer.Status
stateManager *statemanager.Manager
} }
type handlerWithStop interface { type handlerWithStop interface {
@@ -80,7 +77,12 @@ type muxUpdate struct {
} }
// NewDefaultServer returns a new dns server // NewDefaultServer returns a new dns server
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string, statusRecorder *peer.Status, stateManager *statemanager.Manager) (*DefaultServer, error) { func NewDefaultServer(
ctx context.Context,
wgInterface WGIface,
customAddress string,
statusRecorder *peer.Status,
) (*DefaultServer, error) {
var addrPort *netip.AddrPort var addrPort *netip.AddrPort
if customAddress != "" { if customAddress != "" {
parsedAddrPort, err := netip.ParseAddrPort(customAddress) parsedAddrPort, err := netip.ParseAddrPort(customAddress)
@@ -97,7 +99,7 @@ func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress st
dnsService = newServiceViaListener(wgInterface, addrPort) dnsService = newServiceViaListener(wgInterface, addrPort)
} }
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager), nil return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder), nil
} }
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems // NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
@@ -110,7 +112,7 @@ func NewDefaultServerPermanentUpstream(
statusRecorder *peer.Status, statusRecorder *peer.Status,
) *DefaultServer { ) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList) log.Debugf("host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil) ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
ds.hostsDNSHolder.set(hostsDnsList) ds.hostsDNSHolder.set(hostsDnsList)
ds.permanent = true ds.permanent = true
ds.addHostRootZone() ds.addHostRootZone()
@@ -128,12 +130,12 @@ func NewDefaultServerIos(
iosDnsManager IosDnsManager, iosDnsManager IosDnsManager,
statusRecorder *peer.Status, statusRecorder *peer.Status,
) *DefaultServer { ) *DefaultServer {
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil) ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder)
ds.iosDnsManager = iosDnsManager ds.iosDnsManager = iosDnsManager
return ds return ds
} }
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status, stateManager *statemanager.Manager) *DefaultServer { func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status) *DefaultServer {
ctx, stop := context.WithCancel(ctx) ctx, stop := context.WithCancel(ctx)
defaultServer := &DefaultServer{ defaultServer := &DefaultServer{
ctx: ctx, ctx: ctx,
@@ -145,7 +147,6 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
}, },
wgInterface: wgInterface, wgInterface: wgInterface,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
stateManager: stateManager,
hostsDNSHolder: newHostsDNSHolder(), hostsDNSHolder: newHostsDNSHolder(),
} }
@@ -168,7 +169,6 @@ func (s *DefaultServer) Initialize() (err error) {
} }
} }
s.stateManager.RegisterState(&ShutdownState{})
s.hostManager, err = s.initialize() s.hostManager, err = s.initialize()
if err != nil { if err != nil {
return fmt.Errorf("initialize: %w", err) return fmt.Errorf("initialize: %w", err)
@@ -191,10 +191,9 @@ func (s *DefaultServer) Stop() {
s.ctxCancel() s.ctxCancel()
if s.hostManager != nil { if s.hostManager != nil {
if err := s.hostManager.restoreHostDNS(); err != nil { err := s.hostManager.restoreHostDNS()
log.Error("failed to restore host DNS settings: ", err) if err != nil {
} else if err := s.stateManager.DeleteState(&ShutdownState{}); err != nil { log.Error(err)
log.Errorf("failed to delete shutdown dns state: %v", err)
} }
} }
@@ -319,17 +318,10 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
hostUpdate.RouteAll = false hostUpdate.RouteAll = false
} }
if err = s.hostManager.applyDNSConfig(hostUpdate, s.stateManager); err != nil { if err = s.hostManager.applyDNSConfig(hostUpdate); err != nil {
log.Error(err) log.Error(err)
} }
// persist dns state right away
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
defer cancel()
if err := s.stateManager.PersistState(ctx); err != nil {
log.Errorf("Failed to persist dns state: %v", err)
}
if s.searchDomainNotifier != nil { if s.searchDomainNotifier != nil {
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains()) s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
} }
@@ -529,17 +521,10 @@ func (s *DefaultServer) upstreamCallbacks(
} }
} }
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil { if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err) l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
} }
// persist dns state right away
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
defer cancel()
if err := s.stateManager.PersistState(ctx); err != nil {
l.Errorf("Failed to persist dns state: %v", err)
}
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 { if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
s.addHostRootZone() s.addHostRootZone()
} }
@@ -566,7 +551,7 @@ func (s *DefaultServer) upstreamCallbacks(
s.currentConfig.RouteAll = true s.currentConfig.RouteAll = true
s.service.RegisterMux(nbdns.RootZone, handler) s.service.RegisterMux(nbdns.RootZone, handler)
} }
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil { if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil {
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
} }

View File

@@ -15,19 +15,16 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/iface"
pfmock "github.com/netbirdio/netbird/iface/mocks"
) )
type mocWGIface struct { type mocWGIface struct {
filter device.PacketFilter filter iface.PacketFilter
} }
func (w *mocWGIface) Name() string { func (w *mocWGIface) Name() string {
@@ -46,11 +43,11 @@ func (w *mocWGIface) ToInterface() *net.Interface {
panic("implement me") panic("implement me")
} }
func (w *mocWGIface) GetFilter() device.PacketFilter { func (w *mocWGIface) GetFilter() iface.PacketFilter {
return w.filter return w.filter
} }
func (w *mocWGIface) GetDevice() *device.FilteredDevice { func (w *mocWGIface) GetDevice() *iface.DeviceWrapper {
panic("implement me") panic("implement me")
} }
@@ -62,13 +59,13 @@ func (w *mocWGIface) IsUserspaceBind() bool {
return false return false
} }
func (w *mocWGIface) SetFilter(filter device.PacketFilter) error { func (w *mocWGIface) SetFilter(filter iface.PacketFilter) error {
w.filter = filter w.filter = filter
return nil return nil
} }
func (w *mocWGIface) GetStats(_ string) (configurer.WGStats, error) { func (w *mocWGIface) GetStats(_ string) (iface.WGStats, error) {
return configurer.WGStats{}, nil return iface.WGStats{}, nil
} }
var zoneRecords = []nbdns.SimpleRecord{ var zoneRecords = []nbdns.SimpleRecord{
@@ -268,17 +265,7 @@ func TestUpdateDNSServer(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
opts := iface.WGIFaceOpts{
IFaceName: fmt.Sprintf("utun230%d", n),
Address: fmt.Sprintf("100.66.100.%d/32", n+1),
WGPort: 33100,
WGPrivKey: privKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgIface, err := iface.NewWGIFace(opts)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -292,7 +279,7 @@ func TestUpdateDNSServer(t *testing.T) {
t.Log(err) t.Log(err)
} }
}() }()
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil) dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -356,15 +343,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
} }
privKey, _ := wgtypes.GeneratePrivateKey() privKey, _ := wgtypes.GeneratePrivateKey()
opts := iface.WGIFaceOpts{ wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
IFaceName: "utun2301",
Address: "100.66.100.1/32",
WGPort: 33100,
WGPrivKey: privKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgIface, err := iface.NewWGIFace(opts)
if err != nil { if err != nil {
t.Errorf("build interface wireguard: %v", err) t.Errorf("build interface wireguard: %v", err)
return return
@@ -401,7 +380,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
return return
} }
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil) dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{})
if err != nil { if err != nil {
t.Errorf("create DNS server: %v", err) t.Errorf("create DNS server: %v", err)
return return
@@ -496,7 +475,7 @@ func TestDNSServerStartStop(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil) dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{})
if err != nil { if err != nil {
t.Fatalf("%v", err) t.Fatalf("%v", err)
} }
@@ -555,7 +534,6 @@ func TestDNSServerStartStop(t *testing.T) {
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
hostManager := &mockHostConfigurator{} hostManager := &mockHostConfigurator{}
server := DefaultServer{ server := DefaultServer{
ctx: context.Background(),
service: NewServiceViaMemory(&mocWGIface{}), service: NewServiceViaMemory(&mocWGIface{}),
localResolver: &localResolver{ localResolver: &localResolver{
registeredMap: make(registrationMap), registeredMap: make(registrationMap),
@@ -572,7 +550,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
} }
var domainsUpdate string var domainsUpdate string
hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error { hostManager.applyDNSConfigFunc = func(config HostDNSConfig) error {
domains := []string{} domains := []string{}
for _, item := range config.Domains { for _, item := range config.Domains {
if item.Disabled { if item.Disabled {
@@ -823,17 +801,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
} }
privKey, _ := wgtypes.GeneratePrivateKey() privKey, _ := wgtypes.GeneratePrivateKey()
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
opts := iface.WGIFaceOpts{
IFaceName: "utun2301",
Address: "100.66.100.2/24",
WGPort: 33100,
WGPrivKey: privKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgIface, err := iface.NewWGIFace(opts)
if err != nil { if err != nil {
t.Fatalf("build interface wireguard: %v", err) t.Fatalf("build interface wireguard: %v", err)
return nil, err return nil, err

View File

@@ -1,5 +1,5 @@
package dns package dns
func (s *DefaultServer) initialize() (hostManager, error) { func (s *DefaultServer) initialize() (manager hostManager, err error) {
return newHostManager(s.wgInterface) return newHostManager(s.wgInterface)
} }

View File

@@ -7,7 +7,7 @@ import (
var errNotImplemented = errors.New("not implemented") var errNotImplemented = errors.New("not implemented")
func newSystemdDbusConfigurator(string) (restoreHostManager, error) { func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
return nil, fmt.Errorf("systemd dns management: %w on freebsd", errNotImplemented) return nil, fmt.Errorf("systemd dns management: %w on freebsd", errNotImplemented)
} }

View File

@@ -15,7 +15,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
) )
@@ -39,7 +38,6 @@ const (
type systemdDbusConfigurator struct { type systemdDbusConfigurator struct {
dbusLinkObject dbus.ObjectPath dbusLinkObject dbus.ObjectPath
routingAll bool routingAll bool
ifaceName string
} }
// the types below are based on dbus specification, each field is mapped to a dbus type // the types below are based on dbus specification, each field is mapped to a dbus type
@@ -57,7 +55,7 @@ type systemdDbusLinkDomainsInput struct {
MatchOnly bool MatchOnly bool
} }
func newSystemdDbusConfigurator(wgInterface string) (*systemdDbusConfigurator, error) { func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
iface, err := net.InterfaceByName(wgInterface) iface, err := net.InterfaceByName(wgInterface)
if err != nil { if err != nil {
return nil, fmt.Errorf("get interface: %w", err) return nil, fmt.Errorf("get interface: %w", err)
@@ -79,7 +77,6 @@ func newSystemdDbusConfigurator(wgInterface string) (*systemdDbusConfigurator, e
return &systemdDbusConfigurator{ return &systemdDbusConfigurator{
dbusLinkObject: dbus.ObjectPath(s), dbusLinkObject: dbus.ObjectPath(s),
ifaceName: wgInterface,
}, nil }, nil
} }
@@ -87,7 +84,7 @@ func (s *systemdDbusConfigurator) supportCustomPort() bool {
return true return true
} }
func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
parsedIP, err := netip.ParseAddr(config.ServerIP) parsedIP, err := netip.ParseAddr(config.ServerIP)
if err != nil { if err != nil {
return fmt.Errorf("unable to parse ip address, error: %w", err) return fmt.Errorf("unable to parse ip address, error: %w", err)
@@ -138,12 +135,10 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort) log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
} }
state := &ShutdownState{ // create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file.
ManagerType: systemdManager, // The file content itself is not important for systemd restoration
WgIface: s.ifaceName, if err := createUncleanShutdownIndicator(defaultResolvConfPath, systemdManager, parsedIP.String()); err != nil {
} log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
} }
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains) log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
@@ -179,6 +174,10 @@ func (s *systemdDbusConfigurator) restoreHostDNS() error {
return fmt.Errorf("unable to revert link configuration, got error: %w", err) return fmt.Errorf("unable to revert link configuration, got error: %w", err)
} }
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
}
return s.flushCaches() return s.flushCaches()
} }

View File

@@ -0,0 +1,5 @@
package dns
func CheckUncleanShutdown(string) error {
return nil
}

View File

@@ -3,25 +3,57 @@
package dns package dns
import ( import (
"errors"
"fmt" "fmt"
"io/fs"
"os"
"path/filepath"
log "github.com/sirupsen/logrus"
) )
type ShutdownState struct { const fileUncleanShutdownFileLocation = "/var/lib/netbird/unclean_shutdown_dns"
}
func (s *ShutdownState) Name() string { func CheckUncleanShutdown(string) error {
return "dns_state" if _, err := os.Stat(fileUncleanShutdownFileLocation); err != nil {
} if errors.Is(err, fs.ErrNotExist) {
// no file -> clean shutdown
return nil
} else {
return fmt.Errorf("state: %w", err)
}
}
log.Warnf("detected unclean shutdown, file %s exists. Restoring unclean shutdown dns settings.", fileUncleanShutdownFileLocation)
func (s *ShutdownState) Cleanup() error {
manager, err := newHostManager() manager, err := newHostManager()
if err != nil { if err != nil {
return fmt.Errorf("create host manager: %w", err) return fmt.Errorf("create host manager: %w", err)
} }
if err := manager.restoreUncleanShutdownDNS(); err != nil { if err := manager.restoreUncleanShutdownDNS(nil); err != nil {
return fmt.Errorf("restore unclean shutdown dns: %w", err) return fmt.Errorf("restore unclean shutdown backup: %w", err)
} }
return nil return nil
} }
func createUncleanShutdownIndicator() error {
dir := filepath.Dir(fileUncleanShutdownFileLocation)
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
return fmt.Errorf("create dir %s: %w", dir, err)
}
if err := os.WriteFile(fileUncleanShutdownFileLocation, nil, 0644); err != nil { //nolint:gosec
return fmt.Errorf("create %s: %w", fileUncleanShutdownFileLocation, err)
}
return nil
}
func removeUncleanShutdownIndicator() error {
if err := os.Remove(fileUncleanShutdownFileLocation); err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("remove %s: %w", fileUncleanShutdownFileLocation, err)
}
return nil
}

View File

@@ -0,0 +1,5 @@
package dns
func CheckUncleanShutdown(string) error {
return nil
}

View File

@@ -1,14 +0,0 @@
//go:build ios || android
package dns
type ShutdownState struct {
}
func (s *ShutdownState) Name() string {
return "dns_state"
}
func (s *ShutdownState) Cleanup() error {
return nil
}

Some files were not shown because too many files have changed in this diff Show More