Compare commits

..

22 Commits

Author SHA1 Message Date
Misha Bragin
515ce9e3af Update management/server/sqlite_store.go 2024-04-17 20:55:32 +02:00
Misha Bragin
89383b7f01 Update management/server/sqlite_store.go 2024-04-17 20:55:01 +02:00
Misha Bragin
db34162733 Update management/server/sqlite_store.go 2024-04-17 20:54:14 +02:00
Misha Bragin
bd761e2177 Update management/server/sqlite_store.go 2024-04-17 20:53:32 +02:00
Misha Bragin
4e1b95a4c6 Update management/server/sqlite_store.go 2024-04-17 20:53:24 +02:00
Misha Bragin
05993af7bf Update management/server/sqlite_store.go 2024-04-17 20:53:11 +02:00
braginini
9d1cb00570 Fix setup keys test 2024-04-17 20:27:55 +02:00
braginini
543731df45 Fix setup keys test 2024-04-17 19:58:24 +02:00
braginini
e6628ec231 Fix setup keys 2024-04-17 19:48:09 +02:00
braginini
41d4dd2aff reduce log level of scheduler to trace 2024-04-17 19:34:59 +02:00
braginini
30bed57711 Fix account deletion 2024-04-17 19:12:53 +02:00
braginini
6960b68322 Add pats to test save account 2024-04-17 19:07:17 +02:00
braginini
3b3aa18148 Store setup keys and ns groups in a batch 2024-04-17 18:32:13 +02:00
braginini
93045f3e3a Fix rand lint issue 2024-04-17 18:07:02 +02:00
braginini
fd3c1dea8e Add save large account test 2024-04-17 18:02:10 +02:00
braginini
48aff7a26e Fix test compilation errors 2024-04-17 17:39:28 +02:00
braginini
83dfe8e3a3 Fix test compilation errors 2024-04-17 17:27:23 +02:00
braginini
38e10af2d9 Add accountID reference 2024-04-17 17:16:56 +02:00
braginini
99854a126a Add comments 2024-04-17 17:08:01 +02:00
braginini
a75f982fcd Copy account when storing to avoid reference issues 2024-04-17 17:03:21 +02:00
braginini
e7a6483912 Optimize all other objects storing in SQLite 2024-04-17 12:35:41 +02:00
braginini
30ede299b8 Optimize peer storing in SQLite 2024-04-17 11:50:33 +02:00
593 changed files with 17110 additions and 46165 deletions

View File

@@ -1,8 +0,0 @@
root = true
[*]
end_of_line = lf
insert_final_newline = true
[*.go]
indent_style = tab

View File

@@ -31,14 +31,9 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
`netbird version` `netbird version`
**NetBird status -dA output:** **NetBird status -d output:**
If applicable, add the `netbird status -dA' command output. If applicable, add the `netbird status -d' command output.
**Do you face any (non-mobile) client issues?**
Please provide the file created by `netbird debug for 1m -AS`.
We advise reviewing the anonymized files for any remaining PII.
**Screenshots** **Screenshots**

View File

@@ -14,18 +14,18 @@ jobs:
test: test:
strategy: strategy:
matrix: matrix:
store: ['sqlite'] store: ['jsonfile', 'sqlite']
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: macos-go-${{ hashFiles('**/go.sum') }} key: macos-go-${{ hashFiles('**/go.sum') }}
@@ -42,4 +42,4 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./... run: NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...

View File

@@ -1,46 +0,0 @@
name: Test Code FreeBSD
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- name: Test in FreeBSD
id: test
uses: vmactions/freebsd-vm@v1
with:
usesh: true
copyback: false
release: "14.1"
prepare: |
pkg install -y go
# -x - to print all executed commands
# -e - to faile on first error
run: |
set -e -x
time go build -o netbird client/main.go
# check all component except management, since we do not support management server on freebsd
time go test -timeout 1m -failfast ./base62/...
# NOTE: without -p1 `client/internal/dns` will fail becasue of `listen udp4 :33100: bind: address already in use`
time go test -timeout 8m -failfast -p 1 ./client/...
time go test -timeout 1m -failfast ./dns/...
time go test -timeout 1m -failfast ./encryption/...
time go test -timeout 1m -failfast ./formatter/...
time go test -timeout 1m -failfast ./client/iface/...
time go test -timeout 1m -failfast ./route/...
time go test -timeout 1m -failfast ./sharedsock/...
time go test -timeout 1m -failfast ./signal/...
time go test -timeout 1m -failfast ./util/...
time go test -timeout 1m -failfast ./version/...

View File

@@ -15,17 +15,17 @@ jobs:
strategy: strategy:
matrix: matrix:
arch: [ '386','amd64' ] arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres'] store: [ 'jsonfile', 'sqlite' ]
runs-on: ubuntu-22.04 runs-on: ubuntu-latest
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -33,7 +33,7 @@ jobs:
${{ runner.os }}-go- ${{ runner.os }}-go-
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
@@ -49,18 +49,18 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./... run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
test_client_on_docker: test_client_on_docker:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -68,7 +68,7 @@ jobs:
${{ runner.os }}-go- ${{ runner.os }}-go-
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
@@ -80,16 +80,13 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Generate Iface Test bin - name: Generate Iface Test bin
run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./client/iface/ run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./iface/
- name: Generate Shared Sock Test bin - name: Generate Shared Sock Test bin
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
- name: Generate RouteManager Test bin - name: Generate RouteManager Test bin
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager run: CGO_ENABLED=1 go test -c -o routemanager-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/...
- name: Generate SystemOps Test bin
run: CGO_ENABLED=1 go test -c -o systemops-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/systemops
- name: Generate nftables Manager Test bin - name: Generate nftables Manager Test bin
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/... run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
@@ -111,9 +108,6 @@ jobs:
- name: Run RouteManager tests in docker - name: Run RouteManager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1 run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
- name: Run SystemOps tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/systemops-testing.bin -test.timeout 5m -test.parallel 1
- name: Run nftables Manager tests in docker - name: Run nftables Manager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1 run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1

View File

@@ -17,13 +17,13 @@ jobs:
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
id: go id: go
with: with:
go-version: "1.23.x" go-version: "1.21.x"
- name: Download wintun - name: Download wintun
uses: carlosperate/download-file-action@v2 uses: carlosperate/download-file-action@v2

View File

@@ -15,11 +15,11 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: codespell - name: codespell
uses: codespell-project/actions-codespell@v2 uses: codespell-project/actions-codespell@v2
with: with:
ignore_words_list: erro,clienta,hastable,iif ignore_words_list: erro,clienta
skip: go.mod,go.sum skip: go.mod,go.sum
only_warn: 1 only_warn: 1
golangci: golangci:
@@ -32,15 +32,15 @@ jobs:
timeout-minutes: 15 timeout-minutes: 15
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Check for duplicate constants - name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest' if: matrix.os == 'ubuntu-latest'
run: | run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep . ! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
cache: false cache: false
- name: Install dependencies - name: Install dependencies
if: matrix.os == 'ubuntu-latest' if: matrix.os == 'ubuntu-latest'

View File

@@ -13,7 +13,6 @@ concurrency:
jobs: jobs:
test-install-script: test-install-script:
strategy: strategy:
fail-fast: false
max-parallel: 2 max-parallel: 2
matrix: matrix:
os: [ubuntu-latest, macos-latest] os: [ubuntu-latest, macos-latest]
@@ -22,7 +21,7 @@ jobs:
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: run install script - name: run install script
env: env:

View File

@@ -15,30 +15,30 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
- name: Setup Android SDK - name: Setup Android SDK
uses: android-actions/setup-android@v3 uses: android-actions/setup-android@v3
with: with:
cmdline-tools-version: 8512546 cmdline-tools-version: 8512546
- name: Setup Java - name: Setup Java
uses: actions/setup-java@v4 uses: actions/setup-java@v3
with: with:
java-version: "11" java-version: "11"
distribution: "adopt" distribution: "adopt"
- name: NDK Cache - name: NDK Cache
id: ndk-cache id: ndk-cache
uses: actions/cache@v4 uses: actions/cache@v3
with: with:
path: /usr/local/lib/android/sdk/ndk path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620 key: ndk-cache-23.1.7779620
- name: Setup NDK - name: Setup NDK
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620" run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
- name: install gomobile - name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
- name: gomobile init - name: gomobile init
run: gomobile init run: gomobile init
- name: build android netbird lib - name: build android netbird lib
@@ -50,16 +50,16 @@ jobs:
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
- name: install gomobile - name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20230531173138-3c911d8e3eda
- name: gomobile init - name: gomobile init
run: gomobile init run: gomobile init
- name: build iOS netbird lib - name: build iOS netbird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o ./NetBirdSDK.xcframework ./client/ios/NetBirdSDK run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o $GITHUB_WORKSPACE/NetBirdSDK.xcframework $GITHUB_WORKSPACE/client/ios/NetBirdSDK
env: env:
CGO_ENABLED: 0 CGO_ENABLED: 0

View File

@@ -3,16 +3,25 @@ name: Release
on: on:
push: push:
tags: tags:
- "v*" - 'v*'
branches: branches:
- main - main
pull_request: pull_request:
paths:
- 'go.mod'
- 'go.sum'
- '.goreleaser.yml'
- '.goreleaser_ui.yaml'
- '.goreleaser_ui_darwin.yaml'
- '.github/workflows/release.yml'
- 'release_files/**'
- '**/Dockerfile'
- '**/Dockerfile.*'
- 'client/ui/**'
env: env:
SIGN_PIPE_VER: "v0.0.14" SIGN_PIPE_VER: "v0.0.11"
GORELEASER_VER: "v2.3.2" GORELEASER_VER: "v1.14.1"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
@@ -20,30 +29,26 @@ concurrency:
jobs: jobs:
release: release:
runs-on: ubuntu-22.04 runs-on: ubuntu-latest
env: env:
flags: "" flags: ""
steps: steps:
- name: Parse semver string
id: semver_parser
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }} - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout -
uses: actions/checkout@v4 name: Checkout
uses: actions/checkout@v3
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go -
uses: actions/setup-go@v5 name: Set up Go
uses: actions/setup-go@v4
with: with:
go-version: "1.23" go-version: "1.21"
cache: false cache: false
- name: Cache Go modules -
uses: actions/cache@v4 name: Cache Go modules
uses: actions/cache@v3
with: with:
path: | path: |
~/go/pkg/mod ~/go/pkg/mod
@@ -51,86 +56,74 @@ jobs:
key: ${{ runner.os }}-go-releaser-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-releaser-${{ hashFiles('**/go.sum') }}
restore-keys: | restore-keys: |
${{ runner.os }}-go-releaser- ${{ runner.os }}-go-releaser-
- name: Install modules -
name: Install modules
run: go mod tidy run: go mod tidy
- name: check git status -
name: check git status
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Set up QEMU -
name: Set up QEMU
uses: docker/setup-qemu-action@v2 uses: docker/setup-qemu-action@v2
- name: Set up Docker Buildx -
name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2 uses: docker/setup-buildx-action@v2
- name: Login to Docker hub -
name: Login to Docker hub
if: github.event_name != 'pull_request' if: github.event_name != 'pull_request'
uses: docker/login-action@v1 uses: docker/login-action@v1
with: with:
username: ${{ secrets.DOCKER_USER }} username: netbirdio
password: ${{ secrets.DOCKER_TOKEN }} password: ${{ secrets.DOCKER_TOKEN }}
- name: Install OS build dependencies - name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
- name: Install goversioninfo - name: Install rsrc
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e run: go install github.com/akavel/rsrc@v0.10.2
- name: Generate windows syso amd64 - name: Generate windows rsrc amd64
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_amd64.syso
- name: Run GoReleaser - name: Generate windows rsrc arm64
run: rsrc -arch arm64 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm64.syso
- name: Generate windows rsrc arm
run: rsrc -arch arm -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_arm.syso
- name: Generate windows rsrc 386
run: rsrc -arch 386 -ico client/ui/netbird.ico -manifest client/manifest.xml -o client/resources_windows_386.syso
-
name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
args: release --clean ${{ env.flags }} args: release --rm-dist ${{ env.flags }}
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- name: upload non tags for debug purposes -
uses: actions/upload-artifact@v4 name: upload non tags for debug purposes
uses: actions/upload-artifact@v3
with: with:
name: release name: release
path: dist/ path: dist/
retention-days: 3 retention-days: 3
- name: upload linux packages
uses: actions/upload-artifact@v4
with:
name: linux-packages
path: dist/netbird_linux**
retention-days: 3
- name: upload windows packages
uses: actions/upload-artifact@v4
with:
name: windows-packages
path: dist/netbird_windows**
retention-days: 3
- name: upload macos packages
uses: actions/upload-artifact@v4
with:
name: macos-packages
path: dist/netbird_darwin**
retention-days: 3
release_ui: release_ui:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Parse semver string
id: semver_parser
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }} - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v3
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23" go-version: "1.21"
cache: false cache: false
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v3
with: with:
path: | path: |
~/go/pkg/mod ~/go/pkg/mod
@@ -147,44 +140,46 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64 run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
- name: Install goversioninfo - name: Install rsrc
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e run: go install github.com/akavel/rsrc@v0.10.2
- name: Generate windows syso amd64 - name: Generate windows rsrc
run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso run: rsrc -arch amd64 -ico client/ui/netbird.ico -manifest client/ui/manifest.xml -o client/ui/resources_windows_amd64.syso
- name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }} args: release --config .goreleaser_ui.yaml --rm-dist ${{ env.flags }}
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- name: upload non tags for debug purposes - name: upload non tags for debug purposes
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v3
with: with:
name: release-ui name: release-ui
path: dist/ path: dist/
retention-days: 3 retention-days: 3
release_ui_darwin: release_ui_darwin:
runs-on: macos-latest runs-on: macos-11
steps: steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }} - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout -
uses: actions/checkout@v4 name: Checkout
uses: actions/checkout@v3
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go -
uses: actions/setup-go@v5 name: Set up Go
uses: actions/setup-go@v4
with: with:
go-version: "1.23" go-version: "1.21"
cache: false cache: false
- name: Cache Go modules -
uses: actions/cache@v4 name: Cache Go modules
uses: actions/cache@v3
with: with:
path: | path: |
~/go/pkg/mod ~/go/pkg/mod
@@ -192,34 +187,52 @@ jobs:
key: ${{ runner.os }}-ui-go-releaser-darwin-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-ui-go-releaser-darwin-${{ hashFiles('**/go.sum') }}
restore-keys: | restore-keys: |
${{ runner.os }}-ui-go-releaser-darwin- ${{ runner.os }}-ui-go-releaser-darwin-
- name: Install modules -
name: Install modules
run: go mod tidy run: go mod tidy
- name: check git status -
name: check git status
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Run GoReleaser -
name: Run GoReleaser
id: goreleaser id: goreleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }} args: release --config .goreleaser_ui_darwin.yaml --rm-dist ${{ env.flags }}
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: upload non tags for debug purposes -
uses: actions/upload-artifact@v4 name: upload non tags for debug purposes
uses: actions/upload-artifact@v3
with: with:
name: release-ui-darwin name: release-ui-darwin
path: dist/ path: dist/
retention-days: 3 retention-days: 3
trigger_signer: trigger_windows_signer:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: [release, release_ui, release_ui_darwin] needs: [release,release_ui]
if: startsWith(github.ref, 'refs/tags/') if: startsWith(github.ref, 'refs/tags/')
steps: steps:
- name: Trigger binaries sign pipelines - name: Trigger Windows binaries sign pipeline
uses: benc-uk/workflow-dispatch@v1 uses: benc-uk/workflow-dispatch@v1
with: with:
workflow: Sign bin and installer workflow: Sign windows bin and installer
repo: netbirdio/sign-pipelines
ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}" }'
trigger_darwin_signer:
runs-on: ubuntu-latest
needs: [release,release_ui_darwin]
if: startsWith(github.ref, 'refs/tags/')
steps:
- name: Trigger Darwin App binaries sign pipeline
uses: benc-uk/workflow-dispatch@v1
with:
workflow: Sign darwin ui app with dispatch
repo: netbirdio/sign-pipelines repo: netbirdio/sign-pipelines
ref: ${{ env.SIGN_PIPE_VER }} ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }} token: ${{ secrets.SIGN_GITHUB_TOKEN }}

View File

@@ -18,31 +18,7 @@ concurrency:
jobs: jobs:
test-docker-compose: test-docker-compose:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy:
matrix:
store: [ 'sqlite', 'postgres' ]
services:
postgres:
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
env:
POSTGRES_USER: netbird
POSTGRES_PASSWORD: postgres
POSTGRES_DB: netbird
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
ports:
- 5432:5432
steps: steps:
- name: Set Database Connection String
run: |
if [ "${{ matrix.store }}" == "postgres" ]; then
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN=host=$(hostname -I | awk '{print $1}') user=netbird password=postgres dbname=netbird port=5432" >> $GITHUB_ENV
else
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN==" >> $GITHUB_ENV
fi
- name: Install jq - name: Install jq
run: sudo apt-get install -y jq run: sudo apt-get install -y jq
@@ -50,12 +26,12 @@ jobs:
run: sudo apt-get install -y curl run: sudo apt-get install -y curl
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v4
with: with:
go-version: "1.23.x" go-version: "1.21.x"
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v3
with: with:
path: ~/go/pkg/mod path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -63,7 +39,7 @@ jobs:
${{ runner.os }}-go- ${{ runner.os }}-go-
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: cp setup.env - name: cp setup.env
run: cp infrastructure_files/tests/setup.env infrastructure_files/ run: cp infrastructure_files/tests/setup.env infrastructure_files/
@@ -82,8 +58,7 @@ jobs:
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified" CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }} CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
NETBIRD_STORE_ENGINE_POSTGRES_DSN: ${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
- name: check values - name: check values
@@ -110,8 +85,7 @@ jobs:
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_SIGNAL_PORT: 12345 CI_NETBIRD_SIGNAL_PORT: 12345
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }} CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4" CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
@@ -149,14 +123,6 @@ jobs:
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES" grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000" grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
# check relay values
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
grep '33445:33445' docker-compose.yml
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
- name: Install modules - name: Install modules
run: go mod tidy run: go mod tidy
@@ -182,35 +148,26 @@ jobs:
run: | run: |
docker build -t netbirdio/signal:latest . docker build -t netbirdio/signal:latest .
- name: Build relay binary
working-directory: relay
run: CGO_ENABLED=0 go build -o netbird-relay main.go
- name: Build relay docker image
working-directory: relay
run: |
docker build -t netbirdio/relay:latest .
- name: run docker compose up - name: run docker compose up
working-directory: infrastructure_files/artifacts working-directory: infrastructure_files/artifacts
run: | run: |
docker compose up -d docker-compose up -d
sleep 5 sleep 5
docker compose ps docker-compose ps
docker compose logs --tail=20 docker-compose logs --tail=20
- name: test running containers - name: test running containers
run: | run: |
count=$(docker compose ps --format json | jq '. | select(.Name | contains("artifacts")) | .State' | grep -c running) count=$(docker compose ps --format json | jq '. | select(.Name | contains("artifacts")) | .State' | grep -c running)
test $count -eq 5 || docker compose logs test $count -eq 4
working-directory: infrastructure_files/artifacts working-directory: infrastructure_files/artifacts
- name: test geolocation databases - name: test geolocation databases
working-directory: infrastructure_files/artifacts working-directory: infrastructure_files/artifacts
run: | run: |
sleep 30 sleep 30
docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City_[0-9]*.mmdb docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City.mmdb
docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames_[0-9]*.db docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames.db
test-getting-started-script: test-getting-started-script:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -219,69 +176,36 @@ jobs:
run: sudo apt-get install -y jq run: sudo apt-get install -y jq
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v3
- name: run script with Zitadel PostgreSQL - name: run script
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
- name: test Caddy file gen postgres - name: test Caddy file gen
run: test -f Caddyfile run: test -f Caddyfile
- name: test docker-compose file gen
- name: test docker-compose file gen postgres
run: test -f docker-compose.yml run: test -f docker-compose.yml
- name: test management.json file gen
- name: test management.json file gen postgres
run: test -f management.json run: test -f management.json
- name: test turnserver.conf file gen
- name: test turnserver.conf file gen postgres
run: | run: |
set -x set -x
test -f turnserver.conf test -f turnserver.conf
grep external-ip turnserver.conf grep external-ip turnserver.conf
- name: test zitadel.env file gen
- name: test zitadel.env file gen postgres
run: test -f zitadel.env run: test -f zitadel.env
- name: test dashboard.env file gen
- name: test dashboard.env file gen postgres
run: test -f dashboard.env run: test -f dashboard.env
test-download-geolite2-script:
- name: test relay.env file gen postgres runs-on: ubuntu-latest
run: test -f relay.env steps:
- name: Install jq
- name: test zdb.env file gen postgres run: sudo apt-get update && sudo apt-get install -y unzip sqlite3
run: test -f zdb.env - name: Checkout code
uses: actions/checkout@v3
- name: Postgres run cleanup - name: test script
run: | run: bash -x infrastructure_files/download-geolite2.sh
docker compose down --volumes --rmi all - name: test mmdb file exists
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env run: test -f GeoLite2-City.mmdb
- name: test geonames file exists
- name: run script with Zitadel CockroachDB run: test -f geonames.db
run: bash -x infrastructure_files/getting-started-with-zitadel.sh
env:
NETBIRD_DOMAIN: use-ip
ZITADEL_DATABASE: cockroach
- name: test Caddy file gen CockroachDB
run: test -f Caddyfile
- name: test docker-compose file gen CockroachDB
run: test -f docker-compose.yml
- name: test management.json file gen CockroachDB
run: test -f management.json
- name: test turnserver.conf file gen CockroachDB
run: |
set -x
test -f turnserver.conf
grep external-ip turnserver.conf
- name: test zitadel.env file gen CockroachDB
run: test -f zitadel.env
- name: test dashboard.env file gen CockroachDB
run: test -f dashboard.env
- name: test relay.env file gen CockroachDB
run: test -f relay.env

1
.gitignore vendored
View File

@@ -29,3 +29,4 @@ infrastructure_files/setup.env
infrastructure_files/setup-*.env infrastructure_files/setup-*.env
.vscode .vscode
.DS_Store .DS_Store
GeoLite2-City*

View File

@@ -130,10 +130,3 @@ issues:
- path: mock\.go - path: mock\.go
linters: linters:
- nilnil - nilnil
# Exclude specific deprecation warnings for grpc methods
- linters:
- staticcheck
text: "grpc.DialContext is deprecated"
- linters:
- staticcheck
text: "grpc.WithBlock is deprecated"

View File

@@ -1,5 +1,3 @@
version: 2
project_name: netbird project_name: netbird
builds: builds:
- id: netbird - id: netbird
@@ -24,7 +22,7 @@ builds:
goarch: 386 goarch: 386
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: '{{ .CommitTimestamp }}'
tags: tags:
- load_wgnt_from_rsrc - load_wgnt_from_rsrc
@@ -44,7 +42,7 @@ builds:
- softfloat - softfloat
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: '{{ .CommitTimestamp }}'
tags: tags:
- load_wgnt_from_rsrc - load_wgnt_from_rsrc
@@ -66,7 +64,7 @@ builds:
- arm - arm
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: '{{ .CommitTimestamp }}'
- id: netbird-signal - id: netbird-signal
dir: signal dir: signal
@@ -80,21 +78,7 @@ builds:
- arm - arm
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: '{{ .CommitTimestamp }}'
- id: netbird-relay
dir: relay
env: [CGO_ENABLED=0]
binary: netbird-relay
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
archives: archives:
- builds: - builds:
@@ -102,6 +86,7 @@ archives:
- netbird-static - netbird-static
nfpms: nfpms:
- maintainer: Netbird <dev@netbird.io> - maintainer: Netbird <dev@netbird.io>
description: Netbird client. description: Netbird client.
homepage: https://netbird.io/ homepage: https://netbird.io/
@@ -176,52 +161,6 @@ dockers:
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-amd64
ids:
- netbird-relay
goarch: amd64
use: buildx
dockerfile: relay/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
ids:
- netbird-relay
goarch: arm64
use: buildx
dockerfile: relay/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-arm
ids:
- netbird-relay
goarch: arm
goarm: 6
use: buildx
dockerfile: relay/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates: - image_templates:
- netbirdio/signal:{{ .Version }}-amd64 - netbirdio/signal:{{ .Version }}-amd64
ids: ids:
@@ -374,18 +313,6 @@ docker_manifests:
- netbirdio/netbird:{{ .Version }}-arm - netbirdio/netbird:{{ .Version }}-arm
- netbirdio/netbird:{{ .Version }}-amd64 - netbirdio/netbird:{{ .Version }}-amd64
- name_template: netbirdio/relay:{{ .Version }}
image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
- netbirdio/relay:{{ .Version }}-arm
- netbirdio/relay:{{ .Version }}-amd64
- name_template: netbirdio/relay:latest
image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
- netbirdio/relay:{{ .Version }}-arm
- netbirdio/relay:{{ .Version }}-amd64
- name_template: netbirdio/signal:{{ .Version }} - name_template: netbirdio/signal:{{ .Version }}
image_templates: image_templates:
- netbirdio/signal:{{ .Version }}-arm64v8 - netbirdio/signal:{{ .Version }}-arm64v8
@@ -417,9 +344,10 @@ docker_manifests:
- netbirdio/management:{{ .Version }}-debug-amd64 - netbirdio/management:{{ .Version }}-debug-amd64
brews: brews:
- ids: -
ids:
- default - default
repository: tap:
owner: netbirdio owner: netbirdio
name: homebrew-tap name: homebrew-tap
token: "{{ .Env.HOMEBREW_TAP_GITHUB_TOKEN }}" token: "{{ .Env.HOMEBREW_TAP_GITHUB_TOKEN }}"

View File

@@ -1,5 +1,3 @@
version: 2
project_name: netbird-ui project_name: netbird-ui
builds: builds:
- id: netbird-ui - id: netbird-ui
@@ -13,7 +11,9 @@ builds:
- amd64 - amd64
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" tags:
- legacy_appindicator
mod_timestamp: '{{ .CommitTimestamp }}'
- id: netbird-ui-windows - id: netbird-ui-windows
dir: client/ui dir: client/ui
@@ -28,7 +28,7 @@ builds:
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
- -H windowsgui - -H windowsgui
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: '{{ .CommitTimestamp }}'
archives: archives:
- id: linux-arch - id: linux-arch
@@ -41,6 +41,7 @@ archives:
- netbird-ui-windows - netbird-ui-windows
nfpms: nfpms:
- maintainer: Netbird <dev@netbird.io> - maintainer: Netbird <dev@netbird.io>
description: Netbird client UI. description: Netbird client UI.
homepage: https://netbird.io/ homepage: https://netbird.io/

View File

@@ -1,14 +1,10 @@
version: 2
project_name: netbird-ui project_name: netbird-ui
builds: builds:
- id: netbird-ui-darwin - id: netbird-ui-darwin
dir: client/ui dir: client/ui
binary: netbird-ui binary: netbird-ui
env: env: [CGO_ENABLED=1]
- CGO_ENABLED=1
- MACOSX_DEPLOYMENT_TARGET=11.0
- MACOS_DEPLOYMENT_TARGET=11.0
goos: goos:
- darwin - darwin
goarch: goarch:
@@ -19,7 +15,7 @@ builds:
- softfloat - softfloat
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: '{{ .CommitTimestamp }}'
tags: tags:
- load_wgnt_from_rsrc - load_wgnt_from_rsrc
@@ -30,4 +26,4 @@ archives:
checksum: checksum:
name_template: "{{ .ProjectName }}_darwin_checksums.txt" name_template: "{{ .ProjectName }}_darwin_checksums.txt"
changelog: changelog:
disable: true skip: true

View File

@@ -5,7 +5,7 @@
We as members, contributors, and leaders pledge to make participation in our We as members, contributors, and leaders pledge to make participation in our
community a harassment-free experience for everyone, regardless of age, body community a harassment-free experience for everyone, regardless of age, body
size, visible or invisible disability, ethnicity, sex characteristics, gender size, visible or invisible disability, ethnicity, sex characteristics, gender
identity and expression, level of experience, education, socioeconomic status, identity and expression, level of experience, education, socio-economic status,
nationality, personal appearance, race, caste, color, religion, or sexual nationality, personal appearance, race, caste, color, religion, or sexual
identity and orientation. identity and orientation.

View File

@@ -96,7 +96,7 @@ They can be executed from the repository root before every push or PR:
**Goreleaser** **Goreleaser**
```shell ```shell
goreleaser build --snapshot --clean goreleaser --snapshot --rm-dist
``` ```
**golangci-lint** **golangci-lint**
```shell ```shell

View File

@@ -10,14 +10,12 @@
<img width="234" src="docs/media/logo-full.png"/> <img width="234" src="docs/media/logo-full.png"/>
</p> </p>
<p> <p>
<a href="https://img.shields.io/badge/license-BSD--3-blue)">
<img src="https://sonarcloud.io/api/project_badges/measure?project=netbirdio_netbird&metric=alert_status" />
</a>
<a href="https://github.com/netbirdio/netbird/blob/main/LICENSE"> <a href="https://github.com/netbirdio/netbird/blob/main/LICENSE">
<img src="https://img.shields.io/badge/license-BSD--3-blue" /> <img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a> </a>
<a href="https://www.codacy.com/gh/netbirdio/netbird/dashboard?utm_source=github.com&amp;utm_medium=referral&amp;utm_content=netbirdio/netbird&amp;utm_campaign=Badge_Grade"><img src="https://app.codacy.com/project/badge/Grade/e3013d046aec44cdb7462c8673b00976"/></a>
<br> <br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ"> <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/> <img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a> </a>
</p> </p>
@@ -30,7 +28,7 @@
<br/> <br/>
See <a href="https://netbird.io/docs/">Documentation</a> See <a href="https://netbird.io/docs/">Documentation</a>
<br/> <br/>
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">Slack channel</a> Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-vrahf41g-ik1v7fV8du6t0RwxSrJ96A">Slack channel</a>
<br/> <br/>
</strong> </strong>
@@ -46,11 +44,8 @@
### Open-Source Network Security in a Single Platform ### Open-Source Network Security in a Single Platform
![image](https://github.com/netbirdio/netbird/assets/700848/c0d7bae4-3301-499a-bb4e-5e4a225bf35f)
![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab)
### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)
### Key features ### Key features
@@ -64,7 +59,6 @@
| | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> | | | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
| | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> | | | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> |
| | | | | <ul><li> - \[x] Docker </ul></li> | | | | | | <ul><li> - \[x] Docker </ul></li> |
### Quickstart with NetBird Cloud ### Quickstart with NetBird Cloud
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install) - Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)

View File

@@ -1,5 +1,5 @@
FROM alpine:3.20 FROM alpine:3.18.5
RUN apk add --no-cache ca-certificates iptables ip6tables RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"] ENTRYPOINT [ "/go/bin/netbird","up"]
COPY netbird /usr/local/bin/netbird COPY netbird /go/bin/netbird

View File

@@ -1,5 +1,3 @@
//go:build android
package android package android
import ( import (
@@ -8,7 +6,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
@@ -16,7 +13,7 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/iface"
) )
// ConnectionListener export internal Listener for mobile // ConnectionListener export internal Listener for mobile
@@ -26,7 +23,7 @@ type ConnectionListener interface {
// TunAdapter export internal TunAdapter for mobile // TunAdapter export internal TunAdapter for mobile
type TunAdapter interface { type TunAdapter interface {
device.TunAdapter iface.TunAdapter
} }
// IFaceDiscover export internal IFaceDiscover for mobile // IFaceDiscover export internal IFaceDiscover for mobile
@@ -51,23 +48,20 @@ func init() {
// Client struct manage the life circle of background service // Client struct manage the life circle of background service
type Client struct { type Client struct {
cfgFile string cfgFile string
tunAdapter device.TunAdapter tunAdapter iface.TunAdapter
iFaceDiscover IFaceDiscover iFaceDiscover IFaceDiscover
recorder *peer.Status recorder *peer.Status
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
ctxCancelLock *sync.Mutex ctxCancelLock *sync.Mutex
deviceName string deviceName string
uiVersion string
networkChangeListener listener.NetworkChangeListener networkChangeListener listener.NetworkChangeListener
} }
// NewClient instantiate a new Client // NewClient instantiate a new Client
func NewClient(cfgFile, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client { func NewClient(cfgFile, deviceName string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
return &Client{ return &Client{
cfgFile: cfgFile, cfgFile: cfgFile,
deviceName: deviceName, deviceName: deviceName,
uiVersion: uiVersion,
tunAdapter: tunAdapter, tunAdapter: tunAdapter,
iFaceDiscover: iFaceDiscover, iFaceDiscover: iFaceDiscover,
recorder: peer.NewRecorder(""), recorder: peer.NewRecorder(""),
@@ -90,9 +84,6 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
var ctx context.Context var ctx context.Context
//nolint //nolint
ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName) ctxWithValues := context.WithValue(context.Background(), system.DeviceNameCtxKey, c.deviceName)
//nolint
ctxWithValues = context.WithValue(ctxWithValues, system.UiVersionCtxKey, c.uiVersion)
c.ctxCancelLock.Lock() c.ctxCancelLock.Lock()
ctx, c.ctxCancel = context.WithCancel(ctxWithValues) ctx, c.ctxCancel = context.WithCancel(ctxWithValues)
defer c.ctxCancel() defer c.ctxCancel()
@@ -106,8 +97,7 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder) return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
} }
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@@ -132,8 +122,7 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder) return internal.RunClientMobile(ctx, cfg, c.recorder, c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
} }
// Stop the internal client and free the resources // Stop the internal client and free the resources

View File

@@ -84,7 +84,7 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
func (a *Auth) saveConfigIfSSOSupported() (bool, error) { func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
supportsSSO := true supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) { err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil) _, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) { if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL) _, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
s, ok := gstatus.FromError(err) s, ok := gstatus.FromError(err)

View File

@@ -1,227 +0,0 @@
package anonymize
import (
"crypto/rand"
"fmt"
"math/big"
"net"
"net/netip"
"net/url"
"regexp"
"slices"
"strings"
)
type Anonymizer struct {
ipAnonymizer map[netip.Addr]netip.Addr
domainAnonymizer map[string]string
currentAnonIPv4 netip.Addr
currentAnonIPv6 netip.Addr
startAnonIPv4 netip.Addr
startAnonIPv6 netip.Addr
}
func DefaultAddresses() (netip.Addr, netip.Addr) {
// 192.51.100.0, 100::
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
}
func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
return &Anonymizer{
ipAnonymizer: map[netip.Addr]netip.Addr{},
domainAnonymizer: map[string]string{},
currentAnonIPv4: startIPv4,
currentAnonIPv6: startIPv6,
startAnonIPv4: startIPv4,
startAnonIPv6: startIPv6,
}
}
func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
if ip.IsLoopback() ||
ip.IsLinkLocalUnicast() ||
ip.IsLinkLocalMulticast() ||
ip.IsInterfaceLocalMulticast() ||
ip.IsPrivate() ||
ip.IsUnspecified() ||
ip.IsMulticast() ||
isWellKnown(ip) ||
a.isInAnonymizedRange(ip) {
return ip
}
if _, ok := a.ipAnonymizer[ip]; !ok {
if ip.Is4() {
a.ipAnonymizer[ip] = a.currentAnonIPv4
a.currentAnonIPv4 = a.currentAnonIPv4.Next()
} else {
a.ipAnonymizer[ip] = a.currentAnonIPv6
a.currentAnonIPv6 = a.currentAnonIPv6.Next()
}
}
return a.ipAnonymizer[ip]
}
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {
return true
} else if !ip.Is4() && ip.Compare(a.startAnonIPv6) >= 0 && ip.Compare(a.currentAnonIPv6) <= 0 {
return true
}
return false
}
func (a *Anonymizer) AnonymizeIPString(ip string) string {
addr, err := netip.ParseAddr(ip)
if err != nil {
return ip
}
return a.AnonymizeIP(addr).String()
}
func (a *Anonymizer) AnonymizeDomain(domain string) string {
if strings.HasSuffix(domain, "netbird.io") ||
strings.HasSuffix(domain, "netbird.selfhosted") ||
strings.HasSuffix(domain, "netbird.cloud") ||
strings.HasSuffix(domain, "netbird.stage") ||
strings.HasSuffix(domain, ".domain") {
return domain
}
parts := strings.Split(domain, ".")
if len(parts) < 2 {
return domain
}
baseDomain := parts[len(parts)-2] + "." + parts[len(parts)-1]
anonymized, ok := a.domainAnonymizer[baseDomain]
if !ok {
anonymizedBase := "anon-" + generateRandomString(5) + ".domain"
a.domainAnonymizer[baseDomain] = anonymizedBase
anonymized = anonymizedBase
}
return strings.Replace(domain, baseDomain, anonymized, 1)
}
func (a *Anonymizer) AnonymizeURI(uri string) string {
u, err := url.Parse(uri)
if err != nil {
return uri
}
var anonymizedHost string
if u.Opaque != "" {
host, port, err := net.SplitHostPort(u.Opaque)
if err == nil {
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
} else {
anonymizedHost = a.AnonymizeDomain(u.Opaque)
}
u.Opaque = anonymizedHost
} else if u.Host != "" {
host, port, err := net.SplitHostPort(u.Host)
if err == nil {
anonymizedHost = fmt.Sprintf("%s:%s", a.AnonymizeDomain(host), port)
} else {
anonymizedHost = a.AnonymizeDomain(u.Host)
}
u.Host = anonymizedHost
}
return u.String()
}
func (a *Anonymizer) AnonymizeString(str string) string {
ipv4Regex := regexp.MustCompile(`\b(?:[0-9]{1,3}\.){3}[0-9]{1,3}\b`)
ipv6Regex := regexp.MustCompile(`\b([0-9a-fA-F:]+:+[0-9a-fA-F]{0,4})(?:%[0-9a-zA-Z]+)?(?:\/[0-9]{1,3})?(?::[0-9]{1,5})?\b`)
str = ipv4Regex.ReplaceAllStringFunc(str, a.AnonymizeIPString)
str = ipv6Regex.ReplaceAllStringFunc(str, a.AnonymizeIPString)
for domain, anonDomain := range a.domainAnonymizer {
str = strings.ReplaceAll(str, domain, anonDomain)
}
str = a.AnonymizeSchemeURI(str)
str = a.AnonymizeDNSLogLine(str)
return str
}
// AnonymizeSchemeURI finds and anonymizes URIs with stun, stuns, turn, and turns schemes.
func (a *Anonymizer) AnonymizeSchemeURI(text string) string {
re := regexp.MustCompile(`(?i)\b(stuns?:|turns?:|https?://)\S+\b`)
return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
}
// AnonymizeDNSLogLine anonymizes domain names in DNS log entries by replacing them with a random string.
func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
domainPattern := `dns\.Question{Name:"([^"]+)",`
domainRegex := regexp.MustCompile(domainPattern)
return domainRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
parts := strings.Split(match, `"`)
if len(parts) >= 2 {
domain := parts[1]
if strings.HasSuffix(domain, ".domain") {
return match
}
randomDomain := generateRandomString(10) + ".domain"
return strings.Replace(match, domain, randomDomain, 1)
}
return match
})
}
// AnonymizeRoute anonymizes a route string by replacing IP addresses with anonymized versions and
// domain names with random strings.
func (a *Anonymizer) AnonymizeRoute(route string) string {
prefix, err := netip.ParsePrefix(route)
if err == nil {
ip := a.AnonymizeIPString(prefix.Addr().String())
return fmt.Sprintf("%s/%d", ip, prefix.Bits())
}
domains := strings.Split(route, ", ")
for i, domain := range domains {
domains[i] = a.AnonymizeDomain(domain)
}
return strings.Join(domains, ", ")
}
func isWellKnown(addr netip.Addr) bool {
wellKnown := []string{
"8.8.8.8", "8.8.4.4", // Google DNS IPv4
"2001:4860:4860::8888", "2001:4860:4860::8844", // Google DNS IPv6
"1.1.1.1", "1.0.0.1", // Cloudflare DNS IPv4
"2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6
"9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4
"2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6
}
if slices.Contains(wellKnown, addr.String()) {
return true
}
cgnatRangeStart := netip.AddrFrom4([4]byte{100, 64, 0, 0})
cgnatRange := netip.PrefixFrom(cgnatRangeStart, 10)
return cgnatRange.Contains(addr)
}
func generateRandomString(length int) string {
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, length)
for i := range result {
num, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
if err != nil {
continue
}
result[i] = letters[num.Int64()]
}
return string(result)
}

View File

@@ -1,223 +0,0 @@
package anonymize_test
import (
"net/netip"
"regexp"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/anonymize"
)
func TestAnonymizeIP(t *testing.T) {
startIPv4 := netip.MustParseAddr("198.51.100.0")
startIPv6 := netip.MustParseAddr("100::")
anonymizer := anonymize.NewAnonymizer(startIPv4, startIPv6)
tests := []struct {
name string
ip string
expect string
}{
{"Well known", "8.8.8.8", "8.8.8.8"},
{"First Public IPv4", "1.2.3.4", "198.51.100.0"},
{"Second Public IPv4", "4.3.2.1", "198.51.100.1"},
{"Repeated IPv4", "1.2.3.4", "198.51.100.0"},
{"Private IPv4", "192.168.1.1", "192.168.1.1"},
{"First Public IPv6", "2607:f8b0:4005:805::200e", "100::"},
{"Second Public IPv6", "a::b", "100::1"},
{"Repeated IPv6", "2607:f8b0:4005:805::200e", "100::"},
{"Private IPv6", "fe80::1", "fe80::1"},
{"In Range IPv4", "198.51.100.2", "198.51.100.2"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ip := netip.MustParseAddr(tc.ip)
anonymizedIP := anonymizer.AnonymizeIP(ip)
if anonymizedIP.String() != tc.expect {
t.Errorf("%s: expected %s, got %s", tc.name, tc.expect, anonymizedIP)
}
})
}
}
func TestAnonymizeDNSLogLine(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
testLog := `2024-04-23T20:01:11+02:00 TRAC client/internal/dns/local.go:25: received question: dns.Question{Name:"example.com", Qtype:0x1c, Qclass:0x1}`
result := anonymizer.AnonymizeDNSLogLine(testLog)
require.NotEqual(t, testLog, result)
assert.NotContains(t, result, "example.com")
}
func TestAnonymizeDomain(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
tests := []struct {
name string
domain string
expectPattern string
shouldAnonymize bool
}{
{
"General Domain",
"example.com",
`^anon-[a-zA-Z0-9]+\.domain$`,
true,
},
{
"Subdomain",
"sub.example.com",
`^sub\.anon-[a-zA-Z0-9]+\.domain$`,
true,
},
{
"Protected Domain",
"netbird.io",
`^netbird\.io$`,
false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeDomain(tc.domain)
if tc.shouldAnonymize {
assert.Regexp(t, tc.expectPattern, result, "The anonymized domain should match the expected pattern")
assert.NotContains(t, result, tc.domain, "The original domain should not be present in the result")
} else {
assert.Equal(t, tc.domain, result, "Protected domains should not be anonymized")
}
})
}
}
func TestAnonymizeURI(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
tests := []struct {
name string
uri string
regex string
}{
{
"HTTP URI with Port",
"http://example.com:80/path",
`^http://anon-[a-zA-Z0-9]+\.domain:80/path$`,
},
{
"HTTP URI without Port",
"http://example.com/path",
`^http://anon-[a-zA-Z0-9]+\.domain/path$`,
},
{
"Opaque URI with Port",
"stun:example.com:80?transport=udp",
`^stun:anon-[a-zA-Z0-9]+\.domain:80\?transport=udp$`,
},
{
"Opaque URI without Port",
"stun:example.com?transport=udp",
`^stun:anon-[a-zA-Z0-9]+\.domain\?transport=udp$`,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeURI(tc.uri)
assert.Regexp(t, regexp.MustCompile(tc.regex), result, "URI should match expected pattern")
require.NotContains(t, result, "example.com", "Original domain should not be present")
})
}
}
func TestAnonymizeSchemeURI(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
tests := []struct {
name string
input string
expect string
}{
{"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`},
{"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`},
{"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeSchemeURI(tc.input)
assert.Regexp(t, tc.expect, result, "The anonymized output should match expected pattern")
require.NotContains(t, result, "example.com", "Original domain should not be present")
})
}
}
func TestAnonymizString_MemorizedDomain(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
domain := "example.com"
anonymizedDomain := anonymizer.AnonymizeDomain(domain)
sampleString := "This is a test string including the domain example.com which should be anonymized."
firstPassResult := anonymizer.AnonymizeString(sampleString)
secondPassResult := anonymizer.AnonymizeString(firstPassResult)
assert.Contains(t, firstPassResult, anonymizedDomain, "The domain should be anonymized in the first pass")
assert.NotContains(t, firstPassResult, domain, "The original domain should not appear in the first pass output")
assert.Equal(t, firstPassResult, secondPassResult, "The second pass should not further anonymize the string")
}
func TestAnonymizeString_DoubleURI(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
domain := "example.com"
anonymizedDomain := anonymizer.AnonymizeDomain(domain)
sampleString := "Check out our site at https://example.com for more info."
firstPassResult := anonymizer.AnonymizeString(sampleString)
secondPassResult := anonymizer.AnonymizeString(firstPassResult)
assert.Contains(t, firstPassResult, "https://"+anonymizedDomain, "The URI should be anonymized in the first pass")
assert.NotContains(t, firstPassResult, "https://example.com", "The original URI should not appear in the first pass output")
assert.Equal(t, firstPassResult, secondPassResult, "The second pass should not further anonymize the URI")
}
func TestAnonymizeString_IPAddresses(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
tests := []struct {
name string
input string
expect string
}{
{
name: "IPv4 Address",
input: "Error occurred at IP 122.138.1.1",
expect: "Error occurred at IP 198.51.100.0",
},
{
name: "IPv6 Address",
input: "Access attempted from 2001:db8::ff00:42",
expect: "Access attempted from 100::",
},
{
name: "IPv6 Address with Port",
input: "Access attempted from [2001:db8::ff00:42]:8080",
expect: "Access attempted from [100::]:8080",
},
{
name: "Both IPv4 and IPv6",
input: "IPv4: 142.108.0.1 and IPv6: 2001:db8::ff00:43",
expect: "IPv4: 198.51.100.1 and IPv6: 100::1",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeString(tc.input)
assert.Equal(t, tc.expect, result, "IP addresses should be anonymized correctly")
})
}
}

View File

@@ -1,273 +0,0 @@
package cmd
import (
"context"
"fmt"
"time"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server"
)
const errCloseConnection = "Failed to close connection: %v"
var debugCmd = &cobra.Command{
Use: "debug",
Short: "Debugging commands",
Long: "Provides commands for debugging and logging control within the Netbird daemon.",
}
var debugBundleCmd = &cobra.Command{
Use: "bundle",
Example: " netbird debug bundle",
Short: "Create a debug bundle",
Long: "Generates a compressed archive of the daemon's logs and status for debugging purposes.",
RunE: debugBundle,
}
var logCmd = &cobra.Command{
Use: "log",
Short: "Manage logging for the Netbird daemon",
Long: `Commands to manage logging settings for the Netbird daemon, including ICE, gRPC, and general log levels.`,
}
var logLevelCmd = &cobra.Command{
Use: "level <level>",
Short: "Set the logging level for this session",
Long: `Sets the logging level for the current session. This setting is temporary and will revert to the default on daemon restart.
Available log levels are:
panic: for panic level, highest level of severity
fatal: for fatal level errors that cause the program to exit
error: for error conditions
warn: for warning conditions
info: for informational messages
debug: for debug-level messages
trace: for trace-level messages, which include more fine-grained information than debug`,
Args: cobra.ExactArgs(1),
RunE: setLogLevel,
}
var forCmd = &cobra.Command{
Use: "for <time>",
Short: "Run debug logs for a specified duration and create a debug bundle",
Long: `Sets the logging level to trace, runs for the specified duration, and then generates a debug bundle.`,
Example: " netbird debug for 5m",
Args: cobra.ExactArgs(1),
RunE: runForDuration,
}
func debugBundle(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd),
SystemInfo: debugSystemInfoFlag,
})
if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
cmd.Println(resp.GetPath())
return nil
}
func setLogLevel(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
level := server.ParseLogLevel(args[0])
if level == proto.LogLevel_UNKNOWN {
return fmt.Errorf("unknown log level: %s. Available levels are: panic, fatal, error, warn, info, debug, trace\n", args[0])
}
_, err = client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{
Level: level,
})
if err != nil {
return fmt.Errorf("failed to set log level: %v", status.Convert(err).Message())
}
cmd.Println("Log level set successfully to", args[0])
return nil
}
func runForDuration(cmd *cobra.Command, args []string) error {
duration, err := time.ParseDuration(args[0])
if err != nil {
return fmt.Errorf("invalid duration format: %v", err)
}
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
log.Errorf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
stat, err := client.Status(cmd.Context(), &proto.StatusRequest{})
if err != nil {
return fmt.Errorf("failed to get status: %v", status.Convert(err).Message())
}
stateWasDown := stat.Status != string(internal.StatusConnected) && stat.Status != string(internal.StatusConnecting)
initialLogLevel, err := client.GetLogLevel(cmd.Context(), &proto.GetLogLevelRequest{})
if err != nil {
return fmt.Errorf("failed to get log level: %v", status.Convert(err).Message())
}
if stateWasDown {
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
}
cmd.Println("Netbird up")
time.Sleep(time.Second * 10)
}
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
if !initialLevelTrace {
_, err = client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{
Level: proto.LogLevel_TRACE,
})
if err != nil {
return fmt.Errorf("failed to set log level to TRACE: %v", status.Convert(err).Message())
}
cmd.Println("Log level set to trace.")
}
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
}
cmd.Println("Netbird down")
time.Sleep(1 * time.Second)
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
}
cmd.Println("Netbird up")
time.Sleep(3 * time.Second)
headerPostUp := fmt.Sprintf("----- Netbird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd))
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
return waitErr
}
cmd.Println("\nDuration completed")
cmd.Println("Creating debug bundle...")
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag,
Status: statusOutput,
SystemInfo: debugSystemInfoFlag,
})
if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
}
if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
}
cmd.Println("Netbird down")
}
if !initialLevelTrace {
if _, err := client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{Level: initialLogLevel.GetLevel()}); err != nil {
return fmt.Errorf("failed to restore log level: %v", status.Convert(err).Message())
}
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
}
cmd.Println(resp.GetPath())
return nil
}
func getStatusOutput(cmd *cobra.Command) string {
var statusOutputString string
statusResp, err := getStatus(cmd.Context())
if err != nil {
cmd.PrintErrf("Failed to get status: %v\n", err)
} else {
statusOutputString = parseToFullDetailSummary(convertToStatusOutputOverview(statusResp))
}
return statusOutputString
}
func waitForDurationOrCancel(ctx context.Context, duration time.Duration, cmd *cobra.Command) error {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
startTime := time.Now()
done := make(chan struct{})
go func() {
defer close(done)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
elapsed := time.Since(startTime)
if elapsed >= duration {
return
}
remaining := duration - elapsed
cmd.Printf("\rRemaining time: %s", formatDuration(remaining))
}
}
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-done:
return nil
}
}
func formatDuration(d time.Duration) string {
d = d.Round(time.Second)
h := d / time.Hour
d %= time.Hour
m := d / time.Minute
d %= time.Minute
s := d / time.Second
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
}

View File

@@ -2,9 +2,8 @@ package cmd
import ( import (
"context" "context"
"time"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@@ -26,7 +25,7 @@ var downCmd = &cobra.Command{
return err return err
} }
ctx, cancel := context.WithTimeout(context.Background(), time.Second*7) ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
defer cancel() defer cancel()
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)
@@ -42,8 +41,6 @@ var downCmd = &cobra.Command{
log.Errorf("call service down method: %v", err) log.Errorf("call service down method: %v", err)
return err return err
} }
cmd.Println("Disconnected")
return nil return nil
}, },
} }

View File

@@ -39,11 +39,6 @@ var loginCmd = &cobra.Command{
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName) ctx = context.WithValue(ctx, system.DeviceNameCtxKey, hostName)
} }
providedSetupKey, err := getSetupKey()
if err != nil {
return err
}
// workaround to run without service // workaround to run without service
if logFile == "console" { if logFile == "console" {
err = handleRebrand(cmd) err = handleRebrand(cmd)
@@ -67,7 +62,7 @@ var loginCmd = &cobra.Command{
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath) config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey) err = foregroundLogin(ctx, cmd, config, setupKey)
if err != nil { if err != nil {
return fmt.Errorf("foreground login failed: %v", err) return fmt.Errorf("foreground login failed: %v", err)
} }
@@ -86,7 +81,7 @@ var loginCmd = &cobra.Command{
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey, SetupKey: setupKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
IsLinuxDesktopClient: isLinuxRunningDesktop(), IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName, Hostname: hostName,

View File

@@ -5,8 +5,8 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )

View File

@@ -32,12 +32,9 @@ const (
preSharedKeyFlag = "preshared-key" preSharedKeyFlag = "preshared-key"
interfaceNameFlag = "interface-name" interfaceNameFlag = "interface-name"
wireguardPortFlag = "wireguard-port" wireguardPortFlag = "wireguard-port"
networkMonitorFlag = "network-monitor"
disableAutoConnectFlag = "disable-auto-connect" disableAutoConnectFlag = "disable-auto-connect"
serverSSHAllowedFlag = "allow-server-ssh" serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist" extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info"
) )
var ( var (
@@ -56,7 +53,6 @@ var (
managementURL string managementURL string
adminURL string adminURL string
setupKey string setupKey string
setupKeyPath string
hostName string hostName string
preSharedKey string preSharedKey string
natExternalIPs []string natExternalIPs []string
@@ -66,14 +62,9 @@ var (
serverSSHAllowed bool serverSSHAllowed bool
interfaceName string interfaceName string
wireguardPort uint16 wireguardPort uint16
networkMonitor bool
serviceName string serviceName string
autoConnectDisabled bool autoConnectDisabled bool
extraIFaceBlackList []string extraIFaceBlackList []string
anonymizeFlag bool
debugSystemInfoFlag bool
dnsRouteInterval time.Duration
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "netbird", Use: "netbird",
Short: "", Short: "",
@@ -94,15 +85,12 @@ func init() {
oldDefaultConfigPathDir = "/etc/wiretrustee/" oldDefaultConfigPathDir = "/etc/wiretrustee/"
oldDefaultLogFileDir = "/var/log/wiretrustee/" oldDefaultLogFileDir = "/var/log/wiretrustee/"
switch runtime.GOOS { if runtime.GOOS == "windows" {
case "windows":
defaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\" defaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\"
defaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\" defaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Netbird\\"
oldDefaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\" oldDefaultConfigPathDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\"
oldDefaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\" oldDefaultLogFileDir = os.Getenv("PROGRAMDATA") + "\\Wiretrustee\\"
case "freebsd":
defaultConfigPathDir = "/var/db/netbird/"
} }
defaultConfigPath = defaultConfigPathDir + "config.json" defaultConfigPath = defaultConfigPathDir + "config.json"
@@ -127,14 +115,10 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name") rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location") rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", defaultConfigPath, "Netbird config file location")
rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level") rootCmd.PersistentFlags().StringVarP(&logLevel, "log-level", "l", "info", "sets Netbird log level")
rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout. If syslog is specified the log will be sent to syslog daemon.") rootCmd.PersistentFlags().StringVar(&logFile, "log-file", defaultLogFile, "sets Netbird log path. If console is specified the log will be output to stdout")
rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)") rootCmd.PersistentFlags().StringVarP(&setupKey, "setup-key", "k", "", "Setup key obtained from the Management Service Dashboard (used to register peer)")
rootCmd.PersistentFlags().StringVar(&setupKeyPath, "setup-key-file", "", "The path to a setup key obtained from the Management Service Dashboard (used to register peer) This is ignored if the setup-key flag is provided.")
rootCmd.MarkFlagsMutuallyExclusive("setup-key", "setup-key-file")
rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.") rootCmd.PersistentFlags().StringVar(&preSharedKey, preSharedKeyFlag, "", "Sets Wireguard PreSharedKey property. If set, then only peers that have the same key can communicate.")
rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device") rootCmd.PersistentFlags().StringVarP(&hostName, "hostname", "n", "", "Sets a custom hostname for the device")
rootCmd.PersistentFlags().BoolVarP(&anonymizeFlag, "anonymize", "A", false, "anonymize IP addresses and non-netbird.io domains in logs and status output")
rootCmd.AddCommand(serviceCmd) rootCmd.AddCommand(serviceCmd)
rootCmd.AddCommand(upCmd) rootCmd.AddCommand(upCmd)
rootCmd.AddCommand(downCmd) rootCmd.AddCommand(downCmd)
@@ -142,20 +126,8 @@ func init() {
rootCmd.AddCommand(loginCmd) rootCmd.AddCommand(loginCmd)
rootCmd.AddCommand(versionCmd) rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd) rootCmd.AddCommand(sshCmd)
rootCmd.AddCommand(routesCmd)
rootCmd.AddCommand(debugCmd)
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
routesCmd.AddCommand(routesListCmd)
routesCmd.AddCommand(routesSelectCmd, routesDeselectCmd)
debugCmd.AddCommand(debugBundleCmd)
debugCmd.AddCommand(logCmd)
logCmd.AddCommand(logLevelCmd)
debugCmd.AddCommand(forCmd)
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil, upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
`Sets external IPs maps between local addresses and interfaces.`+ `Sets external IPs maps between local addresses and interfaces.`+
`You can specify a comma-separated list with a single IP and IP/IP or IP/Interface Name. `+ `You can specify a comma-separated list with a single IP and IP/IP or IP/Interface Name. `+
@@ -173,8 +145,6 @@ func init() {
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.") upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted") upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.") upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", false, "Adds system information to the debug bundle")
} }
// SetupCloseHandler handles SIGTERM signal and exits with success // SetupCloseHandler handles SIGTERM signal and exits with success
@@ -256,21 +226,6 @@ var CLIBackOffSettings = &backoff.ExponentialBackOff{
Clock: backoff.SystemClock, Clock: backoff.SystemClock,
} }
func getSetupKey() (string, error) {
if setupKeyPath != "" && setupKey == "" {
return getSetupKeyFromFile(setupKeyPath)
}
return setupKey, nil
}
func getSetupKeyFromFile(setupKeyPath string) (string, error) {
data, err := os.ReadFile(setupKeyPath)
if err != nil {
return "", fmt.Errorf("failed to read setup key file: %v", err)
}
return strings.TrimSpace(string(data)), nil
}
func handleRebrand(cmd *cobra.Command) error { func handleRebrand(cmd *cobra.Command) error {
var err error var err error
if logFile == defaultLogFile { if logFile == defaultLogFile {
@@ -380,17 +335,3 @@ func migrateToNetbird(oldPath, newPath string) bool {
return true return true
} }
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
SetFlagsFromEnvVars(rootCmd)
cmd.SetOut(cmd.OutOrStdout())
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil {
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
return conn, nil
}

View File

@@ -4,10 +4,6 @@ import (
"fmt" "fmt"
"io" "io"
"testing" "testing"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/iface"
) )
func TestInitCommands(t *testing.T) { func TestInitCommands(t *testing.T) {
@@ -38,44 +34,3 @@ func TestInitCommands(t *testing.T) {
}) })
} }
} }
func TestSetFlagsFromEnvVars(t *testing.T) {
var cmd = &cobra.Command{
Use: "netbird",
Long: "test",
SilenceUsage: true,
Run: func(cmd *cobra.Command, args []string) {
SetFlagsFromEnvVars(cmd)
},
}
cmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
`comma separated list of external IPs to map to the Wireguard interface`)
cmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
cmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "Enable Rosenpass feature Rosenpass.")
cmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
t.Setenv("NB_EXTERNAL_IP_MAP", "abc,dec")
t.Setenv("NB_INTERFACE_NAME", "test-name")
t.Setenv("NB_ENABLE_ROSENPASS", "true")
t.Setenv("NB_WIREGUARD_PORT", "10000")
err := cmd.Execute()
if err != nil {
t.Fatalf("expected no error while running netbird command, got %v", err)
}
if len(natExternalIPs) != 2 {
t.Errorf("expected 2 external ips, got %d", len(natExternalIPs))
}
if natExternalIPs[0] != "abc" || natExternalIPs[1] != "dec" {
t.Errorf("expected abc,dec, got %s,%s", natExternalIPs[0], natExternalIPs[1])
}
if interfaceName != "test-name" {
t.Errorf("expected test-name, got %s", interfaceName)
}
if !rosenpassEnabled {
t.Errorf("expected rosenpassEnabled to be true, got false")
}
if wireguardPort != 10000 {
t.Errorf("expected wireguardPort to be 10000, got %d", wireguardPort)
}
}

View File

@@ -1,174 +0,0 @@
package cmd
import (
"fmt"
"strings"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var appendFlag bool
var routesCmd = &cobra.Command{
Use: "routes",
Short: "Manage network routes",
Long: `Commands to list, select, or deselect network routes.`,
}
var routesListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List routes",
Example: " netbird routes list",
Long: "List all available network routes.",
RunE: routesList,
}
var routesSelectCmd = &cobra.Command{
Use: "select route...|all",
Short: "Select routes",
Long: "Select a list of routes by identifiers or 'all' to clear all selections and to accept all (including new) routes.\nDefault mode is replace, use -a to append to already selected routes.",
Example: " netbird routes select all\n netbird routes select route1 route2\n netbird routes select -a route3",
Args: cobra.MinimumNArgs(1),
RunE: routesSelect,
}
var routesDeselectCmd = &cobra.Command{
Use: "deselect route...|all",
Short: "Deselect routes",
Long: "Deselect previously selected routes by identifiers or 'all' to disable accepting any routes.",
Example: " netbird routes deselect all\n netbird routes deselect route1 route2",
Args: cobra.MinimumNArgs(1),
RunE: routesDeselect,
}
func init() {
routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current route selection instead of replacing")
}
func routesList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.ListRoutes(cmd.Context(), &proto.ListRoutesRequest{})
if err != nil {
return fmt.Errorf("failed to list routes: %v", status.Convert(err).Message())
}
if len(resp.Routes) == 0 {
cmd.Println("No routes available.")
return nil
}
printRoutes(cmd, resp)
return nil
}
func printRoutes(cmd *cobra.Command, resp *proto.ListRoutesResponse) {
cmd.Println("Available Routes:")
for _, route := range resp.Routes {
printRoute(cmd, route)
}
}
func printRoute(cmd *cobra.Command, route *proto.Route) {
selectedStatus := getSelectedStatus(route)
domains := route.GetDomains()
if len(domains) > 0 {
printDomainRoute(cmd, route, domains, selectedStatus)
} else {
printNetworkRoute(cmd, route, selectedStatus)
}
}
func getSelectedStatus(route *proto.Route) string {
if route.GetSelected() {
return "Selected"
}
return "Not Selected"
}
func printDomainRoute(cmd *cobra.Command, route *proto.Route, domains []string, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus)
resolvedIPs := route.GetResolvedIPs()
if len(resolvedIPs) > 0 {
printResolvedIPs(cmd, domains, resolvedIPs)
} else {
cmd.Printf(" Resolved IPs: -\n")
}
}
func printNetworkRoute(cmd *cobra.Command, route *proto.Route, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
}
func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) {
cmd.Printf(" Resolved IPs:\n")
for _, domain := range domains {
if ipList, exists := resolvedIPs[domain]; exists {
cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", "))
}
}
}
func routesSelect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectRoutesRequest{
RouteIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
} else if appendFlag {
req.Append = true
}
if _, err := client.SelectRoutes(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to select routes: %v", status.Convert(err).Message())
}
cmd.Println("Routes selected successfully.")
return nil
}
func routesDeselect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectRoutesRequest{
RouteIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
}
if _, err := client.DeselectRoutes(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to deselect routes: %v", status.Convert(err).Message())
}
cmd.Println("Routes deselected successfully.")
return nil
}

View File

@@ -2,21 +2,18 @@ package cmd
import ( import (
"context" "context"
"github.com/kardianos/service" "github.com/kardianos/service"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/server"
) )
type program struct { type program struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
serv *grpc.Server serv *grpc.Server
serverInstance *server.Server
} }
func newProgram(ctx context.Context, cancel context.CancelFunc) *program { func newProgram(ctx context.Context, cancel context.CancelFunc) *program {

View File

@@ -61,8 +61,6 @@ func (p *program) Start(svc service.Service) error {
} }
proto.RegisterDaemonServiceServer(p.serv, serverInstance) proto.RegisterDaemonServiceServer(p.serv, serverInstance)
p.serverInstance = serverInstance
log.Printf("started daemon server: %v", split[1]) log.Printf("started daemon server: %v", split[1])
if err := p.serv.Serve(listen); err != nil { if err := p.serv.Serve(listen); err != nil {
log.Errorf("failed to serve daemon requests: %v", err) log.Errorf("failed to serve daemon requests: %v", err)
@@ -72,14 +70,6 @@ func (p *program) Start(svc service.Service) error {
} }
func (p *program) Stop(srv service.Service) error { func (p *program) Stop(srv service.Service) error {
if p.serverInstance != nil {
in := new(proto.DownRequest)
_, err := p.serverInstance.Down(p.ctx, in)
if err != nil {
log.Errorf("failed to stop daemon: %v", err)
}
}
p.cancel() p.cancel()
if p.serv != nil { if p.serv != nil {

View File

@@ -31,8 +31,6 @@ var installCmd = &cobra.Command{
configPath, configPath,
"--log-level", "--log-level",
logLevel, logLevel,
"--daemon-addr",
daemonAddr,
} }
if managementURL != "" { if managementURL != "" {

View File

@@ -24,7 +24,7 @@ var (
) )
var sshCmd = &cobra.Command{ var sshCmd = &cobra.Command{
Use: "ssh [user@]host", Use: "ssh",
Args: func(cmd *cobra.Command, args []string) error { Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 { if len(args) < 1 {
return errors.New("requires a host argument") return errors.New("requires a host argument")
@@ -94,7 +94,7 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command)
if err != nil { if err != nil {
cmd.Printf("Error: %v\n", err) cmd.Printf("Error: %v\n", err)
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" + cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
"\nYou can verify the connection by running:\n\n" + "You can verify the connection by running:\n\n" +
" netbird status\n\n") " netbird status\n\n")
return err return err
} }

View File

@@ -6,8 +6,6 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"os"
"runtime"
"sort" "sort"
"strings" "strings"
"time" "time"
@@ -16,7 +14,6 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
@@ -31,9 +28,9 @@ type peerStateDetailOutput struct {
Status string `json:"status" yaml:"status"` Status string `json:"status" yaml:"status"`
LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"` LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
ConnType string `json:"connectionType" yaml:"connectionType"` ConnType string `json:"connectionType" yaml:"connectionType"`
Direct bool `json:"direct" yaml:"direct"`
IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"` IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
IceCandidateEndpoint iceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"` IceCandidateEndpoint iceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"`
RelayAddress string `json:"relayAddress" yaml:"relayAddress"`
LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"` LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"`
TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"` TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"`
TransferSent int64 `json:"transferSent" yaml:"transferSent"` TransferSent int64 `json:"transferSent" yaml:"transferSent"`
@@ -147,9 +144,9 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed initializing log %v", err) return fmt.Errorf("failed initializing log %v", err)
} }
ctx := internal.CtxInitState(cmd.Context()) ctx := internal.CtxInitState(context.Background())
resp, err := getStatus(ctx) resp, err := getStatus(ctx, cmd)
if err != nil { if err != nil {
return err return err
} }
@@ -194,7 +191,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil return nil
} }
func getStatus(ctx context.Context) (*proto.StatusResponse, error) { func getStatus(ctx context.Context, cmd *cobra.Command) (*proto.StatusResponse, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr) conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+ return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+
@@ -203,7 +200,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
} }
defer conn.Close() defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true}) resp, err := proto.NewDaemonServiceClient(conn).Status(cmd.Context(), &proto.StatusRequest{GetFullPeerStatus: true})
if err != nil { if err != nil {
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message()) return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
} }
@@ -286,11 +283,6 @@ func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverv
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()), NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
} }
if anonymizeFlag {
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
anonymizeOverview(anonymizer, &overview)
}
return overview return overview
} }
@@ -335,18 +327,16 @@ func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput {
func mapPeers(peers []*proto.PeerState) peersStateOutput { func mapPeers(peers []*proto.PeerState) peersStateOutput {
var peersStateDetail []peerStateDetailOutput var peersStateDetail []peerStateDetailOutput
peersConnected := 0
for _, pbPeerState := range peers {
localICE := "" localICE := ""
remoteICE := "" remoteICE := ""
localICEEndpoint := "" localICEEndpoint := ""
remoteICEEndpoint := "" remoteICEEndpoint := ""
relayServerAddress := ""
connType := "" connType := ""
peersConnected := 0
lastHandshake := time.Time{} lastHandshake := time.Time{}
transferReceived := int64(0) transferReceived := int64(0)
transferSent := int64(0) transferSent := int64(0)
for _, pbPeerState := range peers {
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String() isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
if skipDetailByFilters(pbPeerState, isPeerConnected) { if skipDetailByFilters(pbPeerState, isPeerConnected) {
continue continue
@@ -362,7 +352,6 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
if pbPeerState.Relayed { if pbPeerState.Relayed {
connType = "Relayed" connType = "Relayed"
} }
relayServerAddress = pbPeerState.GetRelayAddress()
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local() lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
transferReceived = pbPeerState.GetBytesRx() transferReceived = pbPeerState.GetBytesRx()
transferSent = pbPeerState.GetBytesTx() transferSent = pbPeerState.GetBytesTx()
@@ -375,6 +364,7 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
Status: pbPeerState.GetConnStatus(), Status: pbPeerState.GetConnStatus(),
LastStatusUpdate: timeLocal, LastStatusUpdate: timeLocal,
ConnType: connType, ConnType: connType,
Direct: pbPeerState.GetDirect(),
IceCandidateType: iceCandidateType{ IceCandidateType: iceCandidateType{
Local: localICE, Local: localICE,
Remote: remoteICE, Remote: remoteICE,
@@ -383,7 +373,6 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
Local: localICEEndpoint, Local: localICEEndpoint,
Remote: remoteICEEndpoint, Remote: remoteICEEndpoint,
}, },
RelayAddress: relayServerAddress,
FQDN: pbPeerState.GetFqdn(), FQDN: pbPeerState.GetFqdn(),
LastWireguardHandshake: lastHandshake, LastWireguardHandshake: lastHandshake,
TransferReceived: transferReceived, TransferReceived: transferReceived,
@@ -536,15 +525,7 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total) peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
goos := runtime.GOOS
goarch := runtime.GOARCH
goarm := ""
if goarch == "arm" {
goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM"))
}
summary := fmt.Sprintf( summary := fmt.Sprintf(
"OS: %s\n"+
"Daemon version: %s\n"+ "Daemon version: %s\n"+
"CLI version: %s\n"+ "CLI version: %s\n"+
"Management: %s\n"+ "Management: %s\n"+
@@ -557,7 +538,6 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
"Quantum resistance: %s\n"+ "Quantum resistance: %s\n"+
"Routes: %s\n"+ "Routes: %s\n"+
"Peers count: %s\n", "Peers count: %s\n",
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
overview.DaemonVersion, overview.DaemonVersion,
version.NetbirdVersion(), version.NetbirdVersion(),
managementConnString, managementConnString,
@@ -613,6 +593,15 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
if peerState.IceCandidateEndpoint.Remote != "" { if peerState.IceCandidateEndpoint.Remote != "" {
remoteICEEndpoint = peerState.IceCandidateEndpoint.Remote remoteICEEndpoint = peerState.IceCandidateEndpoint.Remote
} }
lastStatusUpdate := "-"
if !peerState.LastStatusUpdate.IsZero() {
lastStatusUpdate = peerState.LastStatusUpdate.Format("2006-01-02 15:04:05")
}
lastWireGuardHandshake := "-"
if !peerState.LastWireguardHandshake.IsZero() && peerState.LastWireguardHandshake != time.Unix(0, 0) {
lastWireGuardHandshake = peerState.LastWireguardHandshake.Format("2006-01-02 15:04:05")
}
rosenpassEnabledStatus := "false" rosenpassEnabledStatus := "false"
if rosenpassEnabled { if rosenpassEnabled {
@@ -644,9 +633,9 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
" Status: %s\n"+ " Status: %s\n"+
" -- detail --\n"+ " -- detail --\n"+
" Connection type: %s\n"+ " Connection type: %s\n"+
" Direct: %t\n"+
" ICE candidate (Local/Remote): %s/%s\n"+ " ICE candidate (Local/Remote): %s/%s\n"+
" ICE candidate endpoints (Local/Remote): %s/%s\n"+ " ICE candidate endpoints (Local/Remote): %s/%s\n"+
" Relay server address: %s\n"+
" Last connection update: %s\n"+ " Last connection update: %s\n"+
" Last WireGuard handshake: %s\n"+ " Last WireGuard handshake: %s\n"+
" Transfer status (received/sent) %s/%s\n"+ " Transfer status (received/sent) %s/%s\n"+
@@ -658,13 +647,13 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
peerState.PubKey, peerState.PubKey,
peerState.Status, peerState.Status,
peerState.ConnType, peerState.ConnType,
peerState.Direct,
localICE, localICE,
remoteICE, remoteICE,
localICEEndpoint, localICEEndpoint,
remoteICEEndpoint, remoteICEEndpoint,
peerState.RelayAddress, lastStatusUpdate,
timeAgo(peerState.LastStatusUpdate), lastWireGuardHandshake,
timeAgo(peerState.LastWireguardHandshake),
toIEC(peerState.TransferReceived), toIEC(peerState.TransferReceived),
toIEC(peerState.TransferSent), toIEC(peerState.TransferSent),
rosenpassEnabledStatus, rosenpassEnabledStatus,
@@ -733,124 +722,3 @@ func countEnabled(dnsServers []nsServerGroupStateOutput) int {
} }
return count return count
} }
// timeAgo returns a string representing the duration since the provided time in a human-readable format.
func timeAgo(t time.Time) string {
if t.IsZero() || t.Equal(time.Unix(0, 0)) {
return "-"
}
duration := time.Since(t)
switch {
case duration < time.Second:
return "Now"
case duration < time.Minute:
seconds := int(duration.Seconds())
if seconds == 1 {
return "1 second ago"
}
return fmt.Sprintf("%d seconds ago", seconds)
case duration < time.Hour:
minutes := int(duration.Minutes())
seconds := int(duration.Seconds()) % 60
if minutes == 1 {
if seconds == 1 {
return "1 minute, 1 second ago"
} else if seconds > 0 {
return fmt.Sprintf("1 minute, %d seconds ago", seconds)
}
return "1 minute ago"
}
if seconds > 0 {
return fmt.Sprintf("%d minutes, %d seconds ago", minutes, seconds)
}
return fmt.Sprintf("%d minutes ago", minutes)
case duration < 24*time.Hour:
hours := int(duration.Hours())
minutes := int(duration.Minutes()) % 60
if hours == 1 {
if minutes == 1 {
return "1 hour, 1 minute ago"
} else if minutes > 0 {
return fmt.Sprintf("1 hour, %d minutes ago", minutes)
}
return "1 hour ago"
}
if minutes > 0 {
return fmt.Sprintf("%d hours, %d minutes ago", hours, minutes)
}
return fmt.Sprintf("%d hours ago", hours)
}
days := int(duration.Hours()) / 24
hours := int(duration.Hours()) % 24
if days == 1 {
if hours == 1 {
return "1 day, 1 hour ago"
} else if hours > 0 {
return fmt.Sprintf("1 day, %d hours ago", hours)
}
return "1 day ago"
}
if hours > 0 {
return fmt.Sprintf("%d days, %d hours ago", days, hours)
}
return fmt.Sprintf("%d days ago", days)
}
func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
peer.FQDN = a.AnonymizeDomain(peer.FQDN)
if localIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Local); err == nil {
peer.IceCandidateEndpoint.Local = fmt.Sprintf("%s:%s", a.AnonymizeIPString(localIP), port)
}
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
}
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeIPString(route)
}
for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeRoute(route)
}
}
func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) {
for i, peer := range overview.Peers.Details {
peer := peer
anonymizePeerDetail(a, &peer)
overview.Peers.Details[i] = peer
}
overview.ManagementState.URL = a.AnonymizeURI(overview.ManagementState.URL)
overview.ManagementState.Error = a.AnonymizeString(overview.ManagementState.Error)
overview.SignalState.URL = a.AnonymizeURI(overview.SignalState.URL)
overview.SignalState.Error = a.AnonymizeString(overview.SignalState.Error)
overview.IP = a.AnonymizeIPString(overview.IP)
for i, detail := range overview.Relays.Details {
detail.URI = a.AnonymizeURI(detail.URI)
detail.Error = a.AnonymizeString(detail.Error)
overview.Relays.Details[i] = detail
}
for i, nsGroup := range overview.NSServerGroups {
for j, domain := range nsGroup.Domains {
overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
}
for j, ns := range nsGroup.Servers {
host, port, err := net.SplitHostPort(ns)
if err == nil {
overview.NSServerGroups[i].Servers[j] = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
}
}
}
for i, route := range overview.Routes {
overview.Routes[i] = a.AnonymizeRoute(route)
}
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
}

View File

@@ -3,8 +3,6 @@ package cmd
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt"
"runtime"
"testing" "testing"
"time" "time"
@@ -37,6 +35,7 @@ var resp = &proto.StatusResponse{
ConnStatus: "Connected", ConnStatus: "Connected",
ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)), ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
Relayed: false, Relayed: false,
Direct: true,
LocalIceCandidateType: "", LocalIceCandidateType: "",
RemoteIceCandidateType: "", RemoteIceCandidateType: "",
LocalIceCandidateEndpoint: "", LocalIceCandidateEndpoint: "",
@@ -56,6 +55,7 @@ var resp = &proto.StatusResponse{
ConnStatus: "Connected", ConnStatus: "Connected",
ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)), ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
Relayed: true, Relayed: true,
Direct: false,
LocalIceCandidateType: "relay", LocalIceCandidateType: "relay",
RemoteIceCandidateType: "prflx", RemoteIceCandidateType: "prflx",
LocalIceCandidateEndpoint: "10.0.0.1:10001", LocalIceCandidateEndpoint: "10.0.0.1:10001",
@@ -135,6 +135,7 @@ var overview = statusOutputOverview{
Status: "Connected", Status: "Connected",
LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC), LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC),
ConnType: "P2P", ConnType: "P2P",
Direct: true,
IceCandidateType: iceCandidateType{ IceCandidateType: iceCandidateType{
Local: "", Local: "",
Remote: "", Remote: "",
@@ -158,6 +159,7 @@ var overview = statusOutputOverview{
Status: "Connected", Status: "Connected",
LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC), LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC),
ConnType: "Relayed", ConnType: "Relayed",
Direct: false,
IceCandidateType: iceCandidateType{ IceCandidateType: iceCandidateType{
Local: "relay", Local: "relay",
Remote: "prflx", Remote: "prflx",
@@ -279,6 +281,7 @@ func TestParsingToJSON(t *testing.T) {
"status": "Connected", "status": "Connected",
"lastStatusUpdate": "2001-01-01T01:01:01Z", "lastStatusUpdate": "2001-01-01T01:01:01Z",
"connectionType": "P2P", "connectionType": "P2P",
"direct": true,
"iceCandidateType": { "iceCandidateType": {
"local": "", "local": "",
"remote": "" "remote": ""
@@ -287,7 +290,6 @@ func TestParsingToJSON(t *testing.T) {
"local": "", "local": "",
"remote": "" "remote": ""
}, },
"relayAddress": "",
"lastWireguardHandshake": "2001-01-01T01:01:02Z", "lastWireguardHandshake": "2001-01-01T01:01:02Z",
"transferReceived": 200, "transferReceived": 200,
"transferSent": 100, "transferSent": 100,
@@ -304,6 +306,7 @@ func TestParsingToJSON(t *testing.T) {
"status": "Connected", "status": "Connected",
"lastStatusUpdate": "2002-02-02T02:02:02Z", "lastStatusUpdate": "2002-02-02T02:02:02Z",
"connectionType": "Relayed", "connectionType": "Relayed",
"direct": false,
"iceCandidateType": { "iceCandidateType": {
"local": "relay", "local": "relay",
"remote": "prflx" "remote": "prflx"
@@ -312,7 +315,6 @@ func TestParsingToJSON(t *testing.T) {
"local": "10.0.0.1:10001", "local": "10.0.0.1:10001",
"remote": "10.0.10.1:10002" "remote": "10.0.10.1:10002"
}, },
"relayAddress": "",
"lastWireguardHandshake": "2002-02-02T02:02:03Z", "lastWireguardHandshake": "2002-02-02T02:02:03Z",
"transferReceived": 2000, "transferReceived": 2000,
"transferSent": 1000, "transferSent": 1000,
@@ -404,13 +406,13 @@ func TestParsingToYAML(t *testing.T) {
status: Connected status: Connected
lastStatusUpdate: 2001-01-01T01:01:01Z lastStatusUpdate: 2001-01-01T01:01:01Z
connectionType: P2P connectionType: P2P
direct: true
iceCandidateType: iceCandidateType:
local: "" local: ""
remote: "" remote: ""
iceCandidateEndpoint: iceCandidateEndpoint:
local: "" local: ""
remote: "" remote: ""
relayAddress: ""
lastWireguardHandshake: 2001-01-01T01:01:02Z lastWireguardHandshake: 2001-01-01T01:01:02Z
transferReceived: 200 transferReceived: 200
transferSent: 100 transferSent: 100
@@ -424,13 +426,13 @@ func TestParsingToYAML(t *testing.T) {
status: Connected status: Connected
lastStatusUpdate: 2002-02-02T02:02:02Z lastStatusUpdate: 2002-02-02T02:02:02Z
connectionType: Relayed connectionType: Relayed
direct: false
iceCandidateType: iceCandidateType:
local: relay local: relay
remote: prflx remote: prflx
iceCandidateEndpoint: iceCandidateEndpoint:
local: 10.0.0.1:10001 local: 10.0.0.1:10001
remote: 10.0.10.1:10002 remote: 10.0.10.1:10002
relayAddress: ""
lastWireguardHandshake: 2002-02-02T02:02:03Z lastWireguardHandshake: 2002-02-02T02:02:03Z
transferReceived: 2000 transferReceived: 2000
transferSent: 1000 transferSent: 1000
@@ -485,15 +487,9 @@ dnsServers:
} }
func TestParsingToDetail(t *testing.T) { func TestParsingToDetail(t *testing.T) {
// Calculate time ago based on the fixture dates
lastConnectionUpdate1 := timeAgo(overview.Peers.Details[0].LastStatusUpdate)
lastHandshake1 := timeAgo(overview.Peers.Details[0].LastWireguardHandshake)
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
detail := parseToFullDetailSummary(overview) detail := parseToFullDetailSummary(overview)
expectedDetail := fmt.Sprintf( expectedDetail :=
`Peers detail: `Peers detail:
peer-1.awesome-domain.com: peer-1.awesome-domain.com:
NetBird IP: 192.168.178.101 NetBird IP: 192.168.178.101
@@ -501,11 +497,11 @@ func TestParsingToDetail(t *testing.T) {
Status: Connected Status: Connected
-- detail -- -- detail --
Connection type: P2P Connection type: P2P
Direct: true
ICE candidate (Local/Remote): -/- ICE candidate (Local/Remote): -/-
ICE candidate endpoints (Local/Remote): -/- ICE candidate endpoints (Local/Remote): -/-
Relay server address: Last connection update: 2001-01-01 01:01:01
Last connection update: %s Last WireGuard handshake: 2001-01-01 01:01:02
Last WireGuard handshake: %s
Transfer status (received/sent) 200 B/100 B Transfer status (received/sent) 200 B/100 B
Quantum resistance: false Quantum resistance: false
Routes: 10.1.0.0/24 Routes: 10.1.0.0/24
@@ -517,19 +513,18 @@ func TestParsingToDetail(t *testing.T) {
Status: Connected Status: Connected
-- detail -- -- detail --
Connection type: Relayed Connection type: Relayed
Direct: false
ICE candidate (Local/Remote): relay/prflx ICE candidate (Local/Remote): relay/prflx
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002 ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
Relay server address: Last connection update: 2002-02-02 02:02:02
Last connection update: %s Last WireGuard handshake: 2002-02-02 02:02:03
Last WireGuard handshake: %s
Transfer status (received/sent) 2.0 KiB/1000 B Transfer status (received/sent) 2.0 KiB/1000 B
Quantum resistance: false Quantum resistance: false
Routes: - Routes: -
Latency: 10ms Latency: 10ms
OS: %s/%s
Daemon version: 0.14.1 Daemon version: 0.14.1
CLI version: %s CLI version: development
Management: Connected to my-awesome-management.com:443 Management: Connected to my-awesome-management.com:443
Signal: Connected to my-awesome-signal.com:443 Signal: Connected to my-awesome-signal.com:443
Relays: Relays:
@@ -544,7 +539,7 @@ Interface type: Kernel
Quantum resistance: false Quantum resistance: false
Routes: 10.10.0.0/24 Routes: 10.10.0.0/24
Peers count: 2/2 Connected Peers count: 2/2 Connected
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion) `
assert.Equal(t, expectedDetail, detail) assert.Equal(t, expectedDetail, detail)
} }
@@ -552,8 +547,8 @@ Peers count: 2/2 Connected
func TestParsingToShortVersion(t *testing.T) { func TestParsingToShortVersion(t *testing.T) {
shortVersion := parseGeneralSummary(overview, false, false, false) shortVersion := parseGeneralSummary(overview, false, false, false)
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + ` expectedString :=
Daemon version: 0.14.1 `Daemon version: 0.14.1
CLI version: development CLI version: development
Management: Connected Management: Connected
Signal: Connected Signal: Connected
@@ -577,31 +572,3 @@ func TestParsingOfIP(t *testing.T) {
assert.Equal(t, "192.168.178.123\n", parsedIP) assert.Equal(t, "192.168.178.123\n", parsedIP)
} }
func TestTimeAgo(t *testing.T) {
now := time.Now()
cases := []struct {
name string
input time.Time
expected string
}{
{"Now", now, "Now"},
{"Seconds ago", now.Add(-10 * time.Second), "10 seconds ago"},
{"One minute ago", now.Add(-1 * time.Minute), "1 minute ago"},
{"Minutes and seconds ago", now.Add(-(1*time.Minute + 30*time.Second)), "1 minute, 30 seconds ago"},
{"One hour ago", now.Add(-1 * time.Hour), "1 hour ago"},
{"Hours and minutes ago", now.Add(-(2*time.Hour + 15*time.Minute)), "2 hours, 15 minutes ago"},
{"One day ago", now.Add(-24 * time.Hour), "1 day ago"},
{"Multiple days ago", now.Add(-(72*time.Hour + 20*time.Minute)), "3 days ago"},
{"Zero time", time.Time{}, "-"},
{"Unix zero time", time.Unix(0, 0), "-"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
result := timeAgo(tc.input)
assert.Equal(t, tc.expected, result, "Failed %s", tc.name)
})
}
}

View File

@@ -3,21 +3,17 @@ package cmd
import ( import (
"context" "context"
"net" "net"
"path/filepath"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
clientProto "github.com/netbirdio/netbird/client/proto" clientProto "github.com/netbirdio/netbird/client/proto"
client "github.com/netbirdio/netbird/client/server" client "github.com/netbirdio/netbird/client/server"
mgmtProto "github.com/netbirdio/netbird/management/proto" mgmtProto "github.com/netbirdio/netbird/management/proto"
@@ -33,12 +29,18 @@ func startTestingServices(t *testing.T) string {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
testDir := t.TempDir()
config.Datadir = testDir
err = util.CopyFileContents("../testdata/store.json", filepath.Join(testDir, "store.json"))
if err != nil {
t.Fatal(err)
}
_, signalLis := startSignal(t) _, signalLis := startSignal(t)
signalAddr := signalLis.Addr().String() signalAddr := signalLis.Addr().String()
config.Signal.URI = signalAddr config.Signal.URI = signalAddr
_, mgmLis := startManagement(t, config, "../testdata/store.sql") _, mgmLis := startManagement(t, config)
mgmAddr := mgmLis.Addr().String() mgmAddr := mgmLis.Addr().String()
return mgmAddr return mgmAddr
} }
@@ -50,10 +52,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err) t.Fatal(err)
} }
s := grpc.NewServer() s := grpc.NewServer()
srv, err := sig.NewServer(context.Background(), otel.Meter("")) sigProto.RegisterSignalExchangeServer(s, sig.NewServer())
require.NoError(t, err)
sigProto.RegisterSignalExchangeServer(s, srv)
go func() { go func() {
if err := s.Serve(lis); err != nil { if err := s.Serve(lis); err != nil {
panic(err) panic(err)
@@ -63,37 +62,30 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
return s, lis return s, lis
} }
func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) { func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) {
t.Helper() t.Helper()
lis, err := net.Listen("tcp", ":0") lis, err := net.Listen("tcp", ":0")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
s := grpc.NewServer() s := grpc.NewServer()
store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir()) store, err := mgmt.NewStoreFromJson(config.Datadir, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Cleanup(cleanUp)
peersUpdateManager := mgmt.NewPeersUpdateManager(nil) peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{} eventStore := &activity.InMemoryEventStore{}
if err != nil { if err != nil {
return nil, nil return nil, nil
} }
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore) iv, _ := integrations.NewIntegratedValidator(eventStore)
accountManager, err := mgmt.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) mgmtServer, err := mgmt.NewServer(config, accountManager, peersUpdateManager, turnManager, nil, nil)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -108,7 +100,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
} }
func startClientDaemon( func startClientDaemon(
t *testing.T, ctx context.Context, _, configPath string, t *testing.T, ctx context.Context, managementURL, configPath string,
) (*grpc.Server, net.Listener) { ) (*grpc.Server, net.Listener) {
t.Helper() t.Helper()
lis, err := net.Listen("tcp", "127.0.0.1:0") lis, err := net.Listen("tcp", "127.0.0.1:0")

View File

@@ -7,19 +7,17 @@ import (
"net/netip" "net/netip"
"runtime" "runtime"
"strings" "strings"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -42,12 +40,7 @@ func init() {
upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground") upCmd.PersistentFlags().BoolVarP(&foregroundMode, "foreground-mode", "F", false, "start service in foreground")
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name") upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port") upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux. `+
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
)
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening") upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
} }
func upFunc(cmd *cobra.Command, args []string) error { func upFunc(cmd *cobra.Command, args []string) error {
@@ -123,10 +116,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
ic.WireguardPort = &p ic.WireguardPort = &p
} }
if cmd.Flag(networkMonitorFlag).Changed {
ic.NetworkMonitor = &networkMonitor
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
ic.PreSharedKey = &preSharedKey ic.PreSharedKey = &preSharedKey
} }
@@ -143,15 +132,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
} }
} }
if cmd.Flag(dnsRouteIntervalFlag).Changed {
ic.DNSRouteInterval = &dnsRouteInterval
}
providedSetupKey, err := getSetupKey()
if err != nil {
return err
}
config, err := internal.UpdateOrCreateConfig(ic) config, err := internal.UpdateOrCreateConfig(ic)
if err != nil { if err != nil {
return fmt.Errorf("get config file: %v", err) return fmt.Errorf("get config file: %v", err)
@@ -159,7 +139,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath) config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey) err = foregroundLogin(ctx, cmd, config, setupKey)
if err != nil { if err != nil {
return fmt.Errorf("foreground login failed: %v", err) return fmt.Errorf("foreground login failed: %v", err)
} }
@@ -167,12 +147,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
var cancel context.CancelFunc var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx) ctx, cancel = context.WithCancel(ctx)
SetupCloseHandler(ctx, cancel) SetupCloseHandler(ctx, cancel)
return internal.RunClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r)
return connectClient.Run()
} }
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
@@ -207,13 +182,8 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
return nil return nil
} }
providedSetupKey, err := getSetupKey()
if err != nil {
return err
}
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey, SetupKey: setupKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
AdminURL: adminURL, AdminURL: adminURL,
NatExternalIPs: natExternalIPs, NatExternalIPs: natExternalIPs,
@@ -256,14 +226,6 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
loginRequest.WireguardPort = &wp loginRequest.WireguardPort = &wp
} }
if cmd.Flag(networkMonitorFlag).Changed {
loginRequest.NetworkMonitor = &networkMonitor
}
if cmd.Flag(dnsRouteIntervalFlag).Changed {
loginRequest.DnsRouteInterval = durationpb.New(dnsRouteInterval)
}
var loginErr error var loginErr error
var loginResp *proto.LoginResponse var loginResp *proto.LoginResponse

View File

@@ -2,7 +2,6 @@ package cmd
import ( import (
"context" "context"
"os"
"testing" "testing"
"time" "time"
@@ -41,36 +40,6 @@ func TestUpDaemon(t *testing.T) {
return return
} }
// Test the setup-key-file flag.
tempFile, err := os.CreateTemp("", "setup-key")
if err != nil {
t.Errorf("could not create temp file, got error %v", err)
return
}
defer os.Remove(tempFile.Name())
if _, err := tempFile.Write([]byte("A2C8E62B-38F5-4553-B31E-DD66C696CEBB")); err != nil {
t.Errorf("could not write to temp file, got error %v", err)
return
}
if err := tempFile.Close(); err != nil {
t.Errorf("unable to close file, got error %v", err)
}
rootCmd.SetArgs([]string{
"login",
"--daemon-addr", "tcp://" + cliAddr,
"--setup-key-file", tempFile.Name(),
"--log-file", "",
})
if err := rootCmd.Execute(); err != nil {
t.Errorf("expected no error while running up command, got %v", err)
return
}
time.Sleep(time.Second * 3)
if status, err := state.Status(); err != nil && status != internal.StatusIdle {
t.Errorf("wrong status after login: %s, %v", internal.StatusIdle, err)
return
}
rootCmd.SetArgs([]string{ rootCmd.SetArgs([]string{
"up", "up",
"--daemon-addr", "tcp://" + cliAddr, "--daemon-addr", "tcp://" + cliAddr,

View File

@@ -1,30 +0,0 @@
package errors
import (
"fmt"
"strings"
"github.com/hashicorp/go-multierror"
)
func formatError(es []error) string {
if len(es) == 1 {
return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
}
points := make([]string, len(es))
for i, err := range es {
points[i] = fmt.Sprintf("* %s", err)
}
return fmt.Sprintf(
"%d errors occurred:\n\t%s",
len(es), strings.Join(points, "\n\t"))
}
func FormatErrorOrNil(err *multierror.Error) error {
if err != nil {
err.ErrorFormat = formatError
}
return err.ErrorOrNil()
}

View File

@@ -42,20 +42,20 @@ func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager,
switch check() { switch check() {
case IPTABLES: case IPTABLES:
log.Info("creating an iptables firewall manager") log.Debug("creating an iptables firewall manager")
fm, errFw = nbiptables.Create(context, iface) fm, errFw = nbiptables.Create(context, iface)
if errFw != nil { if errFw != nil {
log.Errorf("failed to create iptables manager: %s", errFw) log.Errorf("failed to create iptables manager: %s", errFw)
} }
case NFTABLES: case NFTABLES:
log.Info("creating an nftables firewall manager") log.Debug("creating an nftables firewall manager")
fm, errFw = nbnftables.Create(context, iface) fm, errFw = nbnftables.Create(context, iface)
if errFw != nil { if errFw != nil {
log.Errorf("failed to create nftables manager: %s", errFw) log.Errorf("failed to create nftables manager: %s", errFw)
} }
default: default:
errFw = fmt.Errorf("no firewall manager found") errFw = fmt.Errorf("no firewall manager found")
log.Info("no firewall manager found, trying to use userspace packet filtering firewall") log.Debug("no firewall manager found, try to use userspace packet filtering firewall")
} }
if iface.IsUserspaceBind() { if iface.IsUserspaceBind() {
@@ -85,58 +85,16 @@ func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager,
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found. // check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func check() FWType { func check() FWType {
useIPTABLES := false
var iptablesChains []string
ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err == nil && isIptablesClientAvailable(ip) {
major, minor, _ := ip.GetIptablesVersion()
// use iptables when its version is lower than 1.8.0 which doesn't work well with our nftables manager
if major < 1 || (major == 1 && minor < 8) {
return IPTABLES
}
useIPTABLES = true
iptablesChains, err = ip.ListChains("filter")
if err != nil {
log.Errorf("failed to list iptables chains: %s", err)
useIPTABLES = false
}
}
nf := nftables.Conn{} nf := nftables.Conn{}
if chains, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" { if _, err := nf.ListChains(); err == nil && os.Getenv(SKIP_NFTABLES_ENV) != "true" {
if !useIPTABLES {
return NFTABLES return NFTABLES
} }
// search for chains where table is filter ip, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
// if we find one, we assume that nftables manager can be used with iptables if err != nil {
for _, chain := range chains { return UNKNOWN
if chain.Table.Name == "filter" {
return NFTABLES
} }
} if isIptablesClientAvailable(ip) {
// check tables for the following constraints:
// 1. there is no chain in nftables for the filter table and there is at least one chain in iptables, we assume that nftables manager can not be used
// 2. there is no tables or more than one table, we assume that nftables manager can be used
// 3. there is only one table and its name is filter, we assume that nftables manager can not be used, since there was no chain in it
// 4. if we find an error we log and continue with iptables check
nbTablesList, err := nf.ListTables()
switch {
case err == nil && len(iptablesChains) > 0:
return IPTABLES
case err == nil && len(nbTablesList) != 1:
return NFTABLES
case err == nil && len(nbTablesList) == 1 && nbTablesList[0].Name == "filter":
return IPTABLES
case err != nil:
log.Errorf("failed to list nftables tables on fw manager discovery: %s", err)
}
}
if useIPTABLES {
return IPTABLES return IPTABLES
} }

View File

@@ -1,13 +1,11 @@
package firewall package firewall
import ( import "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
// IFaceMapper defines subset methods of interface required for manager // IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface { type IFaceMapper interface {
Name() string Name() string
Address() device.WGAddress Address() iface.WGAddress
IsUserspaceBind() bool IsUserspaceBind() bool
SetFilter(device.PacketFilter) error SetFilter(iface.PacketFilter) error
} }

View File

@@ -11,7 +11,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnet "github.com/netbirdio/netbird/util/net"
) )
const ( const (
@@ -20,31 +19,26 @@ const (
// rules chains contains the effective ACL rules // rules chains contains the effective ACL rules
chainNameInputRules = "NETBIRD-ACL-INPUT" chainNameInputRules = "NETBIRD-ACL-INPUT"
chainNameOutputRules = "NETBIRD-ACL-OUTPUT" chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
)
type entry struct { postRoutingMark = "0x000007e4"
spec []string )
position int
}
type aclManager struct { type aclManager struct {
iptablesClient *iptables.IPTables iptablesClient *iptables.IPTables
wgIface iFaceMapper wgIface iFaceMapper
routingFwChainName string routeingFwChainName string
entries map[string][][]string entries map[string][][]string
optionalEntries map[string][]entry
ipsetStore *ipsetStore ipsetStore *ipsetStore
} }
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) { func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routeingFwChainName string) (*aclManager, error) {
m := &aclManager{ m := &aclManager{
iptablesClient: iptablesClient, iptablesClient: iptablesClient,
wgIface: wgIface, wgIface: wgIface,
routingFwChainName: routingFwChainName, routeingFwChainName: routeingFwChainName,
entries: make(map[string][][]string), entries: make(map[string][][]string),
optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(), ipsetStore: newIpsetStore(),
} }
@@ -54,7 +48,6 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi
} }
m.seedInitialEntries() m.seedInitialEntries()
m.seedInitialOptionalEntries()
err = m.cleanChains() err = m.cleanChains()
if err != nil { if err != nil {
@@ -68,7 +61,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi
return m, nil return m, nil
} }
func (m *aclManager) AddPeerFiltering( func (m *aclManager) AddFiltering(
ip net.IP, ip net.IP,
protocol firewall.Protocol, protocol firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
@@ -134,7 +127,7 @@ func (m *aclManager) AddPeerFiltering(
return nil, fmt.Errorf("rule already exists") return nil, fmt.Errorf("rule already exists")
} }
if err := m.iptablesClient.Append("filter", chain, specs...); err != nil { if err := m.iptablesClient.Insert("filter", chain, 1, specs...); err != nil {
return nil, err return nil, err
} }
@@ -146,16 +139,28 @@ func (m *aclManager) AddPeerFiltering(
chain: chain, chain: chain,
} }
if !shouldAddToPrerouting(protocol, dPort, direction) {
return []firewall.Rule{rule}, nil return []firewall.Rule{rule}, nil
}
rulePrerouting, err := m.addPreroutingFilter(ipsetName, string(protocol), dPortVal, ip)
if err != nil {
return []firewall.Rule{rule}, err
}
return []firewall.Rule{rule, rulePrerouting}, nil
} }
// DeletePeerRule from the firewall by rule definition // DeleteRule from the firewall by rule definition
func (m *aclManager) DeletePeerRule(rule firewall.Rule) error { func (m *aclManager) DeleteRule(rule firewall.Rule) error {
r, ok := rule.(*Rule) r, ok := rule.(*Rule)
if !ok { if !ok {
return fmt.Errorf("invalid rule type") return fmt.Errorf("invalid rule type")
} }
if r.chain == "PREROUTING" {
goto DELETERULE
}
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok { if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
// delete IP from ruleset IPs list and ipset // delete IP from ruleset IPs list and ipset
if _, ok := ipsetList.ips[r.ip]; ok { if _, ok := ipsetList.ips[r.ip]; ok {
@@ -180,7 +185,14 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
} }
} }
err := m.iptablesClient.Delete(tableName, r.chain, r.specs...) DELETERULE:
var table string
if r.chain == "PREROUTING" {
table = "mangle"
} else {
table = "filter"
}
err := m.iptablesClient.Delete(table, r.chain, r.specs...)
if err != nil { if err != nil {
log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err) log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err)
} }
@@ -191,6 +203,44 @@ func (m *aclManager) Reset() error {
return m.cleanChains() return m.cleanChains()
} }
func (m *aclManager) addPreroutingFilter(ipsetName string, protocol string, port string, ip net.IP) (*Rule, error) {
var src []string
if ipsetName != "" {
src = []string{"-m", "set", "--set", ipsetName, "src"}
} else {
src = []string{"-s", ip.String()}
}
specs := []string{
"-d", m.wgIface.Address().IP.String(),
"-p", protocol,
"--dport", port,
"-j", "MARK", "--set-mark", postRoutingMark,
}
specs = append(src, specs...)
ok, err := m.iptablesClient.Exists("mangle", "PREROUTING", specs...)
if err != nil {
return nil, fmt.Errorf("failed to check rule: %w", err)
}
if ok {
return nil, fmt.Errorf("rule already exists")
}
if err := m.iptablesClient.Insert("mangle", "PREROUTING", 1, specs...); err != nil {
return nil, err
}
rule := &Rule{
ruleID: uuid.New().String(),
specs: specs,
ipsetName: ipsetName,
ip: ip.String(),
chain: "PREROUTING",
}
return rule, nil
}
// todo write less destructive cleanup mechanism // todo write less destructive cleanup mechanism
func (m *aclManager) cleanChains() error { func (m *aclManager) cleanChains() error {
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules) ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules)
@@ -243,7 +293,8 @@ func (m *aclManager) cleanChains() error {
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING") ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
if err != nil { if err != nil {
return fmt.Errorf("list chains: %w", err) log.Debugf("failed to list chains: %s", err)
return err
} }
if ok { if ok {
for _, rule := range m.entries["PREROUTING"] { for _, rule := range m.entries["PREROUTING"] {
@@ -252,6 +303,11 @@ func (m *aclManager) cleanChains() error {
log.Errorf("failed to delete rule: %v, %s", rule, err) log.Errorf("failed to delete rule: %v, %s", rule, err)
} }
} }
err = m.iptablesClient.ClearChain("mangle", "PREROUTING")
if err != nil {
log.Debugf("failed to clear %s chain: %s", "PREROUTING", err)
return err
}
} }
for _, ipsetName := range m.ipsetStore.ipsetNames() { for _, ipsetName := range m.ipsetStore.ipsetNames() {
@@ -282,66 +338,58 @@ func (m *aclManager) createDefaultChains() error {
for chainName, rules := range m.entries { for chainName, rules := range m.entries {
for _, rule := range rules { for _, rule := range rules {
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil { if chainName == "FORWARD" {
// position 2 because we add it after router's, jump rule
if err := m.iptablesClient.InsertUnique(tableName, "FORWARD", 2, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err)
return err
}
} else {
if err := m.iptablesClient.AppendUnique(tableName, chainName, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err) log.Debugf("failed to create input chain jump rule: %s", err)
return err return err
} }
} }
} }
for chainName, entries := range m.optionalEntries {
for _, entry := range entries {
if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil {
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
continue
} }
m.entries[chainName] = append(m.entries[chainName], entry.spec)
}
}
clear(m.optionalEntries)
return nil return nil
} }
// seedInitialEntries adds default rules to the entries map, rules are inserted on pos 1, hence the order is reversed.
// We want to make sure our traffic is not dropped by existing rules.
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
func (m *aclManager) seedInitialEntries() { func (m *aclManager) seedInitialEntries() {
m.appendToEntries("INPUT",
[]string{"-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
established := getConntrackEstablished() m.appendToEntries("INPUT",
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("INPUT",
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...)) m.appendToEntries("OUTPUT",
[]string{"-o", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("OUTPUT",
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("OUTPUT",
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameOutputRules})
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"}) m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", chainNameOutputRules})
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("OUTPUT", append([]string{"-o", m.wgIface.Name()}, established...))
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName}) m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...)) m.appendToEntries("FORWARD",
} []string{"-o", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
m.appendToEntries("FORWARD",
[]string{"-i", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", m.routeingFwChainName})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routeingFwChainName})
func (m *aclManager) seedInitialOptionalEntries() { m.appendToEntries("PREROUTING",
m.optionalEntries["FORWARD"] = []entry{ []string{"-t", "mangle", "-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().IP.String(), "-m", "mark", "--mark", postRoutingMark})
{
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark), "-j", chainNameInputRules},
position: 2,
},
}
m.optionalEntries["PREROUTING"] = []entry{
{
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark)},
position: 1,
},
}
} }
func (m *aclManager) appendToEntries(chainName string, spec []string) { func (m *aclManager) appendToEntries(chainName string, spec []string) {
@@ -408,3 +456,18 @@ func transformIPsetName(ipsetName string, sPort, dPort string) string {
return ipsetName return ipsetName
} }
} }
func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
if proto == "all" {
return false
}
if direction != firewall.RuleDirectionIN {
return false
}
if dPort == nil {
return false
}
return true
}

View File

@@ -4,14 +4,13 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/iface"
) )
// Manager of iptables firewall // Manager of iptables firewall
@@ -22,7 +21,7 @@ type Manager struct {
ipv4Client *iptables.IPTables ipv4Client *iptables.IPTables
aclMgr *aclManager aclMgr *aclManager
router *router router *routerManager
} }
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
@@ -44,12 +43,12 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
ipv4Client: iptablesClient, ipv4Client: iptablesClient,
} }
m.router, err = newRouter(context, iptablesClient, wgIface) m.router, err = newRouterManager(context, iptablesClient)
if err != nil { if err != nil {
log.Debugf("failed to initialize route related chains: %s", err) log.Debugf("failed to initialize route related chains: %s", err)
return nil, err return nil, err
} }
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD) m.aclMgr, err = newAclManager(iptablesClient, wgIface, m.router.RouteingFwChainName())
if err != nil { if err != nil {
log.Debugf("failed to initialize ACL manager: %s", err) log.Debugf("failed to initialize ACL manager: %s", err)
return nil, err return nil, err
@@ -58,10 +57,10 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
return m, nil return m, nil
} }
// AddPeerFiltering adds a rule to the firewall // AddFiltering rule to the firewall
// //
// Comment will be ignored because some system this feature is not supported // Comment will be ignored because some system this feature is not supported
func (m *Manager) AddPeerFiltering( func (m *Manager) AddFiltering(
ip net.IP, ip net.IP,
protocol firewall.Protocol, protocol firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
@@ -74,62 +73,33 @@ func (m *Manager) AddPeerFiltering(
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName) return m.aclMgr.AddFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
} }
func (m *Manager) AddRouteFiltering( // DeleteRule from the firewall by rule definition
sources []netip.Prefix, func (m *Manager) DeleteRule(rule firewall.Rule) error {
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if !destination.Addr().Is4() { return m.aclMgr.DeleteRule(rule)
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
}
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclMgr.DeletePeerRule(rule)
}
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.DeleteRouteRule(rule)
} }
func (m *Manager) IsServerRouteSupported() bool { func (m *Manager) IsServerRouteSupported() bool {
return true return true
} }
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.AddNatRule(pair) return m.router.InsertRoutingRules(pair)
} }
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.RemoveNatRule(pair) return m.router.RemoveRoutingRules(pair)
}
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return firewall.SetLegacyManagement(m.router, isLegacy)
} }
// Reset firewall to the default state // Reset firewall to the default state
@@ -155,7 +125,7 @@ func (m *Manager) AllowNetbird() error {
return nil return nil
} }
_, err := m.AddPeerFiltering( _, err := m.AddFiltering(
net.ParseIP("0.0.0.0"), net.ParseIP("0.0.0.0"),
"all", "all",
nil, nil,
@@ -168,7 +138,7 @@ func (m *Manager) AllowNetbird() error {
if err != nil { if err != nil {
return fmt.Errorf("failed to allow netbird interface traffic: %w", err) return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
} }
_, err = m.AddPeerFiltering( _, err = m.AddFiltering(
net.ParseIP("0.0.0.0"), net.ParseIP("0.0.0.0"),
"all", "all",
nil, nil,
@@ -183,7 +153,3 @@ func (m *Manager) AllowNetbird() error {
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil } func (m *Manager) Flush() error { return nil }
func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}

View File

@@ -11,24 +11,9 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/iface"
) )
var ifaceMock = &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct { type iFaceMock struct {
NameFunc func() string NameFunc func() string
@@ -55,8 +40,23 @@ func TestIptablesManager(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err) require.NoError(t, err)
mock := &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
// just check on the local interface // just check on the local interface
manager, err := Create(context.Background(), ifaceMock) manager, err := Create(context.Background(), mock)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -72,7 +72,7 @@ func TestIptablesManager(t *testing.T) {
t.Run("add first rule", func(t *testing.T) { t.Run("add first rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2") ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}} port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
for _, r := range rule1 { for _, r := range rule1 {
@@ -87,7 +87,7 @@ func TestIptablesManager(t *testing.T) {
port := &fw.Port{ port := &fw.Port{
Values: []int{8043: 8046}, Values: []int{8043: 8046},
} }
rule2, err = manager.AddPeerFiltering( rule2, err = manager.AddFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range") ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
@@ -99,7 +99,7 @@ func TestIptablesManager(t *testing.T) {
t.Run("delete first rule", func(t *testing.T) { t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 { for _, r := range rule1 {
err := manager.DeletePeerRule(r) err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule") require.NoError(t, err, "failed to delete rule")
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...) checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
@@ -108,7 +108,7 @@ func TestIptablesManager(t *testing.T) {
t.Run("delete second rule", func(t *testing.T) { t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 { for _, r := range rule2 {
err := manager.DeletePeerRule(r) err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule") require.NoError(t, err, "failed to delete rule")
} }
@@ -119,7 +119,7 @@ func TestIptablesManager(t *testing.T) {
// add second rule // add second rule
ip := net.ParseIP("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []int{5353}} port := &fw.Port{Values: []int{5353}}
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic") _, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Reset() err = manager.Reset()
@@ -170,7 +170,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
t.Run("add first rule with set", func(t *testing.T) { t.Run("add first rule with set", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2") ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}} port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddPeerFiltering( rule1, err = manager.AddFiltering(
ip, "tcp", nil, port, fw.RuleDirectionOUT, ip, "tcp", nil, port, fw.RuleDirectionOUT,
fw.ActionAccept, "default", "accept HTTP traffic", fw.ActionAccept, "default", "accept HTTP traffic",
) )
@@ -189,7 +189,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
port := &fw.Port{ port := &fw.Port{
Values: []int{443}, Values: []int{443},
} }
rule2, err = manager.AddPeerFiltering( rule2, err = manager.AddFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
"default", "accept HTTPS traffic from ports range", "default", "accept HTTPS traffic from ports range",
) )
@@ -202,7 +202,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
t.Run("delete first rule", func(t *testing.T) { t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 { for _, r := range rule1 {
err := manager.DeletePeerRule(r) err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule") require.NoError(t, err, "failed to delete rule")
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index") require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
@@ -211,7 +211,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
t.Run("delete second rule", func(t *testing.T) { t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 { for _, r := range rule2 {
err := manager.DeletePeerRule(r) err := manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule") require.NoError(t, err, "failed to delete rule")
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty") require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
@@ -269,9 +269,9 @@ func TestIptablesCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 { if i%2 == 0 {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else { } else {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
} }
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")

View File

@@ -5,22 +5,16 @@ package iptables
import ( import (
"context" "context"
"fmt" "fmt"
"net/netip"
"strconv"
"strings" "strings"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/hashicorp/go-multierror"
"github.com/nadoo/ipset"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
) )
const ( const (
Ipv4Forwarding = "netbird-rt-forwarding"
ipv4Nat = "netbird-rt-nat" ipv4Nat = "netbird-rt-nat"
) )
@@ -28,456 +22,319 @@ const (
const ( const (
tableFilter = "filter" tableFilter = "filter"
tableNat = "nat" tableNat = "nat"
chainFORWARD = "FORWARD"
chainPOSTROUTING = "POSTROUTING" chainPOSTROUTING = "POSTROUTING"
chainRTNAT = "NETBIRD-RT-NAT" chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWD = "NETBIRD-RT-FWD" chainRTFWD = "NETBIRD-RT-FWD"
routingFinalForwardJump = "ACCEPT" routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE" routingFinalNatJump = "MASQUERADE"
matchSet = "--match-set"
) )
type routeFilteringRuleParams struct { type routerManager struct {
Sources []netip.Prefix
Destination netip.Prefix
Proto firewall.Protocol
SPort *firewall.Port
DPort *firewall.Port
Direction firewall.RuleDirection
Action firewall.Action
SetName string
}
type router struct {
ctx context.Context ctx context.Context
stop context.CancelFunc stop context.CancelFunc
iptablesClient *iptables.IPTables iptablesClient *iptables.IPTables
rules map[string][]string rules map[string][]string
ipsetCounter *refcounter.Counter[string, []netip.Prefix, struct{}]
wgIface iFaceMapper
legacyManagement bool
} }
func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { func newRouterManager(parentCtx context.Context, iptablesClient *iptables.IPTables) (*routerManager, error) {
ctx, cancel := context.WithCancel(parentCtx) ctx, cancel := context.WithCancel(parentCtx)
r := &router{ m := &routerManager{
ctx: ctx, ctx: ctx,
stop: cancel, stop: cancel,
iptablesClient: iptablesClient, iptablesClient: iptablesClient,
rules: make(map[string][]string), rules: make(map[string][]string),
wgIface: wgIface,
} }
r.ipsetCounter = refcounter.New( err := m.cleanUpDefaultForwardRules()
r.createIpSet,
func(name string, _ struct{}) error {
return r.deleteIpSet(name)
},
)
if err := ipset.Init(); err != nil {
return nil, fmt.Errorf("init ipset: %w", err)
}
err := r.cleanUpDefaultForwardRules()
if err != nil { if err != nil {
log.Errorf("cleanup routing rules: %s", err) log.Errorf("failed to cleanup routing rules: %s", err)
return nil, err return nil, err
} }
err = r.createContainers() err = m.createContainers()
if err != nil { if err != nil {
log.Errorf("create containers for route: %s", err) log.Errorf("failed to create containers for route: %s", err)
} }
return r, err return m, err
} }
func (r *router) AddRouteFiltering( // InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain
sources []netip.Prefix, func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
destination netip.Prefix, err := i.insertRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, pair)
proto firewall.Protocol, if err != nil {
sPort *firewall.Port, return err
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil
} }
var setName string err = i.insertRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, firewall.GetInPair(pair))
if len(sources) > 1 { if err != nil {
setName = firewall.GenerateSetName(sources) return err
if _, err := r.ipsetCounter.Increment(setName, sources); err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err)
}
}
params := routeFilteringRuleParams{
Sources: sources,
Destination: destination,
Proto: proto,
SPort: sPort,
DPort: dPort,
Action: action,
SetName: setName,
}
rule := genRouteFilteringRuleSpec(params)
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
return nil, fmt.Errorf("add route rule: %v", err)
}
r.rules[string(ruleKey)] = rule
return ruleKey, nil
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
ruleKey := rule.GetRuleID()
if rule, exists := r.rules[ruleKey]; exists {
setName := r.findSetNameInRule(rule)
if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil {
return fmt.Errorf("delete route rule: %v", err)
}
delete(r.rules, ruleKey)
if setName != "" {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
return fmt.Errorf("failed to remove ipset: %w", err)
}
}
} else {
log.Debugf("route rule %s not found", ruleKey)
}
return nil
}
func (r *router) findSetNameInRule(rule []string) string {
for i, arg := range rule {
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
return rule[i+3]
}
}
return ""
}
func (r *router) createIpSet(setName string, sources []netip.Prefix) (struct{}, error) {
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
return struct{}{}, fmt.Errorf("create set %s: %w", setName, err)
}
for _, prefix := range sources {
if err := ipset.AddPrefix(setName, prefix); err != nil {
return struct{}{}, fmt.Errorf("add element to set %s: %w", setName, err)
}
}
return struct{}{}, nil
}
func (r *router) deleteIpSet(setName string) error {
if err := ipset.Destroy(setName); err != nil {
return fmt.Errorf("destroy set %s: %w", setName, err)
}
return nil
}
// AddNatRule inserts an iptables rule pair into the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error {
if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil {
return fmt.Errorf("add legacy routing rule: %w", err)
}
} }
if !pair.Masquerade { if !pair.Masquerade {
return nil return nil
} }
if err := r.addNatRule(pair); err != nil { err = i.insertRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
return fmt.Errorf("add nat rule: %w", err) if err != nil {
}
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("add inverse nat rule: %w", err)
}
return nil
}
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse nat rule: %w", err)
}
if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err)
}
return nil
}
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if err := r.removeLegacyRouteRule(pair); err != nil {
return err return err
} }
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump} err = i.insertRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil { if err != nil {
return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) return err
}
r.rules[ruleKey] = rule
return nil
}
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
delete(r.rules, ruleKey)
} else {
log.Debugf("legacy forwarding rule %s not found", ruleKey)
} }
return nil return nil
} }
// GetLegacyManagement returns the current legacy management mode // insertRoutingRule inserts an iptable rule
func (r *router) GetLegacyManagement() bool { func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
return r.legacyManagement var err error
ruleKey := firewall.GenKey(keyFormat, pair.ID)
rule := genRuleSpec(jump, ruleKey, pair.Source, pair.Destination)
existingRule, found := i.rules[ruleKey]
if found {
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
}
delete(i.rules, ruleKey)
}
err = i.iptablesClient.Insert(table, chain, 1, rule...)
if err != nil {
return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
}
i.rules[ruleKey] = rule
return nil
} }
// SetLegacyManagement sets the route manager to use legacy management mode // RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains
func (r *router) SetLegacyManagement(isLegacy bool) { func (i *routerManager) RemoveRoutingRules(pair firewall.RouterPair) error {
r.legacyManagement = isLegacy err := i.removeRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, pair)
if err != nil {
return err
}
err = i.removeRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, firewall.GetInPair(pair))
if err != nil {
return err
}
if !pair.Masquerade {
return nil
}
err = i.removeRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, pair)
if err != nil {
return err
}
err = i.removeRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, firewall.GetInPair(pair))
if err != nil {
return err
}
return nil
} }
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls func (i *routerManager) removeRoutingRule(keyFormat, table, chain string, pair firewall.RouterPair) error {
func (r *router) RemoveAllLegacyRouteRules() error { var err error
var merr *multierror.Error
for k, rule := range r.rules { ruleKey := firewall.GenKey(keyFormat, pair.ID)
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) { existingRule, found := i.rules[ruleKey]
continue if found {
} err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil { if err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
} }
} }
return nberrors.FormatErrorOrNil(merr) delete(i.rules, ruleKey)
return nil
} }
func (r *router) Reset() error { func (i *routerManager) RouteingFwChainName() string {
var merr *multierror.Error return chainRTFWD
if err := r.cleanUpDefaultForwardRules(); err != nil {
merr = multierror.Append(merr, err)
}
r.rules = make(map[string][]string)
if err := r.ipsetCounter.Flush(); err != nil {
merr = multierror.Append(merr, err)
}
return nberrors.FormatErrorOrNil(merr)
} }
func (r *router) cleanUpDefaultForwardRules() error { func (i *routerManager) Reset() error {
err := r.cleanJumpRules() err := i.cleanUpDefaultForwardRules()
if err != nil {
return err
}
i.rules = make(map[string][]string)
return nil
}
func (i *routerManager) cleanUpDefaultForwardRules() error {
err := i.cleanJumpRules()
if err != nil { if err != nil {
return err return err
} }
log.Debug("flushing routing related tables") log.Debug("flushing routing related tables")
for _, chain := range []string{chainRTFWD, chainRTNAT} { ok, err := i.iptablesClient.ChainExists(tableFilter, chainRTFWD)
table := r.getTableForChain(chain)
ok, err := r.iptablesClient.ChainExists(table, chain)
if err != nil { if err != nil {
log.Errorf("failed check chain %s, error: %v", chain, err) log.Errorf("failed check chain %s,error: %v", chainRTFWD, err)
return err return err
} else if ok { } else if ok {
err = r.iptablesClient.ClearAndDeleteChain(table, chain) err = i.iptablesClient.ClearAndDeleteChain(tableFilter, chainRTFWD)
if err != nil { if err != nil {
log.Errorf("failed cleaning chain %s, error: %v", chain, err) log.Errorf("failed cleaning chain %s,error: %v", chainRTFWD, err)
return err return err
} }
} }
}
return nil ok, err = i.iptablesClient.ChainExists(tableNat, chainRTNAT)
}
func (r *router) createContainers() error {
for _, chain := range []string{chainRTFWD, chainRTNAT} {
if err := r.createAndSetupChain(chain); err != nil {
return fmt.Errorf("create chain %s: %w", chain, err)
}
}
if err := r.insertEstablishedRule(chainRTFWD); err != nil {
return fmt.Errorf("insert established rule: %w", err)
}
if err := r.addJumpRules(); err != nil {
return fmt.Errorf("add jump rules: %w", err)
}
return nil
}
func (r *router) createAndSetupChain(chain string) error {
table := r.getTableForChain(chain)
if err := r.iptablesClient.NewChain(table, chain); err != nil {
return fmt.Errorf("failed creating chain %s, error: %v", chain, err)
}
return nil
}
func (r *router) getTableForChain(chain string) string {
if chain == chainRTNAT {
return tableNat
}
return tableFilter
}
func (r *router) insertEstablishedRule(chain string) error {
establishedRule := getConntrackEstablished()
err := r.iptablesClient.Insert(tableFilter, chain, 1, establishedRule...)
if err != nil { if err != nil {
return fmt.Errorf("failed to insert established rule: %v", err) log.Errorf("failed check chain %s,error: %v", chainRTNAT, err)
return err
} else if ok {
err = i.iptablesClient.ClearAndDeleteChain(tableNat, chainRTNAT)
if err != nil {
log.Errorf("failed cleaning chain %s,error: %v", chainRTNAT, err)
return err
}
}
return nil
}
func (i *routerManager) createContainers() error {
if i.rules[Ipv4Forwarding] != nil {
return nil
} }
ruleKey := "established-" + chain errMSGFormat := "failed creating chain %s,error: %v"
r.rules[ruleKey] = establishedRule err := i.createChain(tableFilter, chainRTFWD)
if err != nil {
return fmt.Errorf(errMSGFormat, chainRTFWD, err)
}
err = i.createChain(tableNat, chainRTNAT)
if err != nil {
return fmt.Errorf(errMSGFormat, chainRTNAT, err)
}
err = i.addJumpRules()
if err != nil {
return fmt.Errorf("error while creating jump rules: %v", err)
}
return nil return nil
} }
func (r *router) addJumpRules() error { // addJumpRules create jump rules to send packets to NetBird chains
rule := []string{"-j", chainRTNAT} func (i *routerManager) addJumpRules() error {
err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...) rule := []string{"-j", chainRTFWD}
err := i.iptablesClient.Insert(tableFilter, chainFORWARD, 1, rule...)
if err != nil { if err != nil {
return err return err
} }
r.rules[ipv4Nat] = rule i.rules[Ipv4Forwarding] = rule
rule = []string{"-j", chainRTNAT}
err = i.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
if err != nil {
return err
}
i.rules[ipv4Nat] = rule
return nil return nil
} }
func (r *router) cleanJumpRules() error { // cleanJumpRules cleans jump rules that was sending packets to NetBird chains
rule, found := r.rules[ipv4Nat] func (i *routerManager) cleanJumpRules() error {
var err error
errMSGFormat := "failed cleaning rule from chain %s,err: %v"
rule, found := i.rules[Ipv4Forwarding]
if found { if found {
err := r.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...) err = i.iptablesClient.DeleteIfExists(tableFilter, chainFORWARD, rule...)
if err != nil { if err != nil {
return fmt.Errorf("failed cleaning rule from chain %s, err: %v", chainPOSTROUTING, err) return fmt.Errorf(errMSGFormat, chainFORWARD, err)
}
}
rule, found = i.rules[ipv4Nat]
if found {
err = i.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, chainPOSTROUTING, err)
} }
} }
rules, err := i.iptablesClient.List("nat", "POSTROUTING")
if err != nil {
return fmt.Errorf("failed to list rules: %s", err)
}
for _, ruleString := range rules {
if !strings.Contains(ruleString, "NETBIRD") {
continue
}
rule := strings.Fields(ruleString)
err := i.iptablesClient.DeleteIfExists("nat", "POSTROUTING", rule[2:]...)
if err != nil {
return fmt.Errorf("failed to delete postrouting jump rule: %s", err)
}
}
rules, err = i.iptablesClient.List(tableFilter, "FORWARD")
if err != nil {
return fmt.Errorf("failed to list rules in FORWARD chain: %s", err)
}
for _, ruleString := range rules {
if !strings.Contains(ruleString, "NETBIRD") {
continue
}
rule := strings.Fields(ruleString)
err := i.iptablesClient.DeleteIfExists(tableFilter, "FORWARD", rule[2:]...)
if err != nil {
return fmt.Errorf("failed to delete FORWARD jump rule: %s", err)
}
}
return nil return nil
} }
func (r *router) addNatRule(pair firewall.RouterPair) error { func (i *routerManager) createChain(table, newChain string) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair) chains, err := i.iptablesClient.ListChains(table)
if err != nil {
if rule, exists := r.rules[ruleKey]; exists { return fmt.Errorf("couldn't get %s table chains, error: %v", table, err)
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
}
delete(r.rules, ruleKey)
} }
rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, r.wgIface.Name(), pair.Inverse) shouldCreateChain := true
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule...); err != nil { for _, chain := range chains {
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err) if chain == newChain {
shouldCreateChain = false
}
} }
r.rules[ruleKey] = rule if shouldCreateChain {
err = i.iptablesClient.NewChain(table, newChain)
if err != nil {
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
}
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
if err != nil {
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
}
}
return nil return nil
} }
func (r *router) removeNatRule(pair firewall.RouterPair) error { // genRuleSpec generates rule specification with comment identifier
ruleKey := firewall.GenKey(firewall.NatFormat, pair) func genRuleSpec(jump, id, source, destination string) []string {
return []string{"-s", source, "-d", destination, "-j", jump, "-m", "comment", "--comment", id}
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
return fmt.Errorf("error while removing existing nat rule for %s: %v", pair.Destination, err)
}
delete(r.rules, ruleKey)
} else {
log.Debugf("nat rule %s not found", ruleKey)
}
return nil
} }
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string { func getIptablesRuleType(table string) string {
intdir := "-i" ruleType := "forwarding"
if inverse { if table == tableNat {
intdir = "-o" ruleType = "nat"
} }
return []string{intdir, intf, "-s", source.String(), "-d", destination.String(), "-j", jump} return ruleType
}
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
var rule []string
if params.SetName != "" {
rule = append(rule, "-m", "set", matchSet, params.SetName, "src")
} else if len(params.Sources) > 0 {
source := params.Sources[0]
rule = append(rule, "-s", source.String())
}
rule = append(rule, "-d", params.Destination.String())
if params.Proto != firewall.ProtocolALL {
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
rule = append(rule, applyPort("--sport", params.SPort)...)
rule = append(rule, applyPort("--dport", params.DPort)...)
}
rule = append(rule, "-j", actionToStr(params.Action))
return rule
}
func applyPort(flag string, port *firewall.Port) []string {
if port == nil {
return nil
}
if port.IsRange && len(port.Values) == 2 {
return []string{flag, fmt.Sprintf("%d:%d", port.Values[0], port.Values[1])}
}
if len(port.Values) > 1 {
portList := make([]string, len(port.Values))
for i, p := range port.Values {
portList[i] = strconv.Itoa(p)
}
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
}
return []string{flag, strconv.Itoa(port.Values[0])}
} }

View File

@@ -4,13 +4,11 @@ package iptables
import ( import (
"context" "context"
"net/netip"
"os/exec" "os/exec"
"testing" "testing"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -30,7 +28,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client") require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) manager, err := newRouterManager(context.TODO(), iptablesClient)
require.NoError(t, err, "should return a valid iptables manager") require.NoError(t, err, "should return a valid iptables manager")
defer func() { defer func() {
@@ -39,22 +37,28 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
require.Len(t, manager.rules, 2, "should have created rules map") require.Len(t, manager.rules, 2, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...) exists, err := manager.iptablesClient.Exists(tableFilter, chainFORWARD, manager.rules[Ipv4Forwarding]...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainFORWARD)
require.True(t, exists, "forwarding rule should exist")
exists, err = manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
require.True(t, exists, "postrouting rule should exist") require.True(t, exists, "postrouting rule should exist")
pair := firewall.RouterPair{ pair := firewall.RouterPair{
ID: "abc", ID: "abc",
Source: netip.MustParsePrefix("100.100.100.1/32"), Source: "100.100.100.1/32",
Destination: netip.MustParsePrefix("100.100.100.0/24"), Destination: "100.100.100.0/24",
Masquerade: true, Masquerade: true,
} }
forward4Rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump} forward4RuleKey := firewall.GenKey(firewall.ForwardingFormat, pair.ID)
forward4Rule := genRuleSpec(routingFinalForwardJump, forward4RuleKey, pair.Source, pair.Destination)
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...) err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "inserting rule should not return error")
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, ifaceMock.Name(), false) nat4RuleKey := firewall.GenKey(firewall.NatFormat, pair.ID)
nat4Rule := genRuleSpec(routingFinalNatJump, nat4RuleKey, pair.Source, pair.Destination)
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...) err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "inserting rule should not return error")
@@ -63,7 +67,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
} }
func TestIptablesManager_AddNatRule(t *testing.T) { func TestIptablesManager_InsertRoutingRules(t *testing.T) {
if !isIptablesSupported() { if !isIptablesSupported() {
t.SkipNow() t.SkipNow()
@@ -74,7 +78,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client") require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) manager, err := newRouterManager(context.TODO(), iptablesClient)
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
defer func() { defer func() {
@@ -84,13 +88,35 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
} }
}() }()
err = manager.AddNatRule(testCase.InputPair) err = manager.InsertRoutingRules(testCase.InputPair)
require.NoError(t, err, "forwarding pair should be inserted") require.NoError(t, err, "forwarding pair should be inserted")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false) forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...) exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.True(t, exists, "forwarding rule should exist")
foundRule, found := manager.rules[forwardRuleKey]
require.True(t, found, "forwarding rule should exist in the manager map")
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.True(t, exists, "income forwarding rule should exist")
foundRule, found = manager.rules[inForwardRuleKey]
require.True(t, found, "income forwarding rule should exist in the manager map")
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
if testCase.InputPair.Masquerade { if testCase.InputPair.Masquerade {
require.True(t, exists, "nat rule should be created") require.True(t, exists, "nat rule should be created")
@@ -103,8 +129,8 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
require.False(t, foundNat, "nat rule should not exist in the map") require.False(t, foundNat, "nat rule should not exist in the map")
} }
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true) inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...) exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
@@ -122,7 +148,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
} }
} }
func TestIptablesManager_RemoveNatRule(t *testing.T) { func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
if !isIptablesSupported() { if !isIptablesSupported() {
t.SkipNow() t.SkipNow()
@@ -132,7 +158,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) manager, err := newRouterManager(context.TODO(), iptablesClient)
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
defer func() { defer func() {
_ = manager.Reset() _ = manager.Reset()
@@ -140,14 +166,26 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false) forwardRule := genRuleSpec(routingFinalForwardJump, forwardRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...)
require.NoError(t, err, "inserting rule should not return error")
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, inForwardRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...)
require.NoError(t, err, "inserting rule should not return error")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, natRuleKey, testCase.InputPair.Source, testCase.InputPair.Destination)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...) err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "inserting rule should not return error")
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true) inNatRule := genRuleSpec(routingFinalNatJump, inNatRuleKey, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...) err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "inserting rule should not return error")
@@ -155,14 +193,28 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
err = manager.Reset() err = manager.Reset()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
err = manager.RemoveNatRule(testCase.InputPair) err = manager.RemoveRoutingRules(testCase.InputPair)
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...) exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.False(t, exists, "forwarding rule should not exist")
_, found := manager.rules[forwardRuleKey]
require.False(t, found, "forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.False(t, exists, "income forwarding rule should not exist")
_, found = manager.rules[inForwardRuleKey]
require.False(t, found, "income forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
require.False(t, exists, "nat rule should not exist") require.False(t, exists, "nat rule should not exist")
_, found := manager.rules[natRuleKey] _, found = manager.rules[natRuleKey]
require.False(t, found, "nat rule should exist in the manager map") require.False(t, found, "nat rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...) exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
@@ -171,175 +223,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
_, found = manager.rules[inNatRuleKey] _, found = manager.rules[inNatRuleKey]
require.False(t, found, "income nat rule should exist in the manager map") require.False(t, found, "income nat rule should exist in the manager map")
})
}
}
func TestRouter_AddRouteFiltering(t *testing.T) {
if !isIptablesSupported() {
t.Skip("iptables not supported on this system")
}
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "Failed to create iptables client")
r, err := newRouter(context.Background(), iptablesClient, ifaceMock)
require.NoError(t, err, "Failed to create router manager")
defer func() {
err := r.Reset()
require.NoError(t, err, "Failed to reset router")
}()
tests := []struct {
name string
sources []netip.Prefix
destination netip.Prefix
proto firewall.Protocol
sPort *firewall.Port
dPort *firewall.Port
direction firewall.RuleDirection
action firewall.Action
expectSet bool
}{
{
name: "Basic TCP rule with single source",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolTCP,
sPort: nil,
dPort: &firewall.Port{Values: []int{80}},
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with multiple sources",
sources: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("192.168.0.0/16"),
},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolUDP,
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionDrop,
expectSet: true,
},
{
name: "All protocols rule",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
destination: netip.MustParsePrefix("0.0.0.0/0"),
proto: firewall.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "ICMP rule",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolICMP,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "TCP rule with multiple source ports",
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
destination: netip.MustParsePrefix("192.168.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with single IP and port range",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolUDP,
sPort: nil,
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
},
{
name: "TCP rule with source and destination ports",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
destination: netip.MustParsePrefix("172.16.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []int{22}},
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "Drop all incoming traffic",
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
destination: netip.MustParsePrefix("192.168.0.0/24"),
proto: firewall.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.GetRuleID()]
assert.True(t, ok, "Rule not found in internal map")
// Log the internal rule
t.Logf("Internal rule: %v", rule)
// Check if the rule exists in iptables
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, rule...)
assert.NoError(t, err, "Failed to check rule existence")
assert.True(t, exists, "Rule not found in iptables")
// Verify rule content
params := routeFilteringRuleParams{
Sources: tt.sources,
Destination: tt.destination,
Proto: tt.proto,
SPort: tt.sPort,
DPort: tt.dPort,
Action: tt.action,
SetName: "",
}
expectedRule := genRouteFilteringRuleSpec(params)
if tt.expectSet {
setName := firewall.GenerateSetName(tt.sources)
params.SetName = setName
expectedRule = genRouteFilteringRuleSpec(params)
// Check if the set was created
_, exists := r.ipsetCounter.Get(setName)
assert.True(t, exists, "IPSet not created")
}
assert.Equal(t, expectedRule, rule, "Rule content mismatch")
// Clean up
err = r.DeleteRouteRule(ruleKey)
require.NoError(t, err, "Failed to delete rule")
}) })
} }
} }

View File

@@ -1,21 +1,15 @@
package manager package manager
import ( import (
"crypto/sha256"
"encoding/hex"
"fmt" "fmt"
"net" "net"
"net/netip"
"sort"
"strings"
log "github.com/sirupsen/logrus"
) )
const ( const (
ForwardingFormatPrefix = "netbird-fwd-" NatFormat = "netbird-nat-%s"
ForwardingFormat = "netbird-fwd-%s-%t" ForwardingFormat = "netbird-fwd-%s"
NatFormat = "netbird-nat-%s-%t" InNatFormat = "netbird-nat-in-%s"
InForwardingFormat = "netbird-fwd-in-%s"
) )
// Rule abstraction should be implemented by each firewall manager // Rule abstraction should be implemented by each firewall manager
@@ -55,11 +49,11 @@ type Manager interface {
// AllowNetbird allows netbird interface traffic // AllowNetbird allows netbird interface traffic
AllowNetbird() error AllowNetbird() error
// AddPeerFiltering adds a rule to the firewall // AddFiltering rule to the firewall
// //
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
AddPeerFiltering( AddFiltering(
ip net.IP, ip net.IP,
proto Protocol, proto Protocol,
sPort *Port, sPort *Port,
@@ -70,25 +64,17 @@ type Manager interface {
comment string, comment string,
) ([]Rule, error) ) ([]Rule, error)
// DeletePeerRule from the firewall by rule definition // DeleteRule from the firewall by rule definition
DeletePeerRule(rule Rule) error DeleteRule(rule Rule) error
// IsServerRouteSupported returns true if the firewall supports server side routing operations // IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool IsServerRouteSupported() bool
AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error) // InsertRoutingRules inserts a routing firewall rule
InsertRoutingRules(pair RouterPair) error
// DeleteRouteRule deletes a routing rule // RemoveRoutingRules removes a routing firewall rule
DeleteRouteRule(rule Rule) error RemoveRoutingRules(pair RouterPair) error
// AddNatRule inserts a routing NAT rule
AddNatRule(pair RouterPair) error
// RemoveNatRule removes a routing NAT rule
RemoveNatRule(pair RouterPair) error
// SetLegacyManagement sets the legacy management mode
SetLegacyManagement(legacy bool) error
// Reset firewall to the default state // Reset firewall to the default state
Reset() error Reset() error
@@ -97,89 +83,6 @@ type Manager interface {
Flush() error Flush() error
} }
func GenKey(format string, pair RouterPair) string { func GenKey(format string, input string) string {
return fmt.Sprintf(format, pair.ID, pair.Inverse) return fmt.Sprintf(format, input)
}
// LegacyManager defines the interface for legacy management operations
type LegacyManager interface {
RemoveAllLegacyRouteRules() error
GetLegacyManagement() bool
SetLegacyManagement(bool)
}
// SetLegacyManagement sets the route manager to use legacy management
func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
oldLegacy := router.GetLegacyManagement()
if oldLegacy != isLegacy {
router.SetLegacyManagement(isLegacy)
log.Debugf("Set legacy management to %v", isLegacy)
}
// client reconnected to a newer mgmt, we need to clean up the legacy rules
if !isLegacy && oldLegacy {
if err := router.RemoveAllLegacyRouteRules(); err != nil {
return fmt.Errorf("remove legacy routing rules: %v", err)
}
log.Debugf("Legacy routing rules removed")
}
return nil
}
// GenerateSetName generates a unique name for an ipset based on the given sources.
func GenerateSetName(sources []netip.Prefix) string {
// sort for consistent naming
SortPrefixes(sources)
var sourcesStr strings.Builder
for _, src := range sources {
sourcesStr.WriteString(src.String())
}
hash := sha256.Sum256([]byte(sourcesStr.String()))
shortHash := hex.EncodeToString(hash[:])[:8]
return fmt.Sprintf("nb-%s", shortHash)
}
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
if len(prefixes) == 0 {
return prefixes
}
merged := []netip.Prefix{prefixes[0]}
for _, prefix := range prefixes[1:] {
last := merged[len(merged)-1]
if last.Contains(prefix.Addr()) {
// If the current prefix is contained within the last merged prefix, skip it
continue
}
if prefix.Contains(last.Addr()) {
// If the current prefix contains the last merged prefix, replace it
merged[len(merged)-1] = prefix
} else {
// Otherwise, add the current prefix to the merged list
merged = append(merged, prefix)
}
}
return merged
}
// SortPrefixes sorts the given slice of netip.Prefix in place.
// It sorts first by IP address, then by prefix length (most specific to least specific).
func SortPrefixes(prefixes []netip.Prefix) {
sort.Slice(prefixes, func(i, j int) bool {
addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr())
if addrCmp != 0 {
return addrCmp < 0
}
// If IP addresses are the same, compare prefix lengths (longer prefixes first)
return prefixes[i].Bits() > prefixes[j].Bits()
})
} }

View File

@@ -1,192 +0,0 @@
package manager_test
import (
"net/netip"
"reflect"
"regexp"
"testing"
"github.com/netbirdio/netbird/client/firewall/manager"
)
func TestGenerateSetName(t *testing.T) {
t.Run("Different orders result in same hash", func(t *testing.T) {
prefixes1 := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
}
prefixes2 := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.1.0/24"),
}
result1 := manager.GenerateSetName(prefixes1)
result2 := manager.GenerateSetName(prefixes2)
if result1 != result2 {
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
}
})
t.Run("Result format is correct", func(t *testing.T) {
prefixes := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
}
result := manager.GenerateSetName(prefixes)
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result)
if err != nil {
t.Fatalf("Error matching regex: %v", err)
}
if !matched {
t.Errorf("Result format is incorrect: %s", result)
}
})
t.Run("Empty input produces consistent result", func(t *testing.T) {
result1 := manager.GenerateSetName([]netip.Prefix{})
result2 := manager.GenerateSetName([]netip.Prefix{})
if result1 != result2 {
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
}
})
t.Run("IPv4 and IPv6 mixing", func(t *testing.T) {
prefixes1 := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("2001:db8::/32"),
}
prefixes2 := []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
netip.MustParsePrefix("192.168.1.0/24"),
}
result1 := manager.GenerateSetName(prefixes1)
result2 := manager.GenerateSetName(prefixes2)
if result1 != result2 {
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)
}
})
}
func TestMergeIPRanges(t *testing.T) {
tests := []struct {
name string
input []netip.Prefix
expected []netip.Prefix
}{
{
name: "Empty input",
input: []netip.Prefix{},
expected: []netip.Prefix{},
},
{
name: "Single range",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Two non-overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
},
},
{
name: "One range containing another",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "One range containing another (different order)",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.0.0/16"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "Overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.1.128/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Overlapping ranges (different order)",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.128/25"),
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Multiple overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.2.0/24"),
netip.MustParsePrefix("192.168.1.128/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "Partially overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/23"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.2.0/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/23"),
netip.MustParsePrefix("192.168.2.0/25"),
},
},
{
name: "IPv6 ranges",
input: []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
netip.MustParsePrefix("2001:db8:1::/48"),
netip.MustParsePrefix("2001:db8:2::/48"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := manager.MergeIPRanges(tt.input)
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("MergeIPRanges() = %v, want %v", result, tt.expected)
}
})
}
}

View File

@@ -1,26 +1,18 @@
package manager package manager
import (
"net/netip"
"github.com/netbirdio/netbird/route"
)
type RouterPair struct { type RouterPair struct {
ID route.ID ID string
Source netip.Prefix Source string
Destination netip.Prefix Destination string
Masquerade bool Masquerade bool
Inverse bool
} }
func GetInversePair(pair RouterPair) RouterPair { func GetInPair(pair RouterPair) RouterPair {
return RouterPair{ return RouterPair{
ID: pair.ID, ID: pair.ID,
// invert Source/Destination // invert Source/Destination
Source: pair.Destination, Source: pair.Destination,
Destination: pair.Source, Destination: pair.Source,
Masquerade: pair.Masquerade, Masquerade: pair.Masquerade,
Inverse: true,
} }
} }

View File

@@ -11,14 +11,12 @@ import (
"time" "time"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/iface"
nbnet "github.com/netbirdio/netbird/util/net"
) )
const ( const (
@@ -31,26 +29,26 @@ const (
chainNameInputFilter = "netbird-acl-input-filter" chainNameInputFilter = "netbird-acl-input-filter"
chainNameOutputFilter = "netbird-acl-output-filter" chainNameOutputFilter = "netbird-acl-output-filter"
chainNameForwardFilter = "netbird-acl-forward-filter" chainNameForwardFilter = "netbird-acl-forward-filter"
chainNamePrerouting = "netbird-rt-prerouting"
allowNetbirdInputRuleID = "allow Netbird incoming traffic" allowNetbirdInputRuleID = "allow Netbird incoming traffic"
) )
const flushError = "flush: %w"
var ( var (
anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
postroutingMark = []byte{0xe4, 0x7, 0x0, 0x00}
) )
type AclManager struct { type AclManager struct {
rConn *nftables.Conn rConn *nftables.Conn
sConn *nftables.Conn sConn *nftables.Conn
wgIface iFaceMapper wgIface iFaceMapper
routingFwChainName string routeingFwChainName string
workTable *nftables.Table workTable *nftables.Table
chainInputRules *nftables.Chain chainInputRules *nftables.Chain
chainOutputRules *nftables.Chain chainOutputRules *nftables.Chain
chainFwFilter *nftables.Chain
chainPrerouting *nftables.Chain
ipsetStore *ipsetStore ipsetStore *ipsetStore
rules map[string]*Rule rules map[string]*Rule
@@ -63,10 +61,10 @@ type iFaceMapper interface {
IsUserspaceBind() bool IsUserspaceBind() bool
} }
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) { func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) {
// sConn is used for creating sets and adding/removing elements from them // sConn is used for creating sets and adding/removing elements from them
// it's differ then rConn (which does create new conn for each flush operation) // it's differ then rConn (which does create new conn for each flush operation)
// and is permanent. Using same connection for both type of operations // and is permanent. Using same connection for booth type of operations
// overloads netlink with high amount of rules ( > 10000) // overloads netlink with high amount of rules ( > 10000)
sConn, err := nftables.New(nftables.AsLasting()) sConn, err := nftables.New(nftables.AsLasting())
if err != nil { if err != nil {
@@ -78,7 +76,7 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam
sConn: sConn, sConn: sConn,
wgIface: wgIface, wgIface: wgIface,
workTable: table, workTable: table,
routingFwChainName: routingFwChainName, routeingFwChainName: routeingFwChainName,
ipsetStore: newIpsetStore(), ipsetStore: newIpsetStore(),
rules: make(map[string]*Rule), rules: make(map[string]*Rule),
@@ -92,11 +90,11 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam
return m, nil return m, nil
} }
// AddPeerFiltering rule to the firewall // AddFiltering rule to the firewall
// //
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
func (m *AclManager) AddPeerFiltering( func (m *AclManager) AddFiltering(
ip net.IP, ip net.IP,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
@@ -122,11 +120,20 @@ func (m *AclManager) AddPeerFiltering(
} }
newRules = append(newRules, ioRule) newRules = append(newRules, ioRule)
if !shouldAddToPrerouting(proto, dPort, direction) {
return newRules, nil
}
preroutingRule, err := m.addPreroutingFiltering(ipset, proto, dPort, ip)
if err != nil {
return newRules, err
}
newRules = append(newRules, preroutingRule)
return newRules, nil return newRules, nil
} }
// DeletePeerRule from the firewall by rule definition // DeleteRule from the firewall by rule definition
func (m *AclManager) DeletePeerRule(rule firewall.Rule) error { func (m *AclManager) DeleteRule(rule firewall.Rule) error {
r, ok := rule.(*Rule) r, ok := rule.(*Rule)
if !ok { if !ok {
return fmt.Errorf("invalid rule type") return fmt.Errorf("invalid rule type")
@@ -192,7 +199,8 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
return nil return nil
} }
// createDefaultAllowRules creates default allow rules for the input and output chains // createDefaultAllowRules In case if the USP firewall manager can use the native firewall manager we must to create allow rules for
// input and output chains
func (m *AclManager) createDefaultAllowRules() error { func (m *AclManager) createDefaultAllowRules() error {
expIn := []expr.Any{ expIn := []expr.Any{
&expr.Payload{ &expr.Payload{
@@ -206,13 +214,13 @@ func (m *AclManager) createDefaultAllowRules() error {
SourceRegister: 1, SourceRegister: 1,
DestRegister: 1, DestRegister: 1,
Len: 4, Len: 4,
Mask: []byte{0, 0, 0, 0}, Mask: []byte{0x00, 0x00, 0x00, 0x00},
Xor: []byte{0, 0, 0, 0}, Xor: zeroXor,
}, },
// net address // net address
&expr.Cmp{ &expr.Cmp{
Register: 1, Register: 1,
Data: []byte{0, 0, 0, 0}, Data: []byte{0x00, 0x00, 0x00, 0x00},
}, },
&expr.Verdict{ &expr.Verdict{
Kind: expr.VerdictAccept, Kind: expr.VerdictAccept,
@@ -238,13 +246,13 @@ func (m *AclManager) createDefaultAllowRules() error {
SourceRegister: 1, SourceRegister: 1,
DestRegister: 1, DestRegister: 1,
Len: 4, Len: 4,
Mask: []byte{0, 0, 0, 0}, Mask: []byte{0x00, 0x00, 0x00, 0x00},
Xor: []byte{0, 0, 0, 0}, Xor: zeroXor,
}, },
// net address // net address
&expr.Cmp{ &expr.Cmp{
Register: 1, Register: 1,
Data: []byte{0, 0, 0, 0}, Data: []byte{0x00, 0x00, 0x00, 0x00},
}, },
&expr.Verdict{ &expr.Verdict{
Kind: expr.VerdictAccept, Kind: expr.VerdictAccept,
@@ -258,8 +266,10 @@ func (m *AclManager) createDefaultAllowRules() error {
Exprs: expOut, Exprs: expOut,
}) })
if err := m.rConn.Flush(); err != nil { err := m.rConn.Flush()
return fmt.Errorf(flushError, err) if err != nil {
log.Debugf("failed to create default allow rules: %s", err)
return err
} }
return nil return nil
} }
@@ -280,11 +290,15 @@ func (m *AclManager) Flush() error {
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err) log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
} }
if err := m.refreshRuleHandles(m.chainPrerouting); err != nil {
log.Errorf("failed to refresh rule handles IPv4 prerouting chain: %v", err)
}
return nil return nil
} }
func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) { func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) {
ruleId := generatePeerRuleId(ip, sPort, dPort, direction, action, ipset) ruleId := generateRuleId(ip, sPort, dPort, direction, action, ipset)
if r, ok := m.rules[ruleId]; ok { if r, ok := m.rules[ruleId]; ok {
return &Rule{ return &Rule{
r.nftRule, r.nftRule,
@@ -294,7 +308,18 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
}, nil }, nil
} }
var expressions []expr.Any ifaceKey := expr.MetaKeyIIFNAME
if direction == firewall.RuleDirectionOUT {
ifaceKey = expr.MetaKeyOIFNAME
}
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
}
if proto != firewall.ProtocolALL { if proto != firewall.ProtocolALL {
expressions = append(expressions, &expr.Payload{ expressions = append(expressions, &expr.Payload{
@@ -304,15 +329,21 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
Len: uint32(1), Len: uint32(1),
}) })
protoData, err := protoToInt(proto) var protoData []byte
if err != nil { switch proto {
return nil, fmt.Errorf("convert protocol to number: %v", err) case firewall.ProtocolTCP:
protoData = []byte{unix.IPPROTO_TCP}
case firewall.ProtocolUDP:
protoData = []byte{unix.IPPROTO_UDP}
case firewall.ProtocolICMP:
protoData = []byte{unix.IPPROTO_ICMP}
default:
return nil, fmt.Errorf("unsupported protocol: %s", proto)
} }
expressions = append(expressions, &expr.Cmp{ expressions = append(expressions, &expr.Cmp{
Register: 1, Register: 1,
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
Data: []byte{protoData}, Data: protoData,
}) })
} }
@@ -401,9 +432,10 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
} else { } else {
chain = m.chainOutputRules chain = m.chainOutputRules
} }
nftRule := m.rConn.AddRule(&nftables.Rule{ nftRule := m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable, Table: m.workTable,
Chain: chain, Chain: chain,
Position: 0,
Exprs: expressions, Exprs: expressions,
UserData: userData, UserData: userData,
}) })
@@ -421,13 +453,139 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
return rule, nil return rule, nil
} }
func (m *AclManager) addPreroutingFiltering(ipset *nftables.Set, proto firewall.Protocol, port *firewall.Port, ip net.IP) (*Rule, error) {
var protoData []byte
switch proto {
case firewall.ProtocolTCP:
protoData = []byte{unix.IPPROTO_TCP}
case firewall.ProtocolUDP:
protoData = []byte{unix.IPPROTO_UDP}
case firewall.ProtocolICMP:
protoData = []byte{unix.IPPROTO_ICMP}
default:
return nil, fmt.Errorf("unsupported protocol: %s", proto)
}
ruleId := generateRuleIdForMangle(ipset, ip, proto, port)
if r, ok := m.rules[ruleId]; ok {
return &Rule{
r.nftRule,
r.nftSet,
r.ruleID,
ip,
}, nil
}
var ipExpression expr.Any
// add individual IP for match if no ipset defined
rawIP := ip.To4()
if ipset == nil {
ipExpression = &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: rawIP,
}
} else {
ipExpression = &expr.Lookup{
SourceRegister: 1,
SetName: ipset.Name,
SetID: ipset.ID,
}
}
expressions := []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
ipExpression,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: m.wgIface.Address().IP.To4(),
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: uint32(9),
Len: uint32(1),
},
&expr.Cmp{
Register: 1,
Op: expr.CmpOpEq,
Data: protoData,
},
}
if port != nil {
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: encodePort(*port),
},
)
}
expressions = append(expressions,
&expr.Immediate{
Register: 1,
Data: postroutingMark,
},
&expr.Meta{
Key: expr.MetaKeyMARK,
SourceRegister: true,
Register: 1,
},
)
nftRule := m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainPrerouting,
Position: 0,
Exprs: expressions,
UserData: []byte(ruleId),
})
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf("flush insert rule: %v", err)
}
rule := &Rule{
nftRule: nftRule,
nftSet: ipset,
ruleID: ruleId,
ip: ip,
}
m.rules[ruleId] = rule
if ipset != nil {
m.ipsetStore.AddReferenceToIpset(ipset.Name)
}
return rule, nil
}
func (m *AclManager) createDefaultChains() (err error) { func (m *AclManager) createDefaultChains() (err error) {
// chainNameInputRules // chainNameInputRules
chain := m.createChain(chainNameInputRules) chain := m.createChain(chainNameInputRules)
err = m.rConn.Flush() err = m.rConn.Flush()
if err != nil { if err != nil {
log.Debugf("failed to create chain (%s): %s", chain.Name, err) log.Debugf("failed to create chain (%s): %s", chain.Name, err)
return fmt.Errorf(flushError, err) return err
} }
m.chainInputRules = chain m.chainInputRules = chain
@@ -443,6 +601,9 @@ func (m *AclManager) createDefaultChains() (err error) {
// netbird-acl-input-filter // netbird-acl-input-filter
// type filter hook input priority filter; policy accept; // type filter hook input priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput) chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
//netbird-acl-input-filter iifname "wt0" ip saddr 100.72.0.0/16 ip daddr != 100.72.0.0/16 accept
m.addRouteAllowRule(chain, expr.MetaKeyIIFNAME)
m.addFwdAllow(chain, expr.MetaKeyIIFNAME)
m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
m.addDropExpressions(chain, expr.MetaKeyIIFNAME) m.addDropExpressions(chain, expr.MetaKeyIIFNAME)
err = m.rConn.Flush() err = m.rConn.Flush()
@@ -454,6 +615,7 @@ func (m *AclManager) createDefaultChains() (err error) {
// netbird-acl-output-filter // netbird-acl-output-filter
// type filter hook output priority filter; policy accept; // type filter hook output priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput) chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput)
m.addRouteAllowRule(chain, expr.MetaKeyOIFNAME)
m.addFwdAllow(chain, expr.MetaKeyOIFNAME) m.addFwdAllow(chain, expr.MetaKeyOIFNAME)
m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules
m.addDropExpressions(chain, expr.MetaKeyOIFNAME) m.addDropExpressions(chain, expr.MetaKeyOIFNAME)
@@ -464,106 +626,29 @@ func (m *AclManager) createDefaultChains() (err error) {
} }
// netbird-acl-forward-filter // netbird-acl-forward-filter
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward) m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd m.addJumpRulesToRtForward() // to
m.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME) m.addMarkAccept()
m.addJumpRuleToInputChain() // to netbird-acl-input-rules
m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME)
err = m.rConn.Flush() err = m.rConn.Flush()
if err != nil { if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err) log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
return fmt.Errorf(flushError, err) return err
} }
if err := m.allowRedirectedTraffic(chainFwFilter); err != nil { // netbird-acl-output-filter
log.Errorf("failed to allow redirected traffic: %s", err) // type filter hook output priority filter; policy accept;
m.chainPrerouting = m.createPreroutingMangle()
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", m.chainPrerouting.Name, err)
return err
} }
return nil return nil
} }
// Makes redirected traffic originally destined for the host itself (now subject to the forward filter) func (m *AclManager) addJumpRulesToRtForward() {
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
// netbird peer IP.
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
preroutingChain := m.rConn.AddChain(&nftables.Chain{
Name: chainNamePrerouting,
Table: m.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
})
m.addPreroutingRule(preroutingChain)
m.addFwmarkToForward(chainFwFilter)
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: preroutingChain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Fib{
Register: 1,
ResultADDRTYPE: true,
FlagDADDR: true,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
},
})
}
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
Chain: chainFwFilter,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: m.chainInputRules.Name,
},
},
})
}
func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
expressions := []expr.Any{ expressions := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{ &expr.Cmp{
@@ -573,15 +658,68 @@ func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
}, },
&expr.Verdict{ &expr.Verdict{
Kind: expr.VerdictJump, Kind: expr.VerdictJump,
Chain: m.routingFwChainName, Chain: m.routeingFwChainName,
}, },
} }
_ = m.rConn.AddRule(&nftables.Rule{ _ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable, Table: m.workTable,
Chain: chainFwFilter, Chain: m.chainFwFilter,
Exprs: expressions, Exprs: expressions,
}) })
expressions = []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: m.routeingFwChainName,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainFwFilter,
Exprs: expressions,
})
}
func (m *AclManager) addMarkAccept() {
// oifname "wt0" meta mark 0x000007e4 accept
// iifname "wt0" meta mark 0x000007e4 accept
ifaces := []expr.MetaKey{expr.MetaKeyIIFNAME, expr.MetaKeyOIFNAME}
for _, iface := range ifaces {
expressions := []expr.Any{
&expr.Meta{Key: iface, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: postroutingMark,
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainFwFilter,
Exprs: expressions,
})
}
} }
func (m *AclManager) createChain(name string) *nftables.Chain { func (m *AclManager) createChain(name string) *nftables.Chain {
@@ -591,13 +729,10 @@ func (m *AclManager) createChain(name string) *nftables.Chain {
} }
chain = m.rConn.AddChain(chain) chain = m.rConn.AddChain(chain)
insertReturnTrafficRule(m.rConn, m.workTable, chain)
return chain return chain
} }
func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain { func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.ChainHook) *nftables.Chain {
polAccept := nftables.ChainPolicyAccept polAccept := nftables.ChainPolicyAccept
chain := &nftables.Chain{ chain := &nftables.Chain{
Name: name, Name: name,
@@ -611,6 +746,74 @@ func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.Ch
return m.rConn.AddChain(chain) return m.rConn.AddChain(chain)
} }
func (m *AclManager) createPreroutingMangle() *nftables.Chain {
polAccept := nftables.ChainPolicyAccept
chain := &nftables.Chain{
Name: "netbird-acl-prerouting-filter",
Table: m.workTable,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
Policy: &polAccept,
}
chain = m.rConn.AddChain(chain)
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
expressions := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: m.wgIface.Address().IP.To4(),
},
&expr.Immediate{
Register: 1,
Data: postroutingMark,
},
&expr.Meta{
Key: expr.MetaKeyMARK,
SourceRegister: true,
Register: 1,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: chain,
Exprs: expressions,
})
return chain
}
func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any { func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
expressions := []expr.Any{ expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1}, &expr.Meta{Key: ifaceKey, Register: 1},
@@ -629,9 +832,101 @@ func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.Met
return nil return nil
} }
func (m *AclManager) addJumpRuleToInputChain() {
expressions := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: m.chainInputRules.Name,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainFwFilter,
Exprs: expressions,
})
}
func (m *AclManager) addRouteAllowRule(chain *nftables.Chain, netIfName expr.MetaKey) {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
var srcOp, dstOp expr.CmpOp
if netIfName == expr.MetaKeyIIFNAME {
srcOp = expr.CmpOpEq
dstOp = expr.CmpOpNeq
} else {
srcOp = expr.CmpOpNeq
dstOp = expr.CmpOpEq
}
expressions := []expr.Any{
&expr.Meta{Key: netIfName, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: srcOp,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: dstOp,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: expressions,
})
}
func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) { func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
dstOp := expr.CmpOpNeq var srcOp, dstOp expr.CmpOp
if iifname == expr.MetaKeyIIFNAME {
srcOp = expr.CmpOpNeq
dstOp = expr.CmpOpEq
} else {
srcOp = expr.CmpOpEq
dstOp = expr.CmpOpNeq
}
expressions := []expr.Any{ expressions := []expr.Any{
&expr.Meta{Key: iifname, Register: 1}, &expr.Meta{Key: iifname, Register: 1},
&expr.Cmp{ &expr.Cmp{
@@ -639,6 +934,24 @@ func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
Register: 1, Register: 1,
Data: ifname(m.wgIface.Name()), Data: ifname(m.wgIface.Name()),
}, },
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: srcOp,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Payload{ &expr.Payload{
DestRegister: 2, DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader, Base: expr.PayloadBaseNetworkHeader,
@@ -669,6 +982,7 @@ func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
} }
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) { func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
expressions := []expr.Any{ expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1}, &expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{ &expr.Cmp{
@@ -676,12 +990,47 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr
Register: 1, Register: 1,
Data: ifname(m.wgIface.Name()), Data: ifname(m.wgIface.Name()),
}, },
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Verdict{ &expr.Verdict{
Kind: expr.VerdictJump, Kind: expr.VerdictJump,
Chain: to, Chain: to,
}, },
} }
_ = m.rConn.AddRule(&nftables.Rule{ _ = m.rConn.AddRule(&nftables.Rule{
Table: chain.Table, Table: chain.Table,
Chain: chain, Chain: chain,
@@ -783,7 +1132,7 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
return nil return nil
} }
func generatePeerRuleId( func generateRuleId(
ip net.IP, ip net.IP,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
@@ -806,6 +1155,33 @@ func generatePeerRuleId(
} }
return "set:" + ipset.Name + rulesetID return "set:" + ipset.Name + rulesetID
} }
func generateRuleIdForMangle(ipset *nftables.Set, ip net.IP, proto firewall.Protocol, port *firewall.Port) string {
// case of icmp port is empty
var p string
if port != nil {
p = port.String()
}
if ipset != nil {
return fmt.Sprintf("p:set:%s:%s:%v", ipset.Name, proto, p)
} else {
return fmt.Sprintf("p:ip:%s:%s:%v", ip.String(), proto, p)
}
}
func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
if proto == "all" {
return false
}
if direction != firewall.RuleDirectionIN {
return false
}
if dPort == nil && proto != firewall.ProtocolICMP {
return false
}
return true
}
func encodePort(port firewall.Port) []byte { func encodePort(port firewall.Port) []byte {
bs := make([]byte, 2) bs := make([]byte, 2)
@@ -815,19 +1191,6 @@ func encodePort(port firewall.Port) []byte {
func ifname(n string) []byte { func ifname(n string) []byte {
b := make([]byte, 16) b := make([]byte, 16)
copy(b, n+"\x00") copy(b, []byte(n+"\x00"))
return b return b
} }
func protoToInt(protocol firewall.Protocol) (uint8, error) {
switch protocol {
case firewall.ProtocolTCP:
return unix.IPPROTO_TCP, nil
case firewall.ProtocolUDP:
return unix.IPPROTO_UDP, nil
case firewall.ProtocolICMP:
return unix.IPPROTO_ICMP, nil
}
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
}

View File

@@ -5,11 +5,9 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -17,11 +15,8 @@ import (
) )
const ( const (
// tableNameNetbird is the name of the table that is used for filtering by the Netbird client // tableName is the name of the table that is used for filtering by the Netbird client
tableNameNetbird = "netbird" tableName = "netbird"
tableNameFilter = "filter"
chainNameInput = "INPUT"
) )
// Manager of iptables firewall // Manager of iptables firewall
@@ -46,12 +41,12 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
return nil, err return nil, err
} }
m.router, err = newRouter(context, workTable, wgIface) m.router, err = newRouter(context, workTable)
if err != nil { if err != nil {
return nil, err return nil, err
} }
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw) m.aclManager, err = newAclManager(workTable, wgIface, m.router.RouteingFwChainName())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -59,11 +54,11 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
return m, nil return m, nil
} }
// AddPeerFiltering rule to the firewall // AddFiltering rule to the firewall
// //
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
func (m *Manager) AddPeerFiltering( func (m *Manager) AddFiltering(
ip net.IP, ip net.IP,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
@@ -81,52 +76,33 @@ func (m *Manager) AddPeerFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", ip.String()) return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
} }
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment) return m.aclManager.AddFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
} }
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) { // DeleteRule from the firewall by rule definition
func (m *Manager) DeleteRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if !destination.Addr().Is4() { return m.aclManager.DeleteRule(rule)
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
}
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclManager.DeletePeerRule(rule)
}
// DeleteRouteRule deletes a routing rule
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.DeleteRouteRule(rule)
} }
func (m *Manager) IsServerRouteSupported() bool { func (m *Manager) IsServerRouteSupported() bool {
return true return true
} }
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.AddNatRule(pair) return m.router.InsertRoutingRules(pair)
} }
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.RemoveNatRule(pair) return m.router.RemoveRoutingRules(pair)
} }
// AllowNetbird allows netbird interface traffic // AllowNetbird allows netbird interface traffic
@@ -150,7 +126,7 @@ func (m *Manager) AllowNetbird() error {
var chain *nftables.Chain var chain *nftables.Chain
for _, c := range chains { for _, c := range chains {
if c.Table.Name == tableNameFilter && c.Name == chainNameForward { if c.Table.Name == "filter" && c.Name == "INPUT" {
chain = c chain = c
break break
} }
@@ -181,27 +157,6 @@ func (m *Manager) AllowNetbird() error {
return nil return nil
} }
// SetLegacyManagement sets the route manager to use legacy management
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
oldLegacy := m.router.legacyManagement
if oldLegacy != isLegacy {
m.router.legacyManagement = isLegacy
log.Debugf("Set legacy management to %v", isLegacy)
}
// client reconnected to a newer mgmt, we need to cleanup the legacy rules
if !isLegacy && oldLegacy {
if err := m.router.RemoveAllLegacyRouteRules(); err != nil {
return fmt.Errorf("remove legacy routing rules: %v", err)
}
log.Debugf("Legacy routing rules removed")
}
return nil
}
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset() error { func (m *Manager) Reset() error {
m.mutex.Lock() m.mutex.Lock()
@@ -230,16 +185,14 @@ func (m *Manager) Reset() error {
} }
} }
if err := m.router.Reset(); err != nil { m.router.ResetForwardRules()
return fmt.Errorf("reset forward rules: %v", err)
}
tables, err := m.rConn.ListTables() tables, err := m.rConn.ListTables()
if err != nil { if err != nil {
return fmt.Errorf("list of tables: %w", err) return fmt.Errorf("list of tables: %w", err)
} }
for _, t := range tables { for _, t := range tables {
if t.Name == tableNameNetbird { if t.Name == tableName {
m.rConn.DelTable(t) m.rConn.DelTable(t)
} }
} }
@@ -265,12 +218,12 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
} }
for _, t := range tables { for _, t := range tables {
if t.Name == tableNameNetbird { if t.Name == tableName {
m.rConn.DelTable(t) m.rConn.DelTable(t)
} }
} }
table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}) table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
err = m.rConn.Flush() err = m.rConn.Flush()
return table, err return table, err
} }
@@ -286,7 +239,9 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
Register: 1, Register: 1,
Data: ifname(m.wgIface.Name()), Data: ifname(m.wgIface.Name()),
}, },
&expr.Verdict{}, &expr.Verdict{
Kind: expr.VerdictAccept,
},
}, },
UserData: []byte(allowNetbirdInputRuleID), UserData: []byte(allowNetbirdInputRuleID),
} }
@@ -296,7 +251,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule { func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
ifName := ifname(m.wgIface.Name()) ifName := ifname(m.wgIface.Name())
for _, rule := range existedRules { for _, rule := range existedRules {
if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput { if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" {
if len(rule.Exprs) < 4 { if len(rule.Exprs) < 4 {
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME { if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
continue continue
@@ -310,33 +265,3 @@ func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftable
} }
return nil return nil
} }
func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) {
rule := &nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
}
conn.InsertRule(rule)
}

View File

@@ -9,30 +9,14 @@ import (
"time" "time"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/iface"
) )
var ifaceMock = &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct { type iFaceMock struct {
NameFunc func() string NameFunc func() string
@@ -56,9 +40,23 @@ func (i *iFaceMock) Address() iface.WGAddress {
func (i *iFaceMock) IsUserspaceBind() bool { return false } func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestNftablesManager(t *testing.T) { func TestNftablesManager(t *testing.T) {
mock := &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
// just check on the local interface // just check on the local interface
manager, err := Create(context.Background(), ifaceMock) manager, err := Create(context.Background(), mock)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
@@ -72,7 +70,7 @@ func TestNftablesManager(t *testing.T) {
testClient := &nftables.Conn{} testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering( rule, err := manager.AddFiltering(
ip, ip,
fw.ProtocolTCP, fw.ProtocolTCP,
nil, nil,
@@ -90,34 +88,17 @@ func TestNftablesManager(t *testing.T) {
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules) rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules") require.NoError(t, err, "failed to get rules")
require.Len(t, rules, 2, "expected 2 rules") require.Len(t, rules, 1, "expected 1 rules")
expectedExprs1 := []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions")
ipToAdd, _ := netip.AddrFromSlice(ip) ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap() add := ipToAdd.Unmap()
expectedExprs2 := []expr.Any{ expectedExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname("lo"),
},
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader, Base: expr.PayloadBaseNetworkHeader,
@@ -153,10 +134,10 @@ func TestNftablesManager(t *testing.T) {
}, },
&expr.Verdict{Kind: expr.VerdictDrop}, &expr.Verdict{Kind: expr.VerdictDrop},
} }
require.ElementsMatch(t, rules[1].Exprs, expectedExprs2, "expected the same expressions") require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions")
for _, r := range rule { for _, r := range rule {
err = manager.DeletePeerRule(r) err = manager.DeleteRule(r)
require.NoError(t, err, "failed to delete rule") require.NoError(t, err, "failed to delete rule")
} }
@@ -165,8 +146,7 @@ func TestNftablesManager(t *testing.T) {
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules) rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules") require.NoError(t, err, "failed to get rules")
// established rule remains require.Len(t, rules, 0, "expected 0 rules after deletion")
require.Len(t, rules, 1, "expected 1 rules after deletion")
err = manager.Reset() err = manager.Reset()
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
@@ -207,9 +187,9 @@ func TestNFtablesCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 { if i%2 == 0 {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else { } else {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
} }
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")

View File

@@ -0,0 +1,413 @@
package nftables
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"net/netip"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/manager"
)
const (
chainNameRouteingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-nat"
userDataAcceptForwardRuleSrc = "frwacceptsrc"
userDataAcceptForwardRuleDst = "frwacceptdst"
)
// some presets for building nftable rules
var (
zeroXor = binaryutil.NativeEndian.PutUint32(0)
exprCounterAccept = []expr.Any{
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
)
type router struct {
ctx context.Context
stop context.CancelFunc
conn *nftables.Conn
workTable *nftables.Table
filterTable *nftables.Table
chains map[string]*nftables.Chain
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
rules map[string]*nftables.Rule
isDefaultFwdRulesEnabled bool
}
func newRouter(parentCtx context.Context, workTable *nftables.Table) (*router, error) {
ctx, cancel := context.WithCancel(parentCtx)
r := &router{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
workTable: workTable,
chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
}
var err error
r.filterTable, err = r.loadFilterTable()
if err != nil {
if errors.Is(err, errFilterTableNotFound) {
log.Warnf("table 'filter' not found for forward rules")
} else {
return nil, err
}
}
err = r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
err = r.createContainers()
if err != nil {
log.Errorf("failed to create containers for route: %s", err)
}
return r, err
}
func (r *router) RouteingFwChainName() string {
return chainNameRouteingFw
}
// ResetForwardRules cleans existing nftables default forward rules from the system
func (r *router) ResetForwardRules() {
err := r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to reset forward rules: %s", err)
}
}
func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
}
for _, table := range tables {
if table.Name == "filter" {
return table, nil
}
}
return nil, errFilterTableNotFound
}
func (r *router) createContainers() error {
r.chains[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRouteingFw,
Table: r.workTable,
})
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityNATSource - 1,
Type: nftables.ChainTypeNAT,
})
err := r.refreshRulesMap()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to initialize table: %v", err)
}
return nil
}
// InsertRoutingRules inserts a nftable rule pair to the forwarding chain and if enabled, to the nat chain
func (r *router) InsertRoutingRules(pair manager.RouterPair) error {
err := r.refreshRulesMap()
if err != nil {
return err
}
err = r.insertRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
if err != nil {
return err
}
err = r.insertRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
if err != nil {
return err
}
if pair.Masquerade {
err = r.insertRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
if err != nil {
return err
}
err = r.insertRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
if err != nil {
return err
}
}
if r.filterTable != nil && !r.isDefaultFwdRulesEnabled {
log.Debugf("add default accept forward rule")
r.acceptForwardRule(pair.Source)
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.Destination, err)
}
return nil
}
// insertRoutingRule inserts a nftable rule to the conn client flush queue
func (r *router) insertRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
var expression []expr.Any
if isNat {
expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) // nolint:gocritic
} else {
expression = append(sourceExp, append(destExp, exprCounterAccept...)...) // nolint:gocritic
}
ruleKey := manager.GenKey(format, pair.ID)
_, exists := r.rules[ruleKey]
if exists {
err := r.removeRoutingRule(format, pair)
if err != nil {
return err
}
}
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainName],
Exprs: expression,
UserData: []byte(ruleKey),
})
return nil
}
func (r *router) acceptForwardRule(sourceNetwork string) {
src := generateCIDRMatcherExpressions(true, sourceNetwork)
dst := generateCIDRMatcherExpressions(false, "0.0.0.0/0")
var exprs []expr.Any
exprs = append(src, append(dst, &expr.Verdict{ // nolint:gocritic
Kind: expr.VerdictAccept,
})...)
rule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleSrc),
}
r.conn.AddRule(rule)
src = generateCIDRMatcherExpressions(true, "0.0.0.0/0")
dst = generateCIDRMatcherExpressions(false, sourceNetwork)
exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic
Kind: expr.VerdictAccept,
})...)
rule = &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleDst),
}
r.conn.AddRule(rule)
r.isDefaultFwdRulesEnabled = true
}
// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
func (r *router) RemoveRoutingRules(pair manager.RouterPair) error {
err := r.refreshRulesMap()
if err != nil {
return err
}
err = r.removeRoutingRule(manager.ForwardingFormat, pair)
if err != nil {
return err
}
err = r.removeRoutingRule(manager.InForwardingFormat, manager.GetInPair(pair))
if err != nil {
return err
}
err = r.removeRoutingRule(manager.NatFormat, pair)
if err != nil {
return err
}
err = r.removeRoutingRule(manager.InNatFormat, manager.GetInPair(pair))
if err != nil {
return err
}
if len(r.rules) == 0 {
err := r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
}
log.Debugf("nftables: removed rules for %s", pair.Destination)
return nil
}
// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error {
ruleKey := manager.GenKey(format, pair.ID)
rule, found := r.rules[ruleKey]
if found {
ruleType := "forwarding"
if rule.Chain.Type == nftables.ChainTypeNAT {
ruleType = "nat"
}
err := r.conn.DelRule(rule)
if err != nil {
return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.Destination, err)
}
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.Destination)
delete(r.rules, ruleKey)
}
return nil
}
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
// duplicates and to get missing attributes that we don't have when adding new rules
func (r *router) refreshRulesMap() error {
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("nftables: unable to list rules: %v", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
r.rules[string(rule.UserData)] = rule
}
}
}
return nil
}
func (r *router) cleanUpDefaultForwardRules() error {
if r.filterTable == nil {
r.isDefaultFwdRulesEnabled = false
return nil
}
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return err
}
var rules []*nftables.Rule
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name {
continue
}
if chain.Name != "FORWARD" {
continue
}
rules, err = r.conn.GetRules(r.filterTable, chain)
if err != nil {
return err
}
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleSrc)) || bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleDst)) {
err := r.conn.DelRule(rule)
if err != nil {
return err
}
}
}
r.isDefaultFwdRulesEnabled = false
return r.conn.Flush()
}
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any {
ip, network, _ := net.ParseCIDR(cidr)
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
var offSet uint32
if source {
offSet = 12 // src offset
} else {
offSet = 16 // dst offset
}
return []expr.Any{
// fetch src add
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offSet,
Len: 4,
},
// net mask
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 4,
Mask: network.Mask,
Xor: zeroXor,
},
// net address
&expr.Cmp{
Register: 1,
Data: add.AsSlice(),
},
}
}

View File

@@ -1,813 +0,0 @@
package nftables
import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"net/netip"
"strings"
"github.com/davecgh/go-spew/spew"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
)
const (
chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameForward = "FORWARD"
userDataAcceptForwardRuleIif = "frwacceptiif"
userDataAcceptForwardRuleOif = "frwacceptoif"
)
const refreshRulesMapError = "refresh rules map: %w"
var (
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
)
type router struct {
ctx context.Context
stop context.CancelFunc
conn *nftables.Conn
workTable *nftables.Table
filterTable *nftables.Table
chains map[string]*nftables.Chain
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
rules map[string]*nftables.Rule
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
wgIface iFaceMapper
legacyManagement bool
}
func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
ctx, cancel := context.WithCancel(parentCtx)
r := &router{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
workTable: workTable,
chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
wgIface: wgIface,
}
r.ipsetCounter = refcounter.New(
r.createIpSet,
r.deleteIpSet,
)
var err error
r.filterTable, err = r.loadFilterTable()
if err != nil {
if errors.Is(err, errFilterTableNotFound) {
log.Warnf("table 'filter' not found for forward rules")
} else {
return nil, err
}
}
err = r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
err = r.createContainers()
if err != nil {
log.Errorf("failed to create containers for route: %s", err)
}
return r, err
}
// Reset cleans existing nftables default forward rules from the system
func (r *router) Reset() error {
// clear without deleting the ipsets, the nf table will be deleted by the caller
r.ipsetCounter.Clear()
return r.cleanUpDefaultForwardRules()
}
func (r *router) cleanUpDefaultForwardRules() error {
if r.filterTable == nil {
return nil
}
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list chains: %v", err)
}
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
continue
}
rules, err := r.conn.GetRules(r.filterTable, chain)
if err != nil {
return fmt.Errorf("get rules: %v", err)
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule: %v", err)
}
}
}
}
return r.conn.Flush()
}
func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
}
for _, table := range tables {
if table.Name == "filter" {
return table, nil
}
}
return nil, errFilterTableNotFound
}
func (r *router) createContainers() error {
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingFw,
Table: r.workTable,
})
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
prio := *nftables.ChainPriorityNATSource - 1
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting,
Priority: &prio,
Type: nftables.ChainTypeNAT,
})
r.acceptForwardRules()
if err := r.refreshRulesMap(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: unable to initialize table: %v", err)
}
return nil
}
// AddRouteFiltering appends a nftables rule to the routing chain
func (r *router) AddRouteFiltering(
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil
}
chain := r.chains[chainNameRoutingFw]
var exprs []expr.Any
switch {
case len(sources) == 1 && sources[0].Bits() == 0:
// If it's 0.0.0.0/0, we don't need to add any source matching
case len(sources) == 1:
// If there's only one source, we can use it directly
exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...)
default:
// If there are multiple sources, create or get an ipset
var err error
exprs, err = r.getIpSetExprs(sources, exprs)
if err != nil {
return nil, fmt.Errorf("get ipset expressions: %w", err)
}
}
// Handle destination
exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...)
// Handle protocol
if proto != firewall.ProtocolALL {
protoNum, err := protoToInt(proto)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err)
}
exprs = append(exprs, &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1})
exprs = append(exprs, &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
})
exprs = append(exprs, applyPort(sPort, true)...)
exprs = append(exprs, applyPort(dPort, false)...)
}
exprs = append(exprs, &expr.Counter{})
var verdict expr.VerdictKind
if action == firewall.ActionAccept {
verdict = expr.VerdictAccept
} else {
verdict = expr.VerdictDrop
}
exprs = append(exprs, &expr.Verdict{Kind: verdict})
rule := &nftables.Rule{
Table: r.workTable,
Chain: chain,
Exprs: exprs,
UserData: []byte(ruleKey),
}
rule = r.conn.AddRule(rule)
log.Tracef("Adding route rule %s", spew.Sdump(rule))
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf(flushError, err)
}
r.rules[string(ruleKey)] = rule
log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
return ruleKey, nil
}
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
setName := firewall.GenerateSetName(sources)
ref, err := r.ipsetCounter.Increment(setName, sources)
if err != nil {
return nil, fmt.Errorf("create or get ipset for sources: %w", err)
}
exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: ref.Out.Name,
SetID: ref.Out.ID,
},
)
return exprs, nil
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleKey := rule.GetRuleID()
nftRule, exists := r.rules[ruleKey]
if !exists {
log.Debugf("route rule %s not found", ruleKey)
return nil
}
if nftRule.Handle == 0 {
return fmt.Errorf("route rule %s has no handle", ruleKey)
}
setName := r.findSetNameInRule(nftRule)
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
return fmt.Errorf("delete: %w", err)
}
if setName != "" {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
return fmt.Errorf("decrement ipset reference: %w", err)
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) {
// overlapping prefixes will result in an error, so we need to merge them
sources = firewall.MergeIPRanges(sources)
set := &nftables.Set{
Name: setName,
Table: r.workTable,
// required for prefixes
Interval: true,
KeyType: nftables.TypeIPAddr,
}
var elements []nftables.SetElement
for _, prefix := range sources {
// TODO: Implement IPv6 support
if prefix.Addr().Is6() {
log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue
}
// nftables needs half-open intervals [firstIP, lastIP) for prefixes
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
firstIP := prefix.Addr()
lastIP := calculateLastIP(prefix).Next()
elements = append(elements,
// the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247
// nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true},
nftables.SetElement{Key: firstIP.AsSlice()},
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
)
}
if err := r.conn.AddSet(set, elements); err != nil {
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
return set, nil
}
// calculateLastIP determines the last IP in a given prefix.
func calculateLastIP(prefix netip.Prefix) netip.Addr {
hostMask := ^uint32(0) >> prefix.Masked().Bits()
lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask
return netip.AddrFrom4(uint32ToBytes(lastIP))
}
// Utility function to convert netip.Addr to uint32.
func uint32FromNetipAddr(addr netip.Addr) uint32 {
b := addr.As4()
return binary.BigEndian.Uint32(b[:])
}
// Utility function to convert uint32 to a netip-compatible byte slice.
func uint32ToBytes(ip uint32) [4]byte {
var b [4]byte
binary.BigEndian.PutUint32(b[:], ip)
return b
}
func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
r.conn.DelSet(set)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
log.Debugf("Deleted unused ipset %s", setName)
return nil
}
func (r *router) findSetNameInRule(rule *nftables.Rule) string {
for _, e := range rule.Exprs {
if lookup, ok := e.(*expr.Lookup); ok {
return lookup.SetName
}
}
return ""
}
func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule %s: %w", ruleKey, err)
}
delete(r.rules, ruleKey)
log.Debugf("removed route rule %s", ruleKey)
return nil
}
// AddNatRule appends a nftables rule pair to the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil {
return fmt.Errorf("add legacy routing rule: %w", err)
}
}
if pair.Masquerade {
if err := r.addNatRule(pair); err != nil {
return fmt.Errorf("add nat rule: %w", err)
}
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("add inverse nat rule: %w", err)
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err)
}
return nil
}
// addNatRule inserts a nftables rule to the conn client flush queue
func (r *router) addNatRule(pair firewall.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
dir := expr.MetaKeyIIFNAME
if pair.Inverse {
dir = expr.MetaKeyOIFNAME
}
intf := ifname(r.wgIface.Name())
exprs := []expr.Any{
&expr.Meta{
Key: dir,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
}
exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...)
exprs = append(exprs,
&expr.Counter{}, &expr.Masq{},
)
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if _, exists := r.rules[ruleKey]; exists {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove routing rule: %w", err)
}
}
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs,
UserData: []byte(ruleKey),
})
return nil
}
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
exprs := []expr.Any{
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if _, exists := r.rules[ruleKey]; exists {
if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err)
}
}
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Exprs: expression,
UserData: []byte(ruleKey),
})
return nil
}
// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
} else {
log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey)
}
return nil
}
// GetLegacyManagement returns the route manager's legacy management mode
func (r *router) GetLegacyManagement() bool {
return r.legacyManagement
}
// SetLegacyManagement sets the route manager to use legacy management mode
func (r *router) SetLegacyManagement(isLegacy bool) {
r.legacyManagement = isLegacy
}
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
func (r *router) RemoveAllLegacyRouteRules() error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
for k, rule := range r.rules {
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
continue
}
if err := r.conn.DelRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
// acceptForwardRules adds iif/oif rules in the filter table/forward chain to make sure
// that our traffic is not dropped by existing rules there.
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
func (r *router) acceptForwardRules() {
if r.filterTable == nil {
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
return
}
intf := ifname(r.wgIface.Name())
// Rule for incoming interface (iif) with counter
iifRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
},
UserData: []byte(userDataAcceptForwardRuleIif),
}
r.conn.InsertRule(iifRule)
// Rule for outgoing interface (oif) with counter
oifRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 2,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 2,
Data: []byte{0, 0, 0, 0},
},
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
},
UserData: []byte(userDataAcceptForwardRuleOif),
}
r.conn.InsertRule(oifRule)
}
// RemoveNatRule removes a nftables rule pair from nat chains
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse nat rule: %w", err)
}
if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
}
log.Debugf("nftables: removed nat rules for %s", pair.Destination)
return nil
}
// removeNatRule adds a nftables rule to the removal queue and deletes it from the rules map
func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
err := r.conn.DelRule(rule)
if err != nil {
return fmt.Errorf("remove nat rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("nftables: removed nat rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
} else {
log.Debugf("nftables: nat rule %s not found", ruleKey)
}
return nil
}
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
// duplicates and to get missing attributes that we don't have when adding new rules
func (r *router) refreshRulesMap() error {
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("nftables: unable to list rules: %v", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
r.rules[string(rule.UserData)] = rule
}
}
}
return nil
}
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
var offset uint32
if source {
offset = 12 // src offset
} else {
offset = 16 // dst offset
}
ones := prefix.Bits()
// 0.0.0.0/0 doesn't need extra expressions
if ones == 0 {
return nil
}
mask := net.CIDRMask(ones, 32)
return []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offset,
Len: 4,
},
// netmask
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 4,
Mask: mask,
Xor: []byte{0, 0, 0, 0},
},
// net address
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: prefix.Masked().Addr().AsSlice(),
},
}
}
func applyPort(port *firewall.Port, isSource bool) []expr.Any {
if port == nil {
return nil
}
var exprs []expr.Any
offset := uint32(2) // Default offset for destination port
if isSource {
offset = 0 // Offset for source port
}
exprs = append(exprs, &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: offset,
Len: 2,
})
if port.IsRange && len(port.Values) == 2 {
// Handle port range
exprs = append(exprs,
&expr.Cmp{
Op: expr.CmpOpGte,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[0])),
},
&expr.Cmp{
Op: expr.CmpOpLte,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[1])),
},
)
} else {
// Handle single port or multiple ports
for i, p := range port.Values {
if i > 0 {
// Add a bitwise OR operation between port checks
exprs = append(exprs, &expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: []byte{0x00, 0x00, 0xff, 0xff},
Xor: []byte{0x00, 0x00, 0x00, 0x00},
})
}
exprs = append(exprs, &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(p)),
})
}
}
return exprs
}

View File

@@ -4,15 +4,11 @@ package nftables
import ( import (
"context" "context"
"encoding/binary"
"net/netip"
"os/exec"
"testing" "testing"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -28,50 +24,56 @@ const (
NFTABLES NFTABLES
) )
func TestNftablesManager_AddNatRule(t *testing.T) { func TestNftablesManager_InsertRoutingRules(t *testing.T) {
if check() != NFTABLES { if check() != NFTABLES {
t.Skip("nftables not supported on this OS") t.Skip("nftables not supported on this OS")
} }
table, err := createWorkTable() table, err := createWorkTable()
require.NoError(t, err, "Failed to create work table") if err != nil {
t.Fatal(err)
}
defer deleteWorkTable() defer deleteWorkTable()
for _, testCase := range test.InsertRuleTestCases { for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
manager, err := newRouter(context.TODO(), table, ifaceMock) manager, err := newRouter(context.TODO(), table)
require.NoError(t, err, "failed to create router") require.NoError(t, err, "failed to create router")
nftablesTestingClient := &nftables.Conn{} nftablesTestingClient := &nftables.Conn{}
defer func(manager *router) { defer manager.ResetForwardRules()
require.NoError(t, manager.Reset(), "failed to reset rules")
}(manager)
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
err = manager.AddNatRule(testCase.InputPair) err = manager.InsertRoutingRules(testCase.InputPair)
require.NoError(t, err, "pair should be inserted") defer func() {
_ = manager.RemoveRoutingRules(testCase.InputPair)
}()
require.NoError(t, err, "forwarding pair should be inserted")
defer func(manager *router, pair firewall.RouterPair) {
require.NoError(t, manager.RemoveNatRule(pair), "failed to remove rule")
}(manager, testCase.InputPair)
if testCase.InputPair.Masquerade {
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
testingExpression := append(sourceExp, destExp...) //nolint:gocritic testingExpression := append(sourceExp, destExp...) //nolint:gocritic
testingExpression = append(testingExpression, fwdRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(ifaceMock.Name()),
},
)
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
if testCase.InputPair.Masquerade {
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
found := 0 found := 0
for _, chain := range manager.chains { for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain) rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
@@ -86,20 +88,27 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
require.Equal(t, 1, found, "should find at least 1 rule to test") require.Equal(t, 1, found, "should find at least 1 rule to test")
} }
if testCase.InputPair.Masquerade { sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) testingExpression = append(sourceExp, destExp...) //nolint:gocritic
testingExpression := append(sourceExp, destExp...) //nolint:gocritic inFwdRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
testingExpression = append(testingExpression,
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(ifaceMock.Name()),
},
)
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) found = 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
if testCase.InputPair.Masquerade {
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
found := 0 found := 0
for _, chain := range manager.chains { for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain) rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
@@ -113,37 +122,45 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
} }
require.Equal(t, 1, found, "should find at least 1 rule to test") require.Equal(t, 1, found, "should find at least 1 rule to test")
} }
}) })
} }
} }
func TestNftablesManager_RemoveNatRule(t *testing.T) { func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
if check() != NFTABLES { if check() != NFTABLES {
t.Skip("nftables not supported on this OS") t.Skip("nftables not supported on this OS")
} }
table, err := createWorkTable() table, err := createWorkTable()
require.NoError(t, err, "Failed to create work table") if err != nil {
t.Fatal(err)
}
defer deleteWorkTable() defer deleteWorkTable()
for _, testCase := range test.RemoveRuleTestCases { for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
manager, err := newRouter(context.TODO(), table, ifaceMock) manager, err := newRouter(context.TODO(), table)
require.NoError(t, err, "failed to create router") require.NoError(t, err, "failed to create router")
nftablesTestingClient := &nftables.Conn{} nftablesTestingClient := &nftables.Conn{}
defer func(manager *router) { defer manager.ResetForwardRules()
require.NoError(t, manager.Reset(), "failed to reset rules")
}(manager)
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRouteingFw],
Exprs: forwardExp,
UserData: []byte(forwardRuleKey),
})
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{ insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable, Table: manager.workTable,
@@ -152,11 +169,20 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
UserData: []byte(natRuleKey), UserData: []byte(natRuleKey),
}) })
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInversePair(testCase.InputPair).Source) sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
destExp = generateCIDRMatcherExpressions(false, firewall.GetInversePair(testCase.InputPair).Destination) destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRouteingFw],
Exprs: forwardExp,
UserData: []byte(inForwardRuleKey),
})
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{ insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable, Table: manager.workTable,
@@ -168,10 +194,9 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
err = nftablesTestingClient.Flush() err = nftablesTestingClient.Flush()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
err = manager.Reset() manager.ResetForwardRules()
require.NoError(t, err, "shouldn't return error")
err = manager.RemoveNatRule(testCase.InputPair) err = manager.RemoveRoutingRules(testCase.InputPair)
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
for _, chain := range manager.chains { for _, chain := range manager.chains {
@@ -179,7 +204,9 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules { for _, rule := range rules {
if len(rule.UserData) > 0 { if len(rule.UserData) > 0 {
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist") require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist") require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
} }
} }
@@ -188,468 +215,6 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
} }
} }
func TestRouter_AddRouteFiltering(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable()
r, err := newRouter(context.Background(), workTable, ifaceMock)
require.NoError(t, err, "Failed to create router")
defer func(r *router) {
require.NoError(t, r.Reset(), "Failed to reset rules")
}(r)
tests := []struct {
name string
sources []netip.Prefix
destination netip.Prefix
proto firewall.Protocol
sPort *firewall.Port
dPort *firewall.Port
direction firewall.RuleDirection
action firewall.Action
expectSet bool
}{
{
name: "Basic TCP rule with single source",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolTCP,
sPort: nil,
dPort: &firewall.Port{Values: []int{80}},
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with multiple sources",
sources: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("192.168.0.0/16"),
},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolUDP,
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionDrop,
expectSet: true,
},
{
name: "All protocols rule",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
destination: netip.MustParsePrefix("0.0.0.0/0"),
proto: firewall.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "ICMP rule",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolICMP,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "TCP rule with multiple source ports",
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
destination: netip.MustParsePrefix("192.168.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with single IP and port range",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolUDP,
sPort: nil,
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
},
{
name: "TCP rule with source and destination ports",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
destination: netip.MustParsePrefix("172.16.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []int{22}},
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "Drop all incoming traffic",
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
destination: netip.MustParsePrefix("192.168.0.0/24"),
proto: firewall.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
t.Cleanup(func() {
require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule")
})
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.GetRuleID()]
assert.True(t, ok, "Rule not found in internal map")
t.Log("Internal rule expressions:")
for i, expr := range rule.Exprs {
t.Logf(" [%d] %T: %+v", i, expr, expr)
}
// Verify internal rule content
verifyRule(t, rule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
// Check if the rule exists in nftables and verify its content
rules, err := r.conn.GetRules(r.workTable, r.chains[chainNameRoutingFw])
require.NoError(t, err, "Failed to get rules from nftables")
var nftRule *nftables.Rule
for _, rule := range rules {
if string(rule.UserData) == ruleKey.GetRuleID() {
nftRule = rule
break
}
}
require.NotNil(t, nftRule, "Rule not found in nftables")
t.Log("Actual nftables rule expressions:")
for i, expr := range nftRule.Exprs {
t.Logf(" [%d] %T: %+v", i, expr, expr)
}
// Verify actual nftables rule content
verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
})
}
}
func TestNftablesCreateIpSet(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable()
r, err := newRouter(context.Background(), workTable, ifaceMock)
require.NoError(t, err, "Failed to create router")
defer func() {
require.NoError(t, r.Reset(), "Failed to reset router")
}()
tests := []struct {
name string
sources []netip.Prefix
expected []netip.Prefix
}{
{
name: "Single IP",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
},
{
name: "Multiple IPs",
sources: []netip.Prefix{
netip.MustParsePrefix("192.168.1.1/32"),
netip.MustParsePrefix("10.0.0.1/32"),
netip.MustParsePrefix("172.16.0.1/32"),
},
},
{
name: "Single Subnet",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
},
{
name: "Multiple Subnets with Various Prefix Lengths",
sources: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("203.0.113.0/26"),
},
},
{
name: "Mix of Single IPs and Subnets in Different Positions",
sources: []netip.Prefix{
netip.MustParsePrefix("192.168.1.1/32"),
netip.MustParsePrefix("10.0.0.0/16"),
netip.MustParsePrefix("172.16.0.1/32"),
netip.MustParsePrefix("203.0.113.0/24"),
},
},
{
name: "Overlapping IPs/Subnets",
sources: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("10.0.0.0/16"),
netip.MustParsePrefix("10.0.0.1/32"),
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.1.1/32"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.0.0/16"),
},
},
}
// Add this helper function inside TestNftablesCreateIpSet
printNftSets := func() {
cmd := exec.Command("nft", "list", "sets")
output, err := cmd.CombinedOutput()
if err != nil {
t.Logf("Failed to run 'nft list sets': %v", err)
} else {
t.Logf("Current nft sets:\n%s", output)
}
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setName := firewall.GenerateSetName(tt.sources)
set, err := r.createIpSet(setName, tt.sources)
if err != nil {
t.Logf("Failed to create IP set: %v", err)
printNftSets()
require.NoError(t, err, "Failed to create IP set")
}
require.NotNil(t, set, "Created set is nil")
// Verify set properties
assert.Equal(t, setName, set.Name, "Set name mismatch")
assert.Equal(t, r.workTable, set.Table, "Set table mismatch")
assert.True(t, set.Interval, "Set interval property should be true")
assert.Equal(t, nftables.TypeIPAddr, set.KeyType, "Set key type mismatch")
// Fetch the created set from nftables
fetchedSet, err := r.conn.GetSetByName(r.workTable, setName)
require.NoError(t, err, "Failed to fetch created set")
require.NotNil(t, fetchedSet, "Fetched set is nil")
// Verify set elements
elements, err := r.conn.GetSetElements(fetchedSet)
require.NoError(t, err, "Failed to get set elements")
// Count the number of unique prefixes (excluding interval end markers)
uniquePrefixes := make(map[string]bool)
for _, elem := range elements {
if !elem.IntervalEnd {
ip := netip.AddrFrom4(*(*[4]byte)(elem.Key))
uniquePrefixes[ip.String()] = true
}
}
// Check against expected merged prefixes
expectedCount := len(tt.expected)
if expectedCount == 0 {
expectedCount = len(tt.sources)
}
assert.Equal(t, expectedCount, len(uniquePrefixes), "Number of unique prefixes in set doesn't match expected")
// Verify each expected prefix is in the set
for _, expected := range tt.expected {
found := false
for _, elem := range elements {
if !elem.IntervalEnd {
ip := netip.AddrFrom4(*(*[4]byte)(elem.Key))
if expected.Contains(ip) {
found = true
break
}
}
}
assert.True(t, found, "Expected prefix %s not found in set", expected)
}
r.conn.DelSet(set)
if err := r.conn.Flush(); err != nil {
t.Logf("Failed to delete set: %v", err)
printNftSets()
}
require.NoError(t, err, "Failed to delete set")
})
}
}
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) {
t.Helper()
assert.NotNil(t, rule, "Rule should not be nil")
// Verify sources and destination
if expectSet {
assert.True(t, containsSetLookup(rule.Exprs), "Rule should contain set lookup for multiple sources")
} else if len(sources) == 1 && sources[0].Bits() != 0 {
if direction == firewall.RuleDirectionIN {
assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], true), "Rule should contain source CIDR matcher for %s", sources[0])
} else {
assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], false), "Rule should contain destination CIDR matcher for %s", sources[0])
}
}
if direction == firewall.RuleDirectionIN {
assert.True(t, containsCIDRMatcher(rule.Exprs, destination, false), "Rule should contain destination CIDR matcher for %s", destination)
} else {
assert.True(t, containsCIDRMatcher(rule.Exprs, destination, true), "Rule should contain source CIDR matcher for %s", destination)
}
// Verify protocol
if proto != firewall.ProtocolALL {
assert.True(t, containsProtocol(rule.Exprs, proto), "Rule should contain protocol matcher for %s", proto)
}
// Verify ports
if sPort != nil {
assert.True(t, containsPort(rule.Exprs, sPort, true), "Rule should contain source port matcher for %v", sPort)
}
if dPort != nil {
assert.True(t, containsPort(rule.Exprs, dPort, false), "Rule should contain destination port matcher for %v", dPort)
}
// Verify action
assert.True(t, containsAction(rule.Exprs, action), "Rule should contain correct action: %s", action)
}
func containsSetLookup(exprs []expr.Any) bool {
for _, e := range exprs {
if _, ok := e.(*expr.Lookup); ok {
return true
}
}
return false
}
func containsCIDRMatcher(exprs []expr.Any, prefix netip.Prefix, isSource bool) bool {
var offset uint32
if isSource {
offset = 12 // src offset
} else {
offset = 16 // dst offset
}
var payloadFound, bitwiseFound, cmpFound bool
for _, e := range exprs {
switch ex := e.(type) {
case *expr.Payload:
if ex.Base == expr.PayloadBaseNetworkHeader && ex.Offset == offset && ex.Len == 4 {
payloadFound = true
}
case *expr.Bitwise:
if ex.Len == 4 && len(ex.Mask) == 4 && len(ex.Xor) == 4 {
bitwiseFound = true
}
case *expr.Cmp:
if ex.Op == expr.CmpOpEq && len(ex.Data) == 4 {
cmpFound = true
}
}
}
return (payloadFound && bitwiseFound && cmpFound) || prefix.Bits() == 0
}
func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
var offset uint32 = 2 // Default offset for destination port
if isSource {
offset = 0 // Offset for source port
}
var payloadFound, portMatchFound bool
for _, e := range exprs {
switch ex := e.(type) {
case *expr.Payload:
if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 {
payloadFound = true
}
case *expr.Cmp:
if port.IsRange {
if ex.Op == expr.CmpOpGte || ex.Op == expr.CmpOpLte {
portMatchFound = true
}
} else {
if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 {
portValue := binary.BigEndian.Uint16(ex.Data)
for _, p := range port.Values {
if uint16(p) == portValue {
portMatchFound = true
break
}
}
}
}
}
if payloadFound && portMatchFound {
return true
}
}
return false
}
func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
var metaFound, cmpFound bool
expectedProto, _ := protoToInt(proto)
for _, e := range exprs {
switch ex := e.(type) {
case *expr.Meta:
if ex.Key == expr.MetaKeyL4PROTO {
metaFound = true
}
case *expr.Cmp:
if ex.Op == expr.CmpOpEq && len(ex.Data) == 1 && ex.Data[0] == expectedProto {
cmpFound = true
}
}
}
return metaFound && cmpFound
}
func containsAction(exprs []expr.Any, action firewall.Action) bool {
for _, e := range exprs {
if verdict, ok := e.(*expr.Verdict); ok {
switch action {
case firewall.ActionAccept:
return verdict.Kind == expr.VerdictAccept
case firewall.ActionDrop:
return verdict.Kind == expr.VerdictDrop
}
}
}
return false
}
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found. // check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func check() int { func check() int {
nf := nftables.Conn{} nf := nftables.Conn{}
@@ -685,12 +250,12 @@ func createWorkTable() (*nftables.Table, error) {
} }
for _, t := range tables { for _, t := range tables {
if t.Name == tableNameNetbird { if t.Name == tableName {
sConn.DelTable(t) sConn.DelTable(t)
} }
} }
table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}) table := sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
err = sConn.Flush() err = sConn.Flush()
return table, err return table, err
@@ -708,7 +273,7 @@ func deleteWorkTable() {
} }
for _, t := range tables { for _, t := range tables {
if t.Name == tableNameNetbird { if t.Name == tableName {
sConn.DelTable(t) sConn.DelTable(t)
} }
} }

View File

@@ -1,10 +1,8 @@
//go:build !android
package test package test
import ( import firewall "github.com/netbirdio/netbird/client/firewall/manager"
"net/netip"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
var ( var (
InsertRuleTestCases = []struct { InsertRuleTestCases = []struct {
@@ -15,8 +13,8 @@ var (
Name: "Insert Forwarding IPV4 Rule", Name: "Insert Forwarding IPV4 Rule",
InputPair: firewall.RouterPair{ InputPair: firewall.RouterPair{
ID: "zxa", ID: "zxa",
Source: netip.MustParsePrefix("100.100.100.1/32"), Source: "100.100.100.1/32",
Destination: netip.MustParsePrefix("100.100.200.0/24"), Destination: "100.100.200.0/24",
Masquerade: false, Masquerade: false,
}, },
}, },
@@ -24,8 +22,8 @@ var (
Name: "Insert Forwarding And Nat IPV4 Rules", Name: "Insert Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{ InputPair: firewall.RouterPair{
ID: "zxa", ID: "zxa",
Source: netip.MustParsePrefix("100.100.100.1/32"), Source: "100.100.100.1/32",
Destination: netip.MustParsePrefix("100.100.200.0/24"), Destination: "100.100.200.0/24",
Masquerade: true, Masquerade: true,
}, },
}, },
@@ -40,8 +38,8 @@ var (
Name: "Remove Forwarding And Nat IPV4 Rules", Name: "Remove Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{ InputPair: firewall.RouterPair{
ID: "zxa", ID: "zxa",
Source: netip.MustParsePrefix("100.100.100.1/32"), Source: "100.100.100.1/32",
Destination: netip.MustParsePrefix("100.100.200.0/24"), Destination: "100.100.200.0/24",
Masquerade: true, Masquerade: true,
}, },
}, },

View File

@@ -64,18 +64,15 @@ func manageFirewallRule(ruleName string, action action, extraArgs ...string) err
if action == addRule { if action == addRule {
args = append(args, extraArgs...) args = append(args, extraArgs...)
} }
netshCmd := GetSystem32Command("netsh")
cmd := exec.Command(netshCmd, args...) cmd := exec.Command("netsh", args...)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
return cmd.Run() return cmd.Run()
} }
func isWindowsFirewallReachable() bool { func isWindowsFirewallReachable() bool {
args := []string{"advfirewall", "show", "allprofiles", "state"} args := []string{"advfirewall", "show", "allprofiles", "state"}
cmd := exec.Command("netsh", args...)
netshCmd := GetSystem32Command("netsh")
cmd := exec.Command(netshCmd, args...)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
_, err := cmd.Output() _, err := cmd.Output()
@@ -90,23 +87,8 @@ func isWindowsFirewallReachable() bool {
func isFirewallRuleActive(ruleName string) bool { func isFirewallRuleActive(ruleName string) bool {
args := []string{"advfirewall", "firewall", "show", "rule", "name=" + ruleName} args := []string{"advfirewall", "firewall", "show", "rule", "name=" + ruleName}
netshCmd := GetSystem32Command("netsh") cmd := exec.Command("netsh", args...)
cmd := exec.Command(netshCmd, args...)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true}
_, err := cmd.Output() _, err := cmd.Output()
return err == nil return err == nil
} }
// GetSystem32Command checks if a command can be found in the system path and returns it. In case it can't find it
// in the path it will return the full path of a command assuming C:\windows\system32 as the base path.
func GetSystem32Command(command string) string {
_, err := exec.LookPath(command)
if err == nil {
return command
}
log.Tracef("Command %s not found in PATH, using C:\\windows\\system32\\%s.exe path", command, command)
return "C:\\windows\\system32\\" + command + ".exe"
}

View File

@@ -3,7 +3,6 @@ package uspfilter
import ( import (
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
"github.com/google/gopacket" "github.com/google/gopacket"
@@ -12,8 +11,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/client/iface/device"
) )
const layerTypeAll = 0 const layerTypeAll = 0
@@ -24,7 +22,7 @@ var (
// IFaceMapper defines subset methods of interface required for manager // IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface { type IFaceMapper interface {
SetFilter(device.PacketFilter) error SetFilter(iface.PacketFilter) error
Address() iface.WGAddress Address() iface.WGAddress
} }
@@ -105,26 +103,26 @@ func (m *Manager) IsServerRouteSupported() bool {
} }
} }
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
return errRouteNotSupported return errRouteNotSupported
} }
return m.nativeFirewall.AddNatRule(pair) return m.nativeFirewall.InsertRoutingRules(pair)
} }
// RemoveNatRule removes a routing firewall rule // RemoveRoutingRules removes a routing firewall rule
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
return errRouteNotSupported return errRouteNotSupported
} }
return m.nativeFirewall.RemoveNatRule(pair) return m.nativeFirewall.RemoveRoutingRules(pair)
} }
// AddPeerFiltering rule to the firewall // AddFiltering rule to the firewall
// //
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
func (m *Manager) AddPeerFiltering( func (m *Manager) AddFiltering(
ip net.IP, ip net.IP,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
@@ -190,22 +188,8 @@ func (m *Manager) AddPeerFiltering(
return []firewall.Rule{&r}, nil return []firewall.Rule{&r}, nil
} }
func (m *Manager) AddRouteFiltering(sources [] netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action ) (firewall.Rule, error) { // DeleteRule from the firewall by rule definition
if m.nativeFirewall == nil { func (m *Manager) DeleteRule(rule firewall.Rule) error {
return nil, errRouteNotSupported
}
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
}
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
if m.nativeFirewall == nil {
return errRouteNotSupported
}
return m.nativeFirewall.DeleteRouteRule(rule)
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -231,11 +215,6 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
return nil return nil
} }
// SetLegacyManagement doesn't need to be implemented for this manager
func (m *Manager) SetLegacyManagement(_ bool) error {
return nil
}
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil } func (m *Manager) Flush() error { return nil }
@@ -358,6 +337,7 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) { if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
return rule.drop, true return rule.drop, true
} }
return rule.drop, true
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return rule.drop, true return rule.drop, true
} }
@@ -416,7 +396,7 @@ func (m *Manager) RemovePacketHook(hookID string) error {
for _, r := range arr { for _, r := range arr {
if r.id == hookID { if r.id == hookID {
rule := r rule := r
return m.DeletePeerRule(&rule) return m.DeleteRule(&rule)
} }
} }
} }
@@ -424,7 +404,7 @@ func (m *Manager) RemovePacketHook(hookID string) error {
for _, r := range arr { for _, r := range arr {
if r.id == hookID { if r.id == hookID {
rule := r rule := r
return m.DeletePeerRule(&rule) return m.DeleteRule(&rule)
} }
} }
} }

View File

@@ -11,16 +11,15 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/client/iface/device"
) )
type IFaceMock struct { type IFaceMock struct {
SetFilterFunc func(device.PacketFilter) error SetFilterFunc func(iface.PacketFilter) error
AddressFunc func() iface.WGAddress AddressFunc func() iface.WGAddress
} }
func (i *IFaceMock) SetFilter(iface device.PacketFilter) error { func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
if i.SetFilterFunc == nil { if i.SetFilterFunc == nil {
return fmt.Errorf("not implemented") return fmt.Errorf("not implemented")
} }
@@ -36,7 +35,7 @@ func (i *IFaceMock) Address() iface.WGAddress {
func TestManagerCreate(t *testing.T) { func TestManagerCreate(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(iface.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock) m, err := Create(ifaceMock)
@@ -50,10 +49,10 @@ func TestManagerCreate(t *testing.T) {
} }
} }
func TestManagerAddPeerFiltering(t *testing.T) { func TestManagerAddFiltering(t *testing.T) {
isSetFilterCalled := false isSetFilterCalled := false
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { SetFilterFunc: func(iface.PacketFilter) error {
isSetFilterCalled = true isSetFilterCalled = true
return nil return nil
}, },
@@ -72,7 +71,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule"
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -91,7 +90,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
func TestManagerDeleteRule(t *testing.T) { func TestManagerDeleteRule(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(iface.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock) m, err := Create(ifaceMock)
@@ -107,7 +106,7 @@ func TestManagerDeleteRule(t *testing.T) {
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule"
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -120,14 +119,14 @@ func TestManagerDeleteRule(t *testing.T) {
action = fw.ActionDrop action = fw.ActionDrop
comment = "Test rule 2" comment = "Test rule 2"
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
} }
for _, r := range rule { for _, r := range rule {
err = m.DeletePeerRule(r) err = m.DeleteRule(r)
if err != nil { if err != nil {
t.Errorf("failed to delete rule: %v", err) t.Errorf("failed to delete rule: %v", err)
return return
@@ -141,7 +140,7 @@ func TestManagerDeleteRule(t *testing.T) {
} }
for _, r := range rule2 { for _, r := range rule2 {
err = m.DeletePeerRule(r) err = m.DeleteRule(r)
if err != nil { if err != nil {
t.Errorf("failed to delete rule: %v", err) t.Errorf("failed to delete rule: %v", err)
return return
@@ -237,7 +236,7 @@ func TestAddUDPPacketHook(t *testing.T) {
func TestManagerReset(t *testing.T) { func TestManagerReset(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(iface.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock) m, err := Create(ifaceMock)
@@ -253,7 +252,7 @@ func TestManagerReset(t *testing.T) {
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) _, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -272,7 +271,7 @@ func TestManagerReset(t *testing.T) {
func TestNotMatchByIP(t *testing.T) { func TestNotMatchByIP(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(iface.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock) m, err := Create(ifaceMock)
@@ -291,7 +290,7 @@ func TestNotMatchByIP(t *testing.T) {
action := fw.ActionAccept action := fw.ActionAccept
comment := "Test rule" comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, nil, direction, action, "", comment) _, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -340,7 +339,7 @@ func TestNotMatchByIP(t *testing.T) {
func TestRemovePacketHook(t *testing.T) { func TestRemovePacketHook(t *testing.T) {
// creating mock iface // creating mock iface
iface := &IFaceMock{ iface := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(iface.PacketFilter) error { return nil },
} }
// creating manager instance // creating manager instance
@@ -389,7 +388,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface // just check on the local interface
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(iface.PacketFilter) error { return nil },
} }
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock)
require.NoError(t, err) require.NoError(t, err)
@@ -407,9 +406,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 { if i%2 == 0 {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else { } else {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
} }
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")

View File

@@ -1,5 +0,0 @@
package configurer
import "errors"
var ErrPeerNotFound = errors.New("peer not found")

View File

@@ -1,9 +0,0 @@
package configurer
import "time"
type WGStats struct {
LastHandshake time.Time
TxBytes int64
RxBytes int64
}

View File

@@ -1,18 +0,0 @@
//go:build !android
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
)
type WGTunDevice interface {
Create() (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address WGAddress) error
WgAddress() WGAddress
DeviceName() string
Close() error
FilteredDevice() *device.FilteredDevice
}

View File

@@ -1,29 +0,0 @@
package device
import (
"fmt"
"net"
)
// WGAddress WireGuard parsed address
type WGAddress struct {
IP net.IP
Network *net.IPNet
}
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
func ParseWGAddress(address string) (WGAddress, error) {
ip, network, err := net.ParseCIDR(address)
if err != nil {
return WGAddress{}, err
}
return WGAddress{
IP: ip,
Network: network,
}, nil
}
func (addr WGAddress) String() string {
maskSize, _ := addr.Network.Mask.Size()
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
}

View File

@@ -1,141 +0,0 @@
//go:build !ios
package device
import (
"fmt"
"os/exec"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
)
type TunDevice struct {
name string
address WGAddress
port int
key string
mtu int
iceBind *bind.ICEBind
device *device.Device
filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer
}
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice {
return &TunDevice{
name: name,
address: address,
port: port,
key: key,
mtu: mtu,
iceBind: bind.NewICEBind(transportNet, filterFn),
}
}
func (t *TunDevice) Create() (WGConfigurer, error) {
tunDevice, err := tun.CreateTUN(t.name, t.mtu)
if err != nil {
return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.filteredDevice = newDeviceFilter(tunDevice)
// We need to create a wireguard-go device and listen to configuration requests
t.device = device.NewDevice(
t.filteredDevice,
t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "),
)
err = t.assignAddr()
if err != nil {
t.device.Close()
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()
t.configurer.Close()
return nil, fmt.Errorf("error configuring interface: %s", err)
}
return t.configurer, nil
}
func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
}
udpMux, err := t.iceBind.GetICEMux()
if err != nil {
return nil, err
}
t.udpMux = udpMux
log.Debugf("device is ready to use: %s", t.name)
return udpMux, nil
}
func (t *TunDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
func (t *TunDevice) Close() error {
if t.configurer != nil {
t.configurer.Close()
}
if t.device != nil {
t.device.Close()
t.device = nil
}
if t.udpMux != nil {
return t.udpMux.Close()
}
return nil
}
func (t *TunDevice) WgAddress() WGAddress {
return t.address
}
func (t *TunDevice) DeviceName() string {
return t.name
}
func (t *TunDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
func (t *TunDevice) assignAddr() error {
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
if out, err := cmd.CombinedOutput(); err != nil {
log.Errorf("adding address command '%v' failed with output: %s", cmd.String(), out)
return err
}
// dummy ipv6 so routing works
cmd = exec.Command("ifconfig", t.name, "inet6", "fe80::/64")
if out, err := cmd.CombinedOutput(); err != nil {
log.Debugf("adding address command '%v' failed with output: %s", cmd.String(), out)
}
routeCmd := exec.Command("route", "add", "-net", t.address.Network.String(), "-interface", t.name)
if out, err := routeCmd.CombinedOutput(); err != nil {
log.Errorf("adding route command '%v' failed with output: %s", routeCmd.String(), out)
return err
}
return nil
}

View File

@@ -1,163 +0,0 @@
//go:build (linux && !android) || freebsd
package device
import (
"context"
"fmt"
"net"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/sharedsock"
)
type TunKernelDevice struct {
name string
address WGAddress
wgPort int
key string
mtu int
ctx context.Context
ctxCancel context.CancelFunc
transportNet transport.Net
link *wgLink
udpMuxConn net.PacketConn
udpMux *bind.UniversalUDPMuxDefault
filterFn bind.FilterFn
}
func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
checkUser()
ctx, cancel := context.WithCancel(context.Background())
return &TunKernelDevice{
ctx: ctx,
ctxCancel: cancel,
name: name,
address: address,
wgPort: wgPort,
key: key,
mtu: mtu,
transportNet: transportNet,
}
}
func (t *TunKernelDevice) Create() (WGConfigurer, error) {
link := newWGLink(t.name)
if err := link.recreate(); err != nil {
return nil, fmt.Errorf("recreate: %w", err)
}
t.link = link
if err := t.assignAddr(); err != nil {
return nil, fmt.Errorf("assign addr: %w", err)
}
// TODO: do a MTU discovery
log.Debugf("setting MTU: %d interface: %s", t.mtu, t.name)
if err := link.setMTU(t.mtu); err != nil {
return nil, fmt.Errorf("set mtu: %w", err)
}
configurer := configurer.NewKernelConfigurer(t.name)
if err := configurer.ConfigureInterface(t.key, t.wgPort); err != nil {
return nil, fmt.Errorf("error configuring interface: %s", err)
}
return configurer, nil
}
func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.udpMux != nil {
return t.udpMux, nil
}
if t.link == nil {
return nil, fmt.Errorf("device is not ready yet")
}
log.Debugf("bringing up interface: %s", t.name)
if err := t.link.up(); err != nil {
log.Errorf("error bringing up interface: %s", t.name)
return nil, err
}
rawSock, err := sharedsock.Listen(t.wgPort, sharedsock.NewIncomingSTUNFilter())
if err != nil {
return nil, err
}
bindParams := bind.UniversalUDPMuxParams{
UDPConn: rawSock,
Net: t.transportNet,
FilterFn: t.filterFn,
}
mux := bind.NewUniversalUDPMuxDefault(bindParams)
go mux.ReadFromConn(t.ctx)
t.udpMuxConn = rawSock
t.udpMux = mux
log.Debugf("device is ready to use: %s", t.name)
return t.udpMux, nil
}
func (t *TunKernelDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
func (t *TunKernelDevice) Close() error {
if t.link == nil {
return nil
}
t.ctxCancel()
var closErr error
if err := t.link.Close(); err != nil {
log.Debugf("failed to close link: %s", err)
closErr = err
}
if t.udpMux != nil {
if err := t.udpMux.Close(); err != nil {
log.Debugf("failed to close udp mux: %s", err)
closErr = err
}
if err := t.udpMuxConn.Close(); err != nil {
log.Debugf("failed to close udp mux connection: %s", err)
closErr = err
}
}
return closErr
}
func (t *TunKernelDevice) WgAddress() WGAddress {
return t.address
}
func (t *TunKernelDevice) DeviceName() string {
return t.name
}
func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
return nil
}
// assignAddr Adds IP address to the tunnel interface
func (t *TunKernelDevice) assignAddr() error {
return t.link.assignAddr(t.address)
}

View File

@@ -1,120 +0,0 @@
//go:build !android
// +build !android
package device
import (
"fmt"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/netstack"
)
type TunNetstackDevice struct {
name string
address WGAddress
port int
key string
mtu int
listenAddress string
iceBind *bind.ICEBind
device *device.Device
filteredDevice *FilteredDevice
nsTun *netstack.NetStackTun
udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer
}
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) *TunNetstackDevice {
return &TunNetstackDevice{
name: name,
address: address,
port: wgPort,
key: key,
mtu: mtu,
listenAddress: listenAddress,
iceBind: bind.NewICEBind(transportNet, filterFn),
}
}
func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
log.Info("create netstack tun interface")
t.nsTun = netstack.NewNetStackTun(t.listenAddress, t.address.IP.String(), t.mtu)
tunIface, err := t.nsTun.Create()
if err != nil {
return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.filteredDevice = newDeviceFilter(tunIface)
t.device = device.NewDevice(
t.filteredDevice,
t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "),
)
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
_ = tunIface.Close()
return nil, fmt.Errorf("error configuring interface: %s", err)
}
log.Debugf("device has been created: %s", t.name)
return t.configurer, nil
}
func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.device == nil {
return nil, fmt.Errorf("device is not ready yet")
}
err := t.device.Up()
if err != nil {
return nil, err
}
udpMux, err := t.iceBind.GetICEMux()
if err != nil {
return nil, err
}
t.udpMux = udpMux
log.Debugf("netstack device is ready to use")
return udpMux, nil
}
func (t *TunNetstackDevice) UpdateAddr(WGAddress) error {
return nil
}
func (t *TunNetstackDevice) Close() error {
if t.configurer != nil {
t.configurer.Close()
}
if t.device != nil {
t.device.Close()
}
if t.udpMux != nil {
return t.udpMux.Close()
}
return nil
}
func (t *TunNetstackDevice) WgAddress() WGAddress {
return t.address
}
func (t *TunNetstackDevice) DeviceName() string {
return t.name
}
func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}

View File

@@ -1,145 +0,0 @@
//go:build (linux && !android) || freebsd
package device
import (
"fmt"
"os"
"runtime"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
)
type USPDevice struct {
name string
address WGAddress
port int
key string
mtu int
iceBind *bind.ICEBind
device *device.Device
filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
configurer WGConfigurer
}
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *USPDevice {
log.Infof("using userspace bind mode")
checkUser()
return &USPDevice{
name: name,
address: address,
port: port,
key: key,
mtu: mtu,
iceBind: bind.NewICEBind(transportNet, filterFn)}
}
func (t *USPDevice) Create() (WGConfigurer, error) {
log.Info("create tun interface")
tunIface, err := tun.CreateTUN(t.name, t.mtu)
if err != nil {
log.Debugf("failed to create tun interface (%s, %d): %s", t.name, t.mtu, err)
return nil, fmt.Errorf("error creating tun device: %s", err)
}
t.filteredDevice = newDeviceFilter(tunIface)
// We need to create a wireguard-go device and listen to configuration requests
t.device = device.NewDevice(
t.filteredDevice,
t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "),
)
err = t.assignAddr()
if err != nil {
t.device.Close()
return nil, fmt.Errorf("error assigning ip: %s", err)
}
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()
t.configurer.Close()
return nil, fmt.Errorf("error configuring interface: %s", err)
}
return t.configurer, nil
}
func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.device == nil {
return nil, fmt.Errorf("device is not ready yet")
}
err := t.device.Up()
if err != nil {
return nil, err
}
udpMux, err := t.iceBind.GetICEMux()
if err != nil {
return nil, err
}
t.udpMux = udpMux
log.Debugf("device is ready to use: %s", t.name)
return udpMux, nil
}
func (t *USPDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
func (t *USPDevice) Close() error {
if t.configurer != nil {
t.configurer.Close()
}
if t.device != nil {
t.device.Close()
}
if t.udpMux != nil {
return t.udpMux.Close()
}
return nil
}
func (t *USPDevice) WgAddress() WGAddress {
return t.address
}
func (t *USPDevice) DeviceName() string {
return t.name
}
func (t *USPDevice) FilteredDevice() *FilteredDevice {
return t.filteredDevice
}
// assignAddr Adds IP address to the tunnel interface
func (t *USPDevice) assignAddr() error {
link := newWGLink(t.name)
return link.assignAddr(t.address)
}
func checkUser() {
if runtime.GOOS == "freebsd" {
euid := os.Geteuid()
if euid != 0 {
log.Warn("newTunUSPDevice: on netbird must run as root to be able to assign address to the tun interface with ifconfig")
}
}
}

View File

@@ -1,20 +0,0 @@
package device
import (
"net"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/configurer"
)
type WGConfigurer interface {
ConfigureInterface(privateKey string, port int) error
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error
Close()
GetStats(peerKey string) (configurer.WGStats, error)
}

View File

@@ -1,18 +0,0 @@
package device
// WireGuardModuleIsLoaded check if kernel support wireguard
func WireGuardModuleIsLoaded() bool {
// Despite the fact FreeBSD natively support Wireguard (https://github.com/WireGuard/wireguard-freebsd)
// we are currently do not use it, since it is required to add wireguard kernel support to
// - https://github.com/netbirdio/netbird/tree/main/sharedsock
// - https://github.com/mdlayher/socket
// TODO: implement kernel space
return false
}
// ModuleTunIsLoaded check if tun module exist, if is not attempt to load it
func ModuleTunIsLoaded() bool {
// Assume tun supported by freebsd kernel by default
// TODO: implement check for module loaded in kernel or build-it
return true
}

View File

@@ -1,81 +0,0 @@
package device
import (
"fmt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/freebsd"
)
type wgLink struct {
name string
link *freebsd.Link
}
func newWGLink(name string) *wgLink {
link := freebsd.NewLink(name)
return &wgLink{
name: name,
link: link,
}
}
// Type returns the interface type
func (l *wgLink) Type() string {
return "wireguard"
}
// Close deletes the link interface
func (l *wgLink) Close() error {
return l.link.Del()
}
func (l *wgLink) recreate() error {
if err := l.link.Recreate(); err != nil {
return fmt.Errorf("recreate: %w", err)
}
return nil
}
func (l *wgLink) setMTU(mtu int) error {
if err := l.link.SetMTU(mtu); err != nil {
return fmt.Errorf("set mtu: %w", err)
}
return nil
}
func (l *wgLink) up() error {
if err := l.link.Up(); err != nil {
return fmt.Errorf("up: %w", err)
}
return nil
}
func (l *wgLink) assignAddr(address WGAddress) error {
link, err := freebsd.LinkByName(l.name)
if err != nil {
return fmt.Errorf("link by name: %w", err)
}
ip := address.IP.String()
mask := "0x" + address.Network.Mask.String()
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)
err = link.AssignAddr(ip, mask)
if err != nil {
return fmt.Errorf("assign addr: %w", err)
}
err = link.Up()
if err != nil {
return fmt.Errorf("up: %w", err)
}
return nil
}

View File

@@ -1,133 +0,0 @@
//go:build linux && !android
package device
import (
"fmt"
"os"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
)
type wgLink struct {
attrs *netlink.LinkAttrs
}
func newWGLink(name string) *wgLink {
attrs := netlink.NewLinkAttrs()
attrs.Name = name
return &wgLink{
attrs: &attrs,
}
}
// Attrs returns the Wireguard's default attributes
func (l *wgLink) Attrs() *netlink.LinkAttrs {
return l.attrs
}
// Type returns the interface type
func (l *wgLink) Type() string {
return "wireguard"
}
// Close deletes the link interface
func (l *wgLink) Close() error {
return netlink.LinkDel(l)
}
func (l *wgLink) recreate() error {
name := l.attrs.Name
// check if interface exists
link, err := netlink.LinkByName(name)
if err != nil {
switch err.(type) {
case netlink.LinkNotFoundError:
break
default:
return fmt.Errorf("link by name: %w", err)
}
}
// remove if interface exists
if link != nil {
err = netlink.LinkDel(l)
if err != nil {
return err
}
}
log.Debugf("adding device: %s", name)
err = netlink.LinkAdd(l)
if os.IsExist(err) {
log.Infof("interface %s already exists. Will reuse.", name)
} else if err != nil {
return fmt.Errorf("link add: %w", err)
}
return nil
}
func (l *wgLink) setMTU(mtu int) error {
if err := netlink.LinkSetMTU(l, mtu); err != nil {
log.Errorf("error setting MTU on interface: %s", l.attrs.Name)
return fmt.Errorf("link set mtu: %w", err)
}
return nil
}
func (l *wgLink) up() error {
if err := netlink.LinkSetUp(l); err != nil {
log.Errorf("error bringing up interface: %s", l.attrs.Name)
return fmt.Errorf("link setup: %w", err)
}
return nil
}
func (l *wgLink) assignAddr(address WGAddress) error {
//delete existing addresses
list, err := netlink.AddrList(l, 0)
if err != nil {
return fmt.Errorf("list addr: %w", err)
}
if len(list) > 0 {
for _, a := range list {
addr := a
err = netlink.AddrDel(l, &addr)
if err != nil {
return fmt.Errorf("del addr: %w", err)
}
}
}
name := l.attrs.Name
addrStr := address.String()
log.Debugf("adding address %s to interface: %s", addrStr, name)
addr, err := netlink.ParseAddr(addrStr)
if err != nil {
return fmt.Errorf("parse addr: %w", err)
}
err = netlink.AddrAdd(l, addr)
if os.IsExist(err) {
log.Infof("interface %s already has the address: %s", name, addrStr)
} else if err != nil {
return fmt.Errorf("add addr: %w", err)
}
// On linux, the link must be brought up
if err := netlink.LinkSetUp(l); err != nil {
return fmt.Errorf("link setup: %w", err)
}
return nil
}

View File

@@ -1,15 +0,0 @@
package device
import (
"os"
"golang.zx2c4.com/wireguard/device"
)
func wgLogLevel() int {
if os.Getenv("NB_WG_DEBUG") == "true" {
return device.LogLevelVerbose
} else {
return device.LogLevelSilent
}
}

View File

@@ -1,4 +0,0 @@
package device
// CustomWindowsGUIDString is a custom GUID string for the interface
var CustomWindowsGUIDString string

View File

@@ -1,16 +0,0 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
)
type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address WGAddress) error
WgAddress() WGAddress
DeviceName() string
Close() error
FilteredDevice() *device.FilteredDevice
}

View File

@@ -1,8 +0,0 @@
package freebsd
import "errors"
var (
ErrDoesNotExist = errors.New("does not exist")
ErrNameDoesNotMatch = errors.New("name does not match")
)

View File

@@ -1,108 +0,0 @@
package freebsd
import (
"bufio"
"fmt"
"strconv"
"strings"
)
type iface struct {
Name string
MTU int
Group string
IPAddrs []string
}
func parseError(output []byte) error {
// TODO: implement without allocations
lines := string(output)
if strings.Contains(lines, "does not exist") {
return ErrDoesNotExist
}
return nil
}
func parseIfconfigOutput(output []byte) (*iface, error) {
// TODO: implement without allocations
lines := string(output)
scanner := bufio.NewScanner(strings.NewReader(lines))
var name, mtu, group string
var ips []string
for scanner.Scan() {
line := scanner.Text()
// If line contains ": flags", it's a line with interface information
if strings.Contains(line, ": flags") {
parts := strings.Fields(line)
if len(parts) < 4 {
return nil, fmt.Errorf("failed to parse line: %s", line)
}
name = strings.TrimSuffix(parts[0], ":")
if strings.Contains(line, "mtu") {
mtuIndex := 0
for i, part := range parts {
if part == "mtu" {
mtuIndex = i
break
}
}
mtu = parts[mtuIndex+1]
}
}
// If line contains "groups:", it's a line with interface group
if strings.Contains(line, "groups:") {
parts := strings.Fields(line)
if len(parts) < 2 {
return nil, fmt.Errorf("failed to parse line: %s", line)
}
group = parts[1]
}
// If line contains "inet ", it's a line with IP address
if strings.Contains(line, "inet ") {
parts := strings.Fields(line)
if len(parts) < 2 {
return nil, fmt.Errorf("failed to parse line: %s", line)
}
ips = append(ips, parts[1])
}
}
if name == "" {
return nil, fmt.Errorf("interface name not found in ifconfig output")
}
mtuInt, err := strconv.Atoi(mtu)
if err != nil {
return nil, fmt.Errorf("failed to parse MTU: %w", err)
}
return &iface{
Name: name,
MTU: mtuInt,
Group: group,
IPAddrs: ips,
}, nil
}
func parseIFName(output []byte) (string, error) {
// TODO: implement without allocations
lines := strings.Split(string(output), "\n")
if len(lines) == 0 || lines[0] == "" {
return "", fmt.Errorf("no output returned")
}
fields := strings.Fields(lines[0])
if len(fields) > 1 {
return "", fmt.Errorf("invalid output")
}
return fields[0], nil
}

View File

@@ -1,76 +0,0 @@
package freebsd
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestParseIfconfigOutput(t *testing.T) {
testOutput := `wg1: flags=8080<NOARP,MULTICAST> metric 0 mtu 1420
options=80000<LINKSTATE>
groups: wg
nd6 options=109<PERFORMNUD,IFDISABLED,NO_DAD>`
expected := &iface{
Name: "wg1",
MTU: 1420,
Group: "wg",
}
result, err := parseIfconfigOutput(([]byte)(testOutput))
if err != nil {
t.Errorf("Error parsing ifconfig output: %v", err)
return
}
assert.Equal(t, expected.Name, result.Name, "Name should match")
assert.Equal(t, expected.MTU, result.MTU, "MTU should match")
assert.Equal(t, expected.Group, result.Group, "Group should match")
}
func TestParseIFName(t *testing.T) {
tests := []struct {
name string
output string
expected string
expectedErr error
}{
{
name: "ValidOutput",
output: "eth0\n",
expected: "eth0",
},
{
name: "ValidOutputOneLine",
output: "eth0",
expected: "eth0",
},
{
name: "EmptyOutput",
output: "",
expectedErr: fmt.Errorf("no output returned"),
},
{
name: "InvalidOutput",
output: "This is an invalid output\n",
expectedErr: fmt.Errorf("invalid output"),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result, err := parseIFName(([]byte)(test.output))
assert.Equal(t, test.expected, result, "Interface names should match")
if test.expectedErr != nil {
assert.NotNil(t, err, "Error should not be nil")
assert.EqualError(t, err, test.expectedErr.Error(), "Error messages should match")
} else {
assert.Nil(t, err, "Error should be nil")
}
})
}
}

View File

@@ -1,239 +0,0 @@
package freebsd
import (
"bytes"
"errors"
"fmt"
"os/exec"
"strconv"
log "github.com/sirupsen/logrus"
)
const wgIFGroup = "wg"
// Link represents a network interface.
type Link struct {
name string
}
func NewLink(name string) *Link {
return &Link{
name: name,
}
}
// LinkByName retrieves a network interface by its name.
func LinkByName(name string) (*Link, error) {
out, err := exec.Command("ifconfig", name).CombinedOutput()
if err != nil {
if pErr := parseError(out); pErr != nil {
return nil, pErr
}
log.Debugf("ifconfig out: %s", out)
return nil, fmt.Errorf("command run: %w", err)
}
i, err := parseIfconfigOutput(out)
if err != nil {
return nil, fmt.Errorf("parse ifconfig output: %w", err)
}
if i.Name != name {
return nil, ErrNameDoesNotMatch
}
return &Link{name: i.Name}, nil
}
// Recreate - create new interface, remove current before create if it exists
func (l *Link) Recreate() error {
ok, err := l.isExist()
if err != nil {
return fmt.Errorf("is exist: %w", err)
}
if ok {
if err := l.del(l.name); err != nil {
return fmt.Errorf("del: %w", err)
}
}
return l.Add()
}
// Add creates a new network interface.
func (l *Link) Add() error {
parsedName, err := l.create(wgIFGroup)
if err != nil {
return fmt.Errorf("create link: %w", err)
}
if parsedName == l.name {
return nil
}
parsedName, err = l.rename(parsedName, l.name)
if err != nil {
errDel := l.del(parsedName)
if errDel != nil {
return fmt.Errorf("del on rename link: %w: %w", err, errDel)
}
return fmt.Errorf("rename link: %w", err)
}
return nil
}
// Del removes an existing network interface.
func (l *Link) Del() error {
return l.del(l.name)
}
// SetMTU sets the MTU of the network interface.
func (l *Link) SetMTU(mtu int) error {
return l.setMTU(mtu)
}
// AssignAddr assigns an IP address and netmask to the network interface.
func (l *Link) AssignAddr(ip, netmask string) error {
return l.setAddr(ip, netmask)
}
func (l *Link) Up() error {
return l.up(l.name)
}
func (l *Link) Down() error {
return l.down(l.name)
}
func (l *Link) isExist() (bool, error) {
_, err := LinkByName(l.name)
if errors.Is(err, ErrDoesNotExist) {
return false, nil
}
if err != nil {
return false, fmt.Errorf("link by name: %w", err)
}
return true, nil
}
func (l *Link) create(groupName string) (string, error) {
cmd := exec.Command("ifconfig", groupName, "create")
output, err := cmd.CombinedOutput()
if err != nil {
log.Debugf("ifconfig out: %s", output)
return "", fmt.Errorf("create %s interface: %w", groupName, err)
}
interfaceName, err := parseIFName(output)
if err != nil {
return "", fmt.Errorf("parse interface name: %w", err)
}
return interfaceName, nil
}
func (l *Link) rename(oldName, newName string) (string, error) {
cmd := exec.Command("ifconfig", oldName, "name", newName)
output, err := cmd.CombinedOutput()
if err != nil {
log.Debugf("ifconfig out: %s", output)
return "", fmt.Errorf("change name %q -> %q: %w", oldName, newName, err)
}
interfaceName, err := parseIFName(output)
if err != nil {
return "", fmt.Errorf("parse new name: %w", err)
}
return interfaceName, nil
}
func (l *Link) del(name string) error {
var stderr bytes.Buffer
cmd := exec.Command("ifconfig", name, "destroy")
cmd.Stderr = &stderr
err := cmd.Run()
if err != nil {
log.Debugf("ifconfig out: %s", stderr.String())
return fmt.Errorf("destroy %s interface: %w", name, err)
}
return nil
}
func (l *Link) setMTU(mtu int) error {
var stderr bytes.Buffer
cmd := exec.Command("ifconfig", l.name, "mtu", strconv.Itoa(mtu))
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
log.Debugf("ifconfig out: %s", stderr.String())
return fmt.Errorf("set interface mtu: %w", err)
}
return nil
}
func (l *Link) setAddr(ip, netmask string) error {
var stderr bytes.Buffer
cmd := exec.Command("ifconfig", l.name, "inet", ip, "netmask", netmask)
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
log.Debugf("ifconfig out: %s", stderr.String())
return fmt.Errorf("set interface addr: %w", err)
}
return nil
}
func (l *Link) up(name string) error {
var stderr bytes.Buffer
cmd := exec.Command("ifconfig", name, "up")
cmd.Stderr = &stderr
err := cmd.Run()
if err != nil {
log.Debugf("ifconfig out: %s", stderr.String())
return fmt.Errorf("up %s interface: %w", name, err)
}
return nil
}
func (l *Link) down(name string) error {
var stderr bytes.Buffer
cmd := exec.Command("ifconfig", name, "down")
cmd.Stderr = &stderr
err := cmd.Run()
if err != nil {
log.Debugf("ifconfig out: %s", stderr.String())
return fmt.Errorf("down %s interface: %w", name, err)
}
return nil
}

View File

@@ -1,67 +0,0 @@
//go:build !ios
package iface
import (
"fmt"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{
userspaceBind: true,
}
if netstack.IsEnabled() {
wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
return wgIFace, nil
}
wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
return wgIFace, nil
}
// CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on this platform")
}
// Create creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
// this function is different on Android
func (w *WGIface) Create() error {
w.mu.Lock()
defer w.mu.Unlock()
backOff := &backoff.ExponentialBackOff{
InitialInterval: 20 * time.Millisecond,
MaxElapsedTime: 500 * time.Millisecond,
Stop: backoff.Stop,
Clock: backoff.SystemClock,
}
operation := func() error {
cfgr, err := w.tun.Create()
if err != nil {
return err
}
w.configurer = cfgr
return nil
}
return backoff.Retry(operation, backOff)
}

View File

@@ -1,17 +0,0 @@
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
package iface
import (
"fmt"
"os/exec"
)
func (w *WGIface) Destroy() error {
out, err := exec.Command("ifconfig", w.Name(), "destroy").CombinedOutput()
if err != nil {
return fmt.Errorf("failed to remove interface %s: %w - %s", w.Name(), err, out)
}
return nil
}

View File

@@ -1,22 +0,0 @@
//go:build linux && !android
package iface
import (
"fmt"
"github.com/vishvananda/netlink"
)
func (w *WGIface) Destroy() error {
link, err := netlink.LinkByName(w.Name())
if err != nil {
return fmt.Errorf("failed to get link by name %s: %w", w.Name(), err)
}
if err := netlink.LinkDel(link); err != nil {
return fmt.Errorf("failed to delete link %s: %w", w.Name(), err)
}
return nil
}

View File

@@ -1,9 +0,0 @@
//go:build android || (ios && !darwin)
package iface
import "errors"
func (w *WGIface) Destroy() error {
return errors.New("not supported on mobile")
}

View File

@@ -1,32 +0,0 @@
//go:build windows
package iface
import (
"fmt"
"os/exec"
log "github.com/sirupsen/logrus"
)
func (w *WGIface) Destroy() error {
netshCmd := GetSystem32Command("netsh")
out, err := exec.Command(netshCmd, "interface", "set", "interface", w.Name(), "admin=disable").CombinedOutput()
if err != nil {
return fmt.Errorf("failed to remove interface %s: %w - %s", w.Name(), err, out)
}
return nil
}
// GetSystem32Command checks if a command can be found in the system path and returns it. In case it can't find it
// in the path it will return the full path of a command assuming C:\windows\system32 as the base path.
func GetSystem32Command(command string) string {
_, err := exec.LookPath(command)
if err == nil {
return command
}
log.Tracef("Command %s not found in PATH, using C:\\windows\\system32\\%s.exe path", command, command)
return "C:\\windows\\system32\\" + command + ".exe"
}

View File

@@ -1,105 +0,0 @@
package iface
import (
"net"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
)
type MockWGIface struct {
CreateFunc func() error
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
IsUserspaceBindFunc func() bool
NameFunc func() string
AddressFunc func() device.WGAddress
ToInterfaceFunc func() *net.Interface
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
UpdateAddrFunc func(newAddr string) error
UpdatePeerFunc func(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeerFunc func(peerKey string) error
AddAllowedIPFunc func(peerKey string, allowedIP string) error
RemoveAllowedIPFunc func(peerKey string, allowedIP string) error
CloseFunc func() error
SetFilterFunc func(filter device.PacketFilter) error
GetFilterFunc func() device.PacketFilter
GetDeviceFunc func() *device.FilteredDevice
GetStatsFunc func(peerKey string) (configurer.WGStats, error)
GetInterfaceGUIDStringFunc func() (string, error)
}
func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
return m.GetInterfaceGUIDStringFunc()
}
func (m *MockWGIface) Create() error {
return m.CreateFunc()
}
func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error {
return m.CreateOnAndroidFunc(routeRange, ip, domains)
}
func (m *MockWGIface) IsUserspaceBind() bool {
return m.IsUserspaceBindFunc()
}
func (m *MockWGIface) Name() string {
return m.NameFunc()
}
func (m *MockWGIface) Address() device.WGAddress {
return m.AddressFunc()
}
func (m *MockWGIface) ToInterface() *net.Interface {
return m.ToInterfaceFunc()
}
func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
return m.UpFunc()
}
func (m *MockWGIface) UpdateAddr(newAddr string) error {
return m.UpdateAddrFunc(newAddr)
}
func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
}
func (m *MockWGIface) RemovePeer(peerKey string) error {
return m.RemovePeerFunc(peerKey)
}
func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP string) error {
return m.AddAllowedIPFunc(peerKey, allowedIP)
}
func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
return m.RemoveAllowedIPFunc(peerKey, allowedIP)
}
func (m *MockWGIface) Close() error {
return m.CloseFunc()
}
func (m *MockWGIface) SetFilter(filter device.PacketFilter) error {
return m.SetFilterFunc(filter)
}
func (m *MockWGIface) GetFilter() device.PacketFilter {
return m.GetFilterFunc()
}
func (m *MockWGIface) GetDevice() *device.FilteredDevice {
return m.GetDeviceFunc()
}
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
return m.GetStatsFunc(peerKey)
}

View File

@@ -1,49 +0,0 @@
//go:build (linux && !android) || freebsd
package iface
import (
"fmt"
"runtime"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{}
// move the kernel/usp/netstack preference evaluation to upper layer
if netstack.IsEnabled() {
wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
wgIFace.userspaceBind = true
return wgIFace, nil
}
if device.WireGuardModuleIsLoaded() {
wgIFace.tun = device.NewKernelDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet)
wgIFace.userspaceBind = false
return wgIFace, nil
}
if !device.ModuleTunIsLoaded() {
return nil, fmt.Errorf("couldn't check or load tun module")
}
wgIFace.tun = device.NewUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil)
wgIFace.userspaceBind = true
return wgIFace, nil
}
// CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("CreateOnAndroid function has not implemented on %s platform", runtime.GOOS)
}

View File

@@ -1,34 +0,0 @@
//go:build !windows
package iface
import (
"net"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
)
type IWGIface interface {
Create() error
CreateOnAndroid(routeRange []string, ip string, domains []string) error
IsUserspaceBind() bool
Name() string
Address() device.WGAddress
ToInterface() *net.Interface
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error
Close() error
SetFilter(filter device.PacketFilter) error
GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice
GetStats(peerKey string) (configurer.WGStats, error)
}

View File

@@ -1,33 +0,0 @@
package iface
import (
"net"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
)
type IWGIface interface {
Create() error
CreateOnAndroid(routeRange []string, ip string, domains []string) error
IsUserspaceBind() bool
Name() string
Address() device.WGAddress
ToInterface() *net.Interface
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error
Close() error
SetFilter(filter device.PacketFilter) error
GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice
GetStats(peerKey string) (configurer.WGStats, error)
GetInterfaceGUIDString() (string, error)
}

View File

@@ -1,64 +0,0 @@
package id
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net/netip"
"strconv"
"github.com/netbirdio/netbird/client/firewall/manager"
)
type RuleID string
func (r RuleID) GetRuleID() string {
return string(r)
}
func GenerateRouteRuleKey(
sources []netip.Prefix,
destination netip.Prefix,
proto manager.Protocol,
sPort *manager.Port,
dPort *manager.Port,
action manager.Action,
) RuleID {
manager.SortPrefixes(sources)
h := sha256.New()
// Write all fields to the hasher, with delimiters
h.Write([]byte("sources:"))
for _, src := range sources {
h.Write([]byte(src.String()))
h.Write([]byte(","))
}
h.Write([]byte("destination:"))
h.Write([]byte(destination.String()))
h.Write([]byte("proto:"))
h.Write([]byte(proto))
h.Write([]byte("sPort:"))
if sPort != nil {
h.Write([]byte(sPort.String()))
} else {
h.Write([]byte("<nil>"))
}
h.Write([]byte("dPort:"))
if dPort != nil {
h.Write([]byte(dPort.String()))
} else {
h.Write([]byte("<nil>"))
}
h.Write([]byte("action:"))
h.Write([]byte(strconv.Itoa(int(action))))
hash := hex.EncodeToString(h.Sum(nil))
// prepend destination prefix to be able to identify the rule
return RuleID(fmt.Sprintf("%s-%s", destination.String(), hash[:16]))
}

View File

@@ -5,7 +5,6 @@ import (
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"net" "net"
"net/netip"
"strconv" "strconv"
"sync" "sync"
"time" "time"
@@ -13,7 +12,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
) )
@@ -27,16 +25,14 @@ type Manager interface {
type DefaultManager struct { type DefaultManager struct {
firewall firewall.Manager firewall firewall.Manager
ipsetCounter int ipsetCounter int
peerRulesPairs map[id.RuleID][]firewall.Rule rulesPairs map[string][]firewall.Rule
routeRules map[id.RuleID]struct{}
mutex sync.Mutex mutex sync.Mutex
} }
func NewDefaultManager(fm firewall.Manager) *DefaultManager { func NewDefaultManager(fm firewall.Manager) *DefaultManager {
return &DefaultManager{ return &DefaultManager{
firewall: fm, firewall: fm,
peerRulesPairs: make(map[id.RuleID][]firewall.Rule), rulesPairs: make(map[string][]firewall.Rule),
routeRules: make(map[id.RuleID]struct{}),
} }
} }
@@ -50,7 +46,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
start := time.Now() start := time.Now()
defer func() { defer func() {
total := 0 total := 0
for _, pairs := range d.peerRulesPairs { for _, pairs := range d.rulesPairs {
total += len(pairs) total += len(pairs)
} }
log.Infof( log.Infof(
@@ -63,34 +59,21 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
return return
} }
d.applyPeerACLs(networkMap) defer func() {
// If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
// then the mgmt server is older than the client, and we need to allow all traffic for routes
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
if err := d.firewall.SetLegacyManagement(isLegacy); err != nil {
log.Errorf("failed to set legacy management flag: %v", err)
}
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules); err != nil {
log.Errorf("Failed to apply route ACLs: %v", err)
}
if err := d.firewall.Flush(); err != nil { if err := d.firewall.Flush(); err != nil {
log.Error("failed to flush firewall rules: ", err) log.Error("failed to flush firewall rules: ", err)
} }
} }()
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
rules, squashedProtocols := d.squashAcceptRules(networkMap) rules, squashedProtocols := d.squashAcceptRules(networkMap)
enableSSH := networkMap.PeerConfig != nil && enableSSH := (networkMap.PeerConfig != nil &&
networkMap.PeerConfig.SshConfig != nil && networkMap.PeerConfig.SshConfig != nil &&
networkMap.PeerConfig.SshConfig.SshEnabled networkMap.PeerConfig.SshConfig.SshEnabled)
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok { if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok {
enableSSH = enableSSH && !ok enableSSH = enableSSH && !ok
} }
if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok { if _, ok := squashedProtocols[mgmProto.FirewallRule_TCP]; ok {
enableSSH = enableSSH && !ok enableSSH = enableSSH && !ok
} }
@@ -100,9 +83,9 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
if enableSSH { if enableSSH {
rules = append(rules, &mgmProto.FirewallRule{ rules = append(rules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0", PeerIP: "0.0.0.0",
Direction: mgmProto.RuleDirection_IN, Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP, Protocol: mgmProto.FirewallRule_TCP,
Port: strconv.Itoa(ssh.DefaultSSHPort), Port: strconv.Itoa(ssh.DefaultSSHPort),
}) })
} }
@@ -114,20 +97,20 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
rules = append(rules, rules = append(rules,
&mgmProto.FirewallRule{ &mgmProto.FirewallRule{
PeerIP: "0.0.0.0", PeerIP: "0.0.0.0",
Direction: mgmProto.RuleDirection_IN, Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
&mgmProto.FirewallRule{ &mgmProto.FirewallRule{
PeerIP: "0.0.0.0", PeerIP: "0.0.0.0",
Direction: mgmProto.RuleDirection_OUT, Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
) )
} }
newRulePairs := make(map[id.RuleID][]firewall.Rule) newRulePairs := make(map[string][]firewall.Rule)
ipsetByRuleSelectors := make(map[string]string) ipsetByRuleSelectors := make(map[string]string)
for _, r := range rules { for _, r := range rules {
@@ -147,97 +130,29 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
break break
} }
if len(rules) > 0 { if len(rules) > 0 {
d.peerRulesPairs[pairID] = rulePair d.rulesPairs[pairID] = rulePair
newRulePairs[pairID] = rulePair newRulePairs[pairID] = rulePair
} }
} }
for pairID, rules := range d.peerRulesPairs { for pairID, rules := range d.rulesPairs {
if _, ok := newRulePairs[pairID]; !ok { if _, ok := newRulePairs[pairID]; !ok {
for _, rule := range rules { for _, rule := range rules {
if err := d.firewall.DeletePeerRule(rule); err != nil { if err := d.firewall.DeleteRule(rule); err != nil {
log.Errorf("failed to delete peer firewall rule: %v", err) log.Errorf("failed to delete firewall rule: %v", err)
continue continue
} }
} }
delete(d.peerRulesPairs, pairID) delete(d.rulesPairs, pairID)
} }
} }
d.peerRulesPairs = newRulePairs d.rulesPairs = newRulePairs
}
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error {
var newRouteRules = make(map[id.RuleID]struct{})
for _, rule := range rules {
id, err := d.applyRouteACL(rule)
if err != nil {
return fmt.Errorf("apply route ACL: %w", err)
}
newRouteRules[id] = struct{}{}
}
for id := range d.routeRules {
if _, ok := newRouteRules[id]; !ok {
if err := d.firewall.DeleteRouteRule(id); err != nil {
log.Errorf("failed to delete route firewall rule: %v", err)
continue
}
delete(d.routeRules, id)
}
}
d.routeRules = newRouteRules
return nil
}
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) {
if len(rule.SourceRanges) == 0 {
return "", fmt.Errorf("source ranges is empty")
}
var sources []netip.Prefix
for _, sourceRange := range rule.SourceRanges {
source, err := netip.ParsePrefix(sourceRange)
if err != nil {
return "", fmt.Errorf("parse source range: %w", err)
}
sources = append(sources, source)
}
var destination netip.Prefix
if rule.IsDynamic {
destination = getDefault(sources[0])
} else {
var err error
destination, err = netip.ParsePrefix(rule.Destination)
if err != nil {
return "", fmt.Errorf("parse destination: %w", err)
}
}
protocol, err := convertToFirewallProtocol(rule.Protocol)
if err != nil {
return "", fmt.Errorf("invalid protocol: %w", err)
}
action, err := convertFirewallAction(rule.Action)
if err != nil {
return "", fmt.Errorf("invalid action: %w", err)
}
dPorts := convertPortInfo(rule.PortInfo)
addedRule, err := d.firewall.AddRouteFiltering(sources, destination, protocol, nil, dPorts, action)
if err != nil {
return "", fmt.Errorf("add route rule: %w", err)
}
return id.RuleID(addedRule.GetRuleID()), nil
} }
func (d *DefaultManager) protoRuleToFirewallRule( func (d *DefaultManager) protoRuleToFirewallRule(
r *mgmProto.FirewallRule, r *mgmProto.FirewallRule,
ipsetName string, ipsetName string,
) (id.RuleID, []firewall.Rule, error) { ) (string, []firewall.Rule, error) {
ip := net.ParseIP(r.PeerIP) ip := net.ParseIP(r.PeerIP)
if ip == nil { if ip == nil {
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule") return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
@@ -264,16 +179,16 @@ func (d *DefaultManager) protoRuleToFirewallRule(
} }
} }
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action, "") ruleID := d.getRuleID(ip, protocol, int(r.Direction), port, action, "")
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok { if rulesPair, ok := d.rulesPairs[ruleID]; ok {
return ruleID, rulesPair, nil return ruleID, rulesPair, nil
} }
var rules []firewall.Rule var rules []firewall.Rule
switch r.Direction { switch r.Direction {
case mgmProto.RuleDirection_IN: case mgmProto.FirewallRule_IN:
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "") rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
case mgmProto.RuleDirection_OUT: case mgmProto.FirewallRule_OUT:
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "") rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
default: default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule") return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
@@ -295,7 +210,7 @@ func (d *DefaultManager) addInRules(
comment string, comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
var rules []firewall.Rule var rules []firewall.Rule
rule, err := d.firewall.AddPeerFiltering( rule, err := d.firewall.AddFiltering(
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment) ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("failed to add firewall rule: %v", err)
@@ -306,7 +221,7 @@ func (d *DefaultManager) addInRules(
return rules, nil return rules, nil
} }
rule, err = d.firewall.AddPeerFiltering( rule, err = d.firewall.AddFiltering(
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment) ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("failed to add firewall rule: %v", err)
@@ -324,7 +239,7 @@ func (d *DefaultManager) addOutRules(
comment string, comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
var rules []firewall.Rule var rules []firewall.Rule
rule, err := d.firewall.AddPeerFiltering( rule, err := d.firewall.AddFiltering(
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment) ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("failed to add firewall rule: %v", err)
@@ -335,7 +250,7 @@ func (d *DefaultManager) addOutRules(
return rules, nil return rules, nil
} }
rule, err = d.firewall.AddPeerFiltering( rule, err = d.firewall.AddFiltering(
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment) ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err) return nil, fmt.Errorf("failed to add firewall rule: %v", err)
@@ -344,21 +259,21 @@ func (d *DefaultManager) addOutRules(
return append(rules, rule...), nil return append(rules, rule...), nil
} }
// getPeerRuleID() returns unique ID for the rule based on its parameters. // getRuleID() returns unique ID for the rule based on its parameters.
func (d *DefaultManager) getPeerRuleID( func (d *DefaultManager) getRuleID(
ip net.IP, ip net.IP,
proto firewall.Protocol, proto firewall.Protocol,
direction int, direction int,
port *firewall.Port, port *firewall.Port,
action firewall.Action, action firewall.Action,
comment string, comment string,
) id.RuleID { ) string {
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
if port != nil { if port != nil {
idStr += port.String() idStr += port.String()
} }
return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr)))) return hex.EncodeToString(md5.New().Sum([]byte(idStr)))
} }
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type // squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
@@ -368,7 +283,7 @@ func (d *DefaultManager) getPeerRuleID(
// but other has port definitions or has drop policy. // but other has port definitions or has drop policy.
func (d *DefaultManager) squashAcceptRules( func (d *DefaultManager) squashAcceptRules(
networkMap *mgmProto.NetworkMap, networkMap *mgmProto.NetworkMap,
) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) { ) ([]*mgmProto.FirewallRule, map[mgmProto.FirewallRuleProtocol]struct{}) {
totalIPs := 0 totalIPs := 0
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) { for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
for range p.AllowedIps { for range p.AllowedIps {
@@ -376,14 +291,14 @@ func (d *DefaultManager) squashAcceptRules(
} }
} }
type protoMatch map[mgmProto.RuleProtocol]map[string]int type protoMatch map[mgmProto.FirewallRuleProtocol]map[string]int
in := protoMatch{} in := protoMatch{}
out := protoMatch{} out := protoMatch{}
// trace which type of protocols was squashed // trace which type of protocols was squashed
squashedRules := []*mgmProto.FirewallRule{} squashedRules := []*mgmProto.FirewallRule{}
squashedProtocols := map[mgmProto.RuleProtocol]struct{}{} squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
// this function we use to do calculation, can we squash the rules by protocol or not. // this function we use to do calculation, can we squash the rules by protocol or not.
// We summ amount of Peers IP for given protocol we found in original rules list. // We summ amount of Peers IP for given protocol we found in original rules list.
@@ -393,7 +308,7 @@ func (d *DefaultManager) squashAcceptRules(
// //
// We zeroed this to notify squash function that this protocol can't be squashed. // We zeroed this to notify squash function that this protocol can't be squashed.
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols protoMatch) { addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols protoMatch) {
drop := r.Action == mgmProto.RuleAction_DROP || r.Port != "" drop := r.Action == mgmProto.FirewallRule_DROP || r.Port != ""
if drop { if drop {
protocols[r.Protocol] = map[string]int{} protocols[r.Protocol] = map[string]int{}
return return
@@ -421,7 +336,7 @@ func (d *DefaultManager) squashAcceptRules(
for i, r := range networkMap.FirewallRules { for i, r := range networkMap.FirewallRules {
// calculate squash for different directions // calculate squash for different directions
if r.Direction == mgmProto.RuleDirection_IN { if r.Direction == mgmProto.FirewallRule_IN {
addRuleToCalculationMap(i, r, in) addRuleToCalculationMap(i, r, in)
} else { } else {
addRuleToCalculationMap(i, r, out) addRuleToCalculationMap(i, r, out)
@@ -430,14 +345,14 @@ func (d *DefaultManager) squashAcceptRules(
// order of squashing by protocol is important // order of squashing by protocol is important
// only for their first element ALL, it must be done first // only for their first element ALL, it must be done first
protocolOrders := []mgmProto.RuleProtocol{ protocolOrders := []mgmProto.FirewallRuleProtocol{
mgmProto.RuleProtocol_ALL, mgmProto.FirewallRule_ALL,
mgmProto.RuleProtocol_ICMP, mgmProto.FirewallRule_ICMP,
mgmProto.RuleProtocol_TCP, mgmProto.FirewallRule_TCP,
mgmProto.RuleProtocol_UDP, mgmProto.FirewallRule_UDP,
} }
squash := func(matches protoMatch, direction mgmProto.RuleDirection) { squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) {
for _, protocol := range protocolOrders { for _, protocol := range protocolOrders {
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 { if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
// don't squash if : // don't squash if :
@@ -450,12 +365,12 @@ func (d *DefaultManager) squashAcceptRules(
squashedRules = append(squashedRules, &mgmProto.FirewallRule{ squashedRules = append(squashedRules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0", PeerIP: "0.0.0.0",
Direction: direction, Direction: direction,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: protocol, Protocol: protocol,
}) })
squashedProtocols[protocol] = struct{}{} squashedProtocols[protocol] = struct{}{}
if protocol == mgmProto.RuleProtocol_ALL { if protocol == mgmProto.FirewallRule_ALL {
// if we have ALL traffic type squashed rule // if we have ALL traffic type squashed rule
// it allows all other type of traffic, so we can stop processing // it allows all other type of traffic, so we can stop processing
break break
@@ -463,11 +378,11 @@ func (d *DefaultManager) squashAcceptRules(
} }
} }
squash(in, mgmProto.RuleDirection_IN) squash(in, mgmProto.FirewallRule_IN)
squash(out, mgmProto.RuleDirection_OUT) squash(out, mgmProto.FirewallRule_OUT)
// if all protocol was squashed everything is allow and we can ignore all other rules // if all protocol was squashed everything is allow and we can ignore all other rules
if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok { if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok {
return squashedRules, squashedProtocols return squashedRules, squashedProtocols
} }
@@ -497,26 +412,26 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port) return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
} }
func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) { func (d *DefaultManager) rollBack(newRulePairs map[string][]firewall.Rule) {
log.Debugf("rollback ACL to previous state") log.Debugf("rollback ACL to previous state")
for _, rules := range newRulePairs { for _, rules := range newRulePairs {
for _, rule := range rules { for _, rule := range rules {
if err := d.firewall.DeletePeerRule(rule); err != nil { if err := d.firewall.DeleteRule(rule); err != nil {
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err) log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
} }
} }
} }
} }
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) { func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) (firewall.Protocol, error) {
switch protocol { switch protocol {
case mgmProto.RuleProtocol_TCP: case mgmProto.FirewallRule_TCP:
return firewall.ProtocolTCP, nil return firewall.ProtocolTCP, nil
case mgmProto.RuleProtocol_UDP: case mgmProto.FirewallRule_UDP:
return firewall.ProtocolUDP, nil return firewall.ProtocolUDP, nil
case mgmProto.RuleProtocol_ICMP: case mgmProto.FirewallRule_ICMP:
return firewall.ProtocolICMP, nil return firewall.ProtocolICMP, nil
case mgmProto.RuleProtocol_ALL: case mgmProto.FirewallRule_ALL:
return firewall.ProtocolALL, nil return firewall.ProtocolALL, nil
default: default:
return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String()) return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String())
@@ -527,41 +442,13 @@ func shouldSkipInvertedRule(protocol firewall.Protocol, port *firewall.Port) boo
return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil
} }
func convertFirewallAction(action mgmProto.RuleAction) (firewall.Action, error) { func convertFirewallAction(action mgmProto.FirewallRuleAction) (firewall.Action, error) {
switch action { switch action {
case mgmProto.RuleAction_ACCEPT: case mgmProto.FirewallRule_ACCEPT:
return firewall.ActionAccept, nil return firewall.ActionAccept, nil
case mgmProto.RuleAction_DROP: case mgmProto.FirewallRule_DROP:
return firewall.ActionDrop, nil return firewall.ActionDrop, nil
default: default:
return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action) return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action)
} }
} }
func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port {
if portInfo == nil {
return nil
}
if portInfo.GetPort() != 0 {
return &firewall.Port{
Values: []int{int(portInfo.GetPort())},
}
}
if portInfo.GetRange() != nil {
return &firewall.Port{
IsRange: true,
Values: []int{int(portInfo.GetRange().Start), int(portInfo.GetRange().End)},
}
}
return nil
}
func getDefault(prefix netip.Prefix) netip.Prefix {
if prefix.Addr().Is6() {
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
}
return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
}

View File

@@ -9,8 +9,8 @@ import (
"github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/acl/mocks" "github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/iface"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
) )
@@ -19,16 +19,16 @@ func TestDefaultManager(t *testing.T) {
FirewallRules: []*mgmProto.FirewallRule{ FirewallRules: []*mgmProto.FirewallRule{
{ {
PeerIP: "10.93.0.1", PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_OUT, Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP, Protocol: mgmProto.FirewallRule_TCP,
Port: "80", Port: "80",
}, },
{ {
PeerIP: "10.93.0.2", PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_OUT, Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.RuleAction_DROP, Action: mgmProto.FirewallRule_DROP,
Protocol: mgmProto.RuleProtocol_UDP, Protocol: mgmProto.FirewallRule_UDP,
Port: "53", Port: "53",
}, },
}, },
@@ -65,16 +65,16 @@ func TestDefaultManager(t *testing.T) {
t.Run("apply firewall rules", func(t *testing.T) { t.Run("apply firewall rules", func(t *testing.T) {
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap)
if len(acl.peerRulesPairs) != 2 { if len(acl.rulesPairs) != 2 {
t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs) t.Errorf("firewall rules not applied: %v", acl.rulesPairs)
return return
} }
}) })
t.Run("add extra rules", func(t *testing.T) { t.Run("add extra rules", func(t *testing.T) {
existedPairs := map[string]struct{}{} existedPairs := map[string]struct{}{}
for id := range acl.peerRulesPairs { for id := range acl.rulesPairs {
existedPairs[id.GetRuleID()] = struct{}{} existedPairs[id] = struct{}{}
} }
// remove first rule // remove first rule
@@ -83,24 +83,24 @@ func TestDefaultManager(t *testing.T) {
networkMap.FirewallRules, networkMap.FirewallRules,
&mgmProto.FirewallRule{ &mgmProto.FirewallRule{
PeerIP: "10.93.0.3", PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN, Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.RuleAction_DROP, Action: mgmProto.FirewallRule_DROP,
Protocol: mgmProto.RuleProtocol_ICMP, Protocol: mgmProto.FirewallRule_ICMP,
}, },
) )
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap)
// we should have one old and one new rule in the existed rules // we should have one old and one new rule in the existed rules
if len(acl.peerRulesPairs) != 2 { if len(acl.rulesPairs) != 2 {
t.Errorf("firewall rules not applied") t.Errorf("firewall rules not applied")
return return
} }
// check that old rule was removed // check that old rule was removed
previousCount := 0 previousCount := 0
for id := range acl.peerRulesPairs { for id := range acl.rulesPairs {
if _, ok := existedPairs[id.GetRuleID()]; ok { if _, ok := existedPairs[id]; ok {
previousCount++ previousCount++
} }
} }
@@ -113,15 +113,15 @@ func TestDefaultManager(t *testing.T) {
networkMap.FirewallRules = networkMap.FirewallRules[:0] networkMap.FirewallRules = networkMap.FirewallRules[:0]
networkMap.FirewallRulesIsEmpty = true networkMap.FirewallRulesIsEmpty = true
if acl.ApplyFiltering(networkMap); len(acl.peerRulesPairs) != 0 { if acl.ApplyFiltering(networkMap); len(acl.rulesPairs) != 0 {
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs)) t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.rulesPairs))
return return
} }
networkMap.FirewallRulesIsEmpty = false networkMap.FirewallRulesIsEmpty = false
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap)
if len(acl.peerRulesPairs) != 2 { if len(acl.rulesPairs) != 2 {
t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs)) t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.rulesPairs))
return return
} }
}) })
@@ -138,51 +138,51 @@ func TestDefaultManagerSquashRules(t *testing.T) {
FirewallRules: []*mgmProto.FirewallRule{ FirewallRules: []*mgmProto.FirewallRule{
{ {
PeerIP: "10.93.0.1", PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN, Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
{ {
PeerIP: "10.93.0.2", PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN, Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
{ {
PeerIP: "10.93.0.3", PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN, Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
{ {
PeerIP: "10.93.0.4", PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN, Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
{ {
PeerIP: "10.93.0.1", PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_OUT, Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
{ {
PeerIP: "10.93.0.2", PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_OUT, Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
{ {
PeerIP: "10.93.0.3", PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_OUT, Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
{ {
PeerIP: "10.93.0.4", PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_OUT, Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
}, },
} }
@@ -199,13 +199,13 @@ func TestDefaultManagerSquashRules(t *testing.T) {
case r.PeerIP != "0.0.0.0": case r.PeerIP != "0.0.0.0":
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP) t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
return return
case r.Direction != mgmProto.RuleDirection_IN: case r.Direction != mgmProto.FirewallRule_IN:
t.Errorf("direction should be IN, got: %v", r.Direction) t.Errorf("direction should be IN, got: %v", r.Direction)
return return
case r.Protocol != mgmProto.RuleProtocol_ALL: case r.Protocol != mgmProto.FirewallRule_ALL:
t.Errorf("protocol should be ALL, got: %v", r.Protocol) t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return return
case r.Action != mgmProto.RuleAction_ACCEPT: case r.Action != mgmProto.FirewallRule_ACCEPT:
t.Errorf("action should be ACCEPT, got: %v", r.Action) t.Errorf("action should be ACCEPT, got: %v", r.Action)
return return
} }
@@ -215,13 +215,13 @@ func TestDefaultManagerSquashRules(t *testing.T) {
case r.PeerIP != "0.0.0.0": case r.PeerIP != "0.0.0.0":
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP) t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
return return
case r.Direction != mgmProto.RuleDirection_OUT: case r.Direction != mgmProto.FirewallRule_OUT:
t.Errorf("direction should be OUT, got: %v", r.Direction) t.Errorf("direction should be OUT, got: %v", r.Direction)
return return
case r.Protocol != mgmProto.RuleProtocol_ALL: case r.Protocol != mgmProto.FirewallRule_ALL:
t.Errorf("protocol should be ALL, got: %v", r.Protocol) t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return return
case r.Action != mgmProto.RuleAction_ACCEPT: case r.Action != mgmProto.FirewallRule_ACCEPT:
t.Errorf("action should be ACCEPT, got: %v", r.Action) t.Errorf("action should be ACCEPT, got: %v", r.Action)
return return
} }
@@ -238,51 +238,51 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
FirewallRules: []*mgmProto.FirewallRule{ FirewallRules: []*mgmProto.FirewallRule{
{ {
PeerIP: "10.93.0.1", PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN, Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
{ {
PeerIP: "10.93.0.2", PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN, Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
{ {
PeerIP: "10.93.0.3", PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN, Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
{ {
PeerIP: "10.93.0.4", PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN, Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP, Protocol: mgmProto.FirewallRule_TCP,
}, },
{ {
PeerIP: "10.93.0.1", PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_OUT, Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
{ {
PeerIP: "10.93.0.2", PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_OUT, Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
{ {
PeerIP: "10.93.0.3", PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_OUT, Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL, Protocol: mgmProto.FirewallRule_ALL,
}, },
{ {
PeerIP: "10.93.0.4", PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_OUT, Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP, Protocol: mgmProto.FirewallRule_UDP,
}, },
}, },
} }
@@ -308,21 +308,21 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
FirewallRules: []*mgmProto.FirewallRule{ FirewallRules: []*mgmProto.FirewallRule{
{ {
PeerIP: "10.93.0.1", PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN, Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP, Protocol: mgmProto.FirewallRule_TCP,
}, },
{ {
PeerIP: "10.93.0.2", PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN, Direction: mgmProto.FirewallRule_IN,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP, Protocol: mgmProto.FirewallRule_TCP,
}, },
{ {
PeerIP: "10.93.0.3", PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_OUT, Direction: mgmProto.FirewallRule_OUT,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.FirewallRule_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP, Protocol: mgmProto.FirewallRule_UDP,
}, },
}, },
} }
@@ -357,8 +357,8 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap)
if len(acl.peerRulesPairs) != 4 { if len(acl.rulesPairs) != 4 {
t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.peerRulesPairs)) t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.rulesPairs))
return return
} }
} }

View File

@@ -8,8 +8,7 @@ import (
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
iface "github.com/netbirdio/netbird/client/iface" iface "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/client/iface/device"
) )
// MockIFaceMapper is a mock of IFaceMapper interface. // MockIFaceMapper is a mock of IFaceMapper interface.
@@ -78,7 +77,7 @@ func (mr *MockIFaceMapperMockRecorder) Name() *gomock.Call {
} }
// SetFilter mocks base method. // SetFilter mocks base method.
func (m *MockIFaceMapper) SetFilter(arg0 device.PacketFilter) error { func (m *MockIFaceMapper) SetFilter(arg0 iface.PacketFilter) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetFilter", arg0) ret := m.ctrl.Call(m, "SetFilter", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)

View File

@@ -3,7 +3,6 @@ package auth
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@@ -181,7 +180,7 @@ func (d *DeviceAuthorizationFlow) WaitToken(ctx context.Context, info AuthFlowIn
continue continue
} }
return TokenInfo{}, errors.New(tokenResponse.ErrorDescription) return TokenInfo{}, fmt.Errorf(tokenResponse.ErrorDescription)
} }
tokenInfo := TokenInfo{ tokenInfo := TokenInfo{

View File

@@ -69,11 +69,6 @@ func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopCl
return authenticateWithDeviceCodeFlow(ctx, config) return authenticateWithDeviceCodeFlow(ctx, config)
} }
// On FreeBSD we currently do not support desktop environments and offer only Device Code Flow (#2384)
if runtime.GOOS == "freebsd" {
return authenticateWithDeviceCodeFlow(ctx, config)
}
pkceFlow, err := authenticateWithPKCEFlow(ctx, config) pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
if err != nil { if err != nil {
// fallback to device code flow // fallback to device code flow
@@ -86,7 +81,7 @@ func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopCl
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow // authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) { func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair) pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil { if err != nil {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err) return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
} }

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"crypto/sha256" "crypto/sha256"
"crypto/subtle" "crypto/subtle"
"crypto/tls"
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
@@ -144,18 +143,6 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) { func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
cert := p.providerConfig.ClientCertPair
if cert != nil {
tr := &http.Transport{
TLSClientConfig: &tls.Config{
Certificates: []tls.Certificate{*cert},
},
}
sslClient := &http.Client{Transport: tr}
ctx := context.WithValue(req.Context(), oauth2.HTTPClient, sslClient)
req = req.WithContext(ctx)
}
token, err := p.handleRequest(req) token, err := p.handleRequest(req)
if err != nil { if err != nil {
renderPKCEFlowTmpl(w, err) renderPKCEFlowTmpl(w, err)

View File

@@ -2,23 +2,17 @@ package internal
import ( import (
"context" "context"
"crypto/tls"
"fmt" "fmt"
"net/url" "net/url"
"os" "os"
"reflect"
"runtime"
"strings"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client" mgm "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -54,12 +48,8 @@ type ConfigInput struct {
RosenpassPermissive *bool RosenpassPermissive *bool
InterfaceName *string InterfaceName *string
WireguardPort *int WireguardPort *int
NetworkMonitor *bool
DisableAutoConnect *bool DisableAutoConnect *bool
ExtraIFaceBlackList []string ExtraIFaceBlackList []string
DNSRouteInterval *time.Duration
ClientCertPath string
ClientCertKeyPath string
} }
// Config Configuration type // Config Configuration type
@@ -71,7 +61,6 @@ type Config struct {
AdminURL *url.URL AdminURL *url.URL
WgIface string WgIface string
WgPort int WgPort int
NetworkMonitor *bool
IFaceBlackList []string IFaceBlackList []string
DisableIPv6Discovery bool DisableIPv6Discovery bool
RosenpassEnabled bool RosenpassEnabled bool
@@ -102,38 +91,15 @@ type Config struct {
// DisableAutoConnect determines whether the client should not start with the service // DisableAutoConnect determines whether the client should not start with the service
// it's set to false by default due to backwards compatibility // it's set to false by default due to backwards compatibility
DisableAutoConnect bool DisableAutoConnect bool
// DNSRouteInterval is the interval in which the DNS routes are updated
DNSRouteInterval time.Duration
//Path to a certificate used for mTLS authentication
ClientCertPath string
//Path to corresponding private key of ClientCertPath
ClientCertKeyPath string
ClientCertKeyPair *tls.Certificate `json:"-"`
} }
// ReadConfig read config file and return with Config. If it is not exists create a new with default values // ReadConfig read config file and return with Config. If it is not exists create a new with default values
func ReadConfig(configPath string) (*Config, error) { func ReadConfig(configPath string) (*Config, error) {
if configFileIsExists(configPath) { if configFileIsExists(configPath) {
err := util.EnforcePermission(configPath)
if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err)
}
config := &Config{} config := &Config{}
if _, err := util.ReadJson(configPath, config); err != nil { if _, err := util.ReadJson(configPath, config); err != nil {
return nil, err return nil, err
} }
// initialize through apply() without changes
if changed, err := config.apply(ConfigInput{}); err != nil {
return nil, err
} else if changed {
if err = WriteOutConfig(configPath, config); err != nil {
return nil, err
}
}
return config, nil return config, nil
} }
@@ -164,17 +130,13 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg) err = WriteOutConfig(input.ConfigPath, cfg)
return cfg, err return cfg, err
} }
if isPreSharedKeyHidden(input.PreSharedKey) { if isPreSharedKeyHidden(input.PreSharedKey) {
input.PreSharedKey = nil input.PreSharedKey = nil
} }
err := util.EnforcePermission(input.ConfigPath)
if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err)
}
return update(input) return update(input)
} }
@@ -190,15 +152,79 @@ func WriteOutConfig(path string, config *Config) error {
// createNewConfig creates a new config generating a new Wireguard key and saving to file // createNewConfig creates a new config generating a new Wireguard key and saving to file
func createNewConfig(input ConfigInput) (*Config, error) { func createNewConfig(input ConfigInput) (*Config, error) {
config := &Config{ wgKey := generateKey()
// defaults to false only for new (post 0.26) configurations pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
ServerSSHAllowed: util.False(), if err != nil {
}
if _, err := config.apply(input); err != nil {
return nil, err return nil, err
} }
config := &Config{
SSHKey: string(pem),
PrivateKey: wgKey,
IFaceBlackList: []string{},
DisableIPv6Discovery: false,
NATExternalIPs: input.NATExternalIPs,
CustomDNSAddress: string(input.CustomDNSAddress),
ServerSSHAllowed: util.False(),
DisableAutoConnect: false,
}
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
if err != nil {
return nil, err
}
config.ManagementURL = defaultManagementURL
if input.ManagementURL != "" {
URL, err := parseURL("Management URL", input.ManagementURL)
if err != nil {
return nil, err
}
config.ManagementURL = URL
}
config.WgPort = iface.DefaultWgPort
if input.WireguardPort != nil {
config.WgPort = *input.WireguardPort
}
config.WgIface = iface.WgInterfaceDefault
if input.InterfaceName != nil {
config.WgIface = *input.InterfaceName
}
if input.PreSharedKey != nil {
config.PreSharedKey = *input.PreSharedKey
}
if input.RosenpassEnabled != nil {
config.RosenpassEnabled = *input.RosenpassEnabled
}
if input.RosenpassPermissive != nil {
config.RosenpassPermissive = *input.RosenpassPermissive
}
if input.ServerSSHAllowed != nil {
config.ServerSSHAllowed = input.ServerSSHAllowed
}
defaultAdminURL, err := parseURL("Admin URL", DefaultAdminURL)
if err != nil {
return nil, err
}
config.AdminURL = defaultAdminURL
if input.AdminURL != "" {
newURL, err := parseURL("Admin Panel URL", input.AdminURL)
if err != nil {
return nil, err
}
config.AdminURL = newURL
}
// nolint:gocritic
config.IFaceBlackList = append(defaultInterfaceBlacklist, input.ExtraIFaceBlackList...)
return config, nil return config, nil
} }
@@ -209,12 +235,104 @@ func update(input ConfigInput) (*Config, error) {
return nil, err return nil, err
} }
updated, err := config.apply(input) refresh := false
if input.ManagementURL != "" && config.ManagementURL.String() != input.ManagementURL {
log.Infof("new Management URL provided, updated to %s (old value %s)",
input.ManagementURL, config.ManagementURL)
newURL, err := parseURL("Management URL", input.ManagementURL)
if err != nil { if err != nil {
return nil, err return nil, err
} }
config.ManagementURL = newURL
refresh = true
}
if updated { if input.AdminURL != "" && (config.AdminURL == nil || config.AdminURL.String() != input.AdminURL) {
log.Infof("new Admin Panel URL provided, updated to %s (old value %s)",
input.AdminURL, config.AdminURL)
newURL, err := parseURL("Admin Panel URL", input.AdminURL)
if err != nil {
return nil, err
}
config.AdminURL = newURL
refresh = true
}
if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey {
log.Infof("new pre-shared key provided, replacing old key")
config.PreSharedKey = *input.PreSharedKey
refresh = true
}
if config.SSHKey == "" {
pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
if err != nil {
return nil, err
}
config.SSHKey = string(pem)
refresh = true
}
if config.WgPort == 0 {
config.WgPort = iface.DefaultWgPort
refresh = true
}
if input.WireguardPort != nil {
config.WgPort = *input.WireguardPort
refresh = true
}
if input.InterfaceName != nil {
config.WgIface = *input.InterfaceName
refresh = true
}
if input.NATExternalIPs != nil && len(config.NATExternalIPs) != len(input.NATExternalIPs) {
config.NATExternalIPs = input.NATExternalIPs
refresh = true
}
if input.CustomDNSAddress != nil {
config.CustomDNSAddress = string(input.CustomDNSAddress)
refresh = true
}
if input.RosenpassEnabled != nil {
config.RosenpassEnabled = *input.RosenpassEnabled
refresh = true
}
if input.RosenpassPermissive != nil {
config.RosenpassPermissive = *input.RosenpassPermissive
refresh = true
}
if input.DisableAutoConnect != nil {
config.DisableAutoConnect = *input.DisableAutoConnect
refresh = true
}
if input.ServerSSHAllowed != nil {
config.ServerSSHAllowed = input.ServerSSHAllowed
refresh = true
}
if config.ServerSSHAllowed == nil {
config.ServerSSHAllowed = util.True()
refresh = true
}
if len(input.ExtraIFaceBlackList) > 0 {
for _, iFace := range util.SliceDiff(input.ExtraIFaceBlackList, config.IFaceBlackList) {
config.IFaceBlackList = append(config.IFaceBlackList, iFace)
refresh = true
}
}
if refresh {
// since we have new management URL, we need to update config file
if err := util.WriteJson(input.ConfigPath, config); err != nil { if err := util.WriteJson(input.ConfigPath, config); err != nil {
return nil, err return nil, err
} }
@@ -223,210 +341,6 @@ func update(input ConfigInput) (*Config, error) {
return config, nil return config, nil
} }
func (config *Config) apply(input ConfigInput) (updated bool, err error) {
if config.ManagementURL == nil {
log.Infof("using default Management URL %s", DefaultManagementURL)
config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL)
if err != nil {
return false, err
}
}
if input.ManagementURL != "" && input.ManagementURL != config.ManagementURL.String() {
log.Infof("new Management URL provided, updated to %#v (old value %#v)",
input.ManagementURL, config.ManagementURL.String())
URL, err := parseURL("Management URL", input.ManagementURL)
if err != nil {
return false, err
}
config.ManagementURL = URL
updated = true
} else if config.ManagementURL == nil {
log.Infof("using default Management URL %s", DefaultManagementURL)
config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL)
if err != nil {
return false, err
}
}
if config.AdminURL == nil {
log.Infof("using default Admin URL %s", DefaultManagementURL)
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
if err != nil {
return false, err
}
}
if input.AdminURL != "" && input.AdminURL != config.AdminURL.String() {
log.Infof("new Admin Panel URL provided, updated to %#v (old value %#v)",
input.AdminURL, config.AdminURL.String())
newURL, err := parseURL("Admin Panel URL", input.AdminURL)
if err != nil {
return updated, err
}
config.AdminURL = newURL
updated = true
}
if config.PrivateKey == "" {
log.Infof("generated new Wireguard key")
config.PrivateKey = generateKey()
updated = true
}
if config.SSHKey == "" {
log.Infof("generated new SSH key")
pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
if err != nil {
return false, err
}
config.SSHKey = string(pem)
updated = true
}
if input.WireguardPort != nil && *input.WireguardPort != config.WgPort {
log.Infof("updating Wireguard port %d (old value %d)",
*input.WireguardPort, config.WgPort)
config.WgPort = *input.WireguardPort
updated = true
} else if config.WgPort == 0 {
config.WgPort = iface.DefaultWgPort
log.Infof("using default Wireguard port %d", config.WgPort)
updated = true
}
if input.InterfaceName != nil && *input.InterfaceName != config.WgIface {
log.Infof("updating Wireguard interface %#v (old value %#v)",
*input.InterfaceName, config.WgIface)
config.WgIface = *input.InterfaceName
updated = true
} else if config.WgIface == "" {
config.WgIface = iface.WgInterfaceDefault
log.Infof("using default Wireguard interface %s", config.WgIface)
updated = true
}
if input.NATExternalIPs != nil && !reflect.DeepEqual(config.NATExternalIPs, input.NATExternalIPs) {
log.Infof("updating NAT External IP [ %s ] (old value: [ %s ])",
strings.Join(input.NATExternalIPs, " "),
strings.Join(config.NATExternalIPs, " "))
config.NATExternalIPs = input.NATExternalIPs
updated = true
}
if input.PreSharedKey != nil && *input.PreSharedKey != config.PreSharedKey {
log.Infof("new pre-shared key provided, replacing old key")
config.PreSharedKey = *input.PreSharedKey
updated = true
}
if input.RosenpassEnabled != nil && *input.RosenpassEnabled != config.RosenpassEnabled {
log.Infof("switching Rosenpass to %t", *input.RosenpassEnabled)
config.RosenpassEnabled = *input.RosenpassEnabled
updated = true
}
if input.RosenpassPermissive != nil && *input.RosenpassPermissive != config.RosenpassPermissive {
log.Infof("switching Rosenpass permissive to %t", *input.RosenpassPermissive)
config.RosenpassPermissive = *input.RosenpassPermissive
updated = true
}
if input.NetworkMonitor != nil && input.NetworkMonitor != config.NetworkMonitor {
log.Infof("switching Network Monitor to %t", *input.NetworkMonitor)
config.NetworkMonitor = input.NetworkMonitor
updated = true
}
if config.NetworkMonitor == nil {
// enable network monitoring by default on windows and darwin clients
if runtime.GOOS == "windows" || runtime.GOOS == "darwin" {
enabled := true
config.NetworkMonitor = &enabled
updated = true
}
}
if input.CustomDNSAddress != nil && string(input.CustomDNSAddress) != config.CustomDNSAddress {
log.Infof("updating custom DNS address %#v (old value %#v)",
string(input.CustomDNSAddress), config.CustomDNSAddress)
config.CustomDNSAddress = string(input.CustomDNSAddress)
updated = true
}
if len(config.IFaceBlackList) == 0 {
log.Infof("filling in interface blacklist with defaults: [ %s ]",
strings.Join(defaultInterfaceBlacklist, " "))
config.IFaceBlackList = append(config.IFaceBlackList, defaultInterfaceBlacklist...)
updated = true
}
if len(input.ExtraIFaceBlackList) > 0 {
for _, iFace := range util.SliceDiff(input.ExtraIFaceBlackList, config.IFaceBlackList) {
log.Infof("adding new entry to interface blacklist: %s", iFace)
config.IFaceBlackList = append(config.IFaceBlackList, iFace)
updated = true
}
}
if input.DisableAutoConnect != nil && *input.DisableAutoConnect != config.DisableAutoConnect {
if *input.DisableAutoConnect {
log.Infof("turning off automatic connection on startup")
} else {
log.Infof("enabling automatic connection on startup")
}
config.DisableAutoConnect = *input.DisableAutoConnect
updated = true
}
if input.ServerSSHAllowed != nil && *input.ServerSSHAllowed != *config.ServerSSHAllowed {
if *input.ServerSSHAllowed {
log.Infof("enabling SSH server")
} else {
log.Infof("disabling SSH server")
}
config.ServerSSHAllowed = input.ServerSSHAllowed
updated = true
} else if config.ServerSSHAllowed == nil {
// enables SSH for configs from old versions to preserve backwards compatibility
log.Infof("falling back to enabled SSH server for pre-existing configuration")
config.ServerSSHAllowed = util.True()
updated = true
}
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
log.Infof("updating DNS route interval to %s (old value %s)",
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())
config.DNSRouteInterval = *input.DNSRouteInterval
updated = true
} else if config.DNSRouteInterval == 0 {
config.DNSRouteInterval = dynamic.DefaultInterval
log.Infof("using default DNS route interval %s", config.DNSRouteInterval)
updated = true
}
if input.ClientCertKeyPath != "" {
config.ClientCertKeyPath = input.ClientCertKeyPath
updated = true
}
if input.ClientCertPath != "" {
config.ClientCertPath = input.ClientCertPath
updated = true
}
if config.ClientCertPath != "" && config.ClientCertKeyPath != "" {
cert, err := tls.LoadX509KeyPair(config.ClientCertPath, config.ClientCertKeyPath)
if err != nil {
log.Error("Failed to load mTLS cert/key pair: ", err)
} else {
config.ClientCertKeyPair = &cert
log.Info("Loaded client mTLS cert/key pair")
}
}
return updated, nil
}
// parseURL parses and validates a service URL // parseURL parses and validates a service URL
func parseURL(serviceName, serviceURL string) (*url.URL, error) { func parseURL(serviceName, serviceURL string) (*url.URL, error) {
parsedMgmtURL, err := url.ParseRequestURI(serviceURL) parsedMgmtURL, err := url.ParseRequestURI(serviceURL)

View File

@@ -4,11 +4,9 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"net"
"runtime" "runtime"
"runtime/debug" "runtime/debug"
"strings" "strings"
"sync"
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
@@ -17,61 +15,44 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client" mgm "github.com/netbirdio/netbird/management/client"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/relay/auth/hmac"
relayClient "github.com/netbirdio/netbird/relay/client"
signal "github.com/netbirdio/netbird/signal/client" signal "github.com/netbirdio/netbird/signal/client"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
type ConnectClient struct { // RunClient with main logic.
ctx context.Context func RunClient(ctx context.Context, config *Config, statusRecorder *peer.Status) error {
config *Config return runClient(ctx, config, statusRecorder, MobileDependency{}, nil, nil, nil, nil)
statusRecorder *peer.Status
engine *Engine
engineMutex sync.Mutex
} }
func NewConnectClient( // RunClientWithProbes runs the client's main logic with probes attached
func RunClientWithProbes(
ctx context.Context, ctx context.Context,
config *Config, config *Config,
statusRecorder *peer.Status, statusRecorder *peer.Status,
mgmProbe *Probe,
) *ConnectClient { signalProbe *Probe,
return &ConnectClient{ relayProbe *Probe,
ctx: ctx, wgProbe *Probe,
config: config,
statusRecorder: statusRecorder,
engineMutex: sync.Mutex{},
}
}
// Run with main logic.
func (c *ConnectClient) Run() error {
return c.run(MobileDependency{}, nil, nil)
}
// RunWithProbes runs the client's main logic with probes attached
func (c *ConnectClient) RunWithProbes(
probes *ProbeHolder,
runningChan chan error,
) error { ) error {
return c.run(MobileDependency{}, probes, runningChan) return runClient(ctx, config, statusRecorder, MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe)
} }
// RunOnAndroid with main logic on mobile system // RunClientMobile with main logic on mobile system
func (c *ConnectClient) RunOnAndroid( func RunClientMobile(
tunAdapter device.TunAdapter, ctx context.Context,
config *Config,
statusRecorder *peer.Status,
tunAdapter iface.TunAdapter,
iFaceDiscover stdnet.ExternalIFaceDiscover, iFaceDiscover stdnet.ExternalIFaceDiscover,
networkChangeListener listener.NetworkChangeListener, networkChangeListener listener.NetworkChangeListener,
dnsAddresses []string, dnsAddresses []string,
@@ -85,29 +66,34 @@ func (c *ConnectClient) RunOnAndroid(
HostDNSAddresses: dnsAddresses, HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener, DnsReadyListener: dnsReadyListener,
} }
return c.run(mobileDependency, nil, nil) return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil)
} }
func (c *ConnectClient) RunOniOS( func RunClientiOS(
ctx context.Context,
config *Config,
statusRecorder *peer.Status,
fileDescriptor int32, fileDescriptor int32,
networkChangeListener listener.NetworkChangeListener, networkChangeListener listener.NetworkChangeListener,
dnsManager dns.IosDnsManager, dnsManager dns.IosDnsManager,
) error { ) error {
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
debug.SetGCPercent(5)
mobileDependency := MobileDependency{ mobileDependency := MobileDependency{
FileDescriptor: fileDescriptor, FileDescriptor: fileDescriptor,
NetworkChangeListener: networkChangeListener, NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager, DnsManager: dnsManager,
} }
return c.run(mobileDependency, nil, nil) return runClient(ctx, config, statusRecorder, mobileDependency, nil, nil, nil, nil)
} }
func (c *ConnectClient) run( func runClient(
ctx context.Context,
config *Config,
statusRecorder *peer.Status,
mobileDependency MobileDependency, mobileDependency MobileDependency,
probes *ProbeHolder, mgmProbe *Probe,
runningChan chan error, signalProbe *Probe,
relayProbe *Probe,
wgProbe *Probe,
) error { ) error {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
@@ -119,7 +105,7 @@ func (c *ConnectClient) run(
// Check if client was not shut down in a clean way and restore DNS config if required. // Check if client was not shut down in a clean way and restore DNS config if required.
// Otherwise, we might not be able to connect to the management server to retrieve new config. // Otherwise, we might not be able to connect to the management server to retrieve new config.
if err := dns.CheckUncleanShutdown(c.config.WgIface); err != nil { if err := dns.CheckUncleanShutdown(config.WgIface); err != nil {
log.Errorf("checking unclean shutdown error: %s", err) log.Errorf("checking unclean shutdown error: %s", err)
} }
@@ -133,7 +119,7 @@ func (c *ConnectClient) run(
Clock: backoff.SystemClock, Clock: backoff.SystemClock,
} }
state := CtxGetState(c.ctx) state := CtxGetState(ctx)
defer func() { defer func() {
s, err := state.Status() s, err := state.Status()
if err != nil || s != StatusNeedsLogin { if err != nil || s != StatusNeedsLogin {
@@ -142,50 +128,52 @@ func (c *ConnectClient) run(
}() }()
wrapErr := state.Wrap wrapErr := state.Wrap
myPrivateKey, err := wgtypes.ParseKey(c.config.PrivateKey) myPrivateKey, err := wgtypes.ParseKey(config.PrivateKey)
if err != nil { if err != nil {
log.Errorf("failed parsing Wireguard key %s: [%s]", c.config.PrivateKey, err.Error()) log.Errorf("failed parsing Wireguard key %s: [%s]", config.PrivateKey, err.Error())
return wrapErr(err) return wrapErr(err)
} }
var mgmTlsEnabled bool var mgmTlsEnabled bool
if c.config.ManagementURL.Scheme == "https" { if config.ManagementURL.Scheme == "https" {
mgmTlsEnabled = true mgmTlsEnabled = true
} }
publicSSHKey, err := ssh.GeneratePublicKey([]byte(c.config.SSHKey)) publicSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
if err != nil { if err != nil {
return err return err
} }
defer c.statusRecorder.ClientStop() defer statusRecorder.ClientStop()
runningChanOpen := true
operation := func() error { operation := func() error {
// if context cancelled we not start new backoff cycle // if context cancelled we not start new backoff cycle
if c.isContextCancelled() { select {
case <-ctx.Done():
return nil return nil
default:
} }
state.Set(StatusConnecting) state.Set(StatusConnecting)
engineCtx, cancel := context.WithCancel(c.ctx) engineCtx, cancel := context.WithCancel(ctx)
defer func() { defer func() {
c.statusRecorder.MarkManagementDisconnected(state.err) statusRecorder.MarkManagementDisconnected(state.err)
c.statusRecorder.CleanLocalPeerState() statusRecorder.CleanLocalPeerState()
cancel() cancel()
}() }()
log.Debugf("connecting to the Management service %s", c.config.ManagementURL.Host) log.Debugf("connecting to the Management service %s", config.ManagementURL.Host)
mgmClient, err := mgm.NewClient(engineCtx, c.config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled) mgmClient, err := mgm.NewClient(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled)
if err != nil { if err != nil {
return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err)) return wrapErr(gstatus.Errorf(codes.FailedPrecondition, "failed connecting to Management Service : %s", err))
} }
mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder) mgmNotifier := statusRecorderToMgmConnStateNotifier(statusRecorder)
mgmClient.SetConnStateListener(mgmNotifier) mgmClient.SetConnStateListener(mgmNotifier)
log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host) log.Debugf("connected to the Management service %s", config.ManagementURL.Host)
defer func() { defer func() {
if err = mgmClient.Close(); err != nil { err = mgmClient.Close()
if err != nil {
log.Warnf("failed to close the Management service client %v", err) log.Warnf("failed to close the Management service client %v", err)
} }
}() }()
@@ -196,31 +184,31 @@ func (c *ConnectClient) run(
log.Debug(err) log.Debug(err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) { if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
state.Set(StatusNeedsLogin) state.Set(StatusNeedsLogin)
_ = c.Stop()
return backoff.Permanent(wrapErr(err)) // unrecoverable error return backoff.Permanent(wrapErr(err)) // unrecoverable error
} }
return wrapErr(err) return wrapErr(err)
} }
c.statusRecorder.MarkManagementConnected() statusRecorder.MarkManagementConnected()
localPeerState := peer.LocalPeerState{ localPeerState := peer.LocalPeerState{
IP: loginResp.GetPeerConfig().GetAddress(), IP: loginResp.GetPeerConfig().GetAddress(),
PubKey: myPrivateKey.PublicKey().String(), PubKey: myPrivateKey.PublicKey().String(),
KernelInterface: device.WireGuardModuleIsLoaded(), KernelInterface: iface.WireGuardModuleIsLoaded(),
FQDN: loginResp.GetPeerConfig().GetFqdn(), FQDN: loginResp.GetPeerConfig().GetFqdn(),
} }
c.statusRecorder.UpdateLocalPeerState(localPeerState)
statusRecorder.UpdateLocalPeerState(localPeerState)
signalURL := fmt.Sprintf("%s://%s", signalURL := fmt.Sprintf("%s://%s",
strings.ToLower(loginResp.GetWiretrusteeConfig().GetSignal().GetProtocol().String()), strings.ToLower(loginResp.GetWiretrusteeConfig().GetSignal().GetProtocol().String()),
loginResp.GetWiretrusteeConfig().GetSignal().GetUri(), loginResp.GetWiretrusteeConfig().GetSignal().GetUri(),
) )
c.statusRecorder.UpdateSignalAddress(signalURL) statusRecorder.UpdateSignalAddress(signalURL)
c.statusRecorder.MarkSignalDisconnected(nil) statusRecorder.MarkSignalDisconnected(nil)
defer func() { defer func() {
c.statusRecorder.MarkSignalDisconnected(state.err) statusRecorder.MarkSignalDisconnected(state.err)
}() }()
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal // with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
@@ -236,71 +224,40 @@ func (c *ConnectClient) run(
} }
}() }()
signalNotifier := statusRecorderToSignalConnStateNotifier(c.statusRecorder) signalNotifier := statusRecorderToSignalConnStateNotifier(statusRecorder)
signalClient.SetConnStateListener(signalNotifier) signalClient.SetConnStateListener(signalNotifier)
c.statusRecorder.MarkSignalConnected() statusRecorder.MarkSignalConnected()
relayURLs, token := parseRelayInfo(loginResp)
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String())
if len(relayURLs) > 0 {
if token != nil {
if err := relayManager.UpdateToken(token); err != nil {
log.Errorf("failed to update token: %s", err)
return wrapErr(err)
}
}
log.Infof("connecting to the Relay service(s): %s", strings.Join(relayURLs, ", "))
if err = relayManager.Serve(); err != nil {
log.Error(err)
return wrapErr(err)
}
c.statusRecorder.SetRelayMgr(relayManager)
}
peerConfig := loginResp.GetPeerConfig() peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig) engineConfig, err := createEngineConfig(myPrivateKey, config, peerConfig)
if err != nil { if err != nil {
log.Error(err) log.Error(err)
return wrapErr(err) return wrapErr(err)
} }
checks := loginResp.GetChecks() engine := NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe)
err = engine.Start()
c.engineMutex.Lock() if err != nil {
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
c.engineMutex.Unlock()
if err := c.engine.Start(); err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err) log.Errorf("error while starting Netbird Connection Engine: %s", err)
return wrapErr(err) return wrapErr(err)
} }
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress()) log.Print("Netbird engine started, my IP is: ", peerConfig.Address)
state.Set(StatusConnected) state.Set(StatusConnected)
if runningChan != nil && runningChanOpen {
runningChan <- nil
close(runningChan)
runningChanOpen = false
}
<-engineCtx.Done() <-engineCtx.Done()
c.engineMutex.Lock() statusRecorder.ClientTeardown()
if c.engine != nil && c.engine.wgInterface != nil {
log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name())
if err := c.engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
c.engine = nil
}
c.engineMutex.Unlock()
c.statusRecorder.ClientTeardown()
backOff.Reset() backOff.Reset()
err = engine.Stop()
if err != nil {
log.Errorf("failed stopping engine %v", err)
return wrapErr(err)
}
log.Info("stopped NetBird client") log.Info("stopped NetBird client")
if _, err := state.Status(); errors.Is(err, ErrResetConnection) { if _, err := state.Status(); errors.Is(err, ErrResetConnection) {
@@ -310,72 +267,20 @@ func (c *ConnectClient) run(
return nil return nil
} }
c.statusRecorder.ClientStart() statusRecorder.ClientStart()
err = backoff.Retry(operation, backOff) err = backoff.Retry(operation, backOff)
if err != nil { if err != nil {
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err) log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) { if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
state.Set(StatusNeedsLogin) state.Set(StatusNeedsLogin)
_ = c.Stop()
} }
return err return err
} }
return nil return nil
} }
func parseRelayInfo(loginResp *mgmProto.LoginResponse) ([]string, *hmac.Token) {
relayCfg := loginResp.GetWiretrusteeConfig().GetRelay()
if relayCfg == nil {
return nil, nil
}
token := &hmac.Token{
Payload: relayCfg.GetTokenPayload(),
Signature: relayCfg.GetTokenSignature(),
}
return relayCfg.GetUrls(), token
}
func (c *ConnectClient) Engine() *Engine {
if c == nil {
return nil
}
var e *Engine
c.engineMutex.Lock()
e = c.engine
c.engineMutex.Unlock()
return e
}
func (c *ConnectClient) Stop() error {
if c == nil {
return nil
}
c.engineMutex.Lock()
defer c.engineMutex.Unlock()
if c.engine == nil {
return nil
}
return c.engine.Stop()
}
func (c *ConnectClient) isContextCancelled() bool {
select {
case <-c.ctx.Done():
return true
default:
return false
}
}
// createEngineConfig converts configuration received from Management Service to EngineConfig // createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) { func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
nm := false
if config.NetworkMonitor != nil {
nm = *config.NetworkMonitor
}
engineConf := &EngineConfig{ engineConf := &EngineConfig{
WgIfaceName: config.WgIface, WgIfaceName: config.WgIface,
WgAddr: peerConfig.Address, WgAddr: peerConfig.Address,
@@ -383,14 +288,12 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
DisableIPv6Discovery: config.DisableIPv6Discovery, DisableIPv6Discovery: config.DisableIPv6Discovery,
WgPrivateKey: key, WgPrivateKey: key,
WgPort: config.WgPort, WgPort: config.WgPort,
NetworkMonitor: nm,
SSHKey: []byte(config.SSHKey), SSHKey: []byte(config.SSHKey),
NATExternalIPs: config.NATExternalIPs, NATExternalIPs: config.NATExternalIPs,
CustomDNSAddress: config.CustomDNSAddress, CustomDNSAddress: config.CustomDNSAddress,
RosenpassEnabled: config.RosenpassEnabled, RosenpassEnabled: config.RosenpassEnabled,
RosenpassPermissive: config.RosenpassPermissive, RosenpassPermissive: config.RosenpassPermissive,
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed), ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
DNSRouteInterval: config.DNSRouteInterval,
} }
if config.PreSharedKey != "" { if config.PreSharedKey != "" {
@@ -401,15 +304,6 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
engineConf.PreSharedKey = &preSharedKey engineConf.PreSharedKey = &preSharedKey
} }
port, err := freePort(config.WgPort)
if err != nil {
return nil, err
}
if port != config.WgPort {
log.Infof("using %d as wireguard port: %d is in use", port, config.WgPort)
}
engineConf.WgPort = port
return engineConf, nil return engineConf, nil
} }
@@ -459,44 +353,3 @@ func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal
notifier, _ := sri.(signal.ConnStateNotifier) notifier, _ := sri.(signal.ConnStateNotifier)
return notifier return notifier
} }
// freePort attempts to determine if the provided port is available, if not it will ask the system for a free port.
func freePort(initPort int) (int, error) {
addr := net.UDPAddr{}
if initPort == 0 {
initPort = iface.DefaultWgPort
}
addr.Port = initPort
conn, err := net.ListenUDP("udp", &addr)
if err == nil {
closeConnWithLog(conn)
return initPort, nil
}
// if the port is already in use, ask the system for a free port
addr.Port = 0
conn, err = net.ListenUDP("udp", &addr)
if err != nil {
return 0, fmt.Errorf("unable to get a free port: %v", err)
}
udpAddr, ok := conn.LocalAddr().(*net.UDPAddr)
if !ok {
return 0, errors.New("wrong address type when getting a free port")
}
closeConnWithLog(conn)
return udpAddr.Port, nil
}
func closeConnWithLog(conn *net.UDPConn) {
startClosing := time.Now()
err := conn.Close()
if err != nil {
log.Warnf("closing probe port %d failed: %v. NetBird will still attempt to use this port for connection.", conn.LocalAddr().(*net.UDPAddr).Port, err)
}
if time.Since(startClosing) > time.Second {
log.Warnf("closing the testing port %d took %s. Usually it is safe to ignore, but continuous warnings may indicate a problem.", conn.LocalAddr().(*net.UDPAddr).Port, time.Since(startClosing))
}
}

View File

@@ -1,61 +0,0 @@
package internal
import (
"net"
"testing"
)
func Test_freePort(t *testing.T) {
tests := []struct {
name string
port int
want int
shouldMatch bool
}{
{
name: "not provided, fallback to default",
port: 0,
want: 51820,
shouldMatch: true,
},
{
name: "provided and available",
port: 51821,
want: 51821,
shouldMatch: true,
},
{
name: "provided and not available",
port: 51830,
want: 51830,
shouldMatch: false,
},
}
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 51830})
if err != nil {
t.Errorf("freePort error = %v", err)
}
defer func(c1 *net.UDPConn) {
_ = c1.Close()
}(c1)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := freePort(tt.port)
if err != nil {
t.Errorf("got an error while getting free port: %v", err)
}
if tt.shouldMatch && got != tt.want {
t.Errorf("got a different port %v, want %v", got, tt.want)
}
if !tt.shouldMatch && got == tt.want {
t.Errorf("got the same port %v, want a different port", tt.want)
}
})
}
}

View File

@@ -1,6 +0,0 @@
package dns
const (
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/db/netbird/manager"
)

Some files were not shown because too many files have changed in this diff Show More