diff --git a/.github/ISSUE_TEMPLATE/bug-issue-report.md b/.github/ISSUE_TEMPLATE/bug-issue-report.md
index 789c61974..87f757f42 100644
--- a/.github/ISSUE_TEMPLATE/bug-issue-report.md
+++ b/.github/ISSUE_TEMPLATE/bug-issue-report.md
@@ -35,7 +35,7 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
If applicable, add the `netbird status -dA' command output.
-**Do you face any client issues on desktop?**
+**Do you face any (non-mobile) client issues?**
Please provide the file created by `netbird debug for 1m -AS`.
We advise reviewing the anonymized files for any remaining PII.
diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml
index 2b4c43cb4..2aaef7564 100644
--- a/.github/workflows/golang-test-darwin.yml
+++ b/.github/workflows/golang-test-darwin.yml
@@ -18,14 +18,14 @@ jobs:
runs-on: macos-latest
steps:
- name: Install Go
- uses: actions/setup-go@v4
+ uses: actions/setup-go@v5
with:
- go-version: "1.21.x"
+ go-version: "1.23.x"
- name: Checkout code
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: Cache Go modules
- uses: actions/cache@v3
+ uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: macos-go-${{ hashFiles('**/go.sum') }}
diff --git a/.github/workflows/golang-test-freebsd.yml b/.github/workflows/golang-test-freebsd.yml
index 4f13ee30e..a2d743715 100644
--- a/.github/workflows/golang-test-freebsd.yml
+++ b/.github/workflows/golang-test-freebsd.yml
@@ -38,7 +38,7 @@ jobs:
time go test -timeout 1m -failfast ./dns/...
time go test -timeout 1m -failfast ./encryption/...
time go test -timeout 1m -failfast ./formatter/...
- time go test -timeout 1m -failfast ./iface/...
+ time go test -timeout 1m -failfast ./client/iface/...
time go test -timeout 1m -failfast ./route/...
time go test -timeout 1m -failfast ./sharedsock/...
time go test -timeout 1m -failfast ./signal/...
diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml
index 120b213e9..524f35f6f 100644
--- a/.github/workflows/golang-test-linux.yml
+++ b/.github/workflows/golang-test-linux.yml
@@ -19,13 +19,13 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Install Go
- uses: actions/setup-go@v4
+ uses: actions/setup-go@v5
with:
- go-version: "1.21.x"
+ go-version: "1.23.x"
- name: Cache Go modules
- uses: actions/cache@v3
+ uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -33,7 +33,7 @@ jobs:
${{ runner.os }}-go-
- name: Checkout code
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
@@ -49,18 +49,18 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
- run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
+ run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./...
test_client_on_docker:
runs-on: ubuntu-20.04
steps:
- name: Install Go
- uses: actions/setup-go@v4
+ uses: actions/setup-go@v5
with:
- go-version: "1.21.x"
+ go-version: "1.23.x"
- name: Cache Go modules
- uses: actions/cache@v3
+ uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -68,7 +68,7 @@ jobs:
${{ runner.os }}-go-
- name: Checkout code
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
@@ -80,7 +80,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Generate Iface Test bin
- run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./iface/
+ run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./client/iface/
- name: Generate Shared Sock Test bin
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
@@ -124,4 +124,4 @@ jobs:
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Peer tests in docker
- run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
\ No newline at end of file
+ run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml
index 2d63acbcd..d378bec3f 100644
--- a/.github/workflows/golang-test-windows.yml
+++ b/.github/workflows/golang-test-windows.yml
@@ -17,13 +17,13 @@ jobs:
runs-on: windows-latest
steps:
- name: Checkout code
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: Install Go
- uses: actions/setup-go@v4
+ uses: actions/setup-go@v5
id: go
with:
- go-version: "1.21.x"
+ go-version: "1.23.x"
- name: Download wintun
uses: carlosperate/download-file-action@v2
diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml
index 78b9f504f..2d743f790 100644
--- a/.github/workflows/golangci-lint.yml
+++ b/.github/workflows/golangci-lint.yml
@@ -15,11 +15,11 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout code
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
- ignore_words_list: erro,clienta,hastable,
+ ignore_words_list: erro,clienta,hastable,iif
skip: go.mod,go.sum
only_warn: 1
golangci:
@@ -32,15 +32,15 @@ jobs:
timeout-minutes: 15
steps:
- name: Checkout code
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest'
run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go
- uses: actions/setup-go@v4
+ uses: actions/setup-go@v5
with:
- go-version: "1.21.x"
+ go-version: "1.23.x"
cache: false
- name: Install dependencies
if: matrix.os == 'ubuntu-latest'
@@ -49,4 +49,4 @@ jobs:
uses: golangci/golangci-lint-action@v3
with:
version: latest
- args: --timeout=12m
\ No newline at end of file
+ args: --timeout=12m
diff --git a/.github/workflows/install-script-test.yml b/.github/workflows/install-script-test.yml
index dfb8a279b..22d002a48 100644
--- a/.github/workflows/install-script-test.yml
+++ b/.github/workflows/install-script-test.yml
@@ -13,6 +13,7 @@ concurrency:
jobs:
test-install-script:
strategy:
+ fail-fast: false
max-parallel: 2
matrix:
os: [ubuntu-latest, macos-latest]
@@ -21,7 +22,7 @@ jobs:
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: run install script
env:
diff --git a/.github/workflows/mobile-build-validation.yml b/.github/workflows/mobile-build-validation.yml
index e5a5ff485..dcf461a34 100644
--- a/.github/workflows/mobile-build-validation.yml
+++ b/.github/workflows/mobile-build-validation.yml
@@ -15,23 +15,23 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: Install Go
- uses: actions/setup-go@v4
+ uses: actions/setup-go@v5
with:
- go-version: "1.21.x"
+ go-version: "1.23.x"
- name: Setup Android SDK
uses: android-actions/setup-android@v3
with:
cmdline-tools-version: 8512546
- name: Setup Java
- uses: actions/setup-java@v3
+ uses: actions/setup-java@v4
with:
java-version: "11"
distribution: "adopt"
- name: NDK Cache
id: ndk-cache
- uses: actions/cache@v3
+ uses: actions/cache@v4
with:
path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620
@@ -50,11 +50,11 @@ jobs:
runs-on: macos-latest
steps:
- name: Checkout repository
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: Install Go
- uses: actions/setup-go@v4
+ uses: actions/setup-go@v5
with:
- go-version: "1.21.x"
+ go-version: "1.23.x"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
- name: gomobile init
@@ -62,4 +62,4 @@ jobs:
- name: build iOS netbird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o ./NetBirdSDK.xcframework ./client/ios/NetBirdSDK
env:
- CGO_ENABLED: 0
\ No newline at end of file
+ CGO_ENABLED: 0
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 30f24e92e..7af6d3e4d 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -3,15 +3,14 @@ name: Release
on:
push:
tags:
- - 'v*'
+ - "v*"
branches:
- main
pull_request:
-
env:
- SIGN_PIPE_VER: "v0.0.12"
- GORELEASER_VER: "v1.14.1"
+ SIGN_PIPE_VER: "v0.0.14"
+ GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
@@ -34,20 +33,17 @@ jobs:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- -
- name: Checkout
- uses: actions/checkout@v3
+ - name: Checkout
+ uses: actions/checkout@v4
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
- -
- name: Set up Go
- uses: actions/setup-go@v4
+ - name: Set up Go
+ uses: actions/setup-go@v5
with:
- go-version: "1.21"
+ go-version: "1.23"
cache: false
- -
- name: Cache Go modules
- uses: actions/cache@v3
+ - name: Cache Go modules
+ uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
@@ -55,24 +51,19 @@ jobs:
key: ${{ runner.os }}-go-releaser-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-releaser-
- -
- name: Install modules
+ - name: Install modules
run: go mod tidy
- -
- name: check git status
+ - name: check git status
run: git --no-pager diff --exit-code
- -
- name: Set up QEMU
+ - name: Set up QEMU
uses: docker/setup-qemu-action@v2
- -
- name: Set up Docker Buildx
+ - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- -
- name: Login to Docker hub
+ - name: Login to Docker hub
if: github.event_name != 'pull_request'
uses: docker/login-action@v1
with:
- username: netbirdio
+ username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
@@ -85,36 +76,32 @@ jobs:
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
- args: release --rm-dist ${{ env.flags }}
+ args: release --clean ${{ env.flags }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- -
- name: upload non tags for debug purposes
- uses: actions/upload-artifact@v3
+ - name: upload non tags for debug purposes
+ uses: actions/upload-artifact@v4
with:
name: release
path: dist/
retention-days: 3
- -
- name: upload linux packages
- uses: actions/upload-artifact@v3
+ - name: upload linux packages
+ uses: actions/upload-artifact@v4
with:
name: linux-packages
path: dist/netbird_linux**
retention-days: 3
- -
- name: upload windows packages
- uses: actions/upload-artifact@v3
+ - name: upload windows packages
+ uses: actions/upload-artifact@v4
with:
name: windows-packages
path: dist/netbird_windows**
retention-days: 3
- -
- name: upload macos packages
- uses: actions/upload-artifact@v3
+ - name: upload macos packages
+ uses: actions/upload-artifact@v4
with:
name: macos-packages
path: dist/netbird_darwin**
@@ -133,19 +120,19 @@ jobs:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
- name: Set up Go
- uses: actions/setup-go@v4
+ uses: actions/setup-go@v5
with:
- go-version: "1.21"
+ go-version: "1.23"
cache: false
- name: Cache Go modules
- uses: actions/cache@v3
+ uses: actions/cache@v4
with:
- path: |
+ path: |
~/go/pkg/mod
~/.cache/go-build
key: ${{ runner.os }}-ui-go-releaser-${{ hashFiles('**/go.sum') }}
@@ -169,14 +156,14 @@ jobs:
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
- args: release --config .goreleaser_ui.yaml --rm-dist ${{ env.flags }}
+ args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- name: upload non tags for debug purposes
- uses: actions/upload-artifact@v3
+ uses: actions/upload-artifact@v4
with:
name: release-ui
path: dist/
@@ -187,20 +174,17 @@ jobs:
steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- -
- name: Checkout
- uses: actions/checkout@v3
+ - name: Checkout
+ uses: actions/checkout@v4
with:
fetch-depth: 0 # It is required for GoReleaser to work properly
- -
- name: Set up Go
- uses: actions/setup-go@v4
+ - name: Set up Go
+ uses: actions/setup-go@v5
with:
- go-version: "1.21"
+ go-version: "1.23"
cache: false
- -
- name: Cache Go modules
- uses: actions/cache@v3
+ - name: Cache Go modules
+ uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
@@ -208,52 +192,34 @@ jobs:
key: ${{ runner.os }}-ui-go-releaser-darwin-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-ui-go-releaser-darwin-
- -
- name: Install modules
+ - name: Install modules
run: go mod tidy
- -
- name: check git status
+ - name: check git status
run: git --no-pager diff --exit-code
- -
- name: Run GoReleaser
+ - name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@v4
with:
version: ${{ env.GORELEASER_VER }}
- args: release --config .goreleaser_ui_darwin.yaml --rm-dist ${{ env.flags }}
+ args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- -
- name: upload non tags for debug purposes
- uses: actions/upload-artifact@v3
+ - name: upload non tags for debug purposes
+ uses: actions/upload-artifact@v4
with:
name: release-ui-darwin
path: dist/
retention-days: 3
- trigger_windows_signer:
+ trigger_signer:
runs-on: ubuntu-latest
- needs: [release,release_ui]
+ needs: [release, release_ui, release_ui_darwin]
if: startsWith(github.ref, 'refs/tags/')
steps:
- - name: Trigger Windows binaries sign pipeline
+ - name: Trigger binaries sign pipelines
uses: benc-uk/workflow-dispatch@v1
with:
- workflow: Sign windows bin and installer
- repo: netbirdio/sign-pipelines
- ref: ${{ env.SIGN_PIPE_VER }}
- token: ${{ secrets.SIGN_GITHUB_TOKEN }}
- inputs: '{ "tag": "${{ github.ref }}" }'
-
- trigger_darwin_signer:
- runs-on: ubuntu-latest
- needs: [release,release_ui_darwin]
- if: startsWith(github.ref, 'refs/tags/')
- steps:
- - name: Trigger Darwin App binaries sign pipeline
- uses: benc-uk/workflow-dispatch@v1
- with:
- workflow: Sign darwin ui app with dispatch
+ workflow: Sign bin and installer
repo: netbirdio/sign-pipelines
ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
diff --git a/.github/workflows/test-infrastructure-files.yml b/.github/workflows/test-infrastructure-files.yml
index 52b8ee3e2..da3ec746a 100644
--- a/.github/workflows/test-infrastructure-files.yml
+++ b/.github/workflows/test-infrastructure-files.yml
@@ -18,7 +18,31 @@ concurrency:
jobs:
test-docker-compose:
runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ store: [ 'sqlite', 'postgres' ]
+ services:
+ postgres:
+ image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
+ env:
+ POSTGRES_USER: netbird
+ POSTGRES_PASSWORD: postgres
+ POSTGRES_DB: netbird
+ options: >-
+ --health-cmd pg_isready
+ --health-interval 10s
+ --health-timeout 5s
+ ports:
+ - 5432:5432
steps:
+ - name: Set Database Connection String
+ run: |
+ if [ "${{ matrix.store }}" == "postgres" ]; then
+ echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN=host=$(hostname -I | awk '{print $1}') user=netbird password=postgres dbname=netbird port=5432" >> $GITHUB_ENV
+ else
+ echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN==" >> $GITHUB_ENV
+ fi
+
- name: Install jq
run: sudo apt-get install -y jq
@@ -26,12 +50,12 @@ jobs:
run: sudo apt-get install -y curl
- name: Install Go
- uses: actions/setup-go@v4
+ uses: actions/setup-go@v5
with:
- go-version: "1.21.x"
+ go-version: "1.23.x"
- name: Cache Go modules
- uses: actions/cache@v3
+ uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
@@ -39,7 +63,7 @@ jobs:
${{ runner.os }}-go-
- name: Checkout code
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: cp setup.env
run: cp infrastructure_files/tests/setup.env infrastructure_files/
@@ -58,7 +82,8 @@ jobs:
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
- CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
+ CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
+ NETBIRD_STORE_ENGINE_POSTGRES_DSN: ${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
- name: check values
@@ -85,7 +110,8 @@ jobs:
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_SIGNAL_PORT: 12345
- CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite"
+ CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
+ NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
@@ -123,6 +149,14 @@ jobs:
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP
+ grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
+ # check relay values
+ grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
+ grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
+ grep '33445:33445' docker-compose.yml
+ grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
+ grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
+ grep -A 7 Relay management.json | egrep '"Secret": ".+"'
- name: Install modules
run: go mod tidy
@@ -148,6 +182,15 @@ jobs:
run: |
docker build -t netbirdio/signal:latest .
+ - name: Build relay binary
+ working-directory: relay
+ run: CGO_ENABLED=0 go build -o netbird-relay main.go
+
+ - name: Build relay docker image
+ working-directory: relay
+ run: |
+ docker build -t netbirdio/relay:latest .
+
- name: run docker compose up
working-directory: infrastructure_files/artifacts
run: |
@@ -159,15 +202,15 @@ jobs:
- name: test running containers
run: |
count=$(docker compose ps --format json | jq '. | select(.Name | contains("artifacts")) | .State' | grep -c running)
- test $count -eq 4
+ test $count -eq 5 || docker compose logs
working-directory: infrastructure_files/artifacts
- name: test geolocation databases
working-directory: infrastructure_files/artifacts
run: |
sleep 30
- docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City.mmdb
- docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames.db
+ docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City_[0-9]*.mmdb
+ docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames_[0-9]*.db
test-getting-started-script:
runs-on: ubuntu-latest
@@ -176,7 +219,7 @@ jobs:
run: sudo apt-get install -y jq
- name: Checkout code
- uses: actions/checkout@v3
+ uses: actions/checkout@v4
- name: run script with Zitadel PostgreSQL
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
@@ -191,7 +234,7 @@ jobs:
run: test -f management.json
- name: test turnserver.conf file gen postgres
- run: |
+ run: |
set -x
test -f turnserver.conf
grep external-ip turnserver.conf
@@ -202,6 +245,9 @@ jobs:
- name: test dashboard.env file gen postgres
run: test -f dashboard.env
+ - name: test relay.env file gen postgres
+ run: test -f relay.env
+
- name: test zdb.env file gen postgres
run: test -f zdb.env
@@ -226,7 +272,7 @@ jobs:
run: test -f management.json
- name: test turnserver.conf file gen CockroachDB
- run: |
+ run: |
set -x
test -f turnserver.conf
grep external-ip turnserver.conf
@@ -237,20 +283,5 @@ jobs:
- name: test dashboard.env file gen CockroachDB
run: test -f dashboard.env
- test-download-geolite2-script:
- runs-on: ubuntu-latest
- steps:
- - name: Install jq
- run: sudo apt-get update && sudo apt-get install -y unzip sqlite3
-
- - name: Checkout code
- uses: actions/checkout@v3
-
- - name: test script
- run: bash -x infrastructure_files/download-geolite2.sh
-
- - name: test mmdb file exists
- run: test -f GeoLite2-City.mmdb
-
- - name: test geonames file exists
- run: test -f geonames.db
+ - name: test relay.env file gen CockroachDB
+ run: test -f relay.env
diff --git a/.gitignore b/.gitignore
index cdce46975..d0b4f82dd 100644
--- a/.gitignore
+++ b/.gitignore
@@ -29,4 +29,3 @@ infrastructure_files/setup.env
infrastructure_files/setup-*.env
.vscode
.DS_Store
-GeoLite2-City*
\ No newline at end of file
diff --git a/.goreleaser.yaml b/.goreleaser.yaml
index 7a219110a..cf2ce4f4f 100644
--- a/.goreleaser.yaml
+++ b/.goreleaser.yaml
@@ -1,3 +1,5 @@
+version: 2
+
project_name: netbird
builds:
- id: netbird
@@ -22,7 +24,7 @@ builds:
goarch: 386
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
- mod_timestamp: '{{ .CommitTimestamp }}'
+ mod_timestamp: "{{ .CommitTimestamp }}"
tags:
- load_wgnt_from_rsrc
@@ -42,19 +44,19 @@ builds:
- softfloat
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
- mod_timestamp: '{{ .CommitTimestamp }}'
+ mod_timestamp: "{{ .CommitTimestamp }}"
tags:
- load_wgnt_from_rsrc
- id: netbird-mgmt
dir: management
env:
- - CGO_ENABLED=1
- - >-
- {{- if eq .Runtime.Goos "linux" }}
- {{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
- {{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
- {{- end }}
+ - CGO_ENABLED=1
+ - >-
+ {{- if eq .Runtime.Goos "linux" }}
+ {{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
+ {{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
+ {{- end }}
binary: netbird-mgmt
goos:
- linux
@@ -64,7 +66,7 @@ builds:
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
- mod_timestamp: '{{ .CommitTimestamp }}'
+ mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-signal
dir: signal
@@ -78,7 +80,21 @@ builds:
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
- mod_timestamp: '{{ .CommitTimestamp }}'
+ mod_timestamp: "{{ .CommitTimestamp }}"
+
+ - id: netbird-relay
+ dir: relay
+ env: [CGO_ENABLED=0]
+ binary: netbird-relay
+ goos:
+ - linux
+ goarch:
+ - amd64
+ - arm64
+ - arm
+ ldflags:
+ - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
+ mod_timestamp: "{{ .CommitTimestamp }}"
archives:
- builds:
@@ -86,7 +102,6 @@ archives:
- netbird-static
nfpms:
-
- maintainer: Netbird
description: Netbird client.
homepage: https://netbird.io/
@@ -161,6 +176,52 @@ dockers:
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
+ - image_templates:
+ - netbirdio/relay:{{ .Version }}-amd64
+ ids:
+ - netbird-relay
+ goarch: amd64
+ use: buildx
+ dockerfile: relay/Dockerfile
+ build_flag_templates:
+ - "--platform=linux/amd64"
+ - "--label=org.opencontainers.image.created={{.Date}}"
+ - "--label=org.opencontainers.image.title={{.ProjectName}}"
+ - "--label=org.opencontainers.image.version={{.Version}}"
+ - "--label=org.opencontainers.image.revision={{.FullCommit}}"
+ - "--label=org.opencontainers.image.version={{.Version}}"
+ - "--label=maintainer=dev@netbird.io"
+ - image_templates:
+ - netbirdio/relay:{{ .Version }}-arm64v8
+ ids:
+ - netbird-relay
+ goarch: arm64
+ use: buildx
+ dockerfile: relay/Dockerfile
+ build_flag_templates:
+ - "--platform=linux/arm64"
+ - "--label=org.opencontainers.image.created={{.Date}}"
+ - "--label=org.opencontainers.image.title={{.ProjectName}}"
+ - "--label=org.opencontainers.image.version={{.Version}}"
+ - "--label=org.opencontainers.image.revision={{.FullCommit}}"
+ - "--label=org.opencontainers.image.version={{.Version}}"
+ - "--label=maintainer=dev@netbird.io"
+ - image_templates:
+ - netbirdio/relay:{{ .Version }}-arm
+ ids:
+ - netbird-relay
+ goarch: arm
+ goarm: 6
+ use: buildx
+ dockerfile: relay/Dockerfile
+ build_flag_templates:
+ - "--platform=linux/arm"
+ - "--label=org.opencontainers.image.created={{.Date}}"
+ - "--label=org.opencontainers.image.title={{.ProjectName}}"
+ - "--label=org.opencontainers.image.version={{.Version}}"
+ - "--label=org.opencontainers.image.revision={{.FullCommit}}"
+ - "--label=org.opencontainers.image.version={{.Version}}"
+ - "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/signal:{{ .Version }}-amd64
ids:
@@ -313,6 +374,18 @@ docker_manifests:
- netbirdio/netbird:{{ .Version }}-arm
- netbirdio/netbird:{{ .Version }}-amd64
+ - name_template: netbirdio/relay:{{ .Version }}
+ image_templates:
+ - netbirdio/relay:{{ .Version }}-arm64v8
+ - netbirdio/relay:{{ .Version }}-arm
+ - netbirdio/relay:{{ .Version }}-amd64
+
+ - name_template: netbirdio/relay:latest
+ image_templates:
+ - netbirdio/relay:{{ .Version }}-arm64v8
+ - netbirdio/relay:{{ .Version }}-arm
+ - netbirdio/relay:{{ .Version }}-amd64
+
- name_template: netbirdio/signal:{{ .Version }}
image_templates:
- netbirdio/signal:{{ .Version }}-arm64v8
@@ -344,10 +417,9 @@ docker_manifests:
- netbirdio/management:{{ .Version }}-debug-amd64
brews:
- -
- ids:
+ - ids:
- default
- tap:
+ repository:
owner: netbirdio
name: homebrew-tap
token: "{{ .Env.HOMEBREW_TAP_GITHUB_TOKEN }}"
@@ -364,7 +436,7 @@ brews:
uploads:
- name: debian
ids:
- - netbird-deb
+ - netbird-deb
mode: archive
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
username: dev@wiretrustee.com
@@ -386,4 +458,4 @@ checksum:
release:
extra_files:
- glob: ./infrastructure_files/getting-started-with-zitadel.sh
- - glob: ./release_files/install.sh
\ No newline at end of file
+ - glob: ./release_files/install.sh
diff --git a/.goreleaser_ui.yaml b/.goreleaser_ui.yaml
index fd92b5328..06577f4e3 100644
--- a/.goreleaser_ui.yaml
+++ b/.goreleaser_ui.yaml
@@ -1,3 +1,5 @@
+version: 2
+
project_name: netbird-ui
builds:
- id: netbird-ui
@@ -11,7 +13,7 @@ builds:
- amd64
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
- mod_timestamp: '{{ .CommitTimestamp }}'
+ mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-ui-windows
dir: client/ui
@@ -26,7 +28,7 @@ builds:
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
- -H windowsgui
- mod_timestamp: '{{ .CommitTimestamp }}'
+ mod_timestamp: "{{ .CommitTimestamp }}"
archives:
- id: linux-arch
@@ -39,7 +41,6 @@ archives:
- netbird-ui-windows
nfpms:
-
- maintainer: Netbird
description: Netbird client UI.
homepage: https://netbird.io/
@@ -77,7 +78,7 @@ nfpms:
uploads:
- name: debian
ids:
- - netbird-ui-deb
+ - netbird-ui-deb
mode: archive
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
username: dev@wiretrustee.com
diff --git a/.goreleaser_ui_darwin.yaml b/.goreleaser_ui_darwin.yaml
index 2c3afa91b..bccb7f471 100644
--- a/.goreleaser_ui_darwin.yaml
+++ b/.goreleaser_ui_darwin.yaml
@@ -1,3 +1,5 @@
+version: 2
+
project_name: netbird-ui
builds:
- id: netbird-ui-darwin
@@ -17,7 +19,7 @@ builds:
- softfloat
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
- mod_timestamp: '{{ .CommitTimestamp }}'
+ mod_timestamp: "{{ .CommitTimestamp }}"
tags:
- load_wgnt_from_rsrc
@@ -28,4 +30,4 @@ archives:
checksum:
name_template: "{{ .ProjectName }}_darwin_checksums.txt"
changelog:
- skip: true
\ No newline at end of file
+ disable: true
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 492aa5c2e..c82cfc763 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -96,7 +96,7 @@ They can be executed from the repository root before every push or PR:
**Goreleaser**
```shell
-goreleaser --snapshot --rm-dist
+goreleaser build --snapshot --clean
```
**golangci-lint**
```shell
diff --git a/README.md b/README.md
index 370445412..aa3ec41e5 100644
--- a/README.md
+++ b/README.md
@@ -17,7 +17,7 @@
-
+
@@ -30,7 +30,7 @@
See Documentation
- Join our Slack channel
+ Join our Slack channel
diff --git a/client/Dockerfile b/client/Dockerfile
index a3220bf33..b9f7c1355 100644
--- a/client/Dockerfile
+++ b/client/Dockerfile
@@ -1,4 +1,4 @@
-FROM alpine:3.19
+FROM alpine:3.20
RUN apk add --no-cache ca-certificates iptables ip6tables
ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
diff --git a/client/android/client.go b/client/android/client.go
index d937e132e..229bcd974 100644
--- a/client/android/client.go
+++ b/client/android/client.go
@@ -8,6 +8,7 @@ import (
log "github.com/sirupsen/logrus"
+ "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
@@ -15,7 +16,6 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
- "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util/net"
)
@@ -26,7 +26,7 @@ type ConnectionListener interface {
// TunAdapter export internal TunAdapter for mobile
type TunAdapter interface {
- iface.TunAdapter
+ device.TunAdapter
}
// IFaceDiscover export internal IFaceDiscover for mobile
@@ -51,7 +51,7 @@ func init() {
// Client struct manage the life circle of background service
type Client struct {
cfgFile string
- tunAdapter iface.TunAdapter
+ tunAdapter device.TunAdapter
iFaceDiscover IFaceDiscover
recorder *peer.Status
ctxCancel context.CancelFunc
diff --git a/client/cmd/down.go b/client/cmd/down.go
index 4d9f1eba4..3a324cc19 100644
--- a/client/cmd/down.go
+++ b/client/cmd/down.go
@@ -42,6 +42,8 @@ var downCmd = &cobra.Command{
log.Errorf("call service down method: %v", err)
return err
}
+
+ cmd.Println("Disconnected")
return nil
},
}
diff --git a/client/cmd/login_test.go b/client/cmd/login_test.go
index 6bb7eff4f..fa20435ea 100644
--- a/client/cmd/login_test.go
+++ b/client/cmd/login_test.go
@@ -5,8 +5,8 @@ import (
"strings"
"testing"
+ "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal"
- "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util"
)
diff --git a/client/cmd/root_test.go b/client/cmd/root_test.go
index abb7d41b2..4cbbe8783 100644
--- a/client/cmd/root_test.go
+++ b/client/cmd/root_test.go
@@ -4,6 +4,10 @@ import (
"fmt"
"io"
"testing"
+
+ "github.com/spf13/cobra"
+
+ "github.com/netbirdio/netbird/client/iface"
)
func TestInitCommands(t *testing.T) {
@@ -34,3 +38,44 @@ func TestInitCommands(t *testing.T) {
})
}
}
+
+func TestSetFlagsFromEnvVars(t *testing.T) {
+ var cmd = &cobra.Command{
+ Use: "netbird",
+ Long: "test",
+ SilenceUsage: true,
+ Run: func(cmd *cobra.Command, args []string) {
+ SetFlagsFromEnvVars(cmd)
+ },
+ }
+
+ cmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
+ `comma separated list of external IPs to map to the Wireguard interface`)
+ cmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
+ cmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "Enable Rosenpass feature Rosenpass.")
+ cmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
+
+ t.Setenv("NB_EXTERNAL_IP_MAP", "abc,dec")
+ t.Setenv("NB_INTERFACE_NAME", "test-name")
+ t.Setenv("NB_ENABLE_ROSENPASS", "true")
+ t.Setenv("NB_WIREGUARD_PORT", "10000")
+ err := cmd.Execute()
+ if err != nil {
+ t.Fatalf("expected no error while running netbird command, got %v", err)
+ }
+ if len(natExternalIPs) != 2 {
+ t.Errorf("expected 2 external ips, got %d", len(natExternalIPs))
+ }
+ if natExternalIPs[0] != "abc" || natExternalIPs[1] != "dec" {
+ t.Errorf("expected abc,dec, got %s,%s", natExternalIPs[0], natExternalIPs[1])
+ }
+ if interfaceName != "test-name" {
+ t.Errorf("expected test-name, got %s", interfaceName)
+ }
+ if !rosenpassEnabled {
+ t.Errorf("expected rosenpassEnabled to be true, got false")
+ }
+ if wireguardPort != 10000 {
+ t.Errorf("expected wireguardPort to be 10000, got %d", wireguardPort)
+ }
+}
diff --git a/client/cmd/service.go b/client/cmd/service.go
index 5c60744f9..855eb30fa 100644
--- a/client/cmd/service.go
+++ b/client/cmd/service.go
@@ -2,18 +2,21 @@ package cmd
import (
"context"
+
"github.com/kardianos/service"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"github.com/netbirdio/netbird/client/internal"
+ "github.com/netbirdio/netbird/client/server"
)
type program struct {
- ctx context.Context
- cancel context.CancelFunc
- serv *grpc.Server
+ ctx context.Context
+ cancel context.CancelFunc
+ serv *grpc.Server
+ serverInstance *server.Server
}
func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
diff --git a/client/cmd/service_controller.go b/client/cmd/service_controller.go
index d416afaac..86546e31c 100644
--- a/client/cmd/service_controller.go
+++ b/client/cmd/service_controller.go
@@ -61,6 +61,8 @@ func (p *program) Start(svc service.Service) error {
}
proto.RegisterDaemonServiceServer(p.serv, serverInstance)
+ p.serverInstance = serverInstance
+
log.Printf("started daemon server: %v", split[1])
if err := p.serv.Serve(listen); err != nil {
log.Errorf("failed to serve daemon requests: %v", err)
@@ -70,6 +72,14 @@ func (p *program) Start(svc service.Service) error {
}
func (p *program) Stop(srv service.Service) error {
+ if p.serverInstance != nil {
+ in := new(proto.DownRequest)
+ _, err := p.serverInstance.Down(p.ctx, in)
+ if err != nil {
+ log.Errorf("failed to stop daemon: %v", err)
+ }
+ }
+
p.cancel()
if p.serv != nil {
diff --git a/client/cmd/status.go b/client/cmd/status.go
index d9b7a9c91..ed3daa2b5 100644
--- a/client/cmd/status.go
+++ b/client/cmd/status.go
@@ -31,9 +31,9 @@ type peerStateDetailOutput struct {
Status string `json:"status" yaml:"status"`
LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
ConnType string `json:"connectionType" yaml:"connectionType"`
- Direct bool `json:"direct" yaml:"direct"`
IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
IceCandidateEndpoint iceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"`
+ RelayAddress string `json:"relayAddress" yaml:"relayAddress"`
LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"`
TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"`
TransferSent int64 `json:"transferSent" yaml:"transferSent"`
@@ -335,16 +335,18 @@ func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput {
func mapPeers(peers []*proto.PeerState) peersStateOutput {
var peersStateDetail []peerStateDetailOutput
- localICE := ""
- remoteICE := ""
- localICEEndpoint := ""
- remoteICEEndpoint := ""
- connType := ""
peersConnected := 0
- lastHandshake := time.Time{}
- transferReceived := int64(0)
- transferSent := int64(0)
for _, pbPeerState := range peers {
+ localICE := ""
+ remoteICE := ""
+ localICEEndpoint := ""
+ remoteICEEndpoint := ""
+ relayServerAddress := ""
+ connType := ""
+ lastHandshake := time.Time{}
+ transferReceived := int64(0)
+ transferSent := int64(0)
+
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
if skipDetailByFilters(pbPeerState, isPeerConnected) {
continue
@@ -360,6 +362,7 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
if pbPeerState.Relayed {
connType = "Relayed"
}
+ relayServerAddress = pbPeerState.GetRelayAddress()
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
transferReceived = pbPeerState.GetBytesRx()
transferSent = pbPeerState.GetBytesTx()
@@ -372,7 +375,6 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
Status: pbPeerState.GetConnStatus(),
LastStatusUpdate: timeLocal,
ConnType: connType,
- Direct: pbPeerState.GetDirect(),
IceCandidateType: iceCandidateType{
Local: localICE,
Remote: remoteICE,
@@ -381,6 +383,7 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
Local: localICEEndpoint,
Remote: remoteICEEndpoint,
},
+ RelayAddress: relayServerAddress,
FQDN: pbPeerState.GetFqdn(),
LastWireguardHandshake: lastHandshake,
TransferReceived: transferReceived,
@@ -641,9 +644,9 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
" Status: %s\n"+
" -- detail --\n"+
" Connection type: %s\n"+
- " Direct: %t\n"+
" ICE candidate (Local/Remote): %s/%s\n"+
" ICE candidate endpoints (Local/Remote): %s/%s\n"+
+ " Relay server address: %s\n"+
" Last connection update: %s\n"+
" Last WireGuard handshake: %s\n"+
" Transfer status (received/sent) %s/%s\n"+
@@ -655,11 +658,11 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
peerState.PubKey,
peerState.Status,
peerState.ConnType,
- peerState.Direct,
localICE,
remoteICE,
localICEEndpoint,
remoteICEEndpoint,
+ peerState.RelayAddress,
timeAgo(peerState.LastStatusUpdate),
timeAgo(peerState.LastWireguardHandshake),
toIEC(peerState.TransferReceived),
@@ -802,6 +805,9 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
}
+
+ peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
+
for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeIPString(route)
}
diff --git a/client/cmd/status_test.go b/client/cmd/status_test.go
index 46620a956..ca43df8a5 100644
--- a/client/cmd/status_test.go
+++ b/client/cmd/status_test.go
@@ -37,7 +37,6 @@ var resp = &proto.StatusResponse{
ConnStatus: "Connected",
ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
Relayed: false,
- Direct: true,
LocalIceCandidateType: "",
RemoteIceCandidateType: "",
LocalIceCandidateEndpoint: "",
@@ -57,7 +56,6 @@ var resp = &proto.StatusResponse{
ConnStatus: "Connected",
ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
Relayed: true,
- Direct: false,
LocalIceCandidateType: "relay",
RemoteIceCandidateType: "prflx",
LocalIceCandidateEndpoint: "10.0.0.1:10001",
@@ -137,7 +135,6 @@ var overview = statusOutputOverview{
Status: "Connected",
LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC),
ConnType: "P2P",
- Direct: true,
IceCandidateType: iceCandidateType{
Local: "",
Remote: "",
@@ -161,7 +158,6 @@ var overview = statusOutputOverview{
Status: "Connected",
LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC),
ConnType: "Relayed",
- Direct: false,
IceCandidateType: iceCandidateType{
Local: "relay",
Remote: "prflx",
@@ -283,7 +279,6 @@ func TestParsingToJSON(t *testing.T) {
"status": "Connected",
"lastStatusUpdate": "2001-01-01T01:01:01Z",
"connectionType": "P2P",
- "direct": true,
"iceCandidateType": {
"local": "",
"remote": ""
@@ -292,6 +287,7 @@ func TestParsingToJSON(t *testing.T) {
"local": "",
"remote": ""
},
+ "relayAddress": "",
"lastWireguardHandshake": "2001-01-01T01:01:02Z",
"transferReceived": 200,
"transferSent": 100,
@@ -308,7 +304,6 @@ func TestParsingToJSON(t *testing.T) {
"status": "Connected",
"lastStatusUpdate": "2002-02-02T02:02:02Z",
"connectionType": "Relayed",
- "direct": false,
"iceCandidateType": {
"local": "relay",
"remote": "prflx"
@@ -317,6 +312,7 @@ func TestParsingToJSON(t *testing.T) {
"local": "10.0.0.1:10001",
"remote": "10.0.10.1:10002"
},
+ "relayAddress": "",
"lastWireguardHandshake": "2002-02-02T02:02:03Z",
"transferReceived": 2000,
"transferSent": 1000,
@@ -408,13 +404,13 @@ func TestParsingToYAML(t *testing.T) {
status: Connected
lastStatusUpdate: 2001-01-01T01:01:01Z
connectionType: P2P
- direct: true
iceCandidateType:
local: ""
remote: ""
iceCandidateEndpoint:
local: ""
remote: ""
+ relayAddress: ""
lastWireguardHandshake: 2001-01-01T01:01:02Z
transferReceived: 200
transferSent: 100
@@ -428,13 +424,13 @@ func TestParsingToYAML(t *testing.T) {
status: Connected
lastStatusUpdate: 2002-02-02T02:02:02Z
connectionType: Relayed
- direct: false
iceCandidateType:
local: relay
remote: prflx
iceCandidateEndpoint:
local: 10.0.0.1:10001
remote: 10.0.10.1:10002
+ relayAddress: ""
lastWireguardHandshake: 2002-02-02T02:02:03Z
transferReceived: 2000
transferSent: 1000
@@ -505,9 +501,9 @@ func TestParsingToDetail(t *testing.T) {
Status: Connected
-- detail --
Connection type: P2P
- Direct: true
ICE candidate (Local/Remote): -/-
ICE candidate endpoints (Local/Remote): -/-
+ Relay server address:
Last connection update: %s
Last WireGuard handshake: %s
Transfer status (received/sent) 200 B/100 B
@@ -521,9 +517,9 @@ func TestParsingToDetail(t *testing.T) {
Status: Connected
-- detail --
Connection type: Relayed
- Direct: false
ICE candidate (Local/Remote): relay/prflx
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
+ Relay server address:
Last connection update: %s
Last WireGuard handshake: %s
Transfer status (received/sent) 2.0 KiB/1000 B
diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go
index 984aa6df7..f0dc8bf21 100644
--- a/client/cmd/testutil_test.go
+++ b/client/cmd/testutil_test.go
@@ -57,7 +57,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err)
}
s := grpc.NewServer()
- srv, err := sig.NewServer(otel.Meter(""))
+ srv, err := sig.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
sigProto.RegisterSignalExchangeServer(s, srv)
@@ -98,8 +98,9 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
if err != nil {
t.Fatal(err)
}
- turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
- mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
+
+ secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
+ mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil)
if err != nil {
t.Fatal(err)
}
diff --git a/client/cmd/up.go b/client/cmd/up.go
index 2ed6e41d2..05ecce9e0 100644
--- a/client/cmd/up.go
+++ b/client/cmd/up.go
@@ -15,11 +15,11 @@ import (
gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
+ "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
- "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util"
)
@@ -168,7 +168,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
ctx, cancel = context.WithCancel(ctx)
SetupCloseHandler(ctx, cancel)
- connectClient := internal.NewConnectClient(ctx, config, peer.NewRecorder(config.ManagementURL.String()))
+ r := peer.NewRecorder(config.ManagementURL.String())
+ r.GetFullStatus()
+
+ connectClient := internal.NewConnectClient(ctx, config, r)
return connectClient.Run()
}
diff --git a/client/errors/errors.go b/client/errors/errors.go
index cef999ac8..8faadbda5 100644
--- a/client/errors/errors.go
+++ b/client/errors/errors.go
@@ -8,8 +8,8 @@ import (
)
func formatError(es []error) string {
- if len(es) == 0 {
- return fmt.Sprintf("0 error occurred:\n\t* %s", es[0])
+ if len(es) == 1 {
+ return fmt.Sprintf("1 error occurred:\n\t* %s", es[0])
}
points := make([]string, len(es))
diff --git a/client/firewall/iface.go b/client/firewall/iface.go
index 882daef75..f349f9210 100644
--- a/client/firewall/iface.go
+++ b/client/firewall/iface.go
@@ -1,11 +1,13 @@
package firewall
-import "github.com/netbirdio/netbird/iface"
+import (
+ "github.com/netbirdio/netbird/client/iface/device"
+)
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
Name() string
- Address() iface.WGAddress
+ Address() device.WGAddress
IsUserspaceBind() bool
- SetFilter(iface.PacketFilter) error
+ SetFilter(device.PacketFilter) error
}
diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go
index b77cc8f43..c6a96a876 100644
--- a/client/firewall/iptables/acl_linux.go
+++ b/client/firewall/iptables/acl_linux.go
@@ -19,24 +19,22 @@ const (
// rules chains contains the effective ACL rules
chainNameInputRules = "NETBIRD-ACL-INPUT"
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
-
- postRoutingMark = "0x000007e4"
)
type aclManager struct {
- iptablesClient *iptables.IPTables
- wgIface iFaceMapper
- routeingFwChainName string
+ iptablesClient *iptables.IPTables
+ wgIface iFaceMapper
+ routingFwChainName string
entries map[string][][]string
ipsetStore *ipsetStore
}
-func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routeingFwChainName string) (*aclManager, error) {
+func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
m := &aclManager{
- iptablesClient: iptablesClient,
- wgIface: wgIface,
- routeingFwChainName: routeingFwChainName,
+ iptablesClient: iptablesClient,
+ wgIface: wgIface,
+ routingFwChainName: routingFwChainName,
entries: make(map[string][][]string),
ipsetStore: newIpsetStore(),
@@ -61,7 +59,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, route
return m, nil
}
-func (m *aclManager) AddFiltering(
+func (m *aclManager) AddPeerFiltering(
ip net.IP,
protocol firewall.Protocol,
sPort *firewall.Port,
@@ -127,7 +125,7 @@ func (m *aclManager) AddFiltering(
return nil, fmt.Errorf("rule already exists")
}
- if err := m.iptablesClient.Insert("filter", chain, 1, specs...); err != nil {
+ if err := m.iptablesClient.Append("filter", chain, specs...); err != nil {
return nil, err
}
@@ -139,28 +137,16 @@ func (m *aclManager) AddFiltering(
chain: chain,
}
- if !shouldAddToPrerouting(protocol, dPort, direction) {
- return []firewall.Rule{rule}, nil
- }
-
- rulePrerouting, err := m.addPreroutingFilter(ipsetName, string(protocol), dPortVal, ip)
- if err != nil {
- return []firewall.Rule{rule}, err
- }
- return []firewall.Rule{rule, rulePrerouting}, nil
+ return []firewall.Rule{rule}, nil
}
-// DeleteRule from the firewall by rule definition
-func (m *aclManager) DeleteRule(rule firewall.Rule) error {
+// DeletePeerRule from the firewall by rule definition
+func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
r, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
}
- if r.chain == "PREROUTING" {
- goto DELETERULE
- }
-
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
// delete IP from ruleset IPs list and ipset
if _, ok := ipsetList.ips[r.ip]; ok {
@@ -185,14 +171,7 @@ func (m *aclManager) DeleteRule(rule firewall.Rule) error {
}
}
-DELETERULE:
- var table string
- if r.chain == "PREROUTING" {
- table = "mangle"
- } else {
- table = "filter"
- }
- err := m.iptablesClient.Delete(table, r.chain, r.specs...)
+ err := m.iptablesClient.Delete(tableName, r.chain, r.specs...)
if err != nil {
log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err)
}
@@ -203,44 +182,6 @@ func (m *aclManager) Reset() error {
return m.cleanChains()
}
-func (m *aclManager) addPreroutingFilter(ipsetName string, protocol string, port string, ip net.IP) (*Rule, error) {
- var src []string
- if ipsetName != "" {
- src = []string{"-m", "set", "--set", ipsetName, "src"}
- } else {
- src = []string{"-s", ip.String()}
- }
- specs := []string{
- "-d", m.wgIface.Address().IP.String(),
- "-p", protocol,
- "--dport", port,
- "-j", "MARK", "--set-mark", postRoutingMark,
- }
-
- specs = append(src, specs...)
-
- ok, err := m.iptablesClient.Exists("mangle", "PREROUTING", specs...)
- if err != nil {
- return nil, fmt.Errorf("failed to check rule: %w", err)
- }
- if ok {
- return nil, fmt.Errorf("rule already exists")
- }
-
- if err := m.iptablesClient.Insert("mangle", "PREROUTING", 1, specs...); err != nil {
- return nil, err
- }
-
- rule := &Rule{
- ruleID: uuid.New().String(),
- specs: specs,
- ipsetName: ipsetName,
- ip: ip.String(),
- chain: "PREROUTING",
- }
- return rule, nil
-}
-
// todo write less destructive cleanup mechanism
func (m *aclManager) cleanChains() error {
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules)
@@ -291,25 +232,6 @@ func (m *aclManager) cleanChains() error {
}
}
- ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
- if err != nil {
- log.Debugf("failed to list chains: %s", err)
- return err
- }
- if ok {
- for _, rule := range m.entries["PREROUTING"] {
- err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...)
- if err != nil {
- log.Errorf("failed to delete rule: %v, %s", rule, err)
- }
- }
- err = m.iptablesClient.ClearChain("mangle", "PREROUTING")
- if err != nil {
- log.Debugf("failed to clear %s chain: %s", "PREROUTING", err)
- return err
- }
- }
-
for _, ipsetName := range m.ipsetStore.ipsetNames() {
if err := ipset.Flush(ipsetName); err != nil {
log.Errorf("flush ipset %q during reset: %v", ipsetName, err)
@@ -338,17 +260,9 @@ func (m *aclManager) createDefaultChains() error {
for chainName, rules := range m.entries {
for _, rule := range rules {
- if chainName == "FORWARD" {
- // position 2 because we add it after router's, jump rule
- if err := m.iptablesClient.InsertUnique(tableName, "FORWARD", 2, rule...); err != nil {
- log.Debugf("failed to create input chain jump rule: %s", err)
- return err
- }
- } else {
- if err := m.iptablesClient.AppendUnique(tableName, chainName, rule...); err != nil {
- log.Debugf("failed to create input chain jump rule: %s", err)
- return err
- }
+ if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
+ log.Debugf("failed to create input chain jump rule: %s", err)
+ return err
}
}
}
@@ -356,40 +270,29 @@ func (m *aclManager) createDefaultChains() error {
return nil
}
+// seedInitialEntries adds default rules to the entries map, rules are inserted on pos 1, hence the order is reversed.
+// We want to make sure our traffic is not dropped by existing rules.
+
+// The existing FORWARD rules/policies decide outbound traffic towards our interface.
+// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
+
+// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
func (m *aclManager) seedInitialEntries() {
- m.appendToEntries("INPUT",
- []string{"-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
- m.appendToEntries("INPUT",
- []string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
-
- m.appendToEntries("INPUT",
- []string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameInputRules})
+ established := getConntrackEstablished()
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
-
- m.appendToEntries("OUTPUT",
- []string{"-o", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
-
- m.appendToEntries("OUTPUT",
- []string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
-
- m.appendToEntries("OUTPUT",
- []string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameOutputRules})
+ m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
+ m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
+ m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", chainNameOutputRules})
+ m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
+ m.appendToEntries("OUTPUT", append([]string{"-o", m.wgIface.Name()}, established...))
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
- m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
- m.appendToEntries("FORWARD",
- []string{"-o", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
- m.appendToEntries("FORWARD",
- []string{"-i", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
- m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", m.routeingFwChainName})
- m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routeingFwChainName})
-
- m.appendToEntries("PREROUTING",
- []string{"-t", "mangle", "-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().IP.String(), "-m", "mark", "--mark", postRoutingMark})
+ m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
+ m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
}
func (m *aclManager) appendToEntries(chainName string, spec []string) {
@@ -456,18 +359,3 @@ func transformIPsetName(ipsetName string, sPort, dPort string) string {
return ipsetName
}
}
-
-func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
- if proto == "all" {
- return false
- }
-
- if direction != firewall.RuleDirectionIN {
- return false
- }
-
- if dPort == nil {
- return false
- }
- return true
-}
diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go
index 2d231ec45..6fefd58e6 100644
--- a/client/firewall/iptables/manager_linux.go
+++ b/client/firewall/iptables/manager_linux.go
@@ -4,13 +4,14 @@ import (
"context"
"fmt"
"net"
+ "net/netip"
"sync"
"github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
- "github.com/netbirdio/netbird/iface"
+ "github.com/netbirdio/netbird/client/iface"
)
// Manager of iptables firewall
@@ -21,7 +22,7 @@ type Manager struct {
ipv4Client *iptables.IPTables
aclMgr *aclManager
- router *routerManager
+ router *router
}
// iFaceMapper defines subset methods of interface required for manager
@@ -43,12 +44,12 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
ipv4Client: iptablesClient,
}
- m.router, err = newRouterManager(context, iptablesClient)
+ m.router, err = newRouter(context, iptablesClient, wgIface)
if err != nil {
log.Debugf("failed to initialize route related chains: %s", err)
return nil, err
}
- m.aclMgr, err = newAclManager(iptablesClient, wgIface, m.router.RouteingFwChainName())
+ m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
if err != nil {
log.Debugf("failed to initialize ACL manager: %s", err)
return nil, err
@@ -57,10 +58,10 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
return m, nil
}
-// AddFiltering rule to the firewall
+// AddPeerFiltering adds a rule to the firewall
//
// Comment will be ignored because some system this feature is not supported
-func (m *Manager) AddFiltering(
+func (m *Manager) AddPeerFiltering(
ip net.IP,
protocol firewall.Protocol,
sPort *firewall.Port,
@@ -73,33 +74,62 @@ func (m *Manager) AddFiltering(
m.mutex.Lock()
defer m.mutex.Unlock()
- return m.aclMgr.AddFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
+ return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
}
-// DeleteRule from the firewall by rule definition
-func (m *Manager) DeleteRule(rule firewall.Rule) error {
+func (m *Manager) AddRouteFiltering(
+ sources [] netip.Prefix,
+ destination netip.Prefix,
+ proto firewall.Protocol,
+ sPort *firewall.Port,
+ dPort *firewall.Port,
+ action firewall.Action,
+) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
- return m.aclMgr.DeleteRule(rule)
+ if !destination.Addr().Is4() {
+ return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
+ }
+
+ return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
+}
+
+// DeletePeerRule from the firewall by rule definition
+func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
+ m.mutex.Lock()
+ defer m.mutex.Unlock()
+
+ return m.aclMgr.DeletePeerRule(rule)
+}
+
+func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
+ m.mutex.Lock()
+ defer m.mutex.Unlock()
+
+ return m.router.DeleteRouteRule(rule)
}
func (m *Manager) IsServerRouteSupported() bool {
return true
}
-func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
+func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
- return m.router.InsertRoutingRules(pair)
+ return m.router.AddNatRule(pair)
}
-func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
+func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
- return m.router.RemoveRoutingRules(pair)
+ return m.router.RemoveNatRule(pair)
+}
+
+func (m *Manager) SetLegacyManagement(isLegacy bool) error {
+ return firewall.SetLegacyManagement(m.router, isLegacy)
}
// Reset firewall to the default state
@@ -125,7 +155,7 @@ func (m *Manager) AllowNetbird() error {
return nil
}
- _, err := m.AddFiltering(
+ _, err := m.AddPeerFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
@@ -138,7 +168,7 @@ func (m *Manager) AllowNetbird() error {
if err != nil {
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
}
- _, err = m.AddFiltering(
+ _, err = m.AddPeerFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
@@ -153,3 +183,7 @@ func (m *Manager) AllowNetbird() error {
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
+
+func getConntrackEstablished() []string {
+ return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
+}
diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go
index ceb116c62..498d8f58b 100644
--- a/client/firewall/iptables/manager_linux_test.go
+++ b/client/firewall/iptables/manager_linux_test.go
@@ -11,9 +11,24 @@ import (
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
- "github.com/netbirdio/netbird/iface"
+ "github.com/netbirdio/netbird/client/iface"
)
+var ifaceMock = &iFaceMock{
+ NameFunc: func() string {
+ return "lo"
+ },
+ AddressFunc: func() iface.WGAddress {
+ return iface.WGAddress{
+ IP: net.ParseIP("10.20.0.1"),
+ Network: &net.IPNet{
+ IP: net.ParseIP("10.20.0.0"),
+ Mask: net.IPv4Mask(255, 255, 255, 0),
+ },
+ }
+ },
+}
+
// iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct {
NameFunc func() string
@@ -40,23 +55,8 @@ func TestIptablesManager(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)
- mock := &iFaceMock{
- NameFunc: func() string {
- return "lo"
- },
- AddressFunc: func() iface.WGAddress {
- return iface.WGAddress{
- IP: net.ParseIP("10.20.0.1"),
- Network: &net.IPNet{
- IP: net.ParseIP("10.20.0.0"),
- Mask: net.IPv4Mask(255, 255, 255, 0),
- },
- }
- },
- }
-
// just check on the local interface
- manager, err := Create(context.Background(), mock)
+ manager, err := Create(context.Background(), ifaceMock)
require.NoError(t, err)
time.Sleep(time.Second)
@@ -72,7 +72,7 @@ func TestIptablesManager(t *testing.T) {
t.Run("add first rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
- rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
+ rule1, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
for _, r := range rule1 {
@@ -87,7 +87,7 @@ func TestIptablesManager(t *testing.T) {
port := &fw.Port{
Values: []int{8043: 8046},
}
- rule2, err = manager.AddFiltering(
+ rule2, err = manager.AddPeerFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
require.NoError(t, err, "failed to add rule")
@@ -99,7 +99,7 @@ func TestIptablesManager(t *testing.T) {
t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 {
- err := manager.DeleteRule(r)
+ err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
@@ -108,7 +108,7 @@ func TestIptablesManager(t *testing.T) {
t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 {
- err := manager.DeleteRule(r)
+ err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
}
@@ -119,7 +119,7 @@ func TestIptablesManager(t *testing.T) {
// add second rule
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []int{5353}}
- _, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
+ _, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
require.NoError(t, err, "failed to add rule")
err = manager.Reset()
@@ -170,7 +170,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
t.Run("add first rule with set", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
- rule1, err = manager.AddFiltering(
+ rule1, err = manager.AddPeerFiltering(
ip, "tcp", nil, port, fw.RuleDirectionOUT,
fw.ActionAccept, "default", "accept HTTP traffic",
)
@@ -189,7 +189,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
port := &fw.Port{
Values: []int{443},
}
- rule2, err = manager.AddFiltering(
+ rule2, err = manager.AddPeerFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
"default", "accept HTTPS traffic from ports range",
)
@@ -202,7 +202,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 {
- err := manager.DeleteRule(r)
+ err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
@@ -211,7 +211,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 {
- err := manager.DeleteRule(r)
+ err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
@@ -269,9 +269,9 @@ func TestIptablesCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 {
- _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
+ _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
- _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
+ _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
require.NoError(t, err, "failed to add rule")
diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go
index e8f09a106..737b20785 100644
--- a/client/firewall/iptables/router_linux.go
+++ b/client/firewall/iptables/router_linux.go
@@ -5,368 +5,478 @@ package iptables
import (
"context"
"fmt"
+ "net/netip"
+ "strconv"
"strings"
"github.com/coreos/go-iptables/iptables"
+ "github.com/hashicorp/go-multierror"
+ "github.com/nadoo/ipset"
log "github.com/sirupsen/logrus"
+ nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
+ "github.com/netbirdio/netbird/client/internal/acl/id"
+ "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
)
const (
- Ipv4Forwarding = "netbird-rt-forwarding"
- ipv4Nat = "netbird-rt-nat"
+ ipv4Nat = "netbird-rt-nat"
)
// constants needed to manage and create iptable rules
const (
tableFilter = "filter"
tableNat = "nat"
- chainFORWARD = "FORWARD"
chainPOSTROUTING = "POSTROUTING"
chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWD = "NETBIRD-RT-FWD"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
+
+ matchSet = "--match-set"
)
-type routerManager struct {
- ctx context.Context
- stop context.CancelFunc
- iptablesClient *iptables.IPTables
- rules map[string][]string
+type routeFilteringRuleParams struct {
+ Sources []netip.Prefix
+ Destination netip.Prefix
+ Proto firewall.Protocol
+ SPort *firewall.Port
+ DPort *firewall.Port
+ Direction firewall.RuleDirection
+ Action firewall.Action
+ SetName string
}
-func newRouterManager(parentCtx context.Context, iptablesClient *iptables.IPTables) (*routerManager, error) {
+type router struct {
+ ctx context.Context
+ stop context.CancelFunc
+ iptablesClient *iptables.IPTables
+ rules map[string][]string
+ ipsetCounter *refcounter.Counter[string, []netip.Prefix, struct{}]
+ wgIface iFaceMapper
+ legacyManagement bool
+}
+
+func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
ctx, cancel := context.WithCancel(parentCtx)
- m := &routerManager{
+ r := &router{
ctx: ctx,
stop: cancel,
iptablesClient: iptablesClient,
rules: make(map[string][]string),
+ wgIface: wgIface,
}
- err := m.cleanUpDefaultForwardRules()
+ r.ipsetCounter = refcounter.New(
+ r.createIpSet,
+ func(name string, _ struct{}) error {
+ return r.deleteIpSet(name)
+ },
+ )
+
+ if err := ipset.Init(); err != nil {
+ return nil, fmt.Errorf("init ipset: %w", err)
+ }
+
+ err := r.cleanUpDefaultForwardRules()
if err != nil {
- log.Errorf("failed to cleanup routing rules: %s", err)
+ log.Errorf("cleanup routing rules: %s", err)
return nil, err
}
- err = m.createContainers()
+ err = r.createContainers()
if err != nil {
- log.Errorf("failed to create containers for route: %s", err)
+ log.Errorf("create containers for route: %s", err)
}
- return m, err
+ return r, err
}
-// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain
-func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error {
- err := i.insertRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, pair)
- if err != nil {
- return err
+func (r *router) AddRouteFiltering(
+ sources []netip.Prefix,
+ destination netip.Prefix,
+ proto firewall.Protocol,
+ sPort *firewall.Port,
+ dPort *firewall.Port,
+ action firewall.Action,
+) (firewall.Rule, error) {
+ ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
+ if _, ok := r.rules[string(ruleKey)]; ok {
+ return ruleKey, nil
}
- err = i.insertRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, firewall.GetInPair(pair))
- if err != nil {
- return err
+ var setName string
+ if len(sources) > 1 {
+ setName = firewall.GenerateSetName(sources)
+ if _, err := r.ipsetCounter.Increment(setName, sources); err != nil {
+ return nil, fmt.Errorf("create or get ipset: %w", err)
+ }
+ }
+
+ params := routeFilteringRuleParams{
+ Sources: sources,
+ Destination: destination,
+ Proto: proto,
+ SPort: sPort,
+ DPort: dPort,
+ Action: action,
+ SetName: setName,
+ }
+
+ rule := genRouteFilteringRuleSpec(params)
+ if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
+ return nil, fmt.Errorf("add route rule: %v", err)
+ }
+
+ r.rules[string(ruleKey)] = rule
+
+ return ruleKey, nil
+}
+
+func (r *router) DeleteRouteRule(rule firewall.Rule) error {
+ ruleKey := rule.GetRuleID()
+
+ if rule, exists := r.rules[ruleKey]; exists {
+ setName := r.findSetNameInRule(rule)
+
+ if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil {
+ return fmt.Errorf("delete route rule: %v", err)
+ }
+ delete(r.rules, ruleKey)
+
+ if setName != "" {
+ if _, err := r.ipsetCounter.Decrement(setName); err != nil {
+ return fmt.Errorf("failed to remove ipset: %w", err)
+ }
+ }
+ } else {
+ log.Debugf("route rule %s not found", ruleKey)
+ }
+
+ return nil
+}
+
+func (r *router) findSetNameInRule(rule []string) string {
+ for i, arg := range rule {
+ if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
+ return rule[i+3]
+ }
+ }
+ return ""
+}
+
+func (r *router) createIpSet(setName string, sources []netip.Prefix) (struct{}, error) {
+ if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
+ return struct{}{}, fmt.Errorf("create set %s: %w", setName, err)
+ }
+
+ for _, prefix := range sources {
+ if err := ipset.AddPrefix(setName, prefix); err != nil {
+ return struct{}{}, fmt.Errorf("add element to set %s: %w", setName, err)
+ }
+ }
+
+ return struct{}{}, nil
+}
+
+func (r *router) deleteIpSet(setName string) error {
+ if err := ipset.Destroy(setName); err != nil {
+ return fmt.Errorf("destroy set %s: %w", setName, err)
+ }
+ return nil
+}
+
+// AddNatRule inserts an iptables rule pair into the nat chain
+func (r *router) AddNatRule(pair firewall.RouterPair) error {
+ if r.legacyManagement {
+ log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
+ if err := r.addLegacyRouteRule(pair); err != nil {
+ return fmt.Errorf("add legacy routing rule: %w", err)
+ }
}
if !pair.Masquerade {
return nil
}
- err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair)
- if err != nil {
- return err
+ if err := r.addNatRule(pair); err != nil {
+ return fmt.Errorf("add nat rule: %w", err)
}
- err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair))
- if err != nil {
- return err
+ if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
+ return fmt.Errorf("add inverse nat rule: %w", err)
}
return nil
}
-// insertRoutingRule inserts an iptables rule
-func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
- var err error
+// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
+func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
+ if err := r.removeNatRule(pair); err != nil {
+ return fmt.Errorf("remove nat rule: %w", err)
+ }
- ruleKey := firewall.GenKey(keyFormat, pair.ID)
- rule := genRuleSpec(jump, pair.Source, pair.Destination)
- existingRule, found := i.rules[ruleKey]
- if found {
- err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
- if err != nil {
- return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
+ if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
+ return fmt.Errorf("remove inverse nat rule: %w", err)
+ }
+
+ if err := r.removeLegacyRouteRule(pair); err != nil {
+ return fmt.Errorf("remove legacy routing rule: %w", err)
+ }
+
+ return nil
+}
+
+// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
+func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
+ ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
+
+ if err := r.removeLegacyRouteRule(pair); err != nil {
+ return err
+ }
+
+ rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
+ if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
+ return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
+ }
+
+ r.rules[ruleKey] = rule
+
+ return nil
+}
+
+func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
+ ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
+
+ if rule, exists := r.rules[ruleKey]; exists {
+ if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
+ return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
- delete(i.rules, ruleKey)
- }
-
- err = i.iptablesClient.Insert(table, chain, 1, rule...)
- if err != nil {
- return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
- }
-
- i.rules[ruleKey] = rule
-
- return nil
-}
-
-// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains
-func (i *routerManager) RemoveRoutingRules(pair firewall.RouterPair) error {
- err := i.removeRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, pair)
- if err != nil {
- return err
- }
-
- err = i.removeRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, firewall.GetInPair(pair))
- if err != nil {
- return err
- }
-
- if !pair.Masquerade {
- return nil
- }
-
- err = i.removeRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, pair)
- if err != nil {
- return err
- }
-
- err = i.removeRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, firewall.GetInPair(pair))
- if err != nil {
- return err
+ delete(r.rules, ruleKey)
+ } else {
+ log.Debugf("legacy forwarding rule %s not found", ruleKey)
}
return nil
}
-func (i *routerManager) removeRoutingRule(keyFormat, table, chain string, pair firewall.RouterPair) error {
- var err error
+// GetLegacyManagement returns the current legacy management mode
+func (r *router) GetLegacyManagement() bool {
+ return r.legacyManagement
+}
- ruleKey := firewall.GenKey(keyFormat, pair.ID)
- existingRule, found := i.rules[ruleKey]
- if found {
- err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
- if err != nil {
- return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
+// SetLegacyManagement sets the route manager to use legacy management mode
+func (r *router) SetLegacyManagement(isLegacy bool) {
+ r.legacyManagement = isLegacy
+}
+
+// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
+func (r *router) RemoveAllLegacyRouteRules() error {
+ var merr *multierror.Error
+ for k, rule := range r.rules {
+ if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
+ continue
+ }
+ if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
+ merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
}
}
- delete(i.rules, ruleKey)
-
- return nil
+ return nberrors.FormatErrorOrNil(merr)
}
-func (i *routerManager) RouteingFwChainName() string {
- return chainRTFWD
-}
-
-func (i *routerManager) Reset() error {
- err := i.cleanUpDefaultForwardRules()
- if err != nil {
- return err
+func (r *router) Reset() error {
+ var merr *multierror.Error
+ if err := r.cleanUpDefaultForwardRules(); err != nil {
+ merr = multierror.Append(merr, err)
}
- i.rules = make(map[string][]string)
- return nil
+ r.rules = make(map[string][]string)
+
+ if err := r.ipsetCounter.Flush(); err != nil {
+ merr = multierror.Append(merr, err)
+ }
+
+ return nberrors.FormatErrorOrNil(merr)
}
-func (i *routerManager) cleanUpDefaultForwardRules() error {
- err := i.cleanJumpRules()
+func (r *router) cleanUpDefaultForwardRules() error {
+ err := r.cleanJumpRules()
if err != nil {
return err
}
log.Debug("flushing routing related tables")
- ok, err := i.iptablesClient.ChainExists(tableFilter, chainRTFWD)
- if err != nil {
- log.Errorf("failed check chain %s,error: %v", chainRTFWD, err)
- return err
- } else if ok {
- err = i.iptablesClient.ClearAndDeleteChain(tableFilter, chainRTFWD)
+ for _, chain := range []string{chainRTFWD, chainRTNAT} {
+ table := tableFilter
+ if chain == chainRTNAT {
+ table = tableNat
+ }
+
+ ok, err := r.iptablesClient.ChainExists(table, chain)
if err != nil {
- log.Errorf("failed cleaning chain %s,error: %v", chainRTFWD, err)
+ log.Errorf("failed check chain %s, error: %v", chain, err)
return err
+ } else if ok {
+ err = r.iptablesClient.ClearAndDeleteChain(table, chain)
+ if err != nil {
+ log.Errorf("failed cleaning chain %s, error: %v", chain, err)
+ return err
+ }
}
}
- ok, err = i.iptablesClient.ChainExists(tableNat, chainRTNAT)
- if err != nil {
- log.Errorf("failed check chain %s,error: %v", chainRTNAT, err)
- return err
- } else if ok {
- err = i.iptablesClient.ClearAndDeleteChain(tableNat, chainRTNAT)
- if err != nil {
- log.Errorf("failed cleaning chain %s,error: %v", chainRTNAT, err)
- return err
- }
- }
- return nil
-}
-
-func (i *routerManager) createContainers() error {
- if i.rules[Ipv4Forwarding] != nil {
- return nil
- }
-
- errMSGFormat := "failed creating chain %s,error: %v"
- err := i.createChain(tableFilter, chainRTFWD)
- if err != nil {
- return fmt.Errorf(errMSGFormat, chainRTFWD, err)
- }
-
- err = i.createChain(tableNat, chainRTNAT)
- if err != nil {
- return fmt.Errorf(errMSGFormat, chainRTNAT, err)
- }
-
- err = i.addJumpRules()
- if err != nil {
- return fmt.Errorf("error while creating jump rules: %v", err)
- }
-
return nil
}
-// addJumpRules create jump rules to send packets to NetBird chains
-func (i *routerManager) addJumpRules() error {
- rule := []string{"-j", chainRTFWD}
- err := i.iptablesClient.Insert(tableFilter, chainFORWARD, 1, rule...)
+func (r *router) createContainers() error {
+ for _, chain := range []string{chainRTFWD, chainRTNAT} {
+ if err := r.createAndSetupChain(chain); err != nil {
+ return fmt.Errorf("create chain %s: %v", chain, err)
+ }
+ }
+
+ if err := r.insertEstablishedRule(chainRTFWD); err != nil {
+ return fmt.Errorf("insert established rule: %v", err)
+ }
+
+ return r.addJumpRules()
+}
+
+func (r *router) createAndSetupChain(chain string) error {
+ table := r.getTableForChain(chain)
+
+ if err := r.iptablesClient.NewChain(table, chain); err != nil {
+ return fmt.Errorf("failed creating chain %s, error: %v", chain, err)
+ }
+
+ return nil
+}
+
+func (r *router) getTableForChain(chain string) string {
+ if chain == chainRTNAT {
+ return tableNat
+ }
+ return tableFilter
+}
+
+func (r *router) insertEstablishedRule(chain string) error {
+ establishedRule := getConntrackEstablished()
+
+ err := r.iptablesClient.Insert(tableFilter, chain, 1, establishedRule...)
+ if err != nil {
+ return fmt.Errorf("failed to insert established rule: %v", err)
+ }
+
+ ruleKey := "established-" + chain
+ r.rules[ruleKey] = establishedRule
+
+ return nil
+}
+
+func (r *router) addJumpRules() error {
+ rule := []string{"-j", chainRTNAT}
+ err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
if err != nil {
return err
}
- i.rules[Ipv4Forwarding] = rule
-
- rule = []string{"-j", chainRTNAT}
- err = i.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
- if err != nil {
- return err
- }
- i.rules[ipv4Nat] = rule
+ r.rules[ipv4Nat] = rule
return nil
}
-// cleanJumpRules cleans jump rules that was sending packets to NetBird chains
-func (i *routerManager) cleanJumpRules() error {
- var err error
- errMSGFormat := "failed cleaning rule from chain %s,err: %v"
- rule, found := i.rules[Ipv4Forwarding]
+func (r *router) cleanJumpRules() error {
+ rule, found := r.rules[ipv4Nat]
if found {
- err = i.iptablesClient.DeleteIfExists(tableFilter, chainFORWARD, rule...)
+ err := r.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
if err != nil {
- return fmt.Errorf(errMSGFormat, chainFORWARD, err)
- }
- }
- rule, found = i.rules[ipv4Nat]
- if found {
- err = i.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
- if err != nil {
- return fmt.Errorf(errMSGFormat, chainPOSTROUTING, err)
+ return fmt.Errorf("failed cleaning rule from chain %s, err: %v", chainPOSTROUTING, err)
}
}
- rules, err := i.iptablesClient.List("nat", "POSTROUTING")
- if err != nil {
- return fmt.Errorf("failed to list rules: %s", err)
- }
-
- for _, ruleString := range rules {
- if !strings.Contains(ruleString, "NETBIRD") {
- continue
- }
- rule := strings.Fields(ruleString)
- err := i.iptablesClient.DeleteIfExists("nat", "POSTROUTING", rule[2:]...)
- if err != nil {
- return fmt.Errorf("failed to delete postrouting jump rule: %s", err)
- }
- }
-
- rules, err = i.iptablesClient.List(tableFilter, "FORWARD")
- if err != nil {
- return fmt.Errorf("failed to list rules in FORWARD chain: %s", err)
- }
-
- for _, ruleString := range rules {
- if !strings.Contains(ruleString, "NETBIRD") {
- continue
- }
- rule := strings.Fields(ruleString)
- err := i.iptablesClient.DeleteIfExists(tableFilter, "FORWARD", rule[2:]...)
- if err != nil {
- return fmt.Errorf("failed to delete FORWARD jump rule: %s", err)
- }
- }
return nil
}
-func (i *routerManager) createChain(table, newChain string) error {
- chains, err := i.iptablesClient.ListChains(table)
- if err != nil {
- return fmt.Errorf("couldn't get %s table chains, error: %v", table, err)
- }
+func (r *router) addNatRule(pair firewall.RouterPair) error {
+ ruleKey := firewall.GenKey(firewall.NatFormat, pair)
- shouldCreateChain := true
- for _, chain := range chains {
- if chain == newChain {
- shouldCreateChain = false
- }
- }
-
- if shouldCreateChain {
- err = i.iptablesClient.NewChain(table, newChain)
- if err != nil {
- return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
- }
-
- // Add the loopback return rule to the NAT chain
- loopbackRule := []string{"-o", "lo", "-j", "RETURN"}
- err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...)
- if err != nil {
- return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err)
- }
-
- err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
- if err != nil {
- return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
- }
-
- }
- return nil
-}
-
-// addNATRule appends an iptables rule pair to the nat chain
-func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
- ruleKey := firewall.GenKey(keyFormat, pair.ID)
- rule := genRuleSpec(jump, pair.Source, pair.Destination)
- existingRule, found := i.rules[ruleKey]
- if found {
- err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
- if err != nil {
+ if rule, exists := r.rules[ruleKey]; exists {
+ if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
}
- delete(i.rules, ruleKey)
+ delete(r.rules, ruleKey)
}
- // inserting after loopback ignore rule
- err := i.iptablesClient.Insert(table, chain, 2, rule...)
- if err != nil {
+ rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, r.wgIface.Name(), pair.Inverse)
+ if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule...); err != nil {
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
}
- i.rules[ruleKey] = rule
+ r.rules[ruleKey] = rule
return nil
}
-// genRuleSpec generates rule specification
-func genRuleSpec(jump, source, destination string) []string {
- return []string{"-s", source, "-d", destination, "-j", jump}
+func (r *router) removeNatRule(pair firewall.RouterPair) error {
+ ruleKey := firewall.GenKey(firewall.NatFormat, pair)
+
+ if rule, exists := r.rules[ruleKey]; exists {
+ if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil {
+ return fmt.Errorf("error while removing existing nat rule for %s: %v", pair.Destination, err)
+ }
+
+ delete(r.rules, ruleKey)
+ } else {
+ log.Debugf("nat rule %s not found", ruleKey)
+ }
+
+ return nil
}
-func getIptablesRuleType(table string) string {
- ruleType := "forwarding"
- if table == tableNat {
- ruleType = "nat"
+func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string {
+ intdir := "-i"
+ if inverse {
+ intdir = "-o"
}
- return ruleType
+ return []string{intdir, intf, "-s", source.String(), "-d", destination.String(), "-j", jump}
+}
+
+func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
+ var rule []string
+
+ if params.SetName != "" {
+ rule = append(rule, "-m", "set", matchSet, params.SetName, "src")
+ } else if len(params.Sources) > 0 {
+ source := params.Sources[0]
+ rule = append(rule, "-s", source.String())
+ }
+
+ rule = append(rule, "-d", params.Destination.String())
+
+ if params.Proto != firewall.ProtocolALL {
+ rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
+ rule = append(rule, applyPort("--sport", params.SPort)...)
+ rule = append(rule, applyPort("--dport", params.DPort)...)
+ }
+
+ rule = append(rule, "-j", actionToStr(params.Action))
+
+ return rule
+}
+
+func applyPort(flag string, port *firewall.Port) []string {
+ if port == nil {
+ return nil
+ }
+
+ if port.IsRange && len(port.Values) == 2 {
+ return []string{flag, fmt.Sprintf("%d:%d", port.Values[0], port.Values[1])}
+ }
+
+ if len(port.Values) > 1 {
+ portList := make([]string, len(port.Values))
+ for i, p := range port.Values {
+ portList[i] = strconv.Itoa(p)
+ }
+ return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
+ }
+
+ return []string{flag, strconv.Itoa(port.Values[0])}
}
diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go
index 79b970c36..6cede09e2 100644
--- a/client/firewall/iptables/router_linux_test.go
+++ b/client/firewall/iptables/router_linux_test.go
@@ -4,11 +4,13 @@ package iptables
import (
"context"
+ "net/netip"
"os/exec"
"testing"
"github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
+ "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -28,7 +30,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
- manager, err := newRouterManager(context.TODO(), iptablesClient)
+ manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
require.NoError(t, err, "should return a valid iptables manager")
defer func() {
@@ -37,26 +39,22 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
require.Len(t, manager.rules, 2, "should have created rules map")
- exists, err := manager.iptablesClient.Exists(tableFilter, chainFORWARD, manager.rules[Ipv4Forwarding]...)
- require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainFORWARD)
- require.True(t, exists, "forwarding rule should exist")
-
- exists, err = manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
+ exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
require.True(t, exists, "postrouting rule should exist")
pair := firewall.RouterPair{
ID: "abc",
- Source: "100.100.100.1/32",
- Destination: "100.100.100.0/24",
+ Source: netip.MustParsePrefix("100.100.100.1/32"),
+ Destination: netip.MustParsePrefix("100.100.100.0/24"),
Masquerade: true,
}
- forward4Rule := genRuleSpec(routingFinalForwardJump, pair.Source, pair.Destination)
+ forward4Rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...)
require.NoError(t, err, "inserting rule should not return error")
- nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination)
+ nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, ifaceMock.Name(), false)
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
require.NoError(t, err, "inserting rule should not return error")
@@ -65,7 +63,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
require.NoError(t, err, "shouldn't return error")
}
-func TestIptablesManager_InsertRoutingRules(t *testing.T) {
+func TestIptablesManager_AddNatRule(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
@@ -76,7 +74,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
- manager, err := newRouterManager(context.TODO(), iptablesClient)
+ manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
require.NoError(t, err, "shouldn't return error")
defer func() {
@@ -86,35 +84,13 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
}
}()
- err = manager.InsertRoutingRules(testCase.InputPair)
+ err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "forwarding pair should be inserted")
- forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
- forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination)
+ natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
+ natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
- exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
- require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
- require.True(t, exists, "forwarding rule should exist")
-
- foundRule, found := manager.rules[forwardRuleKey]
- require.True(t, found, "forwarding rule should exist in the manager map")
- require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
-
- inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
- inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
-
- exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
- require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
- require.True(t, exists, "income forwarding rule should exist")
-
- foundRule, found = manager.rules[inForwardRuleKey]
- require.True(t, found, "income forwarding rule should exist in the manager map")
- require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
-
- natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
- natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination)
-
- exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
+ exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
if testCase.InputPair.Masquerade {
require.True(t, exists, "nat rule should be created")
@@ -127,8 +103,8 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
require.False(t, foundNat, "nat rule should not exist in the map")
}
- inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
- inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
+ inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
+ inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
@@ -146,7 +122,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
}
}
-func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
+func TestIptablesManager_RemoveNatRule(t *testing.T) {
if !isIptablesSupported() {
t.SkipNow()
@@ -156,7 +132,7 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) {
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
- manager, err := newRouterManager(context.TODO(), iptablesClient)
+ manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock)
require.NoError(t, err, "shouldn't return error")
defer func() {
_ = manager.Reset()
@@ -164,26 +140,14 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
require.NoError(t, err, "shouldn't return error")
- forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
- forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination)
-
- err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...)
- require.NoError(t, err, "inserting rule should not return error")
-
- inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
- inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
-
- err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...)
- require.NoError(t, err, "inserting rule should not return error")
-
- natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
- natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination)
+ natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
+ natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
require.NoError(t, err, "inserting rule should not return error")
- inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
- inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
+ inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
+ inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
require.NoError(t, err, "inserting rule should not return error")
@@ -191,28 +155,14 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
- err = manager.RemoveRoutingRules(testCase.InputPair)
+ err = manager.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
- exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
- require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
- require.False(t, exists, "forwarding rule should not exist")
-
- _, found := manager.rules[forwardRuleKey]
- require.False(t, found, "forwarding rule should exist in the manager map")
-
- exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
- require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
- require.False(t, exists, "income forwarding rule should not exist")
-
- _, found = manager.rules[inForwardRuleKey]
- require.False(t, found, "income forwarding rule should exist in the manager map")
-
- exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
+ exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
require.False(t, exists, "nat rule should not exist")
- _, found = manager.rules[natRuleKey]
+ _, found := manager.rules[natRuleKey]
require.False(t, found, "nat rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
@@ -221,7 +171,175 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
_, found = manager.rules[inNatRuleKey]
require.False(t, found, "income nat rule should exist in the manager map")
-
+ })
+ }
+}
+
+func TestRouter_AddRouteFiltering(t *testing.T) {
+ if !isIptablesSupported() {
+ t.Skip("iptables not supported on this system")
+ }
+
+ iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
+ require.NoError(t, err, "Failed to create iptables client")
+
+ r, err := newRouter(context.Background(), iptablesClient, ifaceMock)
+ require.NoError(t, err, "Failed to create router manager")
+
+ defer func() {
+ err := r.Reset()
+ require.NoError(t, err, "Failed to reset router")
+ }()
+
+ tests := []struct {
+ name string
+ sources []netip.Prefix
+ destination netip.Prefix
+ proto firewall.Protocol
+ sPort *firewall.Port
+ dPort *firewall.Port
+ direction firewall.RuleDirection
+ action firewall.Action
+ expectSet bool
+ }{
+ {
+ name: "Basic TCP rule with single source",
+ sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
+ destination: netip.MustParsePrefix("10.0.0.0/24"),
+ proto: firewall.ProtocolTCP,
+ sPort: nil,
+ dPort: &firewall.Port{Values: []int{80}},
+ direction: firewall.RuleDirectionIN,
+ action: firewall.ActionAccept,
+ expectSet: false,
+ },
+ {
+ name: "UDP rule with multiple sources",
+ sources: []netip.Prefix{
+ netip.MustParsePrefix("172.16.0.0/16"),
+ netip.MustParsePrefix("192.168.0.0/16"),
+ },
+ destination: netip.MustParsePrefix("10.0.0.0/8"),
+ proto: firewall.ProtocolUDP,
+ sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
+ dPort: nil,
+ direction: firewall.RuleDirectionOUT,
+ action: firewall.ActionDrop,
+ expectSet: true,
+ },
+ {
+ name: "All protocols rule",
+ sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
+ destination: netip.MustParsePrefix("0.0.0.0/0"),
+ proto: firewall.ProtocolALL,
+ sPort: nil,
+ dPort: nil,
+ direction: firewall.RuleDirectionIN,
+ action: firewall.ActionAccept,
+ expectSet: false,
+ },
+ {
+ name: "ICMP rule",
+ sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
+ destination: netip.MustParsePrefix("10.0.0.0/8"),
+ proto: firewall.ProtocolICMP,
+ sPort: nil,
+ dPort: nil,
+ direction: firewall.RuleDirectionIN,
+ action: firewall.ActionAccept,
+ expectSet: false,
+ },
+ {
+ name: "TCP rule with multiple source ports",
+ sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
+ destination: netip.MustParsePrefix("192.168.0.0/16"),
+ proto: firewall.ProtocolTCP,
+ sPort: &firewall.Port{Values: []int{80, 443, 8080}},
+ dPort: nil,
+ direction: firewall.RuleDirectionOUT,
+ action: firewall.ActionAccept,
+ expectSet: false,
+ },
+ {
+ name: "UDP rule with single IP and port range",
+ sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
+ destination: netip.MustParsePrefix("10.0.0.0/24"),
+ proto: firewall.ProtocolUDP,
+ sPort: nil,
+ dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
+ direction: firewall.RuleDirectionIN,
+ action: firewall.ActionDrop,
+ expectSet: false,
+ },
+ {
+ name: "TCP rule with source and destination ports",
+ sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
+ destination: netip.MustParsePrefix("172.16.0.0/16"),
+ proto: firewall.ProtocolTCP,
+ sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
+ dPort: &firewall.Port{Values: []int{22}},
+ direction: firewall.RuleDirectionOUT,
+ action: firewall.ActionAccept,
+ expectSet: false,
+ },
+ {
+ name: "Drop all incoming traffic",
+ sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
+ destination: netip.MustParsePrefix("192.168.0.0/24"),
+ proto: firewall.ProtocolALL,
+ sPort: nil,
+ dPort: nil,
+ direction: firewall.RuleDirectionIN,
+ action: firewall.ActionDrop,
+ expectSet: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
+ require.NoError(t, err, "AddRouteFiltering failed")
+
+ // Check if the rule is in the internal map
+ rule, ok := r.rules[ruleKey.GetRuleID()]
+ assert.True(t, ok, "Rule not found in internal map")
+
+ // Log the internal rule
+ t.Logf("Internal rule: %v", rule)
+
+ // Check if the rule exists in iptables
+ exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, rule...)
+ assert.NoError(t, err, "Failed to check rule existence")
+ assert.True(t, exists, "Rule not found in iptables")
+
+ // Verify rule content
+ params := routeFilteringRuleParams{
+ Sources: tt.sources,
+ Destination: tt.destination,
+ Proto: tt.proto,
+ SPort: tt.sPort,
+ DPort: tt.dPort,
+ Action: tt.action,
+ SetName: "",
+ }
+
+ expectedRule := genRouteFilteringRuleSpec(params)
+
+ if tt.expectSet {
+ setName := firewall.GenerateSetName(tt.sources)
+ params.SetName = setName
+ expectedRule = genRouteFilteringRuleSpec(params)
+
+ // Check if the set was created
+ _, exists := r.ipsetCounter.Get(setName)
+ assert.True(t, exists, "IPSet not created")
+ }
+
+ assert.Equal(t, expectedRule, rule, "Rule content mismatch")
+
+ // Clean up
+ err = r.DeleteRouteRule(ruleKey)
+ require.NoError(t, err, "Failed to delete rule")
})
}
}
diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go
index 6e4edb63e..a6185d370 100644
--- a/client/firewall/manager/firewall.go
+++ b/client/firewall/manager/firewall.go
@@ -1,15 +1,21 @@
package manager
import (
+ "crypto/sha256"
+ "encoding/hex"
"fmt"
"net"
+ "net/netip"
+ "sort"
+ "strings"
+
+ log "github.com/sirupsen/logrus"
)
const (
- NatFormat = "netbird-nat-%s"
- ForwardingFormat = "netbird-fwd-%s"
- InNatFormat = "netbird-nat-in-%s"
- InForwardingFormat = "netbird-fwd-in-%s"
+ ForwardingFormatPrefix = "netbird-fwd-"
+ ForwardingFormat = "netbird-fwd-%s-%t"
+ NatFormat = "netbird-nat-%s-%t"
)
// Rule abstraction should be implemented by each firewall manager
@@ -49,11 +55,11 @@ type Manager interface {
// AllowNetbird allows netbird interface traffic
AllowNetbird() error
- // AddFiltering rule to the firewall
+ // AddPeerFiltering adds a rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
- AddFiltering(
+ AddPeerFiltering(
ip net.IP,
proto Protocol,
sPort *Port,
@@ -64,17 +70,25 @@ type Manager interface {
comment string,
) ([]Rule, error)
- // DeleteRule from the firewall by rule definition
- DeleteRule(rule Rule) error
+ // DeletePeerRule from the firewall by rule definition
+ DeletePeerRule(rule Rule) error
// IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool
- // InsertRoutingRules inserts a routing firewall rule
- InsertRoutingRules(pair RouterPair) error
+ AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error)
- // RemoveRoutingRules removes a routing firewall rule
- RemoveRoutingRules(pair RouterPair) error
+ // DeleteRouteRule deletes a routing rule
+ DeleteRouteRule(rule Rule) error
+
+ // AddNatRule inserts a routing NAT rule
+ AddNatRule(pair RouterPair) error
+
+ // RemoveNatRule removes a routing NAT rule
+ RemoveNatRule(pair RouterPair) error
+
+ // SetLegacyManagement sets the legacy management mode
+ SetLegacyManagement(legacy bool) error
// Reset firewall to the default state
Reset() error
@@ -83,6 +97,89 @@ type Manager interface {
Flush() error
}
-func GenKey(format string, input string) string {
- return fmt.Sprintf(format, input)
+func GenKey(format string, pair RouterPair) string {
+ return fmt.Sprintf(format, pair.ID, pair.Inverse)
+}
+
+// LegacyManager defines the interface for legacy management operations
+type LegacyManager interface {
+ RemoveAllLegacyRouteRules() error
+ GetLegacyManagement() bool
+ SetLegacyManagement(bool)
+}
+
+// SetLegacyManagement sets the route manager to use legacy management
+func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
+ oldLegacy := router.GetLegacyManagement()
+
+ if oldLegacy != isLegacy {
+ router.SetLegacyManagement(isLegacy)
+ log.Debugf("Set legacy management to %v", isLegacy)
+ }
+
+ // client reconnected to a newer mgmt, we need to clean up the legacy rules
+ if !isLegacy && oldLegacy {
+ if err := router.RemoveAllLegacyRouteRules(); err != nil {
+ return fmt.Errorf("remove legacy routing rules: %v", err)
+ }
+
+ log.Debugf("Legacy routing rules removed")
+ }
+
+ return nil
+}
+
+// GenerateSetName generates a unique name for an ipset based on the given sources.
+func GenerateSetName(sources []netip.Prefix) string {
+ // sort for consistent naming
+ sortPrefixes(sources)
+
+ var sourcesStr strings.Builder
+ for _, src := range sources {
+ sourcesStr.WriteString(src.String())
+ }
+
+ hash := sha256.Sum256([]byte(sourcesStr.String()))
+ shortHash := hex.EncodeToString(hash[:])[:8]
+
+ return fmt.Sprintf("nb-%s", shortHash)
+}
+
+// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
+func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
+ if len(prefixes) == 0 {
+ return prefixes
+ }
+
+ merged := []netip.Prefix{prefixes[0]}
+ for _, prefix := range prefixes[1:] {
+ last := merged[len(merged)-1]
+ if last.Contains(prefix.Addr()) {
+ // If the current prefix is contained within the last merged prefix, skip it
+ continue
+ }
+ if prefix.Contains(last.Addr()) {
+ // If the current prefix contains the last merged prefix, replace it
+ merged[len(merged)-1] = prefix
+ } else {
+ // Otherwise, add the current prefix to the merged list
+ merged = append(merged, prefix)
+ }
+ }
+
+ return merged
+}
+
+// sortPrefixes sorts the given slice of netip.Prefix in place.
+// It sorts first by IP address, then by prefix length (most specific to least specific).
+func sortPrefixes(prefixes []netip.Prefix) {
+ sort.Slice(prefixes, func(i, j int) bool {
+ addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr())
+ if addrCmp != 0 {
+ return addrCmp < 0
+ }
+
+ // If IP addresses are the same, compare prefix lengths (longer prefixes first)
+ return prefixes[i].Bits() > prefixes[j].Bits()
+ })
}
diff --git a/client/firewall/manager/firewall_test.go b/client/firewall/manager/firewall_test.go
new file mode 100644
index 000000000..3f47d6679
--- /dev/null
+++ b/client/firewall/manager/firewall_test.go
@@ -0,0 +1,192 @@
+package manager_test
+
+import (
+ "net/netip"
+ "reflect"
+ "regexp"
+ "testing"
+
+ "github.com/netbirdio/netbird/client/firewall/manager"
+)
+
+func TestGenerateSetName(t *testing.T) {
+ t.Run("Different orders result in same hash", func(t *testing.T) {
+ prefixes1 := []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.0/24"),
+ netip.MustParsePrefix("10.0.0.0/8"),
+ }
+ prefixes2 := []netip.Prefix{
+ netip.MustParsePrefix("10.0.0.0/8"),
+ netip.MustParsePrefix("192.168.1.0/24"),
+ }
+
+ result1 := manager.GenerateSetName(prefixes1)
+ result2 := manager.GenerateSetName(prefixes2)
+
+ if result1 != result2 {
+ t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
+ }
+ })
+
+ t.Run("Result format is correct", func(t *testing.T) {
+ prefixes := []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.0/24"),
+ netip.MustParsePrefix("10.0.0.0/8"),
+ }
+
+ result := manager.GenerateSetName(prefixes)
+
+ matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result)
+ if err != nil {
+ t.Fatalf("Error matching regex: %v", err)
+ }
+ if !matched {
+ t.Errorf("Result format is incorrect: %s", result)
+ }
+ })
+
+ t.Run("Empty input produces consistent result", func(t *testing.T) {
+ result1 := manager.GenerateSetName([]netip.Prefix{})
+ result2 := manager.GenerateSetName([]netip.Prefix{})
+
+ if result1 != result2 {
+ t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
+ }
+ })
+
+ t.Run("IPv4 and IPv6 mixing", func(t *testing.T) {
+ prefixes1 := []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.0/24"),
+ netip.MustParsePrefix("2001:db8::/32"),
+ }
+ prefixes2 := []netip.Prefix{
+ netip.MustParsePrefix("2001:db8::/32"),
+ netip.MustParsePrefix("192.168.1.0/24"),
+ }
+
+ result1 := manager.GenerateSetName(prefixes1)
+ result2 := manager.GenerateSetName(prefixes2)
+
+ if result1 != result2 {
+ t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)
+ }
+ })
+}
+
+func TestMergeIPRanges(t *testing.T) {
+ tests := []struct {
+ name string
+ input []netip.Prefix
+ expected []netip.Prefix
+ }{
+ {
+ name: "Empty input",
+ input: []netip.Prefix{},
+ expected: []netip.Prefix{},
+ },
+ {
+ name: "Single range",
+ input: []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.0/24"),
+ },
+ expected: []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.0/24"),
+ },
+ },
+ {
+ name: "Two non-overlapping ranges",
+ input: []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.0/24"),
+ netip.MustParsePrefix("10.0.0.0/8"),
+ },
+ expected: []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.0/24"),
+ netip.MustParsePrefix("10.0.0.0/8"),
+ },
+ },
+ {
+ name: "One range containing another",
+ input: []netip.Prefix{
+ netip.MustParsePrefix("192.168.0.0/16"),
+ netip.MustParsePrefix("192.168.1.0/24"),
+ },
+ expected: []netip.Prefix{
+ netip.MustParsePrefix("192.168.0.0/16"),
+ },
+ },
+ {
+ name: "One range containing another (different order)",
+ input: []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.0/24"),
+ netip.MustParsePrefix("192.168.0.0/16"),
+ },
+ expected: []netip.Prefix{
+ netip.MustParsePrefix("192.168.0.0/16"),
+ },
+ },
+ {
+ name: "Overlapping ranges",
+ input: []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.0/24"),
+ netip.MustParsePrefix("192.168.1.128/25"),
+ },
+ expected: []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.0/24"),
+ },
+ },
+ {
+ name: "Overlapping ranges (different order)",
+ input: []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.128/25"),
+ netip.MustParsePrefix("192.168.1.0/24"),
+ },
+ expected: []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.0/24"),
+ },
+ },
+ {
+ name: "Multiple overlapping ranges",
+ input: []netip.Prefix{
+ netip.MustParsePrefix("192.168.0.0/16"),
+ netip.MustParsePrefix("192.168.1.0/24"),
+ netip.MustParsePrefix("192.168.2.0/24"),
+ netip.MustParsePrefix("192.168.1.128/25"),
+ },
+ expected: []netip.Prefix{
+ netip.MustParsePrefix("192.168.0.0/16"),
+ },
+ },
+ {
+ name: "Partially overlapping ranges",
+ input: []netip.Prefix{
+ netip.MustParsePrefix("192.168.0.0/23"),
+ netip.MustParsePrefix("192.168.1.0/24"),
+ netip.MustParsePrefix("192.168.2.0/25"),
+ },
+ expected: []netip.Prefix{
+ netip.MustParsePrefix("192.168.0.0/23"),
+ netip.MustParsePrefix("192.168.2.0/25"),
+ },
+ },
+ {
+ name: "IPv6 ranges",
+ input: []netip.Prefix{
+ netip.MustParsePrefix("2001:db8::/32"),
+ netip.MustParsePrefix("2001:db8:1::/48"),
+ netip.MustParsePrefix("2001:db8:2::/48"),
+ },
+ expected: []netip.Prefix{
+ netip.MustParsePrefix("2001:db8::/32"),
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := manager.MergeIPRanges(tt.input)
+ if !reflect.DeepEqual(result, tt.expected) {
+ t.Errorf("MergeIPRanges() = %v, want %v", result, tt.expected)
+ }
+ })
+ }
+}
diff --git a/client/firewall/manager/routerpair.go b/client/firewall/manager/routerpair.go
index b63a9f104..8c94b7dd4 100644
--- a/client/firewall/manager/routerpair.go
+++ b/client/firewall/manager/routerpair.go
@@ -1,18 +1,26 @@
package manager
+import (
+ "net/netip"
+
+ "github.com/netbirdio/netbird/route"
+)
+
type RouterPair struct {
- ID string
- Source string
- Destination string
+ ID route.ID
+ Source netip.Prefix
+ Destination netip.Prefix
Masquerade bool
+ Inverse bool
}
-func GetInPair(pair RouterPair) RouterPair {
+func GetInversePair(pair RouterPair) RouterPair {
return RouterPair{
ID: pair.ID,
// invert Source/Destination
Source: pair.Destination,
Destination: pair.Source,
Masquerade: pair.Masquerade,
+ Inverse: true,
}
}
diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go
index 1fa41b63a..eaf7fb6a0 100644
--- a/client/firewall/nftables/acl_linux.go
+++ b/client/firewall/nftables/acl_linux.go
@@ -16,7 +16,7 @@ import (
"golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
- "github.com/netbirdio/netbird/iface"
+ "github.com/netbirdio/netbird/client/iface"
)
const (
@@ -33,9 +33,10 @@ const (
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
)
+const flushError = "flush: %w"
+
var (
- anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
- postroutingMark = []byte{0xe4, 0x7, 0x0, 0x00}
+ anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
)
type AclManager struct {
@@ -48,7 +49,6 @@ type AclManager struct {
chainInputRules *nftables.Chain
chainOutputRules *nftables.Chain
chainFwFilter *nftables.Chain
- chainPrerouting *nftables.Chain
ipsetStore *ipsetStore
rules map[string]*Rule
@@ -64,7 +64,7 @@ type iFaceMapper interface {
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) {
// sConn is used for creating sets and adding/removing elements from them
// it's differ then rConn (which does create new conn for each flush operation)
- // and is permanent. Using same connection for booth type of operations
+ // and is permanent. Using same connection for both type of operations
// overloads netlink with high amount of rules ( > 10000)
sConn, err := nftables.New(nftables.AsLasting())
if err != nil {
@@ -90,11 +90,11 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainNa
return m, nil
}
-// AddFiltering rule to the firewall
+// AddPeerFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
-func (m *AclManager) AddFiltering(
+func (m *AclManager) AddPeerFiltering(
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
@@ -120,20 +120,11 @@ func (m *AclManager) AddFiltering(
}
newRules = append(newRules, ioRule)
- if !shouldAddToPrerouting(proto, dPort, direction) {
- return newRules, nil
- }
-
- preroutingRule, err := m.addPreroutingFiltering(ipset, proto, dPort, ip)
- if err != nil {
- return newRules, err
- }
- newRules = append(newRules, preroutingRule)
return newRules, nil
}
-// DeleteRule from the firewall by rule definition
-func (m *AclManager) DeleteRule(rule firewall.Rule) error {
+// DeletePeerRule from the firewall by rule definition
+func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
r, ok := rule.(*Rule)
if !ok {
return fmt.Errorf("invalid rule type")
@@ -199,8 +190,7 @@ func (m *AclManager) DeleteRule(rule firewall.Rule) error {
return nil
}
-// createDefaultAllowRules In case if the USP firewall manager can use the native firewall manager we must to create allow rules for
-// input and output chains
+// createDefaultAllowRules creates default allow rules for the input and output chains
func (m *AclManager) createDefaultAllowRules() error {
expIn := []expr.Any{
&expr.Payload{
@@ -214,13 +204,13 @@ func (m *AclManager) createDefaultAllowRules() error {
SourceRegister: 1,
DestRegister: 1,
Len: 4,
- Mask: []byte{0x00, 0x00, 0x00, 0x00},
- Xor: zeroXor,
+ Mask: []byte{0, 0, 0, 0},
+ Xor: []byte{0, 0, 0, 0},
},
// net address
&expr.Cmp{
Register: 1,
- Data: []byte{0x00, 0x00, 0x00, 0x00},
+ Data: []byte{0, 0, 0, 0},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
@@ -246,13 +236,13 @@ func (m *AclManager) createDefaultAllowRules() error {
SourceRegister: 1,
DestRegister: 1,
Len: 4,
- Mask: []byte{0x00, 0x00, 0x00, 0x00},
- Xor: zeroXor,
+ Mask: []byte{0, 0, 0, 0},
+ Xor: []byte{0, 0, 0, 0},
},
// net address
&expr.Cmp{
Register: 1,
- Data: []byte{0x00, 0x00, 0x00, 0x00},
+ Data: []byte{0, 0, 0, 0},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
@@ -266,10 +256,8 @@ func (m *AclManager) createDefaultAllowRules() error {
Exprs: expOut,
})
- err := m.rConn.Flush()
- if err != nil {
- log.Debugf("failed to create default allow rules: %s", err)
- return err
+ if err := m.rConn.Flush(); err != nil {
+ return fmt.Errorf(flushError, err)
}
return nil
}
@@ -290,15 +278,11 @@ func (m *AclManager) Flush() error {
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
}
- if err := m.refreshRuleHandles(m.chainPrerouting); err != nil {
- log.Errorf("failed to refresh rule handles IPv4 prerouting chain: %v", err)
- }
-
return nil
}
func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) {
- ruleId := generateRuleId(ip, sPort, dPort, direction, action, ipset)
+ ruleId := generatePeerRuleId(ip, sPort, dPort, direction, action, ipset)
if r, ok := m.rules[ruleId]; ok {
return &Rule{
r.nftRule,
@@ -308,18 +292,7 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
}, nil
}
- ifaceKey := expr.MetaKeyIIFNAME
- if direction == firewall.RuleDirectionOUT {
- ifaceKey = expr.MetaKeyOIFNAME
- }
- expressions := []expr.Any{
- &expr.Meta{Key: ifaceKey, Register: 1},
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 1,
- Data: ifname(m.wgIface.Name()),
- },
- }
+ var expressions []expr.Any
if proto != firewall.ProtocolALL {
expressions = append(expressions, &expr.Payload{
@@ -329,21 +302,15 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
Len: uint32(1),
})
- var protoData []byte
- switch proto {
- case firewall.ProtocolTCP:
- protoData = []byte{unix.IPPROTO_TCP}
- case firewall.ProtocolUDP:
- protoData = []byte{unix.IPPROTO_UDP}
- case firewall.ProtocolICMP:
- protoData = []byte{unix.IPPROTO_ICMP}
- default:
- return nil, fmt.Errorf("unsupported protocol: %s", proto)
+ protoData, err := protoToInt(proto)
+ if err != nil {
+ return nil, fmt.Errorf("convert protocol to number: %v", err)
}
+
expressions = append(expressions, &expr.Cmp{
Register: 1,
Op: expr.CmpOpEq,
- Data: protoData,
+ Data: []byte{protoData},
})
}
@@ -432,10 +399,9 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
} else {
chain = m.chainOutputRules
}
- nftRule := m.rConn.InsertRule(&nftables.Rule{
+ nftRule := m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: chain,
- Position: 0,
Exprs: expressions,
UserData: userData,
})
@@ -453,139 +419,13 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
return rule, nil
}
-func (m *AclManager) addPreroutingFiltering(ipset *nftables.Set, proto firewall.Protocol, port *firewall.Port, ip net.IP) (*Rule, error) {
- var protoData []byte
- switch proto {
- case firewall.ProtocolTCP:
- protoData = []byte{unix.IPPROTO_TCP}
- case firewall.ProtocolUDP:
- protoData = []byte{unix.IPPROTO_UDP}
- case firewall.ProtocolICMP:
- protoData = []byte{unix.IPPROTO_ICMP}
- default:
- return nil, fmt.Errorf("unsupported protocol: %s", proto)
- }
-
- ruleId := generateRuleIdForMangle(ipset, ip, proto, port)
- if r, ok := m.rules[ruleId]; ok {
- return &Rule{
- r.nftRule,
- r.nftSet,
- r.ruleID,
- ip,
- }, nil
- }
-
- var ipExpression expr.Any
- // add individual IP for match if no ipset defined
- rawIP := ip.To4()
- if ipset == nil {
- ipExpression = &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 1,
- Data: rawIP,
- }
- } else {
- ipExpression = &expr.Lookup{
- SourceRegister: 1,
- SetName: ipset.Name,
- SetID: ipset.ID,
- }
- }
-
- expressions := []expr.Any{
- &expr.Payload{
- DestRegister: 1,
- Base: expr.PayloadBaseNetworkHeader,
- Offset: 12,
- Len: 4,
- },
- ipExpression,
- &expr.Payload{
- DestRegister: 1,
- Base: expr.PayloadBaseNetworkHeader,
- Offset: 16,
- Len: 4,
- },
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 1,
- Data: m.wgIface.Address().IP.To4(),
- },
- &expr.Payload{
- DestRegister: 1,
- Base: expr.PayloadBaseNetworkHeader,
- Offset: uint32(9),
- Len: uint32(1),
- },
- &expr.Cmp{
- Register: 1,
- Op: expr.CmpOpEq,
- Data: protoData,
- },
- }
-
- if port != nil {
- expressions = append(expressions,
- &expr.Payload{
- DestRegister: 1,
- Base: expr.PayloadBaseTransportHeader,
- Offset: 2,
- Len: 2,
- },
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 1,
- Data: encodePort(*port),
- },
- )
- }
-
- expressions = append(expressions,
- &expr.Immediate{
- Register: 1,
- Data: postroutingMark,
- },
- &expr.Meta{
- Key: expr.MetaKeyMARK,
- SourceRegister: true,
- Register: 1,
- },
- )
-
- nftRule := m.rConn.InsertRule(&nftables.Rule{
- Table: m.workTable,
- Chain: m.chainPrerouting,
- Position: 0,
- Exprs: expressions,
- UserData: []byte(ruleId),
- })
-
- if err := m.rConn.Flush(); err != nil {
- return nil, fmt.Errorf("flush insert rule: %v", err)
- }
-
- rule := &Rule{
- nftRule: nftRule,
- nftSet: ipset,
- ruleID: ruleId,
- ip: ip,
- }
-
- m.rules[ruleId] = rule
- if ipset != nil {
- m.ipsetStore.AddReferenceToIpset(ipset.Name)
- }
- return rule, nil
-}
-
func (m *AclManager) createDefaultChains() (err error) {
// chainNameInputRules
chain := m.createChain(chainNameInputRules)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chain.Name, err)
- return err
+ return fmt.Errorf(flushError, err)
}
m.chainInputRules = chain
@@ -601,9 +441,6 @@ func (m *AclManager) createDefaultChains() (err error) {
// netbird-acl-input-filter
// type filter hook input priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
- //netbird-acl-input-filter iifname "wt0" ip saddr 100.72.0.0/16 ip daddr != 100.72.0.0/16 accept
- m.addRouteAllowRule(chain, expr.MetaKeyIIFNAME)
- m.addFwdAllow(chain, expr.MetaKeyIIFNAME)
m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
m.addDropExpressions(chain, expr.MetaKeyIIFNAME)
err = m.rConn.Flush()
@@ -615,7 +452,6 @@ func (m *AclManager) createDefaultChains() (err error) {
// netbird-acl-output-filter
// type filter hook output priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput)
- m.addRouteAllowRule(chain, expr.MetaKeyOIFNAME)
m.addFwdAllow(chain, expr.MetaKeyOIFNAME)
m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules
m.addDropExpressions(chain, expr.MetaKeyOIFNAME)
@@ -627,24 +463,15 @@ func (m *AclManager) createDefaultChains() (err error) {
// netbird-acl-forward-filter
m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
- m.addJumpRulesToRtForward() // to
- m.addMarkAccept()
- m.addJumpRuleToInputChain() // to netbird-acl-input-rules
+ m.addJumpRulesToRtForward() // to netbird-rt-fwd
m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME)
+
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
- return err
+ return fmt.Errorf(flushError, err)
}
- // netbird-acl-output-filter
- // type filter hook output priority filter; policy accept;
- m.chainPrerouting = m.createPreroutingMangle()
- err = m.rConn.Flush()
- if err != nil {
- log.Debugf("failed to create chain (%s): %s", m.chainPrerouting.Name, err)
- return err
- }
return nil
}
@@ -667,59 +494,6 @@ func (m *AclManager) addJumpRulesToRtForward() {
Chain: m.chainFwFilter,
Exprs: expressions,
})
-
- expressions = []expr.Any{
- &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 1,
- Data: ifname(m.wgIface.Name()),
- },
- &expr.Verdict{
- Kind: expr.VerdictJump,
- Chain: m.routeingFwChainName,
- },
- }
-
- _ = m.rConn.AddRule(&nftables.Rule{
- Table: m.workTable,
- Chain: m.chainFwFilter,
- Exprs: expressions,
- })
-}
-
-func (m *AclManager) addMarkAccept() {
- // oifname "wt0" meta mark 0x000007e4 accept
- // iifname "wt0" meta mark 0x000007e4 accept
- ifaces := []expr.MetaKey{expr.MetaKeyIIFNAME, expr.MetaKeyOIFNAME}
- for _, iface := range ifaces {
- expressions := []expr.Any{
- &expr.Meta{Key: iface, Register: 1},
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 1,
- Data: ifname(m.wgIface.Name()),
- },
- &expr.Meta{
- Key: expr.MetaKeyMARK,
- Register: 1,
- },
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 1,
- Data: postroutingMark,
- },
- &expr.Verdict{
- Kind: expr.VerdictAccept,
- },
- }
-
- _ = m.rConn.AddRule(&nftables.Rule{
- Table: m.workTable,
- Chain: m.chainFwFilter,
- Exprs: expressions,
- })
- }
}
func (m *AclManager) createChain(name string) *nftables.Chain {
@@ -729,6 +503,9 @@ func (m *AclManager) createChain(name string) *nftables.Chain {
}
chain = m.rConn.AddChain(chain)
+
+ insertReturnTrafficRule(m.rConn, m.workTable, chain)
+
return chain
}
@@ -746,74 +523,6 @@ func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.Cha
return m.rConn.AddChain(chain)
}
-func (m *AclManager) createPreroutingMangle() *nftables.Chain {
- polAccept := nftables.ChainPolicyAccept
- chain := &nftables.Chain{
- Name: "netbird-acl-prerouting-filter",
- Table: m.workTable,
- Hooknum: nftables.ChainHookPrerouting,
- Priority: nftables.ChainPriorityMangle,
- Type: nftables.ChainTypeFilter,
- Policy: &polAccept,
- }
-
- chain = m.rConn.AddChain(chain)
-
- ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
- expressions := []expr.Any{
- &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 1,
- Data: ifname(m.wgIface.Name()),
- },
- &expr.Payload{
- DestRegister: 2,
- Base: expr.PayloadBaseNetworkHeader,
- Offset: 12,
- Len: 4,
- },
- &expr.Bitwise{
- SourceRegister: 2,
- DestRegister: 2,
- Len: 4,
- Xor: []byte{0x0, 0x0, 0x0, 0x0},
- Mask: m.wgIface.Address().Network.Mask,
- },
- &expr.Cmp{
- Op: expr.CmpOpNeq,
- Register: 2,
- Data: ip.Unmap().AsSlice(),
- },
- &expr.Payload{
- DestRegister: 1,
- Base: expr.PayloadBaseNetworkHeader,
- Offset: 16,
- Len: 4,
- },
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 1,
- Data: m.wgIface.Address().IP.To4(),
- },
- &expr.Immediate{
- Register: 1,
- Data: postroutingMark,
- },
- &expr.Meta{
- Key: expr.MetaKeyMARK,
- SourceRegister: true,
- Register: 1,
- },
- }
- _ = m.rConn.AddRule(&nftables.Rule{
- Table: m.workTable,
- Chain: chain,
- Exprs: expressions,
- })
- return chain
-}
-
func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
@@ -832,101 +541,9 @@ func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.Met
return nil
}
-func (m *AclManager) addJumpRuleToInputChain() {
- expressions := []expr.Any{
- &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 1,
- Data: ifname(m.wgIface.Name()),
- },
- &expr.Verdict{
- Kind: expr.VerdictJump,
- Chain: m.chainInputRules.Name,
- },
- }
-
- _ = m.rConn.AddRule(&nftables.Rule{
- Table: m.workTable,
- Chain: m.chainFwFilter,
- Exprs: expressions,
- })
-}
-
-func (m *AclManager) addRouteAllowRule(chain *nftables.Chain, netIfName expr.MetaKey) {
- ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
- var srcOp, dstOp expr.CmpOp
- if netIfName == expr.MetaKeyIIFNAME {
- srcOp = expr.CmpOpEq
- dstOp = expr.CmpOpNeq
- } else {
- srcOp = expr.CmpOpNeq
- dstOp = expr.CmpOpEq
- }
- expressions := []expr.Any{
- &expr.Meta{Key: netIfName, Register: 1},
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 1,
- Data: ifname(m.wgIface.Name()),
- },
- &expr.Payload{
- DestRegister: 2,
- Base: expr.PayloadBaseNetworkHeader,
- Offset: 12,
- Len: 4,
- },
- &expr.Bitwise{
- SourceRegister: 2,
- DestRegister: 2,
- Len: 4,
- Xor: []byte{0x0, 0x0, 0x0, 0x0},
- Mask: m.wgIface.Address().Network.Mask,
- },
- &expr.Cmp{
- Op: srcOp,
- Register: 2,
- Data: ip.Unmap().AsSlice(),
- },
- &expr.Payload{
- DestRegister: 2,
- Base: expr.PayloadBaseNetworkHeader,
- Offset: 16,
- Len: 4,
- },
- &expr.Bitwise{
- SourceRegister: 2,
- DestRegister: 2,
- Len: 4,
- Xor: []byte{0x0, 0x0, 0x0, 0x0},
- Mask: m.wgIface.Address().Network.Mask,
- },
- &expr.Cmp{
- Op: dstOp,
- Register: 2,
- Data: ip.Unmap().AsSlice(),
- },
- &expr.Verdict{
- Kind: expr.VerdictAccept,
- },
- }
- _ = m.rConn.AddRule(&nftables.Rule{
- Table: chain.Table,
- Chain: chain,
- Exprs: expressions,
- })
-}
-
func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
- var srcOp, dstOp expr.CmpOp
- if iifname == expr.MetaKeyIIFNAME {
- srcOp = expr.CmpOpNeq
- dstOp = expr.CmpOpEq
- } else {
- srcOp = expr.CmpOpEq
- dstOp = expr.CmpOpNeq
- }
+ dstOp := expr.CmpOpNeq
expressions := []expr.Any{
&expr.Meta{Key: iifname, Register: 1},
&expr.Cmp{
@@ -934,24 +551,6 @@ func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
Register: 1,
Data: ifname(m.wgIface.Name()),
},
- &expr.Payload{
- DestRegister: 2,
- Base: expr.PayloadBaseNetworkHeader,
- Offset: 12,
- Len: 4,
- },
- &expr.Bitwise{
- SourceRegister: 2,
- DestRegister: 2,
- Len: 4,
- Xor: []byte{0x0, 0x0, 0x0, 0x0},
- Mask: m.wgIface.Address().Network.Mask,
- },
- &expr.Cmp{
- Op: srcOp,
- Register: 2,
- Data: ip.Unmap().AsSlice(),
- },
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
@@ -982,7 +581,6 @@ func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
}
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
- ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
@@ -990,47 +588,12 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr
Register: 1,
Data: ifname(m.wgIface.Name()),
},
- &expr.Payload{
- DestRegister: 2,
- Base: expr.PayloadBaseNetworkHeader,
- Offset: 12,
- Len: 4,
- },
- &expr.Bitwise{
- SourceRegister: 2,
- DestRegister: 2,
- Len: 4,
- Xor: []byte{0x0, 0x0, 0x0, 0x0},
- Mask: m.wgIface.Address().Network.Mask,
- },
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 2,
- Data: ip.Unmap().AsSlice(),
- },
- &expr.Payload{
- DestRegister: 2,
- Base: expr.PayloadBaseNetworkHeader,
- Offset: 16,
- Len: 4,
- },
- &expr.Bitwise{
- SourceRegister: 2,
- DestRegister: 2,
- Len: 4,
- Xor: []byte{0x0, 0x0, 0x0, 0x0},
- Mask: m.wgIface.Address().Network.Mask,
- },
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 2,
- Data: ip.Unmap().AsSlice(),
- },
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: to,
},
}
+
_ = m.rConn.AddRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
@@ -1132,7 +695,7 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
return nil
}
-func generateRuleId(
+func generatePeerRuleId(
ip net.IP,
sPort *firewall.Port,
dPort *firewall.Port,
@@ -1155,33 +718,6 @@ func generateRuleId(
}
return "set:" + ipset.Name + rulesetID
}
-func generateRuleIdForMangle(ipset *nftables.Set, ip net.IP, proto firewall.Protocol, port *firewall.Port) string {
- // case of icmp port is empty
- var p string
- if port != nil {
- p = port.String()
- }
- if ipset != nil {
- return fmt.Sprintf("p:set:%s:%s:%v", ipset.Name, proto, p)
- } else {
- return fmt.Sprintf("p:ip:%s:%s:%v", ip.String(), proto, p)
- }
-}
-
-func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
- if proto == "all" {
- return false
- }
-
- if direction != firewall.RuleDirectionIN {
- return false
- }
-
- if dPort == nil && proto != firewall.ProtocolICMP {
- return false
- }
- return true
-}
func encodePort(port firewall.Port) []byte {
bs := make([]byte, 2)
@@ -1191,6 +727,19 @@ func encodePort(port firewall.Port) []byte {
func ifname(n string) []byte {
b := make([]byte, 16)
- copy(b, []byte(n+"\x00"))
+ copy(b, n+"\x00")
return b
}
+
+func protoToInt(protocol firewall.Protocol) (uint8, error) {
+ switch protocol {
+ case firewall.ProtocolTCP:
+ return unix.IPPROTO_TCP, nil
+ case firewall.ProtocolUDP:
+ return unix.IPPROTO_UDP, nil
+ case firewall.ProtocolICMP:
+ return unix.IPPROTO_ICMP, nil
+ }
+
+ return 0, fmt.Errorf("unsupported protocol: %s", protocol)
+}
diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go
index a376c98c3..d2258ae08 100644
--- a/client/firewall/nftables/manager_linux.go
+++ b/client/firewall/nftables/manager_linux.go
@@ -5,9 +5,11 @@ import (
"context"
"fmt"
"net"
+ "net/netip"
"sync"
"github.com/google/nftables"
+ "github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
@@ -15,8 +17,11 @@ import (
)
const (
- // tableName is the name of the table that is used for filtering by the Netbird client
- tableName = "netbird"
+ // tableNameNetbird is the name of the table that is used for filtering by the Netbird client
+ tableNameNetbird = "netbird"
+
+ tableNameFilter = "filter"
+ chainNameInput = "INPUT"
)
// Manager of iptables firewall
@@ -41,12 +46,12 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
return nil, err
}
- m.router, err = newRouter(context, workTable)
+ m.router, err = newRouter(context, workTable, wgIface)
if err != nil {
return nil, err
}
- m.aclManager, err = newAclManager(workTable, wgIface, m.router.RouteingFwChainName())
+ m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
if err != nil {
return nil, err
}
@@ -54,11 +59,11 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
return m, nil
}
-// AddFiltering rule to the firewall
+// AddPeerFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
-func (m *Manager) AddFiltering(
+func (m *Manager) AddPeerFiltering(
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
@@ -76,33 +81,52 @@ func (m *Manager) AddFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
}
- return m.aclManager.AddFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
+ return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
}
-// DeleteRule from the firewall by rule definition
-func (m *Manager) DeleteRule(rule firewall.Rule) error {
+func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
- return m.aclManager.DeleteRule(rule)
+ if !destination.Addr().Is4() {
+ return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
+ }
+
+ return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
+}
+
+// DeletePeerRule from the firewall by rule definition
+func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
+ m.mutex.Lock()
+ defer m.mutex.Unlock()
+
+ return m.aclManager.DeletePeerRule(rule)
+}
+
+// DeleteRouteRule deletes a routing rule
+func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
+ m.mutex.Lock()
+ defer m.mutex.Unlock()
+
+ return m.router.DeleteRouteRule(rule)
}
func (m *Manager) IsServerRouteSupported() bool {
return true
}
-func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
+func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
- return m.router.AddRoutingRules(pair)
+ return m.router.AddNatRule(pair)
}
-func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
+func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
m.mutex.Lock()
defer m.mutex.Unlock()
- return m.router.RemoveRoutingRules(pair)
+ return m.router.RemoveNatRule(pair)
}
// AllowNetbird allows netbird interface traffic
@@ -126,7 +150,7 @@ func (m *Manager) AllowNetbird() error {
var chain *nftables.Chain
for _, c := range chains {
- if c.Table.Name == "filter" && c.Name == "INPUT" {
+ if c.Table.Name == tableNameFilter && c.Name == chainNameForward {
chain = c
break
}
@@ -157,6 +181,27 @@ func (m *Manager) AllowNetbird() error {
return nil
}
+// SetLegacyManagement sets the route manager to use legacy management
+func (m *Manager) SetLegacyManagement(isLegacy bool) error {
+ oldLegacy := m.router.legacyManagement
+
+ if oldLegacy != isLegacy {
+ m.router.legacyManagement = isLegacy
+ log.Debugf("Set legacy management to %v", isLegacy)
+ }
+
+ // client reconnected to a newer mgmt, we need to cleanup the legacy rules
+ if !isLegacy && oldLegacy {
+ if err := m.router.RemoveAllLegacyRouteRules(); err != nil {
+ return fmt.Errorf("remove legacy routing rules: %v", err)
+ }
+
+ log.Debugf("Legacy routing rules removed")
+ }
+
+ return nil
+}
+
// Reset firewall to the default state
func (m *Manager) Reset() error {
m.mutex.Lock()
@@ -185,14 +230,16 @@ func (m *Manager) Reset() error {
}
}
- m.router.ResetForwardRules()
+ if err := m.router.Reset(); err != nil {
+ return fmt.Errorf("reset forward rules: %v", err)
+ }
tables, err := m.rConn.ListTables()
if err != nil {
return fmt.Errorf("list of tables: %w", err)
}
for _, t := range tables {
- if t.Name == tableName {
+ if t.Name == tableNameNetbird {
m.rConn.DelTable(t)
}
}
@@ -218,12 +265,12 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
}
for _, t := range tables {
- if t.Name == tableName {
+ if t.Name == tableNameNetbird {
m.rConn.DelTable(t)
}
}
- table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
+ table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4})
err = m.rConn.Flush()
return table, err
}
@@ -239,9 +286,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
Register: 1,
Data: ifname(m.wgIface.Name()),
},
- &expr.Verdict{
- Kind: expr.VerdictAccept,
- },
+ &expr.Verdict{},
},
UserData: []byte(allowNetbirdInputRuleID),
}
@@ -251,7 +296,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
ifName := ifname(m.wgIface.Name())
for _, rule := range existedRules {
- if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" {
+ if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput {
if len(rule.Exprs) < 4 {
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
continue
@@ -265,3 +310,33 @@ func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftable
}
return nil
}
+
+func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) {
+ rule := &nftables.Rule{
+ Table: table,
+ Chain: chain,
+ Exprs: []expr.Any{
+ &expr.Ct{
+ Key: expr.CtKeySTATE,
+ Register: 1,
+ },
+ &expr.Bitwise{
+ SourceRegister: 1,
+ DestRegister: 1,
+ Len: 4,
+ Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
+ Xor: binaryutil.NativeEndian.PutUint32(0),
+ },
+ &expr.Cmp{
+ Op: expr.CmpOpNeq,
+ Register: 1,
+ Data: []byte{0, 0, 0, 0},
+ },
+ &expr.Verdict{
+ Kind: expr.VerdictAccept,
+ },
+ },
+ }
+
+ conn.InsertRule(rule)
+}
diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go
index 1f226e315..904050a51 100644
--- a/client/firewall/nftables/manager_linux_test.go
+++ b/client/firewall/nftables/manager_linux_test.go
@@ -9,14 +9,30 @@ import (
"time"
"github.com/google/nftables"
+ "github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall/manager"
- "github.com/netbirdio/netbird/iface"
+ "github.com/netbirdio/netbird/client/iface"
)
+var ifaceMock = &iFaceMock{
+ NameFunc: func() string {
+ return "lo"
+ },
+ AddressFunc: func() iface.WGAddress {
+ return iface.WGAddress{
+ IP: net.ParseIP("100.96.0.1"),
+ Network: &net.IPNet{
+ IP: net.ParseIP("100.96.0.0"),
+ Mask: net.IPv4Mask(255, 255, 255, 0),
+ },
+ }
+ },
+}
+
// iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct {
NameFunc func() string
@@ -40,23 +56,9 @@ func (i *iFaceMock) Address() iface.WGAddress {
func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestNftablesManager(t *testing.T) {
- mock := &iFaceMock{
- NameFunc: func() string {
- return "lo"
- },
- AddressFunc: func() iface.WGAddress {
- return iface.WGAddress{
- IP: net.ParseIP("100.96.0.1"),
- Network: &net.IPNet{
- IP: net.ParseIP("100.96.0.0"),
- Mask: net.IPv4Mask(255, 255, 255, 0),
- },
- }
- },
- }
// just check on the local interface
- manager, err := Create(context.Background(), mock)
+ manager, err := Create(context.Background(), ifaceMock)
require.NoError(t, err)
time.Sleep(time.Second * 3)
@@ -70,7 +72,7 @@ func TestNftablesManager(t *testing.T) {
testClient := &nftables.Conn{}
- rule, err := manager.AddFiltering(
+ rule, err := manager.AddPeerFiltering(
ip,
fw.ProtocolTCP,
nil,
@@ -88,17 +90,34 @@ func TestNftablesManager(t *testing.T) {
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules")
- require.Len(t, rules, 1, "expected 1 rules")
+ require.Len(t, rules, 2, "expected 2 rules")
+
+ expectedExprs1 := []expr.Any{
+ &expr.Ct{
+ Key: expr.CtKeySTATE,
+ Register: 1,
+ },
+ &expr.Bitwise{
+ SourceRegister: 1,
+ DestRegister: 1,
+ Len: 4,
+ Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
+ Xor: binaryutil.NativeEndian.PutUint32(0),
+ },
+ &expr.Cmp{
+ Op: expr.CmpOpNeq,
+ Register: 1,
+ Data: []byte{0, 0, 0, 0},
+ },
+ &expr.Verdict{
+ Kind: expr.VerdictAccept,
+ },
+ }
+ require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions")
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
- expectedExprs := []expr.Any{
- &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 1,
- Data: ifname("lo"),
- },
+ expectedExprs2 := []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
@@ -134,10 +153,10 @@ func TestNftablesManager(t *testing.T) {
},
&expr.Verdict{Kind: expr.VerdictDrop},
}
- require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions")
+ require.ElementsMatch(t, rules[1].Exprs, expectedExprs2, "expected the same expressions")
for _, r := range rule {
- err = manager.DeleteRule(r)
+ err = manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
}
@@ -146,7 +165,8 @@ func TestNftablesManager(t *testing.T) {
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules")
- require.Len(t, rules, 0, "expected 0 rules after deletion")
+ // established rule remains
+ require.Len(t, rules, 1, "expected 1 rules after deletion")
err = manager.Reset()
require.NoError(t, err, "failed to reset")
@@ -187,9 +207,9 @@ func TestNFtablesCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 {
- _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
+ _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
- _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
+ _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
require.NoError(t, err, "failed to add rule")
diff --git a/client/firewall/nftables/route_linux.go b/client/firewall/nftables/route_linux.go
deleted file mode 100644
index 71d5ac88e..000000000
--- a/client/firewall/nftables/route_linux.go
+++ /dev/null
@@ -1,431 +0,0 @@
-package nftables
-
-import (
- "bytes"
- "context"
- "errors"
- "fmt"
- "net"
- "net/netip"
-
- "github.com/google/nftables"
- "github.com/google/nftables/binaryutil"
- "github.com/google/nftables/expr"
- log "github.com/sirupsen/logrus"
-
- "github.com/netbirdio/netbird/client/firewall/manager"
-)
-
-const (
- chainNameRouteingFw = "netbird-rt-fwd"
- chainNameRoutingNat = "netbird-rt-nat"
-
- userDataAcceptForwardRuleSrc = "frwacceptsrc"
- userDataAcceptForwardRuleDst = "frwacceptdst"
-
- loopbackInterface = "lo\x00"
-)
-
-// some presets for building nftable rules
-var (
- zeroXor = binaryutil.NativeEndian.PutUint32(0)
-
- exprCounterAccept = []expr.Any{
- &expr.Counter{},
- &expr.Verdict{
- Kind: expr.VerdictAccept,
- },
- }
-
- errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
-)
-
-type router struct {
- ctx context.Context
- stop context.CancelFunc
- conn *nftables.Conn
- workTable *nftables.Table
- filterTable *nftables.Table
- chains map[string]*nftables.Chain
- // rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
- rules map[string]*nftables.Rule
- isDefaultFwdRulesEnabled bool
-}
-
-func newRouter(parentCtx context.Context, workTable *nftables.Table) (*router, error) {
- ctx, cancel := context.WithCancel(parentCtx)
-
- r := &router{
- ctx: ctx,
- stop: cancel,
- conn: &nftables.Conn{},
- workTable: workTable,
- chains: make(map[string]*nftables.Chain),
- rules: make(map[string]*nftables.Rule),
- }
-
- var err error
- r.filterTable, err = r.loadFilterTable()
- if err != nil {
- if errors.Is(err, errFilterTableNotFound) {
- log.Warnf("table 'filter' not found for forward rules")
- } else {
- return nil, err
- }
- }
-
- err = r.cleanUpDefaultForwardRules()
- if err != nil {
- log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
- }
-
- err = r.createContainers()
- if err != nil {
- log.Errorf("failed to create containers for route: %s", err)
- }
- return r, err
-}
-
-func (r *router) RouteingFwChainName() string {
- return chainNameRouteingFw
-}
-
-// ResetForwardRules cleans existing nftables default forward rules from the system
-func (r *router) ResetForwardRules() {
- err := r.cleanUpDefaultForwardRules()
- if err != nil {
- log.Errorf("failed to reset forward rules: %s", err)
- }
-}
-
-func (r *router) loadFilterTable() (*nftables.Table, error) {
- tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
- if err != nil {
- return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
- }
-
- for _, table := range tables {
- if table.Name == "filter" {
- return table, nil
- }
- }
-
- return nil, errFilterTableNotFound
-}
-
-func (r *router) createContainers() error {
-
- r.chains[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{
- Name: chainNameRouteingFw,
- Table: r.workTable,
- })
-
- r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
- Name: chainNameRoutingNat,
- Table: r.workTable,
- Hooknum: nftables.ChainHookPostrouting,
- Priority: nftables.ChainPriorityNATSource - 1,
- Type: nftables.ChainTypeNAT,
- })
-
- // Add RETURN rule for loopback interface
- loRule := &nftables.Rule{
- Table: r.workTable,
- Chain: r.chains[chainNameRoutingNat],
- Exprs: []expr.Any{
- &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
- &expr.Cmp{
- Op: expr.CmpOpEq,
- Register: 1,
- Data: []byte(loopbackInterface),
- },
- &expr.Verdict{Kind: expr.VerdictReturn},
- },
- }
- r.conn.InsertRule(loRule)
-
- err := r.refreshRulesMap()
- if err != nil {
- log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
- }
-
- err = r.conn.Flush()
- if err != nil {
- return fmt.Errorf("nftables: unable to initialize table: %v", err)
- }
- return nil
-}
-
-// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain
-func (r *router) AddRoutingRules(pair manager.RouterPair) error {
- err := r.refreshRulesMap()
- if err != nil {
- return err
- }
-
- err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
- if err != nil {
- return err
- }
- err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
- if err != nil {
- return err
- }
-
- if pair.Masquerade {
- err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
- if err != nil {
- return err
- }
- err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
- if err != nil {
- return err
- }
- }
-
- if r.filterTable != nil && !r.isDefaultFwdRulesEnabled {
- log.Debugf("add default accept forward rule")
- r.acceptForwardRule(pair.Source)
- }
-
- err = r.conn.Flush()
- if err != nil {
- return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.Destination, err)
- }
- return nil
-}
-
-// addRoutingRule inserts a nftable rule to the conn client flush queue
-func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
- sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
- destExp := generateCIDRMatcherExpressions(false, pair.Destination)
-
- var expression []expr.Any
- if isNat {
- expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) // nolint:gocritic
- } else {
- expression = append(sourceExp, append(destExp, exprCounterAccept...)...) // nolint:gocritic
- }
-
- ruleKey := manager.GenKey(format, pair.ID)
-
- _, exists := r.rules[ruleKey]
- if exists {
- err := r.removeRoutingRule(format, pair)
- if err != nil {
- return err
- }
- }
-
- r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
- Table: r.workTable,
- Chain: r.chains[chainName],
- Exprs: expression,
- UserData: []byte(ruleKey),
- })
- return nil
-}
-
-func (r *router) acceptForwardRule(sourceNetwork string) {
- src := generateCIDRMatcherExpressions(true, sourceNetwork)
- dst := generateCIDRMatcherExpressions(false, "0.0.0.0/0")
-
- var exprs []expr.Any
- exprs = append(src, append(dst, &expr.Verdict{ // nolint:gocritic
- Kind: expr.VerdictAccept,
- })...)
-
- rule := &nftables.Rule{
- Table: r.filterTable,
- Chain: &nftables.Chain{
- Name: "FORWARD",
- Table: r.filterTable,
- Type: nftables.ChainTypeFilter,
- Hooknum: nftables.ChainHookForward,
- Priority: nftables.ChainPriorityFilter,
- },
- Exprs: exprs,
- UserData: []byte(userDataAcceptForwardRuleSrc),
- }
-
- r.conn.AddRule(rule)
-
- src = generateCIDRMatcherExpressions(true, "0.0.0.0/0")
- dst = generateCIDRMatcherExpressions(false, sourceNetwork)
-
- exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic
- Kind: expr.VerdictAccept,
- })...)
-
- rule = &nftables.Rule{
- Table: r.filterTable,
- Chain: &nftables.Chain{
- Name: "FORWARD",
- Table: r.filterTable,
- Type: nftables.ChainTypeFilter,
- Hooknum: nftables.ChainHookForward,
- Priority: nftables.ChainPriorityFilter,
- },
- Exprs: exprs,
- UserData: []byte(userDataAcceptForwardRuleDst),
- }
- r.conn.AddRule(rule)
- r.isDefaultFwdRulesEnabled = true
-}
-
-// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
-func (r *router) RemoveRoutingRules(pair manager.RouterPair) error {
- err := r.refreshRulesMap()
- if err != nil {
- return err
- }
-
- err = r.removeRoutingRule(manager.ForwardingFormat, pair)
- if err != nil {
- return err
- }
-
- err = r.removeRoutingRule(manager.InForwardingFormat, manager.GetInPair(pair))
- if err != nil {
- return err
- }
-
- err = r.removeRoutingRule(manager.NatFormat, pair)
- if err != nil {
- return err
- }
-
- err = r.removeRoutingRule(manager.InNatFormat, manager.GetInPair(pair))
- if err != nil {
- return err
- }
-
- if len(r.rules) == 0 {
- err := r.cleanUpDefaultForwardRules()
- if err != nil {
- log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
- }
- }
-
- err = r.conn.Flush()
- if err != nil {
- return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
- }
- log.Debugf("nftables: removed rules for %s", pair.Destination)
- return nil
-}
-
-// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
-func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error {
- ruleKey := manager.GenKey(format, pair.ID)
-
- rule, found := r.rules[ruleKey]
- if found {
- ruleType := "forwarding"
- if rule.Chain.Type == nftables.ChainTypeNAT {
- ruleType = "nat"
- }
-
- err := r.conn.DelRule(rule)
- if err != nil {
- return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.Destination, err)
- }
-
- log.Debugf("nftables: removing %s rule for %s", ruleType, pair.Destination)
-
- delete(r.rules, ruleKey)
- }
- return nil
-}
-
-// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
-// duplicates and to get missing attributes that we don't have when adding new rules
-func (r *router) refreshRulesMap() error {
- for _, chain := range r.chains {
- rules, err := r.conn.GetRules(chain.Table, chain)
- if err != nil {
- return fmt.Errorf("nftables: unable to list rules: %v", err)
- }
- for _, rule := range rules {
- if len(rule.UserData) > 0 {
- r.rules[string(rule.UserData)] = rule
- }
- }
- }
- return nil
-}
-
-func (r *router) cleanUpDefaultForwardRules() error {
- if r.filterTable == nil {
- r.isDefaultFwdRulesEnabled = false
- return nil
- }
-
- chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
- if err != nil {
- return err
- }
-
- var rules []*nftables.Rule
- for _, chain := range chains {
- if chain.Table.Name != r.filterTable.Name {
- continue
- }
- if chain.Name != "FORWARD" {
- continue
- }
-
- rules, err = r.conn.GetRules(r.filterTable, chain)
- if err != nil {
- return err
- }
- }
-
- for _, rule := range rules {
- if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleSrc)) || bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleDst)) {
- err := r.conn.DelRule(rule)
- if err != nil {
- return err
- }
- }
- }
- r.isDefaultFwdRulesEnabled = false
- return r.conn.Flush()
-}
-
-// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
-func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any {
- ip, network, _ := net.ParseCIDR(cidr)
- ipToAdd, _ := netip.AddrFromSlice(ip)
- add := ipToAdd.Unmap()
-
- var offSet uint32
- if source {
- offSet = 12 // src offset
- } else {
- offSet = 16 // dst offset
- }
-
- return []expr.Any{
- // fetch src add
- &expr.Payload{
- DestRegister: 1,
- Base: expr.PayloadBaseNetworkHeader,
- Offset: offSet,
- Len: 4,
- },
- // net mask
- &expr.Bitwise{
- DestRegister: 1,
- SourceRegister: 1,
- Len: 4,
- Mask: network.Mask,
- Xor: zeroXor,
- },
- // net address
- &expr.Cmp{
- Register: 1,
- Data: add.AsSlice(),
- },
- }
-}
diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go
new file mode 100644
index 000000000..aa61e1858
--- /dev/null
+++ b/client/firewall/nftables/router_linux.go
@@ -0,0 +1,798 @@
+package nftables
+
+import (
+ "bytes"
+ "context"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "net"
+ "net/netip"
+ "strings"
+
+ "github.com/google/nftables"
+ "github.com/google/nftables/binaryutil"
+ "github.com/google/nftables/expr"
+ "github.com/hashicorp/go-multierror"
+ log "github.com/sirupsen/logrus"
+
+ nberrors "github.com/netbirdio/netbird/client/errors"
+ firewall "github.com/netbirdio/netbird/client/firewall/manager"
+ "github.com/netbirdio/netbird/client/internal/acl/id"
+ "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
+)
+
+const (
+ chainNameRoutingFw = "netbird-rt-fwd"
+ chainNameRoutingNat = "netbird-rt-nat"
+ chainNameForward = "FORWARD"
+
+ userDataAcceptForwardRuleIif = "frwacceptiif"
+ userDataAcceptForwardRuleOif = "frwacceptoif"
+)
+
+const refreshRulesMapError = "refresh rules map: %w"
+
+var (
+ errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
+)
+
+type router struct {
+ ctx context.Context
+ stop context.CancelFunc
+ conn *nftables.Conn
+ workTable *nftables.Table
+ filterTable *nftables.Table
+ chains map[string]*nftables.Chain
+ // rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
+ rules map[string]*nftables.Rule
+ ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
+
+ wgIface iFaceMapper
+ legacyManagement bool
+}
+
+func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
+ ctx, cancel := context.WithCancel(parentCtx)
+
+ r := &router{
+ ctx: ctx,
+ stop: cancel,
+ conn: &nftables.Conn{},
+ workTable: workTable,
+ chains: make(map[string]*nftables.Chain),
+ rules: make(map[string]*nftables.Rule),
+ wgIface: wgIface,
+ }
+
+ r.ipsetCounter = refcounter.New(
+ r.createIpSet,
+ r.deleteIpSet,
+ )
+
+ var err error
+ r.filterTable, err = r.loadFilterTable()
+ if err != nil {
+ if errors.Is(err, errFilterTableNotFound) {
+ log.Warnf("table 'filter' not found for forward rules")
+ } else {
+ return nil, err
+ }
+ }
+
+ err = r.cleanUpDefaultForwardRules()
+ if err != nil {
+ log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
+ }
+
+ err = r.createContainers()
+ if err != nil {
+ log.Errorf("failed to create containers for route: %s", err)
+ }
+ return r, err
+}
+
+// Reset cleans existing nftables default forward rules from the system
+func (r *router) Reset() error {
+ // clear without deleting the ipsets, the nf table will be deleted by the caller
+ r.ipsetCounter.Clear()
+
+ return r.cleanUpDefaultForwardRules()
+}
+
+func (r *router) cleanUpDefaultForwardRules() error {
+ if r.filterTable == nil {
+ return nil
+ }
+
+ chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
+ if err != nil {
+ return fmt.Errorf("list chains: %v", err)
+ }
+
+ for _, chain := range chains {
+ if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
+ continue
+ }
+
+ rules, err := r.conn.GetRules(r.filterTable, chain)
+ if err != nil {
+ return fmt.Errorf("get rules: %v", err)
+ }
+
+ for _, rule := range rules {
+ if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
+ bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
+ if err := r.conn.DelRule(rule); err != nil {
+ return fmt.Errorf("delete rule: %v", err)
+ }
+ }
+ }
+ }
+
+ return r.conn.Flush()
+}
+
+func (r *router) loadFilterTable() (*nftables.Table, error) {
+ tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
+ if err != nil {
+ return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
+ }
+
+ for _, table := range tables {
+ if table.Name == "filter" {
+ return table, nil
+ }
+ }
+
+ return nil, errFilterTableNotFound
+}
+
+func (r *router) createContainers() error {
+
+ r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
+ Name: chainNameRoutingFw,
+ Table: r.workTable,
+ })
+
+ insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
+
+ r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
+ Name: chainNameRoutingNat,
+ Table: r.workTable,
+ Hooknum: nftables.ChainHookPostrouting,
+ Priority: nftables.ChainPriorityNATSource - 1,
+ Type: nftables.ChainTypeNAT,
+ })
+
+ r.acceptForwardRules()
+
+ err := r.refreshRulesMap()
+ if err != nil {
+ log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
+ }
+
+ err = r.conn.Flush()
+ if err != nil {
+ return fmt.Errorf("nftables: unable to initialize table: %v", err)
+ }
+ return nil
+}
+
+// AddRouteFiltering appends a nftables rule to the routing chain
+func (r *router) AddRouteFiltering(
+ sources []netip.Prefix,
+ destination netip.Prefix,
+ proto firewall.Protocol,
+ sPort *firewall.Port,
+ dPort *firewall.Port,
+ action firewall.Action,
+) (firewall.Rule, error) {
+ ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
+ if _, ok := r.rules[string(ruleKey)]; ok {
+ return ruleKey, nil
+ }
+
+ chain := r.chains[chainNameRoutingFw]
+ var exprs []expr.Any
+
+ switch {
+ case len(sources) == 1 && sources[0].Bits() == 0:
+ // If it's 0.0.0.0/0, we don't need to add any source matching
+ case len(sources) == 1:
+ // If there's only one source, we can use it directly
+ exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...)
+ default:
+ // If there are multiple sources, create or get an ipset
+ var err error
+ exprs, err = r.getIpSetExprs(sources, exprs)
+ if err != nil {
+ return nil, fmt.Errorf("get ipset expressions: %w", err)
+ }
+ }
+
+ // Handle destination
+ exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...)
+
+ // Handle protocol
+ if proto != firewall.ProtocolALL {
+ protoNum, err := protoToInt(proto)
+ if err != nil {
+ return nil, fmt.Errorf("convert protocol to number: %w", err)
+ }
+ exprs = append(exprs, &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1})
+ exprs = append(exprs, &expr.Cmp{
+ Op: expr.CmpOpEq,
+ Register: 1,
+ Data: []byte{protoNum},
+ })
+
+ exprs = append(exprs, applyPort(sPort, true)...)
+ exprs = append(exprs, applyPort(dPort, false)...)
+ }
+
+ exprs = append(exprs, &expr.Counter{})
+
+ var verdict expr.VerdictKind
+ if action == firewall.ActionAccept {
+ verdict = expr.VerdictAccept
+ } else {
+ verdict = expr.VerdictDrop
+ }
+ exprs = append(exprs, &expr.Verdict{Kind: verdict})
+
+ rule := &nftables.Rule{
+ Table: r.workTable,
+ Chain: chain,
+ Exprs: exprs,
+ UserData: []byte(ruleKey),
+ }
+
+ r.rules[string(ruleKey)] = r.conn.AddRule(rule)
+
+ return ruleKey, r.conn.Flush()
+}
+
+func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
+ setName := firewall.GenerateSetName(sources)
+ ref, err := r.ipsetCounter.Increment(setName, sources)
+ if err != nil {
+ return nil, fmt.Errorf("create or get ipset for sources: %w", err)
+ }
+
+ exprs = append(exprs,
+ &expr.Payload{
+ DestRegister: 1,
+ Base: expr.PayloadBaseNetworkHeader,
+ Offset: 12,
+ Len: 4,
+ },
+ &expr.Lookup{
+ SourceRegister: 1,
+ SetName: ref.Out.Name,
+ SetID: ref.Out.ID,
+ },
+ )
+ return exprs, nil
+}
+
+func (r *router) DeleteRouteRule(rule firewall.Rule) error {
+ if err := r.refreshRulesMap(); err != nil {
+ return fmt.Errorf(refreshRulesMapError, err)
+ }
+
+ ruleKey := rule.GetRuleID()
+ nftRule, exists := r.rules[ruleKey]
+ if !exists {
+ log.Debugf("route rule %s not found", ruleKey)
+ return nil
+ }
+
+ setName := r.findSetNameInRule(nftRule)
+
+ if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
+ return fmt.Errorf("delete: %w", err)
+ }
+
+ if setName != "" {
+ if _, err := r.ipsetCounter.Decrement(setName); err != nil {
+ return fmt.Errorf("decrement ipset reference: %w", err)
+ }
+ }
+
+ if err := r.conn.Flush(); err != nil {
+ return fmt.Errorf(flushError, err)
+ }
+
+ return nil
+}
+
+func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) {
+ // overlapping prefixes will result in an error, so we need to merge them
+ sources = firewall.MergeIPRanges(sources)
+
+ set := &nftables.Set{
+ Name: setName,
+ Table: r.workTable,
+ // required for prefixes
+ Interval: true,
+ KeyType: nftables.TypeIPAddr,
+ }
+
+ var elements []nftables.SetElement
+ for _, prefix := range sources {
+ // TODO: Implement IPv6 support
+ if prefix.Addr().Is6() {
+ log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
+ continue
+ }
+
+ // nftables needs half-open intervals [firstIP, lastIP) for prefixes
+ // e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
+ firstIP := prefix.Addr()
+ lastIP := calculateLastIP(prefix).Next()
+
+ elements = append(elements,
+ // the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247
+ // nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true},
+ nftables.SetElement{Key: firstIP.AsSlice()},
+ nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
+ )
+ }
+
+ if err := r.conn.AddSet(set, elements); err != nil {
+ return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
+ }
+
+ if err := r.conn.Flush(); err != nil {
+ return nil, fmt.Errorf("flush error: %w", err)
+ }
+
+ log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
+
+ return set, nil
+}
+
+// calculateLastIP determines the last IP in a given prefix.
+func calculateLastIP(prefix netip.Prefix) netip.Addr {
+ hostMask := ^uint32(0) >> prefix.Masked().Bits()
+ lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask
+
+ return netip.AddrFrom4(uint32ToBytes(lastIP))
+}
+
+// Utility function to convert netip.Addr to uint32.
+func uint32FromNetipAddr(addr netip.Addr) uint32 {
+ b := addr.As4()
+ return binary.BigEndian.Uint32(b[:])
+}
+
+// Utility function to convert uint32 to a netip-compatible byte slice.
+func uint32ToBytes(ip uint32) [4]byte {
+ var b [4]byte
+ binary.BigEndian.PutUint32(b[:], ip)
+ return b
+}
+
+func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
+ r.conn.DelSet(set)
+ if err := r.conn.Flush(); err != nil {
+ return fmt.Errorf(flushError, err)
+ }
+
+ log.Debugf("Deleted unused ipset %s", setName)
+ return nil
+}
+
+func (r *router) findSetNameInRule(rule *nftables.Rule) string {
+ for _, e := range rule.Exprs {
+ if lookup, ok := e.(*expr.Lookup); ok {
+ return lookup.SetName
+ }
+ }
+ return ""
+}
+
+func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
+ if err := r.conn.DelRule(rule); err != nil {
+ return fmt.Errorf("delete rule %s: %w", ruleKey, err)
+ }
+ delete(r.rules, ruleKey)
+
+ log.Debugf("removed route rule %s", ruleKey)
+
+ return nil
+}
+
+// AddNatRule appends a nftables rule pair to the nat chain
+func (r *router) AddNatRule(pair firewall.RouterPair) error {
+ if err := r.refreshRulesMap(); err != nil {
+ return fmt.Errorf(refreshRulesMapError, err)
+ }
+
+ if r.legacyManagement {
+ log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
+ if err := r.addLegacyRouteRule(pair); err != nil {
+ return fmt.Errorf("add legacy routing rule: %w", err)
+ }
+ }
+
+ if pair.Masquerade {
+ if err := r.addNatRule(pair); err != nil {
+ return fmt.Errorf("add nat rule: %w", err)
+ }
+
+ if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
+ return fmt.Errorf("add inverse nat rule: %w", err)
+ }
+ }
+
+ if err := r.conn.Flush(); err != nil {
+ return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err)
+ }
+
+ return nil
+}
+
+// addNatRule inserts a nftables rule to the conn client flush queue
+func (r *router) addNatRule(pair firewall.RouterPair) error {
+ sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
+ destExp := generateCIDRMatcherExpressions(false, pair.Destination)
+
+ dir := expr.MetaKeyIIFNAME
+ if pair.Inverse {
+ dir = expr.MetaKeyOIFNAME
+ }
+
+ intf := ifname(r.wgIface.Name())
+ exprs := []expr.Any{
+ &expr.Meta{
+ Key: dir,
+ Register: 1,
+ },
+ &expr.Cmp{
+ Op: expr.CmpOpEq,
+ Register: 1,
+ Data: intf,
+ },
+ }
+
+ exprs = append(exprs, sourceExp...)
+ exprs = append(exprs, destExp...)
+ exprs = append(exprs,
+ &expr.Counter{}, &expr.Masq{},
+ )
+
+ ruleKey := firewall.GenKey(firewall.NatFormat, pair)
+
+ if _, exists := r.rules[ruleKey]; exists {
+ if err := r.removeNatRule(pair); err != nil {
+ return fmt.Errorf("remove routing rule: %w", err)
+ }
+ }
+
+ r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
+ Table: r.workTable,
+ Chain: r.chains[chainNameRoutingNat],
+ Exprs: exprs,
+ UserData: []byte(ruleKey),
+ })
+ return nil
+}
+
+// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
+func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
+ sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
+ destExp := generateCIDRMatcherExpressions(false, pair.Destination)
+
+ exprs := []expr.Any{
+ &expr.Counter{},
+ &expr.Verdict{
+ Kind: expr.VerdictAccept,
+ },
+ }
+
+ expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic
+
+ ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
+
+ if _, exists := r.rules[ruleKey]; exists {
+ if err := r.removeLegacyRouteRule(pair); err != nil {
+ return fmt.Errorf("remove legacy routing rule: %w", err)
+ }
+ }
+
+ r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
+ Table: r.workTable,
+ Chain: r.chains[chainNameRoutingFw],
+ Exprs: expression,
+ UserData: []byte(ruleKey),
+ })
+ return nil
+}
+
+// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls
+func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
+ ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
+
+ if rule, exists := r.rules[ruleKey]; exists {
+ if err := r.conn.DelRule(rule); err != nil {
+ return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
+ }
+
+ log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
+
+ delete(r.rules, ruleKey)
+ } else {
+ log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey)
+ }
+
+ return nil
+}
+
+// GetLegacyManagement returns the route manager's legacy management mode
+func (r *router) GetLegacyManagement() bool {
+ return r.legacyManagement
+}
+
+// SetLegacyManagement sets the route manager to use legacy management mode
+func (r *router) SetLegacyManagement(isLegacy bool) {
+ r.legacyManagement = isLegacy
+}
+
+// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
+func (r *router) RemoveAllLegacyRouteRules() error {
+ if err := r.refreshRulesMap(); err != nil {
+ return fmt.Errorf(refreshRulesMapError, err)
+ }
+
+ var merr *multierror.Error
+ for k, rule := range r.rules {
+ if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
+ continue
+ }
+ if err := r.conn.DelRule(rule); err != nil {
+ merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
+ }
+ }
+ return nberrors.FormatErrorOrNil(merr)
+}
+
+// acceptForwardRules adds iif/oif rules in the filter table/forward chain to make sure
+// that our traffic is not dropped by existing rules there.
+// The existing FORWARD rules/policies decide outbound traffic towards our interface.
+// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
+func (r *router) acceptForwardRules() {
+ if r.filterTable == nil {
+ log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
+ return
+ }
+
+ intf := ifname(r.wgIface.Name())
+
+ // Rule for incoming interface (iif) with counter
+ iifRule := &nftables.Rule{
+ Table: r.filterTable,
+ Chain: &nftables.Chain{
+ Name: "FORWARD",
+ Table: r.filterTable,
+ Type: nftables.ChainTypeFilter,
+ Hooknum: nftables.ChainHookForward,
+ Priority: nftables.ChainPriorityFilter,
+ },
+ Exprs: []expr.Any{
+ &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
+ &expr.Cmp{
+ Op: expr.CmpOpEq,
+ Register: 1,
+ Data: intf,
+ },
+ &expr.Counter{},
+ &expr.Verdict{Kind: expr.VerdictAccept},
+ },
+ UserData: []byte(userDataAcceptForwardRuleIif),
+ }
+ r.conn.InsertRule(iifRule)
+
+ // Rule for outgoing interface (oif) with counter
+ oifRule := &nftables.Rule{
+ Table: r.filterTable,
+ Chain: &nftables.Chain{
+ Name: "FORWARD",
+ Table: r.filterTable,
+ Type: nftables.ChainTypeFilter,
+ Hooknum: nftables.ChainHookForward,
+ Priority: nftables.ChainPriorityFilter,
+ },
+ Exprs: []expr.Any{
+ &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
+ &expr.Cmp{
+ Op: expr.CmpOpEq,
+ Register: 1,
+ Data: intf,
+ },
+ &expr.Ct{
+ Key: expr.CtKeySTATE,
+ Register: 2,
+ },
+ &expr.Bitwise{
+ SourceRegister: 2,
+ DestRegister: 2,
+ Len: 4,
+ Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
+ Xor: binaryutil.NativeEndian.PutUint32(0),
+ },
+ &expr.Cmp{
+ Op: expr.CmpOpNeq,
+ Register: 2,
+ Data: []byte{0, 0, 0, 0},
+ },
+ &expr.Counter{},
+ &expr.Verdict{Kind: expr.VerdictAccept},
+ },
+ UserData: []byte(userDataAcceptForwardRuleOif),
+ }
+
+ r.conn.InsertRule(oifRule)
+}
+
+// RemoveNatRule removes a nftables rule pair from nat chains
+func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
+ if err := r.refreshRulesMap(); err != nil {
+ return fmt.Errorf(refreshRulesMapError, err)
+ }
+
+ if err := r.removeNatRule(pair); err != nil {
+ return fmt.Errorf("remove nat rule: %w", err)
+ }
+
+ if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
+ return fmt.Errorf("remove inverse nat rule: %w", err)
+ }
+
+ if err := r.removeLegacyRouteRule(pair); err != nil {
+ return fmt.Errorf("remove legacy routing rule: %w", err)
+ }
+
+ if err := r.conn.Flush(); err != nil {
+ return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
+ }
+
+ log.Debugf("nftables: removed rules for %s", pair.Destination)
+ return nil
+}
+
+// removeNatRule adds a nftables rule to the removal queue and deletes it from the rules map
+func (r *router) removeNatRule(pair firewall.RouterPair) error {
+ ruleKey := firewall.GenKey(firewall.NatFormat, pair)
+
+ if rule, exists := r.rules[ruleKey]; exists {
+ err := r.conn.DelRule(rule)
+ if err != nil {
+ return fmt.Errorf("remove nat rule %s -> %s: %v", pair.Source, pair.Destination, err)
+ }
+
+ log.Debugf("nftables: removed nat rule %s -> %s", pair.Source, pair.Destination)
+
+ delete(r.rules, ruleKey)
+ } else {
+ log.Debugf("nftables: nat rule %s not found", ruleKey)
+ }
+
+ return nil
+}
+
+// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
+// duplicates and to get missing attributes that we don't have when adding new rules
+func (r *router) refreshRulesMap() error {
+ for _, chain := range r.chains {
+ rules, err := r.conn.GetRules(chain.Table, chain)
+ if err != nil {
+ return fmt.Errorf("nftables: unable to list rules: %v", err)
+ }
+ for _, rule := range rules {
+ if len(rule.UserData) > 0 {
+ r.rules[string(rule.UserData)] = rule
+ }
+ }
+ }
+ return nil
+}
+
+// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
+func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
+ var offset uint32
+ if source {
+ offset = 12 // src offset
+ } else {
+ offset = 16 // dst offset
+ }
+
+ ones := prefix.Bits()
+ // 0.0.0.0/0 doesn't need extra expressions
+ if ones == 0 {
+ return nil
+ }
+
+ mask := net.CIDRMask(ones, 32)
+
+ return []expr.Any{
+ &expr.Payload{
+ DestRegister: 1,
+ Base: expr.PayloadBaseNetworkHeader,
+ Offset: offset,
+ Len: 4,
+ },
+ // netmask
+ &expr.Bitwise{
+ DestRegister: 1,
+ SourceRegister: 1,
+ Len: 4,
+ Mask: mask,
+ Xor: []byte{0, 0, 0, 0},
+ },
+ // net address
+ &expr.Cmp{
+ Op: expr.CmpOpEq,
+ Register: 1,
+ Data: prefix.Masked().Addr().AsSlice(),
+ },
+ }
+}
+
+func applyPort(port *firewall.Port, isSource bool) []expr.Any {
+ if port == nil {
+ return nil
+ }
+
+ var exprs []expr.Any
+
+ offset := uint32(2) // Default offset for destination port
+ if isSource {
+ offset = 0 // Offset for source port
+ }
+
+ exprs = append(exprs, &expr.Payload{
+ DestRegister: 1,
+ Base: expr.PayloadBaseTransportHeader,
+ Offset: offset,
+ Len: 2,
+ })
+
+ if port.IsRange && len(port.Values) == 2 {
+ // Handle port range
+ exprs = append(exprs,
+ &expr.Cmp{
+ Op: expr.CmpOpGte,
+ Register: 1,
+ Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[0])),
+ },
+ &expr.Cmp{
+ Op: expr.CmpOpLte,
+ Register: 1,
+ Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[1])),
+ },
+ )
+ } else {
+ // Handle single port or multiple ports
+ for i, p := range port.Values {
+ if i > 0 {
+ // Add a bitwise OR operation between port checks
+ exprs = append(exprs, &expr.Bitwise{
+ SourceRegister: 1,
+ DestRegister: 1,
+ Len: 4,
+ Mask: []byte{0x00, 0x00, 0xff, 0xff},
+ Xor: []byte{0x00, 0x00, 0x00, 0x00},
+ })
+ }
+ exprs = append(exprs, &expr.Cmp{
+ Op: expr.CmpOpEq,
+ Register: 1,
+ Data: binaryutil.BigEndian.PutUint16(uint16(p)),
+ })
+ }
+ }
+
+ return exprs
+}
diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go
index 913fbd5d2..bbf92f3be 100644
--- a/client/firewall/nftables/router_linux_test.go
+++ b/client/firewall/nftables/router_linux_test.go
@@ -4,11 +4,15 @@ package nftables
import (
"context"
+ "encoding/binary"
+ "net/netip"
+ "os/exec"
"testing"
"github.com/coreos/go-iptables/iptables"
"github.com/google/nftables"
"github.com/google/nftables/expr"
+ "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -24,56 +28,50 @@ const (
NFTABLES
)
-func TestNftablesManager_InsertRoutingRules(t *testing.T) {
+func TestNftablesManager_AddNatRule(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this OS")
}
table, err := createWorkTable()
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable()
for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
- manager, err := newRouter(context.TODO(), table)
+ manager, err := newRouter(context.TODO(), table, ifaceMock)
require.NoError(t, err, "failed to create router")
nftablesTestingClient := &nftables.Conn{}
- defer manager.ResetForwardRules()
+ defer func(manager *router) {
+ require.NoError(t, manager.Reset(), "failed to reset rules")
+ }(manager)
require.NoError(t, err, "shouldn't return error")
- err = manager.AddRoutingRules(testCase.InputPair)
- defer func() {
- _ = manager.RemoveRoutingRules(testCase.InputPair)
- }()
- require.NoError(t, err, "forwarding pair should be inserted")
+ err = manager.AddNatRule(testCase.InputPair)
+ require.NoError(t, err, "pair should be inserted")
- sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
- destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
- testingExpression := append(sourceExp, destExp...) //nolint:gocritic
- fwdRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
-
- found := 0
- for _, chain := range manager.chains {
- rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
- require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
- for _, rule := range rules {
- if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey {
- require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match")
- found = 1
- }
- }
- }
-
- require.Equal(t, 1, found, "should find at least 1 rule to test")
+ defer func(manager *router, pair firewall.RouterPair) {
+ require.NoError(t, manager.RemoveNatRule(pair), "failed to remove rule")
+ }(manager, testCase.InputPair)
if testCase.InputPair.Masquerade {
- natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
+ sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
+ destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
+ testingExpression := append(sourceExp, destExp...) //nolint:gocritic
+ testingExpression = append(testingExpression,
+ &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
+ &expr.Cmp{
+ Op: expr.CmpOpEq,
+ Register: 1,
+ Data: ifname(ifaceMock.Name()),
+ },
+ )
+
+ natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
@@ -88,27 +86,20 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
- sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
- destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
- testingExpression = append(sourceExp, destExp...) //nolint:gocritic
- inFwdRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
-
- found = 0
- for _, chain := range manager.chains {
- rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
- require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
- for _, rule := range rules {
- if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
- require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
- found = 1
- }
- }
- }
-
- require.Equal(t, 1, found, "should find at least 1 rule to test")
-
if testCase.InputPair.Masquerade {
- inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
+ sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
+ destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
+ testingExpression := append(sourceExp, destExp...) //nolint:gocritic
+ testingExpression = append(testingExpression,
+ &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
+ &expr.Cmp{
+ Op: expr.CmpOpEq,
+ Register: 1,
+ Data: ifname(ifaceMock.Name()),
+ },
+ )
+
+ inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
@@ -122,45 +113,37 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) {
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
+
})
}
}
-func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
+func TestNftablesManager_RemoveNatRule(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this OS")
}
table, err := createWorkTable()
- if err != nil {
- t.Fatal(err)
- }
+ require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable()
for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
- manager, err := newRouter(context.TODO(), table)
+ manager, err := newRouter(context.TODO(), table, ifaceMock)
require.NoError(t, err, "failed to create router")
nftablesTestingClient := &nftables.Conn{}
- defer manager.ResetForwardRules()
+ defer func(manager *router) {
+ require.NoError(t, manager.Reset(), "failed to reset rules")
+ }(manager)
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
- forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
- forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
- insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
- Table: manager.workTable,
- Chain: manager.chains[chainNameRouteingFw],
- Exprs: forwardExp,
- UserData: []byte(forwardRuleKey),
- })
-
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
- natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
+ natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
@@ -169,20 +152,11 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
UserData: []byte(natRuleKey),
})
- sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
- destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
-
- forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic
- inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
- insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
- Table: manager.workTable,
- Chain: manager.chains[chainNameRouteingFw],
- Exprs: forwardExp,
- UserData: []byte(inForwardRuleKey),
- })
+ sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInversePair(testCase.InputPair).Source)
+ destExp = generateCIDRMatcherExpressions(false, firewall.GetInversePair(testCase.InputPair).Destination)
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
- inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
+ inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
@@ -194,9 +168,10 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
err = nftablesTestingClient.Flush()
require.NoError(t, err, "shouldn't return error")
- manager.ResetForwardRules()
+ err = manager.Reset()
+ require.NoError(t, err, "shouldn't return error")
- err = manager.RemoveRoutingRules(testCase.InputPair)
+ err = manager.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
for _, chain := range manager.chains {
@@ -204,9 +179,7 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 {
- require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
- require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist")
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
}
}
@@ -215,6 +188,468 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) {
}
}
+func TestRouter_AddRouteFiltering(t *testing.T) {
+ if check() != NFTABLES {
+ t.Skip("nftables not supported on this system")
+ }
+
+ workTable, err := createWorkTable()
+ require.NoError(t, err, "Failed to create work table")
+
+ defer deleteWorkTable()
+
+ r, err := newRouter(context.Background(), workTable, ifaceMock)
+ require.NoError(t, err, "Failed to create router")
+
+ defer func(r *router) {
+ require.NoError(t, r.Reset(), "Failed to reset rules")
+ }(r)
+
+ tests := []struct {
+ name string
+ sources []netip.Prefix
+ destination netip.Prefix
+ proto firewall.Protocol
+ sPort *firewall.Port
+ dPort *firewall.Port
+ direction firewall.RuleDirection
+ action firewall.Action
+ expectSet bool
+ }{
+ {
+ name: "Basic TCP rule with single source",
+ sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
+ destination: netip.MustParsePrefix("10.0.0.0/24"),
+ proto: firewall.ProtocolTCP,
+ sPort: nil,
+ dPort: &firewall.Port{Values: []int{80}},
+ direction: firewall.RuleDirectionIN,
+ action: firewall.ActionAccept,
+ expectSet: false,
+ },
+ {
+ name: "UDP rule with multiple sources",
+ sources: []netip.Prefix{
+ netip.MustParsePrefix("172.16.0.0/16"),
+ netip.MustParsePrefix("192.168.0.0/16"),
+ },
+ destination: netip.MustParsePrefix("10.0.0.0/8"),
+ proto: firewall.ProtocolUDP,
+ sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
+ dPort: nil,
+ direction: firewall.RuleDirectionOUT,
+ action: firewall.ActionDrop,
+ expectSet: true,
+ },
+ {
+ name: "All protocols rule",
+ sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
+ destination: netip.MustParsePrefix("0.0.0.0/0"),
+ proto: firewall.ProtocolALL,
+ sPort: nil,
+ dPort: nil,
+ direction: firewall.RuleDirectionIN,
+ action: firewall.ActionAccept,
+ expectSet: false,
+ },
+ {
+ name: "ICMP rule",
+ sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
+ destination: netip.MustParsePrefix("10.0.0.0/8"),
+ proto: firewall.ProtocolICMP,
+ sPort: nil,
+ dPort: nil,
+ direction: firewall.RuleDirectionIN,
+ action: firewall.ActionAccept,
+ expectSet: false,
+ },
+ {
+ name: "TCP rule with multiple source ports",
+ sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
+ destination: netip.MustParsePrefix("192.168.0.0/16"),
+ proto: firewall.ProtocolTCP,
+ sPort: &firewall.Port{Values: []int{80, 443, 8080}},
+ dPort: nil,
+ direction: firewall.RuleDirectionOUT,
+ action: firewall.ActionAccept,
+ expectSet: false,
+ },
+ {
+ name: "UDP rule with single IP and port range",
+ sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
+ destination: netip.MustParsePrefix("10.0.0.0/24"),
+ proto: firewall.ProtocolUDP,
+ sPort: nil,
+ dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
+ direction: firewall.RuleDirectionIN,
+ action: firewall.ActionDrop,
+ expectSet: false,
+ },
+ {
+ name: "TCP rule with source and destination ports",
+ sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
+ destination: netip.MustParsePrefix("172.16.0.0/16"),
+ proto: firewall.ProtocolTCP,
+ sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
+ dPort: &firewall.Port{Values: []int{22}},
+ direction: firewall.RuleDirectionOUT,
+ action: firewall.ActionAccept,
+ expectSet: false,
+ },
+ {
+ name: "Drop all incoming traffic",
+ sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
+ destination: netip.MustParsePrefix("192.168.0.0/24"),
+ proto: firewall.ProtocolALL,
+ sPort: nil,
+ dPort: nil,
+ direction: firewall.RuleDirectionIN,
+ action: firewall.ActionDrop,
+ expectSet: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
+ require.NoError(t, err, "AddRouteFiltering failed")
+
+ // Check if the rule is in the internal map
+ rule, ok := r.rules[ruleKey.GetRuleID()]
+ assert.True(t, ok, "Rule not found in internal map")
+
+ t.Log("Internal rule expressions:")
+ for i, expr := range rule.Exprs {
+ t.Logf(" [%d] %T: %+v", i, expr, expr)
+ }
+
+ // Verify internal rule content
+ verifyRule(t, rule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
+
+ // Check if the rule exists in nftables and verify its content
+ rules, err := r.conn.GetRules(r.workTable, r.chains[chainNameRoutingFw])
+ require.NoError(t, err, "Failed to get rules from nftables")
+
+ var nftRule *nftables.Rule
+ for _, rule := range rules {
+ if string(rule.UserData) == ruleKey.GetRuleID() {
+ nftRule = rule
+ break
+ }
+ }
+
+ require.NotNil(t, nftRule, "Rule not found in nftables")
+ t.Log("Actual nftables rule expressions:")
+ for i, expr := range nftRule.Exprs {
+ t.Logf(" [%d] %T: %+v", i, expr, expr)
+ }
+
+ // Verify actual nftables rule content
+ verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
+
+ // Clean up
+ err = r.DeleteRouteRule(ruleKey)
+ require.NoError(t, err, "Failed to delete rule")
+ })
+ }
+}
+
+func TestNftablesCreateIpSet(t *testing.T) {
+ if check() != NFTABLES {
+ t.Skip("nftables not supported on this system")
+ }
+
+ workTable, err := createWorkTable()
+ require.NoError(t, err, "Failed to create work table")
+
+ defer deleteWorkTable()
+
+ r, err := newRouter(context.Background(), workTable, ifaceMock)
+ require.NoError(t, err, "Failed to create router")
+
+ defer func() {
+ require.NoError(t, r.Reset(), "Failed to reset router")
+ }()
+
+ tests := []struct {
+ name string
+ sources []netip.Prefix
+ expected []netip.Prefix
+ }{
+ {
+ name: "Single IP",
+ sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
+ },
+ {
+ name: "Multiple IPs",
+ sources: []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.1/32"),
+ netip.MustParsePrefix("10.0.0.1/32"),
+ netip.MustParsePrefix("172.16.0.1/32"),
+ },
+ },
+ {
+ name: "Single Subnet",
+ sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
+ },
+ {
+ name: "Multiple Subnets with Various Prefix Lengths",
+ sources: []netip.Prefix{
+ netip.MustParsePrefix("10.0.0.0/8"),
+ netip.MustParsePrefix("172.16.0.0/16"),
+ netip.MustParsePrefix("192.168.1.0/24"),
+ netip.MustParsePrefix("203.0.113.0/26"),
+ },
+ },
+ {
+ name: "Mix of Single IPs and Subnets in Different Positions",
+ sources: []netip.Prefix{
+ netip.MustParsePrefix("192.168.1.1/32"),
+ netip.MustParsePrefix("10.0.0.0/16"),
+ netip.MustParsePrefix("172.16.0.1/32"),
+ netip.MustParsePrefix("203.0.113.0/24"),
+ },
+ },
+ {
+ name: "Overlapping IPs/Subnets",
+ sources: []netip.Prefix{
+ netip.MustParsePrefix("10.0.0.0/8"),
+ netip.MustParsePrefix("10.0.0.0/16"),
+ netip.MustParsePrefix("10.0.0.1/32"),
+ netip.MustParsePrefix("192.168.0.0/16"),
+ netip.MustParsePrefix("192.168.1.0/24"),
+ netip.MustParsePrefix("192.168.1.1/32"),
+ },
+ expected: []netip.Prefix{
+ netip.MustParsePrefix("10.0.0.0/8"),
+ netip.MustParsePrefix("192.168.0.0/16"),
+ },
+ },
+ }
+
+ // Add this helper function inside TestNftablesCreateIpSet
+ printNftSets := func() {
+ cmd := exec.Command("nft", "list", "sets")
+ output, err := cmd.CombinedOutput()
+ if err != nil {
+ t.Logf("Failed to run 'nft list sets': %v", err)
+ } else {
+ t.Logf("Current nft sets:\n%s", output)
+ }
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ setName := firewall.GenerateSetName(tt.sources)
+ set, err := r.createIpSet(setName, tt.sources)
+ if err != nil {
+ t.Logf("Failed to create IP set: %v", err)
+ printNftSets()
+ require.NoError(t, err, "Failed to create IP set")
+ }
+ require.NotNil(t, set, "Created set is nil")
+
+ // Verify set properties
+ assert.Equal(t, setName, set.Name, "Set name mismatch")
+ assert.Equal(t, r.workTable, set.Table, "Set table mismatch")
+ assert.True(t, set.Interval, "Set interval property should be true")
+ assert.Equal(t, nftables.TypeIPAddr, set.KeyType, "Set key type mismatch")
+
+ // Fetch the created set from nftables
+ fetchedSet, err := r.conn.GetSetByName(r.workTable, setName)
+ require.NoError(t, err, "Failed to fetch created set")
+ require.NotNil(t, fetchedSet, "Fetched set is nil")
+
+ // Verify set elements
+ elements, err := r.conn.GetSetElements(fetchedSet)
+ require.NoError(t, err, "Failed to get set elements")
+
+ // Count the number of unique prefixes (excluding interval end markers)
+ uniquePrefixes := make(map[string]bool)
+ for _, elem := range elements {
+ if !elem.IntervalEnd {
+ ip := netip.AddrFrom4(*(*[4]byte)(elem.Key))
+ uniquePrefixes[ip.String()] = true
+ }
+ }
+
+ // Check against expected merged prefixes
+ expectedCount := len(tt.expected)
+ if expectedCount == 0 {
+ expectedCount = len(tt.sources)
+ }
+ assert.Equal(t, expectedCount, len(uniquePrefixes), "Number of unique prefixes in set doesn't match expected")
+
+ // Verify each expected prefix is in the set
+ for _, expected := range tt.expected {
+ found := false
+ for _, elem := range elements {
+ if !elem.IntervalEnd {
+ ip := netip.AddrFrom4(*(*[4]byte)(elem.Key))
+ if expected.Contains(ip) {
+ found = true
+ break
+ }
+ }
+ }
+ assert.True(t, found, "Expected prefix %s not found in set", expected)
+ }
+
+ r.conn.DelSet(set)
+ if err := r.conn.Flush(); err != nil {
+ t.Logf("Failed to delete set: %v", err)
+ printNftSets()
+ }
+ require.NoError(t, err, "Failed to delete set")
+ })
+ }
+}
+
+func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) {
+ t.Helper()
+
+ assert.NotNil(t, rule, "Rule should not be nil")
+
+ // Verify sources and destination
+ if expectSet {
+ assert.True(t, containsSetLookup(rule.Exprs), "Rule should contain set lookup for multiple sources")
+ } else if len(sources) == 1 && sources[0].Bits() != 0 {
+ if direction == firewall.RuleDirectionIN {
+ assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], true), "Rule should contain source CIDR matcher for %s", sources[0])
+ } else {
+ assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], false), "Rule should contain destination CIDR matcher for %s", sources[0])
+ }
+ }
+
+ if direction == firewall.RuleDirectionIN {
+ assert.True(t, containsCIDRMatcher(rule.Exprs, destination, false), "Rule should contain destination CIDR matcher for %s", destination)
+ } else {
+ assert.True(t, containsCIDRMatcher(rule.Exprs, destination, true), "Rule should contain source CIDR matcher for %s", destination)
+ }
+
+ // Verify protocol
+ if proto != firewall.ProtocolALL {
+ assert.True(t, containsProtocol(rule.Exprs, proto), "Rule should contain protocol matcher for %s", proto)
+ }
+
+ // Verify ports
+ if sPort != nil {
+ assert.True(t, containsPort(rule.Exprs, sPort, true), "Rule should contain source port matcher for %v", sPort)
+ }
+ if dPort != nil {
+ assert.True(t, containsPort(rule.Exprs, dPort, false), "Rule should contain destination port matcher for %v", dPort)
+ }
+
+ // Verify action
+ assert.True(t, containsAction(rule.Exprs, action), "Rule should contain correct action: %s", action)
+}
+
+func containsSetLookup(exprs []expr.Any) bool {
+ for _, e := range exprs {
+ if _, ok := e.(*expr.Lookup); ok {
+ return true
+ }
+ }
+ return false
+}
+
+func containsCIDRMatcher(exprs []expr.Any, prefix netip.Prefix, isSource bool) bool {
+ var offset uint32
+ if isSource {
+ offset = 12 // src offset
+ } else {
+ offset = 16 // dst offset
+ }
+
+ var payloadFound, bitwiseFound, cmpFound bool
+ for _, e := range exprs {
+ switch ex := e.(type) {
+ case *expr.Payload:
+ if ex.Base == expr.PayloadBaseNetworkHeader && ex.Offset == offset && ex.Len == 4 {
+ payloadFound = true
+ }
+ case *expr.Bitwise:
+ if ex.Len == 4 && len(ex.Mask) == 4 && len(ex.Xor) == 4 {
+ bitwiseFound = true
+ }
+ case *expr.Cmp:
+ if ex.Op == expr.CmpOpEq && len(ex.Data) == 4 {
+ cmpFound = true
+ }
+ }
+ }
+ return (payloadFound && bitwiseFound && cmpFound) || prefix.Bits() == 0
+}
+
+func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
+ var offset uint32 = 2 // Default offset for destination port
+ if isSource {
+ offset = 0 // Offset for source port
+ }
+
+ var payloadFound, portMatchFound bool
+ for _, e := range exprs {
+ switch ex := e.(type) {
+ case *expr.Payload:
+ if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 {
+ payloadFound = true
+ }
+ case *expr.Cmp:
+ if port.IsRange {
+ if ex.Op == expr.CmpOpGte || ex.Op == expr.CmpOpLte {
+ portMatchFound = true
+ }
+ } else {
+ if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 {
+ portValue := binary.BigEndian.Uint16(ex.Data)
+ for _, p := range port.Values {
+ if uint16(p) == portValue {
+ portMatchFound = true
+ break
+ }
+ }
+ }
+ }
+ }
+ if payloadFound && portMatchFound {
+ return true
+ }
+ }
+ return false
+}
+
+func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
+ var metaFound, cmpFound bool
+ expectedProto, _ := protoToInt(proto)
+ for _, e := range exprs {
+ switch ex := e.(type) {
+ case *expr.Meta:
+ if ex.Key == expr.MetaKeyL4PROTO {
+ metaFound = true
+ }
+ case *expr.Cmp:
+ if ex.Op == expr.CmpOpEq && len(ex.Data) == 1 && ex.Data[0] == expectedProto {
+ cmpFound = true
+ }
+ }
+ }
+ return metaFound && cmpFound
+}
+
+func containsAction(exprs []expr.Any, action firewall.Action) bool {
+ for _, e := range exprs {
+ if verdict, ok := e.(*expr.Verdict); ok {
+ switch action {
+ case firewall.ActionAccept:
+ return verdict.Kind == expr.VerdictAccept
+ case firewall.ActionDrop:
+ return verdict.Kind == expr.VerdictDrop
+ }
+ }
+ }
+ return false
+}
+
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
func check() int {
nf := nftables.Conn{}
@@ -250,12 +685,12 @@ func createWorkTable() (*nftables.Table, error) {
}
for _, t := range tables {
- if t.Name == tableName {
+ if t.Name == tableNameNetbird {
sConn.DelTable(t)
}
}
- table := sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4})
+ table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4})
err = sConn.Flush()
return table, err
@@ -273,7 +708,7 @@ func deleteWorkTable() {
}
for _, t := range tables {
- if t.Name == tableName {
+ if t.Name == tableNameNetbird {
sConn.DelTable(t)
}
}
diff --git a/client/firewall/test/cases_linux.go b/client/firewall/test/cases_linux.go
index 432d113dd..267e93efd 100644
--- a/client/firewall/test/cases_linux.go
+++ b/client/firewall/test/cases_linux.go
@@ -1,8 +1,10 @@
-//go:build !android
-
package test
-import firewall "github.com/netbirdio/netbird/client/firewall/manager"
+import (
+ "net/netip"
+
+ firewall "github.com/netbirdio/netbird/client/firewall/manager"
+)
var (
InsertRuleTestCases = []struct {
@@ -13,8 +15,8 @@ var (
Name: "Insert Forwarding IPV4 Rule",
InputPair: firewall.RouterPair{
ID: "zxa",
- Source: "100.100.100.1/32",
- Destination: "100.100.200.0/24",
+ Source: netip.MustParsePrefix("100.100.100.1/32"),
+ Destination: netip.MustParsePrefix("100.100.200.0/24"),
Masquerade: false,
},
},
@@ -22,8 +24,8 @@ var (
Name: "Insert Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{
ID: "zxa",
- Source: "100.100.100.1/32",
- Destination: "100.100.200.0/24",
+ Source: netip.MustParsePrefix("100.100.100.1/32"),
+ Destination: netip.MustParsePrefix("100.100.200.0/24"),
Masquerade: true,
},
},
@@ -38,8 +40,8 @@ var (
Name: "Remove Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{
ID: "zxa",
- Source: "100.100.100.1/32",
- Destination: "100.100.200.0/24",
+ Source: netip.MustParsePrefix("100.100.100.1/32"),
+ Destination: netip.MustParsePrefix("100.100.200.0/24"),
Masquerade: true,
},
},
diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go
index 75792e9c0..0e3ee9799 100644
--- a/client/firewall/uspfilter/uspfilter.go
+++ b/client/firewall/uspfilter/uspfilter.go
@@ -3,6 +3,7 @@ package uspfilter
import (
"fmt"
"net"
+ "net/netip"
"sync"
"github.com/google/gopacket"
@@ -11,7 +12,8 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
- "github.com/netbirdio/netbird/iface"
+ "github.com/netbirdio/netbird/client/iface"
+ "github.com/netbirdio/netbird/client/iface/device"
)
const layerTypeAll = 0
@@ -22,7 +24,7 @@ var (
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
- SetFilter(iface.PacketFilter) error
+ SetFilter(device.PacketFilter) error
Address() iface.WGAddress
}
@@ -103,26 +105,26 @@ func (m *Manager) IsServerRouteSupported() bool {
}
}
-func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error {
+func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if m.nativeFirewall == nil {
return errRouteNotSupported
}
- return m.nativeFirewall.InsertRoutingRules(pair)
+ return m.nativeFirewall.AddNatRule(pair)
}
-// RemoveRoutingRules removes a routing firewall rule
-func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error {
+// RemoveNatRule removes a routing firewall rule
+func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
if m.nativeFirewall == nil {
return errRouteNotSupported
}
- return m.nativeFirewall.RemoveRoutingRules(pair)
+ return m.nativeFirewall.RemoveNatRule(pair)
}
-// AddFiltering rule to the firewall
+// AddPeerFiltering rule to the firewall
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
-func (m *Manager) AddFiltering(
+func (m *Manager) AddPeerFiltering(
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
@@ -188,8 +190,22 @@ func (m *Manager) AddFiltering(
return []firewall.Rule{&r}, nil
}
-// DeleteRule from the firewall by rule definition
-func (m *Manager) DeleteRule(rule firewall.Rule) error {
+func (m *Manager) AddRouteFiltering(sources [] netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action ) (firewall.Rule, error) {
+ if m.nativeFirewall == nil {
+ return nil, errRouteNotSupported
+ }
+ return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
+}
+
+func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
+ if m.nativeFirewall == nil {
+ return errRouteNotSupported
+ }
+ return m.nativeFirewall.DeleteRouteRule(rule)
+}
+
+// DeletePeerRule from the firewall by rule definition
+func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
@@ -215,6 +231,11 @@ func (m *Manager) DeleteRule(rule firewall.Rule) error {
return nil
}
+// SetLegacyManagement doesn't need to be implemented for this manager
+func (m *Manager) SetLegacyManagement(_ bool) error {
+ return nil
+}
+
// Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil }
@@ -395,7 +416,7 @@ func (m *Manager) RemovePacketHook(hookID string) error {
for _, r := range arr {
if r.id == hookID {
rule := r
- return m.DeleteRule(&rule)
+ return m.DeletePeerRule(&rule)
}
}
}
@@ -403,7 +424,7 @@ func (m *Manager) RemovePacketHook(hookID string) error {
for _, r := range arr {
if r.id == hookID {
rule := r
- return m.DeleteRule(&rule)
+ return m.DeletePeerRule(&rule)
}
}
}
diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go
index 514a90539..c188deea4 100644
--- a/client/firewall/uspfilter/uspfilter_test.go
+++ b/client/firewall/uspfilter/uspfilter_test.go
@@ -11,15 +11,16 @@ import (
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
- "github.com/netbirdio/netbird/iface"
+ "github.com/netbirdio/netbird/client/iface"
+ "github.com/netbirdio/netbird/client/iface/device"
)
type IFaceMock struct {
- SetFilterFunc func(iface.PacketFilter) error
+ SetFilterFunc func(device.PacketFilter) error
AddressFunc func() iface.WGAddress
}
-func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error {
+func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
if i.SetFilterFunc == nil {
return fmt.Errorf("not implemented")
}
@@ -35,7 +36,7 @@ func (i *IFaceMock) Address() iface.WGAddress {
func TestManagerCreate(t *testing.T) {
ifaceMock := &IFaceMock{
- SetFilterFunc: func(iface.PacketFilter) error { return nil },
+ SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock)
@@ -49,10 +50,10 @@ func TestManagerCreate(t *testing.T) {
}
}
-func TestManagerAddFiltering(t *testing.T) {
+func TestManagerAddPeerFiltering(t *testing.T) {
isSetFilterCalled := false
ifaceMock := &IFaceMock{
- SetFilterFunc: func(iface.PacketFilter) error {
+ SetFilterFunc: func(device.PacketFilter) error {
isSetFilterCalled = true
return nil
},
@@ -71,7 +72,7 @@ func TestManagerAddFiltering(t *testing.T) {
action := fw.ActionDrop
comment := "Test rule"
- rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
+ rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -90,7 +91,7 @@ func TestManagerAddFiltering(t *testing.T) {
func TestManagerDeleteRule(t *testing.T) {
ifaceMock := &IFaceMock{
- SetFilterFunc: func(iface.PacketFilter) error { return nil },
+ SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock)
@@ -106,7 +107,7 @@ func TestManagerDeleteRule(t *testing.T) {
action := fw.ActionDrop
comment := "Test rule"
- rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
+ rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -119,14 +120,14 @@ func TestManagerDeleteRule(t *testing.T) {
action = fw.ActionDrop
comment = "Test rule 2"
- rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
+ rule2, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
}
for _, r := range rule {
- err = m.DeleteRule(r)
+ err = m.DeletePeerRule(r)
if err != nil {
t.Errorf("failed to delete rule: %v", err)
return
@@ -140,7 +141,7 @@ func TestManagerDeleteRule(t *testing.T) {
}
for _, r := range rule2 {
- err = m.DeleteRule(r)
+ err = m.DeletePeerRule(r)
if err != nil {
t.Errorf("failed to delete rule: %v", err)
return
@@ -236,7 +237,7 @@ func TestAddUDPPacketHook(t *testing.T) {
func TestManagerReset(t *testing.T) {
ifaceMock := &IFaceMock{
- SetFilterFunc: func(iface.PacketFilter) error { return nil },
+ SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock)
@@ -252,7 +253,7 @@ func TestManagerReset(t *testing.T) {
action := fw.ActionDrop
comment := "Test rule"
- _, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment)
+ _, err = m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -271,7 +272,7 @@ func TestManagerReset(t *testing.T) {
func TestNotMatchByIP(t *testing.T) {
ifaceMock := &IFaceMock{
- SetFilterFunc: func(iface.PacketFilter) error { return nil },
+ SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock)
@@ -290,7 +291,7 @@ func TestNotMatchByIP(t *testing.T) {
action := fw.ActionAccept
comment := "Test rule"
- _, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment)
+ _, err = m.AddPeerFiltering(ip, proto, nil, nil, direction, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -339,7 +340,7 @@ func TestNotMatchByIP(t *testing.T) {
func TestRemovePacketHook(t *testing.T) {
// creating mock iface
iface := &IFaceMock{
- SetFilterFunc: func(iface.PacketFilter) error { return nil },
+ SetFilterFunc: func(device.PacketFilter) error { return nil },
}
// creating manager instance
@@ -388,7 +389,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface
ifaceMock := &IFaceMock{
- SetFilterFunc: func(iface.PacketFilter) error { return nil },
+ SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock)
require.NoError(t, err)
@@ -406,9 +407,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 {
- _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
+ _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
- _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
+ _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
require.NoError(t, err, "failed to add rule")
diff --git a/iface/bind/bind.go b/client/iface/bind/bind.go
similarity index 100%
rename from iface/bind/bind.go
rename to client/iface/bind/bind.go
diff --git a/iface/bind/udp_mux.go b/client/iface/bind/udp_mux.go
similarity index 100%
rename from iface/bind/udp_mux.go
rename to client/iface/bind/udp_mux.go
diff --git a/iface/bind/udp_mux_universal.go b/client/iface/bind/udp_mux_universal.go
similarity index 100%
rename from iface/bind/udp_mux_universal.go
rename to client/iface/bind/udp_mux_universal.go
diff --git a/iface/bind/udp_muxed_conn.go b/client/iface/bind/udp_muxed_conn.go
similarity index 100%
rename from iface/bind/udp_muxed_conn.go
rename to client/iface/bind/udp_muxed_conn.go
diff --git a/client/iface/configurer/err.go b/client/iface/configurer/err.go
new file mode 100644
index 000000000..a64bba2dd
--- /dev/null
+++ b/client/iface/configurer/err.go
@@ -0,0 +1,5 @@
+package configurer
+
+import "errors"
+
+var ErrPeerNotFound = errors.New("peer not found")
diff --git a/iface/wg_configurer_kernel_unix.go b/client/iface/configurer/kernel_unix.go
similarity index 81%
rename from iface/wg_configurer_kernel_unix.go
rename to client/iface/configurer/kernel_unix.go
index 48ea70b7b..7c1c41669 100644
--- a/iface/wg_configurer_kernel_unix.go
+++ b/client/iface/configurer/kernel_unix.go
@@ -1,6 +1,6 @@
//go:build (linux && !android) || freebsd
-package iface
+package configurer
import (
"fmt"
@@ -12,18 +12,17 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
-type wgKernelConfigurer struct {
+type KernelConfigurer struct {
deviceName string
}
-func newWGConfigurer(deviceName string) wgConfigurer {
- wgc := &wgKernelConfigurer{
+func NewKernelConfigurer(deviceName string) *KernelConfigurer {
+ return &KernelConfigurer{
deviceName: deviceName,
}
- return wgc
}
-func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) error {
+func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error {
log.Debugf("adding Wireguard private key")
key, err := wgtypes.ParseKey(privateKey)
if err != nil {
@@ -44,7 +43,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err
return nil
}
-func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
+func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
// parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil {
@@ -56,8 +55,9 @@ func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepA
return err
}
peer := wgtypes.PeerConfig{
- PublicKey: peerKeyParsed,
- ReplaceAllowedIPs: true,
+ PublicKey: peerKeyParsed,
+ ReplaceAllowedIPs: false,
+ // don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive,
Endpoint: endpoint,
@@ -74,7 +74,7 @@ func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepA
return nil
}
-func (c *wgKernelConfigurer) removePeer(peerKey string) error {
+func (c *KernelConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
@@ -95,7 +95,7 @@ func (c *wgKernelConfigurer) removePeer(peerKey string) error {
return nil
}
-func (c *wgKernelConfigurer) addAllowedIP(peerKey string, allowedIP string) error {
+func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
@@ -122,7 +122,7 @@ func (c *wgKernelConfigurer) addAllowedIP(peerKey string, allowedIP string) erro
return nil
}
-func (c *wgKernelConfigurer) removeAllowedIP(peerKey string, allowedIP string) error {
+func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return fmt.Errorf("parse allowed IP: %w", err)
@@ -164,7 +164,7 @@ func (c *wgKernelConfigurer) removeAllowedIP(peerKey string, allowedIP string) e
return nil
}
-func (c *wgKernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
+func (c *KernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
wg, err := wgctrl.New()
if err != nil {
return wgtypes.Peer{}, fmt.Errorf("wgctl: %w", err)
@@ -188,7 +188,7 @@ func (c *wgKernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer
return wgtypes.Peer{}, ErrPeerNotFound
}
-func (c *wgKernelConfigurer) configure(config wgtypes.Config) error {
+func (c *KernelConfigurer) configure(config wgtypes.Config) error {
wg, err := wgctrl.New()
if err != nil {
return err
@@ -204,10 +204,10 @@ func (c *wgKernelConfigurer) configure(config wgtypes.Config) error {
return wg.ConfigureDevice(c.deviceName, config)
}
-func (c *wgKernelConfigurer) close() {
+func (c *KernelConfigurer) Close() {
}
-func (c *wgKernelConfigurer) getStats(peerKey string) (WGStats, error) {
+func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) {
peer, err := c.getPeer(c.deviceName, peerKey)
if err != nil {
return WGStats{}, fmt.Errorf("get wireguard stats: %w", err)
diff --git a/iface/name.go b/client/iface/configurer/name.go
similarity index 87%
rename from iface/name.go
rename to client/iface/configurer/name.go
index 706cb65ad..e2133d0ea 100644
--- a/iface/name.go
+++ b/client/iface/configurer/name.go
@@ -1,6 +1,6 @@
//go:build linux || windows || freebsd
-package iface
+package configurer
// WgInterfaceDefault is a default interface name of Wiretrustee
const WgInterfaceDefault = "wt0"
diff --git a/iface/name_darwin.go b/client/iface/configurer/name_darwin.go
similarity index 86%
rename from iface/name_darwin.go
rename to client/iface/configurer/name_darwin.go
index a4016ce15..034ce388d 100644
--- a/iface/name_darwin.go
+++ b/client/iface/configurer/name_darwin.go
@@ -1,6 +1,6 @@
//go:build darwin
-package iface
+package configurer
// WgInterfaceDefault is a default interface name of Wiretrustee
const WgInterfaceDefault = "utun100"
diff --git a/iface/uapi.go b/client/iface/configurer/uapi.go
similarity index 96%
rename from iface/uapi.go
rename to client/iface/configurer/uapi.go
index d7ff52e7b..4801841de 100644
--- a/iface/uapi.go
+++ b/client/iface/configurer/uapi.go
@@ -1,6 +1,6 @@
//go:build !windows
-package iface
+package configurer
import (
"net"
diff --git a/iface/uapi_windows.go b/client/iface/configurer/uapi_windows.go
similarity index 88%
rename from iface/uapi_windows.go
rename to client/iface/configurer/uapi_windows.go
index e1f466364..46fa90c2e 100644
--- a/iface/uapi_windows.go
+++ b/client/iface/configurer/uapi_windows.go
@@ -1,4 +1,4 @@
-package iface
+package configurer
import (
"net"
diff --git a/iface/wg_configurer_usp.go b/client/iface/configurer/usp.go
similarity index 91%
rename from iface/wg_configurer_usp.go
rename to client/iface/configurer/usp.go
index 04a29a60b..21d65ab2a 100644
--- a/iface/wg_configurer_usp.go
+++ b/client/iface/configurer/usp.go
@@ -1,4 +1,4 @@
-package iface
+package configurer
import (
"encoding/hex"
@@ -19,15 +19,15 @@ import (
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
-type wgUSPConfigurer struct {
+type WGUSPConfigurer struct {
device *device.Device
deviceName string
uapiListener net.Listener
}
-func newWGUSPConfigurer(device *device.Device, deviceName string) wgConfigurer {
- wgCfg := &wgUSPConfigurer{
+func NewUSPConfigurer(device *device.Device, deviceName string) *WGUSPConfigurer {
+ wgCfg := &WGUSPConfigurer{
device: device,
deviceName: deviceName,
}
@@ -35,7 +35,7 @@ func newWGUSPConfigurer(device *device.Device, deviceName string) wgConfigurer {
return wgCfg
}
-func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error {
+func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
log.Debugf("adding Wireguard private key")
key, err := wgtypes.ParseKey(privateKey)
if err != nil {
@@ -52,7 +52,7 @@ func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error
return c.device.IpcSet(toWgUserspaceString(config))
}
-func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
+func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
// parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil {
@@ -64,8 +64,9 @@ func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAliv
return err
}
peer := wgtypes.PeerConfig{
- PublicKey: peerKeyParsed,
- ReplaceAllowedIPs: true,
+ PublicKey: peerKeyParsed,
+ ReplaceAllowedIPs: false,
+ // don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: []net.IPNet{*ipNet},
PersistentKeepaliveInterval: &keepAlive,
PresharedKey: preSharedKey,
@@ -79,7 +80,7 @@ func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAliv
return c.device.IpcSet(toWgUserspaceString(config))
}
-func (c *wgUSPConfigurer) removePeer(peerKey string) error {
+func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return err
@@ -96,7 +97,7 @@ func (c *wgUSPConfigurer) removePeer(peerKey string) error {
return c.device.IpcSet(toWgUserspaceString(config))
}
-func (c *wgUSPConfigurer) addAllowedIP(peerKey string, allowedIP string) error {
+func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil {
return err
@@ -120,7 +121,7 @@ func (c *wgUSPConfigurer) addAllowedIP(peerKey string, allowedIP string) error {
return c.device.IpcSet(toWgUserspaceString(config))
}
-func (c *wgUSPConfigurer) removeAllowedIP(peerKey string, ip string) error {
+func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
ipc, err := c.device.IpcGet()
if err != nil {
return err
@@ -184,7 +185,7 @@ func (c *wgUSPConfigurer) removeAllowedIP(peerKey string, ip string) error {
}
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
-func (t *wgUSPConfigurer) startUAPI() {
+func (t *WGUSPConfigurer) startUAPI() {
var err error
t.uapiListener, err = openUAPI(t.deviceName)
if err != nil {
@@ -206,7 +207,7 @@ func (t *wgUSPConfigurer) startUAPI() {
}(t.uapiListener)
}
-func (t *wgUSPConfigurer) close() {
+func (t *WGUSPConfigurer) Close() {
if t.uapiListener != nil {
err := t.uapiListener.Close()
if err != nil {
@@ -222,7 +223,7 @@ func (t *wgUSPConfigurer) close() {
}
}
-func (t *wgUSPConfigurer) getStats(peerKey string) (WGStats, error) {
+func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) {
ipc, err := t.device.IpcGet()
if err != nil {
return WGStats{}, fmt.Errorf("ipc get: %w", err)
diff --git a/iface/wg_configurer_usp_test.go b/client/iface/configurer/usp_test.go
similarity index 99%
rename from iface/wg_configurer_usp_test.go
rename to client/iface/configurer/usp_test.go
index ac0fc6130..775339f24 100644
--- a/iface/wg_configurer_usp_test.go
+++ b/client/iface/configurer/usp_test.go
@@ -1,4 +1,4 @@
-package iface
+package configurer
import (
"encoding/hex"
diff --git a/client/iface/configurer/wgstats.go b/client/iface/configurer/wgstats.go
new file mode 100644
index 000000000..56d0d7310
--- /dev/null
+++ b/client/iface/configurer/wgstats.go
@@ -0,0 +1,9 @@
+package configurer
+
+import "time"
+
+type WGStats struct {
+ LastHandshake time.Time
+ TxBytes int64
+ RxBytes int64
+}
diff --git a/client/iface/device.go b/client/iface/device.go
new file mode 100644
index 000000000..0d4e69145
--- /dev/null
+++ b/client/iface/device.go
@@ -0,0 +1,18 @@
+//go:build !android
+
+package iface
+
+import (
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/device"
+)
+
+type WGTunDevice interface {
+ Create() (device.WGConfigurer, error)
+ Up() (*bind.UniversalUDPMuxDefault, error)
+ UpdateAddr(address WGAddress) error
+ WgAddress() WGAddress
+ DeviceName() string
+ Close() error
+ FilteredDevice() *device.FilteredDevice
+}
diff --git a/iface/tun_adapter.go b/client/iface/device/adapter.go
similarity index 94%
rename from iface/tun_adapter.go
rename to client/iface/device/adapter.go
index adec93ed1..6ebc05390 100644
--- a/iface/tun_adapter.go
+++ b/client/iface/device/adapter.go
@@ -1,4 +1,4 @@
-package iface
+package device
// TunAdapter is an interface for create tun device from external service
type TunAdapter interface {
diff --git a/iface/address.go b/client/iface/device/address.go
similarity index 69%
rename from iface/address.go
rename to client/iface/device/address.go
index 5ff4fbc06..15de301da 100644
--- a/iface/address.go
+++ b/client/iface/device/address.go
@@ -1,18 +1,18 @@
-package iface
+package device
import (
"fmt"
"net"
)
-// WGAddress Wireguard parsed address
+// WGAddress WireGuard parsed address
type WGAddress struct {
IP net.IP
Network *net.IPNet
}
-// parseWGAddress parse a string ("1.2.3.4/24") address to WG Address
-func parseWGAddress(address string) (WGAddress, error) {
+// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
+func ParseWGAddress(address string) (WGAddress, error) {
ip, network, err := net.ParseCIDR(address)
if err != nil {
return WGAddress{}, err
diff --git a/iface/tun_args.go b/client/iface/device/args.go
similarity index 88%
rename from iface/tun_args.go
rename to client/iface/device/args.go
index 0eac2c4c0..d7b86b335 100644
--- a/iface/tun_args.go
+++ b/client/iface/device/args.go
@@ -1,4 +1,4 @@
-package iface
+package device
type MobileIFaceArguments struct {
TunAdapter TunAdapter // only for Android
diff --git a/iface/tun_android.go b/client/iface/device/device_android.go
similarity index 61%
rename from iface/tun_android.go
rename to client/iface/device/device_android.go
index 504993094..29e3f409d 100644
--- a/iface/tun_android.go
+++ b/client/iface/device/device_android.go
@@ -1,7 +1,6 @@
//go:build android
-// +build android
-package iface
+package device
import (
"strings"
@@ -12,11 +11,12 @@ import (
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
- "github.com/netbirdio/netbird/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/configurer"
)
-// ignore the wgTunDevice interface on Android because the creation of the tun device is different on this platform
-type wgTunDevice struct {
+// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
+type WGTunDevice struct {
address WGAddress
port int
key string
@@ -24,15 +24,15 @@ type wgTunDevice struct {
iceBind *bind.ICEBind
tunAdapter TunAdapter
- name string
- device *device.Device
- wrapper *DeviceWrapper
- udpMux *bind.UniversalUDPMuxDefault
- configurer wgConfigurer
+ name string
+ device *device.Device
+ filteredDevice *FilteredDevice
+ udpMux *bind.UniversalUDPMuxDefault
+ configurer WGConfigurer
}
-func newTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter, filterFn bind.FilterFn) wgTunDevice {
- return wgTunDevice{
+func NewTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter, filterFn bind.FilterFn) *WGTunDevice {
+ return &WGTunDevice{
address: address,
port: port,
key: key,
@@ -42,7 +42,7 @@ func newTunDevice(address WGAddress, port int, key string, mtu int, transportNet
}
}
-func (t *wgTunDevice) Create(routes []string, dns string, searchDomains []string) (wgConfigurer, error) {
+func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) {
log.Info("create tun interface")
routesString := routesToString(routes)
@@ -61,24 +61,24 @@ func (t *wgTunDevice) Create(routes []string, dns string, searchDomains []string
return nil, err
}
t.name = name
- t.wrapper = newDeviceWrapper(tunDevice)
+ t.filteredDevice = newDeviceFilter(tunDevice)
log.Debugf("attaching to interface %v", name)
- t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
+ t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
// this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
- t.configurer = newWGUSPConfigurer(t.device, t.name)
- err = t.configurer.configureInterface(t.key, t.port)
+ t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
+ err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()
- t.configurer.close()
+ t.configurer.Close()
return nil, err
}
return t.configurer, nil
}
-func (t *wgTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
+func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
@@ -93,14 +93,14 @@ func (t *wgTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
-func (t *wgTunDevice) UpdateAddr(addr WGAddress) error {
+func (t *WGTunDevice) UpdateAddr(addr WGAddress) error {
// todo implement
return nil
}
-func (t *wgTunDevice) Close() error {
+func (t *WGTunDevice) Close() error {
if t.configurer != nil {
- t.configurer.close()
+ t.configurer.Close()
}
if t.device != nil {
@@ -115,20 +115,20 @@ func (t *wgTunDevice) Close() error {
return nil
}
-func (t *wgTunDevice) Device() *device.Device {
+func (t *WGTunDevice) Device() *device.Device {
return t.device
}
-func (t *wgTunDevice) DeviceName() string {
+func (t *WGTunDevice) DeviceName() string {
return t.name
}
-func (t *wgTunDevice) WgAddress() WGAddress {
+func (t *WGTunDevice) WgAddress() WGAddress {
return t.address
}
-func (t *wgTunDevice) Wrapper() *DeviceWrapper {
- return t.wrapper
+func (t *WGTunDevice) FilteredDevice() *FilteredDevice {
+ return t.filteredDevice
}
func routesToString(routes []string) string {
diff --git a/iface/tun_darwin.go b/client/iface/device/device_darwin.go
similarity index 64%
rename from iface/tun_darwin.go
rename to client/iface/device/device_darwin.go
index 364e5dfad..03e85a7f1 100644
--- a/iface/tun_darwin.go
+++ b/client/iface/device/device_darwin.go
@@ -1,8 +1,9 @@
//go:build !ios
-package iface
+package device
import (
+ "fmt"
"os/exec"
"github.com/pion/transport/v3"
@@ -10,10 +11,11 @@ import (
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
- "github.com/netbirdio/netbird/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/configurer"
)
-type tunDevice struct {
+type TunDevice struct {
name string
address WGAddress
port int
@@ -21,14 +23,14 @@ type tunDevice struct {
mtu int
iceBind *bind.ICEBind
- device *device.Device
- wrapper *DeviceWrapper
- udpMux *bind.UniversalUDPMuxDefault
- configurer wgConfigurer
+ device *device.Device
+ filteredDevice *FilteredDevice
+ udpMux *bind.UniversalUDPMuxDefault
+ configurer WGConfigurer
}
-func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice {
- return &tunDevice{
+func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice {
+ return &TunDevice{
name: name,
address: address,
port: port,
@@ -38,16 +40,16 @@ func newTunDevice(name string, address WGAddress, port int, key string, mtu int,
}
}
-func (t *tunDevice) Create() (wgConfigurer, error) {
+func (t *TunDevice) Create() (WGConfigurer, error) {
tunDevice, err := tun.CreateTUN(t.name, t.mtu)
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("error creating tun device: %s", err)
}
- t.wrapper = newDeviceWrapper(tunDevice)
+ t.filteredDevice = newDeviceFilter(tunDevice)
// We need to create a wireguard-go device and listen to configuration requests
t.device = device.NewDevice(
- t.wrapper,
+ t.filteredDevice,
t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "),
)
@@ -55,20 +57,20 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
err = t.assignAddr()
if err != nil {
t.device.Close()
- return nil, err
+ return nil, fmt.Errorf("error assigning ip: %s", err)
}
- t.configurer = newWGUSPConfigurer(t.device, t.name)
- err = t.configurer.configureInterface(t.key, t.port)
+ t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
+ err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()
- t.configurer.close()
- return nil, err
+ t.configurer.Close()
+ return nil, fmt.Errorf("error configuring interface: %s", err)
}
return t.configurer, nil
}
-func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
+func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
@@ -83,14 +85,14 @@ func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
-func (t *tunDevice) UpdateAddr(address WGAddress) error {
+func (t *TunDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
-func (t *tunDevice) Close() error {
+func (t *TunDevice) Close() error {
if t.configurer != nil {
- t.configurer.close()
+ t.configurer.Close()
}
if t.device != nil {
@@ -104,20 +106,20 @@ func (t *tunDevice) Close() error {
return nil
}
-func (t *tunDevice) WgAddress() WGAddress {
+func (t *TunDevice) WgAddress() WGAddress {
return t.address
}
-func (t *tunDevice) DeviceName() string {
+func (t *TunDevice) DeviceName() string {
return t.name
}
-func (t *tunDevice) Wrapper() *DeviceWrapper {
- return t.wrapper
+func (t *TunDevice) FilteredDevice() *FilteredDevice {
+ return t.filteredDevice
}
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
-func (t *tunDevice) assignAddr() error {
+func (t *TunDevice) assignAddr() error {
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
if out, err := cmd.CombinedOutput(); err != nil {
log.Errorf("adding address command '%v' failed with output: %s", cmd.String(), out)
diff --git a/iface/device_wrapper.go b/client/iface/device/device_filter.go
similarity index 81%
rename from iface/device_wrapper.go
rename to client/iface/device/device_filter.go
index 2fa219395..f87f10429 100644
--- a/iface/device_wrapper.go
+++ b/client/iface/device/device_filter.go
@@ -1,4 +1,4 @@
-package iface
+package device
import (
"net"
@@ -28,22 +28,23 @@ type PacketFilter interface {
SetNetwork(*net.IPNet)
}
-// DeviceWrapper to override Read or Write of packets
-type DeviceWrapper struct {
+// FilteredDevice to override Read or Write of packets
+type FilteredDevice struct {
tun.Device
+
filter PacketFilter
mutex sync.RWMutex
}
-// newDeviceWrapper constructor function
-func newDeviceWrapper(device tun.Device) *DeviceWrapper {
- return &DeviceWrapper{
+// newDeviceFilter constructor function
+func newDeviceFilter(device tun.Device) *FilteredDevice {
+ return &FilteredDevice{
Device: device,
}
}
// Read wraps read method with filtering feature
-func (d *DeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
+func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
return 0, err
}
@@ -68,7 +69,7 @@ func (d *DeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err
}
// Write wraps write method with filtering feature
-func (d *DeviceWrapper) Write(bufs [][]byte, offset int) (int, error) {
+func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
d.mutex.RLock()
filter := d.filter
d.mutex.RUnlock()
@@ -92,7 +93,7 @@ func (d *DeviceWrapper) Write(bufs [][]byte, offset int) (int, error) {
}
// SetFilter sets packet filter to device
-func (d *DeviceWrapper) SetFilter(filter PacketFilter) {
+func (d *FilteredDevice) SetFilter(filter PacketFilter) {
d.mutex.Lock()
d.filter = filter
d.mutex.Unlock()
diff --git a/iface/device_wrapper_test.go b/client/iface/device/device_filter_test.go
similarity index 95%
rename from iface/device_wrapper_test.go
rename to client/iface/device/device_filter_test.go
index 2d3725ea4..d3278b918 100644
--- a/iface/device_wrapper_test.go
+++ b/client/iface/device/device_filter_test.go
@@ -1,4 +1,4 @@
-package iface
+package device
import (
"net"
@@ -7,7 +7,8 @@ import (
"github.com/golang/mock/gomock"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
- mocks "github.com/netbirdio/netbird/iface/mocks"
+
+ mocks "github.com/netbirdio/netbird/client/iface/mocks"
)
func TestDeviceWrapperRead(t *testing.T) {
@@ -51,7 +52,7 @@ func TestDeviceWrapperRead(t *testing.T) {
return 1, nil
})
- wrapped := newDeviceWrapper(tun)
+ wrapped := newDeviceFilter(tun)
bufs := [][]byte{{}}
sizes := []int{0}
@@ -99,7 +100,7 @@ func TestDeviceWrapperRead(t *testing.T) {
tun := mocks.NewMockDevice(ctrl)
tun.EXPECT().Write(mockBufs, 0).Return(1, nil)
- wrapped := newDeviceWrapper(tun)
+ wrapped := newDeviceFilter(tun)
bufs := [][]byte{buffer.Bytes()}
@@ -147,7 +148,7 @@ func TestDeviceWrapperRead(t *testing.T) {
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropIncoming(gomock.Any()).Return(true)
- wrapped := newDeviceWrapper(tun)
+ wrapped := newDeviceFilter(tun)
wrapped.filter = filter
bufs := [][]byte{buffer.Bytes()}
@@ -202,7 +203,7 @@ func TestDeviceWrapperRead(t *testing.T) {
filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropOutgoing(gomock.Any()).Return(true)
- wrapped := newDeviceWrapper(tun)
+ wrapped := newDeviceFilter(tun)
wrapped.filter = filter
bufs := [][]byte{{}}
diff --git a/iface/tun_ios.go b/client/iface/device/device_ios.go
similarity index 63%
rename from iface/tun_ios.go
rename to client/iface/device/device_ios.go
index 6d53cc333..226e8a2e0 100644
--- a/iface/tun_ios.go
+++ b/client/iface/device/device_ios.go
@@ -1,7 +1,7 @@
//go:build ios
// +build ios
-package iface
+package device
import (
"os"
@@ -12,10 +12,11 @@ import (
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
- "github.com/netbirdio/netbird/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/configurer"
)
-type tunDevice struct {
+type TunDevice struct {
name string
address WGAddress
port int
@@ -23,14 +24,14 @@ type tunDevice struct {
iceBind *bind.ICEBind
tunFd int
- device *device.Device
- wrapper *DeviceWrapper
- udpMux *bind.UniversalUDPMuxDefault
- configurer wgConfigurer
+ device *device.Device
+ filteredDevice *FilteredDevice
+ udpMux *bind.UniversalUDPMuxDefault
+ configurer WGConfigurer
}
-func newTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int, filterFn bind.FilterFn) *tunDevice {
- return &tunDevice{
+func NewTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int, filterFn bind.FilterFn) *TunDevice {
+ return &TunDevice{
name: name,
address: address,
port: port,
@@ -40,7 +41,7 @@ func newTunDevice(name string, address WGAddress, port int, key string, transpor
}
}
-func (t *tunDevice) Create() (wgConfigurer, error) {
+func (t *TunDevice) Create() (WGConfigurer, error) {
log.Infof("create tun interface")
dupTunFd, err := unix.Dup(t.tunFd)
@@ -62,24 +63,24 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
return nil, err
}
- t.wrapper = newDeviceWrapper(tunDevice)
+ t.filteredDevice = newDeviceFilter(tunDevice)
log.Debug("Attaching to interface")
- t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
+ t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
// without this property mobile devices can discover remote endpoints if the configured one was wrong.
// this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
- t.configurer = newWGUSPConfigurer(t.device, t.name)
- err = t.configurer.configureInterface(t.key, t.port)
+ t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
+ err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()
- t.configurer.close()
+ t.configurer.Close()
return nil, err
}
return t.configurer, nil
}
-func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
+func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
@@ -94,17 +95,17 @@ func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
-func (t *tunDevice) Device() *device.Device {
+func (t *TunDevice) Device() *device.Device {
return t.device
}
-func (t *tunDevice) DeviceName() string {
+func (t *TunDevice) DeviceName() string {
return t.name
}
-func (t *tunDevice) Close() error {
+func (t *TunDevice) Close() error {
if t.configurer != nil {
- t.configurer.close()
+ t.configurer.Close()
}
if t.device != nil {
@@ -119,15 +120,15 @@ func (t *tunDevice) Close() error {
return nil
}
-func (t *tunDevice) WgAddress() WGAddress {
+func (t *TunDevice) WgAddress() WGAddress {
return t.address
}
-func (t *tunDevice) UpdateAddr(addr WGAddress) error {
+func (t *TunDevice) UpdateAddr(addr WGAddress) error {
// todo implement
return nil
}
-func (t *tunDevice) Wrapper() *DeviceWrapper {
- return t.wrapper
+func (t *TunDevice) FilteredDevice() *FilteredDevice {
+ return t.filteredDevice
}
diff --git a/iface/tun_kernel_unix.go b/client/iface/device/device_kernel_unix.go
similarity index 73%
rename from iface/tun_kernel_unix.go
rename to client/iface/device/device_kernel_unix.go
index 019dd786b..f355d2cf7 100644
--- a/iface/tun_kernel_unix.go
+++ b/client/iface/device/device_kernel_unix.go
@@ -1,6 +1,6 @@
//go:build (linux && !android) || freebsd
-package iface
+package device
import (
"context"
@@ -10,11 +10,12 @@ import (
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
- "github.com/netbirdio/netbird/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/sharedsock"
)
-type tunKernelDevice struct {
+type TunKernelDevice struct {
name string
address WGAddress
wgPort int
@@ -31,11 +32,11 @@ type tunKernelDevice struct {
filterFn bind.FilterFn
}
-func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) wgTunDevice {
+func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
checkUser()
ctx, cancel := context.WithCancel(context.Background())
- return &tunKernelDevice{
+ return &TunKernelDevice{
ctx: ctx,
ctxCancel: cancel,
name: name,
@@ -47,7 +48,7 @@ func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu in
}
}
-func (t *tunKernelDevice) Create() (wgConfigurer, error) {
+func (t *TunKernelDevice) Create() (WGConfigurer, error) {
link := newWGLink(t.name)
if err := link.recreate(); err != nil {
@@ -67,16 +68,16 @@ func (t *tunKernelDevice) Create() (wgConfigurer, error) {
return nil, fmt.Errorf("set mtu: %w", err)
}
- configurer := newWGConfigurer(t.name)
+ configurer := configurer.NewKernelConfigurer(t.name)
- if err := configurer.configureInterface(t.key, t.wgPort); err != nil {
- return nil, err
+ if err := configurer.ConfigureInterface(t.key, t.wgPort); err != nil {
+ return nil, fmt.Errorf("error configuring interface: %s", err)
}
return configurer, nil
}
-func (t *tunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
+func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.udpMux != nil {
return t.udpMux, nil
}
@@ -111,12 +112,12 @@ func (t *tunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return t.udpMux, nil
}
-func (t *tunKernelDevice) UpdateAddr(address WGAddress) error {
+func (t *TunKernelDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
-func (t *tunKernelDevice) Close() error {
+func (t *TunKernelDevice) Close() error {
if t.link == nil {
return nil
}
@@ -144,19 +145,19 @@ func (t *tunKernelDevice) Close() error {
return closErr
}
-func (t *tunKernelDevice) WgAddress() WGAddress {
+func (t *TunKernelDevice) WgAddress() WGAddress {
return t.address
}
-func (t *tunKernelDevice) DeviceName() string {
+func (t *TunKernelDevice) DeviceName() string {
return t.name
}
-func (t *tunKernelDevice) Wrapper() *DeviceWrapper {
+func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
return nil
}
// assignAddr Adds IP address to the tunnel interface
-func (t *tunKernelDevice) assignAddr() error {
+func (t *TunKernelDevice) assignAddr() error {
return t.link.assignAddr(t.address)
}
diff --git a/iface/tun_netstack.go b/client/iface/device/device_netstack.go
similarity index 51%
rename from iface/tun_netstack.go
rename to client/iface/device/device_netstack.go
index df2f75c45..440a1ca19 100644
--- a/iface/tun_netstack.go
+++ b/client/iface/device/device_netstack.go
@@ -1,7 +1,7 @@
//go:build !android
// +build !android
-package iface
+package device
import (
"fmt"
@@ -10,11 +10,12 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
- "github.com/netbirdio/netbird/iface/bind"
- "github.com/netbirdio/netbird/iface/netstack"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/configurer"
+ "github.com/netbirdio/netbird/client/iface/netstack"
)
-type tunNetstackDevice struct {
+type TunNetstackDevice struct {
name string
address WGAddress
port int
@@ -23,15 +24,15 @@ type tunNetstackDevice struct {
listenAddress string
iceBind *bind.ICEBind
- device *device.Device
- wrapper *DeviceWrapper
- nsTun *netstack.NetStackTun
- udpMux *bind.UniversalUDPMuxDefault
- configurer wgConfigurer
+ device *device.Device
+ filteredDevice *FilteredDevice
+ nsTun *netstack.NetStackTun
+ udpMux *bind.UniversalUDPMuxDefault
+ configurer WGConfigurer
}
-func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) wgTunDevice {
- return &tunNetstackDevice{
+func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) *TunNetstackDevice {
+ return &TunNetstackDevice{
name: name,
address: address,
port: wgPort,
@@ -42,33 +43,33 @@ func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string
}
}
-func (t *tunNetstackDevice) Create() (wgConfigurer, error) {
+func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
log.Info("create netstack tun interface")
t.nsTun = netstack.NewNetStackTun(t.listenAddress, t.address.IP.String(), t.mtu)
tunIface, err := t.nsTun.Create()
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("error creating tun device: %s", err)
}
- t.wrapper = newDeviceWrapper(tunIface)
+ t.filteredDevice = newDeviceFilter(tunIface)
t.device = device.NewDevice(
- t.wrapper,
+ t.filteredDevice,
t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "),
)
- t.configurer = newWGUSPConfigurer(t.device, t.name)
- err = t.configurer.configureInterface(t.key, t.port)
+ t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
+ err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
_ = tunIface.Close()
- return nil, err
+ return nil, fmt.Errorf("error configuring interface: %s", err)
}
log.Debugf("device has been created: %s", t.name)
return t.configurer, nil
}
-func (t *tunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
+func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.device == nil {
return nil, fmt.Errorf("device is not ready yet")
}
@@ -87,13 +88,13 @@ func (t *tunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
-func (t *tunNetstackDevice) UpdateAddr(WGAddress) error {
+func (t *TunNetstackDevice) UpdateAddr(WGAddress) error {
return nil
}
-func (t *tunNetstackDevice) Close() error {
+func (t *TunNetstackDevice) Close() error {
if t.configurer != nil {
- t.configurer.close()
+ t.configurer.Close()
}
if t.device != nil {
@@ -106,14 +107,14 @@ func (t *tunNetstackDevice) Close() error {
return nil
}
-func (t *tunNetstackDevice) WgAddress() WGAddress {
+func (t *TunNetstackDevice) WgAddress() WGAddress {
return t.address
}
-func (t *tunNetstackDevice) DeviceName() string {
+func (t *TunNetstackDevice) DeviceName() string {
return t.name
}
-func (t *tunNetstackDevice) Wrapper() *DeviceWrapper {
- return t.wrapper
+func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice {
+ return t.filteredDevice
}
diff --git a/iface/tun_usp_unix.go b/client/iface/device/device_usp_unix.go
similarity index 55%
rename from iface/tun_usp_unix.go
rename to client/iface/device/device_usp_unix.go
index 814c9ca89..4175f6556 100644
--- a/iface/tun_usp_unix.go
+++ b/client/iface/device/device_usp_unix.go
@@ -1,6 +1,6 @@
//go:build (linux && !android) || freebsd
-package iface
+package device
import (
"fmt"
@@ -12,10 +12,11 @@ import (
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
- "github.com/netbirdio/netbird/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/configurer"
)
-type tunUSPDevice struct {
+type USPDevice struct {
name string
address WGAddress
port int
@@ -23,39 +24,38 @@ type tunUSPDevice struct {
mtu int
iceBind *bind.ICEBind
- device *device.Device
- wrapper *DeviceWrapper
- udpMux *bind.UniversalUDPMuxDefault
- configurer wgConfigurer
+ device *device.Device
+ filteredDevice *FilteredDevice
+ udpMux *bind.UniversalUDPMuxDefault
+ configurer WGConfigurer
}
-func newTunUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice {
+func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *USPDevice {
log.Infof("using userspace bind mode")
checkUser()
- return &tunUSPDevice{
+ return &USPDevice{
name: name,
address: address,
port: port,
key: key,
mtu: mtu,
- iceBind: bind.NewICEBind(transportNet, filterFn),
- }
+ iceBind: bind.NewICEBind(transportNet, filterFn)}
}
-func (t *tunUSPDevice) Create() (wgConfigurer, error) {
+func (t *USPDevice) Create() (WGConfigurer, error) {
log.Info("create tun interface")
tunIface, err := tun.CreateTUN(t.name, t.mtu)
if err != nil {
- log.Debugf("failed to create tun unterface (%s, %d): %s", t.name, t.mtu, err)
- return nil, err
+ log.Debugf("failed to create tun interface (%s, %d): %s", t.name, t.mtu, err)
+ return nil, fmt.Errorf("error creating tun device: %s", err)
}
- t.wrapper = newDeviceWrapper(tunIface)
+ t.filteredDevice = newDeviceFilter(tunIface)
// We need to create a wireguard-go device and listen to configuration requests
t.device = device.NewDevice(
- t.wrapper,
+ t.filteredDevice,
t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "),
)
@@ -63,20 +63,20 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
err = t.assignAddr()
if err != nil {
t.device.Close()
- return nil, err
+ return nil, fmt.Errorf("error assigning ip: %s", err)
}
- t.configurer = newWGUSPConfigurer(t.device, t.name)
- err = t.configurer.configureInterface(t.key, t.port)
+ t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
+ err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()
- t.configurer.close()
- return nil, err
+ t.configurer.Close()
+ return nil, fmt.Errorf("error configuring interface: %s", err)
}
return t.configurer, nil
}
-func (t *tunUSPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
+func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.device == nil {
return nil, fmt.Errorf("device is not ready yet")
}
@@ -96,14 +96,14 @@ func (t *tunUSPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
-func (t *tunUSPDevice) UpdateAddr(address WGAddress) error {
+func (t *USPDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
-func (t *tunUSPDevice) Close() error {
+func (t *USPDevice) Close() error {
if t.configurer != nil {
- t.configurer.close()
+ t.configurer.Close()
}
if t.device != nil {
@@ -116,20 +116,20 @@ func (t *tunUSPDevice) Close() error {
return nil
}
-func (t *tunUSPDevice) WgAddress() WGAddress {
+func (t *USPDevice) WgAddress() WGAddress {
return t.address
}
-func (t *tunUSPDevice) DeviceName() string {
+func (t *USPDevice) DeviceName() string {
return t.name
}
-func (t *tunUSPDevice) Wrapper() *DeviceWrapper {
- return t.wrapper
+func (t *USPDevice) FilteredDevice() *FilteredDevice {
+ return t.filteredDevice
}
// assignAddr Adds IP address to the tunnel interface
-func (t *tunUSPDevice) assignAddr() error {
+func (t *USPDevice) assignAddr() error {
link := newWGLink(t.name)
return link.assignAddr(t.address)
diff --git a/iface/tun_windows.go b/client/iface/device/device_windows.go
similarity index 60%
rename from iface/tun_windows.go
rename to client/iface/device/device_windows.go
index 0d658059f..f3e216ccd 100644
--- a/iface/tun_windows.go
+++ b/client/iface/device/device_windows.go
@@ -1,4 +1,4 @@
-package iface
+package device
import (
"fmt"
@@ -11,10 +11,13 @@ import (
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
- "github.com/netbirdio/netbird/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/configurer"
)
-type tunDevice struct {
+const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}"
+
+type TunDevice struct {
name string
address WGAddress
port int
@@ -24,13 +27,13 @@ type tunDevice struct {
device *device.Device
nativeTunDevice *tun.NativeTun
- wrapper *DeviceWrapper
+ filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault
- configurer wgConfigurer
+ configurer WGConfigurer
}
-func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice {
- return &tunDevice{
+func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice {
+ return &TunDevice{
name: name,
address: address,
port: port,
@@ -40,18 +43,31 @@ func newTunDevice(name string, address WGAddress, port int, key string, mtu int,
}
}
-func (t *tunDevice) Create() (wgConfigurer, error) {
- log.Info("create tun interface")
- tunDevice, err := tun.CreateTUN(t.name, t.mtu)
+func getGUID() (windows.GUID, error) {
+ guidString := defaultWindowsGUIDSTring
+ if CustomWindowsGUIDString != "" {
+ guidString = CustomWindowsGUIDString
+ }
+ return windows.GUIDFromString(guidString)
+}
+
+func (t *TunDevice) Create() (WGConfigurer, error) {
+ guid, err := getGUID()
if err != nil {
+ log.Errorf("failed to get GUID: %s", err)
return nil, err
}
+ log.Info("create tun interface")
+ tunDevice, err := tun.CreateTUNWithRequestedGUID(t.name, &guid, t.mtu)
+ if err != nil {
+ return nil, fmt.Errorf("error creating tun device: %s", err)
+ }
t.nativeTunDevice = tunDevice.(*tun.NativeTun)
- t.wrapper = newDeviceWrapper(tunDevice)
+ t.filteredDevice = newDeviceFilter(tunDevice)
// We need to create a wireguard-go device and listen to configuration requests
t.device = device.NewDevice(
- t.wrapper,
+ t.filteredDevice,
t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "),
)
@@ -74,20 +90,20 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
err = t.assignAddr()
if err != nil {
t.device.Close()
- return nil, err
+ return nil, fmt.Errorf("error assigning ip: %s", err)
}
- t.configurer = newWGUSPConfigurer(t.device, t.name)
- err = t.configurer.configureInterface(t.key, t.port)
+ t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
+ err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil {
t.device.Close()
- t.configurer.close()
- return nil, err
+ t.configurer.Close()
+ return nil, fmt.Errorf("error configuring interface: %s", err)
}
return t.configurer, nil
}
-func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
+func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up()
if err != nil {
return nil, err
@@ -102,14 +118,14 @@ func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil
}
-func (t *tunDevice) UpdateAddr(address WGAddress) error {
+func (t *TunDevice) UpdateAddr(address WGAddress) error {
t.address = address
return t.assignAddr()
}
-func (t *tunDevice) Close() error {
+func (t *TunDevice) Close() error {
if t.configurer != nil {
- t.configurer.close()
+ t.configurer.Close()
}
if t.device != nil {
@@ -123,19 +139,19 @@ func (t *tunDevice) Close() error {
}
return nil
}
-func (t *tunDevice) WgAddress() WGAddress {
+func (t *TunDevice) WgAddress() WGAddress {
return t.address
}
-func (t *tunDevice) DeviceName() string {
+func (t *TunDevice) DeviceName() string {
return t.name
}
-func (t *tunDevice) Wrapper() *DeviceWrapper {
- return t.wrapper
+func (t *TunDevice) FilteredDevice() *FilteredDevice {
+ return t.filteredDevice
}
-func (t *tunDevice) getInterfaceGUIDString() (string, error) {
+func (t *TunDevice) GetInterfaceGUIDString() (string, error) {
if t.nativeTunDevice == nil {
return "", fmt.Errorf("interface has not been initialized yet")
}
@@ -149,7 +165,7 @@ func (t *tunDevice) getInterfaceGUIDString() (string, error) {
}
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided
-func (t *tunDevice) assignAddr() error {
+func (t *TunDevice) assignAddr() error {
luid := winipcfg.LUID(t.nativeTunDevice.LUID())
log.Debugf("adding address %s to interface: %s", t.address.IP, t.name)
return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())})
diff --git a/client/iface/device/interface.go b/client/iface/device/interface.go
new file mode 100644
index 000000000..0196b0085
--- /dev/null
+++ b/client/iface/device/interface.go
@@ -0,0 +1,20 @@
+package device
+
+import (
+ "net"
+ "time"
+
+ "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
+
+ "github.com/netbirdio/netbird/client/iface/configurer"
+)
+
+type WGConfigurer interface {
+ ConfigureInterface(privateKey string, port int) error
+ UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
+ RemovePeer(peerKey string) error
+ AddAllowedIP(peerKey string, allowedIP string) error
+ RemoveAllowedIP(peerKey string, allowedIP string) error
+ Close()
+ GetStats(peerKey string) (configurer.WGStats, error)
+}
diff --git a/iface/module.go b/client/iface/device/kernel_module.go
similarity index 92%
rename from iface/module.go
rename to client/iface/device/kernel_module.go
index ca70cf3c7..1bdd6f7c6 100644
--- a/iface/module.go
+++ b/client/iface/device/kernel_module.go
@@ -1,6 +1,6 @@
//go:build (!linux && !freebsd) || android
-package iface
+package device
// WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only)
func WireGuardModuleIsLoaded() bool {
diff --git a/iface/module_freebsd.go b/client/iface/device/kernel_module_freebsd.go
similarity index 84%
rename from iface/module_freebsd.go
rename to client/iface/device/kernel_module_freebsd.go
index 00ad882c2..dd6c8b408 100644
--- a/iface/module_freebsd.go
+++ b/client/iface/device/kernel_module_freebsd.go
@@ -1,4 +1,4 @@
-package iface
+package device
// WireGuardModuleIsLoaded check if kernel support wireguard
func WireGuardModuleIsLoaded() bool {
@@ -10,8 +10,8 @@ func WireGuardModuleIsLoaded() bool {
return false
}
-// tunModuleIsLoaded check if tun module exist, if is not attempt to load it
-func tunModuleIsLoaded() bool {
+// ModuleTunIsLoaded check if tun module exist, if is not attempt to load it
+func ModuleTunIsLoaded() bool {
// Assume tun supported by freebsd kernel by default
// TODO: implement check for module loaded in kernel or build-it
return true
diff --git a/iface/module_linux.go b/client/iface/device/kernel_module_linux.go
similarity index 98%
rename from iface/module_linux.go
rename to client/iface/device/kernel_module_linux.go
index 11c0482d5..0d195779d 100644
--- a/iface/module_linux.go
+++ b/client/iface/device/kernel_module_linux.go
@@ -1,7 +1,7 @@
//go:build linux && !android
// Package iface provides wireguard network interface creation and management
-package iface
+package device
import (
"bufio"
@@ -66,8 +66,8 @@ func getModuleRoot() string {
return filepath.Join(moduleLibDir, string(uname.Release[:i]))
}
-// tunModuleIsLoaded check if tun module exist, if is not attempt to load it
-func tunModuleIsLoaded() bool {
+// ModuleTunIsLoaded check if tun module exist, if is not attempt to load it
+func ModuleTunIsLoaded() bool {
_, err := os.Stat("/dev/net/tun")
if err == nil {
return true
diff --git a/iface/module_linux_test.go b/client/iface/device/kernel_module_linux_test.go
similarity index 98%
rename from iface/module_linux_test.go
rename to client/iface/device/kernel_module_linux_test.go
index 97e9b1f78..de9656e47 100644
--- a/iface/module_linux_test.go
+++ b/client/iface/device/kernel_module_linux_test.go
@@ -1,4 +1,6 @@
-package iface
+//go:build linux && !android
+
+package device
import (
"bufio"
@@ -132,7 +134,7 @@ func resetGlobals() {
}
func createFiles(t *testing.T) (string, []module) {
- t.Helper()
+ t.Helper()
writeFile := func(path, text string) {
if err := os.WriteFile(path, []byte(text), 0644); err != nil {
t.Fatal(err)
@@ -168,7 +170,7 @@ func createFiles(t *testing.T) (string, []module) {
}
func getRandomLoadedModule(t *testing.T) (string, error) {
- t.Helper()
+ t.Helper()
f, err := os.Open("/proc/modules")
if err != nil {
return "", err
diff --git a/iface/tun_link_freebsd.go b/client/iface/device/wg_link_freebsd.go
similarity index 95%
rename from iface/tun_link_freebsd.go
rename to client/iface/device/wg_link_freebsd.go
index be7921fdb..104010f47 100644
--- a/iface/tun_link_freebsd.go
+++ b/client/iface/device/wg_link_freebsd.go
@@ -1,10 +1,11 @@
-package iface
+package device
import (
"fmt"
- "github.com/netbirdio/netbird/iface/freebsd"
log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/client/iface/freebsd"
)
type wgLink struct {
diff --git a/iface/tun_link_linux.go b/client/iface/device/wg_link_linux.go
similarity index 99%
rename from iface/tun_link_linux.go
rename to client/iface/device/wg_link_linux.go
index 3ce644e84..a15cffe48 100644
--- a/iface/tun_link_linux.go
+++ b/client/iface/device/wg_link_linux.go
@@ -1,6 +1,6 @@
//go:build linux && !android
-package iface
+package device
import (
"fmt"
diff --git a/iface/wg_log.go b/client/iface/device/wg_log.go
similarity index 93%
rename from iface/wg_log.go
rename to client/iface/device/wg_log.go
index b44f6fc0b..db2f3111f 100644
--- a/iface/wg_log.go
+++ b/client/iface/device/wg_log.go
@@ -1,4 +1,4 @@
-package iface
+package device
import (
"os"
diff --git a/client/iface/device/windows_guid.go b/client/iface/device/windows_guid.go
new file mode 100644
index 000000000..1c7d40d13
--- /dev/null
+++ b/client/iface/device/windows_guid.go
@@ -0,0 +1,4 @@
+package device
+
+// CustomWindowsGUIDString is a custom GUID string for the interface
+var CustomWindowsGUIDString string
diff --git a/client/iface/device_android.go b/client/iface/device_android.go
new file mode 100644
index 000000000..3d15080ff
--- /dev/null
+++ b/client/iface/device_android.go
@@ -0,0 +1,16 @@
+package iface
+
+import (
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/device"
+)
+
+type WGTunDevice interface {
+ Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
+ Up() (*bind.UniversalUDPMuxDefault, error)
+ UpdateAddr(address WGAddress) error
+ WgAddress() WGAddress
+ DeviceName() string
+ Close() error
+ FilteredDevice() *device.FilteredDevice
+}
diff --git a/iface/freebsd/errors.go b/client/iface/freebsd/errors.go
similarity index 100%
rename from iface/freebsd/errors.go
rename to client/iface/freebsd/errors.go
diff --git a/iface/freebsd/iface.go b/client/iface/freebsd/iface.go
similarity index 100%
rename from iface/freebsd/iface.go
rename to client/iface/freebsd/iface.go
diff --git a/iface/freebsd/iface_internal_test.go b/client/iface/freebsd/iface_internal_test.go
similarity index 100%
rename from iface/freebsd/iface_internal_test.go
rename to client/iface/freebsd/iface_internal_test.go
diff --git a/iface/freebsd/link.go b/client/iface/freebsd/link.go
similarity index 100%
rename from iface/freebsd/link.go
rename to client/iface/freebsd/link.go
diff --git a/iface/iface.go b/client/iface/iface.go
similarity index 59%
rename from iface/iface.go
rename to client/iface/iface.go
index 928077a3d..accf5ce0a 100644
--- a/iface/iface.go
+++ b/client/iface/iface.go
@@ -9,28 +9,27 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
- "github.com/netbirdio/netbird/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/configurer"
+ "github.com/netbirdio/netbird/client/iface/device"
)
const (
- DefaultMTU = 1280
- DefaultWgPort = 51820
+ DefaultMTU = 1280
+ DefaultWgPort = 51820
+ WgInterfaceDefault = configurer.WgInterfaceDefault
)
-// WGIface represents a interface instance
+type WGAddress = device.WGAddress
+
+// WGIface represents an interface instance
type WGIface struct {
- tun wgTunDevice
+ tun WGTunDevice
userspaceBind bool
mu sync.Mutex
- configurer wgConfigurer
- filter PacketFilter
-}
-
-type WGStats struct {
- LastHandshake time.Time
- TxBytes int64
- RxBytes int64
+ configurer device.WGConfigurer
+ filter device.PacketFilter
}
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
@@ -44,7 +43,7 @@ func (w *WGIface) Name() string {
}
// Address returns the interface address
-func (w *WGIface) Address() WGAddress {
+func (w *WGIface) Address() device.WGAddress {
return w.tun.WgAddress()
}
@@ -75,7 +74,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
w.mu.Lock()
defer w.mu.Unlock()
- addr, err := parseWGAddress(newAddr)
+ addr, err := device.ParseWGAddress(newAddr)
if err != nil {
return err
}
@@ -90,7 +89,7 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.D
defer w.mu.Unlock()
log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint)
- return w.configurer.updatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
+ return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
}
// RemovePeer removes a Wireguard Peer from the interface iface
@@ -99,7 +98,7 @@ func (w *WGIface) RemovePeer(peerKey string) error {
defer w.mu.Unlock()
log.Debugf("Removing peer %s from interface %s ", peerKey, w.tun.DeviceName())
- return w.configurer.removePeer(peerKey)
+ return w.configurer.RemovePeer(peerKey)
}
// AddAllowedIP adds a prefix to the allowed IPs list of peer
@@ -108,7 +107,7 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error {
defer w.mu.Unlock()
log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
- return w.configurer.addAllowedIP(peerKey, allowedIP)
+ return w.configurer.AddAllowedIP(peerKey, allowedIP)
}
// RemoveAllowedIP removes a prefix from the allowed IPs list of peer
@@ -117,34 +116,50 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
defer w.mu.Unlock()
log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
- return w.configurer.removeAllowedIP(peerKey, allowedIP)
+ return w.configurer.RemoveAllowedIP(peerKey, allowedIP)
}
// Close closes the tunnel interface
func (w *WGIface) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
- return w.tun.Close()
+
+ err := w.tun.Close()
+ if err != nil {
+ return fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err)
+ }
+
+ err = w.waitUntilRemoved()
+ if err != nil {
+ log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
+ err = w.Destroy()
+ if err != nil {
+ return fmt.Errorf("failed to remove WireGuard interface %s: %w", w.Name(), err)
+ }
+ log.Infof("interface %s successfully removed", w.Name())
+ }
+
+ return nil
}
// SetFilter sets packet filters for the userspace implementation
-func (w *WGIface) SetFilter(filter PacketFilter) error {
+func (w *WGIface) SetFilter(filter device.PacketFilter) error {
w.mu.Lock()
defer w.mu.Unlock()
- if w.tun.Wrapper() == nil {
+ if w.tun.FilteredDevice() == nil {
return fmt.Errorf("userspace packet filtering not handled on this device")
}
w.filter = filter
w.filter.SetNetwork(w.tun.WgAddress().Network)
- w.tun.Wrapper().SetFilter(filter)
+ w.tun.FilteredDevice().SetFilter(filter)
return nil
}
// GetFilter returns packet filter used by interface if it uses userspace device implementation
-func (w *WGIface) GetFilter() PacketFilter {
+func (w *WGIface) GetFilter() device.PacketFilter {
w.mu.Lock()
defer w.mu.Unlock()
@@ -152,14 +167,41 @@ func (w *WGIface) GetFilter() PacketFilter {
}
// GetDevice to interact with raw device (with filtering)
-func (w *WGIface) GetDevice() *DeviceWrapper {
+func (w *WGIface) GetDevice() *device.FilteredDevice {
w.mu.Lock()
defer w.mu.Unlock()
- return w.tun.Wrapper()
+ return w.tun.FilteredDevice()
}
// GetStats returns the last handshake time, rx and tx bytes for the given peer
-func (w *WGIface) GetStats(peerKey string) (WGStats, error) {
- return w.configurer.getStats(peerKey)
+func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
+ return w.configurer.GetStats(peerKey)
+}
+
+func (w *WGIface) waitUntilRemoved() error {
+ maxWaitTime := 5 * time.Second
+ timeout := time.NewTimer(maxWaitTime)
+ defer timeout.Stop()
+
+ for {
+ iface, err := net.InterfaceByName(w.Name())
+ if err != nil {
+ if _, ok := err.(*net.OpError); ok {
+ log.Infof("interface %s has been removed", w.Name())
+ return nil
+ }
+ log.Debugf("failed to get interface by name %s: %v", w.Name(), err)
+ } else if iface == nil {
+ log.Infof("interface %s has been removed", w.Name())
+ return nil
+ }
+
+ select {
+ case <-timeout.C:
+ return fmt.Errorf("timeout when waiting for interface %s to be removed", w.Name())
+ default:
+ time.Sleep(100 * time.Millisecond)
+ }
+ }
}
diff --git a/iface/iface_android.go b/client/iface/iface_android.go
similarity index 67%
rename from iface/iface_android.go
rename to client/iface/iface_android.go
index 99f6885a5..5ed476e70 100644
--- a/iface/iface_android.go
+++ b/client/iface/iface_android.go
@@ -5,18 +5,19 @@ import (
"github.com/pion/transport/v3"
- "github.com/netbirdio/netbird/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/device"
)
// NewWGIFace Creates a new WireGuard interface instance
-func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
- wgAddress, err := parseWGAddress(address)
+func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
+ wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{
- tun: newTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn),
+ tun: device.NewTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn),
userspaceBind: true,
}
return wgIFace, nil
diff --git a/iface/iface_create.go b/client/iface/iface_create.go
similarity index 89%
rename from iface/iface_create.go
rename to client/iface/iface_create.go
index cfc555f2e..f389019ed 100644
--- a/iface/iface_create.go
+++ b/client/iface/iface_create.go
@@ -1,4 +1,4 @@
-//go:build !android
+//go:build (!android && !darwin) || ios
package iface
diff --git a/client/iface/iface_darwin.go b/client/iface/iface_darwin.go
new file mode 100644
index 000000000..b46ea0f80
--- /dev/null
+++ b/client/iface/iface_darwin.go
@@ -0,0 +1,67 @@
+//go:build !ios
+
+package iface
+
+import (
+ "fmt"
+ "time"
+
+ "github.com/cenkalti/backoff/v4"
+ "github.com/pion/transport/v3"
+
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/device"
+ "github.com/netbirdio/netbird/client/iface/netstack"
+)
+
+// NewWGIFace Creates a new WireGuard interface instance
+func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
+ wgAddress, err := device.ParseWGAddress(address)
+ if err != nil {
+ return nil, err
+ }
+
+ wgIFace := &WGIface{
+ userspaceBind: true,
+ }
+
+ if netstack.IsEnabled() {
+ wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
+ return wgIFace, nil
+ }
+
+ wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
+
+ return wgIFace, nil
+}
+
+// CreateOnAndroid this function make sense on mobile only
+func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
+ return fmt.Errorf("this function has not implemented on this platform")
+}
+
+// Create creates a new Wireguard interface, sets a given IP and brings it up.
+// Will reuse an existing one.
+// this function is different on Android
+func (w *WGIface) Create() error {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+
+ backOff := &backoff.ExponentialBackOff{
+ InitialInterval: 20 * time.Millisecond,
+ MaxElapsedTime: 500 * time.Millisecond,
+ Stop: backoff.Stop,
+ Clock: backoff.SystemClock,
+ }
+
+ operation := func() error {
+ cfgr, err := w.tun.Create()
+ if err != nil {
+ return err
+ }
+ w.configurer = cfgr
+ return nil
+ }
+
+ return backoff.Retry(operation, backOff)
+}
diff --git a/client/iface/iface_destroy_bsd.go b/client/iface/iface_destroy_bsd.go
new file mode 100644
index 000000000..c16010a1c
--- /dev/null
+++ b/client/iface/iface_destroy_bsd.go
@@ -0,0 +1,17 @@
+//go:build darwin || dragonfly || freebsd || netbsd || openbsd
+
+package iface
+
+import (
+ "fmt"
+ "os/exec"
+)
+
+func (w *WGIface) Destroy() error {
+ out, err := exec.Command("ifconfig", w.Name(), "destroy").CombinedOutput()
+ if err != nil {
+ return fmt.Errorf("failed to remove interface %s: %w - %s", w.Name(), err, out)
+ }
+
+ return nil
+}
diff --git a/client/iface/iface_destroy_linux.go b/client/iface/iface_destroy_linux.go
new file mode 100644
index 000000000..e9d54bed1
--- /dev/null
+++ b/client/iface/iface_destroy_linux.go
@@ -0,0 +1,22 @@
+//go:build linux && !android
+
+package iface
+
+import (
+ "fmt"
+
+ "github.com/vishvananda/netlink"
+)
+
+func (w *WGIface) Destroy() error {
+ link, err := netlink.LinkByName(w.Name())
+ if err != nil {
+ return fmt.Errorf("failed to get link by name %s: %w", w.Name(), err)
+ }
+
+ if err := netlink.LinkDel(link); err != nil {
+ return fmt.Errorf("failed to delete link %s: %w", w.Name(), err)
+ }
+
+ return nil
+}
diff --git a/client/iface/iface_destroy_mobile.go b/client/iface/iface_destroy_mobile.go
new file mode 100644
index 000000000..89f87a598
--- /dev/null
+++ b/client/iface/iface_destroy_mobile.go
@@ -0,0 +1,9 @@
+//go:build android || (ios && !darwin)
+
+package iface
+
+import "errors"
+
+func (w *WGIface) Destroy() error {
+ return errors.New("not supported on mobile")
+}
diff --git a/client/iface/iface_destroy_windows.go b/client/iface/iface_destroy_windows.go
new file mode 100644
index 000000000..0bfa4e211
--- /dev/null
+++ b/client/iface/iface_destroy_windows.go
@@ -0,0 +1,32 @@
+//go:build windows
+
+package iface
+
+import (
+ "fmt"
+ "os/exec"
+
+ log "github.com/sirupsen/logrus"
+)
+
+func (w *WGIface) Destroy() error {
+ netshCmd := GetSystem32Command("netsh")
+ out, err := exec.Command(netshCmd, "interface", "set", "interface", w.Name(), "admin=disable").CombinedOutput()
+ if err != nil {
+ return fmt.Errorf("failed to remove interface %s: %w - %s", w.Name(), err, out)
+ }
+ return nil
+}
+
+// GetSystem32Command checks if a command can be found in the system path and returns it. In case it can't find it
+// in the path it will return the full path of a command assuming C:\windows\system32 as the base path.
+func GetSystem32Command(command string) string {
+ _, err := exec.LookPath(command)
+ if err == nil {
+ return command
+ }
+
+ log.Tracef("Command %s not found in PATH, using C:\\windows\\system32\\%s.exe path", command, command)
+
+ return "C:\\windows\\system32\\" + command + ".exe"
+}
diff --git a/iface/iface_ios.go b/client/iface/iface_ios.go
similarity index 59%
rename from iface/iface_ios.go
rename to client/iface/iface_ios.go
index 6babe5964..fc0214748 100644
--- a/iface/iface_ios.go
+++ b/client/iface/iface_ios.go
@@ -7,17 +7,18 @@ import (
"github.com/pion/transport/v3"
- "github.com/netbirdio/netbird/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/device"
)
// NewWGIFace Creates a new WireGuard interface instance
-func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
- wgAddress, err := parseWGAddress(address)
+func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
+ wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{
- tun: newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd, filterFn),
+ tun: device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd, filterFn),
userspaceBind: true,
}
return wgIFace, nil
diff --git a/client/iface/iface_moc.go b/client/iface/iface_moc.go
new file mode 100644
index 000000000..703da9ce0
--- /dev/null
+++ b/client/iface/iface_moc.go
@@ -0,0 +1,105 @@
+package iface
+
+import (
+ "net"
+ "time"
+
+ "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
+
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/configurer"
+ "github.com/netbirdio/netbird/client/iface/device"
+)
+
+type MockWGIface struct {
+ CreateFunc func() error
+ CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
+ IsUserspaceBindFunc func() bool
+ NameFunc func() string
+ AddressFunc func() device.WGAddress
+ ToInterfaceFunc func() *net.Interface
+ UpFunc func() (*bind.UniversalUDPMuxDefault, error)
+ UpdateAddrFunc func(newAddr string) error
+ UpdatePeerFunc func(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
+ RemovePeerFunc func(peerKey string) error
+ AddAllowedIPFunc func(peerKey string, allowedIP string) error
+ RemoveAllowedIPFunc func(peerKey string, allowedIP string) error
+ CloseFunc func() error
+ SetFilterFunc func(filter device.PacketFilter) error
+ GetFilterFunc func() device.PacketFilter
+ GetDeviceFunc func() *device.FilteredDevice
+ GetStatsFunc func(peerKey string) (configurer.WGStats, error)
+ GetInterfaceGUIDStringFunc func() (string, error)
+}
+
+func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
+ return m.GetInterfaceGUIDStringFunc()
+}
+
+func (m *MockWGIface) Create() error {
+ return m.CreateFunc()
+}
+
+func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error {
+ return m.CreateOnAndroidFunc(routeRange, ip, domains)
+}
+
+func (m *MockWGIface) IsUserspaceBind() bool {
+ return m.IsUserspaceBindFunc()
+}
+
+func (m *MockWGIface) Name() string {
+ return m.NameFunc()
+}
+
+func (m *MockWGIface) Address() device.WGAddress {
+ return m.AddressFunc()
+}
+
+func (m *MockWGIface) ToInterface() *net.Interface {
+ return m.ToInterfaceFunc()
+}
+
+func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) {
+ return m.UpFunc()
+}
+
+func (m *MockWGIface) UpdateAddr(newAddr string) error {
+ return m.UpdateAddrFunc(newAddr)
+}
+
+func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
+ return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
+}
+
+func (m *MockWGIface) RemovePeer(peerKey string) error {
+ return m.RemovePeerFunc(peerKey)
+}
+
+func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP string) error {
+ return m.AddAllowedIPFunc(peerKey, allowedIP)
+}
+
+func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
+ return m.RemoveAllowedIPFunc(peerKey, allowedIP)
+}
+
+func (m *MockWGIface) Close() error {
+ return m.CloseFunc()
+}
+
+func (m *MockWGIface) SetFilter(filter device.PacketFilter) error {
+ return m.SetFilterFunc(filter)
+}
+
+func (m *MockWGIface) GetFilter() device.PacketFilter {
+ return m.GetFilterFunc()
+}
+
+func (m *MockWGIface) GetDevice() *device.FilteredDevice {
+ return m.GetDeviceFunc()
+}
+
+func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
+ return m.GetStatsFunc(peerKey)
+}
diff --git a/iface/iface_test.go b/client/iface/iface_test.go
similarity index 84%
rename from iface/iface_test.go
rename to client/iface/iface_test.go
index 43c44b770..87a68addb 100644
--- a/iface/iface_test.go
+++ b/client/iface/iface_test.go
@@ -4,14 +4,18 @@ import (
"fmt"
"net"
"net/netip"
+ "strings"
"testing"
"time"
+ "github.com/google/uuid"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
+
+ "github.com/netbirdio/netbird/client/iface/device"
)
// keep darwin compatibility
@@ -174,6 +178,72 @@ func Test_Close(t *testing.T) {
}
}
+func TestRecreation(t *testing.T) {
+ for i := 0; i < 100; i++ {
+ t.Run(fmt.Sprintf("down-%d", i), func(t *testing.T) {
+ ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2)
+ wgIP := "10.99.99.2/32"
+ wgPort := 33100
+ newNet, err := stdnet.NewNet()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ for {
+ _, err = net.InterfaceByName(ifaceName)
+ if err != nil {
+ t.Logf("interface %s not found: err: %s", ifaceName, err)
+ break
+ }
+ t.Logf("interface %s found", ifaceName)
+ }
+
+ err = iface.Create()
+ if err != nil {
+ t.Fatal(err)
+ }
+ wg, err := wgctrl.New()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func() {
+ err = wg.Close()
+ if err != nil {
+ t.Error(err)
+ }
+ }()
+
+ _, err = iface.Up()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ for {
+ _, err = net.InterfaceByName(ifaceName)
+ if err == nil {
+ t.Logf("interface %s found", ifaceName)
+
+ break
+ }
+ t.Logf("interface %s not found: err: %s", ifaceName, err)
+
+ }
+
+ start := time.Now()
+ err = iface.Close()
+ t.Logf("down time: %s", time.Since(start))
+ if err != nil {
+ t.Fatal(err)
+ }
+ })
+ }
+}
+
func Test_ConfigureInterface(t *testing.T) {
ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3)
wgIP := "10.99.99.5/30"
@@ -345,6 +415,9 @@ func Test_ConnectPeers(t *testing.T) {
t.Fatal(err)
}
+ guid := fmt.Sprintf("{%s}", uuid.New().String())
+ device.CustomWindowsGUIDString = strings.ToLower(guid)
+
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, peer1wgPort, peer1Key.String(), DefaultMTU, newNet, nil, nil)
if err != nil {
t.Fatal(err)
@@ -364,6 +437,9 @@ func Test_ConnectPeers(t *testing.T) {
t.Fatal(err)
}
+ guid = fmt.Sprintf("{%s}", uuid.New().String())
+ device.CustomWindowsGUIDString = strings.ToLower(guid)
+
newNet, err = stdnet.NewNet()
if err != nil {
t.Fatal(err)
diff --git a/iface/iface_unix.go b/client/iface/iface_unix.go
similarity index 53%
rename from iface/iface_unix.go
rename to client/iface/iface_unix.go
index 9608df1ad..09dbb2c1f 100644
--- a/iface/iface_unix.go
+++ b/client/iface/iface_unix.go
@@ -8,13 +8,14 @@ import (
"github.com/pion/transport/v3"
- "github.com/netbirdio/netbird/iface/bind"
- "github.com/netbirdio/netbird/iface/netstack"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/device"
+ "github.com/netbirdio/netbird/client/iface/netstack"
)
// NewWGIFace Creates a new WireGuard interface instance
-func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
- wgAddress, err := parseWGAddress(address)
+func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
+ wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
@@ -23,21 +24,21 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string,
// move the kernel/usp/netstack preference evaluation to upper layer
if netstack.IsEnabled() {
- wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
+ wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
wgIFace.userspaceBind = true
return wgIFace, nil
}
- if WireGuardModuleIsLoaded() {
- wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet)
+ if device.WireGuardModuleIsLoaded() {
+ wgIFace.tun = device.NewKernelDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet)
wgIFace.userspaceBind = false
return wgIFace, nil
}
- if !tunModuleIsLoaded() {
+ if !device.ModuleTunIsLoaded() {
return nil, fmt.Errorf("couldn't check or load tun module")
}
- wgIFace.tun = newTunUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil)
+ wgIFace.tun = device.NewUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil)
wgIFace.userspaceBind = true
return wgIFace, nil
}
diff --git a/iface/iface_windows.go b/client/iface/iface_windows.go
similarity index 52%
rename from iface/iface_windows.go
rename to client/iface/iface_windows.go
index c5edd27a9..6845ef3dd 100644
--- a/iface/iface_windows.go
+++ b/client/iface/iface_windows.go
@@ -5,13 +5,14 @@ import (
"github.com/pion/transport/v3"
- "github.com/netbirdio/netbird/iface/bind"
- "github.com/netbirdio/netbird/iface/netstack"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/device"
+ "github.com/netbirdio/netbird/client/iface/netstack"
)
// NewWGIFace Creates a new WireGuard interface instance
-func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
- wgAddress, err := parseWGAddress(address)
+func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
+ wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
@@ -21,11 +22,11 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string,
}
if netstack.IsEnabled() {
- wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
+ wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
return wgIFace, nil
}
- wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
+ wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
return wgIFace, nil
}
@@ -36,5 +37,5 @@ func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only
func (w *WGIface) GetInterfaceGUIDString() (string, error) {
- return w.tun.(*tunDevice).getInterfaceGUIDString()
+ return w.tun.(*device.TunDevice).GetInterfaceGUIDString()
}
diff --git a/client/iface/iwginterface.go b/client/iface/iwginterface.go
new file mode 100644
index 000000000..cb6d7ccd9
--- /dev/null
+++ b/client/iface/iwginterface.go
@@ -0,0 +1,34 @@
+//go:build !windows
+
+package iface
+
+import (
+ "net"
+ "time"
+
+ "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
+
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/configurer"
+ "github.com/netbirdio/netbird/client/iface/device"
+)
+
+type IWGIface interface {
+ Create() error
+ CreateOnAndroid(routeRange []string, ip string, domains []string) error
+ IsUserspaceBind() bool
+ Name() string
+ Address() device.WGAddress
+ ToInterface() *net.Interface
+ Up() (*bind.UniversalUDPMuxDefault, error)
+ UpdateAddr(newAddr string) error
+ UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
+ RemovePeer(peerKey string) error
+ AddAllowedIP(peerKey string, allowedIP string) error
+ RemoveAllowedIP(peerKey string, allowedIP string) error
+ Close() error
+ SetFilter(filter device.PacketFilter) error
+ GetFilter() device.PacketFilter
+ GetDevice() *device.FilteredDevice
+ GetStats(peerKey string) (configurer.WGStats, error)
+}
diff --git a/client/iface/iwginterface_windows.go b/client/iface/iwginterface_windows.go
new file mode 100644
index 000000000..6baeb66ae
--- /dev/null
+++ b/client/iface/iwginterface_windows.go
@@ -0,0 +1,33 @@
+package iface
+
+import (
+ "net"
+ "time"
+
+ "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
+
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/configurer"
+ "github.com/netbirdio/netbird/client/iface/device"
+)
+
+type IWGIface interface {
+ Create() error
+ CreateOnAndroid(routeRange []string, ip string, domains []string) error
+ IsUserspaceBind() bool
+ Name() string
+ Address() device.WGAddress
+ ToInterface() *net.Interface
+ Up() (*bind.UniversalUDPMuxDefault, error)
+ UpdateAddr(newAddr string) error
+ UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
+ RemovePeer(peerKey string) error
+ AddAllowedIP(peerKey string, allowedIP string) error
+ RemoveAllowedIP(peerKey string, allowedIP string) error
+ Close() error
+ SetFilter(filter device.PacketFilter) error
+ GetFilter() device.PacketFilter
+ GetDevice() *device.FilteredDevice
+ GetStats(peerKey string) (configurer.WGStats, error)
+ GetInterfaceGUIDString() (string, error)
+}
diff --git a/iface/mocks/README.md b/client/iface/mocks/README.md
similarity index 100%
rename from iface/mocks/README.md
rename to client/iface/mocks/README.md
diff --git a/iface/mocks/filter.go b/client/iface/mocks/filter.go
similarity index 97%
rename from iface/mocks/filter.go
rename to client/iface/mocks/filter.go
index 2d80d69f1..6348e0e77 100644
--- a/iface/mocks/filter.go
+++ b/client/iface/mocks/filter.go
@@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
-// Source: github.com/netbirdio/netbird/iface (interfaces: PacketFilter)
+// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter)
// Package mocks is a generated GoMock package.
package mocks
diff --git a/iface/mocks/iface/mocks/filter.go b/client/iface/mocks/iface/mocks/filter.go
similarity index 97%
rename from iface/mocks/iface/mocks/filter.go
rename to client/iface/mocks/iface/mocks/filter.go
index 059a2b9a0..17e123abb 100644
--- a/iface/mocks/iface/mocks/filter.go
+++ b/client/iface/mocks/iface/mocks/filter.go
@@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT.
-// Source: github.com/netbirdio/netbird/iface (interfaces: PacketFilter)
+// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter)
// Package mocks is a generated GoMock package.
package mocks
diff --git a/iface/mocks/tun.go b/client/iface/mocks/tun.go
similarity index 100%
rename from iface/mocks/tun.go
rename to client/iface/mocks/tun.go
diff --git a/iface/netstack/dialer.go b/client/iface/netstack/dialer.go
similarity index 100%
rename from iface/netstack/dialer.go
rename to client/iface/netstack/dialer.go
diff --git a/iface/netstack/env.go b/client/iface/netstack/env.go
similarity index 100%
rename from iface/netstack/env.go
rename to client/iface/netstack/env.go
diff --git a/iface/netstack/proxy.go b/client/iface/netstack/proxy.go
similarity index 100%
rename from iface/netstack/proxy.go
rename to client/iface/netstack/proxy.go
diff --git a/iface/netstack/tun.go b/client/iface/netstack/tun.go
similarity index 100%
rename from iface/netstack/tun.go
rename to client/iface/netstack/tun.go
diff --git a/client/internal/acl/id/id.go b/client/internal/acl/id/id.go
new file mode 100644
index 000000000..e27fce439
--- /dev/null
+++ b/client/internal/acl/id/id.go
@@ -0,0 +1,25 @@
+package id
+
+import (
+ "fmt"
+ "net/netip"
+
+ "github.com/netbirdio/netbird/client/firewall/manager"
+)
+
+type RuleID string
+
+func (r RuleID) GetRuleID() string {
+ return string(r)
+}
+
+func GenerateRouteRuleKey(
+ sources []netip.Prefix,
+ destination netip.Prefix,
+ proto manager.Protocol,
+ sPort *manager.Port,
+ dPort *manager.Port,
+ action manager.Action,
+) RuleID {
+ return RuleID(fmt.Sprintf("%s-%s-%s-%s-%s-%d", sources, destination, proto, sPort, dPort, action))
+}
diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go
index fd2c2c875..ce2a12af1 100644
--- a/client/internal/acl/manager.go
+++ b/client/internal/acl/manager.go
@@ -5,6 +5,7 @@ import (
"encoding/hex"
"fmt"
"net"
+ "net/netip"
"strconv"
"sync"
"time"
@@ -12,6 +13,7 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
+ "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/ssh"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
@@ -23,16 +25,18 @@ type Manager interface {
// DefaultManager uses firewall manager to handle
type DefaultManager struct {
- firewall firewall.Manager
- ipsetCounter int
- rulesPairs map[string][]firewall.Rule
- mutex sync.Mutex
+ firewall firewall.Manager
+ ipsetCounter int
+ peerRulesPairs map[id.RuleID][]firewall.Rule
+ routeRules map[id.RuleID]struct{}
+ mutex sync.Mutex
}
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
return &DefaultManager{
- firewall: fm,
- rulesPairs: make(map[string][]firewall.Rule),
+ firewall: fm,
+ peerRulesPairs: make(map[id.RuleID][]firewall.Rule),
+ routeRules: make(map[id.RuleID]struct{}),
}
}
@@ -46,7 +50,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
start := time.Now()
defer func() {
total := 0
- for _, pairs := range d.rulesPairs {
+ for _, pairs := range d.peerRulesPairs {
total += len(pairs)
}
log.Infof(
@@ -59,21 +63,34 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
return
}
- defer func() {
- if err := d.firewall.Flush(); err != nil {
- log.Error("failed to flush firewall rules: ", err)
- }
- }()
+ d.applyPeerACLs(networkMap)
+ // If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
+ // then the mgmt server is older than the client, and we need to allow all traffic for routes
+ isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
+ if err := d.firewall.SetLegacyManagement(isLegacy); err != nil {
+ log.Errorf("failed to set legacy management flag: %v", err)
+ }
+
+ if err := d.applyRouteACLs(networkMap.RoutesFirewallRules); err != nil {
+ log.Errorf("Failed to apply route ACLs: %v", err)
+ }
+
+ if err := d.firewall.Flush(); err != nil {
+ log.Error("failed to flush firewall rules: ", err)
+ }
+}
+
+func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
rules, squashedProtocols := d.squashAcceptRules(networkMap)
- enableSSH := (networkMap.PeerConfig != nil &&
+ enableSSH := networkMap.PeerConfig != nil &&
networkMap.PeerConfig.SshConfig != nil &&
- networkMap.PeerConfig.SshConfig.SshEnabled)
- if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok {
+ networkMap.PeerConfig.SshConfig.SshEnabled
+ if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
enableSSH = enableSSH && !ok
}
- if _, ok := squashedProtocols[mgmProto.FirewallRule_TCP]; ok {
+ if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok {
enableSSH = enableSSH && !ok
}
@@ -83,9 +100,9 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
if enableSSH {
rules = append(rules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
- Direction: mgmProto.FirewallRule_IN,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_TCP,
+ Direction: mgmProto.RuleDirection_IN,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_TCP,
Port: strconv.Itoa(ssh.DefaultSSHPort),
})
}
@@ -97,20 +114,20 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
rules = append(rules,
&mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
- Direction: mgmProto.FirewallRule_IN,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_IN,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
&mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
- Direction: mgmProto.FirewallRule_OUT,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_OUT,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
)
}
- newRulePairs := make(map[string][]firewall.Rule)
+ newRulePairs := make(map[id.RuleID][]firewall.Rule)
ipsetByRuleSelectors := make(map[string]string)
for _, r := range rules {
@@ -130,29 +147,97 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
break
}
if len(rules) > 0 {
- d.rulesPairs[pairID] = rulePair
+ d.peerRulesPairs[pairID] = rulePair
newRulePairs[pairID] = rulePair
}
}
- for pairID, rules := range d.rulesPairs {
+ for pairID, rules := range d.peerRulesPairs {
if _, ok := newRulePairs[pairID]; !ok {
for _, rule := range rules {
- if err := d.firewall.DeleteRule(rule); err != nil {
- log.Errorf("failed to delete firewall rule: %v", err)
+ if err := d.firewall.DeletePeerRule(rule); err != nil {
+ log.Errorf("failed to delete peer firewall rule: %v", err)
continue
}
}
- delete(d.rulesPairs, pairID)
+ delete(d.peerRulesPairs, pairID)
}
}
- d.rulesPairs = newRulePairs
+ d.peerRulesPairs = newRulePairs
+}
+
+func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error {
+ var newRouteRules = make(map[id.RuleID]struct{})
+ for _, rule := range rules {
+ id, err := d.applyRouteACL(rule)
+ if err != nil {
+ return fmt.Errorf("apply route ACL: %w", err)
+ }
+ newRouteRules[id] = struct{}{}
+ }
+
+ for id := range d.routeRules {
+ if _, ok := newRouteRules[id]; !ok {
+ if err := d.firewall.DeleteRouteRule(id); err != nil {
+ log.Errorf("failed to delete route firewall rule: %v", err)
+ continue
+ }
+ delete(d.routeRules, id)
+ }
+ }
+ d.routeRules = newRouteRules
+ return nil
+}
+
+func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) {
+ if len(rule.SourceRanges) == 0 {
+ return "", fmt.Errorf("source ranges is empty")
+ }
+
+ var sources []netip.Prefix
+ for _, sourceRange := range rule.SourceRanges {
+ source, err := netip.ParsePrefix(sourceRange)
+ if err != nil {
+ return "", fmt.Errorf("parse source range: %w", err)
+ }
+ sources = append(sources, source)
+ }
+
+ var destination netip.Prefix
+ if rule.IsDynamic {
+ destination = getDefault(sources[0])
+ } else {
+ var err error
+ destination, err = netip.ParsePrefix(rule.Destination)
+ if err != nil {
+ return "", fmt.Errorf("parse destination: %w", err)
+ }
+ }
+
+ protocol, err := convertToFirewallProtocol(rule.Protocol)
+ if err != nil {
+ return "", fmt.Errorf("invalid protocol: %w", err)
+ }
+
+ action, err := convertFirewallAction(rule.Action)
+ if err != nil {
+ return "", fmt.Errorf("invalid action: %w", err)
+ }
+
+ dPorts := convertPortInfo(rule.PortInfo)
+
+ addedRule, err := d.firewall.AddRouteFiltering(sources, destination, protocol, nil, dPorts, action)
+ if err != nil {
+ return "", fmt.Errorf("add route rule: %w", err)
+ }
+
+ return id.RuleID(addedRule.GetRuleID()), nil
}
func (d *DefaultManager) protoRuleToFirewallRule(
r *mgmProto.FirewallRule,
ipsetName string,
-) (string, []firewall.Rule, error) {
+) (id.RuleID, []firewall.Rule, error) {
ip := net.ParseIP(r.PeerIP)
if ip == nil {
return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule")
@@ -179,16 +264,16 @@ func (d *DefaultManager) protoRuleToFirewallRule(
}
}
- ruleID := d.getRuleID(ip, protocol, int(r.Direction), port, action, "")
- if rulesPair, ok := d.rulesPairs[ruleID]; ok {
+ ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action, "")
+ if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
return ruleID, rulesPair, nil
}
var rules []firewall.Rule
switch r.Direction {
- case mgmProto.FirewallRule_IN:
+ case mgmProto.RuleDirection_IN:
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
- case mgmProto.FirewallRule_OUT:
+ case mgmProto.RuleDirection_OUT:
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
@@ -210,7 +295,7 @@ func (d *DefaultManager) addInRules(
comment string,
) ([]firewall.Rule, error) {
var rules []firewall.Rule
- rule, err := d.firewall.AddFiltering(
+ rule, err := d.firewall.AddPeerFiltering(
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
@@ -221,7 +306,7 @@ func (d *DefaultManager) addInRules(
return rules, nil
}
- rule, err = d.firewall.AddFiltering(
+ rule, err = d.firewall.AddPeerFiltering(
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
@@ -239,7 +324,7 @@ func (d *DefaultManager) addOutRules(
comment string,
) ([]firewall.Rule, error) {
var rules []firewall.Rule
- rule, err := d.firewall.AddFiltering(
+ rule, err := d.firewall.AddPeerFiltering(
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
@@ -250,7 +335,7 @@ func (d *DefaultManager) addOutRules(
return rules, nil
}
- rule, err = d.firewall.AddFiltering(
+ rule, err = d.firewall.AddPeerFiltering(
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
@@ -259,21 +344,21 @@ func (d *DefaultManager) addOutRules(
return append(rules, rule...), nil
}
-// getRuleID() returns unique ID for the rule based on its parameters.
-func (d *DefaultManager) getRuleID(
+// getPeerRuleID() returns unique ID for the rule based on its parameters.
+func (d *DefaultManager) getPeerRuleID(
ip net.IP,
proto firewall.Protocol,
direction int,
port *firewall.Port,
action firewall.Action,
comment string,
-) string {
+) id.RuleID {
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
if port != nil {
idStr += port.String()
}
- return hex.EncodeToString(md5.New().Sum([]byte(idStr)))
+ return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr))))
}
// squashAcceptRules does complex logic to convert many rules which allows connection by traffic type
@@ -283,7 +368,7 @@ func (d *DefaultManager) getRuleID(
// but other has port definitions or has drop policy.
func (d *DefaultManager) squashAcceptRules(
networkMap *mgmProto.NetworkMap,
-) ([]*mgmProto.FirewallRule, map[mgmProto.FirewallRuleProtocol]struct{}) {
+) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) {
totalIPs := 0
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
for range p.AllowedIps {
@@ -291,14 +376,14 @@ func (d *DefaultManager) squashAcceptRules(
}
}
- type protoMatch map[mgmProto.FirewallRuleProtocol]map[string]int
+ type protoMatch map[mgmProto.RuleProtocol]map[string]int
in := protoMatch{}
out := protoMatch{}
// trace which type of protocols was squashed
squashedRules := []*mgmProto.FirewallRule{}
- squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{}
+ squashedProtocols := map[mgmProto.RuleProtocol]struct{}{}
// this function we use to do calculation, can we squash the rules by protocol or not.
// We summ amount of Peers IP for given protocol we found in original rules list.
@@ -308,7 +393,7 @@ func (d *DefaultManager) squashAcceptRules(
//
// We zeroed this to notify squash function that this protocol can't be squashed.
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols protoMatch) {
- drop := r.Action == mgmProto.FirewallRule_DROP || r.Port != ""
+ drop := r.Action == mgmProto.RuleAction_DROP || r.Port != ""
if drop {
protocols[r.Protocol] = map[string]int{}
return
@@ -336,7 +421,7 @@ func (d *DefaultManager) squashAcceptRules(
for i, r := range networkMap.FirewallRules {
// calculate squash for different directions
- if r.Direction == mgmProto.FirewallRule_IN {
+ if r.Direction == mgmProto.RuleDirection_IN {
addRuleToCalculationMap(i, r, in)
} else {
addRuleToCalculationMap(i, r, out)
@@ -345,14 +430,14 @@ func (d *DefaultManager) squashAcceptRules(
// order of squashing by protocol is important
// only for their first element ALL, it must be done first
- protocolOrders := []mgmProto.FirewallRuleProtocol{
- mgmProto.FirewallRule_ALL,
- mgmProto.FirewallRule_ICMP,
- mgmProto.FirewallRule_TCP,
- mgmProto.FirewallRule_UDP,
+ protocolOrders := []mgmProto.RuleProtocol{
+ mgmProto.RuleProtocol_ALL,
+ mgmProto.RuleProtocol_ICMP,
+ mgmProto.RuleProtocol_TCP,
+ mgmProto.RuleProtocol_UDP,
}
- squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) {
+ squash := func(matches protoMatch, direction mgmProto.RuleDirection) {
for _, protocol := range protocolOrders {
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
// don't squash if :
@@ -365,12 +450,12 @@ func (d *DefaultManager) squashAcceptRules(
squashedRules = append(squashedRules, &mgmProto.FirewallRule{
PeerIP: "0.0.0.0",
Direction: direction,
- Action: mgmProto.FirewallRule_ACCEPT,
+ Action: mgmProto.RuleAction_ACCEPT,
Protocol: protocol,
})
squashedProtocols[protocol] = struct{}{}
- if protocol == mgmProto.FirewallRule_ALL {
+ if protocol == mgmProto.RuleProtocol_ALL {
// if we have ALL traffic type squashed rule
// it allows all other type of traffic, so we can stop processing
break
@@ -378,11 +463,11 @@ func (d *DefaultManager) squashAcceptRules(
}
}
- squash(in, mgmProto.FirewallRule_IN)
- squash(out, mgmProto.FirewallRule_OUT)
+ squash(in, mgmProto.RuleDirection_IN)
+ squash(out, mgmProto.RuleDirection_OUT)
// if all protocol was squashed everything is allow and we can ignore all other rules
- if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok {
+ if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok {
return squashedRules, squashedProtocols
}
@@ -412,26 +497,26 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st
return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port)
}
-func (d *DefaultManager) rollBack(newRulePairs map[string][]firewall.Rule) {
+func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
log.Debugf("rollback ACL to previous state")
for _, rules := range newRulePairs {
for _, rule := range rules {
- if err := d.firewall.DeleteRule(rule); err != nil {
+ if err := d.firewall.DeletePeerRule(rule); err != nil {
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
}
}
}
}
-func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) (firewall.Protocol, error) {
+func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) {
switch protocol {
- case mgmProto.FirewallRule_TCP:
+ case mgmProto.RuleProtocol_TCP:
return firewall.ProtocolTCP, nil
- case mgmProto.FirewallRule_UDP:
+ case mgmProto.RuleProtocol_UDP:
return firewall.ProtocolUDP, nil
- case mgmProto.FirewallRule_ICMP:
+ case mgmProto.RuleProtocol_ICMP:
return firewall.ProtocolICMP, nil
- case mgmProto.FirewallRule_ALL:
+ case mgmProto.RuleProtocol_ALL:
return firewall.ProtocolALL, nil
default:
return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String())
@@ -442,13 +527,41 @@ func shouldSkipInvertedRule(protocol firewall.Protocol, port *firewall.Port) boo
return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil
}
-func convertFirewallAction(action mgmProto.FirewallRuleAction) (firewall.Action, error) {
+func convertFirewallAction(action mgmProto.RuleAction) (firewall.Action, error) {
switch action {
- case mgmProto.FirewallRule_ACCEPT:
+ case mgmProto.RuleAction_ACCEPT:
return firewall.ActionAccept, nil
- case mgmProto.FirewallRule_DROP:
+ case mgmProto.RuleAction_DROP:
return firewall.ActionDrop, nil
default:
return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action)
}
}
+
+func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port {
+ if portInfo == nil {
+ return nil
+ }
+
+ if portInfo.GetPort() != 0 {
+ return &firewall.Port{
+ Values: []int{int(portInfo.GetPort())},
+ }
+ }
+
+ if portInfo.GetRange() != nil {
+ return &firewall.Port{
+ IsRange: true,
+ Values: []int{int(portInfo.GetRange().Start), int(portInfo.GetRange().End)},
+ }
+ }
+
+ return nil
+}
+
+func getDefault(prefix netip.Prefix) netip.Prefix {
+ if prefix.Addr().Is6() {
+ return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
+ }
+ return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
+}
diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go
index 494d54bf2..7d999669a 100644
--- a/client/internal/acl/manager_test.go
+++ b/client/internal/acl/manager_test.go
@@ -9,8 +9,8 @@ import (
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager"
+ "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/acl/mocks"
- "github.com/netbirdio/netbird/iface"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
@@ -19,16 +19,16 @@ func TestDefaultManager(t *testing.T) {
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
- Direction: mgmProto.FirewallRule_OUT,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_TCP,
+ Direction: mgmProto.RuleDirection_OUT,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_TCP,
Port: "80",
},
{
PeerIP: "10.93.0.2",
- Direction: mgmProto.FirewallRule_OUT,
- Action: mgmProto.FirewallRule_DROP,
- Protocol: mgmProto.FirewallRule_UDP,
+ Direction: mgmProto.RuleDirection_OUT,
+ Action: mgmProto.RuleAction_DROP,
+ Protocol: mgmProto.RuleProtocol_UDP,
Port: "53",
},
},
@@ -65,16 +65,16 @@ func TestDefaultManager(t *testing.T) {
t.Run("apply firewall rules", func(t *testing.T) {
acl.ApplyFiltering(networkMap)
- if len(acl.rulesPairs) != 2 {
- t.Errorf("firewall rules not applied: %v", acl.rulesPairs)
+ if len(acl.peerRulesPairs) != 2 {
+ t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs)
return
}
})
t.Run("add extra rules", func(t *testing.T) {
existedPairs := map[string]struct{}{}
- for id := range acl.rulesPairs {
- existedPairs[id] = struct{}{}
+ for id := range acl.peerRulesPairs {
+ existedPairs[id.GetRuleID()] = struct{}{}
}
// remove first rule
@@ -83,24 +83,24 @@ func TestDefaultManager(t *testing.T) {
networkMap.FirewallRules,
&mgmProto.FirewallRule{
PeerIP: "10.93.0.3",
- Direction: mgmProto.FirewallRule_IN,
- Action: mgmProto.FirewallRule_DROP,
- Protocol: mgmProto.FirewallRule_ICMP,
+ Direction: mgmProto.RuleDirection_IN,
+ Action: mgmProto.RuleAction_DROP,
+ Protocol: mgmProto.RuleProtocol_ICMP,
},
)
acl.ApplyFiltering(networkMap)
// we should have one old and one new rule in the existed rules
- if len(acl.rulesPairs) != 2 {
+ if len(acl.peerRulesPairs) != 2 {
t.Errorf("firewall rules not applied")
return
}
// check that old rule was removed
previousCount := 0
- for id := range acl.rulesPairs {
- if _, ok := existedPairs[id]; ok {
+ for id := range acl.peerRulesPairs {
+ if _, ok := existedPairs[id.GetRuleID()]; ok {
previousCount++
}
}
@@ -113,15 +113,15 @@ func TestDefaultManager(t *testing.T) {
networkMap.FirewallRules = networkMap.FirewallRules[:0]
networkMap.FirewallRulesIsEmpty = true
- if acl.ApplyFiltering(networkMap); len(acl.rulesPairs) != 0 {
- t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.rulesPairs))
+ if acl.ApplyFiltering(networkMap); len(acl.peerRulesPairs) != 0 {
+ t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs))
return
}
networkMap.FirewallRulesIsEmpty = false
acl.ApplyFiltering(networkMap)
- if len(acl.rulesPairs) != 2 {
- t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.rulesPairs))
+ if len(acl.peerRulesPairs) != 2 {
+ t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
return
}
})
@@ -138,51 +138,51 @@ func TestDefaultManagerSquashRules(t *testing.T) {
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
- Direction: mgmProto.FirewallRule_IN,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_IN,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
- Direction: mgmProto.FirewallRule_IN,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_IN,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
- Direction: mgmProto.FirewallRule_IN,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_IN,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
- Direction: mgmProto.FirewallRule_IN,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_IN,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.1",
- Direction: mgmProto.FirewallRule_OUT,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_OUT,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
- Direction: mgmProto.FirewallRule_OUT,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_OUT,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
- Direction: mgmProto.FirewallRule_OUT,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_OUT,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
- Direction: mgmProto.FirewallRule_OUT,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_OUT,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
},
}
@@ -199,13 +199,13 @@ func TestDefaultManagerSquashRules(t *testing.T) {
case r.PeerIP != "0.0.0.0":
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
return
- case r.Direction != mgmProto.FirewallRule_IN:
+ case r.Direction != mgmProto.RuleDirection_IN:
t.Errorf("direction should be IN, got: %v", r.Direction)
return
- case r.Protocol != mgmProto.FirewallRule_ALL:
+ case r.Protocol != mgmProto.RuleProtocol_ALL:
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return
- case r.Action != mgmProto.FirewallRule_ACCEPT:
+ case r.Action != mgmProto.RuleAction_ACCEPT:
t.Errorf("action should be ACCEPT, got: %v", r.Action)
return
}
@@ -215,13 +215,13 @@ func TestDefaultManagerSquashRules(t *testing.T) {
case r.PeerIP != "0.0.0.0":
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
return
- case r.Direction != mgmProto.FirewallRule_OUT:
+ case r.Direction != mgmProto.RuleDirection_OUT:
t.Errorf("direction should be OUT, got: %v", r.Direction)
return
- case r.Protocol != mgmProto.FirewallRule_ALL:
+ case r.Protocol != mgmProto.RuleProtocol_ALL:
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return
- case r.Action != mgmProto.FirewallRule_ACCEPT:
+ case r.Action != mgmProto.RuleAction_ACCEPT:
t.Errorf("action should be ACCEPT, got: %v", r.Action)
return
}
@@ -238,51 +238,51 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
- Direction: mgmProto.FirewallRule_IN,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_IN,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
- Direction: mgmProto.FirewallRule_IN,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_IN,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
- Direction: mgmProto.FirewallRule_IN,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_IN,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
- Direction: mgmProto.FirewallRule_IN,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_TCP,
+ Direction: mgmProto.RuleDirection_IN,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.1",
- Direction: mgmProto.FirewallRule_OUT,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_OUT,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.2",
- Direction: mgmProto.FirewallRule_OUT,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_OUT,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.3",
- Direction: mgmProto.FirewallRule_OUT,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_ALL,
+ Direction: mgmProto.RuleDirection_OUT,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_ALL,
},
{
PeerIP: "10.93.0.4",
- Direction: mgmProto.FirewallRule_OUT,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_UDP,
+ Direction: mgmProto.RuleDirection_OUT,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_UDP,
},
},
}
@@ -308,21 +308,21 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
- Direction: mgmProto.FirewallRule_IN,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_TCP,
+ Direction: mgmProto.RuleDirection_IN,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
- Direction: mgmProto.FirewallRule_IN,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_TCP,
+ Direction: mgmProto.RuleDirection_IN,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
- Direction: mgmProto.FirewallRule_OUT,
- Action: mgmProto.FirewallRule_ACCEPT,
- Protocol: mgmProto.FirewallRule_UDP,
+ Direction: mgmProto.RuleDirection_OUT,
+ Action: mgmProto.RuleAction_ACCEPT,
+ Protocol: mgmProto.RuleProtocol_UDP,
},
},
}
@@ -357,8 +357,8 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
acl.ApplyFiltering(networkMap)
- if len(acl.rulesPairs) != 4 {
- t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.rulesPairs))
+ if len(acl.peerRulesPairs) != 4 {
+ t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
return
}
}
diff --git a/client/internal/acl/mocks/iface_mapper.go b/client/internal/acl/mocks/iface_mapper.go
index 621b29513..3ed12b6dd 100644
--- a/client/internal/acl/mocks/iface_mapper.go
+++ b/client/internal/acl/mocks/iface_mapper.go
@@ -8,7 +8,8 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
- iface "github.com/netbirdio/netbird/iface"
+ iface "github.com/netbirdio/netbird/client/iface"
+ "github.com/netbirdio/netbird/client/iface/device"
)
// MockIFaceMapper is a mock of IFaceMapper interface.
@@ -77,7 +78,7 @@ func (mr *MockIFaceMapperMockRecorder) Name() *gomock.Call {
}
// SetFilter mocks base method.
-func (m *MockIFaceMapper) SetFilter(arg0 iface.PacketFilter) error {
+func (m *MockIFaceMapper) SetFilter(arg0 device.PacketFilter) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetFilter", arg0)
ret0, _ := ret[0].(error)
diff --git a/client/internal/config.go b/client/internal/config.go
index 725703c43..ee54c6380 100644
--- a/client/internal/config.go
+++ b/client/internal/config.go
@@ -16,9 +16,9 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
+ "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/ssh"
- "github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/util"
)
@@ -117,6 +117,11 @@ type Config struct {
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
func ReadConfig(configPath string) (*Config, error) {
if configFileIsExists(configPath) {
+ err := util.EnforcePermission(configPath)
+ if err != nil {
+ log.Errorf("failed to enforce permission on config dir: %v", err)
+ }
+
config := &Config{}
if _, err := util.ReadJson(configPath, config); err != nil {
return nil, err
@@ -159,13 +164,17 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if err != nil {
return nil, err
}
- err = WriteOutConfig(input.ConfigPath, cfg)
+ err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg)
return cfg, err
}
if isPreSharedKeyHidden(input.PreSharedKey) {
input.PreSharedKey = nil
}
+ err := util.EnforcePermission(input.ConfigPath)
+ if err != nil {
+ log.Errorf("failed to enforce permission on config dir: %v", err)
+ }
return update(input)
}
diff --git a/client/internal/connect.go b/client/internal/connect.go
index 1cfabe910..c77f95603 100644
--- a/client/internal/connect.go
+++ b/client/internal/connect.go
@@ -17,15 +17,18 @@ import (
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
+ "github.com/netbirdio/netbird/client/iface"
+ "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
- "github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client"
mgmProto "github.com/netbirdio/netbird/management/proto"
+ "github.com/netbirdio/netbird/relay/auth/hmac"
+ relayClient "github.com/netbirdio/netbird/relay/client"
signal "github.com/netbirdio/netbird/signal/client"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/version"
@@ -55,22 +58,20 @@ func NewConnectClient(
// Run with main logic.
func (c *ConnectClient) Run() error {
- return c.run(MobileDependency{}, nil, nil, nil, nil)
+ return c.run(MobileDependency{}, nil, nil)
}
// RunWithProbes runs the client's main logic with probes attached
func (c *ConnectClient) RunWithProbes(
- mgmProbe *Probe,
- signalProbe *Probe,
- relayProbe *Probe,
- wgProbe *Probe,
+ probes *ProbeHolder,
+ runningChan chan error,
) error {
- return c.run(MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe)
+ return c.run(MobileDependency{}, probes, runningChan)
}
// RunOnAndroid with main logic on mobile system
func (c *ConnectClient) RunOnAndroid(
- tunAdapter iface.TunAdapter,
+ tunAdapter device.TunAdapter,
iFaceDiscover stdnet.ExternalIFaceDiscover,
networkChangeListener listener.NetworkChangeListener,
dnsAddresses []string,
@@ -84,7 +85,7 @@ func (c *ConnectClient) RunOnAndroid(
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
}
- return c.run(mobileDependency, nil, nil, nil, nil)
+ return c.run(mobileDependency, nil, nil)
}
func (c *ConnectClient) RunOniOS(
@@ -100,15 +101,13 @@ func (c *ConnectClient) RunOniOS(
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
}
- return c.run(mobileDependency, nil, nil, nil, nil)
+ return c.run(mobileDependency, nil, nil)
}
func (c *ConnectClient) run(
mobileDependency MobileDependency,
- mgmProbe *Probe,
- signalProbe *Probe,
- relayProbe *Probe,
- wgProbe *Probe,
+ probes *ProbeHolder,
+ runningChan chan error,
) error {
defer func() {
if r := recover(); r != nil {
@@ -160,12 +159,11 @@ func (c *ConnectClient) run(
}
defer c.statusRecorder.ClientStop()
+ runningChanOpen := true
operation := func() error {
// if context cancelled we not start new backoff cycle
- select {
- case <-c.ctx.Done():
+ if c.isContextCancelled() {
return nil
- default:
}
state.Set(StatusConnecting)
@@ -187,8 +185,7 @@ func (c *ConnectClient) run(
log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host)
defer func() {
- err = mgmClient.Close()
- if err != nil {
+ if err = mgmClient.Close(); err != nil {
log.Warnf("failed to close the Management service client %v", err)
}
}()
@@ -199,6 +196,7 @@ func (c *ConnectClient) run(
log.Debug(err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
state.Set(StatusNeedsLogin)
+ _ = c.Stop()
return backoff.Permanent(wrapErr(err)) // unrecoverable error
}
return wrapErr(err)
@@ -208,10 +206,9 @@ func (c *ConnectClient) run(
localPeerState := peer.LocalPeerState{
IP: loginResp.GetPeerConfig().GetAddress(),
PubKey: myPrivateKey.PublicKey().String(),
- KernelInterface: iface.WireGuardModuleIsLoaded(),
+ KernelInterface: device.WireGuardModuleIsLoaded(),
FQDN: loginResp.GetPeerConfig().GetFqdn(),
}
-
c.statusRecorder.UpdateLocalPeerState(localPeerState)
signalURL := fmt.Sprintf("%s://%s",
@@ -244,6 +241,23 @@ func (c *ConnectClient) run(
c.statusRecorder.MarkSignalConnected()
+ relayURLs, token := parseRelayInfo(loginResp)
+ relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String())
+ if len(relayURLs) > 0 {
+ if token != nil {
+ if err := relayManager.UpdateToken(token); err != nil {
+ log.Errorf("failed to update token: %s", err)
+ return wrapErr(err)
+ }
+ }
+ log.Infof("connecting to the Relay service(s): %s", strings.Join(relayURLs, ", "))
+ if err = relayManager.Serve(); err != nil {
+ log.Error(err)
+ return wrapErr(err)
+ }
+ c.statusRecorder.SetRelayMgr(relayManager)
+ }
+
peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig)
@@ -255,11 +269,17 @@ func (c *ConnectClient) run(
checks := loginResp.GetChecks()
c.engineMutex.Lock()
- c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, checks)
+ if c.engine != nil && c.engine.ctx.Err() != nil {
+ log.Info("Stopping Netbird Engine")
+ if err := c.engine.Stop(); err != nil {
+ log.Errorf("Failed to stop engine: %v", err)
+ }
+ }
+ c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
+
c.engineMutex.Unlock()
- err = c.engine.Start()
- if err != nil {
+ if err := c.engine.Start(); err != nil {
log.Errorf("error while starting Netbird Connection Engine: %s", err)
return wrapErr(err)
}
@@ -267,17 +287,17 @@ func (c *ConnectClient) run(
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)
+ if runningChan != nil && runningChanOpen {
+ runningChan <- nil
+ close(runningChan)
+ runningChanOpen = false
+ }
+
<-engineCtx.Done()
c.statusRecorder.ClientTeardown()
backOff.Reset()
- err = c.engine.Stop()
- if err != nil {
- log.Errorf("failed stopping engine %v", err)
- return wrapErr(err)
- }
-
log.Info("stopped NetBird client")
if _, err := state.Status(); errors.Is(err, ErrResetConnection) {
@@ -293,13 +313,31 @@ func (c *ConnectClient) run(
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
state.Set(StatusNeedsLogin)
+ _ = c.Stop()
}
return err
}
return nil
}
+func parseRelayInfo(loginResp *mgmProto.LoginResponse) ([]string, *hmac.Token) {
+ relayCfg := loginResp.GetWiretrusteeConfig().GetRelay()
+ if relayCfg == nil {
+ return nil, nil
+ }
+
+ token := &hmac.Token{
+ Payload: relayCfg.GetTokenPayload(),
+ Signature: relayCfg.GetTokenSignature(),
+ }
+
+ return relayCfg.GetUrls(), token
+}
+
func (c *ConnectClient) Engine() *Engine {
+ if c == nil {
+ return nil
+ }
var e *Engine
c.engineMutex.Lock()
e = c.engine
@@ -307,6 +345,28 @@ func (c *ConnectClient) Engine() *Engine {
return e
}
+func (c *ConnectClient) Stop() error {
+ if c == nil {
+ return nil
+ }
+ c.engineMutex.Lock()
+ defer c.engineMutex.Unlock()
+
+ if c.engine == nil {
+ return nil
+ }
+ return c.engine.Stop()
+}
+
+func (c *ConnectClient) isContextCancelled() bool {
+ select {
+ case <-c.ctx.Done():
+ return true
+ default:
+ return false
+ }
+}
+
// createEngineConfig converts configuration received from Management Service to EngineConfig
func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) {
nm := false
@@ -397,19 +457,43 @@ func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal
return notifier
}
-func freePort(start int) (int, error) {
+// freePort attempts to determine if the provided port is available, if not it will ask the system for a free port.
+func freePort(initPort int) (int, error) {
addr := net.UDPAddr{}
- if start == 0 {
- start = iface.DefaultWgPort
+ if initPort == 0 {
+ initPort = iface.DefaultWgPort
}
- for x := start; x <= 65535; x++ {
- addr.Port = x
- conn, err := net.ListenUDP("udp", &addr)
- if err != nil {
- continue
- }
- conn.Close()
- return x, nil
+
+ addr.Port = initPort
+
+ conn, err := net.ListenUDP("udp", &addr)
+ if err == nil {
+ closeConnWithLog(conn)
+ return initPort, nil
+ }
+
+ // if the port is already in use, ask the system for a free port
+ addr.Port = 0
+ conn, err = net.ListenUDP("udp", &addr)
+ if err != nil {
+ return 0, fmt.Errorf("unable to get a free port: %v", err)
+ }
+
+ udpAddr, ok := conn.LocalAddr().(*net.UDPAddr)
+ if !ok {
+ return 0, errors.New("wrong address type when getting a free port")
+ }
+ closeConnWithLog(conn)
+ return udpAddr.Port, nil
+}
+
+func closeConnWithLog(conn *net.UDPConn) {
+ startClosing := time.Now()
+ err := conn.Close()
+ if err != nil {
+ log.Warnf("closing probe port %d failed: %v. NetBird will still attempt to use this port for connection.", conn.LocalAddr().(*net.UDPAddr).Port, err)
+ }
+ if time.Since(startClosing) > time.Second {
+ log.Warnf("closing the testing port %d took %s. Usually it is safe to ignore, but continuous warnings may indicate a problem.", conn.LocalAddr().(*net.UDPAddr).Port, time.Since(startClosing))
}
- return 0, errors.New("no free ports")
}
diff --git a/client/internal/connect_test.go b/client/internal/connect_test.go
index 6f4a6bbb7..78b4b06e8 100644
--- a/client/internal/connect_test.go
+++ b/client/internal/connect_test.go
@@ -7,51 +7,55 @@ import (
func Test_freePort(t *testing.T) {
tests := []struct {
- name string
- port int
- want int
- wantErr bool
+ name string
+ port int
+ want int
+ shouldMatch bool
}{
{
- name: "available",
- port: 51820,
- want: 51820,
- wantErr: false,
+ name: "not provided, fallback to default",
+ port: 0,
+ want: 51820,
+ shouldMatch: true,
},
{
- name: "notavailable",
- port: 51830,
- want: 51831,
- wantErr: false,
+ name: "provided and available",
+ port: 51821,
+ want: 51821,
+ shouldMatch: true,
},
{
- name: "noports",
- port: 65535,
- want: 0,
- wantErr: true,
+ name: "provided and not available",
+ port: 51830,
+ want: 51830,
+ shouldMatch: false,
},
}
+ c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 51830})
+ if err != nil {
+ t.Errorf("freePort error = %v", err)
+ }
+ defer func(c1 *net.UDPConn) {
+ _ = c1.Close()
+ }(c1)
+
for _, tt := range tests {
- c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 51830})
- if err != nil {
- t.Errorf("freePort error = %v", err)
- }
- c2, err := net.ListenUDP("udp", &net.UDPAddr{Port: 65535})
- if err != nil {
- t.Errorf("freePort error = %v", err)
- }
t.Run(tt.name, func(t *testing.T) {
got, err := freePort(tt.port)
- if (err != nil) != tt.wantErr {
- t.Errorf("freePort() error = %v, wantErr %v", err, tt.wantErr)
- return
+
+ if err != nil {
+ t.Errorf("got an error while getting free port: %v", err)
}
- if got != tt.want {
- t.Errorf("freePort() = %v, want %v", got, tt.want)
+
+ if tt.shouldMatch && got != tt.want {
+ t.Errorf("got a different port %v, want %v", got, tt.want)
+ }
+
+ if !tt.shouldMatch && got == tt.want {
+ t.Errorf("got the same port %v, want a different port", tt.want)
}
})
- c1.Close()
- c2.Close()
+
}
}
diff --git a/client/internal/dns/response_writer_test.go b/client/internal/dns/response_writer_test.go
index 5a0047700..857964406 100644
--- a/client/internal/dns/response_writer_test.go
+++ b/client/internal/dns/response_writer_test.go
@@ -9,7 +9,7 @@ import (
"github.com/google/gopacket/layers"
"github.com/miekg/dns"
- "github.com/netbirdio/netbird/iface/mocks"
+ "github.com/netbirdio/netbird/client/iface/mocks"
)
func TestResponseWriterLocalAddr(t *testing.T) {
diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go
index b9552bc17..53d18a678 100644
--- a/client/internal/dns/server_test.go
+++ b/client/internal/dns/server_test.go
@@ -15,16 +15,18 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
+ "github.com/netbirdio/netbird/client/iface"
+ "github.com/netbirdio/netbird/client/iface/configurer"
+ "github.com/netbirdio/netbird/client/iface/device"
+ pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter"
- "github.com/netbirdio/netbird/iface"
- pfmock "github.com/netbirdio/netbird/iface/mocks"
)
type mocWGIface struct {
- filter iface.PacketFilter
+ filter device.PacketFilter
}
func (w *mocWGIface) Name() string {
@@ -43,11 +45,11 @@ func (w *mocWGIface) ToInterface() *net.Interface {
panic("implement me")
}
-func (w *mocWGIface) GetFilter() iface.PacketFilter {
+func (w *mocWGIface) GetFilter() device.PacketFilter {
return w.filter
}
-func (w *mocWGIface) GetDevice() *iface.DeviceWrapper {
+func (w *mocWGIface) GetDevice() *device.FilteredDevice {
panic("implement me")
}
@@ -59,13 +61,13 @@ func (w *mocWGIface) IsUserspaceBind() bool {
return false
}
-func (w *mocWGIface) SetFilter(filter iface.PacketFilter) error {
+func (w *mocWGIface) SetFilter(filter device.PacketFilter) error {
w.filter = filter
return nil
}
-func (w *mocWGIface) GetStats(_ string) (iface.WGStats, error) {
- return iface.WGStats{}, nil
+func (w *mocWGIface) GetStats(_ string) (configurer.WGStats, error) {
+ return configurer.WGStats{}, nil
}
var zoneRecords = []nbdns.SimpleRecord{
diff --git a/client/internal/dns/wgiface.go b/client/internal/dns/wgiface.go
index 2f08e8d52..69bc83659 100644
--- a/client/internal/dns/wgiface.go
+++ b/client/internal/dns/wgiface.go
@@ -5,7 +5,9 @@ package dns
import (
"net"
- "github.com/netbirdio/netbird/iface"
+ "github.com/netbirdio/netbird/client/iface"
+ "github.com/netbirdio/netbird/client/iface/configurer"
+ "github.com/netbirdio/netbird/client/iface/device"
)
// WGIface defines subset methods of interface required for manager
@@ -14,7 +16,7 @@ type WGIface interface {
Address() iface.WGAddress
ToInterface() *net.Interface
IsUserspaceBind() bool
- GetFilter() iface.PacketFilter
- GetDevice() *iface.DeviceWrapper
- GetStats(peerKey string) (iface.WGStats, error)
+ GetFilter() device.PacketFilter
+ GetDevice() *device.FilteredDevice
+ GetStats(peerKey string) (configurer.WGStats, error)
}
diff --git a/client/internal/dns/wgiface_windows.go b/client/internal/dns/wgiface_windows.go
index f8bb80fb9..765132fdb 100644
--- a/client/internal/dns/wgiface_windows.go
+++ b/client/internal/dns/wgiface_windows.go
@@ -1,14 +1,18 @@
package dns
-import "github.com/netbirdio/netbird/iface"
+import (
+ "github.com/netbirdio/netbird/client/iface"
+ "github.com/netbirdio/netbird/client/iface/configurer"
+ "github.com/netbirdio/netbird/client/iface/device"
+)
// WGIface defines subset methods of interface required for manager
type WGIface interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
- GetFilter() iface.PacketFilter
- GetDevice() *iface.DeviceWrapper
- GetStats(peerKey string) (iface.WGStats, error)
+ GetFilter() device.PacketFilter
+ GetDevice() *device.FilteredDevice
+ GetStats(peerKey string) (configurer.WGStats, error)
GetInterfaceGUIDString() (string, error)
}
diff --git a/client/internal/engine.go b/client/internal/engine.go
index d65322d6a..c51901a22 100644
--- a/client/internal/engine.go
+++ b/client/internal/engine.go
@@ -13,6 +13,7 @@ import (
"slices"
"strings"
"sync"
+ "sync/atomic"
"time"
"github.com/pion/ice/v3"
@@ -22,8 +23,12 @@ import (
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager"
+ "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns"
+
+ "github.com/netbirdio/netbird/client/iface"
+ "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/relay"
@@ -34,11 +39,11 @@ import (
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
- "github.com/netbirdio/netbird/iface"
- "github.com/netbirdio/netbird/iface/bind"
mgm "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto"
+ auth "github.com/netbirdio/netbird/relay/auth/hmac"
+ relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client"
sProto "github.com/netbirdio/netbird/signal/proto"
@@ -101,7 +106,8 @@ type EngineConfig struct {
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
type Engine struct {
// signal is a Signal Service client
- signal signal.Client
+ signal signal.Client
+ signaler *peer.Signaler
// mgmClient is a Management Service client
mgmClient mgm.Client
// peerConns is a map that holds all the peers that are known to this peer
@@ -122,7 +128,8 @@ type Engine struct {
// STUNs is a list of STUN servers used by ICE
STUNs []*stun.URI
// TURNs is a list of STUN servers used by ICE
- TURNs []*stun.URI
+ TURNs []*stun.URI
+ stunTurn atomic.Value
// clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes route.HAMap
@@ -134,7 +141,7 @@ type Engine struct {
ctx context.Context
cancel context.CancelFunc
- wgInterface *iface.WGIface
+ wgInterface iface.IWGIface
wgProxyFactory *wgproxy.Factory
udpMux *bind.UniversalUDPMuxDefault
@@ -155,15 +162,12 @@ type Engine struct {
dnsServer dns.Server
- mgmProbe *Probe
- signalProbe *Probe
- relayProbe *Probe
- wgProbe *Probe
-
- wgConnWorker sync.WaitGroup
+ probes *ProbeHolder
// checks are the client-applied posture checks that need to be evaluated on the client
checks []*mgmProto.Checks
+
+ relayManager *relayClient.Manager
}
// Peer is an instance of the Connection Peer
@@ -178,6 +182,7 @@ func NewEngine(
clientCancel context.CancelFunc,
signalClient signal.Client,
mgmClient mgm.Client,
+ relayManager *relayClient.Manager,
config *EngineConfig,
mobileDep MobileDependency,
statusRecorder *peer.Status,
@@ -188,13 +193,11 @@ func NewEngine(
clientCancel,
signalClient,
mgmClient,
+ relayManager,
config,
mobileDep,
statusRecorder,
nil,
- nil,
- nil,
- nil,
checks,
)
}
@@ -205,21 +208,20 @@ func NewEngineWithProbes(
clientCancel context.CancelFunc,
signalClient signal.Client,
mgmClient mgm.Client,
+ relayManager *relayClient.Manager,
config *EngineConfig,
mobileDep MobileDependency,
statusRecorder *peer.Status,
- mgmProbe *Probe,
- signalProbe *Probe,
- relayProbe *Probe,
- wgProbe *Probe,
+ probes *ProbeHolder,
checks []*mgmProto.Checks,
) *Engine {
-
return &Engine{
clientCtx: clientCtx,
clientCancel: clientCancel,
signal: signalClient,
+ signaler: peer.NewSignaler(signalClient, config.WgPrivateKey),
mgmClient: mgmClient,
+ relayManager: relayManager,
peerConns: make(map[string]*peer.Conn),
syncMsgMux: &sync.Mutex{},
config: config,
@@ -229,22 +231,20 @@ func NewEngineWithProbes(
networkSerial: 0,
sshServerFunc: nbssh.DefaultSSHServer,
statusRecorder: statusRecorder,
- mgmProbe: mgmProbe,
- signalProbe: signalProbe,
- relayProbe: relayProbe,
- wgProbe: wgProbe,
+ probes: probes,
checks: checks,
}
}
func (e *Engine) Stop() error {
+ if e == nil {
+ // this seems to be a very odd case but there was the possibility if the netbird down command comes before the engine is fully started
+ log.Debugf("tried stopping engine that is nil")
+ return nil
+ }
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
- if e.cancel != nil {
- e.cancel()
- }
-
// stopping network monitor first to avoid starting the engine again
if e.networkMonitor != nil {
e.networkMonitor.Stop()
@@ -253,36 +253,24 @@ func (e *Engine) Stop() error {
err := e.removeAllPeers()
if err != nil {
- return err
+ return fmt.Errorf("failed to remove all peers: %s", err)
}
e.clientRoutesMu.Lock()
e.clientRoutes = nil
e.clientRoutesMu.Unlock()
+ if e.cancel != nil {
+ e.cancel()
+ }
+
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
// Removing peers happens in the conn.Close() asynchronously
time.Sleep(500 * time.Millisecond)
e.close()
- e.wgConnWorker.Wait()
-
- maxWaitTime := 5 * time.Second
- timeout := time.After(maxWaitTime)
-
- for {
- if !e.IsWGIfaceUp() {
- log.Infof("stopped Netbird Engine")
- return nil
- }
-
- select {
- case <-timeout:
- return fmt.Errorf("timeout when waiting for interface shutdown")
- default:
- time.Sleep(100 * time.Millisecond)
- }
- }
+ log.Infof("stopped Netbird Engine")
+ return nil
}
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
@@ -305,7 +293,7 @@ func (e *Engine) Start() error {
e.wgInterface = wgIface
userspace := e.wgInterface.IsUserspaceBind()
- e.wgProxyFactory = wgproxy.NewFactory(e.ctx, userspace, e.config.WgPort)
+ e.wgProxyFactory = wgproxy.NewFactory(userspace, e.config.WgPort)
if e.config.RosenpassEnabled {
log.Infof("rosenpass is enabled")
@@ -331,7 +319,7 @@ func (e *Engine) Start() error {
}
e.dnsServer = dnsServer
- e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, initialRoutes)
+ e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes)
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
if err != nil {
log.Errorf("Failed to initialize route manager: %s", err)
@@ -480,80 +468,45 @@ func (e *Engine) removePeer(peerKey string) error {
conn, exists := e.peerConns[peerKey]
if exists {
delete(e.peerConns, peerKey)
- err := conn.Close()
- if err != nil {
- switch err.(type) {
- case *peer.ConnectionAlreadyClosedError:
- return nil
- default:
- return err
- }
- }
+ conn.Close()
}
return nil
}
-func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client) error {
- err := s.Send(&sProto.Message{
- Key: myKey.PublicKey().String(),
- RemoteKey: remoteKey.String(),
- Body: &sProto.Body{
- Type: sProto.Body_CANDIDATE,
- Payload: candidate.Marshal(),
- },
- })
- if err != nil {
- return err
- }
-
- return nil
-}
-
-func sendSignal(message *sProto.Message, s signal.Client) error {
- return s.Send(message)
-}
-
-// SignalOfferAnswer signals either an offer or an answer to remote peer
-func SignalOfferAnswer(offerAnswer peer.OfferAnswer, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client,
- isAnswer bool) error {
- var t sProto.Body_Type
- if isAnswer {
- t = sProto.Body_ANSWER
- } else {
- t = sProto.Body_OFFER
- }
-
- msg, err := signal.MarshalCredential(myKey, offerAnswer.WgListenPort, remoteKey, &signal.Credential{
- UFrag: offerAnswer.IceCredentials.UFrag,
- Pwd: offerAnswer.IceCredentials.Pwd,
- }, t, offerAnswer.RosenpassPubKey, offerAnswer.RosenpassAddr)
- if err != nil {
- return err
- }
-
- err = s.Send(msg)
- if err != nil {
- return err
- }
-
- return nil
-}
-
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if update.GetWiretrusteeConfig() != nil {
- err := e.updateTURNs(update.GetWiretrusteeConfig().GetTurns())
+ wCfg := update.GetWiretrusteeConfig()
+ err := e.updateTURNs(wCfg.GetTurns())
if err != nil {
- return err
+ return fmt.Errorf("update TURNs: %w", err)
}
- err = e.updateSTUNs(update.GetWiretrusteeConfig().GetStuns())
+ err = e.updateSTUNs(wCfg.GetStuns())
if err != nil {
- return err
+ return fmt.Errorf("update STUNs: %w", err)
}
+ var stunTurn []*stun.URI
+ stunTurn = append(stunTurn, e.STUNs...)
+ stunTurn = append(stunTurn, e.TURNs...)
+ e.stunTurn.Store(stunTurn)
+
+ relayMsg := wCfg.GetRelay()
+ if relayMsg != nil {
+ c := &auth.Token{
+ Payload: relayMsg.GetTokenPayload(),
+ Signature: relayMsg.GetTokenSignature(),
+ }
+ if err := e.relayManager.UpdateToken(c); err != nil {
+ log.Errorf("failed to update relay token: %v", err)
+ return fmt.Errorf("update relay token: %w", err)
+ }
+ }
+
+ // todo update relay address in the relay manager
// todo update signal
}
@@ -667,7 +620,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
e.statusRecorder.UpdateLocalPeerState(peer.LocalPeerState{
IP: e.config.WgAddr,
PubKey: e.config.WgPrivateKey.PublicKey().String(),
- KernelInterface: iface.WireGuardModuleIsLoaded(),
+ KernelInterface: device.WireGuardModuleIsLoaded(),
FQDN: conf.GetFqdn(),
})
@@ -752,6 +705,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return nil
}
+ // Apply ACLs in the beginning to avoid security leaks
+ if e.acl != nil {
+ e.acl.ApplyFiltering(networkMap)
+ }
+
protoRoutes := networkMap.GetRoutes()
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
@@ -818,10 +776,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
log.Errorf("failed to update dns server, err: %v", err)
}
- if e.acl != nil {
- e.acl.ApplyFiltering(networkMap)
- }
-
e.networkSerial = serial
// Test received (upstream) servers for availability right away instead of upon usage.
@@ -949,68 +903,13 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
}
- e.wgConnWorker.Add(1)
- go e.connWorker(conn, peerKey)
+ conn.Open()
}
return nil
}
-func (e *Engine) connWorker(conn *peer.Conn, peerKey string) {
- defer e.wgConnWorker.Done()
- for {
-
- // randomize starting time a bit
- minValue := 500
- maxValue := 2000
- duration := time.Duration(rand.Intn(maxValue-minValue)+minValue) * time.Millisecond
- select {
- case <-e.ctx.Done():
- return
- case <-time.After(duration):
- }
-
- // if peer has been removed -> give up
- if !e.peerExists(peerKey) {
- log.Debugf("peer %s doesn't exist anymore, won't retry connection", peerKey)
- return
- }
-
- if !e.signal.Ready() {
- log.Infof("signal client isn't ready, skipping connection attempt %s", peerKey)
- continue
- }
-
- // we might have received new STUN and TURN servers meanwhile, so update them
- e.syncMsgMux.Lock()
- conn.UpdateStunTurn(append(e.STUNs, e.TURNs...))
- e.syncMsgMux.Unlock()
-
- err := conn.Open(e.ctx)
- if err != nil {
- log.Debugf("connection to peer %s failed: %v", peerKey, err)
- var connectionClosedError *peer.ConnectionClosedError
- switch {
- case errors.As(err, &connectionClosedError):
- // conn has been forced to close, so we exit the loop
- return
- default:
- }
- }
- }
-}
-
-func (e *Engine) peerExists(peerKey string) bool {
- e.syncMsgMux.Lock()
- defer e.syncMsgMux.Unlock()
- _, ok := e.peerConns[peerKey]
- return ok
-}
-
func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, error) {
log.Debugf("creating peer connection %s", pubKey)
- var stunTurn []*stun.URI
- stunTurn = append(stunTurn, e.STUNs...)
- stunTurn = append(stunTurn, e.TURNs...)
wgConfig := peer.WgConfig{
RemoteKey: pubKey,
@@ -1043,52 +942,29 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
// randomize connection timeout
timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond
config := peer.ConnConfig{
- Key: pubKey,
- LocalKey: e.config.WgPrivateKey.PublicKey().String(),
- StunTurn: stunTurn,
- InterfaceBlackList: e.config.IFaceBlackList,
- DisableIPv6Discovery: e.config.DisableIPv6Discovery,
- Timeout: timeout,
- UDPMux: e.udpMux.UDPMuxDefault,
- UDPMuxSrflx: e.udpMux,
- WgConfig: wgConfig,
- LocalWgPort: e.config.WgPort,
- NATExternalIPs: e.parseNATExternalIPMappings(),
- RosenpassPubKey: e.getRosenpassPubKey(),
- RosenpassAddr: e.getRosenpassAddr(),
+ Key: pubKey,
+ LocalKey: e.config.WgPrivateKey.PublicKey().String(),
+ Timeout: timeout,
+ WgConfig: wgConfig,
+ LocalWgPort: e.config.WgPort,
+ RosenpassPubKey: e.getRosenpassPubKey(),
+ RosenpassAddr: e.getRosenpassAddr(),
+ ICEConfig: peer.ICEConfig{
+ StunTurn: &e.stunTurn,
+ InterfaceBlackList: e.config.IFaceBlackList,
+ DisableIPv6Discovery: e.config.DisableIPv6Discovery,
+ UDPMux: e.udpMux.UDPMuxDefault,
+ UDPMuxSrflx: e.udpMux,
+ NATExternalIPs: e.parseNATExternalIPMappings(),
+ },
}
- peerConn, err := peer.NewConn(config, e.statusRecorder, e.wgProxyFactory, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover)
+ peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.wgProxyFactory, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager)
if err != nil {
return nil, err
}
- wgPubKey, err := wgtypes.ParseKey(pubKey)
- if err != nil {
- return nil, err
- }
-
- signalOffer := func(offerAnswer peer.OfferAnswer) error {
- return SignalOfferAnswer(offerAnswer, e.config.WgPrivateKey, wgPubKey, e.signal, false)
- }
-
- signalCandidate := func(candidate ice.Candidate) error {
- return signalCandidate(candidate, e.config.WgPrivateKey, wgPubKey, e.signal)
- }
-
- signalAnswer := func(offerAnswer peer.OfferAnswer) error {
- return SignalOfferAnswer(offerAnswer, e.config.WgPrivateKey, wgPubKey, e.signal, true)
- }
-
- peerConn.SetSignalCandidate(signalCandidate)
- peerConn.SetSignalOffer(signalOffer)
- peerConn.SetSignalAnswer(signalAnswer)
- peerConn.SetSendSignalMessage(func(message *sProto.Message) error {
- return sendSignal(message, e.signal)
- })
-
if e.rpManager != nil {
-
peerConn.SetOnConnected(e.rpManager.OnConnected)
peerConn.SetOnDisconnected(e.rpManager.OnDisconnected)
}
@@ -1131,6 +1007,7 @@ func (e *Engine) receiveSignalEvents() {
Version: msg.GetBody().GetNetBirdVersion(),
RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr,
+ RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
})
case sProto.Body_ANSWER:
remoteCred, err := signal.UnMarshalCredential(msg)
@@ -1153,6 +1030,7 @@ func (e *Engine) receiveSignalEvents() {
Version: msg.GetBody().GetNetBirdVersion(),
RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr,
+ RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
})
case sProto.Body_CANDIDATE:
candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload)
@@ -1161,7 +1039,7 @@ func (e *Engine) receiveSignalEvents() {
return err
}
- conn.OnRemoteCandidate(candidate, e.GetClientRoutes())
+ go conn.OnRemoteCandidate(candidate, e.GetClientRoutes())
case sProto.Body_MODE:
}
@@ -1239,10 +1117,7 @@ func (e *Engine) close() {
}
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
- if e.dnsServer != nil {
- e.dnsServer.Stop()
- e.dnsServer = nil
- }
+ e.stopDNSServer()
if e.routeManager != nil {
e.routeManager.Stop()
@@ -1291,15 +1166,15 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
log.Errorf("failed to create pion's stdnet: %s", err)
}
- var mArgs *iface.MobileIFaceArguments
+ var mArgs *device.MobileIFaceArguments
switch runtime.GOOS {
case "android":
- mArgs = &iface.MobileIFaceArguments{
+ mArgs = &device.MobileIFaceArguments{
TunAdapter: e.mobileDep.TunAdapter,
TunFd: int(e.mobileDep.FileDescriptor),
}
case "ios":
- mArgs = &iface.MobileIFaceArguments{
+ mArgs = &device.MobileIFaceArguments{
TunFd: int(e.mobileDep.FileDescriptor),
}
default:
@@ -1415,24 +1290,27 @@ func (e *Engine) getRosenpassAddr() string {
}
func (e *Engine) receiveProbeEvents() {
- if e.signalProbe != nil {
- go e.signalProbe.Receive(e.ctx, func() bool {
+ if e.probes == nil {
+ return
+ }
+ if e.probes.SignalProbe != nil {
+ go e.probes.SignalProbe.Receive(e.ctx, func() bool {
healthy := e.signal.IsHealthy()
log.Debugf("received signal probe request, healthy: %t", healthy)
return healthy
})
}
- if e.mgmProbe != nil {
- go e.mgmProbe.Receive(e.ctx, func() bool {
+ if e.probes.MgmProbe != nil {
+ go e.probes.MgmProbe.Receive(e.ctx, func() bool {
healthy := e.mgmClient.IsHealthy()
log.Debugf("received management probe request, healthy: %t", healthy)
return healthy
})
}
- if e.relayProbe != nil {
- go e.relayProbe.Receive(e.ctx, func() bool {
+ if e.probes.RelayProbe != nil {
+ go e.probes.RelayProbe.Receive(e.ctx, func() bool {
healthy := true
results := append(e.probeSTUNs(), e.probeTURNs()...)
@@ -1451,13 +1329,13 @@ func (e *Engine) receiveProbeEvents() {
})
}
- if e.wgProbe != nil {
- go e.wgProbe.Receive(e.ctx, func() bool {
+ if e.probes.WgProbe != nil {
+ go e.probes.WgProbe.Receive(e.ctx, func() bool {
log.Debug("received wg probe request")
for _, peer := range e.peerConns {
key := peer.GetKey()
- wgStats, err := peer.GetConf().WgConfig.WgInterface.GetStats(key)
+ wgStats, err := peer.WgConfig().WgInterface.GetStats(key)
if err != nil {
log.Debugf("failed to get wg stats for peer %s: %s", key, err)
}
@@ -1481,12 +1359,16 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
}
func (e *Engine) restartEngine() {
+ log.Info("restarting engine")
+ CtxGetState(e.ctx).Set(StatusConnecting)
+
if err := e.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
- if err := e.Start(); err != nil {
- log.Errorf("Failed to start engine: %v", err)
- }
+
+ _ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
+ log.Infof("cancelling client, engine will be recreated")
+ e.clientCancel()
}
func (e *Engine) startNetworkMonitor() {
@@ -1508,6 +1390,7 @@ func (e *Engine) startNetworkMonitor() {
defer mu.Unlock()
if debounceTimer != nil {
+ log.Infof("Network monitor: detected network change, reset debounceTimer")
debounceTimer.Stop()
}
@@ -1517,7 +1400,7 @@ func (e *Engine) startNetworkMonitor() {
mu.Lock()
defer mu.Unlock()
- log.Infof("Network monitor detected network change, restarting engine")
+ log.Infof("Network monitor: detected network change, restarting engine")
e.restartEngine()
})
})
@@ -1542,26 +1425,23 @@ func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
return false, netip.Prefix{}, nil
}
+func (e *Engine) stopDNSServer() {
+ err := fmt.Errorf("DNS server stopped")
+ nsGroupStates := e.statusRecorder.GetDNSStates()
+ for i := range nsGroupStates {
+ nsGroupStates[i].Enabled = false
+ nsGroupStates[i].Error = err
+ }
+ e.statusRecorder.UpdateDNSStates(nsGroupStates)
+ if e.dnsServer != nil {
+ e.dnsServer.Stop()
+ e.dnsServer = nil
+ }
+}
+
// isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files)
})
}
-
-func (e *Engine) IsWGIfaceUp() bool {
- if e == nil || e.wgInterface == nil {
- return false
- }
- iface, err := net.InterfaceByName(e.wgInterface.Name())
- if err != nil {
- log.Debugf("failed to get interface by name %s: %v", e.wgInterface.Name(), err)
- return false
- }
-
- if iface.Flags&net.FlagUp != 0 {
- return true
- }
-
- return false
-}
diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go
index e0f85d211..29a8439a2 100644
--- a/client/internal/engine_test.go
+++ b/client/internal/engine_test.go
@@ -13,6 +13,7 @@ import (
"testing"
"time"
+ "github.com/google/uuid"
"github.com/pion/transport/v3/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
@@ -24,19 +25,21 @@ import (
"github.com/netbirdio/management-integrations/integrations"
+ "github.com/netbirdio/netbird/client/iface"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
- "github.com/netbirdio/netbird/iface"
- "github.com/netbirdio/netbird/iface/bind"
mgmt "github.com/netbirdio/netbird/management/client"
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/telemetry"
+ relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client"
"github.com/netbirdio/netbird/signal/proto"
@@ -58,6 +61,12 @@ var (
}
)
+func TestMain(m *testing.M) {
+ _ = util.InitLog("debug", "console")
+ code := m.Run()
+ os.Exit(code)
+}
+
func TestEngine_SSH(t *testing.T) {
// todo resolve test execution on freebsd
if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" {
@@ -73,13 +82,23 @@ func TestEngine_SSH(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
- WgIfaceName: "utun101",
- WgAddr: "100.64.0.1/24",
- WgPrivateKey: key,
- WgPort: 33100,
- ServerSSHAllowed: true,
- }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
+ relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
+ engine := NewEngine(
+ ctx, cancel,
+ &signal.MockClient{},
+ &mgmt.MockClient{},
+ relayMgr,
+ &EngineConfig{
+ WgIfaceName: "utun101",
+ WgAddr: "100.64.0.1/24",
+ WgPrivateKey: key,
+ WgPort: 33100,
+ ServerSSHAllowed: true,
+ },
+ MobileDependency{},
+ peer.NewRecorder("https://mgm"),
+ nil,
+ )
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
@@ -208,21 +227,29 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
- WgIfaceName: "utun102",
- WgAddr: "100.64.0.1/24",
- WgPrivateKey: key,
- WgPort: 33100,
- }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
- newNet, err := stdnet.NewNet()
- if err != nil {
- t.Fatal(err)
+ relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
+ engine := NewEngine(
+ ctx, cancel,
+ &signal.MockClient{},
+ &mgmt.MockClient{},
+ relayMgr,
+ &EngineConfig{
+ WgIfaceName: "utun102",
+ WgAddr: "100.64.0.1/24",
+ WgPrivateKey: key,
+ WgPort: 33100,
+ },
+ MobileDependency{},
+ peer.NewRecorder("https://mgm"),
+ nil)
+
+ wgIface := &iface.MockWGIface{
+ RemovePeerFunc: func(peerKey string) error {
+ return nil
+ },
}
- engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil)
- if err != nil {
- t.Fatal(err)
- }
- engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, nil)
+ engine.wgInterface = wgIface
+ engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
@@ -404,8 +431,8 @@ func TestEngine_Sync(t *testing.T) {
}
return nil
}
-
- engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, &EngineConfig{
+ relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
+ engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{
WgIfaceName: "utun103",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
@@ -564,7 +591,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
- engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
+ relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
+ engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: wgIfaceName,
WgAddr: wgAddr,
WgPrivateKey: key,
@@ -734,7 +762,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
wgIfaceName := fmt.Sprintf("utun%d", 104+n)
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
- engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
+ relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
+ engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
WgIfaceName: wgIfaceName,
WgAddr: wgAddr,
WgPrivateKey: key,
@@ -845,6 +874,8 @@ func TestEngine_MultiplePeers(t *testing.T) {
engine.dnsServer = &dns.MockServer{}
mu.Lock()
defer mu.Unlock()
+ guid := fmt.Sprintf("{%s}", uuid.New().String())
+ device.CustomWindowsGUIDString = strings.ToLower(guid)
err = engine.Start()
if err != nil {
t.Errorf("unable to start engine for peer %d with error %v", j, err)
@@ -1010,7 +1041,8 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
WgPort: wgPort,
}
- e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
+ relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String())
+ e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
e.ctx = ctx
return e, err
}
@@ -1025,7 +1057,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
log.Fatalf("failed to listen: %v", err)
}
- srv, err := signalServer.NewServer(otel.Meter(""))
+ srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)
@@ -1044,6 +1076,11 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
config := &server.Config{
Stuns: []*server.Host{},
TURNConfig: &server.TURNConfig{},
+ Relay: &server.Relay{
+ Addresses: []string{"127.0.0.1:1234"},
+ CredentialsTTL: util.Duration{Duration: time.Hour},
+ Secret: "222222222222222222",
+ },
Signal: &server.Host{
Proto: "http",
URI: "localhost:10000",
@@ -1078,8 +1115,9 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
if err != nil {
return nil, "", err
}
- turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
- mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
+
+ secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
+ mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil)
if err != nil {
return nil, "", err
}
diff --git a/client/internal/mobile_dependency.go b/client/internal/mobile_dependency.go
index 2355c67c3..2b0c92cc6 100644
--- a/client/internal/mobile_dependency.go
+++ b/client/internal/mobile_dependency.go
@@ -1,16 +1,16 @@
package internal
import (
+ "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/stdnet"
- "github.com/netbirdio/netbird/iface"
)
// MobileDependency collect all dependencies for mobile platform
type MobileDependency struct {
// Android only
- TunAdapter iface.TunAdapter
+ TunAdapter device.TunAdapter
IFaceDiscover stdnet.ExternalIFaceDiscover
NetworkChangeListener listener.NetworkChangeListener
HostDNSAddresses []string
diff --git a/client/internal/networkmonitor/monitor_bsd.go b/client/internal/networkmonitor/monitor_bsd.go
index 29df7ea7f..4dc2c1aa3 100644
--- a/client/internal/networkmonitor/monitor_bsd.go
+++ b/client/internal/networkmonitor/monitor_bsd.go
@@ -24,7 +24,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
defer func() {
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
- log.Errorf("Network monitor: failed to close routing socket: %v", err)
+ log.Warnf("Network monitor: failed to close routing socket: %v", err)
}
}()
@@ -32,7 +32,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
<-ctx.Done()
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
- log.Debugf("Network monitor: closed routing socket")
+ log.Debugf("Network monitor: closed routing socket: %v", err)
}
}()
@@ -45,12 +45,12 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
n, err := unix.Read(fd, buf)
if err != nil {
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
- log.Errorf("Network monitor: failed to read from routing socket: %v", err)
+ log.Warnf("Network monitor: failed to read from routing socket: %v", err)
}
continue
}
if n < unix.SizeofRtMsghdr {
- log.Errorf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
+ log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
continue
}
@@ -61,11 +61,11 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n])
if err != nil {
- log.Errorf("Network monitor: error parsing routing message: %v", err)
+ log.Debugf("Network monitor: error parsing routing message: %v", err)
continue
}
- if !route.Dst.Addr().IsUnspecified() {
+ if route.Dst.Bits() != 0 {
continue
}
diff --git a/client/internal/networkmonitor/monitor_generic.go b/client/internal/networkmonitor/monitor_generic.go
index f5cc19473..19648edba 100644
--- a/client/internal/networkmonitor/monitor_generic.go
+++ b/client/internal/networkmonitor/monitor_generic.go
@@ -59,7 +59,7 @@ func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error
// recover in case sys ops panic
defer func() {
if r := recover(); r != nil {
- err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
+ err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, debug.Stack())
}
}()
diff --git a/client/internal/networkmonitor/monitor_windows.go b/client/internal/networkmonitor/monitor_windows.go
index 308b2aa45..cd48c269d 100644
--- a/client/internal/networkmonitor/monitor_windows.go
+++ b/client/internal/networkmonitor/monitor_windows.go
@@ -3,252 +3,73 @@ package networkmonitor
import (
"context"
"fmt"
- "net"
- "net/netip"
"strings"
- "time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
-const (
- unreachable = 0
- incomplete = 1
- probe = 2
- delay = 3
- stale = 4
- reachable = 5
- permanent = 6
- tbd = 7
-)
-
-const interval = 10 * time.Second
-
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error {
- var neighborv4, neighborv6 *systemops.Neighbor
- {
- initialNeighbors, err := getNeighbors()
- if err != nil {
- return fmt.Errorf("get neighbors: %w", err)
- }
-
- neighborv4 = assignNeighbor(nexthopv4, initialNeighbors)
- neighborv6 = assignNeighbor(nexthopv6, initialNeighbors)
+ routeMonitor, err := systemops.NewRouteMonitor(ctx)
+ if err != nil {
+ return fmt.Errorf("failed to create route monitor: %w", err)
}
- log.Debugf("Network monitor: initial IPv4 neighbor: %v, IPv6 neighbor: %v", neighborv4, neighborv6)
-
- ticker := time.NewTicker(interval)
- defer ticker.Stop()
+ defer func() {
+ if err := routeMonitor.Stop(); err != nil {
+ log.Errorf("Network monitor: failed to stop route monitor: %v", err)
+ }
+ }()
for {
select {
case <-ctx.Done():
return ErrStopped
- case <-ticker.C:
- if changed(nexthopv4, neighborv4, nexthopv6, neighborv6) {
- go callback()
- return nil
+ case route := <-routeMonitor.RouteUpdates():
+ if route.Destination.Bits() != 0 {
+ continue
+ }
+
+ if routeChanged(route, nexthopv4, nexthopv6, callback) {
+ break
}
}
}
}
-func assignNeighbor(nexthop systemops.Nexthop, initialNeighbors map[netip.Addr]systemops.Neighbor) *systemops.Neighbor {
- if n, ok := initialNeighbors[nexthop.IP]; ok &&
- n.State != unreachable &&
- n.State != incomplete &&
- n.State != tbd {
- return &n
- }
- return nil
-}
-
-func changed(
- nexthopv4 systemops.Nexthop,
- neighborv4 *systemops.Neighbor,
- nexthopv6 systemops.Nexthop,
- neighborv6 *systemops.Neighbor,
-) bool {
- neighbors, err := getNeighbors()
- if err != nil {
- log.Errorf("network monitor: error fetching current neighbors: %v", err)
- return false
- }
- if neighborChanged(nexthopv4, neighborv4, neighbors) || neighborChanged(nexthopv6, neighborv6, neighbors) {
- return true
- }
-
- routes, err := getRoutes()
- if err != nil {
- log.Errorf("network monitor: error fetching current routes: %v", err)
- return false
- }
-
- if routeChanged(nexthopv4, nexthopv4.Intf, routes) || routeChanged(nexthopv6, nexthopv6.Intf, routes) {
- return true
- }
-
- return false
-}
-
-// routeChanged checks if the default routes still point to our nexthop/interface
-func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route) bool {
- if !nexthop.IP.IsValid() {
- return false
- }
-
- if isSoftInterface(nexthop.Intf.Name) {
- log.Tracef("network monitor: ignoring default route change for soft interface %s", nexthop.Intf.Name)
- return false
- }
-
- unspec := getUnspecifiedPrefix(nexthop.IP)
- defaultRoutes, foundMatchingRoute := processRoutes(nexthop, intf, routes, unspec)
-
- log.Tracef("network monitor: all default routes:\n%s", strings.Join(defaultRoutes, "\n"))
-
- if !foundMatchingRoute {
- logRouteChange(nexthop.IP, intf)
- return true
- }
-
- return false
-}
-
-func getUnspecifiedPrefix(ip netip.Addr) netip.Prefix {
- if ip.Is6() {
- return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
- }
- return netip.PrefixFrom(netip.IPv4Unspecified(), 0)
-}
-
-func processRoutes(nexthop systemops.Nexthop, nexthopIntf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) {
- var defaultRoutes []string
- foundMatchingRoute := false
-
- for _, r := range routes {
- if r.Destination == unspec {
- routeInfo := formatRouteInfo(r)
- defaultRoutes = append(defaultRoutes, routeInfo)
-
- if r.Nexthop == nexthop.IP && compareIntf(r.Interface, nexthopIntf) == 0 {
- foundMatchingRoute = true
- log.Debugf("network monitor: found matching default route: %s", routeInfo)
- }
+func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) bool {
+ intf := ""
+ if route.Interface != nil {
+ intf = route.Interface.Name
+ if isSoftInterface(intf) {
+ log.Debugf("Network monitor: ignoring default route change for soft interface %s", intf)
+ return false
}
}
- return defaultRoutes, foundMatchingRoute
-}
-
-func formatRouteInfo(r systemops.Route) string {
- newIntf := ""
- if r.Interface != nil {
- newIntf = r.Interface.Name
- }
- return fmt.Sprintf("Nexthop: %s, Interface: %s", r.Nexthop, newIntf)
-}
-
-func logRouteChange(ip netip.Addr, intf *net.Interface) {
- oldIntf := ""
- if intf != nil {
- oldIntf = intf.Name
- }
- log.Infof("network monitor: default route for %s (%s) is gone or changed", ip, oldIntf)
-}
-
-func neighborChanged(nexthop systemops.Nexthop, neighbor *systemops.Neighbor, neighbors map[netip.Addr]systemops.Neighbor) bool {
- if neighbor == nil {
- return false
- }
-
- // TODO: consider non-local nexthops, e.g. on point-to-point interfaces
- if n, ok := neighbors[nexthop.IP]; ok {
- if n.State == unreachable || n.State == incomplete {
- log.Infof("network monitor: neighbor %s (%s) is not reachable: %s", neighbor.IPAddress, neighbor.LinkLayerAddress, stateFromInt(n.State))
- return true
- } else if n.InterfaceIndex != neighbor.InterfaceIndex {
- log.Infof(
- "network monitor: neighbor %s (%s) changed interface from '%s' (%d) to '%s' (%d): %s",
- neighbor.IPAddress,
- neighbor.LinkLayerAddress,
- neighbor.InterfaceAlias,
- neighbor.InterfaceIndex,
- n.InterfaceAlias,
- n.InterfaceIndex,
- stateFromInt(n.State),
- )
+ switch route.Type {
+ case systemops.RouteModified:
+ // TODO: get routing table to figure out if our route is affected for modified routes
+ log.Infof("Network monitor: default route changed: via %s, interface %s", route.NextHop, intf)
+ go callback()
+ return true
+ case systemops.RouteAdded:
+ if route.NextHop.Is4() && route.NextHop != nexthopv4.IP || route.NextHop.Is6() && route.NextHop != nexthopv6.IP {
+ log.Infof("Network monitor: default route added: via %s, interface %s", route.NextHop, intf)
+ go callback()
+ return true
+ }
+ case systemops.RouteDeleted:
+ if nexthopv4.Intf != nil && route.NextHop == nexthopv4.IP || nexthopv6.Intf != nil && route.NextHop == nexthopv6.IP {
+ log.Infof("Network monitor: default route removed: via %s, interface %s", route.NextHop, intf)
+ go callback()
return true
}
- } else {
- log.Infof("network monitor: neighbor %s (%s) is gone", neighbor.IPAddress, neighbor.LinkLayerAddress)
- return true
}
return false
}
-func getNeighbors() (map[netip.Addr]systemops.Neighbor, error) {
- entries, err := systemops.GetNeighbors()
- if err != nil {
- return nil, fmt.Errorf("get neighbors: %w", err)
- }
-
- neighbours := make(map[netip.Addr]systemops.Neighbor, len(entries))
- for _, entry := range entries {
- neighbours[entry.IPAddress] = entry
- }
-
- return neighbours, nil
-}
-
-func getRoutes() ([]systemops.Route, error) {
- entries, err := systemops.GetRoutes()
- if err != nil {
- return nil, fmt.Errorf("get routes: %w", err)
- }
-
- return entries, nil
-}
-
-func stateFromInt(state uint8) string {
- switch state {
- case unreachable:
- return "unreachable"
- case incomplete:
- return "incomplete"
- case probe:
- return "probe"
- case delay:
- return "delay"
- case stale:
- return "stale"
- case reachable:
- return "reachable"
- case permanent:
- return "permanent"
- case tbd:
- return "tbd"
- default:
- return "unknown"
- }
-}
-
-func compareIntf(a, b *net.Interface) int {
- switch {
- case a == nil && b == nil:
- return 0
- case a == nil:
- return -1
- case b == nil:
- return 1
- default:
- return a.Index - b.Index
- }
-}
-
func isSoftInterface(name string) bool {
return strings.Contains(strings.ToLower(name), "isatap") || strings.Contains(strings.ToLower(name), "teredo")
}
diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go
index 0d8fd932c..ad84bd700 100644
--- a/client/internal/peer/conn.go
+++ b/client/internal/peer/conn.go
@@ -2,570 +2,277 @@ package peer
import (
"context"
- "fmt"
+ "math/rand"
"net"
+ "os"
"runtime"
"strings"
"sync"
"time"
+ "github.com/cenkalti/backoff/v4"
"github.com/pion/ice/v3"
- "github.com/pion/stun/v2"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
+ "github.com/netbirdio/netbird/client/iface"
+ "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/wgproxy"
- "github.com/netbirdio/netbird/iface"
- "github.com/netbirdio/netbird/iface/bind"
+ relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route"
- sProto "github.com/netbirdio/netbird/signal/proto"
nbnet "github.com/netbirdio/netbird/util/net"
- "github.com/netbirdio/netbird/version"
)
-const (
- iceKeepAliveDefault = 4 * time.Second
- iceDisconnectedTimeoutDefault = 6 * time.Second
- // iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
- iceRelayAcceptanceMinWaitDefault = 2 * time.Second
+type ConnPriority int
+const (
defaultWgKeepAlive = 25 * time.Second
+
+ connPriorityRelay ConnPriority = 1
+ connPriorityICETurn ConnPriority = 1
+ connPriorityICEP2P ConnPriority = 2
)
type WgConfig struct {
WgListenPort int
RemoteKey string
- WgInterface *iface.WGIface
+ WgInterface iface.IWGIface
AllowedIps string
PreSharedKey *wgtypes.Key
}
// ConnConfig is a peer Connection configuration
type ConnConfig struct {
-
// Key is a public key of a remote peer
Key string
// LocalKey is a public key of a local peer
LocalKey string
- // StunTurn is a list of STUN and TURN URLs
- StunTurn []*stun.URI
-
- // InterfaceBlackList is a list of machine interfaces that should be filtered out by ICE Candidate gathering
- // (e.g. if eth0 is in the list, host candidate of this interface won't be used)
- InterfaceBlackList []string
- DisableIPv6Discovery bool
-
Timeout time.Duration
WgConfig WgConfig
- UDPMux ice.UDPMux
- UDPMuxSrflx ice.UniversalUDPMux
-
LocalWgPort int
- NATExternalIPs []string
-
// RosenpassPubKey is this peer's Rosenpass public key
RosenpassPubKey []byte
// RosenpassPubKey is this peer's RosenpassAddr server address (IP:port)
RosenpassAddr string
+
+ // ICEConfig ICE protocol configuration
+ ICEConfig ICEConfig
}
-// OfferAnswer represents a session establishment offer or answer
-type OfferAnswer struct {
- IceCredentials IceCredentials
- // WgListenPort is a remote WireGuard listen port.
- // This field is used when establishing a direct WireGuard connection without any proxy.
- // We can set the remote peer's endpoint with this port.
- WgListenPort int
+type WorkerCallbacks struct {
+ OnRelayReadyCallback func(info RelayConnInfo)
+ OnRelayStatusChanged func(ConnStatus)
- // Version of NetBird Agent
- Version string
- // RosenpassPubKey is the Rosenpass public key of the remote peer when receiving this message
- // This value is the local Rosenpass server public key when sending the message
- RosenpassPubKey []byte
- // RosenpassAddr is the Rosenpass server address (IP:port) of the remote peer when receiving this message
- // This value is the local Rosenpass server address when sending the message
- RosenpassAddr string
-}
-
-// IceCredentials ICE protocol credentials struct
-type IceCredentials struct {
- UFrag string
- Pwd string
+ OnICEConnReadyCallback func(ConnPriority, ICEConnInfo)
+ OnICEStatusChanged func(ConnStatus)
}
type Conn struct {
- config ConnConfig
- mu sync.Mutex
-
- // signalCandidate is a handler function to signal remote peer about local connection candidate
- signalCandidate func(candidate ice.Candidate) error
- // signalOffer is a handler function to signal remote peer our connection offer (credentials)
- signalOffer func(OfferAnswer) error
- signalAnswer func(OfferAnswer) error
- sendSignalMessage func(message *sProto.Message) error
- onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
- onDisconnected func(remotePeer string, wgIP string)
-
- // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
- remoteOffersCh chan OfferAnswer
- // remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
- remoteAnswerCh chan OfferAnswer
- closeCh chan struct{}
- ctx context.Context
- notifyDisconnected context.CancelFunc
-
- agent *ice.Agent
- status ConnStatus
-
+ log *log.Entry
+ mu sync.Mutex
+ ctx context.Context
+ ctxCancel context.CancelFunc
+ config ConnConfig
statusRecorder *Status
-
wgProxyFactory *wgproxy.Factory
- wgProxy wgproxy.Proxy
+ wgProxyICE wgproxy.Proxy
+ wgProxyRelay wgproxy.Proxy
+ signaler *Signaler
+ relayManager *relayClient.Manager
+ allowedIPsIP string
+ handshaker *Handshaker
- adapter iface.TunAdapter
- iFaceDiscover stdnet.ExternalIFaceDiscover
- sentExtraSrflx bool
+ onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)
+ onDisconnected func(remotePeer string, wgIP string)
- connID nbnet.ConnectionID
+ statusRelay *AtomicConnStatus
+ statusICE *AtomicConnStatus
+ currentConnPriority ConnPriority
+ opened bool // this flag is used to prevent close in case of not opened connection
+
+ workerICE *WorkerICE
+ workerRelay *WorkerRelay
+
+ connIDRelay nbnet.ConnectionID
+ connIDICE nbnet.ConnectionID
beforeAddPeerHooks []nbnet.AddHookFunc
afterRemovePeerHooks []nbnet.RemoveHookFunc
-}
-// GetConf returns the connection config
-func (conn *Conn) GetConf() ConnConfig {
- return conn.config
-}
+ endpointRelay *net.UDPAddr
-// WgConfig returns the WireGuard config
-func (conn *Conn) WgConfig() WgConfig {
- return conn.config.WgConfig
-}
-
-// UpdateStunTurn update the turn and stun addresses
-func (conn *Conn) UpdateStunTurn(turnStun []*stun.URI) {
- conn.config.StunTurn = turnStun
+ // for reconnection operations
+ iCEDisconnected chan bool
+ relayDisconnected chan bool
}
// NewConn creates a new not opened Conn to the remote peer.
// To establish a connection run Conn.Open
-func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.Factory, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*Conn, error) {
- return &Conn{
- config: config,
- mu: sync.Mutex{},
- status: StatusDisconnected,
- closeCh: make(chan struct{}),
- remoteOffersCh: make(chan OfferAnswer),
- remoteAnswerCh: make(chan OfferAnswer),
- statusRecorder: statusRecorder,
- wgProxyFactory: wgProxyFactory,
- adapter: adapter,
- iFaceDiscover: iFaceDiscover,
- }, nil
+func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.Factory, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager) (*Conn, error) {
+ _, allowedIPsIP, err := net.ParseCIDR(config.WgConfig.AllowedIps)
+ if err != nil {
+ log.Errorf("failed to parse allowedIPS: %v", err)
+ return nil, err
+ }
+
+ ctx, ctxCancel := context.WithCancel(engineCtx)
+ connLog := log.WithField("peer", config.Key)
+
+ var conn = &Conn{
+ log: connLog,
+ ctx: ctx,
+ ctxCancel: ctxCancel,
+ config: config,
+ statusRecorder: statusRecorder,
+ wgProxyFactory: wgProxyFactory,
+ signaler: signaler,
+ relayManager: relayManager,
+ allowedIPsIP: allowedIPsIP.String(),
+ statusRelay: NewAtomicConnStatus(),
+ statusICE: NewAtomicConnStatus(),
+ iCEDisconnected: make(chan bool, 1),
+ relayDisconnected: make(chan bool, 1),
+ }
+
+ rFns := WorkerRelayCallbacks{
+ OnConnReady: conn.relayConnectionIsReady,
+ OnDisconnected: conn.onWorkerRelayStateDisconnected,
+ }
+
+ wFns := WorkerICECallbacks{
+ OnConnReady: conn.iCEConnectionIsReady,
+ OnStatusChanged: conn.onWorkerICEStateDisconnected,
+ }
+
+ conn.workerRelay = NewWorkerRelay(connLog, config, relayManager, rFns)
+
+ relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
+ conn.workerICE, err = NewWorkerICE(ctx, connLog, config, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally, wFns)
+ if err != nil {
+ return nil, err
+ }
+
+ conn.handshaker = NewHandshaker(ctx, connLog, config, signaler, conn.workerICE, conn.workerRelay)
+
+ conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer)
+ if os.Getenv("NB_FORCE_RELAY") != "true" {
+ conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer)
+ }
+
+ go conn.handshaker.Listen()
+
+ return conn, nil
}
-func (conn *Conn) reCreateAgent() error {
+// Open opens connection to the remote peer
+// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
+// be used.
+func (conn *Conn) Open() {
+ conn.log.Debugf("open connection to peer")
conn.mu.Lock()
defer conn.mu.Unlock()
-
- failedTimeout := 6 * time.Second
-
- var err error
- transportNet, err := conn.newStdNet()
- if err != nil {
- log.Errorf("failed to create pion's stdnet: %s", err)
- }
-
- iceKeepAlive := iceKeepAlive()
- iceDisconnectedTimeout := iceDisconnectedTimeout()
- iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
-
- agentConfig := &ice.AgentConfig{
- MulticastDNSMode: ice.MulticastDNSModeDisabled,
- NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
- Urls: conn.config.StunTurn,
- CandidateTypes: conn.candidateTypes(),
- FailedTimeout: &failedTimeout,
- InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList),
- UDPMux: conn.config.UDPMux,
- UDPMuxSrflx: conn.config.UDPMuxSrflx,
- NAT1To1IPs: conn.config.NATExternalIPs,
- Net: transportNet,
- DisconnectedTimeout: &iceDisconnectedTimeout,
- KeepaliveInterval: &iceKeepAlive,
- RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
- }
-
- if conn.config.DisableIPv6Discovery {
- agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4}
- }
-
- conn.agent, err = ice.NewAgent(agentConfig)
- if err != nil {
- return err
- }
-
- err = conn.agent.OnCandidate(conn.onICECandidate)
- if err != nil {
- return err
- }
-
- err = conn.agent.OnConnectionStateChange(conn.onICEConnectionStateChange)
- if err != nil {
- return err
- }
-
- err = conn.agent.OnSelectedCandidatePairChange(conn.onICESelectedCandidatePair)
- if err != nil {
- return err
- }
-
- err = conn.agent.OnSuccessfulSelectedPairBindingResponse(func(p *ice.CandidatePair) {
- err := conn.statusRecorder.UpdateLatency(conn.config.Key, p.Latency())
- if err != nil {
- log.Debugf("failed to update latency for peer %s: %s", conn.config.Key, err)
- return
- }
- })
- if err != nil {
- return fmt.Errorf("failed setting binding response callback: %w", err)
- }
-
- return nil
-}
-
-func (conn *Conn) candidateTypes() []ice.CandidateType {
- if hasICEForceRelayConn() {
- return []ice.CandidateType{ice.CandidateTypeRelay}
- }
- // TODO: remove this once we have refactored userspace proxy into the bind package
- if runtime.GOOS == "ios" {
- return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
- }
- return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}
-}
-
-// Open opens connection to the remote peer starting ICE candidate gathering process.
-// Blocks until connection has been closed or connection timeout.
-// ConnStatus will be set accordingly
-func (conn *Conn) Open(ctx context.Context) error {
- log.Debugf("trying to connect to peer %s", conn.config.Key)
+ conn.opened = true
peerState := State{
PubKey: conn.config.Key,
IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0],
ConnStatusUpdate: time.Now(),
- ConnStatus: conn.status,
+ ConnStatus: StatusDisconnected,
Mux: new(sync.RWMutex),
}
err := conn.statusRecorder.UpdatePeerState(peerState)
if err != nil {
- log.Warnf("error while updating the state of peer %s,err: %v", conn.config.Key, err)
+ conn.log.Warnf("error while updating the state err: %v", err)
}
- defer func() {
- err := conn.cleanup()
- if err != nil {
- log.Warnf("error while cleaning up peer connection %s: %v", conn.config.Key, err)
- return
- }
- }()
+ go conn.startHandshakeAndReconnect()
+}
- err = conn.reCreateAgent()
+func (conn *Conn) startHandshakeAndReconnect() {
+ conn.waitInitialRandomSleepTime()
+
+ err := conn.handshaker.sendOffer()
if err != nil {
- return err
+ conn.log.Errorf("failed to send initial offer: %v", err)
}
- err = conn.sendOffer()
- if err != nil {
- return err
- }
-
- log.Debugf("connection offer sent to peer %s, waiting for the confirmation", conn.config.Key)
-
- // Only continue once we got a connection confirmation from the remote peer.
- // The connection timeout could have happened before a confirmation received from the remote.
- // The connection could have also been closed externally (e.g. when we received an update from the management that peer shouldn't be connected)
- var remoteOfferAnswer OfferAnswer
- select {
- case remoteOfferAnswer = <-conn.remoteOffersCh:
- // received confirmation from the remote peer -> ready to proceed
- err = conn.sendAnswer()
- if err != nil {
- return err
- }
- case remoteOfferAnswer = <-conn.remoteAnswerCh:
- case <-time.After(conn.config.Timeout):
- return NewConnectionTimeoutError(conn.config.Key, conn.config.Timeout)
- case <-conn.closeCh:
- // closed externally
- return NewConnectionClosedError(conn.config.Key)
- }
-
- log.Debugf("received connection confirmation from peer %s running version %s and with remote WireGuard listen port %d",
- conn.config.Key, remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort)
-
- // at this point we received offer/answer and we are ready to gather candidates
- conn.mu.Lock()
- conn.status = StatusConnecting
- conn.ctx, conn.notifyDisconnected = context.WithCancel(ctx)
- defer conn.notifyDisconnected()
- conn.mu.Unlock()
-
- peerState = State{
- PubKey: conn.config.Key,
- ConnStatus: conn.status,
- ConnStatusUpdate: time.Now(),
- Mux: new(sync.RWMutex),
- }
- err = conn.statusRecorder.UpdatePeerState(peerState)
- if err != nil {
- log.Warnf("error while updating the state of peer %s,err: %v", conn.config.Key, err)
- }
-
- err = conn.agent.GatherCandidates()
- if err != nil {
- return fmt.Errorf("gather candidates: %v", err)
- }
-
- // will block until connection succeeded
- // but it won't release if ICE Agent went into Disconnected or Failed state,
- // so we have to cancel it with the provided context once agent detected a broken connection
- isControlling := conn.config.LocalKey > conn.config.Key
- var remoteConn *ice.Conn
- if isControlling {
- remoteConn, err = conn.agent.Dial(conn.ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
+ if conn.workerRelay.IsController() {
+ conn.reconnectLoopWithRetry()
} else {
- remoteConn, err = conn.agent.Accept(conn.ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
- }
- if err != nil {
- return err
- }
-
- // dynamically set remote WireGuard port if other side specified a different one from the default one
- remoteWgPort := iface.DefaultWgPort
- if remoteOfferAnswer.WgListenPort != 0 {
- remoteWgPort = remoteOfferAnswer.WgListenPort
- }
-
- // the ice connection has been established successfully so we are ready to start the proxy
- remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort, remoteOfferAnswer.RosenpassPubKey,
- remoteOfferAnswer.RosenpassAddr)
- if err != nil {
- return err
- }
-
- log.Infof("connected to peer %s, endpoint address: %s", conn.config.Key, remoteAddr.String())
-
- // wait until connection disconnected or has been closed externally (upper layer, e.g. engine)
- select {
- case <-conn.closeCh:
- // closed externally
- return NewConnectionClosedError(conn.config.Key)
- case <-conn.ctx.Done():
- // disconnected from the remote peer
- return NewConnectionDisconnectedError(conn.config.Key)
+ conn.reconnectLoopForOnDisconnectedEvent()
}
}
-func isRelayCandidate(candidate ice.Candidate) bool {
- return candidate.Type() == ice.CandidateTypeRelay
+// Close closes this peer Conn issuing a close event to the Conn closeCh
+func (conn *Conn) Close() {
+ conn.mu.Lock()
+ defer conn.mu.Unlock()
+
+ conn.log.Infof("close peer connection")
+ conn.ctxCancel()
+
+ if !conn.opened {
+ conn.log.Debugf("ignore close connection to peer")
+ return
+ }
+
+ conn.workerRelay.DisableWgWatcher()
+ conn.workerRelay.CloseConn()
+ conn.workerICE.Close()
+
+ if conn.wgProxyRelay != nil {
+ err := conn.wgProxyRelay.CloseConn()
+ if err != nil {
+ conn.log.Errorf("failed to close wg proxy for relay: %v", err)
+ }
+ conn.wgProxyRelay = nil
+ }
+
+ if conn.wgProxyICE != nil {
+ err := conn.wgProxyICE.CloseConn()
+ if err != nil {
+ conn.log.Errorf("failed to close wg proxy for ice: %v", err)
+ }
+ conn.wgProxyICE = nil
+ }
+
+ err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
+ if err != nil {
+ conn.log.Errorf("failed to remove wg endpoint: %v", err)
+ }
+
+ conn.freeUpConnID()
+
+ if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil {
+ conn.onDisconnected(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps)
+ }
+
+ conn.setStatusToDisconnected()
+}
+
+// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
+// doesn't block, discards the message if connection wasn't ready
+func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool {
+ conn.log.Debugf("OnRemoteAnswer, status ICE: %s, status relay: %s", conn.statusICE, conn.statusRelay)
+ return conn.handshaker.OnRemoteAnswer(answer)
+}
+
+// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
+func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
+ conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
}
func (conn *Conn) AddBeforeAddPeerHook(hook nbnet.AddHookFunc) {
conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook)
}
-
func (conn *Conn) AddAfterRemovePeerHook(hook nbnet.RemoveHookFunc) {
conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook)
}
-// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
-func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, remoteRosenpassPubKey []byte, remoteRosenpassAddr string) (net.Addr, error) {
- conn.mu.Lock()
- defer conn.mu.Unlock()
-
- pair, err := conn.agent.GetSelectedCandidatePair()
- if err != nil {
- return nil, err
- }
-
- var endpoint net.Addr
- if isRelayCandidate(pair.Local) {
- log.Debugf("setup relay connection")
- conn.wgProxy = conn.wgProxyFactory.GetProxy(conn.ctx)
- endpoint, err = conn.wgProxy.AddTurnConn(remoteConn)
- if err != nil {
- return nil, err
- }
- } else {
- // To support old version's with direct mode we attempt to punch an additional role with the remote WireGuard port
- go conn.punchRemoteWGPort(pair, remoteWgPort)
- endpoint = remoteConn.RemoteAddr()
- }
-
- endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
- log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
-
- conn.connID = nbnet.GenerateConnID()
- for _, hook := range conn.beforeAddPeerHooks {
- if err := hook(conn.connID, endpointUdpAddr.IP); err != nil {
- log.Errorf("Before add peer hook failed: %v", err)
- }
- }
-
- err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
- if err != nil {
- if conn.wgProxy != nil {
- if err := conn.wgProxy.CloseConn(); err != nil {
- log.Warnf("Failed to close turn connection: %v", err)
- }
- }
- return nil, fmt.Errorf("update peer: %w", err)
- }
-
- conn.status = StatusConnected
- rosenpassEnabled := false
- if remoteRosenpassPubKey != nil {
- rosenpassEnabled = true
- }
-
- peerState := State{
- PubKey: conn.config.Key,
- ConnStatus: conn.status,
- ConnStatusUpdate: time.Now(),
- LocalIceCandidateType: pair.Local.Type().String(),
- RemoteIceCandidateType: pair.Remote.Type().String(),
- LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
- RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
- Direct: !isRelayCandidate(pair.Local),
- RosenpassEnabled: rosenpassEnabled,
- Mux: new(sync.RWMutex),
- }
- if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
- peerState.Relayed = true
- }
-
- err = conn.statusRecorder.UpdatePeerState(peerState)
- if err != nil {
- log.Warnf("unable to save peer's state, got error: %v", err)
- }
-
- _, ipNet, err := net.ParseCIDR(conn.config.WgConfig.AllowedIps)
- if err != nil {
- return nil, err
- }
-
- if runtime.GOOS == "ios" {
- runtime.GC()
- }
-
- if conn.onConnected != nil {
- conn.onConnected(conn.config.Key, remoteRosenpassPubKey, ipNet.IP.String(), remoteRosenpassAddr)
- }
-
- return endpoint, nil
-}
-
-func (conn *Conn) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
- // wait local endpoint configuration
- time.Sleep(time.Second)
- addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pair.Remote.Address(), remoteWgPort))
- if err != nil {
- log.Warnf("got an error while resolving the udp address, err: %s", err)
- return
- }
-
- mux, ok := conn.config.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault)
- if !ok {
- log.Warn("invalid udp mux conversion")
- return
- }
- _, err = mux.GetSharedConn().WriteTo([]byte{0x6e, 0x62}, addr)
- if err != nil {
- log.Warnf("got an error while sending the punch packet, err: %s", err)
- }
-}
-
-// cleanup closes all open resources and sets status to StatusDisconnected
-func (conn *Conn) cleanup() error {
- log.Debugf("trying to cleanup %s", conn.config.Key)
- conn.mu.Lock()
- defer conn.mu.Unlock()
-
- conn.sentExtraSrflx = false
-
- var err1, err2, err3 error
- if conn.agent != nil {
- err1 = conn.agent.Close()
- if err1 == nil {
- conn.agent = nil
- }
- }
-
- if conn.wgProxy != nil {
- err2 = conn.wgProxy.CloseConn()
- conn.wgProxy = nil
- }
-
- // todo: is it problem if we try to remove a peer what is never existed?
- err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
-
- if conn.connID != "" {
- for _, hook := range conn.afterRemovePeerHooks {
- if err := hook(conn.connID); err != nil {
- log.Errorf("After remove peer hook failed: %v", err)
- }
- }
- }
- conn.connID = ""
-
- if conn.notifyDisconnected != nil {
- conn.notifyDisconnected()
- conn.notifyDisconnected = nil
- }
-
- if conn.status == StatusConnected && conn.onDisconnected != nil {
- conn.onDisconnected(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps)
- }
-
- conn.status = StatusDisconnected
-
- peerState := State{
- PubKey: conn.config.Key,
- ConnStatus: conn.status,
- ConnStatusUpdate: time.Now(),
- Mux: new(sync.RWMutex),
- }
- err := conn.statusRecorder.UpdatePeerState(peerState)
- if err != nil {
- // pretty common error because by that time Engine can already remove the peer and status won't be available.
- // todo rethink status updates
- log.Debugf("error while updating peer's %s state, err: %v", conn.config.Key, err)
- }
- if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, iface.WGStats{}); err != nil {
- log.Debugf("failed to reset wireguard stats for peer %s: %s", conn.config.Key, err)
- }
-
- log.Debugf("cleaned up connection to peer %s", conn.config.Key)
- if err1 != nil {
- return err1
- }
- if err2 != nil {
- return err2
- }
- return err3
-}
-
-// SetSignalOffer sets a handler function to be triggered by Conn when a new connection offer has to be signalled to the remote peer
-func (conn *Conn) SetSignalOffer(handler func(offer OfferAnswer) error) {
- conn.signalOffer = handler
-}
-
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)) {
conn.onConnected = handler
@@ -576,218 +283,521 @@ func (conn *Conn) SetOnDisconnected(handler func(remotePeer string, wgIP string)
conn.onDisconnected = handler
}
-// SetSignalAnswer sets a handler function to be triggered by Conn when a new connection answer has to be signalled to the remote peer
-func (conn *Conn) SetSignalAnswer(handler func(answer OfferAnswer) error) {
- conn.signalAnswer = handler
+func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool {
+ conn.log.Debugf("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay)
+ return conn.handshaker.OnRemoteOffer(offer)
}
-// SetSignalCandidate sets a handler function to be triggered by Conn when a new ICE local connection candidate has to be signalled to the remote peer
-func (conn *Conn) SetSignalCandidate(handler func(candidate ice.Candidate) error) {
- conn.signalCandidate = handler
-}
-
-// SetSendSignalMessage sets a handler function to be triggered by Conn when there is new message to send via signal
-func (conn *Conn) SetSendSignalMessage(handler func(message *sProto.Message) error) {
- conn.sendSignalMessage = handler
-}
-
-// onICECandidate is a callback attached to an ICE Agent to receive new local connection candidates
-// and then signals them to the remote peer
-func (conn *Conn) onICECandidate(candidate ice.Candidate) {
- // nil means candidate gathering has been ended
- if candidate == nil {
- return
- }
-
- // TODO: reported port is incorrect for CandidateTypeHost, makes understanding ICE use via logs confusing as port is ignored
- log.Debugf("discovered local candidate %s", candidate.String())
- go func() {
- err := conn.signalCandidate(candidate)
- if err != nil {
- log.Errorf("failed signaling candidate to the remote peer %s %s", conn.config.Key, err)
- }
- }()
-
- if !conn.shouldSendExtraSrflxCandidate(candidate) {
- return
- }
-
- // sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
- // this is useful when network has an existing port forwarding rule for the wireguard port and this peer
- extraSrflx, err := extraSrflxCandidate(candidate)
- if err != nil {
- log.Errorf("failed creating extra server reflexive candidate %s", err)
- return
- }
- conn.sentExtraSrflx = true
-
- go func() {
- err = conn.signalCandidate(extraSrflx)
- if err != nil {
- log.Errorf("failed signaling the extra server reflexive candidate to the remote peer %s: %s", conn.config.Key, err)
- }
- }()
-}
-
-func (conn *Conn) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) {
- log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(),
- conn.config.Key)
-}
-
-// onICEConnectionStateChange registers callback of an ICE Agent to track connection state
-func (conn *Conn) onICEConnectionStateChange(state ice.ConnectionState) {
- log.Debugf("peer %s ICE ConnectionState has changed to %s", conn.config.Key, state.String())
- if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected {
- conn.notifyDisconnected()
- }
-}
-
-func (conn *Conn) sendAnswer() error {
- conn.mu.Lock()
- defer conn.mu.Unlock()
-
- localUFrag, localPwd, err := conn.agent.GetLocalUserCredentials()
- if err != nil {
- return err
- }
-
- log.Debugf("sending answer to %s", conn.config.Key)
- err = conn.signalAnswer(OfferAnswer{
- IceCredentials: IceCredentials{localUFrag, localPwd},
- WgListenPort: conn.config.LocalWgPort,
- Version: version.NetbirdVersion(),
- RosenpassPubKey: conn.config.RosenpassPubKey,
- RosenpassAddr: conn.config.RosenpassAddr,
- })
- if err != nil {
- return err
- }
-
- return nil
-}
-
-// sendOffer prepares local user credentials and signals them to the remote peer
-func (conn *Conn) sendOffer() error {
- conn.mu.Lock()
- defer conn.mu.Unlock()
-
- localUFrag, localPwd, err := conn.agent.GetLocalUserCredentials()
- if err != nil {
- return err
- }
- err = conn.signalOffer(OfferAnswer{
- IceCredentials: IceCredentials{localUFrag, localPwd},
- WgListenPort: conn.config.LocalWgPort,
- Version: version.NetbirdVersion(),
- RosenpassPubKey: conn.config.RosenpassPubKey,
- RosenpassAddr: conn.config.RosenpassAddr,
- })
- if err != nil {
- return err
- }
- return nil
-}
-
-// Close closes this peer Conn issuing a close event to the Conn closeCh
-func (conn *Conn) Close() error {
- conn.mu.Lock()
- defer conn.mu.Unlock()
- select {
- case conn.closeCh <- struct{}{}:
- return nil
- default:
- // probably could happen when peer has been added and removed right after not even starting to connect
- // todo further investigate
- // this really happens due to unordered messages coming from management
- // more importantly it causes inconsistency -> 2 Conn objects for the same peer
- // e.g. this flow:
- // update from management has peers: [1,2,3,4]
- // engine creates a Conn for peers: [1,2,3,4] and schedules Open in ~1sec
- // before conn.Open() another update from management arrives with peers: [1,2,3]
- // engine removes peer 4 and calls conn.Close() which does nothing (this default clause)
- // before conn.Open() another update from management arrives with peers: [1,2,3,4,5]
- // engine adds a new Conn for 4 and 5
- // therefore peer 4 has 2 Conn objects
- log.Warnf("Connection has been already closed or attempted closing not started connection %s", conn.config.Key)
- return NewConnectionAlreadyClosed(conn.config.Key)
- }
+// WgConfig returns the WireGuard config
+func (conn *Conn) WgConfig() WgConfig {
+ return conn.config.WgConfig
}
// Status returns current status of the Conn
func (conn *Conn) Status() ConnStatus {
conn.mu.Lock()
defer conn.mu.Unlock()
- return conn.status
-}
-
-// OnRemoteOffer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
-// doesn't block, discards the message if connection wasn't ready
-func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool {
- log.Debugf("OnRemoteOffer from peer %s on status %s", conn.config.Key, conn.status.String())
-
- select {
- case conn.remoteOffersCh <- offer:
- return true
- default:
- log.Debugf("OnRemoteOffer skipping message from peer %s on status %s because is not ready", conn.config.Key, conn.status.String())
- // connection might not be ready yet to receive so we ignore the message
- return false
- }
-}
-
-// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
-// doesn't block, discards the message if connection wasn't ready
-func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool {
- log.Debugf("OnRemoteAnswer from peer %s on status %s", conn.config.Key, conn.status.String())
-
- select {
- case conn.remoteAnswerCh <- answer:
- return true
- default:
- // connection might not be ready yet to receive so we ignore the message
- log.Debugf("OnRemoteAnswer skipping message from peer %s on status %s because is not ready", conn.config.Key, conn.status.String())
- return false
- }
-}
-
-// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
-func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
- log.Debugf("OnRemoteCandidate from peer %s -> %s", conn.config.Key, candidate.String())
- go func() {
- conn.mu.Lock()
- defer conn.mu.Unlock()
-
- if conn.agent == nil {
- return
- }
-
- err := conn.agent.AddRemoteCandidate(candidate)
- if err != nil {
- log.Errorf("error while handling remote candidate from peer %s", conn.config.Key)
- return
- }
- }()
+ return conn.evalStatus()
}
func (conn *Conn) GetKey() string {
return conn.config.Key
}
-func (conn *Conn) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool {
- if !conn.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
- return true
+func (conn *Conn) reconnectLoopWithRetry() {
+ // Give chance to the peer to establish the initial connection.
+ // With it, we can decrease to send necessary offer
+ select {
+ case <-conn.ctx.Done():
+ case <-time.After(3 * time.Second):
+ }
+
+ ticker := conn.prepareExponentTicker()
+ defer ticker.Stop()
+ time.Sleep(1 * time.Second)
+ for {
+ select {
+ case t := <-ticker.C:
+ if t.IsZero() {
+ // in case if the ticker has been canceled by context then avoid the temporary loop
+ return
+ }
+
+ if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
+ if conn.statusRelay.Get() == StatusDisconnected || conn.statusICE.Get() == StatusDisconnected {
+ conn.log.Tracef("connectivity guard timedout, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE)
+ }
+ } else {
+ if conn.statusICE.Get() == StatusDisconnected {
+ conn.log.Tracef("connectivity guard timedout, ice state: %s", conn.statusICE)
+ }
+ }
+
+ // checks if there is peer connection is established via relay or ice
+ if conn.isConnected() {
+ continue
+ }
+
+ err := conn.handshaker.sendOffer()
+ if err != nil {
+ conn.log.Errorf("failed to do handshake: %v", err)
+ }
+ case changed := <-conn.relayDisconnected:
+ if !changed {
+ continue
+ }
+ conn.log.Debugf("Relay state changed, reset reconnect timer")
+ ticker.Stop()
+ ticker = conn.prepareExponentTicker()
+ case changed := <-conn.iCEDisconnected:
+ if !changed {
+ continue
+ }
+ conn.log.Debugf("ICE state changed, reset reconnect timer")
+ ticker.Stop()
+ ticker = conn.prepareExponentTicker()
+ case <-conn.ctx.Done():
+ conn.log.Debugf("context is done, stop reconnect loop")
+ return
+ }
}
- return false
}
-func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
- relatedAdd := candidate.RelatedAddress()
- return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
- Network: candidate.NetworkType().String(),
- Address: candidate.Address(),
- Port: relatedAdd.Port,
- Component: candidate.Component(),
- RelAddr: relatedAdd.Address,
- RelPort: relatedAdd.Port,
- })
+func (conn *Conn) prepareExponentTicker() *backoff.Ticker {
+ bo := backoff.WithContext(&backoff.ExponentialBackOff{
+ InitialInterval: 800 * time.Millisecond,
+ RandomizationFactor: 0.01,
+ Multiplier: 2,
+ MaxInterval: conn.config.Timeout,
+ MaxElapsedTime: 0,
+ Stop: backoff.Stop,
+ Clock: backoff.SystemClock,
+ }, conn.ctx)
+
+ ticker := backoff.NewTicker(bo)
+ <-ticker.C // consume the initial tick what is happening right after the ticker has been created
+
+ return ticker
+}
+
+// reconnectLoopForOnDisconnectedEvent is used when the peer is not a controller and it should reconnect to the peer
+// when the connection is lost. It will try to establish a connection only once time if before the connection was established
+// It track separately the ice and relay connection status. Just because a lover priority connection reestablished it does not
+// mean that to switch to it. We always force to use the higher priority connection.
+func (conn *Conn) reconnectLoopForOnDisconnectedEvent() {
+ for {
+ select {
+ case changed := <-conn.relayDisconnected:
+ if !changed {
+ continue
+ }
+ conn.log.Debugf("Relay state changed, try to send new offer")
+ case changed := <-conn.iCEDisconnected:
+ if !changed {
+ continue
+ }
+ conn.log.Debugf("ICE state changed, try to send new offer")
+ case <-conn.ctx.Done():
+ conn.log.Debugf("context is done, stop reconnect loop")
+ return
+ }
+
+ err := conn.handshaker.SendOffer()
+ if err != nil {
+ conn.log.Errorf("failed to do handshake: %v", err)
+ }
+ }
+}
+
+// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
+func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) {
+ conn.mu.Lock()
+ defer conn.mu.Unlock()
+
+ if conn.ctx.Err() != nil {
+ return
+ }
+
+ conn.log.Debugf("ICE connection is ready")
+
+ conn.statusICE.Set(StatusConnected)
+
+ defer conn.updateIceState(iceConnInfo)
+
+ if conn.currentConnPriority > priority {
+ return
+ }
+
+ conn.log.Infof("set ICE to active connection")
+
+ endpoint, wgProxy, err := conn.getEndpointForICEConnInfo(iceConnInfo)
+ if err != nil {
+ return
+ }
+
+ endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
+ conn.log.Debugf("Conn resolved IP is %s for endopint %s", endpoint, endpointUdpAddr.IP)
+
+ conn.connIDICE = nbnet.GenerateConnID()
+ for _, hook := range conn.beforeAddPeerHooks {
+ if err := hook(conn.connIDICE, endpointUdpAddr.IP); err != nil {
+ conn.log.Errorf("Before add peer hook failed: %v", err)
+ }
+ }
+
+ conn.workerRelay.DisableWgWatcher()
+
+ err = conn.configureWGEndpoint(endpointUdpAddr)
+ if err != nil {
+ if wgProxy != nil {
+ if err := wgProxy.CloseConn(); err != nil {
+ conn.log.Warnf("Failed to close turn connection: %v", err)
+ }
+ }
+ conn.log.Warnf("Failed to update wg peer configuration: %v", err)
+ return
+ }
+ wgConfigWorkaround()
+
+ if conn.wgProxyICE != nil {
+ if err := conn.wgProxyICE.CloseConn(); err != nil {
+ conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
+ }
+ }
+ conn.wgProxyICE = wgProxy
+
+ conn.currentConnPriority = priority
+
+ conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr)
+}
+
+// todo review to make sense to handle connecting and disconnected status also?
+func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
+ conn.mu.Lock()
+ defer conn.mu.Unlock()
+
+ if conn.ctx.Err() != nil {
+ return
+ }
+
+ conn.log.Tracef("ICE connection state changed to %s", newState)
+
+ // switch back to relay connection
+ if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay {
+ conn.log.Debugf("ICE disconnected, set Relay to active connection")
+ err := conn.configureWGEndpoint(conn.endpointRelay)
+ if err != nil {
+ conn.log.Errorf("failed to switch to relay conn: %v", err)
+ }
+ conn.workerRelay.EnableWgWatcher(conn.ctx)
+ conn.currentConnPriority = connPriorityRelay
+ }
+
+ changed := conn.statusICE.Get() != newState && newState != StatusConnecting
+ conn.statusICE.Set(newState)
+
+ select {
+ case conn.iCEDisconnected <- changed:
+ default:
+ }
+
+ peerState := State{
+ PubKey: conn.config.Key,
+ ConnStatus: conn.evalStatus(),
+ Relayed: conn.isRelayed(),
+ ConnStatusUpdate: time.Now(),
+ }
+
+ err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState)
+ if err != nil {
+ conn.log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err)
+ }
+}
+
+func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
+ conn.mu.Lock()
+ defer conn.mu.Unlock()
+
+ if conn.ctx.Err() != nil {
+ if err := rci.relayedConn.Close(); err != nil {
+ log.Warnf("failed to close unnecessary relayed connection: %v", err)
+ }
+ return
+ }
+
+ conn.log.Debugf("Relay connection is ready to use")
+ conn.statusRelay.Set(StatusConnected)
+
+ wgProxy := conn.wgProxyFactory.GetProxy()
+ endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn)
+ if err != nil {
+ conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err)
+ return
+ }
+ conn.log.Infof("created new wgProxy for relay connection: %s", endpoint)
+
+ endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String())
+ conn.endpointRelay = endpointUdpAddr
+ conn.log.Debugf("conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
+
+ defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
+
+ if conn.currentConnPriority > connPriorityRelay {
+ if conn.statusICE.Get() == StatusConnected {
+ log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority)
+ return
+ }
+ }
+
+ conn.connIDRelay = nbnet.GenerateConnID()
+ for _, hook := range conn.beforeAddPeerHooks {
+ if err := hook(conn.connIDRelay, endpointUdpAddr.IP); err != nil {
+ conn.log.Errorf("Before add peer hook failed: %v", err)
+ }
+ }
+
+ err = conn.configureWGEndpoint(endpointUdpAddr)
+ if err != nil {
+ if err := wgProxy.CloseConn(); err != nil {
+ conn.log.Warnf("Failed to close relay connection: %v", err)
+ }
+ conn.log.Errorf("Failed to update wg peer configuration: %v", err)
+ return
+ }
+ conn.workerRelay.EnableWgWatcher(conn.ctx)
+ wgConfigWorkaround()
+
+ if conn.wgProxyRelay != nil {
+ if err := conn.wgProxyRelay.CloseConn(); err != nil {
+ conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err)
+ }
+ }
+ conn.wgProxyRelay = wgProxy
+ conn.currentConnPriority = connPriorityRelay
+
+ conn.log.Infof("start to communicate with peer via relay")
+ conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr)
+}
+
+func (conn *Conn) onWorkerRelayStateDisconnected() {
+ conn.mu.Lock()
+ defer conn.mu.Unlock()
+
+ if conn.ctx.Err() != nil {
+ return
+ }
+
+ log.Debugf("relay connection is disconnected")
+
+ if conn.currentConnPriority == connPriorityRelay {
+ log.Debugf("clean up WireGuard config")
+ err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
+ if err != nil {
+ conn.log.Errorf("failed to remove wg endpoint: %v", err)
+ }
+ }
+
+ if conn.wgProxyRelay != nil {
+ conn.endpointRelay = nil
+ _ = conn.wgProxyRelay.CloseConn()
+ conn.wgProxyRelay = nil
+ }
+
+ changed := conn.statusRelay.Get() != StatusDisconnected
+ conn.statusRelay.Set(StatusDisconnected)
+
+ select {
+ case conn.relayDisconnected <- changed:
+ default:
+ }
+
+ peerState := State{
+ PubKey: conn.config.Key,
+ ConnStatus: conn.evalStatus(),
+ Relayed: conn.isRelayed(),
+ ConnStatusUpdate: time.Now(),
+ }
+
+ err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState)
+ if err != nil {
+ conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err)
+ }
+}
+
+func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr) error {
+ return conn.config.WgConfig.WgInterface.UpdatePeer(
+ conn.config.WgConfig.RemoteKey,
+ conn.config.WgConfig.AllowedIps,
+ defaultWgKeepAlive,
+ addr,
+ conn.config.WgConfig.PreSharedKey,
+ )
+}
+
+func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) {
+ peerState := State{
+ PubKey: conn.config.Key,
+ ConnStatusUpdate: time.Now(),
+ ConnStatus: conn.evalStatus(),
+ Relayed: conn.isRelayed(),
+ RelayServerAddress: relayServerAddr,
+ RosenpassEnabled: isRosenpassEnabled(rosenpassPubKey),
+ }
+
+ err := conn.statusRecorder.UpdatePeerRelayedState(peerState)
+ if err != nil {
+ conn.log.Warnf("unable to save peer's Relay state, got error: %v", err)
+ }
+}
+
+func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo) {
+ peerState := State{
+ PubKey: conn.config.Key,
+ ConnStatusUpdate: time.Now(),
+ ConnStatus: conn.evalStatus(),
+ Relayed: iceConnInfo.Relayed,
+ LocalIceCandidateType: iceConnInfo.LocalIceCandidateType,
+ RemoteIceCandidateType: iceConnInfo.RemoteIceCandidateType,
+ LocalIceCandidateEndpoint: iceConnInfo.LocalIceCandidateEndpoint,
+ RemoteIceCandidateEndpoint: iceConnInfo.RemoteIceCandidateEndpoint,
+ RosenpassEnabled: isRosenpassEnabled(iceConnInfo.RosenpassPubKey),
+ }
+
+ err := conn.statusRecorder.UpdatePeerICEState(peerState)
+ if err != nil {
+ conn.log.Warnf("unable to save peer's ICE state, got error: %v", err)
+ }
+}
+
+func (conn *Conn) setStatusToDisconnected() {
+ conn.statusRelay.Set(StatusDisconnected)
+ conn.statusICE.Set(StatusDisconnected)
+
+ peerState := State{
+ PubKey: conn.config.Key,
+ ConnStatus: StatusDisconnected,
+ ConnStatusUpdate: time.Now(),
+ Mux: new(sync.RWMutex),
+ }
+ err := conn.statusRecorder.UpdatePeerState(peerState)
+ if err != nil {
+ // pretty common error because by that time Engine can already remove the peer and status won't be available.
+ // todo rethink status updates
+ conn.log.Debugf("error while updating peer's state, err: %v", err)
+ }
+ if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, configurer.WGStats{}); err != nil {
+ conn.log.Debugf("failed to reset wireguard stats for peer: %s", err)
+ }
+}
+
+func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAddr string) {
+ if runtime.GOOS == "ios" {
+ runtime.GC()
+ }
+
+ if conn.onConnected != nil {
+ conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedIPsIP, remoteRosenpassAddr)
+ }
+}
+
+func (conn *Conn) waitInitialRandomSleepTime() {
+ minWait := 100
+ maxWait := 800
+ duration := time.Duration(rand.Intn(maxWait-minWait)+minWait) * time.Millisecond
+
+ timeout := time.NewTimer(duration)
+ defer timeout.Stop()
+
+ select {
+ case <-conn.ctx.Done():
+ case <-timeout.C:
+ }
+}
+
+func (conn *Conn) isRelayed() bool {
+ if conn.statusRelay.Get() == StatusDisconnected && (conn.statusICE.Get() == StatusDisconnected || conn.statusICE.Get() == StatusConnecting) {
+ return false
+ }
+
+ if conn.currentConnPriority == connPriorityICEP2P {
+ return false
+ }
+
+ return true
+}
+
+func (conn *Conn) evalStatus() ConnStatus {
+ if conn.statusRelay.Get() == StatusConnected || conn.statusICE.Get() == StatusConnected {
+ return StatusConnected
+ }
+
+ if conn.statusRelay.Get() == StatusConnecting || conn.statusICE.Get() == StatusConnecting {
+ return StatusConnecting
+ }
+
+ return StatusDisconnected
+}
+
+func (conn *Conn) isConnected() bool {
+ conn.mu.Lock()
+ defer conn.mu.Unlock()
+
+ if conn.statusICE.Get() != StatusConnected && conn.statusICE.Get() != StatusConnecting {
+ return false
+ }
+
+ if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
+ if conn.statusRelay.Get() != StatusConnected {
+ return false
+ }
+ }
+
+ return true
+}
+
+func (conn *Conn) freeUpConnID() {
+ if conn.connIDRelay != "" {
+ for _, hook := range conn.afterRemovePeerHooks {
+ if err := hook(conn.connIDRelay); err != nil {
+ conn.log.Errorf("After remove peer hook failed: %v", err)
+ }
+ }
+ conn.connIDRelay = ""
+ }
+
+ if conn.connIDICE != "" {
+ for _, hook := range conn.afterRemovePeerHooks {
+ if err := hook(conn.connIDICE); err != nil {
+ conn.log.Errorf("After remove peer hook failed: %v", err)
+ }
+ }
+ conn.connIDICE = ""
+ }
+}
+
+func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr, wgproxy.Proxy, error) {
+ if !iceConnInfo.RelayedOnLocal {
+ return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil
+ }
+ conn.log.Debugf("setup ice turn connection")
+ wgProxy := conn.wgProxyFactory.GetProxy()
+ ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn)
+ if err != nil {
+ conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err)
+ if errClose := wgProxy.CloseConn(); errClose != nil {
+ conn.log.Warnf("failed to close turn proxy connection: %v", errClose)
+ }
+ return nil, nil, err
+ }
+ return ep, wgProxy, nil
+}
+
+func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
+ return remoteRosenpassPubKey != nil
+}
+
+// wgConfigWorkaround is a workaround for the issue with WireGuard configuration update
+// When update a peer configuration in near to each other time, the second update can be ignored by WireGuard
+func wgConfigWorkaround() {
+ time.Sleep(100 * time.Millisecond)
}
diff --git a/client/internal/peer/conn_status.go b/client/internal/peer/conn_status.go
index 639117c89..3c747864f 100644
--- a/client/internal/peer/conn_status.go
+++ b/client/internal/peer/conn_status.go
@@ -1,6 +1,10 @@
package peer
-import log "github.com/sirupsen/logrus"
+import (
+ "sync/atomic"
+
+ log "github.com/sirupsen/logrus"
+)
const (
// StatusConnected indicate the peer is in connected state
@@ -12,7 +16,34 @@ const (
)
// ConnStatus describe the status of a peer's connection
-type ConnStatus int
+type ConnStatus int32
+
+// AtomicConnStatus is a thread-safe wrapper for ConnStatus
+type AtomicConnStatus struct {
+ status atomic.Int32
+}
+
+// NewAtomicConnStatus creates a new AtomicConnStatus with the given initial status
+func NewAtomicConnStatus() *AtomicConnStatus {
+ acs := &AtomicConnStatus{}
+ acs.Set(StatusDisconnected)
+ return acs
+}
+
+// Get returns the current connection status
+func (acs *AtomicConnStatus) Get() ConnStatus {
+ return ConnStatus(acs.status.Load())
+}
+
+// Set updates the connection status
+func (acs *AtomicConnStatus) Set(status ConnStatus) {
+ acs.status.Store(int32(status))
+}
+
+// String returns the string representation of the current status
+func (acs *AtomicConnStatus) String() string {
+ return acs.Get().String()
+}
func (s ConnStatus) String() string {
switch s {
diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go
index b608a5929..b4926a9d2 100644
--- a/client/internal/peer/conn_test.go
+++ b/client/internal/peer/conn_test.go
@@ -2,25 +2,33 @@ package peer
import (
"context"
+ "os"
"sync"
"testing"
"time"
"github.com/magiconair/properties/assert"
- "github.com/pion/stun/v2"
+ "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/wgproxy"
- "github.com/netbirdio/netbird/iface"
+ "github.com/netbirdio/netbird/util"
)
var connConf = ConnConfig{
- Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
- LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
- StunTurn: []*stun.URI{},
- InterfaceBlackList: nil,
- Timeout: time.Second,
- LocalWgPort: 51820,
+ Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
+ LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
+ Timeout: time.Second,
+ LocalWgPort: 51820,
+ ICEConfig: ICEConfig{
+ InterfaceBlackList: nil,
+ },
+}
+
+func TestMain(m *testing.M) {
+ _ = util.InitLog("trace", "console")
+ code := m.Run()
+ os.Exit(code)
}
func TestNewConn_interfaceFilter(t *testing.T) {
@@ -36,11 +44,11 @@ func TestNewConn_interfaceFilter(t *testing.T) {
}
func TestConn_GetKey(t *testing.T) {
- wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
+ wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
- conn, err := NewConn(connConf, nil, wgProxyFactory, nil, nil)
+ conn, err := NewConn(context.Background(), connConf, nil, wgProxyFactory, nil, nil, nil)
if err != nil {
return
}
@@ -51,11 +59,11 @@ func TestConn_GetKey(t *testing.T) {
}
func TestConn_OnRemoteOffer(t *testing.T) {
- wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
+ wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
- conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
+ conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
if err != nil {
return
}
@@ -63,7 +71,7 @@ func TestConn_OnRemoteOffer(t *testing.T) {
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
- <-conn.remoteOffersCh
+ <-conn.handshaker.remoteOffersCh
wg.Done()
}()
@@ -88,11 +96,11 @@ func TestConn_OnRemoteOffer(t *testing.T) {
}
func TestConn_OnRemoteAnswer(t *testing.T) {
- wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
+ wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
- conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
+ conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
if err != nil {
return
}
@@ -100,7 +108,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg := sync.WaitGroup{}
wg.Add(2)
go func() {
- <-conn.remoteAnswerCh
+ <-conn.handshaker.remoteAnswerCh
wg.Done()
}()
@@ -124,62 +132,42 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
wg.Wait()
}
func TestConn_Status(t *testing.T) {
- wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
+ wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort)
defer func() {
_ = wgProxyFactory.Free()
}()
- conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
+ conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil)
if err != nil {
return
}
tables := []struct {
- name string
- status ConnStatus
- want ConnStatus
+ name string
+ statusIce ConnStatus
+ statusRelay ConnStatus
+ want ConnStatus
}{
- {"StatusConnected", StatusConnected, StatusConnected},
- {"StatusDisconnected", StatusDisconnected, StatusDisconnected},
- {"StatusConnecting", StatusConnecting, StatusConnecting},
+ {"StatusConnected", StatusConnected, StatusConnected, StatusConnected},
+ {"StatusDisconnected", StatusDisconnected, StatusDisconnected, StatusDisconnected},
+ {"StatusConnecting", StatusConnecting, StatusConnecting, StatusConnecting},
+ {"StatusConnectingIce", StatusConnecting, StatusDisconnected, StatusConnecting},
+ {"StatusConnectingIceAlternative", StatusConnecting, StatusConnected, StatusConnected},
+ {"StatusConnectingRelay", StatusDisconnected, StatusConnecting, StatusConnecting},
+ {"StatusConnectingRelayAlternative", StatusConnected, StatusConnecting, StatusConnected},
}
for _, table := range tables {
t.Run(table.name, func(t *testing.T) {
- conn.status = table.status
+ si := NewAtomicConnStatus()
+ si.Set(table.statusIce)
+ conn.statusICE = si
+
+ sr := NewAtomicConnStatus()
+ sr.Set(table.statusRelay)
+ conn.statusRelay = sr
got := conn.Status()
assert.Equal(t, got, table.want, "they should be equal")
})
}
}
-
-func TestConn_Close(t *testing.T) {
- wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort)
- defer func() {
- _ = wgProxyFactory.Free()
- }()
- conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil)
- if err != nil {
- return
- }
-
- wg := sync.WaitGroup{}
- wg.Add(1)
- go func() {
- <-conn.closeCh
- wg.Done()
- }()
-
- go func() {
- for {
- err := conn.Close()
- if err != nil {
- continue
- } else {
- return
- }
- }
- }()
-
- wg.Wait()
-}
diff --git a/client/internal/peer/handshaker.go b/client/internal/peer/handshaker.go
new file mode 100644
index 000000000..545f81966
--- /dev/null
+++ b/client/internal/peer/handshaker.go
@@ -0,0 +1,192 @@
+package peer
+
+import (
+ "context"
+ "errors"
+ "sync"
+
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/version"
+)
+
+var (
+ ErrSignalIsNotReady = errors.New("signal is not ready")
+)
+
+// IceCredentials ICE protocol credentials struct
+type IceCredentials struct {
+ UFrag string
+ Pwd string
+}
+
+// OfferAnswer represents a session establishment offer or answer
+type OfferAnswer struct {
+ IceCredentials IceCredentials
+ // WgListenPort is a remote WireGuard listen port.
+ // This field is used when establishing a direct WireGuard connection without any proxy.
+ // We can set the remote peer's endpoint with this port.
+ WgListenPort int
+
+ // Version of NetBird Agent
+ Version string
+ // RosenpassPubKey is the Rosenpass public key of the remote peer when receiving this message
+ // This value is the local Rosenpass server public key when sending the message
+ RosenpassPubKey []byte
+ // RosenpassAddr is the Rosenpass server address (IP:port) of the remote peer when receiving this message
+ // This value is the local Rosenpass server address when sending the message
+ RosenpassAddr string
+
+ // relay server address
+ RelaySrvAddress string
+}
+
+type Handshaker struct {
+ mu sync.Mutex
+ ctx context.Context
+ log *log.Entry
+ config ConnConfig
+ signaler *Signaler
+ ice *WorkerICE
+ relay *WorkerRelay
+ onNewOfferListeners []func(*OfferAnswer)
+
+ // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
+ remoteOffersCh chan OfferAnswer
+ // remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
+ remoteAnswerCh chan OfferAnswer
+}
+
+func NewHandshaker(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay) *Handshaker {
+ return &Handshaker{
+ ctx: ctx,
+ log: log,
+ config: config,
+ signaler: signaler,
+ ice: ice,
+ relay: relay,
+ remoteOffersCh: make(chan OfferAnswer),
+ remoteAnswerCh: make(chan OfferAnswer),
+ }
+}
+
+func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) {
+ h.onNewOfferListeners = append(h.onNewOfferListeners, offer)
+}
+
+func (h *Handshaker) Listen() {
+ for {
+ h.log.Debugf("wait for remote offer confirmation")
+ remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation()
+ if err != nil {
+ var connectionClosedError *ConnectionClosedError
+ if errors.As(err, &connectionClosedError) {
+ h.log.Tracef("stop handshaker")
+ return
+ }
+ h.log.Errorf("failed to received remote offer confirmation: %s", err)
+ continue
+ }
+
+ h.log.Debugf("received connection confirmation, running version %s and with remote WireGuard listen port %d", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort)
+ for _, listener := range h.onNewOfferListeners {
+ go listener(remoteOfferAnswer)
+ }
+ }
+}
+
+func (h *Handshaker) SendOffer() error {
+ h.mu.Lock()
+ defer h.mu.Unlock()
+ return h.sendOffer()
+}
+
+// OnRemoteOffer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
+// doesn't block, discards the message if connection wasn't ready
+func (h *Handshaker) OnRemoteOffer(offer OfferAnswer) bool {
+ select {
+ case h.remoteOffersCh <- offer:
+ return true
+ default:
+ h.log.Debugf("OnRemoteOffer skipping message because is not ready")
+ // connection might not be ready yet to receive so we ignore the message
+ return false
+ }
+}
+
+// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise
+// doesn't block, discards the message if connection wasn't ready
+func (h *Handshaker) OnRemoteAnswer(answer OfferAnswer) bool {
+ select {
+ case h.remoteAnswerCh <- answer:
+ return true
+ default:
+ // connection might not be ready yet to receive so we ignore the message
+ h.log.Debugf("OnRemoteAnswer skipping message because is not ready")
+ return false
+ }
+}
+
+func (h *Handshaker) waitForRemoteOfferConfirmation() (*OfferAnswer, error) {
+ select {
+ case remoteOfferAnswer := <-h.remoteOffersCh:
+ // received confirmation from the remote peer -> ready to proceed
+ err := h.sendAnswer()
+ if err != nil {
+ return nil, err
+ }
+ return &remoteOfferAnswer, nil
+ case remoteOfferAnswer := <-h.remoteAnswerCh:
+ return &remoteOfferAnswer, nil
+ case <-h.ctx.Done():
+ // closed externally
+ return nil, NewConnectionClosedError(h.config.Key)
+ }
+}
+
+// sendOffer prepares local user credentials and signals them to the remote peer
+func (h *Handshaker) sendOffer() error {
+ if !h.signaler.Ready() {
+ return ErrSignalIsNotReady
+ }
+
+ iceUFrag, icePwd := h.ice.GetLocalUserCredentials()
+ offer := OfferAnswer{
+ IceCredentials: IceCredentials{iceUFrag, icePwd},
+ WgListenPort: h.config.LocalWgPort,
+ Version: version.NetbirdVersion(),
+ RosenpassPubKey: h.config.RosenpassPubKey,
+ RosenpassAddr: h.config.RosenpassAddr,
+ }
+
+ addr, err := h.relay.RelayInstanceAddress()
+ if err == nil {
+ offer.RelaySrvAddress = addr
+ }
+
+ return h.signaler.SignalOffer(offer, h.config.Key)
+}
+
+func (h *Handshaker) sendAnswer() error {
+ h.log.Debugf("sending answer")
+ uFrag, pwd := h.ice.GetLocalUserCredentials()
+
+ answer := OfferAnswer{
+ IceCredentials: IceCredentials{uFrag, pwd},
+ WgListenPort: h.config.LocalWgPort,
+ Version: version.NetbirdVersion(),
+ RosenpassPubKey: h.config.RosenpassPubKey,
+ RosenpassAddr: h.config.RosenpassAddr,
+ }
+ addr, err := h.relay.RelayInstanceAddress()
+ if err == nil {
+ answer.RelaySrvAddress = addr
+ }
+
+ err = h.signaler.SignalAnswer(answer, h.config.Key)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
diff --git a/client/internal/peer/signaler.go b/client/internal/peer/signaler.go
new file mode 100644
index 000000000..713123e5d
--- /dev/null
+++ b/client/internal/peer/signaler.go
@@ -0,0 +1,70 @@
+package peer
+
+import (
+ "github.com/pion/ice/v3"
+ "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
+
+ signal "github.com/netbirdio/netbird/signal/client"
+ sProto "github.com/netbirdio/netbird/signal/proto"
+)
+
+type Signaler struct {
+ signal signal.Client
+ wgPrivateKey wgtypes.Key
+}
+
+func NewSignaler(signal signal.Client, wgPrivateKey wgtypes.Key) *Signaler {
+ return &Signaler{
+ signal: signal,
+ wgPrivateKey: wgPrivateKey,
+ }
+}
+
+func (s *Signaler) SignalOffer(offer OfferAnswer, remoteKey string) error {
+ return s.signalOfferAnswer(offer, remoteKey, sProto.Body_OFFER)
+}
+
+func (s *Signaler) SignalAnswer(offer OfferAnswer, remoteKey string) error {
+ return s.signalOfferAnswer(offer, remoteKey, sProto.Body_ANSWER)
+}
+
+func (s *Signaler) SignalICECandidate(candidate ice.Candidate, remoteKey string) error {
+ return s.signal.Send(&sProto.Message{
+ Key: s.wgPrivateKey.PublicKey().String(),
+ RemoteKey: remoteKey,
+ Body: &sProto.Body{
+ Type: sProto.Body_CANDIDATE,
+ Payload: candidate.Marshal(),
+ },
+ })
+}
+
+func (s *Signaler) Ready() bool {
+ return s.signal.Ready()
+}
+
+// SignalOfferAnswer signals either an offer or an answer to remote peer
+func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
+ msg, err := signal.MarshalCredential(
+ s.wgPrivateKey,
+ offerAnswer.WgListenPort,
+ remoteKey,
+ &signal.Credential{
+ UFrag: offerAnswer.IceCredentials.UFrag,
+ Pwd: offerAnswer.IceCredentials.Pwd,
+ },
+ bodyType,
+ offerAnswer.RosenpassPubKey,
+ offerAnswer.RosenpassAddr,
+ offerAnswer.RelaySrvAddress)
+ if err != nil {
+ return err
+ }
+
+ err = s.signal.Send(msg)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go
index a7cfb95c4..a28992fac 100644
--- a/client/internal/peer/status.go
+++ b/client/internal/peer/status.go
@@ -3,6 +3,7 @@ package peer
import (
"errors"
"net/netip"
+ "slices"
"sync"
"time"
@@ -10,9 +11,10 @@ import (
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
+ "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/internal/relay"
- "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/management/domain"
+ relayClient "github.com/netbirdio/netbird/relay/client"
)
// State contains the latest state of a peer
@@ -24,11 +26,11 @@ type State struct {
ConnStatus ConnStatus
ConnStatusUpdate time.Time
Relayed bool
- Direct bool
LocalIceCandidateType string
RemoteIceCandidateType string
LocalIceCandidateEndpoint string
RemoteIceCandidateEndpoint string
+ RelayServerAddress string
LastWireguardHandshake time.Time
BytesTx int64
BytesRx int64
@@ -142,6 +144,8 @@ type Status struct {
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
// set to true this variable and at the end of the processing we will reset it by the FinishPeerListModifications()
peerListChangedForNotification bool
+
+ relayMgr *relayClient.Manager
}
// NewRecorder returns a new Status instance
@@ -156,6 +160,12 @@ func NewRecorder(mgmAddress string) *Status {
}
}
+func (d *Status) SetRelayMgr(manager *relayClient.Manager) {
+ d.mux.Lock()
+ defer d.mux.Unlock()
+ d.relayMgr = manager
+}
+
// ReplaceOfflinePeers replaces
func (d *Status) ReplaceOfflinePeers(replacement []State) {
d.mux.Lock()
@@ -193,7 +203,7 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) {
state, ok := d.peers[peerPubKey]
if !ok {
- return State{}, iface.ErrPeerNotFound
+ return State{}, configurer.ErrPeerNotFound
}
return state, nil
}
@@ -231,17 +241,17 @@ func (d *Status) UpdatePeerState(receivedState State) error {
peerState.SetRoutes(receivedState.GetRoutes())
}
- skipNotification := shouldSkipNotify(receivedState, peerState)
+ skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
if receivedState.ConnStatus != peerState.ConnStatus {
peerState.ConnStatus = receivedState.ConnStatus
peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
- peerState.Direct = receivedState.Direct
peerState.Relayed = receivedState.Relayed
peerState.LocalIceCandidateType = receivedState.LocalIceCandidateType
peerState.RemoteIceCandidateType = receivedState.RemoteIceCandidateType
peerState.LocalIceCandidateEndpoint = receivedState.LocalIceCandidateEndpoint
peerState.RemoteIceCandidateEndpoint = receivedState.RemoteIceCandidateEndpoint
+ peerState.RelayServerAddress = receivedState.RelayServerAddress
peerState.RosenpassEnabled = receivedState.RosenpassEnabled
}
@@ -261,8 +271,148 @@ func (d *Status) UpdatePeerState(receivedState State) error {
return nil
}
+func (d *Status) UpdatePeerICEState(receivedState State) error {
+ d.mux.Lock()
+ defer d.mux.Unlock()
+
+ peerState, ok := d.peers[receivedState.PubKey]
+ if !ok {
+ return errors.New("peer doesn't exist")
+ }
+
+ if receivedState.IP != "" {
+ peerState.IP = receivedState.IP
+ }
+
+ skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
+
+ peerState.ConnStatus = receivedState.ConnStatus
+ peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
+ peerState.Relayed = receivedState.Relayed
+ peerState.LocalIceCandidateType = receivedState.LocalIceCandidateType
+ peerState.RemoteIceCandidateType = receivedState.RemoteIceCandidateType
+ peerState.LocalIceCandidateEndpoint = receivedState.LocalIceCandidateEndpoint
+ peerState.RemoteIceCandidateEndpoint = receivedState.RemoteIceCandidateEndpoint
+ peerState.RosenpassEnabled = receivedState.RosenpassEnabled
+
+ d.peers[receivedState.PubKey] = peerState
+
+ if skipNotification {
+ return nil
+ }
+
+ ch, found := d.changeNotify[receivedState.PubKey]
+ if found && ch != nil {
+ close(ch)
+ d.changeNotify[receivedState.PubKey] = nil
+ }
+
+ d.notifyPeerListChanged()
+ return nil
+}
+
+func (d *Status) UpdatePeerRelayedState(receivedState State) error {
+ d.mux.Lock()
+ defer d.mux.Unlock()
+
+ peerState, ok := d.peers[receivedState.PubKey]
+ if !ok {
+ return errors.New("peer doesn't exist")
+ }
+
+ skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
+
+ peerState.ConnStatus = receivedState.ConnStatus
+ peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
+ peerState.Relayed = receivedState.Relayed
+ peerState.RelayServerAddress = receivedState.RelayServerAddress
+ peerState.RosenpassEnabled = receivedState.RosenpassEnabled
+
+ d.peers[receivedState.PubKey] = peerState
+
+ if skipNotification {
+ return nil
+ }
+
+ ch, found := d.changeNotify[receivedState.PubKey]
+ if found && ch != nil {
+ close(ch)
+ d.changeNotify[receivedState.PubKey] = nil
+ }
+
+ d.notifyPeerListChanged()
+ return nil
+}
+
+func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error {
+ d.mux.Lock()
+ defer d.mux.Unlock()
+
+ peerState, ok := d.peers[receivedState.PubKey]
+ if !ok {
+ return errors.New("peer doesn't exist")
+ }
+
+ skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
+
+ peerState.ConnStatus = receivedState.ConnStatus
+ peerState.Relayed = receivedState.Relayed
+ peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
+ peerState.RelayServerAddress = ""
+
+ d.peers[receivedState.PubKey] = peerState
+
+ if skipNotification {
+ return nil
+ }
+
+ ch, found := d.changeNotify[receivedState.PubKey]
+ if found && ch != nil {
+ close(ch)
+ d.changeNotify[receivedState.PubKey] = nil
+ }
+
+ d.notifyPeerListChanged()
+ return nil
+}
+
+func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
+ d.mux.Lock()
+ defer d.mux.Unlock()
+
+ peerState, ok := d.peers[receivedState.PubKey]
+ if !ok {
+ return errors.New("peer doesn't exist")
+ }
+
+ skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState)
+
+ peerState.ConnStatus = receivedState.ConnStatus
+ peerState.Relayed = receivedState.Relayed
+ peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate
+ peerState.LocalIceCandidateType = receivedState.LocalIceCandidateType
+ peerState.RemoteIceCandidateType = receivedState.RemoteIceCandidateType
+ peerState.LocalIceCandidateEndpoint = receivedState.LocalIceCandidateEndpoint
+ peerState.RemoteIceCandidateEndpoint = receivedState.RemoteIceCandidateEndpoint
+
+ d.peers[receivedState.PubKey] = peerState
+
+ if skipNotification {
+ return nil
+ }
+
+ ch, found := d.changeNotify[receivedState.PubKey]
+ if found && ch != nil {
+ close(ch)
+ d.changeNotify[receivedState.PubKey] = nil
+ }
+
+ d.notifyPeerListChanged()
+ return nil
+}
+
// UpdateWireGuardPeerState updates the WireGuard bits of the peer state
-func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats iface.WGStats) error {
+func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats configurer.WGStats) error {
d.mux.Lock()
defer d.mux.Unlock()
@@ -280,13 +430,13 @@ func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats iface.WGStats)
return nil
}
-func shouldSkipNotify(received, curr State) bool {
+func shouldSkipNotify(receivedConnStatus ConnStatus, curr State) bool {
switch {
- case received.ConnStatus == StatusConnecting:
+ case receivedConnStatus == StatusConnecting:
return true
- case received.ConnStatus == StatusDisconnected && curr.ConnStatus == StatusConnecting:
+ case receivedConnStatus == StatusDisconnected && curr.ConnStatus == StatusConnecting:
return true
- case received.ConnStatus == StatusDisconnected && curr.ConnStatus == StatusDisconnected:
+ case receivedConnStatus == StatusDisconnected && curr.ConnStatus == StatusDisconnected:
return curr.IP != ""
default:
return false
@@ -447,6 +597,8 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
}
func (d *Status) GetRosenpassState() RosenpassState {
+ d.mux.Lock()
+ defer d.mux.Unlock()
return RosenpassState{
d.rosenpassEnabled,
d.rosenpassPermissive,
@@ -454,6 +606,8 @@ func (d *Status) GetRosenpassState() RosenpassState {
}
func (d *Status) GetManagementState() ManagementState {
+ d.mux.Lock()
+ defer d.mux.Unlock()
return ManagementState{
d.mgmAddress,
d.managementState,
@@ -495,6 +649,8 @@ func (d *Status) IsLoginRequired() bool {
}
func (d *Status) GetSignalState() SignalState {
+ d.mux.Lock()
+ defer d.mux.Unlock()
return SignalState{
d.signalAddress,
d.signalState,
@@ -502,11 +658,42 @@ func (d *Status) GetSignalState() SignalState {
}
}
+// GetRelayStates returns the stun/turn/permanent relay states
func (d *Status) GetRelayStates() []relay.ProbeResult {
- return d.relayStates
+ d.mux.Lock()
+ defer d.mux.Unlock()
+ if d.relayMgr == nil {
+ return d.relayStates
+ }
+
+ // extend the list of stun, turn servers with relay address
+ relayStates := slices.Clone(d.relayStates)
+
+ var relayState relay.ProbeResult
+
+ // if the server connection is not established then we will use the general address
+ // in case of connection we will use the instance specific address
+ instanceAddr, err := d.relayMgr.RelayInstanceAddress()
+ if err != nil {
+ // TODO add their status
+ if errors.Is(err, relayClient.ErrRelayClientNotConnected) {
+ for _, r := range d.relayMgr.ServerURLs() {
+ relayStates = append(relayStates, relay.ProbeResult{
+ URI: r,
+ })
+ }
+ return relayStates
+ }
+ relayState.Err = err
+ }
+
+ relayState.URI = instanceAddr
+ return append(relayStates, relayState)
}
func (d *Status) GetDNSStates() []NSGroupState {
+ d.mux.Lock()
+ defer d.mux.Unlock()
return d.nsGroupStates
}
@@ -518,24 +705,24 @@ func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix {
// GetFullStatus gets full status
func (d *Status) GetFullStatus() FullStatus {
- d.mux.Lock()
- defer d.mux.Unlock()
-
fullStatus := FullStatus{
ManagementState: d.GetManagementState(),
SignalState: d.GetSignalState(),
- LocalPeerState: d.localPeer,
Relays: d.GetRelayStates(),
RosenpassState: d.GetRosenpassState(),
NSGroupStates: d.GetDNSStates(),
}
+ d.mux.Lock()
+ defer d.mux.Unlock()
+
+ fullStatus.LocalPeerState = d.localPeer
+
for _, status := range d.peers {
fullStatus.Peers = append(fullStatus.Peers, status)
}
fullStatus.Peers = append(fullStatus.Peers, d.offlinePeers...)
-
return fullStatus
}
diff --git a/client/internal/peer/status_test.go b/client/internal/peer/status_test.go
index a4a6e6081..1d283433b 100644
--- a/client/internal/peer/status_test.go
+++ b/client/internal/peer/status_test.go
@@ -2,8 +2,8 @@ package peer
import (
"errors"
- "testing"
"sync"
+ "testing"
"github.com/stretchr/testify/assert"
)
@@ -43,7 +43,7 @@ func TestUpdatePeerState(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
- Mux: new(sync.RWMutex),
+ Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -64,7 +64,7 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
- Mux: new(sync.RWMutex),
+ Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -83,7 +83,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
- Mux: new(sync.RWMutex),
+ Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
@@ -108,7 +108,7 @@ func TestRemovePeer(t *testing.T) {
status := NewRecorder("https://mgm")
peerState := State{
PubKey: key,
- Mux: new(sync.RWMutex),
+ Mux: new(sync.RWMutex),
}
status.peers[key] = peerState
diff --git a/client/internal/peer/stdnet.go b/client/internal/peer/stdnet.go
index 13f5886f5..ae31ebbf0 100644
--- a/client/internal/peer/stdnet.go
+++ b/client/internal/peer/stdnet.go
@@ -6,6 +6,6 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet"
)
-func (conn *Conn) newStdNet() (*stdnet.Net, error) {
- return stdnet.NewNet(conn.config.InterfaceBlackList)
+func (w *WorkerICE) newStdNet() (*stdnet.Net, error) {
+ return stdnet.NewNet(w.config.ICEConfig.InterfaceBlackList)
}
diff --git a/client/internal/peer/stdnet_android.go b/client/internal/peer/stdnet_android.go
index 8a2454371..b411405bb 100644
--- a/client/internal/peer/stdnet_android.go
+++ b/client/internal/peer/stdnet_android.go
@@ -2,6 +2,6 @@ package peer
import "github.com/netbirdio/netbird/client/internal/stdnet"
-func (conn *Conn) newStdNet() (*stdnet.Net, error) {
- return stdnet.NewNetWithDiscover(conn.iFaceDiscover, conn.config.InterfaceBlackList)
+func (w *WorkerICE) newStdNet() (*stdnet.Net, error) {
+ return stdnet.NewNetWithDiscover(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList)
}
diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go
new file mode 100644
index 000000000..c4e9d1950
--- /dev/null
+++ b/client/internal/peer/worker_ice.go
@@ -0,0 +1,470 @@
+package peer
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "net/netip"
+ "runtime"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/pion/ice/v3"
+ "github.com/pion/randutil"
+ "github.com/pion/stun/v2"
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/client/iface"
+ "github.com/netbirdio/netbird/client/iface/bind"
+ "github.com/netbirdio/netbird/client/internal/stdnet"
+ "github.com/netbirdio/netbird/route"
+)
+
+const (
+ iceKeepAliveDefault = 4 * time.Second
+ iceDisconnectedTimeoutDefault = 6 * time.Second
+ // iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package
+ iceRelayAcceptanceMinWaitDefault = 2 * time.Second
+
+ lenUFrag = 16
+ lenPwd = 32
+ runesAlpha = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
+)
+
+var (
+ failedTimeout = 6 * time.Second
+)
+
+type ICEConfig struct {
+ // StunTurn is a list of STUN and TURN URLs
+ StunTurn *atomic.Value // []*stun.URI
+
+ // InterfaceBlackList is a list of machine interfaces that should be filtered out by ICE Candidate gathering
+ // (e.g. if eth0 is in the list, host candidate of this interface won't be used)
+ InterfaceBlackList []string
+ DisableIPv6Discovery bool
+
+ UDPMux ice.UDPMux
+ UDPMuxSrflx ice.UniversalUDPMux
+
+ NATExternalIPs []string
+}
+
+type ICEConnInfo struct {
+ RemoteConn net.Conn
+ RosenpassPubKey []byte
+ RosenpassAddr string
+ LocalIceCandidateType string
+ RemoteIceCandidateType string
+ RemoteIceCandidateEndpoint string
+ LocalIceCandidateEndpoint string
+ Relayed bool
+ RelayedOnLocal bool
+}
+
+type WorkerICECallbacks struct {
+ OnConnReady func(ConnPriority, ICEConnInfo)
+ OnStatusChanged func(ConnStatus)
+}
+
+type WorkerICE struct {
+ ctx context.Context
+ log *log.Entry
+ config ConnConfig
+ signaler *Signaler
+ iFaceDiscover stdnet.ExternalIFaceDiscover
+ statusRecorder *Status
+ hasRelayOnLocally bool
+ conn WorkerICECallbacks
+
+ selectedPriority ConnPriority
+
+ agent *ice.Agent
+ muxAgent sync.Mutex
+
+ StunTurn []*stun.URI
+
+ sentExtraSrflx bool
+
+ localUfrag string
+ localPwd string
+}
+
+func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) {
+ w := &WorkerICE{
+ ctx: ctx,
+ log: log,
+ config: config,
+ signaler: signaler,
+ iFaceDiscover: ifaceDiscover,
+ statusRecorder: statusRecorder,
+ hasRelayOnLocally: hasRelayOnLocally,
+ conn: callBacks,
+ }
+
+ localUfrag, localPwd, err := generateICECredentials()
+ if err != nil {
+ return nil, err
+ }
+ w.localUfrag = localUfrag
+ w.localPwd = localPwd
+ return w, nil
+}
+
+func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
+ w.log.Debugf("OnNewOffer for ICE")
+ w.muxAgent.Lock()
+
+ if w.agent != nil {
+ w.log.Debugf("agent already exists, skipping the offer")
+ w.muxAgent.Unlock()
+ return
+ }
+
+ var preferredCandidateTypes []ice.CandidateType
+ if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" {
+ w.selectedPriority = connPriorityICEP2P
+ preferredCandidateTypes = candidateTypesP2P()
+ } else {
+ w.selectedPriority = connPriorityICETurn
+ preferredCandidateTypes = candidateTypes()
+ }
+
+ w.log.Debugf("recreate ICE agent")
+ agentCtx, agentCancel := context.WithCancel(w.ctx)
+ agent, err := w.reCreateAgent(agentCancel, preferredCandidateTypes)
+ if err != nil {
+ w.log.Errorf("failed to recreate ICE Agent: %s", err)
+ w.muxAgent.Unlock()
+ return
+ }
+ w.agent = agent
+ w.muxAgent.Unlock()
+
+ w.log.Debugf("gather candidates")
+ err = w.agent.GatherCandidates()
+ if err != nil {
+ w.log.Debugf("failed to gather candidates: %s", err)
+ return
+ }
+
+ // will block until connection succeeded
+ // but it won't release if ICE Agent went into Disconnected or Failed state,
+ // so we have to cancel it with the provided context once agent detected a broken connection
+ w.log.Debugf("turn agent dial")
+ remoteConn, err := w.turnAgentDial(agentCtx, remoteOfferAnswer)
+ if err != nil {
+ w.log.Debugf("failed to dial the remote peer: %s", err)
+ return
+ }
+ w.log.Debugf("agent dial succeeded")
+
+ pair, err := w.agent.GetSelectedCandidatePair()
+ if err != nil {
+ return
+ }
+
+ if !isRelayCandidate(pair.Local) {
+ // dynamically set remote WireGuard port if other side specified a different one from the default one
+ remoteWgPort := iface.DefaultWgPort
+ if remoteOfferAnswer.WgListenPort != 0 {
+ remoteWgPort = remoteOfferAnswer.WgListenPort
+ }
+
+ // To support old version's with direct mode we attempt to punch an additional role with the remote WireGuard port
+ go w.punchRemoteWGPort(pair, remoteWgPort)
+ }
+
+ ci := ICEConnInfo{
+ RemoteConn: remoteConn,
+ RosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
+ RosenpassAddr: remoteOfferAnswer.RosenpassAddr,
+ LocalIceCandidateType: pair.Local.Type().String(),
+ RemoteIceCandidateType: pair.Remote.Type().String(),
+ LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()),
+ RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()),
+ Relayed: isRelayed(pair),
+ RelayedOnLocal: isRelayCandidate(pair.Local),
+ }
+ w.log.Debugf("on ICE conn read to use ready")
+ go w.conn.OnConnReady(w.selectedPriority, ci)
+}
+
+// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
+func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
+ w.muxAgent.Lock()
+ defer w.muxAgent.Unlock()
+ w.log.Debugf("OnRemoteCandidate from peer %s -> %s", w.config.Key, candidate.String())
+ if w.agent == nil {
+ w.log.Warnf("ICE Agent is not initialized yet")
+ return
+ }
+
+ if candidateViaRoutes(candidate, haRoutes) {
+ return
+ }
+
+ err := w.agent.AddRemoteCandidate(candidate)
+ if err != nil {
+ w.log.Errorf("error while handling remote candidate")
+ return
+ }
+}
+
+func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) {
+ w.muxAgent.Lock()
+ defer w.muxAgent.Unlock()
+ return w.localUfrag, w.localPwd
+}
+
+func (w *WorkerICE) Close() {
+ w.muxAgent.Lock()
+ defer w.muxAgent.Unlock()
+
+ if w.agent == nil {
+ return
+ }
+
+ err := w.agent.Close()
+ if err != nil {
+ w.log.Warnf("failed to close ICE agent: %s", err)
+ }
+}
+
+func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, relaySupport []ice.CandidateType) (*ice.Agent, error) {
+ transportNet, err := w.newStdNet()
+ if err != nil {
+ w.log.Errorf("failed to create pion's stdnet: %s", err)
+ }
+
+ iceKeepAlive := iceKeepAlive()
+ iceDisconnectedTimeout := iceDisconnectedTimeout()
+ iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait()
+
+ agentConfig := &ice.AgentConfig{
+ MulticastDNSMode: ice.MulticastDNSModeDisabled,
+ NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6},
+ Urls: w.config.ICEConfig.StunTurn.Load().([]*stun.URI),
+ CandidateTypes: relaySupport,
+ InterfaceFilter: stdnet.InterfaceFilter(w.config.ICEConfig.InterfaceBlackList),
+ UDPMux: w.config.ICEConfig.UDPMux,
+ UDPMuxSrflx: w.config.ICEConfig.UDPMuxSrflx,
+ NAT1To1IPs: w.config.ICEConfig.NATExternalIPs,
+ Net: transportNet,
+ FailedTimeout: &failedTimeout,
+ DisconnectedTimeout: &iceDisconnectedTimeout,
+ KeepaliveInterval: &iceKeepAlive,
+ RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait,
+ LocalUfrag: w.localUfrag,
+ LocalPwd: w.localPwd,
+ }
+
+ if w.config.ICEConfig.DisableIPv6Discovery {
+ agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4}
+ }
+
+ w.sentExtraSrflx = false
+ agent, err := ice.NewAgent(agentConfig)
+ if err != nil {
+ return nil, err
+ }
+
+ err = agent.OnCandidate(w.onICECandidate)
+ if err != nil {
+ return nil, err
+ }
+
+ err = agent.OnConnectionStateChange(func(state ice.ConnectionState) {
+ w.log.Debugf("ICE ConnectionState has changed to %s", state.String())
+ if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected {
+ w.conn.OnStatusChanged(StatusDisconnected)
+
+ w.muxAgent.Lock()
+ agentCancel()
+ _ = agent.Close()
+ w.agent = nil
+
+ w.muxAgent.Unlock()
+ }
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ err = agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair)
+ if err != nil {
+ return nil, err
+ }
+
+ err = agent.OnSuccessfulSelectedPairBindingResponse(func(p *ice.CandidatePair) {
+ err := w.statusRecorder.UpdateLatency(w.config.Key, p.Latency())
+ if err != nil {
+ w.log.Debugf("failed to update latency for peer: %s", err)
+ return
+ }
+ })
+ if err != nil {
+ return nil, fmt.Errorf("failed setting binding response callback: %w", err)
+ }
+
+ return agent, nil
+}
+
+func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) {
+ // wait local endpoint configuration
+ time.Sleep(time.Second)
+ addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pair.Remote.Address(), remoteWgPort))
+ if err != nil {
+ w.log.Warnf("got an error while resolving the udp address, err: %s", err)
+ return
+ }
+
+ mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault)
+ if !ok {
+ w.log.Warn("invalid udp mux conversion")
+ return
+ }
+ _, err = mux.GetSharedConn().WriteTo([]byte{0x6e, 0x62}, addr)
+ if err != nil {
+ w.log.Warnf("got an error while sending the punch packet, err: %s", err)
+ }
+}
+
+// onICECandidate is a callback attached to an ICE Agent to receive new local connection candidates
+// and then signals them to the remote peer
+func (w *WorkerICE) onICECandidate(candidate ice.Candidate) {
+ // nil means candidate gathering has been ended
+ if candidate == nil {
+ return
+ }
+
+ // TODO: reported port is incorrect for CandidateTypeHost, makes understanding ICE use via logs confusing as port is ignored
+ w.log.Debugf("discovered local candidate %s", candidate.String())
+ go func() {
+ err := w.signaler.SignalICECandidate(candidate, w.config.Key)
+ if err != nil {
+ w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err)
+ }
+ }()
+
+ if !w.shouldSendExtraSrflxCandidate(candidate) {
+ return
+ }
+
+ // sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port)
+ // this is useful when network has an existing port forwarding rule for the wireguard port and this peer
+ extraSrflx, err := extraSrflxCandidate(candidate)
+ if err != nil {
+ w.log.Errorf("failed creating extra server reflexive candidate %s", err)
+ return
+ }
+ w.sentExtraSrflx = true
+
+ go func() {
+ err = w.signaler.SignalICECandidate(extraSrflx, w.config.Key)
+ if err != nil {
+ w.log.Errorf("failed signaling the extra server reflexive candidate: %s", err)
+ }
+ }()
+}
+
+func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) {
+ w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(),
+ w.config.Key)
+}
+
+func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool {
+ if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port {
+ return true
+ }
+ return false
+}
+
+func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
+ isControlling := w.config.LocalKey > w.config.Key
+ if isControlling {
+ return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
+ } else {
+ return w.agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
+ }
+}
+
+func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) {
+ relatedAdd := candidate.RelatedAddress()
+ return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{
+ Network: candidate.NetworkType().String(),
+ Address: candidate.Address(),
+ Port: relatedAdd.Port,
+ Component: candidate.Component(),
+ RelAddr: relatedAdd.Address,
+ RelPort: relatedAdd.Port,
+ })
+}
+
+func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
+ var routePrefixes []netip.Prefix
+ for _, routes := range clientRoutes {
+ if len(routes) > 0 && routes[0] != nil {
+ routePrefixes = append(routePrefixes, routes[0].Network)
+ }
+ }
+
+ addr, err := netip.ParseAddr(candidate.Address())
+ if err != nil {
+ log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
+ return false
+ }
+
+ for _, prefix := range routePrefixes {
+ // default route is
+ if prefix.Bits() == 0 {
+ continue
+ }
+
+ if prefix.Contains(addr) {
+ log.Debugf("Ignoring candidate [%s], its address is part of routed network %s", candidate.String(), prefix)
+ return true
+ }
+ }
+ return false
+}
+
+func candidateTypes() []ice.CandidateType {
+ if hasICEForceRelayConn() {
+ return []ice.CandidateType{ice.CandidateTypeRelay}
+ }
+ // TODO: remove this once we have refactored userspace proxy into the bind package
+ if runtime.GOOS == "ios" {
+ return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
+ }
+ return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay}
+}
+
+func candidateTypesP2P() []ice.CandidateType {
+ return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive}
+}
+
+func isRelayCandidate(candidate ice.Candidate) bool {
+ return candidate.Type() == ice.CandidateTypeRelay
+}
+
+func isRelayed(pair *ice.CandidatePair) bool {
+ if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay {
+ return true
+ }
+ return false
+}
+
+func generateICECredentials() (string, string, error) {
+ ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha)
+ if err != nil {
+ return "", "", err
+ }
+
+ pwd, err := randutil.GenerateCryptoRandomString(lenPwd, runesAlpha)
+ if err != nil {
+ return "", "", err
+ }
+ return ufrag, pwd, nil
+}
diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go
new file mode 100644
index 000000000..c02fccebc
--- /dev/null
+++ b/client/internal/peer/worker_relay.go
@@ -0,0 +1,237 @@
+package peer
+
+import (
+ "context"
+ "errors"
+ "net"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+
+ relayClient "github.com/netbirdio/netbird/relay/client"
+)
+
+var (
+ wgHandshakePeriod = 3 * time.Minute
+ wgHandshakeOvertime = 30 * time.Second
+)
+
+type RelayConnInfo struct {
+ relayedConn net.Conn
+ rosenpassPubKey []byte
+ rosenpassAddr string
+}
+
+type WorkerRelayCallbacks struct {
+ OnConnReady func(RelayConnInfo)
+ OnDisconnected func()
+}
+
+type WorkerRelay struct {
+ log *log.Entry
+ config ConnConfig
+ relayManager relayClient.ManagerService
+ callBacks WorkerRelayCallbacks
+
+ relayedConn net.Conn
+ relayLock sync.Mutex
+ ctxWgWatch context.Context
+ ctxCancelWgWatch context.CancelFunc
+ ctxLock sync.Mutex
+
+ relaySupportedOnRemotePeer atomic.Bool
+}
+
+func NewWorkerRelay(log *log.Entry, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay {
+ r := &WorkerRelay{
+ log: log,
+ config: config,
+ relayManager: relayManager,
+ callBacks: callbacks,
+ }
+ return r
+}
+
+func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
+ if !w.isRelaySupported(remoteOfferAnswer) {
+ w.log.Infof("Relay is not supported by remote peer")
+ w.relaySupportedOnRemotePeer.Store(false)
+ return
+ }
+ w.relaySupportedOnRemotePeer.Store(true)
+
+ // the relayManager will return with error in case if the connection has lost with relay server
+ currentRelayAddress, err := w.relayManager.RelayInstanceAddress()
+ if err != nil {
+ w.log.Errorf("failed to handle new offer: %s", err)
+ return
+ }
+
+ srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
+
+ relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key)
+ if err != nil {
+ if errors.Is(err, relayClient.ErrConnAlreadyExists) {
+ w.log.Debugf("handled offer by reusing existing relay connection")
+ return
+ }
+ w.log.Errorf("failed to open connection via Relay: %s", err)
+ return
+ }
+ w.relayLock.Lock()
+ w.relayedConn = relayedConn
+ w.relayLock.Unlock()
+
+ err = w.relayManager.AddCloseListener(srv, w.onRelayMGDisconnected)
+ if err != nil {
+ log.Errorf("failed to add close listener: %s", err)
+ _ = relayedConn.Close()
+ return
+ }
+
+ w.log.Debugf("peer conn opened via Relay: %s", srv)
+ go w.callBacks.OnConnReady(RelayConnInfo{
+ relayedConn: relayedConn,
+ rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey,
+ rosenpassAddr: remoteOfferAnswer.RosenpassAddr,
+ })
+}
+
+func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) {
+ w.log.Debugf("enable WireGuard watcher")
+ w.ctxLock.Lock()
+ defer w.ctxLock.Unlock()
+
+ if w.ctxWgWatch != nil && w.ctxWgWatch.Err() == nil {
+ return
+ }
+
+ ctx, ctxCancel := context.WithCancel(ctx)
+ w.ctxWgWatch = ctx
+ w.ctxCancelWgWatch = ctxCancel
+
+ w.wgStateCheck(ctx, ctxCancel)
+}
+
+func (w *WorkerRelay) DisableWgWatcher() {
+ w.ctxLock.Lock()
+ defer w.ctxLock.Unlock()
+
+ if w.ctxCancelWgWatch == nil {
+ return
+ }
+
+ w.log.Debugf("disable WireGuard watcher")
+
+ w.ctxCancelWgWatch()
+}
+
+func (w *WorkerRelay) RelayInstanceAddress() (string, error) {
+ return w.relayManager.RelayInstanceAddress()
+}
+
+func (w *WorkerRelay) IsRelayConnectionSupportedWithPeer() bool {
+ return w.relaySupportedOnRemotePeer.Load() && w.RelayIsSupportedLocally()
+}
+
+func (w *WorkerRelay) IsController() bool {
+ return w.config.LocalKey > w.config.Key
+}
+
+func (w *WorkerRelay) RelayIsSupportedLocally() bool {
+ return w.relayManager.HasRelayAddress()
+}
+
+func (w *WorkerRelay) CloseConn() {
+ w.relayLock.Lock()
+ defer w.relayLock.Unlock()
+ if w.relayedConn == nil {
+ return
+ }
+
+ err := w.relayedConn.Close()
+ if err != nil {
+ w.log.Warnf("failed to close relay connection: %v", err)
+ }
+}
+
+// wgStateCheck help to check the state of the WireGuard handshake and relay connection
+func (w *WorkerRelay) wgStateCheck(ctx context.Context, ctxCancel context.CancelFunc) {
+ w.log.Debugf("WireGuard watcher started")
+ lastHandshake, err := w.wgState()
+ if err != nil {
+ w.log.Warnf("failed to read wg stats: %v", err)
+ lastHandshake = time.Time{}
+ }
+
+ go func(lastHandshake time.Time) {
+ timer := time.NewTimer(wgHandshakeOvertime)
+ defer timer.Stop()
+ defer ctxCancel()
+
+ for {
+ select {
+ case <-timer.C:
+ handshake, err := w.wgState()
+ if err != nil {
+ w.log.Errorf("failed to read wg stats: %v", err)
+ timer.Reset(wgHandshakeOvertime)
+ continue
+ }
+
+ w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake)
+
+ if handshake.Equal(lastHandshake) {
+ w.log.Infof("WireGuard handshake timed out, closing relay connection: %v", handshake)
+ w.relayLock.Lock()
+ _ = w.relayedConn.Close()
+ w.relayLock.Unlock()
+ w.callBacks.OnDisconnected()
+ return
+ }
+
+ resetTime := time.Until(handshake.Add(wgHandshakePeriod + wgHandshakeOvertime))
+ lastHandshake = handshake
+ timer.Reset(resetTime)
+ case <-ctx.Done():
+ w.log.Debugf("WireGuard watcher stopped")
+ return
+ }
+ }
+ }(lastHandshake)
+
+}
+
+func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool {
+ if !w.relayManager.HasRelayAddress() {
+ return false
+ }
+ return answer.RelaySrvAddress != ""
+}
+
+func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress string) string {
+ if w.IsController() {
+ return myRelayAddress
+ }
+ return remoteRelayAddress
+}
+
+func (w *WorkerRelay) wgState() (time.Time, error) {
+ wgState, err := w.config.WgConfig.WgInterface.GetStats(w.config.Key)
+ if err != nil {
+ return time.Time{}, err
+ }
+ return wgState.LastHandshake, nil
+}
+
+func (w *WorkerRelay) onRelayMGDisconnected() {
+ w.ctxLock.Lock()
+ defer w.ctxLock.Unlock()
+
+ if w.ctxCancelWgWatch != nil {
+ w.ctxCancelWgWatch()
+ }
+ go w.callBacks.OnDisconnected()
+}
diff --git a/client/internal/probe.go b/client/internal/probe.go
index 743b6b190..23290cf74 100644
--- a/client/internal/probe.go
+++ b/client/internal/probe.go
@@ -2,6 +2,13 @@ package internal
import "context"
+type ProbeHolder struct {
+ MgmProbe *Probe
+ SignalProbe *Probe
+ RelayProbe *Probe
+ WgProbe *Probe
+}
+
// Probe allows to run on-demand callbacks from different code locations.
// Pass the probe to a receiving and a sending end. The receiving end starts listening
// to requests with Receive and executes a callback when the sending end requests it
diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go
index 4542a37fe..7d98a6060 100644
--- a/client/internal/relay/relay.go
+++ b/client/internal/relay/relay.go
@@ -17,7 +17,7 @@ import (
// ProbeResult holds the info about the result of a relay probe request
type ProbeResult struct {
- URI *stun.URI
+ URI string
Err error
Addr string
}
@@ -176,7 +176,7 @@ func ProbeAll(
wg.Add(1)
go func(res *ProbeResult, stunURI *stun.URI) {
defer wg.Done()
- res.URI = stunURI
+ res.URI = stunURI.String()
res.Addr, res.Err = fn(ctx, stunURI)
}(&results[i], uri)
}
diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go
index 1566d10dd..eaa232151 100644
--- a/client/internal/routemanager/client.go
+++ b/client/internal/routemanager/client.go
@@ -10,19 +10,18 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
+ "github.com/netbirdio/netbird/client/iface"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/static"
- "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
type routerPeerStatus struct {
connected bool
relayed bool
- direct bool
latency time.Duration
}
@@ -44,7 +43,7 @@ type clientNetwork struct {
ctx context.Context
cancel context.CancelFunc
statusRecorder *peer.Status
- wgInterface *iface.WGIface
+ wgInterface iface.IWGIface
routes map[route.ID]*route.Route
routeUpdate chan routesUpdate
peerStateUpdate chan struct{}
@@ -54,7 +53,7 @@ type clientNetwork struct {
updateSerial uint64
}
-func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface *iface.WGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork {
+func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface iface.IWGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork {
ctx, cancel := context.WithCancel(ctx)
client := &clientNetwork{
@@ -82,7 +81,6 @@ func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
routePeerStatuses[r.ID] = routerPeerStatus{
connected: peerStatus.ConnStatus == peer.StatusConnected,
relayed: peerStatus.Relayed,
- direct: peerStatus.Direct,
latency: peerStatus.Latency,
}
}
@@ -97,8 +95,8 @@ func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus {
// * Connected peers: Only routes with connected peers are considered.
// * Metric: Routes with lower metrics (better) are prioritized.
// * Non-relayed: Routes without relays are preferred.
-// * Direct connections: Routes with direct peer connections are favored.
// * Latency: Routes with lower latency are prioritized.
+// * we compare the current score + 10ms to the chosen score to avoid flapping between routes
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
//
// It returns the ID of the selected optimal route.
@@ -137,10 +135,6 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID]
tempScore++
}
- if peerStatus.direct {
- tempScore++
- }
-
if tempScore > chosenScore || (tempScore == chosenScore && chosen == "") {
chosen = r.ID
chosenScore = tempScore
@@ -384,7 +378,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
}
}
-func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface *iface.WGIface) RouteHandler {
+func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface iface.IWGIface) RouteHandler {
if rt.IsDynamic() {
dns := nbdns.NewServiceViaMemory(wgInterface)
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()))
diff --git a/client/internal/routemanager/client_test.go b/client/internal/routemanager/client_test.go
index 0ae10e568..583156e4d 100644
--- a/client/internal/routemanager/client_test.go
+++ b/client/internal/routemanager/client_test.go
@@ -24,7 +24,6 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
"route1": {
connected: true,
relayed: false,
- direct: true,
},
},
existingRoutes: map[route.ID]*route.Route{
@@ -43,7 +42,6 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
"route1": {
connected: true,
relayed: true,
- direct: true,
},
},
existingRoutes: map[route.ID]*route.Route{
@@ -62,7 +60,6 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
"route1": {
connected: true,
relayed: true,
- direct: false,
},
},
existingRoutes: map[route.ID]*route.Route{
@@ -81,7 +78,6 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
"route1": {
connected: false,
relayed: false,
- direct: false,
},
},
existingRoutes: map[route.ID]*route.Route{
@@ -100,12 +96,10 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
"route1": {
connected: true,
relayed: false,
- direct: true,
},
"route2": {
connected: true,
relayed: false,
- direct: true,
},
},
existingRoutes: map[route.ID]*route.Route{
@@ -129,41 +123,10 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
"route1": {
connected: true,
relayed: false,
- direct: true,
},
"route2": {
connected: true,
relayed: true,
- direct: true,
- },
- },
- existingRoutes: map[route.ID]*route.Route{
- "route1": {
- ID: "route1",
- Metric: route.MaxMetric,
- Peer: "peer1",
- },
- "route2": {
- ID: "route2",
- Metric: route.MaxMetric,
- Peer: "peer2",
- },
- },
- currentRoute: "",
- expectedRouteID: "route1",
- },
- {
- name: "multiple connected peers with one direct",
- statuses: map[route.ID]routerPeerStatus{
- "route1": {
- connected: true,
- relayed: false,
- direct: true,
- },
- "route2": {
- connected: true,
- relayed: false,
- direct: false,
},
},
existingRoutes: map[route.ID]*route.Route{
@@ -241,13 +204,11 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
"route1": {
connected: true,
relayed: false,
- direct: true,
latency: 15 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
- direct: true,
latency: 10 * time.Millisecond,
},
},
@@ -272,13 +233,11 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
"route1": {
connected: true,
relayed: false,
- direct: true,
latency: 200 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
- direct: true,
latency: 10 * time.Millisecond,
},
},
@@ -303,13 +262,11 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
"route1": {
connected: true,
relayed: false,
- direct: true,
latency: 20 * time.Millisecond,
},
"route2": {
connected: true,
relayed: false,
- direct: true,
latency: 10 * time.Millisecond,
},
},
diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go
index 3296f3ddf..ac94d4a5c 100644
--- a/client/internal/routemanager/dynamic/route.go
+++ b/client/internal/routemanager/dynamic/route.go
@@ -13,10 +13,10 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
+ "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
- "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
@@ -48,7 +48,7 @@ type Route struct {
currentPeerKey string
cancel context.CancelFunc
statusRecorder *peer.Status
- wgInterface *iface.WGIface
+ wgInterface iface.IWGIface
resolverAddr string
}
@@ -58,7 +58,7 @@ func NewRoute(
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
interval time.Duration,
statusRecorder *peer.Status,
- wgInterface *iface.WGIface,
+ wgInterface iface.IWGIface,
resolverAddr string,
) *Route {
return &Route{
@@ -303,7 +303,7 @@ func (r *Route) addRoutes(domain domain.Domain, prefixes []netip.Prefix) ([]neti
var merr *multierror.Error
for _, prefix := range prefixes {
- if _, err := r.routeRefCounter.Increment(prefix, nil); err != nil {
+ if _, err := r.routeRefCounter.Increment(prefix, struct{}{}); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add dynamic route for IP %s: %w", prefix, err))
continue
}
diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go
index 0b10dbe33..d7ddf7ae8 100644
--- a/client/internal/routemanager/manager.go
+++ b/client/internal/routemanager/manager.go
@@ -14,6 +14,8 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
+ "github.com/netbirdio/netbird/client/iface"
+ "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
@@ -21,7 +23,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/routeselector"
- "github.com/netbirdio/netbird/iface"
+ relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route"
nbnet "github.com/netbirdio/netbird/util/net"
"github.com/netbirdio/netbird/version"
@@ -49,7 +51,8 @@ type DefaultManager struct {
serverRouter serverRouter
sysOps *systemops.SysOps
statusRecorder *peer.Status
- wgInterface *iface.WGIface
+ relayMgr *relayClient.Manager
+ wgInterface iface.IWGIface
pubKey string
notifier *notifier.Notifier
routeRefCounter *refcounter.RouteRefCounter
@@ -61,8 +64,9 @@ func NewManager(
ctx context.Context,
pubKey string,
dnsRouteInterval time.Duration,
- wgInterface *iface.WGIface,
+ wgInterface iface.IWGIface,
statusRecorder *peer.Status,
+ relayMgr *relayClient.Manager,
initialRoutes []*route.Route,
) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx)
@@ -74,6 +78,7 @@ func NewManager(
stop: cancel,
dnsRouteInterval: dnsRouteInterval,
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
+ relayMgr: relayMgr,
routeSelector: routeselector.NewRouteSelector(),
sysOps: sysOps,
statusRecorder: statusRecorder,
@@ -83,10 +88,10 @@ func NewManager(
}
dm.routeRefCounter = refcounter.New(
- func(prefix netip.Prefix, _ any) (any, error) {
- return nil, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface())
+ func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
+ return struct{}{}, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface())
},
- func(prefix netip.Prefix, _ any) error {
+ func(prefix netip.Prefix, _ struct{}) error {
return sysOps.RemoveVPNRoute(prefix, wgInterface.ToInterface())
},
)
@@ -98,7 +103,7 @@ func NewManager(
},
func(prefix netip.Prefix, peerKey string) error {
if err := wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil {
- if !errors.Is(err, iface.ErrPeerNotFound) && !errors.Is(err, iface.ErrAllowedIPNotFound) {
+ if !errors.Is(err, configurer.ErrPeerNotFound) && !errors.Is(err, configurer.ErrAllowedIPNotFound) {
return err
}
log.Tracef("Remove allowed IPs %s for %s: %v", prefix, peerKey, err)
@@ -124,9 +129,12 @@ func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
log.Warnf("Failed cleaning up routing: %v", err)
}
- mgmtAddress := m.statusRecorder.GetManagementState().URL
- signalAddress := m.statusRecorder.GetSignalState().URL
- ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress})
+ initialAddresses := []string{m.statusRecorder.GetManagementState().URL, m.statusRecorder.GetSignalState().URL}
+ if m.relayMgr != nil {
+ initialAddresses = append(initialAddresses, m.relayMgr.ServerURLs()...)
+ }
+
+ ips := resolveURLsToIPs(initialAddresses)
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips)
if err != nil {
diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go
index 455c7ac0b..2f26f7a5e 100644
--- a/client/internal/routemanager/manager_test.go
+++ b/client/internal/routemanager/manager_test.go
@@ -12,8 +12,8 @@ import (
"github.com/stretchr/testify/require"
+ "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer"
- "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
@@ -416,7 +416,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
statusRecorder := peer.NewRecorder("https://mgm")
ctx := context.TODO()
- routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil)
+ routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil)
_, _, err = routeManager.Init()
diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go
index 58a66715c..908279c88 100644
--- a/client/internal/routemanager/mock.go
+++ b/client/internal/routemanager/mock.go
@@ -5,9 +5,9 @@ import (
"fmt"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
+ "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/routeselector"
- "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/util/net"
)
diff --git a/client/internal/routemanager/refcounter/refcounter.go b/client/internal/routemanager/refcounter/refcounter.go
index f1d696ad9..65ea0f708 100644
--- a/client/internal/routemanager/refcounter/refcounter.go
+++ b/client/internal/routemanager/refcounter/refcounter.go
@@ -3,7 +3,8 @@ package refcounter
import (
"errors"
"fmt"
- "net/netip"
+ "runtime"
+ "strings"
"sync"
"github.com/hashicorp/go-multierror"
@@ -12,118 +13,153 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
)
-// ErrIgnore can be returned by AddFunc to indicate that the counter not be incremented for the given prefix.
+const logLevel = log.TraceLevel
+
+// ErrIgnore can be returned by AddFunc to indicate that the counter should not be incremented for the given key.
var ErrIgnore = errors.New("ignore")
+// Ref holds the reference count and associated data for a key.
type Ref[O any] struct {
Count int
Out O
}
-type AddFunc[I, O any] func(prefix netip.Prefix, in I) (out O, err error)
-type RemoveFunc[I, O any] func(prefix netip.Prefix, out O) error
+// AddFunc is the function type for adding a new key.
+// Key is the type of the key (e.g., netip.Prefix).
+type AddFunc[Key, I, O any] func(key Key, in I) (out O, err error)
-type Counter[I, O any] struct {
- // refCountMap keeps track of the reference Ref for prefixes
- refCountMap map[netip.Prefix]Ref[O]
+// RemoveFunc is the function type for removing a key.
+type RemoveFunc[Key, O any] func(key Key, out O) error
+
+// Counter is a generic reference counter for managing keys and their associated data.
+// Key: The type of the key (e.g., netip.Prefix, string).
+//
+// I: The input type for the AddFunc. It is the input type for additional data needed
+// when adding a key, it is passed as the second argument to AddFunc.
+//
+// O: The output type for the AddFunc and RemoveFunc. This is the output returned by AddFunc.
+// It is stored and passed to RemoveFunc when the reference count reaches 0.
+//
+// The types can be aliased to a specific type using the following syntax:
+//
+// type RouteRefCounter = Counter[netip.Prefix, any, any]
+type Counter[Key comparable, I, O any] struct {
+ // refCountMap keeps track of the reference Ref for keys
+ refCountMap map[Key]Ref[O]
refCountMu sync.Mutex
- // idMap keeps track of the prefixes associated with an ID for removal
- idMap map[string][]netip.Prefix
+ // idMap keeps track of the keys associated with an ID for removal
+ idMap map[string][]Key
idMu sync.Mutex
- add AddFunc[I, O]
- remove RemoveFunc[I, O]
+ add AddFunc[Key, I, O]
+ remove RemoveFunc[Key, O]
}
-// New creates a new Counter instance
-func New[I, O any](add AddFunc[I, O], remove RemoveFunc[I, O]) *Counter[I, O] {
- return &Counter[I, O]{
- refCountMap: map[netip.Prefix]Ref[O]{},
- idMap: map[string][]netip.Prefix{},
+// New creates a new Counter instance.
+// Usage example:
+//
+// counter := New[netip.Prefix, string, string](
+// func(key netip.Prefix, in string) (out string, err error) { ... },
+// func(key netip.Prefix, out string) error { ... },`
+// )
+func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key, O]) *Counter[Key, I, O] {
+ return &Counter[Key, I, O]{
+ refCountMap: map[Key]Ref[O]{},
+ idMap: map[string][]Key{},
add: add,
remove: remove,
}
}
-// Increment increments the reference count for the given prefix.
-// If this is the first reference to the prefix, the AddFunc is called.
-func (rm *Counter[I, O]) Increment(prefix netip.Prefix, in I) (Ref[O], error) {
+// Get retrieves the current reference count and associated data for a key.
+// If the key doesn't exist, it returns a zero value Ref and false.
+func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
- ref := rm.refCountMap[prefix]
- log.Tracef("Increasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
+ ref, ok := rm.refCountMap[key]
+ return ref, ok
+}
- // Call AddFunc only if it's a new prefix
+// Increment increments the reference count for the given key.
+// If this is the first reference to the key, the AddFunc is called.
+func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
+ rm.refCountMu.Lock()
+ defer rm.refCountMu.Unlock()
+
+ ref := rm.refCountMap[key]
+ logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out)
+
+ // Call AddFunc only if it's a new key
if ref.Count == 0 {
- log.Tracef("Adding for prefix %s with [%v]", prefix, ref.Out)
- out, err := rm.add(prefix, in)
+ logCallerF("Calling add for key %v", key)
+ out, err := rm.add(key, in)
if errors.Is(err, ErrIgnore) {
return ref, nil
}
if err != nil {
- return ref, fmt.Errorf("failed to add for prefix %s: %w", prefix, err)
+ return ref, fmt.Errorf("failed to add for key %v: %w", key, err)
}
ref.Out = out
}
ref.Count++
- rm.refCountMap[prefix] = ref
+ rm.refCountMap[key] = ref
return ref, nil
}
-// IncrementWithID increments the reference count for the given prefix and groups it under the given ID.
-// If this is the first reference to the prefix, the AddFunc is called.
-func (rm *Counter[I, O]) IncrementWithID(id string, prefix netip.Prefix, in I) (Ref[O], error) {
+// IncrementWithID increments the reference count for the given key and groups it under the given ID.
+// If this is the first reference to the key, the AddFunc is called.
+func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) {
rm.idMu.Lock()
defer rm.idMu.Unlock()
- ref, err := rm.Increment(prefix, in)
+ ref, err := rm.Increment(key, in)
if err != nil {
return ref, fmt.Errorf("with ID: %w", err)
}
- rm.idMap[id] = append(rm.idMap[id], prefix)
+ rm.idMap[id] = append(rm.idMap[id], key)
return ref, nil
}
-// Decrement decrements the reference count for the given prefix.
+// Decrement decrements the reference count for the given key.
// If the reference count reaches 0, the RemoveFunc is called.
-func (rm *Counter[I, O]) Decrement(prefix netip.Prefix) (Ref[O], error) {
+func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
- ref, ok := rm.refCountMap[prefix]
+ ref, ok := rm.refCountMap[key]
if !ok {
- log.Tracef("No reference found for prefix %s", prefix)
+ logCallerF("No reference found for key %v", key)
return ref, nil
}
- log.Tracef("Decreasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out)
+ logCallerF("Decreasing ref count [%d -> %d] for key %v with Out [%v]", ref.Count, ref.Count-1, key, ref.Out)
if ref.Count == 1 {
- log.Tracef("Removing for prefix %s with [%v]", prefix, ref.Out)
- if err := rm.remove(prefix, ref.Out); err != nil {
- return ref, fmt.Errorf("remove for prefix %s: %w", prefix, err)
+ logCallerF("Calling remove for key %v", key)
+ if err := rm.remove(key, ref.Out); err != nil {
+ return ref, fmt.Errorf("remove for key %v: %w", key, err)
}
- delete(rm.refCountMap, prefix)
+ delete(rm.refCountMap, key)
} else {
ref.Count--
- rm.refCountMap[prefix] = ref
+ rm.refCountMap[key] = ref
}
return ref, nil
}
-// DecrementWithID decrements the reference count for all prefixes associated with the given ID.
+// DecrementWithID decrements the reference count for all keys associated with the given ID.
// If the reference count reaches 0, the RemoveFunc is called.
-func (rm *Counter[I, O]) DecrementWithID(id string) error {
+func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
rm.idMu.Lock()
defer rm.idMu.Unlock()
var merr *multierror.Error
- for _, prefix := range rm.idMap[id] {
- if _, err := rm.Decrement(prefix); err != nil {
+ for _, key := range rm.idMap[id] {
+ if _, err := rm.Decrement(key); err != nil {
merr = multierror.Append(merr, err)
}
}
@@ -132,24 +168,77 @@ func (rm *Counter[I, O]) DecrementWithID(id string) error {
return nberrors.FormatErrorOrNil(merr)
}
-// Flush removes all references and calls RemoveFunc for each prefix.
-func (rm *Counter[I, O]) Flush() error {
+// Flush removes all references and calls RemoveFunc for each key.
+func (rm *Counter[Key, I, O]) Flush() error {
rm.refCountMu.Lock()
defer rm.refCountMu.Unlock()
rm.idMu.Lock()
defer rm.idMu.Unlock()
var merr *multierror.Error
- for prefix := range rm.refCountMap {
- log.Tracef("Removing for prefix %s", prefix)
- ref := rm.refCountMap[prefix]
- if err := rm.remove(prefix, ref.Out); err != nil {
- merr = multierror.Append(merr, fmt.Errorf("remove for prefix %s: %w", prefix, err))
+ for key := range rm.refCountMap {
+ logCallerF("Calling remove for key %v", key)
+ ref := rm.refCountMap[key]
+ if err := rm.remove(key, ref.Out); err != nil {
+ merr = multierror.Append(merr, fmt.Errorf("remove for key %v: %w", key, err))
}
}
- rm.refCountMap = map[netip.Prefix]Ref[O]{}
- rm.idMap = map[string][]netip.Prefix{}
+ clear(rm.refCountMap)
+ clear(rm.idMap)
return nberrors.FormatErrorOrNil(merr)
}
+
+// Clear removes all references without calling RemoveFunc.
+func (rm *Counter[Key, I, O]) Clear() {
+ rm.refCountMu.Lock()
+ defer rm.refCountMu.Unlock()
+ rm.idMu.Lock()
+ defer rm.idMu.Unlock()
+
+ clear(rm.refCountMap)
+ clear(rm.idMap)
+}
+
+func getCallerInfo(depth int, maxDepth int) (string, bool) {
+ if depth >= maxDepth {
+ return "", false
+ }
+
+ pc, _, _, ok := runtime.Caller(depth)
+ if !ok {
+ return "", false
+ }
+
+ if details := runtime.FuncForPC(pc); details != nil {
+ name := details.Name()
+
+ lastDotIndex := strings.LastIndex(name, "/")
+ if lastDotIndex != -1 {
+ name = name[lastDotIndex+1:]
+ }
+
+ if strings.HasPrefix(name, "refcounter.") {
+ // +2 to account for recursion
+ return getCallerInfo(depth+2, maxDepth)
+ }
+
+ return name, true
+ }
+
+ return "", false
+}
+
+// logCaller logs a message with the package name and method of the function that called the current function.
+func logCallerF(format string, args ...interface{}) {
+ if log.GetLevel() < logLevel {
+ return
+ }
+
+ if callerName, ok := getCallerInfo(3, 18); ok {
+ format = fmt.Sprintf("[%s] %s", callerName, format)
+ }
+
+ log.StandardLogger().Logf(logLevel, format, args...)
+}
diff --git a/client/internal/routemanager/refcounter/types.go b/client/internal/routemanager/refcounter/types.go
index 6753b64ef..aadac3e25 100644
--- a/client/internal/routemanager/refcounter/types.go
+++ b/client/internal/routemanager/refcounter/types.go
@@ -1,7 +1,9 @@
package refcounter
+import "net/netip"
+
// RouteRefCounter is a Counter for Route, it doesn't take any input on Increment and doesn't use any output on Decrement
-type RouteRefCounter = Counter[any, any]
+type RouteRefCounter = Counter[netip.Prefix, struct{}, struct{}]
// AllowedIPsRefCounter is a Counter for AllowedIPs, it takes a peer key on Increment and passes it back to Decrement
-type AllowedIPsRefCounter = Counter[string, string]
+type AllowedIPsRefCounter = Counter[netip.Prefix, string, string]
diff --git a/client/internal/routemanager/server_android.go b/client/internal/routemanager/server_android.go
index b4065bca6..c75a0a7f2 100644
--- a/client/internal/routemanager/server_android.go
+++ b/client/internal/routemanager/server_android.go
@@ -7,10 +7,10 @@ import (
"fmt"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
+ "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer"
- "github.com/netbirdio/netbird/iface"
)
-func newServerRouter(context.Context, *iface.WGIface, firewall.Manager, *peer.Status) (serverRouter, error) {
+func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (serverRouter, error) {
return nil, fmt.Errorf("server route not supported on this os")
}
diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go
index 8470934c2..ef38d5707 100644
--- a/client/internal/routemanager/server_nonandroid.go
+++ b/client/internal/routemanager/server_nonandroid.go
@@ -11,9 +11,9 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
+ "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
- "github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/route"
)
@@ -22,11 +22,11 @@ type defaultServerRouter struct {
ctx context.Context
routes map[route.ID]*route.Route
firewall firewall.Manager
- wgInterface *iface.WGIface
+ wgInterface iface.IWGIface
statusRecorder *peer.Status
}
-func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) {
+func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) {
return &defaultServerRouter{
ctx: ctx,
routes: make(map[route.ID]*route.Route),
@@ -94,7 +94,7 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error
return fmt.Errorf("parse prefix: %w", err)
}
- err = m.firewall.RemoveRoutingRules(routerPair)
+ err = m.firewall.RemoveNatRule(routerPair)
if err != nil {
return fmt.Errorf("remove routing rules: %w", err)
}
@@ -123,7 +123,7 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
return fmt.Errorf("parse prefix: %w", err)
}
- err = m.firewall.InsertRoutingRules(routerPair)
+ err = m.firewall.AddNatRule(routerPair)
if err != nil {
return fmt.Errorf("insert routing rules: %w", err)
}
@@ -157,7 +157,7 @@ func (m *defaultServerRouter) cleanUp() {
continue
}
- err = m.firewall.RemoveRoutingRules(routerPair)
+ err = m.firewall.RemoveNatRule(routerPair)
if err != nil {
log.Errorf("Failed to remove cleanup route: %v", err)
}
@@ -173,15 +173,15 @@ func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) {
// TODO: add ipv6
source := getDefaultPrefix(route.Network)
- destination := route.Network.Masked().String()
+ destination := route.Network.Masked()
if route.IsDynamic() {
- // TODO: add ipv6
- destination = "0.0.0.0/0"
+ // TODO: add ipv6 additionally
+ destination = getDefaultPrefix(destination)
}
return firewall.RouterPair{
- ID: string(route.ID),
- Source: source.String(),
+ ID: route.ID,
+ Source: source,
Destination: destination,
Masquerade: route.Masquerade,
}, nil
diff --git a/client/internal/routemanager/static/route.go b/client/internal/routemanager/static/route.go
index 88cca522a..98c34dbee 100644
--- a/client/internal/routemanager/static/route.go
+++ b/client/internal/routemanager/static/route.go
@@ -30,7 +30,7 @@ func (r *Route) String() string {
}
func (r *Route) AddRoute(context.Context) error {
- _, err := r.routeRefCounter.Increment(r.route.Network, nil)
+ _, err := r.routeRefCounter.Increment(r.route.Network, struct{}{})
return err
}
diff --git a/client/internal/routemanager/sysctl/sysctl_linux.go b/client/internal/routemanager/sysctl/sysctl_linux.go
index 43394a823..bb620ee68 100644
--- a/client/internal/routemanager/sysctl/sysctl_linux.go
+++ b/client/internal/routemanager/sysctl/sysctl_linux.go
@@ -13,7 +13,7 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
- "github.com/netbirdio/netbird/iface"
+ "github.com/netbirdio/netbird/client/iface"
)
const (
@@ -23,7 +23,7 @@ const (
)
// Setup configures sysctl settings for RP filtering and source validation.
-func Setup(wgIface *iface.WGIface) (map[string]int, error) {
+func Setup(wgIface iface.IWGIface) (map[string]int, error) {
keys := map[string]int{}
var result *multierror.Error
diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go
index cddd7e7e2..d1cb83bfb 100644
--- a/client/internal/routemanager/systemops/systemops.go
+++ b/client/internal/routemanager/systemops/systemops.go
@@ -5,9 +5,9 @@ import (
"net/netip"
"sync"
+ "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
- "github.com/netbirdio/netbird/iface"
)
type Nexthop struct {
@@ -15,11 +15,11 @@ type Nexthop struct {
Intf *net.Interface
}
-type ExclusionCounter = refcounter.Counter[any, Nexthop]
+type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop]
type SysOps struct {
refCounter *ExclusionCounter
- wgInterface *iface.WGIface
+ wgInterface iface.IWGIface
// prefixes is tracking all the current added prefixes im memory
// (this is used in iOS as all route updates require a full table update)
//nolint
@@ -30,7 +30,7 @@ type SysOps struct {
notifier *notifier.Notifier
}
-func NewSysOps(wgInterface *iface.WGIface, notifier *notifier.Notifier) *SysOps {
+func NewSysOps(wgInterface iface.IWGIface, notifier *notifier.Notifier) *SysOps {
return &SysOps{
wgInterface: wgInterface,
notifier: notifier,
diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go
index 671545b86..9258f4a4e 100644
--- a/client/internal/routemanager/systemops/systemops_generic.go
+++ b/client/internal/routemanager/systemops/systemops_generic.go
@@ -16,10 +16,10 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
+ "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
- "github.com/netbirdio/netbird/iface"
nbnet "github.com/netbirdio/netbird/util/net"
)
@@ -41,7 +41,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn
}
refCounter := refcounter.New(
- func(prefix netip.Prefix, _ any) (Nexthop, error) {
+ func(prefix netip.Prefix, _ struct{}) (Nexthop, error) {
initialNexthop := initialNextHopV4
if prefix.Addr().Is6() {
initialNexthop = initialNextHopV6
@@ -122,7 +122,7 @@ func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
-func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop Nexthop) (Nexthop, error) {
+func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.IWGIface, initialNextHop Nexthop) (Nexthop, error) {
addr := prefix.Addr()
switch {
case addr.IsLoopback(),
@@ -317,7 +317,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re
return fmt.Errorf("convert ip to prefix: %w", err)
}
- if _, err := r.refCounter.IncrementWithID(string(connID), prefix, nil); err != nil {
+ if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil {
return fmt.Errorf("adding route reference: %v", err)
}
diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go
index 94965c119..238225807 100644
--- a/client/internal/routemanager/systemops/systemops_generic_test.go
+++ b/client/internal/routemanager/systemops/systemops_generic_test.go
@@ -19,7 +19,7 @@ import (
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
- "github.com/netbirdio/netbird/iface"
+ "github.com/netbirdio/netbird/client/iface"
)
type dialer interface {
diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go
index 0d3630cb8..3f756788e 100644
--- a/client/internal/routemanager/systemops/systemops_windows.go
+++ b/client/internal/routemanager/systemops/systemops_windows.go
@@ -3,6 +3,8 @@
package systemops
import (
+ "context"
+ "encoding/binary"
"fmt"
"net"
"net/netip"
@@ -11,15 +13,43 @@ import (
"strconv"
"strings"
"sync"
+ "syscall"
"time"
+ "unsafe"
log "github.com/sirupsen/logrus"
"github.com/yusufpapurcu/wmi"
+ "golang.org/x/sys/windows"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
nbnet "github.com/netbirdio/netbird/util/net"
)
+type RouteUpdateType int
+
+// RouteUpdate represents a change in the routing table.
+// The interface field contains the index only.
+type RouteUpdate struct {
+ Type RouteUpdateType
+ Destination netip.Prefix
+ NextHop netip.Addr
+ Interface *net.Interface
+}
+
+// RouteMonitor provides a way to monitor changes in the routing table.
+type RouteMonitor struct {
+ updates chan RouteUpdate
+ handle windows.Handle
+ done chan struct{}
+}
+
+// Route represents a single routing table entry.
+type Route struct {
+ Destination netip.Prefix
+ Nexthop netip.Addr
+ Interface *net.Interface
+}
+
type MSFT_NetRoute struct {
DestinationPrefix string
NextHop string
@@ -28,33 +58,77 @@ type MSFT_NetRoute struct {
AddressFamily uint16
}
-type Route struct {
- Destination netip.Prefix
- Nexthop netip.Addr
- Interface *net.Interface
+// MIB_IPFORWARD_ROW2 is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-mib_ipforward_row2
+type MIB_IPFORWARD_ROW2 struct {
+ InterfaceLuid uint64
+ InterfaceIndex uint32
+ DestinationPrefix IP_ADDRESS_PREFIX
+ NextHop SOCKADDR_INET_NEXTHOP
+ SitePrefixLength uint8
+ ValidLifetime uint32
+ PreferredLifetime uint32
+ Metric uint32
+ Protocol uint32
+ Loopback uint8
+ AutoconfigureAddress uint8
+ Publish uint8
+ Immortal uint8
+ Age uint32
+ Origin uint32
}
-type MSFT_NetNeighbor struct {
- IPAddress string
- LinkLayerAddress string
- State uint8
- AddressFamily uint16
- InterfaceIndex uint32
- InterfaceAlias string
+// IP_ADDRESS_PREFIX is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-ip_address_prefix
+type IP_ADDRESS_PREFIX struct {
+ Prefix SOCKADDR_INET
+ PrefixLength uint8
}
-type Neighbor struct {
- IPAddress netip.Addr
- LinkLayerAddress string
- State uint8
- AddressFamily uint16
- InterfaceIndex uint32
- InterfaceAlias string
+// SOCKADDR_INET is defined in https://learn.microsoft.com/en-us/windows/win32/api/ws2ipdef/ns-ws2ipdef-sockaddr_inet
+// It represents the union of IPv4 and IPv6 socket addresses
+type SOCKADDR_INET struct {
+ sin6_family int16
+ // nolint:unused
+ sin6_port uint16
+ // 4 bytes ipv4 or 4 bytes flowinfo + 16 bytes ipv6 + 4 bytes scope_id
+ data [24]byte
}
-var prefixList []netip.Prefix
-var lastUpdate time.Time
-var mux = sync.Mutex{}
+// SOCKADDR_INET_NEXTHOP is the same as SOCKADDR_INET but offset by 2 bytes
+type SOCKADDR_INET_NEXTHOP struct {
+ // nolint:unused
+ pad [2]byte
+ sin6_family int16
+ // nolint:unused
+ sin6_port uint16
+ // 4 bytes ipv4 or 4 bytes flowinfo + 16 bytes ipv6 + 4 bytes scope_id
+ data [24]byte
+}
+
+// MIB_NOTIFICATION_TYPE is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ne-netioapi-mib_notification_type
+type MIB_NOTIFICATION_TYPE int32
+
+var (
+ modiphlpapi = windows.NewLazyDLL("iphlpapi.dll")
+ procNotifyRouteChange2 = modiphlpapi.NewProc("NotifyRouteChange2")
+ procCancelMibChangeNotify2 = modiphlpapi.NewProc("CancelMibChangeNotify2")
+
+ prefixList []netip.Prefix
+ lastUpdate time.Time
+ mux sync.Mutex
+)
+
+const (
+ MibParemeterModification MIB_NOTIFICATION_TYPE = iota
+ MibAddInstance
+ MibDeleteInstance
+ MibInitialNotification
+)
+
+const (
+ RouteModified RouteUpdateType = iota
+ RouteAdded
+ RouteDeleted
+)
func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
return r.setupRefCounter(initAddresses)
@@ -94,6 +168,155 @@ func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) erro
return nil
}
+// NewRouteMonitor creates and starts a new RouteMonitor.
+// It returns a pointer to the RouteMonitor and an error if the monitor couldn't be started.
+func NewRouteMonitor(ctx context.Context) (*RouteMonitor, error) {
+ rm := &RouteMonitor{
+ updates: make(chan RouteUpdate, 5),
+ done: make(chan struct{}),
+ }
+
+ if err := rm.start(ctx); err != nil {
+ return nil, err
+ }
+
+ return rm, nil
+}
+
+func (rm *RouteMonitor) start(ctx context.Context) error {
+ if ctx.Err() != nil {
+ return ctx.Err()
+ }
+
+ callbackPtr := windows.NewCallback(func(callerContext uintptr, row *MIB_IPFORWARD_ROW2, notificationType MIB_NOTIFICATION_TYPE) uintptr {
+ if ctx.Err() != nil {
+ return 0
+ }
+
+ update, err := rm.parseUpdate(row, notificationType)
+ if err != nil {
+ log.Errorf("Failed to parse route update: %v", err)
+ return 0
+ }
+
+ select {
+ case <-rm.done:
+ return 0
+ case rm.updates <- update:
+ default:
+ log.Warn("Route update channel is full, dropping update")
+ }
+ return 0
+ })
+
+ var handle windows.Handle
+ if err := notifyRouteChange2(windows.AF_UNSPEC, callbackPtr, 0, false, &handle); err != nil {
+ return fmt.Errorf("NotifyRouteChange2 failed: %w", err)
+ }
+
+ rm.handle = handle
+
+ return nil
+}
+
+func (rm *RouteMonitor) parseUpdate(row *MIB_IPFORWARD_ROW2, notificationType MIB_NOTIFICATION_TYPE) (RouteUpdate, error) {
+ // destination prefix, next hop, interface index, interface luid are guaranteed to be there
+ // GetIpForwardEntry2 is not needed
+
+ var update RouteUpdate
+
+ idx := int(row.InterfaceIndex)
+ if idx != 0 {
+ intf, err := net.InterfaceByIndex(idx)
+ if err != nil {
+ return update, fmt.Errorf("get interface name: %w", err)
+ }
+
+ update.Interface = intf
+ }
+
+ log.Tracef("Received route update with destination %v, next hop %v, interface %v", row.DestinationPrefix, row.NextHop, update.Interface)
+ dest := parseIPPrefix(row.DestinationPrefix, idx)
+ if !dest.Addr().IsValid() {
+ return RouteUpdate{}, fmt.Errorf("invalid destination: %v", row)
+ }
+
+ nexthop := parseIPNexthop(row.NextHop, idx)
+ if !nexthop.IsValid() {
+ return RouteUpdate{}, fmt.Errorf("invalid next hop %v", row)
+ }
+
+ updateType := RouteModified
+ switch notificationType {
+ case MibParemeterModification:
+ updateType = RouteModified
+ case MibAddInstance:
+ updateType = RouteAdded
+ case MibDeleteInstance:
+ updateType = RouteDeleted
+ }
+
+ update.Type = updateType
+ update.Destination = dest
+ update.NextHop = nexthop
+
+ return update, nil
+}
+
+// Stop stops the RouteMonitor.
+func (rm *RouteMonitor) Stop() error {
+ if rm.handle != 0 {
+ if err := cancelMibChangeNotify2(rm.handle); err != nil {
+ return fmt.Errorf("CancelMibChangeNotify2 failed: %w", err)
+ }
+ rm.handle = 0
+ }
+ close(rm.done)
+ close(rm.updates)
+ return nil
+}
+
+// RouteUpdates returns a channel that receives RouteUpdate messages.
+func (rm *RouteMonitor) RouteUpdates() <-chan RouteUpdate {
+ return rm.updates
+}
+
+func notifyRouteChange2(family uint32, callback uintptr, callerContext uintptr, initialNotification bool, handle *windows.Handle) error {
+ var initNotif uint32
+ if initialNotification {
+ initNotif = 1
+ }
+
+ r1, _, e1 := syscall.SyscallN(
+ procNotifyRouteChange2.Addr(),
+ uintptr(family),
+ callback,
+ callerContext,
+ uintptr(initNotif),
+ uintptr(unsafe.Pointer(handle)),
+ )
+ if r1 != 0 {
+ if e1 != 0 {
+ return e1
+ }
+ return syscall.EINVAL
+ }
+ return nil
+}
+
+func cancelMibChangeNotify2(handle windows.Handle) error {
+ r1, _, e1 := syscall.SyscallN(procCancelMibChangeNotify2.Addr(), uintptr(handle))
+ if r1 != 0 {
+ if e1 != 0 {
+ return e1
+ }
+ return syscall.EINVAL
+ }
+ return nil
+}
+
+// GetRoutesFromTable returns the current routing table from with prefixes only.
+// It ccaches the result for 2 seconds to avoid blocking the caller.
func GetRoutesFromTable() ([]netip.Prefix, error) {
mux.Lock()
defer mux.Unlock()
@@ -117,6 +340,7 @@ func GetRoutesFromTable() ([]netip.Prefix, error) {
return prefixList, nil
}
+// GetRoutes retrieves the current routing table using WMI.
func GetRoutes() ([]Route, error) {
var entries []MSFT_NetRoute
@@ -146,8 +370,8 @@ func GetRoutes() ([]Route, error) {
Name: entry.InterfaceAlias,
}
- if nexthop.Is6() && (nexthop.IsLinkLocalUnicast() || nexthop.IsLinkLocalMulticast()) {
- nexthop = nexthop.WithZone(strconv.Itoa(int(entry.InterfaceIndex)))
+ if nexthop.Is6() {
+ nexthop = addZone(nexthop, int(entry.InterfaceIndex))
}
}
@@ -161,33 +385,6 @@ func GetRoutes() ([]Route, error) {
return routes, nil
}
-func GetNeighbors() ([]Neighbor, error) {
- var entries []MSFT_NetNeighbor
- query := `SELECT IPAddress, LinkLayerAddress, State, AddressFamily, InterfaceIndex, InterfaceAlias FROM MSFT_NetNeighbor`
- if err := wmi.QueryNamespace(query, &entries, `ROOT\StandardCimv2`); err != nil {
- return nil, fmt.Errorf("failed to query MSFT_NetNeighbor: %w", err)
- }
-
- var neighbors []Neighbor
- for _, entry := range entries {
- addr, err := netip.ParseAddr(entry.IPAddress)
- if err != nil {
- log.Warnf("Unable to parse neighbor IP address %s: %v", entry.IPAddress, err)
- continue
- }
- neighbors = append(neighbors, Neighbor{
- IPAddress: addr,
- LinkLayerAddress: entry.LinkLayerAddress,
- State: entry.State,
- AddressFamily: entry.AddressFamily,
- InterfaceIndex: entry.InterfaceIndex,
- InterfaceAlias: entry.InterfaceAlias,
- })
- }
-
- return neighbors, nil
-}
-
func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error {
args := []string{"add", prefix.String()}
@@ -220,3 +417,54 @@ func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error {
func isCacheDisabled() bool {
return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true"
}
+
+func parseIPPrefix(prefix IP_ADDRESS_PREFIX, idx int) netip.Prefix {
+ ip := parseIP(prefix.Prefix, idx)
+ return netip.PrefixFrom(ip, int(prefix.PrefixLength))
+}
+
+func parseIP(addr SOCKADDR_INET, idx int) netip.Addr {
+ return parseIPGeneric(addr.sin6_family, addr.data, idx)
+}
+
+func parseIPNexthop(addr SOCKADDR_INET_NEXTHOP, idx int) netip.Addr {
+ return parseIPGeneric(addr.sin6_family, addr.data, idx)
+}
+
+func parseIPGeneric(family int16, data [24]byte, interfaceIndex int) netip.Addr {
+ switch family {
+ case windows.AF_INET:
+ ipv4 := binary.BigEndian.Uint32(data[:4])
+ return netip.AddrFrom4([4]byte{
+ byte(ipv4 >> 24),
+ byte(ipv4 >> 16),
+ byte(ipv4 >> 8),
+ byte(ipv4),
+ })
+
+ case windows.AF_INET6:
+ // The IPv6 address is stored after the 4-byte flowinfo field
+ var ipv6 [16]byte
+ copy(ipv6[:], data[4:20])
+ ip := netip.AddrFrom16(ipv6)
+
+ // Check if there's a non-zero scope_id
+ scopeID := binary.BigEndian.Uint32(data[20:24])
+ if scopeID != 0 {
+ ip = ip.WithZone(strconv.FormatUint(uint64(scopeID), 10))
+ } else if interfaceIndex != 0 {
+ ip = addZone(ip, interfaceIndex)
+ }
+
+ return ip
+ }
+
+ return netip.IPv4Unspecified()
+}
+
+func addZone(ip netip.Addr, interfaceIndex int) netip.Addr {
+ if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
+ ip = ip.WithZone(strconv.Itoa(interfaceIndex))
+ }
+ return ip
+}
diff --git a/client/internal/wgproxy/portlookup.go b/client/internal/wgproxy/ebpf/portlookup.go
similarity index 96%
rename from client/internal/wgproxy/portlookup.go
rename to client/internal/wgproxy/ebpf/portlookup.go
index 6f3d33487..0e2c20c99 100644
--- a/client/internal/wgproxy/portlookup.go
+++ b/client/internal/wgproxy/ebpf/portlookup.go
@@ -1,4 +1,4 @@
-package wgproxy
+package ebpf
import (
"fmt"
diff --git a/client/internal/wgproxy/portlookup_test.go b/client/internal/wgproxy/ebpf/portlookup_test.go
similarity index 97%
rename from client/internal/wgproxy/portlookup_test.go
rename to client/internal/wgproxy/ebpf/portlookup_test.go
index 6a386f330..92f4b8eee 100644
--- a/client/internal/wgproxy/portlookup_test.go
+++ b/client/internal/wgproxy/ebpf/portlookup_test.go
@@ -1,4 +1,4 @@
-package wgproxy
+package ebpf
import (
"fmt"
diff --git a/client/internal/wgproxy/proxy_ebpf.go b/client/internal/wgproxy/ebpf/proxy.go
similarity index 64%
rename from client/internal/wgproxy/proxy_ebpf.go
rename to client/internal/wgproxy/ebpf/proxy.go
index bbd00d6e2..27ede3ef1 100644
--- a/client/internal/wgproxy/proxy_ebpf.go
+++ b/client/internal/wgproxy/ebpf/proxy.go
@@ -1,6 +1,6 @@
//go:build linux && !android
-package wgproxy
+package ebpf
import (
"context"
@@ -13,47 +13,49 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
+ "github.com/hashicorp/go-multierror"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
+ nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/ebpf"
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
nbnet "github.com/netbirdio/netbird/util/net"
)
+const (
+ loopbackAddr = "127.0.0.1"
+)
+
// WGEBPFProxy definition for proxy with EBPF support
type WGEBPFProxy struct {
- ebpfManager ebpfMgr.Manager
-
- ctx context.Context
- cancel context.CancelFunc
-
- lastUsedPort uint16
localWGListenPort int
+ ebpfManager ebpfMgr.Manager
turnConnStore map[uint16]net.Conn
turnConnMutex sync.Mutex
- rawConn net.PacketConn
- conn transport.UDPConn
+ lastUsedPort uint16
+ rawConn net.PacketConn
+ conn transport.UDPConn
+
+ ctx context.Context
+ ctxCancel context.CancelFunc
}
// NewWGEBPFProxy create new WGEBPFProxy instance
-func NewWGEBPFProxy(ctx context.Context, wgPort int) *WGEBPFProxy {
+func NewWGEBPFProxy(wgPort int) *WGEBPFProxy {
log.Debugf("instantiate ebpf proxy")
wgProxy := &WGEBPFProxy{
localWGListenPort: wgPort,
ebpfManager: ebpf.GetEbpfManagerInstance(),
- lastUsedPort: 0,
turnConnStore: make(map[uint16]net.Conn),
}
- wgProxy.ctx, wgProxy.cancel = context.WithCancel(ctx)
-
return wgProxy
}
-// listen load ebpf program and listen the proxy
-func (p *WGEBPFProxy) listen() error {
+// Listen load ebpf program and listen the proxy
+func (p *WGEBPFProxy) Listen() error {
pl := portLookup{}
wgPorxyPort, err := pl.searchFreePort()
if err != nil {
@@ -72,13 +74,14 @@ func (p *WGEBPFProxy) listen() error {
addr := net.UDPAddr{
Port: wgPorxyPort,
- IP: net.ParseIP("127.0.0.1"),
+ IP: net.ParseIP(loopbackAddr),
}
+ p.ctx, p.ctxCancel = context.WithCancel(context.Background())
+
conn, err := nbnet.ListenUDP("udp", &addr)
if err != nil {
- cErr := p.Free()
- if cErr != nil {
+ if cErr := p.Free(); cErr != nil {
log.Errorf("Failed to close the wgproxy: %s", cErr)
}
return err
@@ -91,108 +94,114 @@ func (p *WGEBPFProxy) listen() error {
}
// AddTurnConn add new turn connection for the proxy
-func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
+func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) {
wgEndpointPort, err := p.storeTurnConn(turnConn)
if err != nil {
return nil, err
}
- go p.proxyToLocal(wgEndpointPort, turnConn)
+ go p.proxyToLocal(ctx, wgEndpointPort, turnConn)
log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort)
wgEndpoint := &net.UDPAddr{
- IP: net.ParseIP("127.0.0.1"),
+ IP: net.ParseIP(loopbackAddr),
Port: int(wgEndpointPort),
}
return wgEndpoint, nil
}
-// CloseConn doing nothing because this type of proxy implementation does not store the connection
-func (p *WGEBPFProxy) CloseConn() error {
- return nil
-}
-
-// Free resources
+// Free resources except the remoteConns will be keep open.
func (p *WGEBPFProxy) Free() error {
log.Debugf("free up ebpf wg proxy")
- var err1, err2, err3 error
- if p.conn != nil {
- err1 = p.conn.Close()
+ if p.ctx != nil && p.ctx.Err() != nil {
+ //nolint
+ return nil
}
- err2 = p.ebpfManager.FreeWGProxy()
- if p.rawConn != nil {
- err3 = p.rawConn.Close()
+ p.ctxCancel()
+
+ var result *multierror.Error
+ if p.conn != nil { // p.conn will be nil if we have failed to listen
+ if err := p.conn.Close(); err != nil {
+ result = multierror.Append(result, err)
+ }
}
- if err1 != nil {
- return err1
+ if err := p.ebpfManager.FreeWGProxy(); err != nil {
+ result = multierror.Append(result, err)
}
- if err2 != nil {
- return err2
+ if err := p.rawConn.Close(); err != nil {
+ result = multierror.Append(result, err)
}
-
- return err3
+ return nberrors.FormatErrorOrNil(result)
}
-func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) {
+func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) {
+ defer p.removeTurnConn(endpointPort)
+
+ var (
+ err error
+ n int
+ )
buf := make([]byte, 1500)
- var err error
- defer func() {
- p.removeTurnConn(endpointPort)
- }()
- for {
- select {
- case <-p.ctx.Done():
- return
- default:
- var n int
- n, err = remoteConn.Read(buf)
- if err != nil {
- if err != io.EOF {
- log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
- }
+ for ctx.Err() == nil {
+ n, err = remoteConn.Read(buf)
+ if err != nil {
+ if ctx.Err() != nil {
return
}
- err = p.sendPkg(buf[:n], endpointPort)
- if err != nil {
- log.Errorf("failed to write out turn pkg to local conn: %v", err)
+ if err != io.EOF {
+ log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err)
}
+ return
+ }
+
+ if err := p.sendPkg(buf[:n], endpointPort); err != nil {
+ if ctx.Err() != nil || p.ctx.Err() != nil {
+ return
+ }
+ log.Errorf("failed to write out turn pkg to local conn: %v", err)
}
}
}
// proxyToRemote read messages from local WireGuard interface and forward it to remote conn
+// From this go routine has only one instance.
func (p *WGEBPFProxy) proxyToRemote() {
buf := make([]byte, 1500)
- for {
- select {
- case <-p.ctx.Done():
- return
- default:
- n, addr, err := p.conn.ReadFromUDP(buf)
- if err != nil {
- log.Errorf("failed to read UDP pkg from WG: %s", err)
+ for p.ctx.Err() == nil {
+ if err := p.readAndForwardPacket(buf); err != nil {
+ if p.ctx.Err() != nil {
return
}
-
- p.turnConnMutex.Lock()
- conn, ok := p.turnConnStore[uint16(addr.Port)]
- p.turnConnMutex.Unlock()
- if !ok {
- log.Infof("turn conn not found by port: %d", addr.Port)
- continue
- }
-
- _, err = conn.Write(buf[:n])
- if err != nil {
- log.Debugf("failed to forward local wg pkg (%d) to remote turn conn: %s", addr.Port, err)
- }
+ log.Errorf("failed to proxy packet to remote conn: %s", err)
}
}
}
+func (p *WGEBPFProxy) readAndForwardPacket(buf []byte) error {
+ n, addr, err := p.conn.ReadFromUDP(buf)
+ if err != nil {
+ return fmt.Errorf("failed to read UDP packet from WG: %w", err)
+ }
+
+ p.turnConnMutex.Lock()
+ conn, ok := p.turnConnStore[uint16(addr.Port)]
+ p.turnConnMutex.Unlock()
+ if !ok {
+ if p.ctx.Err() == nil {
+ log.Debugf("turn conn not found by port because conn already has been closed: %d", addr.Port)
+ }
+ return nil
+ }
+
+ if _, err := conn.Write(buf[:n]); err != nil {
+ return fmt.Errorf("failed to forward local WG packet (%d) to remote turn conn: %w", addr.Port, err)
+ }
+ return nil
+}
+
func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
p.turnConnMutex.Lock()
defer p.turnConnMutex.Unlock()
@@ -206,11 +215,14 @@ func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) {
}
func (p *WGEBPFProxy) removeTurnConn(turnConnID uint16) {
- log.Tracef("remove turn conn from store by port: %d", turnConnID)
p.turnConnMutex.Lock()
defer p.turnConnMutex.Unlock()
- delete(p.turnConnStore, turnConnID)
+ _, ok := p.turnConnStore[turnConnID]
+ if ok {
+ log.Debugf("remove turn conn from store by port: %d", turnConnID)
+ }
+ delete(p.turnConnStore, turnConnID)
}
func (p *WGEBPFProxy) nextFreePort() (uint16, error) {
diff --git a/client/internal/wgproxy/proxy_ebpf_test.go b/client/internal/wgproxy/ebpf/proxy_test.go
similarity index 86%
rename from client/internal/wgproxy/proxy_ebpf_test.go
rename to client/internal/wgproxy/ebpf/proxy_test.go
index 821e64218..b15bc686c 100644
--- a/client/internal/wgproxy/proxy_ebpf_test.go
+++ b/client/internal/wgproxy/ebpf/proxy_test.go
@@ -1,14 +1,13 @@
//go:build linux && !android
-package wgproxy
+package ebpf
import (
- "context"
"testing"
)
func TestWGEBPFProxy_connStore(t *testing.T) {
- wgProxy := NewWGEBPFProxy(context.Background(), 1)
+ wgProxy := NewWGEBPFProxy(1)
p, _ := wgProxy.storeTurnConn(nil)
if p != 1 {
@@ -28,7 +27,7 @@ func TestWGEBPFProxy_connStore(t *testing.T) {
}
func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
- wgProxy := NewWGEBPFProxy(context.Background(), 1)
+ wgProxy := NewWGEBPFProxy(1)
_, _ = wgProxy.storeTurnConn(nil)
wgProxy.lastUsedPort = 65535
@@ -44,7 +43,7 @@ func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) {
}
func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) {
- wgProxy := NewWGEBPFProxy(context.Background(), 1)
+ wgProxy := NewWGEBPFProxy(1)
for i := 0; i < 65535; i++ {
_, _ = wgProxy.storeTurnConn(nil)
diff --git a/client/internal/wgproxy/ebpf/wrapper.go b/client/internal/wgproxy/ebpf/wrapper.go
new file mode 100644
index 000000000..c5639f840
--- /dev/null
+++ b/client/internal/wgproxy/ebpf/wrapper.go
@@ -0,0 +1,44 @@
+//go:build linux && !android
+
+package ebpf
+
+import (
+ "context"
+ "fmt"
+ "net"
+)
+
+// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call
+type ProxyWrapper struct {
+ WgeBPFProxy *WGEBPFProxy
+
+ remoteConn net.Conn
+ cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread
+}
+
+func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
+ ctxConn, cancel := context.WithCancel(ctx)
+ addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn)
+
+ if err != nil {
+ cancel()
+ return nil, fmt.Errorf("add turn conn: %w", err)
+ }
+ e.remoteConn = remoteConn
+ e.cancel = cancel
+ return addr, err
+}
+
+// CloseConn close the remoteConn and automatically remove the conn instance from the map
+func (e *ProxyWrapper) CloseConn() error {
+ if e.cancel == nil {
+ return fmt.Errorf("proxy not started")
+ }
+
+ e.cancel()
+
+ if err := e.remoteConn.Close(); err != nil {
+ return fmt.Errorf("failed to close remote conn: %w", err)
+ }
+ return nil
+}
diff --git a/client/internal/wgproxy/factory.go b/client/internal/wgproxy/factory.go
deleted file mode 100644
index f4eb150b0..000000000
--- a/client/internal/wgproxy/factory.go
+++ /dev/null
@@ -1,22 +0,0 @@
-package wgproxy
-
-import "context"
-
-type Factory struct {
- wgPort int
- ebpfProxy Proxy
-}
-
-func (w *Factory) GetProxy(ctx context.Context) Proxy {
- if w.ebpfProxy != nil {
- return w.ebpfProxy
- }
- return NewWGUserSpaceProxy(ctx, w.wgPort)
-}
-
-func (w *Factory) Free() error {
- if w.ebpfProxy != nil {
- return w.ebpfProxy.Free()
- }
- return nil
-}
diff --git a/client/internal/wgproxy/factory_linux.go b/client/internal/wgproxy/factory_linux.go
index d01ae7e74..369ba99db 100644
--- a/client/internal/wgproxy/factory_linux.go
+++ b/client/internal/wgproxy/factory_linux.go
@@ -3,20 +3,26 @@
package wgproxy
import (
- "context"
-
log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
+ "github.com/netbirdio/netbird/client/internal/wgproxy/usp"
)
-func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory {
+type Factory struct {
+ wgPort int
+ ebpfProxy *ebpf.WGEBPFProxy
+}
+
+func NewFactory(userspace bool, wgPort int) *Factory {
f := &Factory{wgPort: wgPort}
if userspace {
return f
}
- ebpfProxy := NewWGEBPFProxy(ctx, wgPort)
- err := ebpfProxy.listen()
+ ebpfProxy := ebpf.NewWGEBPFProxy(wgPort)
+ err := ebpfProxy.Listen()
if err != nil {
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
return f
@@ -25,3 +31,20 @@ func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory {
f.ebpfProxy = ebpfProxy
return f
}
+
+func (w *Factory) GetProxy() Proxy {
+ if w.ebpfProxy != nil {
+ p := &ebpf.ProxyWrapper{
+ WgeBPFProxy: w.ebpfProxy,
+ }
+ return p
+ }
+ return usp.NewWGUserSpaceProxy(w.wgPort)
+}
+
+func (w *Factory) Free() error {
+ if w.ebpfProxy == nil {
+ return nil
+ }
+ return w.ebpfProxy.Free()
+}
diff --git a/client/internal/wgproxy/factory_nonlinux.go b/client/internal/wgproxy/factory_nonlinux.go
index d1640c97d..f930b09b3 100644
--- a/client/internal/wgproxy/factory_nonlinux.go
+++ b/client/internal/wgproxy/factory_nonlinux.go
@@ -2,8 +2,20 @@
package wgproxy
-import "context"
+import "github.com/netbirdio/netbird/client/internal/wgproxy/usp"
-func NewFactory(ctx context.Context, _ bool, wgPort int) *Factory {
+type Factory struct {
+ wgPort int
+}
+
+func NewFactory(_ bool, wgPort int) *Factory {
return &Factory{wgPort: wgPort}
}
+
+func (w *Factory) GetProxy() Proxy {
+ return usp.NewWGUserSpaceProxy(w.wgPort)
+}
+
+func (w *Factory) Free() error {
+ return nil
+}
diff --git a/client/internal/wgproxy/proxy.go b/client/internal/wgproxy/proxy.go
index b88df73a0..96fae8dd1 100644
--- a/client/internal/wgproxy/proxy.go
+++ b/client/internal/wgproxy/proxy.go
@@ -1,12 +1,12 @@
package wgproxy
import (
+ "context"
"net"
)
-// Proxy is a transfer layer between the Turn connection and the WireGuard
+// Proxy is a transfer layer between the relayed connection and the WireGuard
type Proxy interface {
- AddTurnConn(turnConn net.Conn) (net.Addr, error)
+ AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error)
CloseConn() error
- Free() error
}
diff --git a/client/internal/wgproxy/proxy_test.go b/client/internal/wgproxy/proxy_test.go
new file mode 100644
index 000000000..b09e6be55
--- /dev/null
+++ b/client/internal/wgproxy/proxy_test.go
@@ -0,0 +1,128 @@
+//go:build linux
+
+package wgproxy
+
+import (
+ "context"
+ "io"
+ "net"
+ "os"
+ "runtime"
+ "testing"
+ "time"
+
+ "github.com/netbirdio/netbird/client/internal/wgproxy/ebpf"
+ "github.com/netbirdio/netbird/client/internal/wgproxy/usp"
+ "github.com/netbirdio/netbird/util"
+)
+
+func TestMain(m *testing.M) {
+ _ = util.InitLog("trace", "console")
+ code := m.Run()
+ os.Exit(code)
+}
+
+type mocConn struct {
+ closeChan chan struct{}
+ closed bool
+}
+
+func newMockConn() *mocConn {
+ return &mocConn{
+ closeChan: make(chan struct{}),
+ }
+}
+
+func (m *mocConn) Read(b []byte) (n int, err error) {
+ <-m.closeChan
+ return 0, io.EOF
+}
+
+func (m *mocConn) Write(b []byte) (n int, err error) {
+ <-m.closeChan
+ return 0, io.EOF
+}
+
+func (m *mocConn) Close() error {
+ if m.closed == true {
+ return nil
+ }
+
+ m.closed = true
+ close(m.closeChan)
+ return nil
+}
+
+func (m *mocConn) LocalAddr() net.Addr {
+ panic("implement me")
+}
+
+func (m *mocConn) RemoteAddr() net.Addr {
+ return &net.UDPAddr{
+ IP: net.ParseIP("172.16.254.1"),
+ }
+}
+
+func (m *mocConn) SetDeadline(t time.Time) error {
+ panic("implement me")
+}
+
+func (m *mocConn) SetReadDeadline(t time.Time) error {
+ panic("implement me")
+}
+
+func (m *mocConn) SetWriteDeadline(t time.Time) error {
+ panic("implement me")
+}
+
+func TestProxyCloseByRemoteConn(t *testing.T) {
+ ctx := context.Background()
+
+ tests := []struct {
+ name string
+ proxy Proxy
+ }{
+ {
+ name: "userspace proxy",
+ proxy: usp.NewWGUserSpaceProxy(51830),
+ },
+ }
+
+ if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" {
+ ebpfProxy := ebpf.NewWGEBPFProxy(51831)
+ if err := ebpfProxy.Listen(); err != nil {
+ t.Fatalf("failed to initialize ebpf proxy: %s", err)
+ }
+ defer func() {
+ if err := ebpfProxy.Free(); err != nil {
+ t.Errorf("failed to free ebpf proxy: %s", err)
+ }
+ }()
+ proxyWrapper := &ebpf.ProxyWrapper{
+ WgeBPFProxy: ebpfProxy,
+ }
+
+ tests = append(tests, struct {
+ name string
+ proxy Proxy
+ }{
+ name: "ebpf proxy",
+ proxy: proxyWrapper,
+ })
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ relayedConn := newMockConn()
+ _, err := tt.proxy.AddTurnConn(ctx, relayedConn)
+ if err != nil {
+ t.Errorf("error: %v", err)
+ }
+
+ _ = relayedConn.Close()
+ if err := tt.proxy.CloseConn(); err != nil {
+ t.Errorf("error: %v", err)
+ }
+ })
+ }
+}
diff --git a/client/internal/wgproxy/proxy_userspace.go b/client/internal/wgproxy/proxy_userspace.go
deleted file mode 100644
index 234ea2a42..000000000
--- a/client/internal/wgproxy/proxy_userspace.go
+++ /dev/null
@@ -1,108 +0,0 @@
-package wgproxy
-
-import (
- "context"
- "fmt"
- "net"
-
- log "github.com/sirupsen/logrus"
-
- nbnet "github.com/netbirdio/netbird/util/net"
-)
-
-// WGUserSpaceProxy proxies
-type WGUserSpaceProxy struct {
- localWGListenPort int
- ctx context.Context
- cancel context.CancelFunc
-
- remoteConn net.Conn
- localConn net.Conn
-}
-
-// NewWGUserSpaceProxy instantiate a user space WireGuard proxy
-func NewWGUserSpaceProxy(ctx context.Context, wgPort int) *WGUserSpaceProxy {
- log.Debugf("Initializing new user space proxy with port %d", wgPort)
- p := &WGUserSpaceProxy{
- localWGListenPort: wgPort,
- }
- p.ctx, p.cancel = context.WithCancel(ctx)
- return p
-}
-
-// AddTurnConn start the proxy with the given remote conn
-func (p *WGUserSpaceProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) {
- p.remoteConn = turnConn
-
- var err error
- p.localConn, err = nbnet.NewDialer().DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
- if err != nil {
- log.Errorf("failed dialing to local Wireguard port %s", err)
- return nil, err
- }
-
- go p.proxyToRemote()
- go p.proxyToLocal()
-
- return p.localConn.LocalAddr(), err
-}
-
-// CloseConn close the localConn
-func (p *WGUserSpaceProxy) CloseConn() error {
- p.cancel()
- if p.localConn == nil {
- return nil
- }
- return p.localConn.Close()
-}
-
-// Free doing nothing because this implementation of proxy does not have global state
-func (p *WGUserSpaceProxy) Free() error {
- return nil
-}
-
-// proxyToRemote proxies everything from Wireguard to the RemoteKey peer
-// blocks
-func (p *WGUserSpaceProxy) proxyToRemote() {
-
- buf := make([]byte, 1500)
- for {
- select {
- case <-p.ctx.Done():
- return
- default:
- n, err := p.localConn.Read(buf)
- if err != nil {
- continue
- }
-
- _, err = p.remoteConn.Write(buf[:n])
- if err != nil {
- continue
- }
- }
- }
-}
-
-// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard
-// blocks
-func (p *WGUserSpaceProxy) proxyToLocal() {
-
- buf := make([]byte, 1500)
- for {
- select {
- case <-p.ctx.Done():
- return
- default:
- n, err := p.remoteConn.Read(buf)
- if err != nil {
- continue
- }
-
- _, err = p.localConn.Write(buf[:n])
- if err != nil {
- continue
- }
- }
- }
-}
diff --git a/client/internal/wgproxy/usp/proxy.go b/client/internal/wgproxy/usp/proxy.go
new file mode 100644
index 000000000..83a8725d8
--- /dev/null
+++ b/client/internal/wgproxy/usp/proxy.go
@@ -0,0 +1,146 @@
+package usp
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "sync"
+
+ "github.com/hashicorp/go-multierror"
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/client/errors"
+)
+
+// WGUserSpaceProxy proxies
+type WGUserSpaceProxy struct {
+ localWGListenPort int
+ ctx context.Context
+ cancel context.CancelFunc
+
+ remoteConn net.Conn
+ localConn net.Conn
+ closeMu sync.Mutex
+ closed bool
+}
+
+// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation
+func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
+ log.Debugf("Initializing new user space proxy with port %d", wgPort)
+ p := &WGUserSpaceProxy{
+ localWGListenPort: wgPort,
+ }
+ return p
+}
+
+// AddTurnConn start the proxy with the given remote conn
+func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) {
+ p.ctx, p.cancel = context.WithCancel(ctx)
+
+ p.remoteConn = remoteConn
+
+ var err error
+ dialer := net.Dialer{}
+ p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
+ if err != nil {
+ log.Errorf("failed dialing to local Wireguard port %s", err)
+ return nil, err
+ }
+
+ go p.proxyToRemote()
+ go p.proxyToLocal()
+
+ return p.localConn.LocalAddr(), err
+}
+
+// CloseConn close the localConn
+func (p *WGUserSpaceProxy) CloseConn() error {
+ if p.cancel == nil {
+ return fmt.Errorf("proxy not started")
+ }
+ return p.close()
+}
+
+func (p *WGUserSpaceProxy) close() error {
+ p.closeMu.Lock()
+ defer p.closeMu.Unlock()
+
+ // prevent double close
+ if p.closed {
+ return nil
+ }
+ p.closed = true
+
+ p.cancel()
+
+ var result *multierror.Error
+ if err := p.remoteConn.Close(); err != nil {
+ result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
+ }
+
+ if err := p.localConn.Close(); err != nil {
+ result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
+ }
+ return errors.FormatErrorOrNil(result)
+}
+
+// proxyToRemote proxies from Wireguard to the RemoteKey
+func (p *WGUserSpaceProxy) proxyToRemote() {
+ defer func() {
+ if err := p.close(); err != nil {
+ log.Warnf("error in proxy to remote loop: %s", err)
+ }
+ }()
+
+ buf := make([]byte, 1500)
+ for p.ctx.Err() == nil {
+ n, err := p.localConn.Read(buf)
+ if err != nil {
+ if p.ctx.Err() != nil {
+ return
+ }
+ log.Debugf("failed to read from wg interface conn: %s", err)
+ return
+ }
+
+ _, err = p.remoteConn.Write(buf[:n])
+ if err != nil {
+ if p.ctx.Err() != nil {
+ return
+ }
+
+ log.Debugf("failed to write to remote conn: %s", err)
+ return
+ }
+ }
+}
+
+// proxyToLocal proxies from the Remote peer to local WireGuard
+func (p *WGUserSpaceProxy) proxyToLocal() {
+ defer func() {
+ if err := p.close(); err != nil {
+ log.Warnf("error in proxy to local loop: %s", err)
+ }
+ }()
+
+ buf := make([]byte, 1500)
+ for p.ctx.Err() == nil {
+ n, err := p.remoteConn.Read(buf)
+ if err != nil {
+ if p.ctx.Err() != nil {
+ return
+ }
+ log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
+ return
+ }
+
+ _, err = p.localConn.Write(buf[:n])
+ if err != nil {
+ if p.ctx.Err() != nil {
+ return
+ }
+ log.Debugf("failed to write to wg interface conn: %s", err)
+ continue
+ }
+ }
+}
diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go
index 779c27a4d..dc13706bf 100644
--- a/client/ios/NetBirdSDK/client.go
+++ b/client/ios/NetBirdSDK/client.go
@@ -168,7 +168,6 @@ func (c *Client) GetStatusDetails() *StatusDetails {
BytesTx: p.BytesTx,
ConnStatus: p.ConnStatus.String(),
ConnStatusUpdate: p.ConnStatusUpdate.Format("2006-01-02 15:04:05"),
- Direct: p.Direct,
LastWireguardHandshake: p.LastWireguardHandshake.String(),
Relayed: p.Relayed,
RosenpassEnabled: p.RosenpassEnabled,
diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go
index fb10a38d3..b942d8b6e 100644
--- a/client/proto/daemon.pb.go
+++ b/client/proto/daemon.pb.go
@@ -1,7 +1,7 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
-// protoc v4.23.4
+// protoc v3.21.12
// source: daemon.proto
package proto
@@ -899,7 +899,6 @@ type PeerState struct {
ConnStatus string `protobuf:"bytes,3,opt,name=connStatus,proto3" json:"connStatus,omitempty"`
ConnStatusUpdate *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=connStatusUpdate,proto3" json:"connStatusUpdate,omitempty"`
Relayed bool `protobuf:"varint,5,opt,name=relayed,proto3" json:"relayed,omitempty"`
- Direct bool `protobuf:"varint,6,opt,name=direct,proto3" json:"direct,omitempty"`
LocalIceCandidateType string `protobuf:"bytes,7,opt,name=localIceCandidateType,proto3" json:"localIceCandidateType,omitempty"`
RemoteIceCandidateType string `protobuf:"bytes,8,opt,name=remoteIceCandidateType,proto3" json:"remoteIceCandidateType,omitempty"`
Fqdn string `protobuf:"bytes,9,opt,name=fqdn,proto3" json:"fqdn,omitempty"`
@@ -911,6 +910,7 @@ type PeerState struct {
RosenpassEnabled bool `protobuf:"varint,15,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"`
Routes []string `protobuf:"bytes,16,rep,name=routes,proto3" json:"routes,omitempty"`
Latency *durationpb.Duration `protobuf:"bytes,17,opt,name=latency,proto3" json:"latency,omitempty"`
+ RelayAddress string `protobuf:"bytes,18,opt,name=relayAddress,proto3" json:"relayAddress,omitempty"`
}
func (x *PeerState) Reset() {
@@ -980,13 +980,6 @@ func (x *PeerState) GetRelayed() bool {
return false
}
-func (x *PeerState) GetDirect() bool {
- if x != nil {
- return x.Direct
- }
- return false
-}
-
func (x *PeerState) GetLocalIceCandidateType() string {
if x != nil {
return x.LocalIceCandidateType
@@ -1064,6 +1057,13 @@ func (x *PeerState) GetLatency() *durationpb.Duration {
return nil
}
+func (x *PeerState) GetRelayAddress() string {
+ if x != nil {
+ return x.RelayAddress
+ }
+ return ""
+}
+
// LocalPeerState contains the latest state of the local peer
type LocalPeerState struct {
state protoimpl.MessageState
@@ -2243,7 +2243,7 @@ var file_daemon_proto_rawDesc = []byte{
0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e,
0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x0c,
0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50,
- 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x22, 0xce, 0x05, 0x0a, 0x09, 0x50, 0x65,
+ 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x22, 0xda, 0x05, 0x0a, 0x09, 0x50, 0x65,
0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20,
0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65,
0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12,
@@ -2255,209 +2255,210 @@ var file_daemon_proto_rawDesc = []byte{
0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x10, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75,
0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79,
0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x65,
- 0x64, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28,
- 0x08, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x12, 0x34, 0x0a, 0x15, 0x6c, 0x6f, 0x63,
- 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79,
- 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49,
+ 0x64, 0x12, 0x34, 0x0a, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e,
+ 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09,
+ 0x52, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64,
+ 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x36, 0x0a, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74,
+ 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70,
+ 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49,
0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12,
- 0x36, 0x0a, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64,
- 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52,
- 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64,
- 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18,
- 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x3c, 0x0a, 0x19, 0x6c,
- 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65,
- 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x19,
- 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74,
- 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x65, 0x6d,
- 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45,
- 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x09, 0x52, 0x1a, 0x72,
- 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74,
- 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x52, 0x0a, 0x16, 0x6c, 0x61, 0x73,
- 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, 0x64, 0x73, 0x68,
- 0x61, 0x6b, 0x65, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67,
- 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65,
- 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x16, 0x6c, 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67,
- 0x75, 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65, 0x12, 0x18, 0x0a,
- 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07,
- 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, 0x12, 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73,
- 0x54, 0x78, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54,
- 0x78, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e,
- 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73,
- 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x16, 0x0a,
- 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x10, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72,
- 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79,
- 0x18, 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e,
- 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f,
- 0x6e, 0x52, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x22, 0xec, 0x01, 0x0a, 0x0e, 0x4c,
- 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a,
- 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a,
- 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70,
- 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49,
- 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f,
- 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12,
- 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66,
- 0x71, 0x64, 0x6e, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73,
- 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72,
- 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12,
- 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d,
- 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f,
- 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76,
- 0x65, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28,
- 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22, 0x53, 0x0a, 0x0b, 0x53, 0x69, 0x67,
- 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18,
- 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f,
- 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63,
- 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f,
- 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x57,
- 0x0a, 0x0f, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74,
- 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03,
- 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64,
- 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65,
- 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09,
- 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x52, 0x0a, 0x0a, 0x52, 0x65, 0x6c, 0x61, 0x79,
- 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x49, 0x18, 0x01, 0x20, 0x01,
- 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x49, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c,
- 0x61, 0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x76, 0x61, 0x69,
- 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03,
- 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x72, 0x0a, 0x0c, 0x4e,
- 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73,
- 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x73, 0x65,
- 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73,
- 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12,
- 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08,
- 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72,
- 0x6f, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22,
- 0xd2, 0x02, 0x0a, 0x0a, 0x46, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x41,
- 0x0a, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74,
- 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
- 0x2e, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65,
- 0x52, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74,
- 0x65, 0x12, 0x35, 0x0a, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65,
- 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
- 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0b, 0x73, 0x69, 0x67,
- 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x3e, 0x0a, 0x0e, 0x6c, 0x6f, 0x63, 0x61,
- 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b,
- 0x32, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50,
- 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50,
- 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x27, 0x0a, 0x05, 0x70, 0x65, 0x65, 0x72,
- 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
- 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x05, 0x70, 0x65, 0x65, 0x72,
- 0x73, 0x12, 0x2a, 0x0a, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28,
- 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79,
- 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x12, 0x35, 0x0a,
- 0x0b, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03,
- 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4e, 0x53, 0x47, 0x72,
- 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72,
- 0x76, 0x65, 0x72, 0x73, 0x22, 0x13, 0x0a, 0x11, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74,
- 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73,
- 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12,
- 0x25, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32,
- 0x0d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06,
- 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22, 0x5b, 0x0a, 0x13, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74,
- 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a,
- 0x08, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52,
- 0x08, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x70, 0x70,
- 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e,
- 0x64, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03,
- 0x61, 0x6c, 0x6c, 0x22, 0x16, 0x0a, 0x14, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75,
- 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1a, 0x0a, 0x06, 0x49,
- 0x50, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x69, 0x70, 0x73, 0x18, 0x01, 0x20, 0x03,
- 0x28, 0x09, 0x52, 0x03, 0x69, 0x70, 0x73, 0x22, 0xf9, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74,
- 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49,
- 0x44, 0x12, 0x18, 0x0a, 0x07, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01,
- 0x28, 0x09, 0x52, 0x07, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x1a, 0x0a, 0x08, 0x73,
- 0x65, 0x6c, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x73,
- 0x65, 0x6c, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69,
- 0x6e, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e,
- 0x73, 0x12, 0x40, 0x0a, 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73,
- 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
- 0x52, 0x6f, 0x75, 0x74, 0x65, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50,
- 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64,
- 0x49, 0x50, 0x73, 0x1a, 0x4e, 0x0a, 0x10, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49,
- 0x50, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01,
- 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x24, 0x0a, 0x05, 0x76, 0x61, 0x6c,
- 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
- 0x6e, 0x2e, 0x49, 0x50, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a,
- 0x02, 0x38, 0x01, 0x22, 0x6a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64,
- 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6e, 0x6f,
- 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x6e,
- 0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75,
- 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12,
- 0x1e, 0x0a, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x03, 0x20,
- 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x22,
- 0x29, 0x0a, 0x13, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65,
- 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01,
- 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x14, 0x0a, 0x12, 0x47, 0x65,
- 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
- 0x22, 0x3d, 0x0a, 0x13, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52,
- 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c,
- 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
- 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22,
- 0x3c, 0x0a, 0x12, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65,
- 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x01,
- 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f,
- 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a,
- 0x13, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70,
- 0x6f, 0x6e, 0x73, 0x65, 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c,
- 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a,
- 0x05, 0x50, 0x41, 0x4e, 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41,
- 0x4c, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08,
- 0x0a, 0x04, 0x57, 0x41, 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f,
- 0x10, 0x05, 0x12, 0x09, 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a,
- 0x05, 0x54, 0x52, 0x41, 0x43, 0x45, 0x10, 0x07, 0x32, 0xb8, 0x06, 0x0a, 0x0d, 0x44, 0x61, 0x65,
- 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f,
- 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67,
- 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d,
- 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
- 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67,
- 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74,
- 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a,
- 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f,
- 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12,
- 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55,
- 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
- 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39,
- 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
- 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a,
- 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52,
- 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77,
- 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52,
- 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
- 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42,
- 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61,
- 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65,
- 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47,
- 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
- 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73,
- 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f,
- 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61,
- 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52,
- 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x53, 0x65, 0x6c,
- 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d,
- 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52,
- 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
- 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70,
- 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65,
- 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
- 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65,
- 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53,
- 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f,
- 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75,
- 0x6e, 0x64, 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65,
- 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
- 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42,
- 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12,
- 0x48, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a,
- 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65,
- 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65,
+ 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66,
+ 0x71, 0x64, 0x6e, 0x12, 0x3c, 0x0a, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43,
+ 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74,
+ 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65,
+ 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e,
+ 0x74, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61,
+ 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18,
+ 0x0b, 0x20, 0x01, 0x28, 0x09, 0x52, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65,
+ 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e,
+ 0x74, 0x12, 0x52, 0x0a, 0x16, 0x6c, 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61,
+ 0x72, 0x64, 0x48, 0x61, 0x6e, 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65, 0x18, 0x0c, 0x20, 0x01, 0x28,
+ 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f,
+ 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x16, 0x6c,
+ 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, 0x64,
+ 0x73, 0x68, 0x61, 0x6b, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78,
+ 0x18, 0x0d, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, 0x12,
+ 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x03,
+ 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73,
+ 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0f, 0x20,
+ 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e,
+ 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18,
+ 0x10, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a,
+ 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x18, 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19,
+ 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66,
+ 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e,
+ 0x63, 0x79, 0x12, 0x22, 0x0a, 0x0c, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x41, 0x64, 0x64, 0x72, 0x65,
+ 0x73, 0x73, 0x18, 0x12, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x41,
+ 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0xec, 0x01, 0x0a, 0x0e, 0x4c, 0x6f, 0x63, 0x61, 0x6c,
+ 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18,
+ 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62,
+ 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65,
+ 0x79, 0x12, 0x28, 0x0a, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72,
+ 0x66, 0x61, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, 0x6b, 0x65, 0x72, 0x6e,
+ 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66,
+ 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12,
+ 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62,
+ 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e,
+ 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72,
+ 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69,
+ 0x76, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70,
+ 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x12, 0x16, 0x0a,
+ 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72,
+ 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22, 0x53, 0x0a, 0x0b, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53,
+ 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28,
+ 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63,
+ 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65,
+ 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20,
+ 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x57, 0x0a, 0x0f, 0x4d, 0x61,
+ 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a,
+ 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12,
+ 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01,
+ 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a,
+ 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72,
+ 0x72, 0x6f, 0x72, 0x22, 0x52, 0x0a, 0x0a, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74,
+ 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x49, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03,
+ 0x55, 0x52, 0x49, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65,
+ 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c,
+ 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09,
+ 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x72, 0x0a, 0x0c, 0x4e, 0x53, 0x47, 0x72, 0x6f,
+ 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65,
+ 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72,
+ 0x73, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03,
+ 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x65,
+ 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e,
+ 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x04,
+ 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0xd2, 0x02, 0x0a, 0x0a,
+ 0x46, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x41, 0x0a, 0x0f, 0x6d, 0x61,
+ 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20,
+ 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4d, 0x61, 0x6e,
+ 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0f, 0x6d, 0x61,
+ 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x35, 0x0a,
+ 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01,
+ 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x69, 0x67, 0x6e,
+ 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53,
+ 0x74, 0x61, 0x74, 0x65, 0x12, 0x3e, 0x0a, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65,
+ 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64,
+ 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53,
+ 0x74, 0x61, 0x74, 0x65, 0x52, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53,
+ 0x74, 0x61, 0x74, 0x65, 0x12, 0x27, 0x0a, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20,
+ 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x50, 0x65, 0x65,
+ 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2a, 0x0a,
+ 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e,
+ 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74,
+ 0x65, 0x52, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x12, 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73,
+ 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14,
+ 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53,
+ 0x74, 0x61, 0x74, 0x65, 0x52, 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73,
+ 0x22, 0x13, 0x0a, 0x11, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65,
+ 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75,
+ 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x06, 0x72,
+ 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x64, 0x61,
+ 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74,
+ 0x65, 0x73, 0x22, 0x5b, 0x0a, 0x13, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74,
+ 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x72, 0x6f, 0x75,
+ 0x74, 0x65, 0x49, 0x44, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x72, 0x6f, 0x75,
+ 0x74, 0x65, 0x49, 0x44, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, 0x64, 0x18,
+ 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, 0x64, 0x12, 0x10, 0x0a,
+ 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22,
+ 0x16, 0x0a, 0x14, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52,
+ 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1a, 0x0a, 0x06, 0x49, 0x50, 0x4c, 0x69, 0x73,
+ 0x74, 0x12, 0x10, 0x0a, 0x03, 0x69, 0x70, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x03,
+ 0x69, 0x70, 0x73, 0x22, 0xf9, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a,
+ 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a,
+ 0x07, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07,
+ 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, 0x6c, 0x65, 0x63,
+ 0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x73, 0x65, 0x6c, 0x65, 0x63,
+ 0x74, 0x65, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x04,
+ 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x40, 0x0a,
+ 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x18, 0x05, 0x20, 0x03,
+ 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x6f, 0x75, 0x74,
+ 0x65, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x45, 0x6e, 0x74,
+ 0x72, 0x79, 0x52, 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x1a,
+ 0x4e, 0x0a, 0x10, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x45, 0x6e,
+ 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09,
+ 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x24, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02,
+ 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x49, 0x50,
+ 0x4c, 0x69, 0x73, 0x74, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22,
+ 0x6a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65,
+ 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6e, 0x6f, 0x6e, 0x79, 0x6d, 0x69,
+ 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x6e, 0x6f, 0x6e, 0x79, 0x6d,
+ 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, 0x20,
+ 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x1e, 0x0a, 0x0a, 0x73,
+ 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52,
+ 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x22, 0x29, 0x0a, 0x13, 0x44,
+ 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
+ 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09,
+ 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x14, 0x0a, 0x12, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67,
+ 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3d, 0x0a, 0x13,
+ 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f,
+ 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01,
+ 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c,
+ 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x3c, 0x0a, 0x12, 0x53,
+ 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
+ 0x74, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e,
+ 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76,
+ 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74,
+ 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
+ 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07,
+ 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e,
+ 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12,
+ 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41,
+ 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09,
+ 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41,
+ 0x43, 0x45, 0x10, 0x07, 0x32, 0xb8, 0x06, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53,
+ 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12,
+ 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65,
+ 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c,
+ 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b,
+ 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b,
+ 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c,
+ 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61,
+ 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69,
+ 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55,
+ 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71,
+ 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70,
+ 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74,
+ 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74,
+ 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61,
+ 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f,
+ 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e,
+ 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65,
+ 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e,
+ 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65,
+ 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
+ 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
+ 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f,
+ 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45,
+ 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64,
+ 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73,
+ 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e,
+ 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f,
+ 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52,
+ 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53,
+ 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65,
+ 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65,
+ 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65,
+ 0x22, 0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f,
+ 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65,
+ 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73,
+ 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63,
+ 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22,
+ 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65,
+ 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42,
+ 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64,
+ 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c,
+ 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47,
+ 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65,
0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52,
- 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74,
- 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f,
- 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71,
- 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65,
- 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73,
- 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70,
- 0x72, 0x6f, 0x74, 0x6f, 0x33,
+ 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e,
+ 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f,
+ 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c,
+ 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65,
+ 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
+ 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67,
+ 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42,
+ 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f,
+ 0x33,
}
var (
diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto
index 43c379fb5..384bc0e62 100644
--- a/client/proto/daemon.proto
+++ b/client/proto/daemon.proto
@@ -168,7 +168,6 @@ message PeerState {
string connStatus = 3;
google.protobuf.Timestamp connStatusUpdate = 4;
bool relayed = 5;
- bool direct = 6;
string localIceCandidateType = 7;
string remoteIceCandidateType = 8;
string fqdn = 9;
@@ -180,6 +179,7 @@ message PeerState {
bool rosenpassEnabled = 15;
repeated string routes = 16;
google.protobuf.Duration latency = 17;
+ string relayAddress = 18;
}
// LocalPeerState contains the latest state of the local peer
diff --git a/client/server/debug.go b/client/server/debug.go
index 1187f3187..5ed43293b 100644
--- a/client/server/debug.go
+++ b/client/server/debug.go
@@ -369,8 +369,8 @@ func seedFromStatus(a *anonymize.Anonymizer, status *peer.FullStatus) {
}
for _, relay := range status.Relays {
- if relay.URI != nil {
- a.AnonymizeURI(relay.URI.String())
+ if relay.URI != "" {
+ a.AnonymizeURI(relay.URI)
}
}
}
diff --git a/client/server/server.go b/client/server/server.go
index 8173d0741..0a4c18131 100644
--- a/client/server/server.go
+++ b/client/server/server.go
@@ -12,7 +12,6 @@ import (
"github.com/cenkalti/backoff/v4"
"golang.org/x/exp/maps"
-
"google.golang.org/protobuf/types/known/durationpb"
log "github.com/sirupsen/logrus"
@@ -143,10 +142,12 @@ func (s *Server) Start() error {
s.sessionWatcher.SetOnExpireListener(s.onSessionExpire)
}
- if !config.DisableAutoConnect {
- go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
+ if config.DisableAutoConnect {
+ return nil
}
+ go s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil)
+
return nil
}
@@ -154,7 +155,7 @@ func (s *Server) Start() error {
// mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status,
- mgmProbe *internal.Probe, signalProbe *internal.Probe, relayProbe *internal.Probe, wgProbe *internal.Probe,
+ runningChan chan error,
) {
backOff := getConnectWithBackoff(ctx)
retryStarted := false
@@ -185,7 +186,15 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Conf
runOperation := func() error {
log.Tracef("running client connection")
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
- err := s.connectClient.RunWithProbes(mgmProbe, signalProbe, relayProbe, wgProbe)
+
+ probes := internal.ProbeHolder{
+ MgmProbe: s.mgmProbe,
+ SignalProbe: s.signalProbe,
+ RelayProbe: s.relayProbe,
+ WgProbe: s.wgProbe,
+ }
+
+ err := s.connectClient.RunWithProbes(&probes, runningChan)
if err != nil {
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
}
@@ -576,9 +585,22 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
- go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
+ runningChan := make(chan error)
+ go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, runningChan)
- return &proto.UpResponse{}, nil
+ for {
+ select {
+ case err := <-runningChan:
+ if err != nil {
+ log.Debugf("waiting for engine to become ready failed: %s", err)
+ } else {
+ return &proto.UpResponse{}, nil
+ }
+ case <-callerCtx.Done():
+ log.Debug("context done, stopping the wait for engine to become ready")
+ return nil, callerCtx.Err()
+ }
+ }
}
// Down engine work in the daemon.
@@ -590,28 +612,19 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
return nil, fmt.Errorf("service is not up")
}
s.actCancel()
+
+ err := s.connectClient.Stop()
+ if err != nil {
+ log.Errorf("failed to shut down properly: %v", err)
+ return nil, err
+ }
+
state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle)
- maxWaitTime := 5 * time.Second
- timeout := time.After(maxWaitTime)
+ log.Infof("service is down")
- engine := s.connectClient.Engine()
-
- for {
- if !engine.IsWGIfaceUp() {
- return &proto.DownResponse{}, nil
- }
-
- select {
- case <-ctx.Done():
- return &proto.DownResponse{}, nil
- case <-timeout:
- return nil, fmt.Errorf("failed to shut down properly")
- default:
- time.Sleep(100 * time.Millisecond)
- }
- }
+ return &proto.DownResponse{}, nil
}
// Status returns the daemon status
@@ -745,11 +758,11 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
ConnStatus: peerState.ConnStatus.String(),
ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate),
Relayed: peerState.Relayed,
- Direct: peerState.Direct,
LocalIceCandidateType: peerState.LocalIceCandidateType,
RemoteIceCandidateType: peerState.RemoteIceCandidateType,
LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint,
RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint,
+ RelayAddress: peerState.RelayServerAddress,
Fqdn: peerState.FQDN,
LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake),
BytesRx: peerState.BytesRx,
@@ -763,7 +776,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus {
for _, relayState := range fullStatus.Relays {
pbRelayState := &proto.RelayState{
- URI: relayState.URI.String(),
+ URI: relayState.URI,
Available: relayState.Err == nil,
}
if err := relayState.Err; err != nil {
diff --git a/client/server/server_test.go b/client/server/server_test.go
index 6a3de774c..9b18df4d3 100644
--- a/client/server/server_test.go
+++ b/client/server/server_test.go
@@ -6,10 +6,11 @@ import (
"testing"
"time"
- "github.com/netbirdio/management-integrations/integrations"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
+ "github.com/netbirdio/management-integrations/integrations"
+
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
@@ -73,7 +74,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
t.Setenv(maxRetryTimeVar, "5s")
t.Setenv(retryMultiplierVar, "1")
- s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe)
+ s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil)
if counter < 3 {
t.Fatalf("expected counter > 2, got %d", counter)
}
@@ -129,8 +130,9 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
if err != nil {
return nil, "", err
}
- turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
- mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
+
+ secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
+ mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil)
if err != nil {
return nil, "", err
}
@@ -158,7 +160,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
log.Fatalf("failed to listen: %v", err)
}
- srv, err := signalServer.NewServer(otel.Meter(""))
+ srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)
diff --git a/client/testdata/management.json b/client/testdata/management.json
index 4745f2e8c..674c66e06 100644
--- a/client/testdata/management.json
+++ b/client/testdata/management.json
@@ -20,6 +20,13 @@
"Secret": "c29tZV9wYXNzd29yZA==",
"TimeBasedCredentials": true
},
+ "Relay": {
+ "Addresses": [
+ "localhost:0"
+ ],
+ "CredentialsTTL": "1h",
+ "Secret": "b29tZV9wYXNzd29yZA=="
+ },
"Signal": {
"Proto": "http",
"URI": "signal.wiretrustee.com:10000",
@@ -34,4 +41,4 @@
"AuthAudience": "",
"AuthKeysLocation": ""
}
-}
\ No newline at end of file
+}
diff --git a/client/ui/bundled.go b/client/ui/bundled.go
new file mode 100644
index 000000000..e2c138b14
--- /dev/null
+++ b/client/ui/bundled.go
@@ -0,0 +1,12 @@
+// auto-generated
+// Code generated by '$ fyne bundle'. DO NOT EDIT.
+
+package main
+
+import "fyne.io/fyne/v2"
+
+var resourceNetbirdSystemtrayConnectedPng = &fyne.StaticResource{
+ StaticName: "netbird-systemtray-connected.png",
+ StaticContent: []byte(
+ "\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\x00\x00\x00\x01\x00\b\x06\x00\x00\x00\\r\xa8f\x00\x00\x00\xc3zTXtRaw profile type exif\x00\x00x\xdamP\xdb\r\xc3 \f\xfc\xf7\x14\x1d\xc1\xaf\x80\x19\x874\xa9\xd4\r:~\r8Q\x88r\x92χ\x9d\x1cư\xff\xbe\x1fx50)\xe8\x92-\x95\x94СE\vW\x17\x86\x03\xb53\xa1v>@\xc1S\x1dNɞų\x8c\x86\xa5\xf8\xeb\xa8\xd3d\x83T]-\x17#{Gc\x9d\x1bEGf\xbb\x19\xc5E\xd2&b\x17[\x18\x950\x12\x1e\r\n\x83:\x9e\x85\xa9X\xbe>a\xddq\x86\x8d\x80F\x92\xbb\xf7ir?k\xf6\xedm\x8b\x17\x85y\x17\x12t\x16\xd11\x80\xb4P\x90\xdaE\xf5\xf0\xa1\xfc#u-\x92:[L\xe2\vy\xda\xd3\x01\xf8\x03\xda\xd4Y\x17ݮ\xb7\xee\x00\x00\x01\x84iCCPICC profile\x00\x00x\x9c}\x91=H\xc3@\x1c\xc5_S\xa5\"-\x0e\x16\x14\x11\xccP\x9d\xec\xa2\"\xe2T\xabP\x84\n\xa5Vh\xd5\xc1\xe4\xd2/hҐ\xa4\xb88\n\xae\x05\a?\x16\xab\x0e.κ:\xb8\n\x82\xe0\a\x88\xb3\x83\x93\xa2\x8b\x94\xf8\xbf\xa4\xd0\"ƃ\xe3~\xbc\xbb\xf7\xb8{\a\b\x8d\nSͮ\x18\xa0j\x96\x91N\xc4\xc5lnU\f\xbcB\xc0\x00B\x18\xc1\xac\xc4L}.\x95J\xc2s|\xdd\xc3\xc7\u05fb(\xcf\xf2>\xf7\xe7\b)y\x93\x01>\x918\xc6t\xc3\"\xde \x9e\u07b4t\xce\xfb\xc4aV\x92\x14\xe2s\xe2q\x83.H\xfc\xc8u\xd9\xe57\xceE\x87\x05\x9e\x1962\xe9y\xe20\xb1X\xec`\xb9\x83Y\xc9P\x89\xa7\x88#\x8a\xaaQ\xbe\x90uY\xe1\xbc\xc5Y\xad\xd4X\xeb\x9e\xfc\x85\xc1\xbc\xb6\xb2\xccu\x9a\xc3H`\x11KHA\x84\x8c\x1aʨ\xc0B\x94V\x8d\x14\x13iڏ{\xf8\x87\x1c\x7f\x8a\\2\xb9\xca`\xe4X@\x15*$\xc7\x0f\xfe\a\xbf\xbb5\v\x93\x13nR0\x0et\xbf\xd8\xf6\xc7(\x10\xd8\x05\x9au\xdb\xfe>\xb6\xed\xe6\t\xe0\x7f\x06\xae\xb4\xb6\xbf\xda\x00f>I\xaf\xb7\xb5\xc8\x11з\r\\\\\xb75y\x0f\xb8\xdc\x01\x06\x9ftɐ\x1c\xc9OS(\x14\x80\xf73\xfa\xa6\x1c\xd0\x7f\v\xf4\xae\xb9\xbd\xb5\xf6q\xfa\x00d\xa8\xab\xe4\rpp\b\x8c\x15){\xdd\xe3\xdd=\x9d\xbd\xfd{\xa6\xd5\xdf\x0fںr\xd0VwQ\xba\x00\x00\rxiTXtXML:com.adobe.xmp\x00\x00\x00\x00\x00\n\n \n \n \n \n \n \n \n \n \n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\xf0C\xff\xd9\x00\x00\x00\x06bKGD\x00\xff\x00\xff\x00\xff\xa0\xbd\xa7\x93\x00\x00\x00\tpHYs\x00\x00\v\x13\x00\x00\v\x13\x01\x00\x9a\x9c\x18\x00\x00\x00\atIME\a\xe8\x02\x17\r$'\xdd\xf7ȗ\x00\x00\x13;IDATx\xda\xed\x9d]o\x14W\x9a\xc7\xff\xa7\xaamh\xbf\xc46,I`\x99\xa1\xc3\ni\xb5{1\x95O0\xe4\x1b\xc0'X\xf2\t`.W`hp\xa2\xb9\fH{O\xa3\xcc\xc5\xecJ3q\xa4\x1d\xed\xcdJx>Aj/\"EBJګL \xb1\x00g\xf1\v\xb6\xbb\xeb\xec\x85mb\f\xb6\xfb\xa5^Ω\xfa\xfd\xee\x928v\xf7\xa9z\xfe\xcfs\x9e\xa7ο$\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\u0603a\t\xc0g\xd6\x7f\x1f5\x92\x8e\"k\xd4\b\xa4s\xb2jH\x9afez\n\xfe\xdb\b\x00x\x81mF\xd3/CE]\xa3(\x94~c\xa5\x8b;\xc1\x0e\x83E\x7f{\xecF\xfcA\x8d\x95\x00g\xb3\xfb\\tQ\xd2o\xadtq]\xba(I\x81\x95,K345\xa3˒\x84\x00\x80SY~5Х0\xd0o\x13\xabK\x96R>\x9b\xe4o\xd4\x1a\xbd\x1e\xc7\b\x008\x93\xe9\xadtkM\x8a\x02i\xdaZ\x9aS\x99\x12\xea\xf6\xabJ\x80Հ\x02\xf7\xf4W\x13\xe9\xdan\xa6'\xe8sXw\xe9\xf6ؿ\xc6\xed_Z\x01\x00\x05d{\xed\xec\xe9!\xcf\xda\x7f\xbb\xf1\xf7Z/\x80U\x81<\x03\xdf\x12\xf8\xc5\xc5\x7f\xf2K\xe9O\x05\x00d\xfcje\xffx\xecF\xfc\xe1\xfe\x7fM\x05\x00\xd9\x04~3j$5ݲVWX\r\a\xe2?\xdc\x1e\xfb!\x00\x909ks\xd1\xd5Dj\x1a\xcb\x18ω\xe07j\xd5\xf74\xfe\x10\x00Ȅ\x95O\xa3(Ht_R\xc4\xdeҙҿ\xbdw췟\x80\x15\x82\xb4\xb2~\x90\xe8+I\x11\xab\xe1\x0e\xd6\xea\xc1A\xd9\x7f[\x1f\x00\x86\xdc\xeb\xdbP\xf7E\x93\xcf\xc9\xec\xbf\x7f\xec\xc7\x16\x00\xd2\v\xfe\xb9\xe8b\"}axd\xd7\xcd\xf8O\x0e.\xfd\xd9\x02\xc0\xb0\xc1\x7f\xcbJ\x0f\t~G\t4_\xbf\x19\xb7\x8e\xfa1*\x00\xe8\x9b\xd5O\xa2\xfb\x8c\xf7\x1c\xcf\xfe\x81~\xd7\xcb\xcf!\x00\xd03\xb6\x19M\xaf\x87\xfaB\x96\xfd\xbe\xd3\xc1\x7f\xc8؏-\x00\fV\xf27\xa3\xc6z\xa8\x87\xa2\xd9\xe7x\xf4\x1f>\xf6\xa3\x02\x80\x81\x82\xdf\xd6\xf4\x10\a\x1e\x0f\xe2?\xd1\xed\xfa\x8d\u07b2\xff\xb6^\x00\x10\xfc\xa5\xc9\xfeG\x8d\xfdJW\x01\xd8f4\xfd\xf2ؾNt\xe7\xe8\x9b5\f\xb4\xdc\r\xb4\xbc\xfb\xcf\xc77\xb4l\x9a\xf12wѾ=?\xc1\xef\xcf\xf5\xb2\xbd5\xfe\x9c\xac\x00\xd6\x7f\x1f5\xc2D\xd3[\x89\x1a\xd6jZ\x81\xa6\x8d\xd5t`tn\xe7\xcb5$M\xcb\xec\x04{\x867\xa5\x95\x96\x8dѲ\xacvK\xa9ec\xb4\x9cX-Z\xa3ec\xd5\x0e\xa4e\xd5\xd4\xee\xb5\xd9\xe2#ks\x11O\xf6\xf9\x92\xfc\x8dZ\xf5\x1b\xf1\xc7N\n\x80mF\xd3[#jlv\x15)\xd0\xf4\x1e\xfb憌\xa6\xbd\xcf0F\xed\x1d\xb1X\xb6\xd2\xffX\xabvh\xd4>\xdeU\xeckU\xb1v'\xfaLF\xd7\b-On\xc1\x9a>\x18$\x19\x99,\x82<\b\xf4\x1b\xb3\xed\xed\x16Y\xa9Q\xe5\x87E\xac\xb4l\xa4XFq\"-\x86V\xb1\xeb°\xf3\x90O\x93\xb0\xf2\xe6\x1e\xbb=>\x1b\x0ft\xbd\xcc\xc0\x81nuq'\x93Gv\xfb\xf4\x17O\x84\xf5G,\xa9\x9d\x18\xfd5\xb4\x8a\xeb\xb3\xf1\x82\v\x1fju.\xbad\xa4/\xb8<\xfeT\x9f\xf5\x8e>\x1c4\xa1\x98\xa3\x82}-\xd4Ek\xd4\xe0e\f\xb9\xb0 \xa3\xd8Z\xfdu\xac\xab\x85\xbc\xab\x04:\xfe\x1eƿ\xd5ǽ<\xf2ۓ\x00l~\x1aE\x9bV\x17\tv\x87\xaa\x04\xa3\x05c\xf5e\x1e\x15\xc2\xda'\xd1w\\s\xbf\xb2\x7f\xbfc\xbf7~\xc5\xda\\tU\xd2%\xcax/z\t\v\x89\u0557\xe1\x88\x16Ҟ>\xb0\xef\xf70\xfe\al\xfc\xbd\xf6;V>\x8d\"\x93p\xaa\xcb\xc7\xedBb\xf5 \r1\xd89\xd3\xff\x1dK\xeaQ\xf0\x0f8\xf6{\xeb\x16`ǹ\xf5!\xcbZM1X\x9d\x8b\xbe2\xcc\xfb\xbd\xaa\x06\x83\x9a>L\xa3\n|\xd5\x03X\xbf\x13]\xb1F\xf7Y^\uf677҃\xf1\xd9x\xbe\x97\x1f^\xb9\x13]\t\xb8\xee\xbe\t\xc0\xc0c\xbf\x03\x05\x00\x11([\x8d\xa8\xb6\x12͛\x11\xdd;,S\xd0\xf8\xf3\xef\xba\x0e\xdb\xf8\xdb\xcbkǁ\xeb7\xe3\x96U\xefG\t\xc1\xe94ѐ\xd15\xdb\xd1w\xab\x9fD\xf7w^\xb5\xfd\xfa\xde\x7f.\xbaE\xf0{\x16\xffI\xba\xf1i\x0e\xd8\x136\xcd\xf6\xdb\\\xa0d\xd9#It{\xe2fܢ\xf1\xe7\xe5\xf5[\x18\xbb\x11\x7f\x94\xb9\x00 \x02\x15\x10\x82\xe7\x13\xed`z\xe5\"\x8b\xe1\xd1eKa\xec׳\x00 \x02%\xde\x1dlִ\xf5\xf5\xafeF;\nO?Wp\xe2\x05\x8b\xe2z\xf0\xa74\xf6;\xb4\a\xb0\x9f\xf1ٸ\x99H\x0fX\xfer\xd1yt\xe6\x95\x10t\x16Oi\xeb\xeb_+y6\xc9\xc28\\\xb1\xf5c\xf3\x95\x9a\x00H\xd2\xc4l|\x05\x11(\x0fɳI\xd9\xcd\xda\x1b\x15\xc1\xae\x10ؕ:\x8b\xe4\x1e\xf7\xb2\xf2\x9d\xe8\xf94\xe0\xfa\\\xf4\x90w\xbb{^\xfaw\x03u\xbe9\xfb\x86\x00\xbc\x91\x15N\xac(<\xfd\\ft\x8bEs \xfb\xa79\xf6\xeb\xbb\x02\xd8\xe5eW\x97\xed\xf6\x11V\xf05\xfb\xff4ud\xf0oW\t\x13\xda\xfa\xfaW\xea>\x9eaъ\x8e\xff$۱|_~\x00ϛ\xd1\xf4h\xa8\x87<6\xeaa\xf6\xdfi\xfc\xf5}\x83\x8cvT\xbb\xf0\x98j\xa0\x88\xe0Ϩ\xf17P\x05 I3\xcdx9\xe8게\xda\\\x1e\xbf\x184\x9bo\vǯ\xd4\xfd\xfe\xa4l\x97\xd7H\xe4J\x98\xfdCy}_\xd1z3n\x9b\x8e>B\x04<\xca\xfe+\xf5\xa1\xbb\xfcݥ\xa9\x9d\xfe\xc1\b\v\x9a\xc75\x93n\xe7a8;\x90\xa4#\x02~\xd1Y<\x95\xe26\x82\xde@\xf6\xb5\xbf\xdaAM\xad<\xfe\xd4\xc05\x1d\"\xe0\ao\x1b\xfb\r\xbd\x9dx2\xa3-\xaa\x81\xec\xe2?\xc9'\xfbok͐`(\xe2p\x19\xb9YS\xe7љ\xd4\x05\xe0\xd5\xcd3\xdaQx\xf6\xa9\x82\xa9U\x16;\xc5\xec\x9f\xe5\xd8/\xb5\n`\x97\x89\xebql\x03}d%ު\xe3\x18\xdd\xc73\x99\x05\xff+\x81\xf9\xf6=\xb6\x04)R3\xba\x9c\xe7\xdfK\xa5\xad;q=\x8e%}\xcc\xe5s+\xfb\xe7\xf5xo\xf7Ɍ:\x8b\xef2%\x186\xf9\x1b\xb5F\xb7c\xc9/\x01\x90\xa4\xf1\xd9x\xdeXD\xc0\xa5\xec\x9fo\xafa\x82)\xc1\xb0\x84\xf9{q\xa4*\xd9\xf5\x9bqK\xa6\xff\x17\x14B\xda\xc18Y\xc8\xe1\x9e\xed\x9e\xc3iD`\x90\xb5S~\x8d\xbf\xd7[\x0e\x19\xc01\xe2b\xd9\xfa\xfaי\xee\xfd\x8f\xced\x89F.<\x96\xa9op1z\x8b\xc2\\\x1b\x7f\x99U\x00{\xb6\x03M\xac\xc5\n*\xfd{|\xde?\xdb\x0f\x11h\xeb\xd1i%?\x8fsAz\x89\xff\xa4\xb8X\xc9\xf4\xed\xc0T\x02E\x94\xe0g\x8a\x17\x80\xbd\xc5\xc0\xb9%\x85\x18\x8e\x1c\x16\x81\xf1؍\xf8â\xfe|\xa6m\xdb\x1dC\x91{\\\xe5\x9c\x12o\xc6c\xbf\x81>\xd3\xe2)u1\x1b98\xfe\xc3|\xc7~\xb9\n\x80$M\xcc\xc6\xd70\x14\xc9'\xfb\xbb\xea\xea\x83\b\x1c\x10\xfcF\xad\"\x1a\x7f\xb9\n\xc0\x8e\b\xe0*\x941\x9do\xdfw\xbb:A\x04\xf6\x97\xfe\xed\"\xc6~\x85\b\x80$muu\rC\x91lH\x9eMʮ\x8f\xba\xbfE\xf9\xfe\xa4\xec\xfa1.\x98$k\xf5\xa0\xe8쿭C9\x82\xa1HF\xe2Z\xf4د\x1f\xc2D#\xff\xf8\xb7j\x1b\x8c\x148\xf6+\xac\x02\x90\xb6\rE6\xbb\x9c L5\xab:\xd8\xf8;\xfc\x03\a\x95\x7fX\xa8ȱ_\xa1\x02\xb0+\x02\x1c#N\xa9\x8cܬ\xa9\xfbd\xc6\xcb\xcf\xdd\xf9\xf6\xbdJ\x9e\x1d0F\xad\xfa\u0378UY\x01\x90\xf0\x12H3\xfb{+^\xeb\xa3\xea~\xffwջh\xa1[\x0f\xc8\x15&\xc1\x88\xc0\xb0\x01t\xcc\xfb\x97y$\xcf&*u\x94\u0605\xb1\x9f3\x02\xb0W\x04\xf0\x12\xe8\x9fη\uf563\x8ay2S\x8d\x97\x9182\xf6sJ\x00vE\x00C\x91~3\xe7\xa4_\x8d\xbf#\xd8\xfa\xf6\xbd\xd27\x05\xf3\xb4\xf9\xeaO\x97\x1c\x01k\xb1\x1eK\x7f\a\x9f\xf7O\xe5F\x9cx\xa9\x91\v?\x946\xfb\xbb2\xf6s\xae\x02\xd8e\xe2z\x1c\a\x16/\x81#\xb3\xff\xd3\xc9\xd2\x05\xbf$ٕ\xe3\xea.M\x953\xfe\x1d6\xcaqj\x0eS\xbf\x19\xb7p\x15:<\xfb\xfb8\xf6\xeb\xb9\x1f\xf0\xfd\xc9\xd2\xf5\x03\x8cQ\xab>\x1b/ \x00}\x88\x00\xaeB\a\x04H\x05:\xe6\x9d\xc5S\xe5z> t\xdb\x17\xc3ɕ\x1e\xbb\x11\xdf\xc5Pd_\xf6\xffy\xdc\xfb\xb1_\xafUNR\x12\xa1+\xca\xe6\xcb{\x01\x90p\x15z#3~\x7f\xb2:\x95\xceҔ\xff[\x01\xa3\xf6XWw]\xff\x98N\xd7Z\x88\xc06e\x1b\xfbUA\xf0L\xa2ۦ\x19;?\xda6>,\xe6\xca\\t7\x90\xaeV\xb2\xf4/\xe9د\xa7\xed\xf3٧\nO\xfd\xecg\xf6wt\xec\xe7U\x05\xb0K\x95]\x85\xbc;\xed\x97\xf6w\xf7\xb0!hB}\xe4\xcbg\xf5fu\xab\xe8*\xe4\xb2\xcdW>\n\x10xw`\xc8\xc5\xe7\xfdK!\x00R\xf5\\\x85*yZn\x1fɳ\to\x1a\x82VZv}\xec\xe7\xb5\x00\xec\x1a\x8aTA\x04\x92g\x93J~\x1e\x13H\x1d\x7fƂ\xf7|\xca\xfe\x92'M\xc0\xfd\xac7\xa3\x86\xad顬\x1ae\xbd齲\xf9ʁ\x91\v\x8fe&\xd6]\x8e$o\x1a\x7f\xdeV\x00\xbb\x94\xddK\xa0ʍ?_\xab\x00\x97l\xbeJ/\x00e\x16\x01\xbbYSR\xd2C1C\xad\xcb\xcaqw{\x01\x81\xe6]\xb2\xf9\xaa\x84\x00\x94U\x04|\x1d}U\xb9\n0\x81\xbfgW\xbc\xbf\xd3\xca\xe4*T\xf9\xb1\x9f\x87U\x80oc\xbf\xd2\t\xc0\xae\b\x94\xc1U\xa8\xf3\xe8\fQ\xeeS\x15\xe0\xa8\xcdW\xe5\x04@\xda1\x14Q\xb1/Z\x1c\x86*>\xef\xef{\x15\xe0\xaa\xcdW%\x05@\x92\xea\xb3\U000423c6\"\xb6\x1bT\xca\x1d\xb7\x14U\x80Q\xdb\xd7\xc6_i\x05@\xf2\xd3U(\xf9i\x8a\xec\xdfo\x15P\xb0\x89\xa8-\x89}])\xdb\xcd\xf5\x9bq˗c\xc4e\xb7\xf9\xcan\xcb4Q\\\xf27j\x8d\xcf\xc6\xf3\b\x80\xc3\xf8\xe2%@\xe9?\xe0\xba\xfd4Uܸ4,\x8fGE\xa9\aή\x8b\x80]\xa93\xf6\x1bX\x01\x82B\xd6\xce\a\x9b/\x04\xc0\x13\x11\xe8,\x9e\"\x90\x87\xd9\x06,\x8f\xe7\\\xfb\xab\x1d\xd4\xd4*\xd3\x1aV⑳\xf1ٸ隗\x00c\xbf4*\xa8\xe3\xb2\xeb\xc7\xf2\x8b\xff\xa4\\ٿ2\x02 \xb9e(b7k\xec\xfd\xd3\x12Ҽ\x8eL\x97d\xecWY\x01\xd8\x15\x01#-\x14~\xd32\xf6K\xaf\x15\xf0S>\a\xa7j\xc6߇\xcc\x10\x80=\xbc\xec\xear\x91\x86\"v\xb3V\xdaW`\x15\xa3\x00A\xe6O\x06\x1a\xa3\xd6\xe8\xf58F\x00J@ѮB\x94\xfe\x19\xaci\xd6ۀ\xb0\xbc\xd6\xf4\x95@X\xbd\xec\x8f\x00d \x02\x8c\xfd\x1c\xe8\x03\xf4)\x02I\xc5\x1a\x7f\b\xc0\x80\"Ћ\xa1\bo\xf7q\xa1\x0f\xd0\xc76\xc0\xa8\x1d\xd6t\xb7\xaak\x85\x00\xf4\xc1Q\xaeB\xd8|\xb9\xb2\r\xe8\xdd&\xac\x8c6_\xfd`\xb8]\xfagu.\xfa\xcaH\xd1k7]7P盳\b\x80\v7\xf5hG#\xff\xfc\xbf=e\xff\xb1\x1b\xf1\aU^+*\x80\x01x\x9b\xa1\b6_\x0eU\x00\x9b\xb5\x9e\x9a\xb0>\xbeF\x0e\x01p\x80]W\xa1\xdd\x13\x84\xbc\xdd\xc7A\x8e\xd8\x06\x18\xa3V}6^@\x00``\x11\xd8=F\xcc\x13\x7f\xee\x91\x1c5\t\xa8\xe8\xd8\x0f\x01H\x91z3n'\x1b\xb5{\xae\xbc\xae\x1a\xf6l\x03\x0e\xa9\x00\xcan\xf3\x85\x00乀\xc7:Wk\x17~\x90\u0084\xc5pI\x00\xd6\x0e\xa8\x00\x8c\xdac\xdd\xea\x8e\xfd\x10\x80\x14Y\x9b\x8b\xaeʪaF;\x1a\xb9\xf0\x18\x11pI\x00\x0ehȚD\xb7M3^f\x85\x10\x80\xa1XoF\r\x19]{uc\xd574r\xfeG\x16\xc6\x15\xba\xc1\x9b\x0eA%}\xbb\x0f\x02P\x00IM\xb7d\xd5x\xed\xfe\x9aXWxn\x89\xc5q\xa6\x0f\xf0\xfa6\xc0\x84\xfa\x88UA\x00R\xc9\xfe\xc6\xea\xca\xdb\xfe[x\xe2\x05\"\xe0\xe06\xa0J6_\b@\xd6\xd9?\xd4\x17\x87\xfd\xf7\xf0\xc4\v\x85\xa7\x9f\xb3P\x8e\b\x80\x95\x96\x19\xfb!\x00\xa9\xb0r'\xba\xb2\xff1්\xc0\xfb\xcf\x11\x01w*\x80{d\x7f\x04 \x9d\x05\vt\xabןE\x04\nf\xedX%\xde\xee\x83\x00\xe4\xb5\xf7\x9f\x8b\xdeh\xfc!\x02\x0eW\x00ݠ\x926_\xfd\xc0i\xc0^\x83\xbf\x195l\xa8\xef\x06\xfd\xff;\x8b\xa70\n\xc9\xfb\xe6~g\xbd=\xd5\xfa\xaf\x0fX\t*\x80\xa1Ij\xbd\x97\xfeo\xa3vnI\xc1\x89\x17,d\x8e\x84'V\xc8\xfeT\x00\xc5g\xff\xbdl=:û\x02\xf2\xc8lc\x9b\xf3\xef\xfc\xe1?/\xb3\x12T\x00\xc3\xef%kz\x98\xd6\xef\x1a9\xffD\xa6\xbeɢf\x99\xd5F;\xeav&~\xc7J \x00C\xb3r'\xba\xd2o\xe3\xef\xf0\xba4\xd1ȅ\x1f\x10\x81,K\xff\xa9\xf5\xd6\xcc\x1f\xff\xd8f%\x10\x80\xa1K\xff~\xc6~\xfd\x88@\xed\xfc\x13\x99\xd1\x0e\x8b\x9c>\xed\xb0\xb1\xc4\xde\x1f\x01H#P\xf5/\xa9f\xff}ej\xed\xc2\x0f\x88@\xda7\xf4\xf4\x1ag\xfd\xfb\xb9\x0fY\x82\x83\xb3\x7fZ\x8d\xbfC\xfb\v\x9b5u\x1e\x9d\xc1O0\x8d\x9b\x99\xb1\x1f\x15@Z\f;\xf6\xa3\x12(\xe0f\x9e\xd8\xfa\x98U@\x00\x86fu.\xbat\xd0i\xbf\xccD\xe0\xfc\x8f\x18\x8a\fs#\x8fo\xb4&\xff\xed\xbf\x17X\t\x04`\xf8\x804\xfa,\xf7\xbfY\xdf\xc0Uh\b\x01\x1d9\xb9J\xe3\x0f\x01\x18\x9e\xd4\xc7~}\x8a@\r/\x81\xfe\x19\xedܮ\xdf]h\xb3\x10\b\xc0Pd6\xf6\xeb\xe7\x82L\xadb(\xd2\x1f\xedw\xce\xff\x80\xc9'\x020\xff\v\x8d?*\x80\xe1qe\xec\xd7\x0f\xb5\xb3O+\xed%\x10\x9e\xfa?J\x7f*\x80\x94\xb2\xbfcc\xbf\x9e\xe9\x06\xdb\xd6b\xeb\xa3\xd5\xcaV#\xc9\xfc;\xff>\x8f\xcd\x17\x15@\n\xd9\xff\x88\xb7\xfb\xb8\x9d\x06w\\\x85*t\x82Ќv\xd45DZ\xf9B\x00\x86\xa7\u05f7\xfb\xb8.\x02U:F\x8c\xcd\x17\x02\x90ޗv|\xec\xd7OV\xac\x88\b`\xf3\x85\x00\xa4\xb4\xf7\xf7d\xec\x87\b\xfc\x82\x95\xb0\xf9\xca\xea\xfe\xa9T\xf0\xfb\xdc\xf8;*H6k\xda\xfa\xe6\xac\xd4-\x97\xa6\xf3\xbc?\x15@j\xe4e\xf3UT%PFC\x11l\xbe\xa8\x00Ra\xe5\xd3(\n\x12}U\xf6\xefi\u05cfi\xeb\xd1\xe9RT\x02\xc1\xf8F\xeb\x9d\xcf\xff\x82\x00P\x01\xa4\xf0E\xad\xc7c\xbf~\x14\xbd\xbeQ\n/\x013\xdaQwk\x92\xc6\x1f\x02\x90B\xf6/\xd0\xe6\xab\b\xc2\x13/\xfcw\x15\x1a\xed\xdcf\xec\x87\x00\f\x8d\v6_\x85\x89\x80\xbf^\x02\xd8|!\x00)\xed\x89\x03]\xadR\xf6\x7fM\x04<5\x14\t\xde_\xc6\xe6+\xaf\xadVٳ\x7fY\xc7~\xfd\xe0\x93\xb5\x18c?*\x80\xd4(\xf3د\xac\x95\x006_\b@*\xac܉\xae\xe4\xf9v\x1f/D\xc0qW\xa1`|\x83\xe7\xfd\x11\x80\x94\xbeX@\xf6\x7fC\x04\xdcv\x15\xc2\xe6\v\x01H\x87\xb5\xb9\xa8\xb2\x8d\xbf\xa3\xa8\x9d[\x92\x99x\xe9\xde\xde\x1f\x9b\xafbֽl_\xc8G\x9b\xaf\xdcq\xccP\xc4Ԓ\xf6\xd4\x7f\xcc\xd3\xf8\xa3\x02\x18\x1e\x1fm\xbe\xf2\xdf\v\xec\x18\x8a8b-\x16\x9e}J\xe9O\x05\x90R\xf6g\xec\xd73v\xb3\xa6Σ3\xb2\x9bŽ\x1e\x02\x9b/*\x80\x14S\x89\xeesI\xfbP\xff\x82\xbd\x04\xb0\xf9B\x00Rc\xe5Nt\xc5J\x17\xb9\xa4\xfe\x88\x80\x19\xe92\xf6C\x00R\xfa\"\x8c\xfd|\x13\x01\xc6~\b@J{\xff\x92\xd9|\x15&\x02\xe7\x7f\xcc\xcdP\x04\x9b/G\xae\xbb\xf7\xc1ߌ\x1aI\xa8\xaf\x8c4\xcd\xe5L!0s0\x14\xe1\xed>T\x00\xa9\x91\xd4t\x8b\xe0O18\xeb\x1b\x1a9\xffc\xb67\xdd\xd4\x06.?T\x00\xe9d\x7f\xc6~\xd9\xd0}6\xa9\xee\xe2\xa9\xf4o8cZS\x7f\xfa\x13\x02@\x05\x90B\xb9Z\xd3C.a6d\xe1*dF;JFFh\xfc!\x00\xc3S5\x9b\xaf\xc2D \xcdc\xc4\xd8|!\x00\xa9d\xfef4\xcd\xd8/'\x11H\xcfK\x00\x9b/\x04 \x1d^\x86ⴟg\"\x10\x1c\xdf\xc2\xe6\xcbA\xbck\x02\xd2\xf8+\x8eA\xadŰ\xf9\xa2\x02H\rl\xbe\x8a\xad\x04\x061\x14\xc1\xe6\v\x01H\x85չ\xe8\x126_\xc5R;\xb7ԗ\b`\xf3\x85\x00\xa4\xb7_1\xfa\x8cK\xe6\x86\b\xf4\xe4%\x10&\xcb#'W\x19\xfb!\x00\xc3\xc3\xd8\xcf-z1\x141\xf5\xcd{\xf5\xbb\vd\x7f\x97\x93\xaa\x0f\x1f\x12\x9b/G\xe9\x06\xda\xfa\xe6\xecA\x86\"\xed\xe9?\xff\x99\xc6\x1f\x15\xc0\xf0`\xf3\xe5(ar\xe01\xe2Zc\x89ҟ\n \xa5\xec\xcf\xd8\xcfi\xf6[\x8ba\xf3E\x05\x90\xde\xcd\x15\xd2\xf8s>\x8b\xec1\x141a\x82͗G\xd4\\\xfep\xabs\xd1%\x19E\x92\xda\\*\xc7E\xe0XG#\xff\xf0X\x1b\x8b\xa7\xbe\x9c\xf9\xc3<\xd7\v\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00|\xe4\xff\x01\xf6P(\xf3)+S\x1f\x00\x00\x00\x00IEND\xaeB`\x82"),
+}
diff --git a/encryption/cert.go b/encryption/cert.go
new file mode 100644
index 000000000..3f6d5c679
--- /dev/null
+++ b/encryption/cert.go
@@ -0,0 +1,19 @@
+package encryption
+
+import "crypto/tls"
+
+func LoadTLSConfig(certFile, keyFile string) (*tls.Config, error) {
+ serverCert, err := tls.LoadX509KeyPair(certFile, keyFile)
+ if err != nil {
+ return nil, err
+ }
+
+ config := &tls.Config{
+ Certificates: []tls.Certificate{serverCert},
+ ClientAuth: tls.NoClientCert,
+ NextProtos: []string{
+ "h2", "http/1.1", // enable HTTP/2
+ },
+ }
+ return config, nil
+}
diff --git a/encryption/letsencrypt.go b/encryption/letsencrypt.go
index cfe54ec5a..27a5e3110 100644
--- a/encryption/letsencrypt.go
+++ b/encryption/letsencrypt.go
@@ -9,7 +9,7 @@ import (
)
// CreateCertManager wraps common logic of generating Let's encrypt certificate.
-func CreateCertManager(datadir string, letsencryptDomain string) (*autocert.Manager, error) {
+func CreateCertManager(datadir string, letsencryptDomain ...string) (*autocert.Manager, error) {
certDir := filepath.Join(datadir, "letsencrypt")
if _, err := os.Stat(certDir); os.IsNotExist(err) {
@@ -24,7 +24,7 @@ func CreateCertManager(datadir string, letsencryptDomain string) (*autocert.Mana
certManager := &autocert.Manager{
Prompt: autocert.AcceptTOS,
Cache: autocert.DirCache(certDir),
- HostPolicy: autocert.HostWhitelist(letsencryptDomain),
+ HostPolicy: autocert.HostWhitelist(letsencryptDomain...),
}
return certManager, nil
diff --git a/encryption/route53.go b/encryption/route53.go
new file mode 100644
index 000000000..3c81ab103
--- /dev/null
+++ b/encryption/route53.go
@@ -0,0 +1,87 @@
+package encryption
+
+import (
+ "context"
+ "crypto/tls"
+ "fmt"
+ "os"
+ "strings"
+
+ "github.com/caddyserver/certmagic"
+ "github.com/libdns/route53"
+ log "github.com/sirupsen/logrus"
+ "go.uber.org/zap"
+ "go.uber.org/zap/zapcore"
+ "golang.org/x/crypto/acme"
+)
+
+// Route53TLS by default, loads the AWS configuration from the environment.
+// env variables: AWS_REGION, AWS_PROFILE, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN
+type Route53TLS struct {
+ DataDir string
+ Email string
+ Domains []string
+ CA string
+}
+
+func (r *Route53TLS) GetCertificate() (*tls.Config, error) {
+ if len(r.Domains) == 0 {
+ return nil, fmt.Errorf("no domains provided")
+ }
+
+ certmagic.Default.Logger = logger()
+ certmagic.Default.Storage = &certmagic.FileStorage{Path: r.DataDir}
+ certmagic.DefaultACME.Agreed = true
+ if r.Email != "" {
+ certmagic.DefaultACME.Email = r.Email
+ } else {
+ certmagic.DefaultACME.Email = emailFromDomain(r.Domains[0])
+ }
+
+ if r.CA == "" {
+ certmagic.DefaultACME.CA = certmagic.LetsEncryptProductionCA
+ } else {
+ certmagic.DefaultACME.CA = r.CA
+ }
+
+ certmagic.DefaultACME.DNS01Solver = &certmagic.DNS01Solver{
+ DNSManager: certmagic.DNSManager{
+ DNSProvider: &route53.Provider{},
+ },
+ }
+ cm := certmagic.NewDefault()
+ if err := cm.ManageSync(context.Background(), r.Domains); err != nil {
+ log.Errorf("failed to manage certificate: %v", err)
+ return nil, err
+ }
+
+ tlsConfig := &tls.Config{
+ GetCertificate: cm.GetCertificate,
+ NextProtos: []string{"h2", "http/1.1", acme.ALPNProto},
+ }
+
+ return tlsConfig, nil
+}
+
+func emailFromDomain(domain string) string {
+ if domain == "" {
+ return ""
+ }
+
+ parts := strings.Split(domain, ".")
+ if len(parts) < 2 {
+ return ""
+ }
+ if parts[0] == "" {
+ return ""
+ }
+ return fmt.Sprintf("admin@%s.%s", parts[len(parts)-2], parts[len(parts)-1])
+}
+
+func logger() *zap.Logger {
+ return zap.New(zapcore.NewCore(
+ zapcore.NewConsoleEncoder(zap.NewProductionEncoderConfig()),
+ os.Stderr,
+ zap.ErrorLevel,
+ ))
+}
diff --git a/encryption/route53_test.go b/encryption/route53_test.go
new file mode 100644
index 000000000..765b60f84
--- /dev/null
+++ b/encryption/route53_test.go
@@ -0,0 +1,84 @@
+package encryption
+
+import (
+ "context"
+ "io"
+ "net/http"
+ "os"
+ "testing"
+ "time"
+)
+
+func TestRoute53TLSConfig(t *testing.T) {
+ t.SkipNow() // This test requires AWS credentials
+ exampleString := "Hello, world!"
+ rtls := &Route53TLS{
+ DataDir: t.TempDir(),
+ Email: os.Getenv("LE_EMAIL_ROUTE53"),
+ Domains: []string{os.Getenv("DOMAIN")},
+ }
+ tlsConfig, err := rtls.GetCertificate()
+ if err != nil {
+ t.Errorf("Route53TLSConfig failed: %v", err)
+ }
+
+ server := &http.Server{
+ Addr: ":8443",
+ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _, _ = w.Write([]byte(exampleString))
+ }),
+ TLSConfig: tlsConfig,
+ }
+
+ go func() {
+ err := server.ListenAndServeTLS("", "")
+ if err != http.ErrServerClosed {
+ t.Errorf("Failed to start server: %v", err)
+ }
+ }()
+ defer func() {
+ if err := server.Shutdown(context.Background()); err != nil {
+ t.Errorf("Failed to shutdown server: %v", err)
+ }
+ }()
+
+ time.Sleep(1 * time.Second)
+ resp, err := http.Get("https://relay.godevltd.com:8443")
+ if err != nil {
+ t.Errorf("Failed to get response: %v", err)
+ return
+ }
+ defer func() {
+ _ = resp.Body.Close()
+ }()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ t.Errorf("Failed to read response body: %v", err)
+ }
+ if string(body) != exampleString {
+ t.Errorf("Unexpected response: %s", body)
+ }
+}
+
+func Test_emailFromDomain(t *testing.T) {
+ tests := []struct {
+ input string
+ want string
+ }{
+ {"example.com", "admin@example.com"},
+ {"x.example.com", "admin@example.com"},
+ {"x.x.example.com", "admin@example.com"},
+ {"*.example.com", "admin@example.com"},
+ {"example", ""},
+ {"", ""},
+ {".com", ""},
+ }
+ for _, tt := range tests {
+ t.Run("domain test", func(t *testing.T) {
+ if got := emailFromDomain(tt.input); got != tt.want {
+ t.Errorf("emailFromDomain() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/go.mod b/go.mod
index e80a7eb46..d9f6162e5 100644
--- a/go.mod
+++ b/go.mod
@@ -1,6 +1,6 @@
module github.com/netbirdio/netbird
-go 1.21.0
+go 1.23.0
require (
cunicu.li/go-rosenpass v0.4.0
@@ -10,9 +10,9 @@ require (
github.com/golang/protobuf v1.5.4
github.com/google/uuid v1.6.0
github.com/gorilla/mux v1.8.0
- github.com/kardianos/service v1.2.1-0.20210728001519-a323c3813bc7
+ github.com/kardianos/service v1.2.3-0.20240613133416-becf2eb62b83
github.com/onsi/ginkgo v1.16.5
- github.com/onsi/gomega v1.23.0
+ github.com/onsi/gomega v1.27.6
github.com/pion/ice/v3 v3.0.2
github.com/rs/cors v1.8.0
github.com/sirupsen/logrus v1.9.3
@@ -34,6 +34,7 @@ require (
fyne.io/systray v1.11.0
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
github.com/c-robinson/iplib v1.0.3
+ github.com/caddyserver/certmagic v0.21.3
github.com/cilium/ebpf v0.15.0
github.com/coreos/go-iptables v0.7.0
github.com/creack/pty v1.1.18
@@ -50,18 +51,21 @@ require (
github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
github.com/hashicorp/go-version v1.6.0
+ github.com/libdns/route53 v1.5.0
github.com/libp2p/go-netroute v0.2.1
github.com/magiconair/properties v1.8.7
github.com/mattn/go-sqlite3 v1.14.19
github.com/mdlayher/socket v0.4.1
- github.com/miekg/dns v1.1.43
+ github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0
- github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e
+ github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd
+ github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pion/logging v0.2.2
+ github.com/pion/randutil v0.1.0
github.com/pion/stun/v2 v2.0.0
github.com/pion/transport/v3 v3.0.1
github.com/pion/turn/v3 v3.0.1
@@ -70,6 +74,7 @@ require (
github.com/rs/xid v1.3.0
github.com/shirou/gopsutil/v3 v3.24.4
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
+ github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8
github.com/stretchr/testify v1.9.0
github.com/testcontainers/testcontainers-go v0.31.0
github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0
@@ -81,6 +86,7 @@ require (
go.opentelemetry.io/otel/exporters/prometheus v0.48.0
go.opentelemetry.io/otel/metric v1.26.0
go.opentelemetry.io/otel/sdk/metric v1.26.0
+ go.uber.org/zap v1.27.0
goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
@@ -93,6 +99,7 @@ require (
gorm.io/driver/postgres v1.5.7
gorm.io/driver/sqlite v1.5.3
gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde
+ nhooyr.io/websocket v1.8.11
)
require (
@@ -106,8 +113,23 @@ require (
github.com/Microsoft/hcsshim v0.12.3 // indirect
github.com/XiaoMi/pegasus-go-client v0.0.0-20210427083443-f3b6b08bc4c2 // indirect
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
+ github.com/aws/aws-sdk-go-v2 v1.30.3 // indirect
+ github.com/aws/aws-sdk-go-v2/config v1.27.27 // indirect
+ github.com/aws/aws-sdk-go-v2/credentials v1.17.27 // indirect
+ github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 // indirect
+ github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 // indirect
+ github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 // indirect
+ github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect
+ github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 // indirect
+ github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 // indirect
+ github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3 // indirect
+ github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 // indirect
+ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 // indirect
+ github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 // indirect
+ github.com/aws/smithy-go v1.20.3 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d // indirect
+ github.com/caddyserver/zerossl v0.1.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/containerd/containerd v1.7.16 // indirect
github.com/containerd/log v0.1.0 // indirect
@@ -116,7 +138,7 @@ require (
github.com/dgraph-io/ristretto v0.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/distribution/reference v0.6.0 // indirect
- github.com/docker/docker v26.1.4+incompatible // indirect
+ github.com/docker/docker v26.1.5+incompatible // indirect
github.com/docker/go-connections v0.5.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
@@ -140,7 +162,7 @@ require (
github.com/googleapis/gax-go/v2 v2.12.3 // indirect
github.com/gopherjs/gopherjs v1.17.2 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
- github.com/hashicorp/go-uuid v1.0.2 // indirect
+ github.com/hashicorp/go-uuid v1.0.3 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
@@ -149,13 +171,17 @@ require (
github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
+ github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect
github.com/kelseyhightower/envconfig v1.4.0 // indirect
github.com/klauspost/compress v1.17.8 // indirect
+ github.com/klauspost/cpuid/v2 v2.2.7 // indirect
+ github.com/libdns/libdns v0.2.2 // indirect
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
+ github.com/mholt/acmez/v2 v2.0.1 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/patternmatcher v0.6.0 // indirect
github.com/moby/sys/sequential v0.5.0 // indirect
@@ -164,12 +190,12 @@ require (
github.com/morikuni/aec v1.0.0 // indirect
github.com/nicksnyder/go-i18n/v2 v2.4.0 // indirect
github.com/nxadm/tail v1.4.8 // indirect
+ github.com/onsi/ginkgo/v2 v2.9.5 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/pegasus-kv/thrift v0.13.0 // indirect
github.com/pion/dtls/v2 v2.2.10 // indirect
github.com/pion/mdns v0.0.12 // indirect
- github.com/pion/randutil v0.1.0 // indirect
github.com/pion/transport/v2 v2.2.4 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
@@ -186,10 +212,12 @@ require (
github.com/tklauser/numcpus v0.8.0 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
github.com/yuin/goldmark v1.7.1 // indirect
+ github.com/zeebo/blake3 v0.2.3 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
go.opentelemetry.io/otel/sdk v1.26.0 // indirect
go.opentelemetry.io/otel/trace v1.26.0 // indirect
+ go.uber.org/multierr v1.11.0 // indirect
golang.org/x/image v0.18.0 // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/text v0.16.0 // indirect
@@ -205,7 +233,7 @@ require (
k8s.io/apimachinery v0.26.2 // indirect
)
-replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0
+replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
diff --git a/go.sum b/go.sum
index 6d0408013..d823f505a 100644
--- a/go.sum
+++ b/go.sum
@@ -79,6 +79,34 @@ github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kd
github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o=
github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY=
github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
+github.com/aws/aws-sdk-go-v2 v1.30.3 h1:jUeBtG0Ih+ZIFH0F4UkmL9w3cSpaMv9tYYDbzILP8dY=
+github.com/aws/aws-sdk-go-v2 v1.30.3/go.mod h1:nIQjQVp5sfpQcTc9mPSr1B0PaWK5ByX9MOoDadSN4lc=
+github.com/aws/aws-sdk-go-v2/config v1.27.27 h1:HdqgGt1OAP0HkEDDShEl0oSYa9ZZBSOmKpdpsDMdO90=
+github.com/aws/aws-sdk-go-v2/config v1.27.27/go.mod h1:MVYamCg76dFNINkZFu4n4RjDixhVr51HLj4ErWzrVwg=
+github.com/aws/aws-sdk-go-v2/credentials v1.17.27 h1:2raNba6gr2IfA0eqqiP2XiQ0UVOpGPgDSi0I9iAP+UI=
+github.com/aws/aws-sdk-go-v2/credentials v1.17.27/go.mod h1:gniiwbGahQByxan6YjQUMcW4Aov6bLC3m+evgcoN4r4=
+github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 h1:KreluoV8FZDEtI6Co2xuNk/UqI9iwMrOx/87PBNIKqw=
+github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11/go.mod h1:SeSUYBLsMYFoRvHE0Tjvn7kbxaUhl75CJi1sbfhMxkU=
+github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 h1:SoNJ4RlFEQEbtDcCEt+QG56MY4fm4W8rYirAmq+/DdU=
+github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15/go.mod h1:U9ke74k1n2bf+RIgoX1SXFed1HLs51OgUSs+Ph0KJP8=
+github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 h1:C6WHdGnTDIYETAm5iErQUiVNsclNx9qbJVPIt03B6bI=
+github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15/go.mod h1:ZQLZqhcu+JhSrA9/NXRm8SkDvsycE+JkV3WGY41e+IM=
+github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU=
+github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY=
+github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 h1:dT3MqvGhSoaIhRseqw2I0yH81l7wiR2vjs57O51EAm8=
+github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3/go.mod h1:GlAeCkHwugxdHaueRr4nhPuY+WW+gR8UjlcqzPr1SPI=
+github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 h1:HGErhhrxZlQ044RiM+WdoZxp0p+EGM62y3L6pwA4olE=
+github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17/go.mod h1:RkZEx4l0EHYDJpWppMJ3nD9wZJAa8/0lq9aVC+r2UII=
+github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3 h1:MmLCRqP4U4Cw9gJ4bNrCG0mWqEtBlmAVleyelcHARMU=
+github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3/go.mod h1:AMPjK2YnRh0YgOID3PqhJA1BRNfXDfGOnSsKHtAe8yA=
+github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 h1:BXx0ZIxvrJdSgSvKTZ+yRBeSqqgPM89VPlulEcl37tM=
+github.com/aws/aws-sdk-go-v2/service/sso v1.22.4/go.mod h1:ooyCOXjvJEsUw7x+ZDHeISPMhtwI3ZCB7ggFMcFfWLU=
+github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 h1:yiwVzJW2ZxZTurVbYWA7QOrAaCYQR72t0wrSBfoesUE=
+github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4/go.mod h1:0oxfLkpz3rQ/CHlx5hB7H69YUpFiI1tql6Q6Ne+1bCw=
+github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 h1:ZsDKRLXGWHk8WdtyYMoGNO7bTudrvuKpDKgMVRlepGE=
+github.com/aws/aws-sdk-go-v2/service/sts v1.30.3/go.mod h1:zwySh8fpFyXp9yOr/KVzxOl8SRqgf/IDw5aUt9UKFcQ=
+github.com/aws/smithy-go v1.20.3 h1:ryHwveWzPV5BIof6fyDvor6V3iUL7nTfiTKXHiW05nE=
+github.com/aws/smithy-go v1.20.3/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs=
@@ -87,6 +115,10 @@ github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d h1:pVrfxiGfwel
github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d/go.mod h1:H0wQNHz2YrLsuXOZozoeDmnHXkNCRmMW0gwFWDfEZDA=
github.com/c-robinson/iplib v1.0.3 h1:NG0UF0GoEsrC1/vyfX1Lx2Ss7CySWl3KqqXh3q4DdPU=
github.com/c-robinson/iplib v1.0.3/go.mod h1:i3LuuFL1hRT5gFpBRnEydzw8R6yhGkF4szNDIbF8pgo=
+github.com/caddyserver/certmagic v0.21.3 h1:pqRRry3yuB4CWBVq9+cUqu+Y6E2z8TswbhNx1AZeYm0=
+github.com/caddyserver/certmagic v0.21.3/go.mod h1:Zq6pklO9nVRl3DIFUw9gVUfXKdpc/0qwTUAQMBlfgtI=
+github.com/caddyserver/zerossl v0.1.3 h1:onS+pxp3M8HnHpN5MMbOMyNjmTheJyWRaZYwn+YTAyA=
+github.com/caddyserver/zerossl v0.1.3/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4=
github.com/cenkalti/backoff/v4 v4.1.0/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
@@ -132,8 +164,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
-github.com/docker/docker v26.1.4+incompatible h1:vuTpXDuoga+Z38m1OZHzl7NKisKWaWlhjQk7IDPSLsU=
-github.com/docker/docker v26.1.4+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
+github.com/docker/docker v26.1.5+incompatible h1:NEAxTwEjxV6VbBMBoGG3zPqbiJosIApZjxlbrG9q3/g=
+github.com/docker/docker v26.1.5+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
@@ -207,6 +239,8 @@ github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZs
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
+github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
+github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/go-text/render v0.1.0 h1:osrmVDZNHuP1RSu3pNG7Z77Sd2xSbcb/xWytAj9kyVs=
github.com/go-text/render v0.1.0/go.mod h1:jqEuNMenrmj6QRnkdpeaP0oKGFLDNhDkVKwGjsWWYU4=
github.com/go-text/typesetting v0.1.0 h1:vioSaLPYcHwPEPLT7gsjCGDCoYSbljxoHJzMnKwVvHw=
@@ -350,8 +384,9 @@ github.com/hashicorp/go-sockaddr v1.0.0/go.mod h1:7Xibr9yA9JjQq1JpNB2Vw7kxv8xerX
github.com/hashicorp/go-syslog v1.0.0/go.mod h1:qPfqrKkXGihmCqbJM2mZgkZGvKG1dFdvsLplgctolz4=
github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
-github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE=
github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
+github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8=
+github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/go-version v1.6.0 h1:feTTfFNnjP967rlCxM/I9g701jU+RN74YKx2mOkIeek=
github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90=
@@ -382,6 +417,10 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
+github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
+github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
+github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
+github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/json-iterator/go v0.0.0-20180612202835-f2b4162afba3/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
@@ -401,6 +440,9 @@ github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU=
github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
+github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c=
+github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
+github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
@@ -413,6 +455,10 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
+github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s=
+github.com/libdns/libdns v0.2.2/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ=
+github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA=
+github.com/libdns/route53 v1.5.0/go.mod h1:joT4hKmaTNKHEwb7GmZ65eoDz1whTu7KKYPS8ZqIh6Q=
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tAFlj1FYZl8ztUZ13bdq+PLY+NOfbyI=
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k=
@@ -431,9 +477,11 @@ github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U=
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
+github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k=
+github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U=
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
-github.com/miekg/dns v1.1.43 h1:JKfpVSCB84vrAmHzyrsxB5NAr5kLoMXZArPSw7Qlgyg=
-github.com/miekg/dns v1.1.43/go.mod h1:+evo5L0630/F6ca/Z9+GAqzhjGyn8/c+TBaOyfEl0V4=
+github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs=
+github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc=
@@ -473,10 +521,12 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
-github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e h1:LYxhAmiEzSldLELHSMVoUnRPq3ztTNQImrD27frrGsI=
-github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
-github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g=
-github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
+github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd h1:phKq1S1Y/lnqEhP5Qknta733+rPX16dRDHM7hKkot9c=
+github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y=
+github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
+github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
+github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f h1:Rl23OSc2xKFyxiuBXtWDMzhZBV4gOM7lhFxvYoCmBZg=
+github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs=
github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA=
github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM=
@@ -492,14 +542,14 @@ github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+
github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
-github.com/onsi/ginkgo/v2 v2.4.0 h1:+Ig9nvqgS5OBSACXNk15PLdp0U9XPYROt9CFzVdFGIs=
-github.com/onsi/ginkgo/v2 v2.4.0/go.mod h1:iHkDK1fKGcBoEHT5W7YBq4RFWaQulw+caOMkAt4OrFo=
+github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
+github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
github.com/onsi/gomega v0.0.0-20170829124025-dcabb60a477c/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA=
github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
-github.com/onsi/gomega v1.23.0 h1:/oxKu9c2HVap+F3PfKort2Hw5DEU+HGlW8n+tguWsys=
-github.com/onsi/gomega v1.23.0/go.mod h1:Z/NWtiqwBrwUt4/2loMmHL63EDLnYHmVbuBpDr2vQAg=
+github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
+github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug=
@@ -592,6 +642,8 @@ github.com/smartystreets/assertions v1.13.0/go.mod h1:wDmR7qL282YbGsPy6H/yAsesrx
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
github.com/smartystreets/goconvey v1.7.2 h1:9RBaZCeXEQ3UselpuwUQHltGVXvdwm6cv1hgR6gDIPg=
github.com/smartystreets/goconvey v1.7.2/go.mod h1:Vw0tHAZW6lzCRk3xgdin6fKYcG+G3Pg9vgXWeJpQFMM=
+github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8=
+github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E=
github.com/spf13/afero v1.6.0/go.mod h1:Ai8FlHk4v/PARR026UzYexafAt9roJ7LcLMAmO6Z93I=
github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
github.com/spf13/cast v1.5.0 h1:rj3WzYc11XZaIZMPKmwP96zkFEnnAmV8s6XbB2aY32w=
@@ -660,6 +712,12 @@ github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
github.com/zcalusic/sysinfo v1.0.2 h1:nwTTo2a+WQ0NXwo0BGRojOJvJ/5XKvQih+2RrtWqfxc=
github.com/zcalusic/sysinfo v1.0.2/go.mod h1:kluzTYflRWo6/tXVMJPdEjShsbPpsFRyy+p1mBQPC30=
+github.com/zeebo/assert v1.1.0 h1:hU1L1vLTHsnO8x8c9KAR5GmM5QscxHg5RNU5z5qbUWY=
+github.com/zeebo/assert v1.1.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0=
+github.com/zeebo/blake3 v0.2.3 h1:TFoLXsjeXqRNFxSbk35Dk4YtszE/MQQGK10BH4ptoTg=
+github.com/zeebo/blake3 v0.2.3/go.mod h1:mjJjZpnsyIVtVgTOSpJ9vmRE4wgDeyt2HU3qXvvKCaQ=
+github.com/zeebo/pcg v1.0.1 h1:lyqfGeWiv4ahac6ttHs+I5hwtH/+1mrhlCtVNQM2kHo=
+github.com/zeebo/pcg v1.0.1/go.mod h1:09F0S9iiKrwn9rlI5yjLkmrug154/YRW6KnnXVDM/l4=
go.etcd.io/etcd/api/v3 v3.5.0/go.mod h1:cbVKeC6lCfl7j/8jBhAK6aIYO9XOjdptoxU/nLQcPvs=
go.etcd.io/etcd/client/pkg/v3 v3.5.0/go.mod h1:IJHfcCEKxYu1Os13ZdwCwIUTUVGYTSAM3YSwc9/Ac1g=
go.etcd.io/etcd/client/v2 v2.305.0/go.mod h1:h9puh54ZTgAKtEbut2oe9P4L/oqKCVB6xsXlzd7alYQ=
@@ -695,8 +753,14 @@ go.opentelemetry.io/otel/trace v1.26.0/go.mod h1:4iDxvGDQuUkHve82hJJ8UqrwswHYsZu
go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I=
go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
+go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
+go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
+go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
+go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.17.0/go.mod h1:MXVU+bhUf/A7Xi2HNOnopQOrmycQ5Ih87HtOu4q5SSo=
+go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
+go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
goauthentik.io/api/v3 v3.2023051.3 h1:NebAhD/TeTWNo/9X3/Uj+rM5fG1HaiLOlKTNLQv9Qq4=
goauthentik.io/api/v3 v3.2023051.3/go.mod h1:nYECml4jGbp/541hj8GcylKQG1gVBsKppHy4+7G8u4U=
golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
@@ -890,7 +954,6 @@ golang.org/x/sys v0.0.0-20210104204734-6f8348627aad/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210220050731-9a76102bfb43/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210305230114-8fe3ee5dd75b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -1187,6 +1250,8 @@ k8s.io/gengo v0.0.0-20190128074634-0689ccc1d7d6/go.mod h1:ezvh/TsK7cY6rbqRK0oQQ8
k8s.io/klog v0.0.0-20181102134211-b9b56d5dfc92/go.mod h1:Gq+BEi5rUBO/HRz0bTSXDUcqjScdoY3a9IHpCEIOOfk=
k8s.io/klog v1.0.0/go.mod h1:4Bi6QPql/J/LkTDqv7R/cd3hPo4k2DG6Ptcz060Ez5I=
k8s.io/kube-openapi v0.0.0-20191107075043-30be4d16710a/go.mod h1:1TqjTSzOxsLGIKfj0lK8EeCP7K1iUG65v09OM0/WG5E=
+nhooyr.io/websocket v1.8.11 h1:f/qXNc2/3DpoSZkHt1DQu6rj4zGC8JmkkLkWss0MgN0=
+nhooyr.io/websocket v1.8.11/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c=
rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8=
rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0=
rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA=
diff --git a/iface/iface_darwin.go b/iface/iface_darwin.go
deleted file mode 100644
index 15e4a7817..000000000
--- a/iface/iface_darwin.go
+++ /dev/null
@@ -1,38 +0,0 @@
-//go:build !ios
-
-package iface
-
-import (
- "fmt"
-
- "github.com/pion/transport/v3"
-
- "github.com/netbirdio/netbird/iface/bind"
- "github.com/netbirdio/netbird/iface/netstack"
-)
-
-// NewWGIFace Creates a new WireGuard interface instance
-func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
- wgAddress, err := parseWGAddress(address)
- if err != nil {
- return nil, err
- }
-
- wgIFace := &WGIface{
- userspaceBind: true,
- }
-
- if netstack.IsEnabled() {
- wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
- return wgIFace, nil
- }
-
- wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
-
- return wgIFace, nil
-}
-
-// CreateOnAndroid this function make sense on mobile only
-func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
- return fmt.Errorf("this function has not implemented on this platform")
-}
diff --git a/iface/tun.go b/iface/tun.go
deleted file mode 100644
index b3c0f9d80..000000000
--- a/iface/tun.go
+++ /dev/null
@@ -1,18 +0,0 @@
-//go:build !android
-// +build !android
-
-package iface
-
-import (
- "github.com/netbirdio/netbird/iface/bind"
-)
-
-type wgTunDevice interface {
- Create() (wgConfigurer, error)
- Up() (*bind.UniversalUDPMuxDefault, error)
- UpdateAddr(address WGAddress) error
- WgAddress() WGAddress
- DeviceName() string
- Close() error
- Wrapper() *DeviceWrapper // todo eliminate this function
-}
diff --git a/iface/wg_configurer.go b/iface/wg_configurer.go
deleted file mode 100644
index dd38ba075..000000000
--- a/iface/wg_configurer.go
+++ /dev/null
@@ -1,21 +0,0 @@
-package iface
-
-import (
- "errors"
- "net"
- "time"
-
- "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
-)
-
-var ErrPeerNotFound = errors.New("peer not found")
-
-type wgConfigurer interface {
- configureInterface(privateKey string, port int) error
- updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
- removePeer(peerKey string) error
- addAllowedIP(peerKey string, allowedIP string) error
- removeAllowedIP(peerKey string, allowedIP string) error
- close()
- getStats(peerKey string) (WGStats, error)
-}
diff --git a/infrastructure_files/base.setup.env b/infrastructure_files/base.setup.env
index 296e165f0..45dce8d88 100644
--- a/infrastructure_files/base.setup.env
+++ b/infrastructure_files/base.setup.env
@@ -20,6 +20,12 @@ NETBIRD_MGMT_IDP_SIGNKEY_REFRESH=${NETBIRD_MGMT_IDP_SIGNKEY_REFRESH:-false}
NETBIRD_SIGNAL_PROTOCOL="http"
NETBIRD_SIGNAL_PORT=${NETBIRD_SIGNAL_PORT:-10000}
+# Relay
+NETBIRD_RELAY_DOMAIN=${NETBIRD_RELAY_DOMAIN:-$NETBIRD_DOMAIN}
+NETBIRD_RELAY_PORT=${NETBIRD_RELAY_PORT:-33080}
+# Relay auth secret
+NETBIRD_RELAY_AUTH_SECRET=
+
# Turn
TURN_DOMAIN=${NETBIRD_TURN_DOMAIN:-$NETBIRD_DOMAIN}
@@ -69,7 +75,7 @@ NETBIRD_DASHBOARD_TAG=${NETBIRD_DASHBOARD_TAG:-"latest"}
NETBIRD_SIGNAL_TAG=${NETBIRD_SIGNAL_TAG:-"latest"}
NETBIRD_MANAGEMENT_TAG=${NETBIRD_MANAGEMENT_TAG:-"latest"}
COTURN_TAG=${COTURN_TAG:-"latest"}
-
+NETBIRD_RELAY_TAG=${NETBIRD_RELAY_TAG:-"latest"}
# exports
export NETBIRD_DOMAIN
@@ -123,3 +129,7 @@ export NETBIRD_SIGNAL_TAG
export NETBIRD_MANAGEMENT_TAG
export COTURN_TAG
export NETBIRD_TURN_EXTERNAL_IP
+export NETBIRD_RELAY_DOMAIN
+export NETBIRD_RELAY_PORT
+export NETBIRD_RELAY_AUTH_SECRET
+export NETBIRD_RELAY_TAG
diff --git a/infrastructure_files/configure.sh b/infrastructure_files/configure.sh
index f04735de6..ff33004b2 100755
--- a/infrastructure_files/configure.sh
+++ b/infrastructure_files/configure.sh
@@ -41,6 +41,18 @@ if [[ "x-$NETBIRD_DOMAIN" == "x-" ]]; then
exit 1
fi
+# Check if PostgreSQL is set as the store engine
+if [[ "$NETBIRD_STORE_CONFIG_ENGINE" == "postgres" ]]; then
+ # Exit if 'NETBIRD_STORE_ENGINE_POSTGRES_DSN' is not set
+ if [[ -z "$NETBIRD_STORE_ENGINE_POSTGRES_DSN" ]]; then
+ echo "Warning: NETBIRD_STORE_CONFIG_ENGINE=postgres but NETBIRD_STORE_ENGINE_POSTGRES_DSN is not set."
+ echo "Please add the following line to your setup.env file:"
+ echo 'NETBIRD_STORE_ENGINE_POSTGRES_DSN="host= user= password= dbname= port="'
+ exit 1
+ fi
+ export NETBIRD_STORE_ENGINE_POSTGRES_DSN
+fi
+
# local development or tests
if [[ $NETBIRD_DOMAIN == "localhost" || $NETBIRD_DOMAIN == "127.0.0.1" ]]; then
export NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN="netbird.selfhosted"
@@ -77,6 +89,11 @@ fi
export TURN_EXTERNAL_IP_CONFIG
+# if not provided, we generate a relay auth secret
+if [[ "x-$NETBIRD_RELAY_AUTH_SECRET" == "x-" ]]; then
+ export NETBIRD_RELAY_AUTH_SECRET=$(openssl rand -base64 32 | sed 's/=//g')
+fi
+
artifacts_path="./artifacts"
mkdir -p $artifacts_path
diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl
index 6b6831493..ba68b3f8d 100644
--- a/infrastructure_files/docker-compose.yml.tmpl
+++ b/infrastructure_files/docker-compose.yml.tmpl
@@ -49,6 +49,23 @@ services:
options:
max-size: "500m"
max-file: "2"
+ # Relay
+ relay:
+ image: netbirdio/relay:$NETBIRD_RELAY_TAG
+ restart: unless-stopped
+ environment:
+ - NB_LOG_LEVEL=info
+ - NB_LISTEN_ADDRESS=:$NETBIRD_RELAY_PORT
+ - NB_EXPOSED_ADDRESS=$NETBIRD_RELAY_DOMAIN:$NETBIRD_RELAY_PORT
+ # todo: change to a secure secret
+ - NB_AUTH_SECRET=$NETBIRD_RELAY_AUTH_SECRET
+ ports:
+ - $NETBIRD_RELAY_PORT:$NETBIRD_RELAY_PORT
+ logging:
+ driver: "json-file"
+ options:
+ max-size: "500m"
+ max-file: "2"
# Management
management:
@@ -77,6 +94,9 @@ services:
options:
max-size: "500m"
max-file: "2"
+ environment:
+ - NETBIRD_STORE_ENGINE_POSTGRES_DSN=$NETBIRD_STORE_ENGINE_POSTGRES_DSN
+
# Coturn
coturn:
image: coturn/coturn:$COTURN_TAG
diff --git a/infrastructure_files/docker-compose.yml.tmpl.traefik b/infrastructure_files/docker-compose.yml.tmpl.traefik
index d3ae6529a..c4415d848 100644
--- a/infrastructure_files/docker-compose.yml.tmpl.traefik
+++ b/infrastructure_files/docker-compose.yml.tmpl.traefik
@@ -81,7 +81,9 @@ services:
- traefik.http.routers.netbird-management.service=netbird-management
- traefik.http.services.netbird-management.loadbalancer.server.port=443
- traefik.http.services.netbird-management.loadbalancer.server.scheme=h2c
-
+ environment:
+ - NETBIRD_STORE_ENGINE_POSTGRES_DSN=$NETBIRD_STORE_ENGINE_POSTGRES_DSN
+
# Coturn
coturn:
image: coturn/coturn:$COTURN_TAG
diff --git a/infrastructure_files/download-geolite2.sh b/infrastructure_files/download-geolite2.sh
deleted file mode 100755
index 4a9db5e01..000000000
--- a/infrastructure_files/download-geolite2.sh
+++ /dev/null
@@ -1,109 +0,0 @@
-#!/bin/bash
-
-# to install sha256sum on mac: brew install coreutils
-if ! command -v sha256sum &> /dev/null
-then
- echo "sha256sum is not installed or not in PATH, please install with your package manager. e.g. sudo apt install sha256sum" > /dev/stderr
- exit 1
-fi
-
-if ! command -v sqlite3 &> /dev/null
-then
- echo "sqlite3 is not installed or not in PATH, please install with your package manager. e.g. sudo apt install sqlite3" > /dev/stderr
- exit 1
-fi
-
-if ! command -v unzip &> /dev/null
-then
- echo "unzip is not installed or not in PATH, please install with your package manager. e.g. sudo apt install unzip" > /dev/stderr
- exit 1
-fi
-
-download_geolite_mmdb() {
- DATABASE_URL="https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz"
- SIGNATURE_URL="https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz.sha256"
- # Download the database and signature files
- echo "Downloading mmdb signature file..."
- SIGNATURE_FILE=$(curl -s -L -O -J "$SIGNATURE_URL" -w "%{filename_effective}")
- echo "Downloading mmdb database file..."
- DATABASE_FILE=$(curl -s -L -O -J "$DATABASE_URL" -w "%{filename_effective}")
-
- # Verify the signature
- echo "Verifying signature..."
- if sha256sum -c --status "$SIGNATURE_FILE"; then
- echo "Signature is valid."
- else
- echo "Signature is invalid. Aborting."
- exit 1
- fi
-
- # Unpack the database file
- EXTRACTION_DIR=$(basename "$DATABASE_FILE" .tar.gz)
- echo "Unpacking $DATABASE_FILE..."
- mkdir -p "$EXTRACTION_DIR"
- tar -xzvf "$DATABASE_FILE" > /dev/null 2>&1
-
- MMDB_FILE="GeoLite2-City.mmdb"
- cp "$EXTRACTION_DIR"/"$MMDB_FILE" $MMDB_FILE
-
- # Remove downloaded files
- rm -r "$EXTRACTION_DIR"
- rm "$DATABASE_FILE" "$SIGNATURE_FILE"
-
- # Done. Print next steps
- echo ""
- echo "Process completed successfully."
- echo "Now you can place $MMDB_FILE to 'datadir' of management service."
- echo -e "Example:\n\tdocker compose cp $MMDB_FILE management:/var/lib/netbird/"
-}
-
-
-download_geolite_csv_and_create_sqlite_db() {
- DATABASE_URL="https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City-CSV/download?suffix=zip"
- SIGNATURE_URL="https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City-CSV/download?suffix=zip.sha256"
-
-
- # Download the database file
- echo "Downloading csv signature file..."
- SIGNATURE_FILE=$(curl -s -L -O -J "$SIGNATURE_URL" -w "%{filename_effective}")
- echo "Downloading csv database file..."
- DATABASE_FILE=$(curl -s -L -O -J "$DATABASE_URL" -w "%{filename_effective}")
-
- # Verify the signature
- echo "Verifying signature..."
- if sha256sum -c --status "$SIGNATURE_FILE"; then
- echo "Signature is valid."
- else
- echo "Signature is invalid. Aborting."
- exit 1
- fi
-
- # Unpack the database file
- EXTRACTION_DIR=$(basename "$DATABASE_FILE" .zip)
- DB_NAME="geonames.db"
-
- echo "Unpacking $DATABASE_FILE..."
- unzip "$DATABASE_FILE" > /dev/null 2>&1
-
-# Create SQLite database and import data from CSV
-sqlite3 "$DB_NAME" < dashboard.env
echo "" > turnserver.conf
echo "" > management.json
+ echo "" > relay.env
mkdir -p machinekey
chmod 777 machinekey
@@ -498,6 +514,7 @@ initEnvironment() {
renderTurnServerConf > turnserver.conf
renderManagementJson > management.json
renderDashboardEnv > dashboard.env
+ renderRelayEnv > relay.env
echo -e "\nStarting NetBird services\n"
$DOCKER_COMPOSE_COMMAND up -d
@@ -541,7 +558,7 @@ renderCaddyfile() {
# clickjacking protection
# https://cheatsheetseries.owasp.org/cheatsheets/HTTP_Headers_Cheat_Sheet.html#x-frame-options
- X-Frame-Options "DENY"
+ X-Frame-Options "SAMEORIGIN"
# xss protection
# https://cheatsheetseries.owasp.org/cheatsheets/HTTP_Headers_Cheat_Sheet.html#x-xss-protection
@@ -559,6 +576,8 @@ renderCaddyfile() {
:80${CADDY_SECURE_DOMAIN} {
import security_headers
+ # relay
+ reverse_proxy /relay* relay:80
# Signal
reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000
# Management
@@ -629,6 +648,11 @@ renderManagementJson() {
],
"TimeBasedCredentials": false
},
+ "Relay": {
+ "Addresses": ["$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT"],
+ "CredentialsTTL": "24h",
+ "Secret": "$NETBIRD_RELAY_AUTH_SECRET"
+ },
"Signal": {
"Proto": "$NETBIRD_HTTP_PROTOCOL",
"URI": "$NETBIRD_DOMAIN:$NETBIRD_PORT"
@@ -744,6 +768,15 @@ POSTGRES_PASSWORD=$POSTGRES_ROOT_PASSWORD
EOF
}
+renderRelayEnv() {
+ cat < management.PeerSystemMeta
17, // 1: management.SyncResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig
- 20, // 2: management.SyncResponse.peerConfig:type_name -> management.PeerConfig
- 22, // 3: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig
- 21, // 4: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap
- 37, // 5: management.SyncResponse.Checks:type_name -> management.Checks
+ 21, // 2: management.SyncResponse.peerConfig:type_name -> management.PeerConfig
+ 23, // 3: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig
+ 22, // 4: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap
+ 38, // 5: management.SyncResponse.Checks:type_name -> management.Checks
13, // 6: management.SyncMetaRequest.meta:type_name -> management.PeerSystemMeta
13, // 7: management.LoginRequest.meta:type_name -> management.PeerSystemMeta
10, // 8: management.LoginRequest.peerKeys:type_name -> management.PeerKeys
- 36, // 9: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress
+ 37, // 9: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress
11, // 10: management.PeerSystemMeta.environment:type_name -> management.Environment
12, // 11: management.PeerSystemMeta.files:type_name -> management.File
17, // 12: management.LoginResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig
- 20, // 13: management.LoginResponse.peerConfig:type_name -> management.PeerConfig
- 37, // 14: management.LoginResponse.Checks:type_name -> management.Checks
- 38, // 15: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp
+ 21, // 13: management.LoginResponse.peerConfig:type_name -> management.PeerConfig
+ 38, // 14: management.LoginResponse.Checks:type_name -> management.Checks
+ 42, // 15: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp
18, // 16: management.WiretrusteeConfig.stuns:type_name -> management.HostConfig
- 19, // 17: management.WiretrusteeConfig.turns:type_name -> management.ProtectedHostConfig
+ 20, // 17: management.WiretrusteeConfig.turns:type_name -> management.ProtectedHostConfig
18, // 18: management.WiretrusteeConfig.signal:type_name -> management.HostConfig
- 0, // 19: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol
- 18, // 20: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig
- 23, // 21: management.PeerConfig.sshConfig:type_name -> management.SSHConfig
- 20, // 22: management.NetworkMap.peerConfig:type_name -> management.PeerConfig
- 22, // 23: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig
- 29, // 24: management.NetworkMap.Routes:type_name -> management.Route
- 30, // 25: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig
- 22, // 26: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig
- 35, // 27: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule
- 23, // 28: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig
- 1, // 29: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider
- 28, // 30: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig
- 28, // 31: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig
- 33, // 32: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup
- 31, // 33: management.DNSConfig.CustomZones:type_name -> management.CustomZone
- 32, // 34: management.CustomZone.Records:type_name -> management.SimpleRecord
- 34, // 35: management.NameServerGroup.NameServers:type_name -> management.NameServer
- 2, // 36: management.FirewallRule.Direction:type_name -> management.FirewallRule.direction
- 3, // 37: management.FirewallRule.Action:type_name -> management.FirewallRule.action
- 4, // 38: management.FirewallRule.Protocol:type_name -> management.FirewallRule.protocol
- 5, // 39: management.ManagementService.Login:input_type -> management.EncryptedMessage
- 5, // 40: management.ManagementService.Sync:input_type -> management.EncryptedMessage
- 16, // 41: management.ManagementService.GetServerKey:input_type -> management.Empty
- 16, // 42: management.ManagementService.isHealthy:input_type -> management.Empty
- 5, // 43: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage
- 5, // 44: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage
- 5, // 45: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage
- 5, // 46: management.ManagementService.Login:output_type -> management.EncryptedMessage
- 5, // 47: management.ManagementService.Sync:output_type -> management.EncryptedMessage
- 15, // 48: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse
- 16, // 49: management.ManagementService.isHealthy:output_type -> management.Empty
- 5, // 50: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage
- 5, // 51: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage
- 16, // 52: management.ManagementService.SyncMeta:output_type -> management.Empty
- 46, // [46:53] is the sub-list for method output_type
- 39, // [39:46] is the sub-list for method input_type
- 39, // [39:39] is the sub-list for extension type_name
- 39, // [39:39] is the sub-list for extension extendee
- 0, // [0:39] is the sub-list for field type_name
+ 19, // 19: management.WiretrusteeConfig.relay:type_name -> management.RelayConfig
+ 3, // 20: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol
+ 18, // 21: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig
+ 24, // 22: management.PeerConfig.sshConfig:type_name -> management.SSHConfig
+ 21, // 23: management.NetworkMap.peerConfig:type_name -> management.PeerConfig
+ 23, // 24: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig
+ 30, // 25: management.NetworkMap.Routes:type_name -> management.Route
+ 31, // 26: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig
+ 23, // 27: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig
+ 36, // 28: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule
+ 40, // 29: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule
+ 24, // 30: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig
+ 4, // 31: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider
+ 29, // 32: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig
+ 29, // 33: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig
+ 34, // 34: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup
+ 32, // 35: management.DNSConfig.CustomZones:type_name -> management.CustomZone
+ 33, // 36: management.CustomZone.Records:type_name -> management.SimpleRecord
+ 35, // 37: management.NameServerGroup.NameServers:type_name -> management.NameServer
+ 1, // 38: management.FirewallRule.Direction:type_name -> management.RuleDirection
+ 2, // 39: management.FirewallRule.Action:type_name -> management.RuleAction
+ 0, // 40: management.FirewallRule.Protocol:type_name -> management.RuleProtocol
+ 41, // 41: management.PortInfo.range:type_name -> management.PortInfo.Range
+ 2, // 42: management.RouteFirewallRule.action:type_name -> management.RuleAction
+ 0, // 43: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol
+ 39, // 44: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo
+ 5, // 45: management.ManagementService.Login:input_type -> management.EncryptedMessage
+ 5, // 46: management.ManagementService.Sync:input_type -> management.EncryptedMessage
+ 16, // 47: management.ManagementService.GetServerKey:input_type -> management.Empty
+ 16, // 48: management.ManagementService.isHealthy:input_type -> management.Empty
+ 5, // 49: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage
+ 5, // 50: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage
+ 5, // 51: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage
+ 5, // 52: management.ManagementService.Login:output_type -> management.EncryptedMessage
+ 5, // 53: management.ManagementService.Sync:output_type -> management.EncryptedMessage
+ 15, // 54: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse
+ 16, // 55: management.ManagementService.isHealthy:output_type -> management.Empty
+ 5, // 56: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage
+ 5, // 57: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage
+ 16, // 58: management.ManagementService.SyncMeta:output_type -> management.Empty
+ 52, // [52:59] is the sub-list for method output_type
+ 45, // [45:52] is the sub-list for method input_type
+ 45, // [45:45] is the sub-list for extension type_name
+ 45, // [45:45] is the sub-list for extension extendee
+ 0, // [0:45] is the sub-list for field type_name
}
func init() { file_management_proto_init() }
@@ -3256,7 +3630,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[14].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*ProtectedHostConfig); i {
+ switch v := v.(*RelayConfig); i {
case 0:
return &v.state
case 1:
@@ -3268,7 +3642,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[15].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*PeerConfig); i {
+ switch v := v.(*ProtectedHostConfig); i {
case 0:
return &v.state
case 1:
@@ -3280,7 +3654,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[16].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*NetworkMap); i {
+ switch v := v.(*PeerConfig); i {
case 0:
return &v.state
case 1:
@@ -3292,7 +3666,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[17].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*RemotePeerConfig); i {
+ switch v := v.(*NetworkMap); i {
case 0:
return &v.state
case 1:
@@ -3304,7 +3678,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[18].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*SSHConfig); i {
+ switch v := v.(*RemotePeerConfig); i {
case 0:
return &v.state
case 1:
@@ -3316,7 +3690,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[19].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*DeviceAuthorizationFlowRequest); i {
+ switch v := v.(*SSHConfig); i {
case 0:
return &v.state
case 1:
@@ -3328,7 +3702,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*DeviceAuthorizationFlow); i {
+ switch v := v.(*DeviceAuthorizationFlowRequest); i {
case 0:
return &v.state
case 1:
@@ -3340,7 +3714,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*PKCEAuthorizationFlowRequest); i {
+ switch v := v.(*DeviceAuthorizationFlow); i {
case 0:
return &v.state
case 1:
@@ -3352,7 +3726,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[22].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*PKCEAuthorizationFlow); i {
+ switch v := v.(*PKCEAuthorizationFlowRequest); i {
case 0:
return &v.state
case 1:
@@ -3364,7 +3738,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[23].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*ProviderConfig); i {
+ switch v := v.(*PKCEAuthorizationFlow); i {
case 0:
return &v.state
case 1:
@@ -3376,7 +3750,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*Route); i {
+ switch v := v.(*ProviderConfig); i {
case 0:
return &v.state
case 1:
@@ -3388,7 +3762,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[25].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*DNSConfig); i {
+ switch v := v.(*Route); i {
case 0:
return &v.state
case 1:
@@ -3400,7 +3774,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[26].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*CustomZone); i {
+ switch v := v.(*DNSConfig); i {
case 0:
return &v.state
case 1:
@@ -3412,7 +3786,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[27].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*SimpleRecord); i {
+ switch v := v.(*CustomZone); i {
case 0:
return &v.state
case 1:
@@ -3424,7 +3798,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[28].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*NameServerGroup); i {
+ switch v := v.(*SimpleRecord); i {
case 0:
return &v.state
case 1:
@@ -3436,7 +3810,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[29].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*NameServer); i {
+ switch v := v.(*NameServerGroup); i {
case 0:
return &v.state
case 1:
@@ -3448,7 +3822,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[30].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*FirewallRule); i {
+ switch v := v.(*NameServer); i {
case 0:
return &v.state
case 1:
@@ -3460,7 +3834,7 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[31].Exporter = func(v interface{}, i int) interface{} {
- switch v := v.(*NetworkAddress); i {
+ switch v := v.(*FirewallRule); i {
case 0:
return &v.state
case 1:
@@ -3472,6 +3846,18 @@ func file_management_proto_init() {
}
}
file_management_proto_msgTypes[32].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*NetworkAddress); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ file_management_proto_msgTypes[33].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Checks); i {
case 0:
return &v.state
@@ -3483,6 +3869,46 @@ func file_management_proto_init() {
return nil
}
}
+ file_management_proto_msgTypes[34].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*PortInfo); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ file_management_proto_msgTypes[35].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*RouteFirewallRule); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ file_management_proto_msgTypes[36].Exporter = func(v interface{}, i int) interface{} {
+ switch v := v.(*PortInfo_Range); i {
+ case 0:
+ return &v.state
+ case 1:
+ return &v.sizeCache
+ case 2:
+ return &v.unknownFields
+ default:
+ return nil
+ }
+ }
+ }
+ file_management_proto_msgTypes[34].OneofWrappers = []interface{}{
+ (*PortInfo_Port)(nil),
+ (*PortInfo_Range_)(nil),
}
type x struct{}
out := protoimpl.TypeBuilder{
@@ -3490,7 +3916,7 @@ func file_management_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_management_proto_rawDesc,
NumEnums: 5,
- NumMessages: 33,
+ NumMessages: 37,
NumExtensions: 0,
NumServices: 1,
},
diff --git a/management/proto/management.proto b/management/proto/management.proto
index 06b243773..fe6a828b1 100644
--- a/management/proto/management.proto
+++ b/management/proto/management.proto
@@ -177,6 +177,8 @@ message WiretrusteeConfig {
// a Signal server config
HostConfig signal = 3;
+
+ RelayConfig relay = 4;
}
// HostConfig describes connection properties of some server (e.g. STUN, Signal, Management)
@@ -193,6 +195,13 @@ message HostConfig {
DTLS = 4;
}
}
+
+message RelayConfig {
+ repeated string urls = 1;
+ string tokenPayload = 2;
+ string tokenSignature = 3;
+}
+
// ProtectedHostConfig is similar to HostConfig but has additional user and password
// Mostly used for TURN servers
message ProtectedHostConfig {
@@ -245,6 +254,12 @@ message NetworkMap {
// firewallRulesIsEmpty indicates whether FirewallRule array is empty or not to bypass protobuf null and empty array equality.
bool firewallRulesIsEmpty = 9;
+
+ // RoutesFirewallRules represents a list of routes firewall rules to be applied to peer
+ repeated RouteFirewallRule routesFirewallRules = 10;
+
+ // RoutesFirewallRulesIsEmpty indicates whether RouteFirewallRule array is empty or not to bypass protobuf null and empty array equality.
+ bool routesFirewallRulesIsEmpty = 11;
}
// RemotePeerConfig represents a configuration of a remote peer.
@@ -375,29 +390,32 @@ message NameServer {
int64 Port = 3;
}
+enum RuleProtocol {
+ UNKNOWN = 0;
+ ALL = 1;
+ TCP = 2;
+ UDP = 3;
+ ICMP = 4;
+}
+
+enum RuleDirection {
+ IN = 0;
+ OUT = 1;
+}
+
+enum RuleAction {
+ ACCEPT = 0;
+ DROP = 1;
+}
+
+
// FirewallRule represents a firewall rule
message FirewallRule {
string PeerIP = 1;
- direction Direction = 2;
- action Action = 3;
- protocol Protocol = 4;
+ RuleDirection Direction = 2;
+ RuleAction Action = 3;
+ RuleProtocol Protocol = 4;
string Port = 5;
-
- enum direction {
- IN = 0;
- OUT = 1;
- }
- enum action {
- ACCEPT = 0;
- DROP = 1;
- }
- enum protocol {
- UNKNOWN = 0;
- ALL = 1;
- TCP = 2;
- UDP = 3;
- ICMP = 4;
- }
}
message NetworkAddress {
@@ -406,5 +424,40 @@ message NetworkAddress {
}
message Checks {
- repeated string Files= 1;
+ repeated string Files = 1;
}
+
+
+message PortInfo {
+ oneof portSelection {
+ uint32 port = 1;
+ Range range = 2;
+ }
+
+ message Range {
+ uint32 start = 1;
+ uint32 end = 2;
+ }
+}
+
+// RouteFirewallRule signifies a firewall rule applicable for a routed network.
+message RouteFirewallRule {
+ // sourceRanges IP ranges of the routing peers.
+ repeated string sourceRanges = 1;
+
+ // Action to be taken by the firewall when the rule is applicable.
+ RuleAction action = 2;
+
+ // Network prefix for the routed network.
+ string destination = 3;
+
+ // Protocol of the routed network.
+ RuleProtocol protocol = 4;
+
+ // Details about the port.
+ PortInfo portInfo = 5;
+
+ // IsDynamic indicates if the route is a DNS route.
+ bool isDynamic = 6;
+}
+
diff --git a/management/server/account.go b/management/server/account.go
index 0a91ae6d8..d5e8c8cf8 100644
--- a/management/server/account.go
+++ b/management/server/account.go
@@ -20,11 +20,6 @@ import (
cacheStore "github.com/eko/gocache/v3/store"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
- gocache "github.com/patrickmn/go-cache"
- "github.com/rs/xid"
- log "github.com/sirupsen/logrus"
- "golang.org/x/exp/maps"
-
"github.com/netbirdio/netbird/base62"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
@@ -41,6 +36,10 @@ import (
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
+ gocache "github.com/patrickmn/go-cache"
+ "github.com/rs/xid"
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/exp/maps"
)
const (
@@ -63,6 +62,7 @@ func cacheEntryExpiration() time.Duration {
type AccountManager interface {
GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*Account, error)
+ GetAccount(ctx context.Context, accountID string) (*Account, error)
CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration,
autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error)
SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error)
@@ -75,12 +75,14 @@ type AccountManager interface {
SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error)
SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error)
GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error)
- GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error)
- GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error)
+ GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error)
+ GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error)
+ GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
DeleteAccount(ctx context.Context, accountID, userID string) error
MarkPATUsed(ctx context.Context, tokenID string) error
+ GetUserByID(ctx context.Context, id string) (*User, error)
GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error)
ListUsers(ctx context.Context, accountID string) ([]*User, error)
GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
@@ -94,6 +96,7 @@ type AccountManager interface {
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error)
GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error)
+ UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error
GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error)
GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error)
GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error)
@@ -106,11 +109,11 @@ type AccountManager interface {
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
- SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error
+ SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
- CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
+ CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
SaveRoute(ctx context.Context, accountID, userID string, route *route.Route) error
DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error
ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error)
@@ -144,6 +147,7 @@ type AccountManager interface {
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
+ GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error)
}
type DefaultAccountManager struct {
@@ -160,6 +164,8 @@ type DefaultAccountManager struct {
eventStore activity.Store
geo *geolocation.Geolocation
+ requestBuffer *AccountRequestBuffer
+
// singleAccountMode indicates whether the instance has a single account.
// If true, then every new user will end up under the same account.
// This value will be set to false if management service has more than one account.
@@ -260,6 +266,16 @@ type AccountSettings struct {
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
}
+// Subclass used in gorm to only load network and not whole account
+type AccountNetwork struct {
+ Network *Network `gorm:"embedded;embeddedPrefix:network_"`
+}
+
+// AccountDNSSettings used in gorm to only load dns settings and not whole account
+type AccountDNSSettings struct {
+ DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
+}
+
type UserPermissions struct {
DashboardView string `json:"dashboard_view"`
}
@@ -444,6 +460,7 @@ func (a *Account) GetPeerNetworkMap(
}
routesUpdate := a.getRoutesToSync(ctx, peerID, peersToConnect)
+ routesFirewallRules := a.getPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap)
dnsManagementStatus := a.getPeerDNSManagementStatus(peerID)
dnsUpdate := nbdns.Config{
@@ -467,12 +484,19 @@ func (a *Account) GetPeerNetworkMap(
DNSConfig: dnsUpdate,
OfflinePeers: expiredPeers,
FirewallRules: firewallRules,
+ RoutesFirewallRules: routesFirewallRules,
}
if metrics != nil {
objectCount := int64(len(peersToConnect) + len(expiredPeers) + len(routesUpdate) + len(firewallRules))
metrics.CountNetworkMapObjects(objectCount)
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
+
+ if objectCount > 5000 {
+ log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+
+ "peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d",
+ a.Id, objectCount, len(peersToConnect), len(expiredPeers), len(routesUpdate), len(firewallRules))
+ }
}
return nm
@@ -691,14 +715,6 @@ func (a *Account) GetPeerGroupsList(peerID string) []string {
return grps
}
-func (a *Account) getUserGroups(userID string) ([]string, error) {
- user, err := a.FindUser(userID)
- if err != nil {
- return nil, err
- }
- return user.AutoGroups, nil
-}
-
func (a *Account) getPeerDNSManagementStatus(peerID string) bool {
peerGroups := a.getPeerGroups(peerID)
enabled := true
@@ -725,14 +741,6 @@ func (a *Account) getPeerGroups(peerID string) lookupMap {
return groupList
}
-func (a *Account) getSetupKeyGroups(setupKey string) ([]string, error) {
- key, err := a.FindSetupKey(setupKey)
- if err != nil {
- return nil, err
- }
- return key.AutoGroups, nil
-}
-
func (a *Account) getTakenIPs() []net.IP {
var takenIps []net.IP
for _, existingPeer := range a.Peers {
@@ -966,6 +974,7 @@ func BuildManager(
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
integratedPeerValidator: integratedPeerValidator,
metrics: metrics,
+ requestBuffer: NewAccountRequestBuffer(ctx, store),
}
allAccounts := store.GetAllAccounts(ctx)
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
@@ -1253,25 +1262,37 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return nil
}
-// GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and
-// userID doesn't have an account associated with it, one account is created
-// domain is used to create a new account if no account is found
-func (am *DefaultAccountManager) GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error) {
+// GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided.
+// If an accountID is provided, it checks if the account exists and returns it.
+// If no accountID is provided, but a userID is given, it tries to retrieve the account by userID.
+// If the user doesn't have an account, it creates one using the provided domain.
+// Returns the account ID or an error if none is found or created.
+func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) {
if accountID != "" {
- return am.Store.GetAccount(ctx, accountID)
- } else if userID != "" {
- account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
+ exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID)
if err != nil {
- return nil, status.Errorf(status.NotFound, "account not found using user id: %s", userID)
+ return "", err
}
- err = am.addAccountIDToIDPAppMeta(ctx, userID, account)
- if err != nil {
- return nil, err
+ if !exists {
+ return "", status.Errorf(status.NotFound, "account %s does not exist", accountID)
}
- return account, nil
+ return accountID, nil
}
- return nil, status.Errorf(status.NotFound, "no valid user or account Id provided")
+ if userID != "" {
+ account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
+ if err != nil {
+ return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
+ }
+
+ if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
+ return "", err
+ }
+
+ return account.Id, nil
+ }
+
+ return "", status.Errorf(status.NotFound, "no valid userID or accountID provided")
}
func isNil(i idp.Manager) bool {
@@ -1614,13 +1635,18 @@ func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domai
}
// redeemInvite checks whether user has been invited and redeems the invite
-func (am *DefaultAccountManager) redeemInvite(ctx context.Context, account *Account, userID string) error {
+func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID string, userID string) error {
// only possible with the enabled IdP manager
if am.idpManager == nil {
log.WithContext(ctx).Warnf("invites only work with enabled IdP manager")
return nil
}
+ account, err := am.Store.GetAccount(ctx, accountID)
+ if err != nil {
+ return err
+ }
+
user, err := am.lookupUserInCache(ctx, userID, account)
if err != nil {
return err
@@ -1679,6 +1705,11 @@ func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string
return am.Store.SaveAccount(ctx, account)
}
+// GetAccount returns an account associated with this account ID.
+func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*Account, error) {
+ return am.Store.GetAccount(ctx, accountID)
+}
+
// GetAccountFromPAT returns Account and User associated with a personal access token
func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*Account, *User, *PersonalAccessToken, error) {
if len(token) != PATLength {
@@ -1727,10 +1758,24 @@ func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token st
return account, user, pat, nil
}
-// GetAccountFromToken returns an account associated with this token
-func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) {
+// GetAccountByID returns an account associated with this account ID.
+func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) {
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
+ if err != nil {
+ return nil, err
+ }
+
+ if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
+ return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
+ }
+
+ return am.Store.GetAccount(ctx, accountID)
+}
+
+// GetAccountIDFromToken returns an account ID associated with this token.
+func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
if claims.UserId == "" {
- return nil, nil, fmt.Errorf("user ID is empty")
+ return "", "", fmt.Errorf("user ID is empty")
}
if am.singleAccountMode && am.singleAccountModeDomain != "" {
// This section is mostly related to self-hosted installations.
@@ -1740,115 +1785,111 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
}
- newAcc, err := am.getAccountWithAuthorizationClaims(ctx, claims)
+ accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, claims)
if err != nil {
- return nil, nil, err
- }
- unlock := am.Store.AcquireWriteLockByUID(ctx, newAcc.Id)
- alreadyUnlocked := false
- defer func() {
- if !alreadyUnlocked {
- unlock()
- }
- }()
-
- account, err := am.Store.GetAccount(ctx, newAcc.Id)
- if err != nil {
- return nil, nil, err
+ return "", "", err
}
- user := account.Users[claims.UserId]
- if user == nil {
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId)
+ if err != nil {
// this is not really possible because we got an account by user ID
- return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId)
+ return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
}
if !user.IsServiceUser && claims.Invited {
- err = am.redeemInvite(ctx, account, claims.UserId)
+ err = am.redeemInvite(ctx, accountID, user.Id)
if err != nil {
- return nil, nil, err
+ return "", "", err
}
}
- if account.Settings.JWTGroupsEnabled {
- if account.Settings.JWTGroupsClaimName == "" {
- log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set")
- return account, user, nil
- }
- if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
- if slice, ok := claim.([]interface{}); ok {
- var groupsNames []string
- for _, item := range slice {
- if g, ok := item.(string); ok {
- groupsNames = append(groupsNames, g)
- } else {
- log.WithContext(ctx).Errorf("JWT claim %q is not a string: %v", account.Settings.JWTGroupsClaimName, item)
- }
- }
-
- oldGroups := make([]string, len(user.AutoGroups))
- copy(oldGroups, user.AutoGroups)
- // if groups were added or modified, save the account
- if account.SetJWTGroups(claims.UserId, groupsNames) {
- if account.Settings.GroupsPropagationEnabled {
- if user, err := account.FindUser(claims.UserId); err == nil {
- addNewGroups := difference(user.AutoGroups, oldGroups)
- removeOldGroups := difference(oldGroups, user.AutoGroups)
- account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
- account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
-
- account.Network.IncSerial()
- if err := am.Store.SaveAccount(ctx, account); err != nil {
- log.WithContext(ctx).Errorf("failed to save account: %v", err)
- } else {
- log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
-
- if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) {
- am.updateAccountPeers(ctx, account)
- }
- unlock()
-
- alreadyUnlocked = true
- for _, g := range addNewGroups {
- if group := account.GetGroup(g); group != nil {
- am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
- map[string]any{
- "group": group.Name,
- "group_id": group.ID,
- "is_service_user": user.IsServiceUser,
- "user_name": user.ServiceUserName})
- }
- }
- for _, g := range removeOldGroups {
- if group := account.GetGroup(g); group != nil {
- am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
- map[string]any{
- "group": group.Name,
- "group_id": group.ID,
- "is_service_user": user.IsServiceUser,
- "user_name": user.ServiceUserName})
- }
- }
- }
- }
- } else {
- if err := am.Store.SaveAccount(ctx, account); err != nil {
- log.WithContext(ctx).Errorf("failed to save account: %v", err)
- }
- }
- }
- } else {
- log.WithContext(ctx).Debugf("JWT claim %q is not a string array", account.Settings.JWTGroupsClaimName)
- }
- } else {
- log.WithContext(ctx).Debugf("JWT claim %q not found", account.Settings.JWTGroupsClaimName)
- }
+ if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil {
+ return "", "", err
}
- return account, user, nil
+ return accountID, user.Id, nil
}
-// getAccountWithAuthorizationClaims retrievs an account using JWT Claims.
+// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
+// and propagates changes to peers if group propagation is enabled.
+func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, user *User, claims jwtclaims.AuthorizationClaims) error {
+ settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
+ if err != nil {
+ return err
+ }
+
+ if settings == nil || !settings.JWTGroupsEnabled {
+ return nil
+ }
+
+ if settings.JWTGroupsClaimName == "" {
+ log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set")
+ return nil
+ }
+
+ // TODO: Remove GetAccount after refactoring account peer's update
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
+ defer unlock()
+
+ account, err := am.Store.GetAccount(ctx, accountID)
+ if err != nil {
+ return err
+ }
+
+ jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
+
+ oldGroups := make([]string, len(user.AutoGroups))
+ copy(oldGroups, user.AutoGroups)
+
+ // Update the account if group membership changes
+ if account.SetJWTGroups(claims.UserId, jwtGroupsNames) {
+ addNewGroups := difference(user.AutoGroups, oldGroups)
+ removeOldGroups := difference(oldGroups, user.AutoGroups)
+
+ if settings.GroupsPropagationEnabled {
+ account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
+ account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
+ account.Network.IncSerial()
+ }
+
+ if err := am.Store.SaveAccount(ctx, account); err != nil {
+ log.WithContext(ctx).Errorf("failed to save account: %v", err)
+ return nil
+ }
+
+ // Propagate changes to peers if group propagation is enabled
+ if settings.GroupsPropagationEnabled {
+ log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
+ am.updateAccountPeers(ctx, account)
+ }
+
+ for _, g := range addNewGroups {
+ if group := account.GetGroup(g); group != nil {
+ am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
+ map[string]any{
+ "group": group.Name,
+ "group_id": group.ID,
+ "is_service_user": user.IsServiceUser,
+ "user_name": user.ServiceUserName})
+ }
+ }
+
+ for _, g := range removeOldGroups {
+ if group := account.GetGroup(g); group != nil {
+ am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
+ map[string]any{
+ "group": group.Name,
+ "group_id": group.ID,
+ "is_service_user": user.IsServiceUser,
+ "user_name": user.ServiceUserName})
+ }
+ }
+ }
+
+ return nil
+}
+
+// getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims.
// if domain is of the PrivateCategory category, it will evaluate
// if account is new, existing or if there is another account with the same domain
//
@@ -1865,26 +1906,34 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes
//
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
-func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, error) {
+func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
if claims.UserId == "" {
- return nil, fmt.Errorf("user ID is empty")
+ return "", fmt.Errorf("user ID is empty")
}
+
// if Account ID is part of the claims
// it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
- return am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
+ return am.GetAccountIDByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
} else if claims.AccountId != "" {
- accountFromID, err := am.Store.GetAccount(ctx, claims.AccountId)
+ userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
if err != nil {
- return nil, err
+ return "", err
}
- if _, ok := accountFromID.Users[claims.UserId]; !ok {
- return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
+
+ if userAccountID != claims.AccountId {
+ return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
}
- if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain {
- return accountFromID, nil
+
+ domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId)
+ if err != nil {
+ return "", err
+ }
+
+ if domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain {
+ return userAccountID, nil
}
}
@@ -1894,48 +1943,53 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId)
// We checked if the domain has a primary account already
- domainAccount, err := am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
+ domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain)
if err != nil {
// if NotFound we are good to continue, otherwise return error
e, ok := status.FromError(err)
if !ok || e.Type() != status.NotFound {
- return nil, err
+ return "", err
}
}
- account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
+ userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
if err == nil {
- unlockAccount := am.Store.AcquireWriteLockByUID(ctx, account.Id)
+ unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID)
defer unlockAccount()
- account, err = am.Store.GetAccountByUser(ctx, claims.UserId)
+ account, err := am.Store.GetAccountByUser(ctx, claims.UserId)
if err != nil {
- return nil, err
+ return "", err
}
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as
// non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain
// was previously unclassified or classified as public so N users that logged int that time, has they own account
// and peers that shouldn't be lost.
- primaryDomain := domainAccount == nil || account.Id == domainAccount.Id
-
- err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims)
- if err != nil {
- return nil, err
+ primaryDomain := domainAccountID == "" || account.Id == domainAccountID
+ if err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims); err != nil {
+ return "", err
}
- return account, nil
+
+ return account.Id, nil
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
- if domainAccount != nil {
- unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id)
+ var domainAccount *Account
+ if domainAccountID != "" {
+ unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
defer unlockAccount()
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
if err != nil {
- return nil, err
+ return "", err
}
}
- return am.handleNewUserAccount(ctx, domainAccount, claims)
+
+ account, err := am.handleNewUserAccount(ctx, domainAccount, claims)
+ if err != nil {
+ return "", err
+ }
+ return account.Id, nil
} else {
// other error
- return nil, err
+ return "", err
}
}
@@ -2028,26 +2082,21 @@ func (am *DefaultAccountManager) GetDNSDomain() string {
// CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT
// group propagation and set the list of groups with access permissions.
func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
- account, _, err := am.GetAccountFromToken(ctx, claims)
+ accountID, _, err := am.GetAccountIDFromToken(ctx, claims)
+ if err != nil {
+ return err
+ }
+
+ settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
// Ensures JWT group synchronization to the management is enabled before,
// filtering access based on the allowed groups.
- if account.Settings != nil && account.Settings.JWTGroupsEnabled {
- if allowedGroups := account.Settings.JWTAllowGroups; len(allowedGroups) > 0 {
- userJWTGroups := make([]string, 0)
-
- if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok {
- if claimGroups, ok := claim.([]interface{}); ok {
- for _, g := range claimGroups {
- if group, ok := g.(string); ok {
- userJWTGroups = append(userJWTGroups, group)
- }
- }
- }
- }
+ if settings != nil && settings.JWTGroupsEnabled {
+ if allowedGroups := settings.JWTAllowGroups; len(allowedGroups) > 0 {
+ userJWTGroups := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
if !userHasAllowedGroup(allowedGroups, userJWTGroups) {
return fmt.Errorf("user does not belong to any of the allowed JWT groups")
@@ -2076,6 +2125,60 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
return am.Store.GetAccountIDByPeerPubKey(ctx, peerKey)
}
+func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) {
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
+ if err != nil {
+ return false, err
+ }
+
+ err = checkIfPeerOwnerIsBlocked(peer, user)
+ if err != nil {
+ return false, err
+ }
+
+ if peerLoginExpired(ctx, peer, settings) {
+ err = am.handleExpiredPeer(ctx, user, peer)
+ if err != nil {
+ return false, err
+ }
+ return true, nil
+ }
+
+ return false, nil
+}
+
+func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Store, accountID string, peerHostName string) (string, error) {
+ existingLabels, err := store.GetPeerLabelsInAccount(ctx, LockingStrengthShare, accountID)
+ if err != nil {
+ return "", fmt.Errorf("failed to get peer dns labels: %w", err)
+ }
+
+ labelMap := ConvertSliceToMap(existingLabels)
+ newLabel, err := getPeerHostLabel(peerHostName, labelMap)
+ if err != nil {
+ return "", fmt.Errorf("failed to get new host label: %w", err)
+ }
+
+ if newLabel == "" {
+ return "", fmt.Errorf("failed to get new host label: %w", err)
+ }
+
+ return newLabel, nil
+}
+
+func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) {
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
+ if err != nil {
+ return nil, err
+ }
+
+ if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) {
+ return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
+ }
+
+ return am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
+}
+
// addAllGroup to account object if it doesn't exist
func addAllGroup(account *Account) error {
if len(account.Groups) == 0 {
@@ -2158,6 +2261,27 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac
return acc
}
+// extractJWTGroups extracts the group names from a JWT token's claims.
+func extractJWTGroups(ctx context.Context, claimName string, claims jwtclaims.AuthorizationClaims) []string {
+ userJWTGroups := make([]string, 0)
+
+ if claim, ok := claims.Raw[claimName]; ok {
+ if claimGroups, ok := claim.([]interface{}); ok {
+ for _, g := range claimGroups {
+ if group, ok := g.(string); ok {
+ userJWTGroups = append(userJWTGroups, group)
+ } else {
+ log.WithContext(ctx).Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g)
+ }
+ }
+ }
+ } else {
+ log.WithContext(ctx).Debugf("JWT claim %q is not a string array", claimName)
+ }
+
+ return userJWTGroups
+}
+
// userHasAllowedGroup checks if a user belongs to any of the allowed groups.
func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool {
for _, userGroup := range userGroups {
diff --git a/management/server/account_request_buffer.go b/management/server/account_request_buffer.go
new file mode 100644
index 000000000..5f4897e6a
--- /dev/null
+++ b/management/server/account_request_buffer.go
@@ -0,0 +1,108 @@
+package server
+
+import (
+ "context"
+ "os"
+ "sync"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+)
+
+// AccountRequest holds the result channel to return the requested account.
+type AccountRequest struct {
+ AccountID string
+ ResultChan chan *AccountResult
+}
+
+// AccountResult holds the account data or an error.
+type AccountResult struct {
+ Account *Account
+ Err error
+}
+
+type AccountRequestBuffer struct {
+ store Store
+ getAccountRequests map[string][]*AccountRequest
+ mu sync.Mutex
+ getAccountRequestCh chan *AccountRequest
+ bufferInterval time.Duration
+}
+
+func NewAccountRequestBuffer(ctx context.Context, store Store) *AccountRequestBuffer {
+ bufferIntervalStr := os.Getenv("NB_GET_ACCOUNT_BUFFER_INTERVAL")
+ bufferInterval, err := time.ParseDuration(bufferIntervalStr)
+ if err != nil {
+ if bufferIntervalStr != "" {
+ log.WithContext(ctx).Warnf("failed to parse account request buffer interval: %s", err)
+ }
+ bufferInterval = 100 * time.Millisecond
+ }
+
+ log.WithContext(ctx).Infof("set account request buffer interval to %s", bufferInterval)
+
+ ac := AccountRequestBuffer{
+ store: store,
+ getAccountRequests: make(map[string][]*AccountRequest),
+ getAccountRequestCh: make(chan *AccountRequest),
+ bufferInterval: bufferInterval,
+ }
+
+ go ac.processGetAccountRequests(ctx)
+
+ return &ac
+}
+func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context, accountID string) (*Account, error) {
+ req := &AccountRequest{
+ AccountID: accountID,
+ ResultChan: make(chan *AccountResult, 1),
+ }
+
+ log.WithContext(ctx).Tracef("requesting account %s with backpressure", accountID)
+ startTime := time.Now()
+ ac.getAccountRequestCh <- req
+
+ result := <-req.ResultChan
+ log.WithContext(ctx).Tracef("got account with backpressure after %s", time.Since(startTime))
+ return result.Account, result.Err
+}
+
+func (ac *AccountRequestBuffer) processGetAccountBatch(ctx context.Context, accountID string) {
+ ac.mu.Lock()
+ requests := ac.getAccountRequests[accountID]
+ delete(ac.getAccountRequests, accountID)
+ ac.mu.Unlock()
+
+ if len(requests) == 0 {
+ return
+ }
+
+ startTime := time.Now()
+ account, err := ac.store.GetAccount(ctx, accountID)
+ log.WithContext(ctx).Tracef("getting account %s in batch took %s", accountID, time.Since(startTime))
+ result := &AccountResult{Account: account, Err: err}
+
+ for _, req := range requests {
+ req.ResultChan <- result
+ close(req.ResultChan)
+ }
+}
+
+func (ac *AccountRequestBuffer) processGetAccountRequests(ctx context.Context) {
+ for {
+ select {
+ case req := <-ac.getAccountRequestCh:
+ ac.mu.Lock()
+ ac.getAccountRequests[req.AccountID] = append(ac.getAccountRequests[req.AccountID], req)
+ if len(ac.getAccountRequests[req.AccountID]) == 1 {
+ go func(ctx context.Context, accountID string) {
+ time.Sleep(ac.bufferInterval)
+ ac.processGetAccountBatch(ctx, accountID)
+ }(ctx, req.AccountID)
+ }
+ ac.mu.Unlock()
+ case <-ctx.Done():
+ return
+ }
+ }
+}
diff --git a/management/server/account_test.go b/management/server/account_test.go
index d89ce4e4a..1bad43c13 100644
--- a/management/server/account_test.go
+++ b/management/server/account_test.go
@@ -462,7 +462,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) {
assert.Equal(t, account.Id, ev.TargetID)
}
-func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
+func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) {
type initUserParams jwtclaims.AuthorizationClaims
type test struct {
@@ -633,9 +633,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
- initAccount, err := manager.GetAccountByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
+ accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain)
require.NoError(t, err, "create init user failed")
+ initAccount, err := manager.Store.GetAccount(context.Background(), accountID)
+ require.NoError(t, err, "get init account failed")
+
if testCase.inputUpdateAttrs {
err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true)
require.NoError(t, err, "update init user failed")
@@ -645,8 +648,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
testCase.inputClaims.AccountId = initAccount.Id
}
- account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims)
+ accountID, _, err = manager.GetAccountIDFromToken(context.Background(), testCase.inputClaims)
require.NoError(t, err, "support function failed")
+
+ account, err := manager.Store.GetAccount(context.Background(), accountID)
+ require.NoError(t, err, "get account failed")
+
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
@@ -669,12 +676,13 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
require.NoError(t, err, "unable to create account manager")
accountID := initAccount.Id
- acc, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, accountID, domain)
+ accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain)
require.NoError(t, err, "create init user failed")
// as initAccount was created without account id we have to take the id after account initialization
- // that happens inside the GetAccountByUserOrAccountID where the id is getting generated
+ // that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated
// it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it
- initAccount = acc
+ initAccount, err = manager.Store.GetAccount(context.Background(), accountID)
+ require.NoError(t, err, "get init account failed")
claims := jwtclaims.AuthorizationClaims{
AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount
@@ -685,8 +693,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
}
t.Run("JWT groups disabled", func(t *testing.T) {
- account, _, err := manager.GetAccountFromToken(context.Background(), claims)
+ accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")
+
+ account, err := manager.Store.GetAccount(context.Background(), accountID)
+ require.NoError(t, err, "get account failed")
+
require.Len(t, account.Groups, 1, "only ALL group should exists")
})
@@ -696,8 +708,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
- account, _, err := manager.GetAccountFromToken(context.Background(), claims)
+ accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")
+
+ account, err := manager.Store.GetAccount(context.Background(), accountID)
+ require.NoError(t, err, "get account failed")
+
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
})
@@ -708,8 +724,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
- account, _, err := manager.GetAccountFromToken(context.Background(), claims)
+ accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed")
+
+ account, err := manager.Store.GetAccount(context.Background(), accountID)
+ require.NoError(t, err, "get account failed")
+
require.Len(t, account.Groups, 3, "groups should be added to the account")
groupsByNames := map[string]*group.Group{}
@@ -874,21 +894,21 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) {
userId := "test_user"
- account, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, "", "")
+ accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "")
if err != nil {
t.Fatal(err)
}
- if account == nil {
+ if accountID == "" {
t.Fatalf("expected to create an account for a user %s", userId)
return
}
- _, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
+ _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
if err != nil {
- t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id)
+ t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID)
}
- _, err = manager.GetAccountByUserOrAccountID(context.Background(), "", "", "")
+ _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "")
if err == nil {
t.Errorf("expected an error when user and account IDs are empty")
}
@@ -1225,10 +1245,10 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
}
}()
- if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil {
- t.Errorf("save policy: %v", err)
- return
- }
+ if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
+ t.Errorf("delete default rule: %v", err)
+ return
+ }
wg.Wait()
}
@@ -1636,9 +1656,10 @@ func TestAccount_Copy(t *testing.T) {
},
Routes: map[route.ID]*route.Route{
"route1": {
- ID: "route1",
- PeerGroups: []string{},
- Groups: []string{"group1"},
+ ID: "route1",
+ PeerGroups: []string{},
+ Groups: []string{"group1"},
+ AccessControlGroups: []string{},
},
},
NameServerGroups: map[string]*nbdns.NameServerGroup{
@@ -1705,19 +1726,22 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
- account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
+ accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
- assert.NotNil(t, account.Settings)
- assert.Equal(t, account.Settings.PeerLoginExpirationEnabled, true)
- assert.Equal(t, account.Settings.PeerLoginExpiration, 24*time.Hour)
+ settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
+ require.NoError(t, err, "unable to get account settings")
+
+ assert.NotNil(t, settings)
+ assert.Equal(t, settings.PeerLoginExpirationEnabled, true)
+ assert.Equal(t, settings.PeerLoginExpiration, 24*time.Hour)
}
func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
- _, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
+ _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
@@ -1729,11 +1753,16 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
})
require.NoError(t, err, "unable to add peer")
- account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
+ accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account")
+
+ account, err := manager.Store.GetAccount(context.Background(), accountID)
+ require.NoError(t, err, "unable to get the account")
+
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
- account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
+
+ account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
@@ -1770,7 +1799,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
- account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
+ accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
@@ -1781,7 +1810,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
LoginExpirationEnabled: true,
})
require.NoError(t, err, "unable to add peer")
- _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
+ _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
@@ -1798,8 +1827,12 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
},
}
- account, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
+ accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account")
+
+ account, err := manager.Store.GetAccount(context.Background(), accountID)
+ require.NoError(t, err, "unable to get the account")
+
// when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
@@ -1814,7 +1847,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
- _, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
+ _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
@@ -1826,8 +1859,12 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
})
require.NoError(t, err, "unable to add peer")
- account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
+ accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to get the account")
+
+ account, err := manager.Store.GetAccount(context.Background(), accountID)
+ require.NoError(t, err, "unable to get the account")
+
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
@@ -1870,10 +1907,10 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
- account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "")
+ accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
- updated, err := manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
+ updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false,
})
@@ -1881,19 +1918,22 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
- account, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
+ accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
require.NoError(t, err, "unable to get account by ID")
- assert.False(t, account.Settings.PeerLoginExpirationEnabled)
- assert.Equal(t, account.Settings.PeerLoginExpiration, time.Hour)
+ settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID)
+ require.NoError(t, err, "unable to get account settings")
- _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
+ assert.False(t, settings.PeerLoginExpirationEnabled)
+ assert.Equal(t, settings.PeerLoginExpiration, time.Hour)
+
+ _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Second,
PeerLoginExpirationEnabled: false,
})
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
- _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
+ _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour * 24 * 181,
PeerLoginExpirationEnabled: false,
})
diff --git a/management/server/activity/sqlite/crypt.go b/management/server/activity/sqlite/crypt.go
index cf4dda746..096f49ea3 100644
--- a/management/server/activity/sqlite/crypt.go
+++ b/management/server/activity/sqlite/crypt.go
@@ -6,13 +6,14 @@ import (
"crypto/cipher"
"crypto/rand"
"encoding/base64"
- "fmt"
+ "errors"
)
var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05}
type FieldEncrypt struct {
block cipher.Block
+ gcm cipher.AEAD
}
func GenerateKey() (string, error) {
@@ -35,14 +36,21 @@ func NewFieldEncrypt(key string) (*FieldEncrypt, error) {
if err != nil {
return nil, err
}
+
+ gcm, err := cipher.NewGCM(block)
+ if err != nil {
+ return nil, err
+ }
+
ec := &FieldEncrypt{
block: block,
+ gcm: gcm,
}
return ec, nil
}
-func (ec *FieldEncrypt) Encrypt(payload string) string {
+func (ec *FieldEncrypt) LegacyEncrypt(payload string) string {
plainText := pkcs5Padding([]byte(payload))
cipherText := make([]byte, len(plainText))
cbc := cipher.NewCBCEncrypter(ec.block, iv)
@@ -50,7 +58,22 @@ func (ec *FieldEncrypt) Encrypt(payload string) string {
return base64.StdEncoding.EncodeToString(cipherText)
}
-func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
+// Encrypt encrypts plaintext using AES-GCM
+func (ec *FieldEncrypt) Encrypt(payload string) (string, error) {
+ plaintext := []byte(payload)
+ nonceSize := ec.gcm.NonceSize()
+
+ nonce := make([]byte, nonceSize, len(plaintext)+nonceSize+ec.gcm.Overhead())
+ if _, err := rand.Read(nonce); err != nil {
+ return "", err
+ }
+
+ ciphertext := ec.gcm.Seal(nonce, nonce, plaintext, nil)
+
+ return base64.StdEncoding.EncodeToString(ciphertext), nil
+}
+
+func (ec *FieldEncrypt) LegacyDecrypt(data string) (string, error) {
cipherText, err := base64.StdEncoding.DecodeString(data)
if err != nil {
return "", err
@@ -65,17 +88,49 @@ func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
return string(payload), nil
}
+// Decrypt decrypts ciphertext using AES-GCM
+func (ec *FieldEncrypt) Decrypt(data string) (string, error) {
+ cipherText, err := base64.StdEncoding.DecodeString(data)
+ if err != nil {
+ return "", err
+ }
+
+ nonceSize := ec.gcm.NonceSize()
+ if len(cipherText) < nonceSize {
+ return "", errors.New("cipher text too short")
+ }
+
+ nonce, cipherText := cipherText[:nonceSize], cipherText[nonceSize:]
+ plainText, err := ec.gcm.Open(nil, nonce, cipherText, nil)
+ if err != nil {
+ return "", err
+ }
+
+ return string(plainText), nil
+}
+
func pkcs5Padding(ciphertext []byte) []byte {
padding := aes.BlockSize - len(ciphertext)%aes.BlockSize
padText := bytes.Repeat([]byte{byte(padding)}, padding)
return append(ciphertext, padText...)
}
-
func pkcs5UnPadding(src []byte) ([]byte, error) {
srcLen := len(src)
- paddingLen := int(src[srcLen-1])
- if paddingLen >= srcLen || paddingLen > aes.BlockSize {
- return nil, fmt.Errorf("padding size error")
+ if srcLen == 0 {
+ return nil, errors.New("input data is empty")
}
+
+ paddingLen := int(src[srcLen-1])
+ if paddingLen == 0 || paddingLen > aes.BlockSize || paddingLen > srcLen {
+ return nil, errors.New("invalid padding size")
+ }
+
+ // Verify that all padding bytes are the same
+ for i := 0; i < paddingLen; i++ {
+ if src[srcLen-1-i] != byte(paddingLen) {
+ return nil, errors.New("invalid padding")
+ }
+ }
+
return src[:srcLen-paddingLen], nil
}
diff --git a/management/server/activity/sqlite/crypt_test.go b/management/server/activity/sqlite/crypt_test.go
index efa740921..aff3a08b1 100644
--- a/management/server/activity/sqlite/crypt_test.go
+++ b/management/server/activity/sqlite/crypt_test.go
@@ -1,6 +1,7 @@
package sqlite
import (
+ "bytes"
"testing"
)
@@ -15,7 +16,11 @@ func TestGenerateKey(t *testing.T) {
t.Fatalf("failed to init email encryption: %s", err)
}
- encrypted := ee.Encrypt(testData)
+ encrypted, err := ee.Encrypt(testData)
+ if err != nil {
+ t.Fatalf("failed to encrypt data: %s", err)
+ }
+
if encrypted == "" {
t.Fatalf("invalid encrypted text")
}
@@ -30,6 +35,32 @@ func TestGenerateKey(t *testing.T) {
}
}
+func TestGenerateKeyLegacy(t *testing.T) {
+ testData := "exampl@netbird.io"
+ key, err := GenerateKey()
+ if err != nil {
+ t.Fatalf("failed to generate key: %s", err)
+ }
+ ee, err := NewFieldEncrypt(key)
+ if err != nil {
+ t.Fatalf("failed to init email encryption: %s", err)
+ }
+
+ encrypted := ee.LegacyEncrypt(testData)
+ if encrypted == "" {
+ t.Fatalf("invalid encrypted text")
+ }
+
+ decrypted, err := ee.LegacyDecrypt(encrypted)
+ if err != nil {
+ t.Fatalf("failed to decrypt data: %s", err)
+ }
+
+ if decrypted != testData {
+ t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted)
+ }
+}
+
func TestCorruptKey(t *testing.T) {
testData := "exampl@netbird.io"
key, err := GenerateKey()
@@ -41,7 +72,11 @@ func TestCorruptKey(t *testing.T) {
t.Fatalf("failed to init email encryption: %s", err)
}
- encrypted := ee.Encrypt(testData)
+ encrypted, err := ee.Encrypt(testData)
+ if err != nil {
+ t.Fatalf("failed to encrypt data: %s", err)
+ }
+
if encrypted == "" {
t.Fatalf("invalid encrypted text")
}
@@ -61,3 +96,215 @@ func TestCorruptKey(t *testing.T) {
t.Fatalf("incorrect decryption, the result is: %s", res)
}
}
+
+func TestEncryptDecrypt(t *testing.T) {
+ // Generate a key for encryption/decryption
+ key, err := GenerateKey()
+ if err != nil {
+ t.Fatalf("Failed to generate key: %v", err)
+ }
+
+ // Initialize the FieldEncrypt with the generated key
+ ec, err := NewFieldEncrypt(key)
+ if err != nil {
+ t.Fatalf("Failed to create FieldEncrypt: %v", err)
+ }
+
+ // Test cases
+ testCases := []struct {
+ name string
+ input string
+ }{
+ {
+ name: "Empty String",
+ input: "",
+ },
+ {
+ name: "Short String",
+ input: "Hello",
+ },
+ {
+ name: "String with Spaces",
+ input: "Hello, World!",
+ },
+ {
+ name: "Long String",
+ input: "The quick brown fox jumps over the lazy dog.",
+ },
+ {
+ name: "Unicode Characters",
+ input: "こんにちは世界",
+ },
+ {
+ name: "Special Characters",
+ input: "!@#$%^&*()_+-=[]{}|;':\",./<>?",
+ },
+ {
+ name: "Numeric String",
+ input: "1234567890",
+ },
+ {
+ name: "Repeated Characters",
+ input: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
+ },
+ {
+ name: "Multi-block String",
+ input: "This is a longer string that will span multiple blocks in the encryption algorithm.",
+ },
+ {
+ name: "Non-ASCII and ASCII Mix",
+ input: "Hello 世界 123",
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name+" - Legacy", func(t *testing.T) {
+ // Legacy Encryption
+ encryptedLegacy := ec.LegacyEncrypt(tc.input)
+ if encryptedLegacy == "" {
+ t.Errorf("LegacyEncrypt returned empty string for input '%s'", tc.input)
+ }
+
+ // Legacy Decryption
+ decryptedLegacy, err := ec.LegacyDecrypt(encryptedLegacy)
+ if err != nil {
+ t.Errorf("LegacyDecrypt failed for input '%s': %v", tc.input, err)
+ }
+
+ // Verify that the decrypted value matches the original input
+ if decryptedLegacy != tc.input {
+ t.Errorf("LegacyDecrypt output '%s' does not match original input '%s'", decryptedLegacy, tc.input)
+ }
+ })
+
+ t.Run(tc.name+" - New", func(t *testing.T) {
+ // New Encryption
+ encryptedNew, err := ec.Encrypt(tc.input)
+ if err != nil {
+ t.Errorf("Encrypt failed for input '%s': %v", tc.input, err)
+ }
+ if encryptedNew == "" {
+ t.Errorf("Encrypt returned empty string for input '%s'", tc.input)
+ }
+
+ // New Decryption
+ decryptedNew, err := ec.Decrypt(encryptedNew)
+ if err != nil {
+ t.Errorf("Decrypt failed for input '%s': %v", tc.input, err)
+ }
+
+ // Verify that the decrypted value matches the original input
+ if decryptedNew != tc.input {
+ t.Errorf("Decrypt output '%s' does not match original input '%s'", decryptedNew, tc.input)
+ }
+ })
+ }
+}
+
+func TestPKCS5UnPadding(t *testing.T) {
+ tests := []struct {
+ name string
+ input []byte
+ expected []byte
+ expectError bool
+ }{
+ {
+ name: "Valid Padding",
+ input: append([]byte("Hello, World!"), bytes.Repeat([]byte{4}, 4)...),
+ expected: []byte("Hello, World!"),
+ },
+ {
+ name: "Empty Input",
+ input: []byte{},
+ expectError: true,
+ },
+ {
+ name: "Padding Length Zero",
+ input: append([]byte("Hello, World!"), bytes.Repeat([]byte{0}, 4)...),
+ expectError: true,
+ },
+ {
+ name: "Padding Length Exceeds Block Size",
+ input: append([]byte("Hello, World!"), bytes.Repeat([]byte{17}, 17)...),
+ expectError: true,
+ },
+ {
+ name: "Padding Length Exceeds Input Length",
+ input: []byte{5, 5, 5},
+ expectError: true,
+ },
+ {
+ name: "Invalid Padding Bytes",
+ input: append([]byte("Hello, World!"), []byte{2, 3, 4, 5}...),
+ expectError: true,
+ },
+ {
+ name: "Valid Single Byte Padding",
+ input: append([]byte("Hello, World!"), byte(1)),
+ expected: []byte("Hello, World!"),
+ },
+ {
+ name: "Invalid Mixed Padding Bytes",
+ input: append([]byte("Hello, World!"), []byte{3, 3, 2}...),
+ expectError: true,
+ },
+ {
+ name: "Valid Full Block Padding",
+ input: append([]byte("Hello, World!"), bytes.Repeat([]byte{16}, 16)...),
+ expected: []byte("Hello, World!"),
+ },
+ {
+ name: "Non-Padding Byte at End",
+ input: append([]byte("Hello, World!"), []byte{4, 4, 4, 5}...),
+ expectError: true,
+ },
+ {
+ name: "Valid Padding with Different Text Length",
+ input: append([]byte("Test"), bytes.Repeat([]byte{12}, 12)...),
+ expected: []byte("Test"),
+ },
+ {
+ name: "Padding Length Equal to Input Length",
+ input: bytes.Repeat([]byte{8}, 8),
+ expected: []byte{},
+ },
+ {
+ name: "Invalid Padding Length Zero (Again)",
+ input: append([]byte("Test"), byte(0)),
+ expectError: true,
+ },
+ {
+ name: "Padding Length Greater Than Input",
+ input: []byte{10},
+ expectError: true,
+ },
+ {
+ name: "Input Length Not Multiple of Block Size",
+ input: append([]byte("Invalid Length"), byte(1)),
+ expected: []byte("Invalid Length"),
+ },
+ {
+ name: "Valid Padding with Non-ASCII Characters",
+ input: append([]byte("こんにちは"), bytes.Repeat([]byte{2}, 2)...),
+ expected: []byte("こんにちは"),
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result, err := pkcs5UnPadding(tt.input)
+ if tt.expectError {
+ if err == nil {
+ t.Errorf("Expected error but got nil")
+ }
+ } else {
+ if err != nil {
+ t.Errorf("Did not expect error but got: %v", err)
+ }
+ if !bytes.Equal(result, tt.expected) {
+ t.Errorf("Expected output %v, got %v", tt.expected, result)
+ }
+ }
+ })
+ }
+}
diff --git a/management/server/activity/sqlite/migration.go b/management/server/activity/sqlite/migration.go
new file mode 100644
index 000000000..28c5b3020
--- /dev/null
+++ b/management/server/activity/sqlite/migration.go
@@ -0,0 +1,157 @@
+package sqlite
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+
+ log "github.com/sirupsen/logrus"
+)
+
+func migrate(ctx context.Context, crypt *FieldEncrypt, db *sql.DB) error {
+ if _, err := db.Exec(createTableQuery); err != nil {
+ return err
+ }
+
+ if _, err := db.Exec(creatTableDeletedUsersQuery); err != nil {
+ return err
+ }
+
+ if err := updateDeletedUsersTable(ctx, db); err != nil {
+ return fmt.Errorf("failed to update deleted_users table: %v", err)
+ }
+
+ return migrateLegacyEncryptedUsersToGCM(ctx, crypt, db)
+}
+
+// updateDeletedUsersTable checks and updates the deleted_users table schema to ensure required columns exist.
+func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error {
+ exists, err := checkColumnExists(db, "deleted_users", "name")
+ if err != nil {
+ return err
+ }
+
+ if !exists {
+ log.WithContext(ctx).Debug("Adding name column to the deleted_users table")
+
+ _, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
+ if err != nil {
+ return err
+ }
+
+ log.WithContext(ctx).Debug("Successfully added name column to the deleted_users table")
+ }
+
+ exists, err = checkColumnExists(db, "deleted_users", "enc_algo")
+ if err != nil {
+ return err
+ }
+
+ if !exists {
+ log.WithContext(ctx).Debug("Adding enc_algo column to the deleted_users table")
+
+ _, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN enc_algo TEXT;`)
+ if err != nil {
+ return err
+ }
+
+ log.WithContext(ctx).Debug("Successfully added enc_algo column to the deleted_users table")
+ }
+
+ return nil
+}
+
+// migrateLegacyEncryptedUsersToGCM migrates previously encrypted data using,
+// legacy CBC encryption with a static IV to the new GCM encryption method.
+func migrateLegacyEncryptedUsersToGCM(ctx context.Context, crypt *FieldEncrypt, db *sql.DB) error {
+ log.WithContext(ctx).Debug("Migrating CBC encrypted deleted users to GCM")
+
+ tx, err := db.Begin()
+ if err != nil {
+ return fmt.Errorf("failed to begin transaction: %v", err)
+ }
+ defer func() {
+ _ = tx.Rollback()
+ }()
+
+ rows, err := tx.Query(fmt.Sprintf(`SELECT id, email, name FROM deleted_users where enc_algo IS NULL OR enc_algo != '%s'`, gcmEncAlgo))
+ if err != nil {
+ return fmt.Errorf("failed to execute select query: %v", err)
+ }
+ defer rows.Close()
+
+ updateStmt, err := tx.Prepare(`UPDATE deleted_users SET email = ?, name = ?, enc_algo = ? WHERE id = ?`)
+ if err != nil {
+ return fmt.Errorf("failed to prepare update statement: %v", err)
+ }
+ defer updateStmt.Close()
+
+ if err = processUserRows(ctx, crypt, rows, updateStmt); err != nil {
+ return err
+ }
+
+ if err = tx.Commit(); err != nil {
+ return fmt.Errorf("failed to commit transaction: %v", err)
+ }
+
+ log.WithContext(ctx).Debug("Successfully migrated CBC encrypted deleted users to GCM")
+ return nil
+}
+
+// processUserRows processes database rows of user data, decrypts legacy encryption fields, and re-encrypts them using GCM.
+func processUserRows(ctx context.Context, crypt *FieldEncrypt, rows *sql.Rows, updateStmt *sql.Stmt) error {
+ for rows.Next() {
+ var (
+ id, decryptedEmail, decryptedName string
+ email, name *string
+ )
+
+ err := rows.Scan(&id, &email, &name)
+ if err != nil {
+ return err
+ }
+
+ if email != nil {
+ decryptedEmail, err = crypt.LegacyDecrypt(*email)
+ if err != nil {
+ log.WithContext(ctx).Warnf("skipping migrating deleted user %s: %v",
+ id,
+ fmt.Errorf("failed to decrypt email: %w", err),
+ )
+ continue
+ }
+ }
+
+ if name != nil {
+ decryptedName, err = crypt.LegacyDecrypt(*name)
+ if err != nil {
+ log.WithContext(ctx).Warnf("skipping migrating deleted user %s: %v",
+ id,
+ fmt.Errorf("failed to decrypt name: %w", err),
+ )
+ continue
+ }
+ }
+
+ encryptedEmail, err := crypt.Encrypt(decryptedEmail)
+ if err != nil {
+ return fmt.Errorf("failed to encrypt email: %w", err)
+ }
+
+ encryptedName, err := crypt.Encrypt(decryptedName)
+ if err != nil {
+ return fmt.Errorf("failed to encrypt name: %w", err)
+ }
+
+ _, err = updateStmt.Exec(encryptedEmail, encryptedName, gcmEncAlgo, id)
+ if err != nil {
+ return err
+ }
+ }
+
+ if err := rows.Err(); err != nil {
+ return err
+ }
+
+ return nil
+}
diff --git a/management/server/activity/sqlite/migration_test.go b/management/server/activity/sqlite/migration_test.go
new file mode 100644
index 000000000..a03774fa8
--- /dev/null
+++ b/management/server/activity/sqlite/migration_test.go
@@ -0,0 +1,84 @@
+package sqlite
+
+import (
+ "context"
+ "database/sql"
+ "path/filepath"
+ "testing"
+ "time"
+
+ _ "github.com/mattn/go-sqlite3"
+ "github.com/netbirdio/netbird/management/server/activity"
+
+ "github.com/stretchr/testify/require"
+)
+
+func setupDatabase(t *testing.T) *sql.DB {
+ t.Helper()
+
+ dbFile := filepath.Join(t.TempDir(), eventSinkDB)
+ db, err := sql.Open("sqlite3", dbFile)
+ require.NoError(t, err, "Failed to open database")
+
+ t.Cleanup(func() {
+ _ = db.Close()
+ })
+
+ _, err = db.Exec(createTableQuery)
+ require.NoError(t, err, "Failed to create events table")
+
+ _, err = db.Exec(`CREATE TABLE deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);`)
+ require.NoError(t, err, "Failed to create deleted_users table")
+
+ return db
+}
+
+func TestMigrate(t *testing.T) {
+ db := setupDatabase(t)
+
+ key, err := GenerateKey()
+ require.NoError(t, err, "Failed to generate key")
+
+ crypt, err := NewFieldEncrypt(key)
+ require.NoError(t, err, "Failed to initialize FieldEncrypt")
+
+ legacyEmail := crypt.LegacyEncrypt("testaccount@test.com")
+ legacyName := crypt.LegacyEncrypt("Test Account")
+
+ _, err = db.Exec(`INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) VALUES(?, ?, ?, ?, ?, ?)`,
+ activity.UserDeleted, time.Now(), "initiatorID", "targetID", "accountID", "")
+ require.NoError(t, err, "Failed to insert event")
+
+ _, err = db.Exec(`INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)`, "targetID", legacyEmail, legacyName)
+ require.NoError(t, err, "Failed to insert legacy encrypted data")
+
+ colExists, err := checkColumnExists(db, "deleted_users", "enc_algo")
+ require.NoError(t, err, "Failed to check if enc_algo column exists")
+ require.False(t, colExists, "enc_algo column should not exist before migration")
+
+ err = migrate(context.Background(), crypt, db)
+ require.NoError(t, err, "Migration failed")
+
+ colExists, err = checkColumnExists(db, "deleted_users", "enc_algo")
+ require.NoError(t, err, "Failed to check if enc_algo column exists after migration")
+ require.True(t, colExists, "enc_algo column should exist after migration")
+
+ var encAlgo string
+ err = db.QueryRow(`SELECT enc_algo FROM deleted_users LIMIT 1`, "").Scan(&encAlgo)
+ require.NoError(t, err, "Failed to select updated data")
+ require.Equal(t, gcmEncAlgo, encAlgo, "enc_algo should be set to 'GCM' after migration")
+
+ store, err := createStore(crypt, db)
+ require.NoError(t, err, "Failed to create store")
+
+ events, err := store.Get(context.Background(), "accountID", 0, 1, false)
+ require.NoError(t, err, "Failed to get events")
+
+ require.Len(t, events, 1, "Should have one event")
+ require.Equal(t, activity.UserDeleted, events[0].Activity, "activity should match")
+ require.Equal(t, "initiatorID", events[0].InitiatorID, "initiator id should match")
+ require.Equal(t, "targetID", events[0].TargetID, "target id should match")
+ require.Equal(t, "accountID", events[0].AccountID, "account id should match")
+ require.Equal(t, "testaccount@test.com", events[0].Meta["email"], "email should match")
+ require.Equal(t, "Test Account", events[0].Meta["username"], "username should match")
+}
diff --git a/management/server/activity/sqlite/sqlite.go b/management/server/activity/sqlite/sqlite.go
index fadf1eb07..823e0b4ac 100644
--- a/management/server/activity/sqlite/sqlite.go
+++ b/management/server/activity/sqlite/sqlite.go
@@ -26,7 +26,7 @@ const (
"meta TEXT," +
" target_id TEXT);"
- creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);`
+ creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT, enc_algo TEXT NOT NULL);`
selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.name as "initiator_name", i.email as "initiator_email", target_id, t.name as "target_name", t.email as "target_email", account_id, meta
FROM events
@@ -69,10 +69,12 @@ const (
and some selfhosted deployments might have duplicates already so we need to clean the table first.
*/
- insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)`
+ insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name, enc_algo) VALUES(?, ?, ?, ?)`
fallbackName = "unknown"
fallbackEmail = "unknown@unknown.com"
+
+ gcmEncAlgo = "GCM"
)
// Store is the implementation of the activity.Store interface backed by SQLite
@@ -100,58 +102,12 @@ func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) (
return nil, err
}
- _, err = db.Exec(createTableQuery)
- if err != nil {
+ if err = migrate(ctx, crypt, db); err != nil {
_ = db.Close()
- return nil, err
+ return nil, fmt.Errorf("events database migration: %w", err)
}
- _, err = db.Exec(creatTableDeletedUsersQuery)
- if err != nil {
- _ = db.Close()
- return nil, err
- }
-
- err = updateDeletedUsersTable(ctx, db)
- if err != nil {
- _ = db.Close()
- return nil, err
- }
-
- insertStmt, err := db.Prepare(insertQuery)
- if err != nil {
- _ = db.Close()
- return nil, err
- }
-
- selectDescStmt, err := db.Prepare(selectDescQuery)
- if err != nil {
- _ = db.Close()
- return nil, err
- }
-
- selectAscStmt, err := db.Prepare(selectAscQuery)
- if err != nil {
- _ = db.Close()
- return nil, err
- }
-
- deleteUserStmt, err := db.Prepare(insertDeleteUserQuery)
- if err != nil {
- _ = db.Close()
- return nil, err
- }
-
- s := &Store{
- db: db,
- fieldEncrypt: crypt,
- insertStatement: insertStmt,
- selectDescStatement: selectDescStmt,
- selectAscStatement: selectAscStmt,
- deleteUserStmt: deleteUserStmt,
- }
-
- return s, nil
+ return createStore(crypt, db)
}
func (store *Store) processResult(ctx context.Context, result *sql.Rows) ([]*activity.Event, error) {
@@ -302,9 +258,16 @@ func (store *Store) saveDeletedUserEmailAndNameInEncrypted(event *activity.Event
return event.Meta, nil
}
- encryptedEmail := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email))
- encryptedName := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name))
- _, err := store.deleteUserStmt.Exec(event.TargetID, encryptedEmail, encryptedName)
+ encryptedEmail, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email))
+ if err != nil {
+ return nil, err
+ }
+ encryptedName, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name))
+ if err != nil {
+ return nil, err
+ }
+
+ _, err = store.deleteUserStmt.Exec(event.TargetID, encryptedEmail, encryptedName, gcmEncAlgo)
if err != nil {
return nil, err
}
@@ -325,43 +288,70 @@ func (store *Store) Close(_ context.Context) error {
return nil
}
-func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error {
- log.WithContext(ctx).Debugf("check deleted_users table version")
- rows, err := db.Query(`PRAGMA table_info(deleted_users);`)
+// createStore initializes and returns a new Store instance with prepared SQL statements.
+func createStore(crypt *FieldEncrypt, db *sql.DB) (*Store, error) {
+ insertStmt, err := db.Prepare(insertQuery)
if err != nil {
- return err
+ _ = db.Close()
+ return nil, err
+ }
+
+ selectDescStmt, err := db.Prepare(selectDescQuery)
+ if err != nil {
+ _ = db.Close()
+ return nil, err
+ }
+
+ selectAscStmt, err := db.Prepare(selectAscQuery)
+ if err != nil {
+ _ = db.Close()
+ return nil, err
+ }
+
+ deleteUserStmt, err := db.Prepare(insertDeleteUserQuery)
+ if err != nil {
+ _ = db.Close()
+ return nil, err
+ }
+
+ return &Store{
+ db: db,
+ fieldEncrypt: crypt,
+ insertStatement: insertStmt,
+ selectDescStatement: selectDescStmt,
+ selectAscStatement: selectAscStmt,
+ deleteUserStmt: deleteUserStmt,
+ }, nil
+}
+
+// checkColumnExists checks if a column exists in a specified table
+func checkColumnExists(db *sql.DB, tableName, columnName string) (bool, error) {
+ query := fmt.Sprintf("PRAGMA table_info(%s);", tableName)
+ rows, err := db.Query(query)
+ if err != nil {
+ return false, fmt.Errorf("failed to query table info: %w", err)
}
defer rows.Close()
- found := false
+
for rows.Next() {
- var (
- cid int
- name string
- dataType string
- notNull int
- dfltVal sql.NullString
- pk int
- )
- err := rows.Scan(&cid, &name, &dataType, ¬Null, &dfltVal, &pk)
+ var cid int
+ var name, ctype string
+ var notnull, pk int
+ var dfltValue sql.NullString
+
+ err = rows.Scan(&cid, &name, &ctype, ¬null, &dfltValue, &pk)
if err != nil {
- return err
+ return false, fmt.Errorf("failed to scan row: %w", err)
}
- if name == "name" {
- found = true
- break
+
+ if name == columnName {
+ return true, nil
}
}
- err = rows.Err()
- if err != nil {
- return err
+ if err = rows.Err(); err != nil {
+ return false, err
}
- if found {
- return nil
- }
-
- log.WithContext(ctx).Debugf("update delted_users table")
- _, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`)
- return err
+ return false, nil
}
diff --git a/management/server/config.go b/management/server/config.go
index 4efe4fe74..2f7e49766 100644
--- a/management/server/config.go
+++ b/management/server/config.go
@@ -34,6 +34,7 @@ const (
type Config struct {
Stuns []*Host
TURNConfig *TURNConfig
+ Relay *Relay
Signal *Host
Datadir string
@@ -75,6 +76,12 @@ type TURNConfig struct {
Turns []*Host
}
+type Relay struct {
+ Addresses []string
+ CredentialsTTL util.Duration
+ Secret string
+}
+
// HttpServerConfig is a config of the HTTP Management service server
type HttpServerConfig struct {
LetsEncryptDomain string
diff --git a/management/server/dns.go b/management/server/dns.go
index fa62a0d4a..256b8b125 100644
--- a/management/server/dns.go
+++ b/management/server/dns.go
@@ -80,24 +80,16 @@ func (d DNSSettings) Copy() DNSSettings {
// GetDNSSettings validates a user role and returns the DNS settings for the provided account ID
func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
- user, err := account.FindUser(userID)
- if err != nil {
- return nil, err
- }
-
- if !(user.HasAdminPower() || user.IsServiceUser) {
+ if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings")
}
- dnsSettings := account.DNSSettings.Copy()
- return &dnsSettings, nil
+
+ return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
}
// SaveDNSSettings validates a user role and updates the account's DNS settings
diff --git a/management/server/ephemeral_test.go b/management/server/ephemeral_test.go
index 36c88f1d1..1390352a5 100644
--- a/management/server/ephemeral_test.go
+++ b/management/server/ephemeral_test.go
@@ -7,6 +7,7 @@ import (
"time"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
+ "github.com/netbirdio/netbird/management/server/status"
)
type MockStore struct {
@@ -24,7 +25,7 @@ func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Accou
return s.account, nil
}
- return nil, fmt.Errorf("account not found")
+ return nil, status.NewPeerNotFoundError(peerId)
}
type MocAccountManager struct {
diff --git a/management/server/file_store.go b/management/server/file_store.go
index 6e3536bcd..994a4b1ee 100644
--- a/management/server/file_store.go
+++ b/management/server/file_store.go
@@ -2,20 +2,23 @@ package server
import (
"context"
+ "errors"
+ "net"
"os"
"path/filepath"
"strings"
"sync"
"time"
- "github.com/rs/xid"
- log "github.com/sirupsen/logrus"
-
+ "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
+ "github.com/netbirdio/netbird/route"
+ "github.com/rs/xid"
+ log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/util"
)
@@ -46,6 +49,158 @@ type FileStore struct {
metrics telemetry.AppMetrics `json:"-"`
}
+func (s *FileStore) ExecuteInTransaction(ctx context.Context, f func(store Store) error) error {
+ return f(s)
+}
+
+func (s *FileStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
+ s.mux.Lock()
+ defer s.mux.Unlock()
+
+ accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKeyID)]
+ if !ok {
+ return status.NewSetupKeyNotFoundError()
+ }
+
+ account, err := s.getAccount(accountID)
+ if err != nil {
+ return err
+ }
+
+ account.SetupKeys[setupKeyID].UsedTimes++
+
+ return s.SaveAccount(ctx, account)
+}
+
+func (s *FileStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
+ s.mux.Lock()
+ defer s.mux.Unlock()
+
+ account, err := s.getAccount(accountID)
+ if err != nil {
+ return err
+ }
+
+ allGroup, err := account.GetGroupAll()
+ if err != nil || allGroup == nil {
+ return errors.New("all group not found")
+ }
+
+ allGroup.Peers = append(allGroup.Peers, peerID)
+
+ return nil
+}
+
+func (s *FileStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
+ s.mux.Lock()
+ defer s.mux.Unlock()
+
+ account, err := s.getAccount(accountId)
+ if err != nil {
+ return err
+ }
+
+ account.Groups[groupID].Peers = append(account.Groups[groupID].Peers, peerId)
+
+ return nil
+}
+
+func (s *FileStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
+ s.mux.Lock()
+ defer s.mux.Unlock()
+
+ account, ok := s.Accounts[peer.AccountID]
+ if !ok {
+ return status.NewAccountNotFoundError(peer.AccountID)
+ }
+
+ account.Peers[peer.ID] = peer
+ return s.SaveAccount(ctx, account)
+}
+
+func (s *FileStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
+ s.mux.Lock()
+ defer s.mux.Unlock()
+
+ account, ok := s.Accounts[accountId]
+ if !ok {
+ return status.NewAccountNotFoundError(accountId)
+ }
+
+ account.Network.Serial++
+
+ return s.SaveAccount(ctx, account)
+}
+
+func (s *FileStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
+ s.mux.Lock()
+ defer s.mux.Unlock()
+
+ accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(key)]
+ if !ok {
+ return nil, status.NewSetupKeyNotFoundError()
+ }
+
+ account, err := s.getAccount(accountID)
+ if err != nil {
+ return nil, err
+ }
+
+ setupKey, ok := account.SetupKeys[key]
+ if !ok {
+ return nil, status.Errorf(status.NotFound, "setup key not found")
+ }
+
+ return setupKey, nil
+}
+
+func (s *FileStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
+ s.mux.Lock()
+ defer s.mux.Unlock()
+
+ account, err := s.getAccount(accountID)
+ if err != nil {
+ return nil, err
+ }
+
+ var takenIps []net.IP
+ for _, existingPeer := range account.Peers {
+ takenIps = append(takenIps, existingPeer.IP)
+ }
+
+ return takenIps, nil
+}
+
+func (s *FileStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
+ s.mux.Lock()
+ defer s.mux.Unlock()
+
+ account, err := s.getAccount(accountID)
+ if err != nil {
+ return nil, err
+ }
+
+ existingLabels := []string{}
+ for _, peer := range account.Peers {
+ if peer.DNSLabel != "" {
+ existingLabels = append(existingLabels, peer.DNSLabel)
+ }
+ }
+ return existingLabels, nil
+}
+
+func (s *FileStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
+ s.mux.Lock()
+ defer s.mux.Unlock()
+
+ account, err := s.getAccount(accountID)
+ if err != nil {
+ return nil, err
+ }
+
+ return account.Network, nil
+}
+
type StoredAccount struct{}
// NewFileStore restores a store from the file located in the datadir
@@ -422,7 +577,7 @@ func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*A
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !ok {
- return nil, status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
+ return nil, status.NewSetupKeyNotFoundError()
}
account, err := s.getAccount(accountID)
@@ -469,6 +624,44 @@ func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User,
return account.Users[userID].Copy(), nil
}
+func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID string) (*User, error) {
+ accountID, ok := s.UserID2AccountID[userID]
+ if !ok {
+ return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists")
+ }
+
+ account, err := s.getAccount(accountID)
+ if err != nil {
+ return nil, err
+ }
+
+ user := account.Users[userID].Copy()
+ pat := make([]PersonalAccessToken, 0, len(user.PATs))
+ for _, token := range user.PATs {
+ if token != nil {
+ pat = append(pat, *token)
+ }
+ }
+ user.PATsG = pat
+
+ return user, nil
+}
+
+func (s *FileStore) GetAccountGroups(_ context.Context, accountID string) ([]*nbgroup.Group, error) {
+ account, err := s.getAccount(accountID)
+ if err != nil {
+ return nil, err
+ }
+
+ groupsSlice := make([]*nbgroup.Group, 0, len(account.Groups))
+
+ for _, group := range account.Groups {
+ groupsSlice = append(groupsSlice, group)
+ }
+
+ return groupsSlice, nil
+}
+
// GetAllAccounts returns all accounts
func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
s.mux.Lock()
@@ -484,7 +677,7 @@ func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
func (s *FileStore) getAccount(accountID string) (*Account, error) {
account, ok := s.Accounts[accountID]
if !ok {
- return nil, status.Errorf(status.NotFound, "account not found")
+ return nil, status.NewAccountNotFoundError(accountID)
}
return account, nil
@@ -610,13 +803,13 @@ func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) (
accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)]
if !ok {
- return "", status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists")
+ return "", status.NewSetupKeyNotFoundError()
}
return accountID, nil
}
-func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbpeer.Peer, error) {
+func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, _ LockingStrength, peerKey string) (*nbpeer.Peer, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -639,7 +832,7 @@ func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbp
return nil, status.NewPeerNotFoundError(peerKey)
}
-func (s *FileStore) GetAccountSettings(_ context.Context, accountID string) (*Settings, error) {
+func (s *FileStore) GetAccountSettings(_ context.Context, _ LockingStrength, accountID string) (*Settings, error) {
s.mux.Lock()
defer s.mux.Unlock()
@@ -729,7 +922,7 @@ func (s *FileStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.
}
// SaveUserLastLogin stores the last login time for a user in memory. It doesn't attempt to persist data to speed up things.
-func (s *FileStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
+func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID string, lastLogin time.Time) error {
s.mux.Lock()
defer s.mux.Unlock()
@@ -748,7 +941,7 @@ func (s *FileStore) SaveUserLastLogin(accountID, userID string, lastLogin time.T
return nil
}
-func (s *FileStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
+func (s *FileStore) GetPostureCheckByChecksDefinition(_ string, _ *posture.ChecksDefinition) (*posture.Checks, error) {
return nil, status.Errorf(status.Internal, "GetPostureCheckByChecksDefinition is not implemented")
}
@@ -767,10 +960,85 @@ func (s *FileStore) GetStoreEngine() StoreEngine {
return FileStoreEngine
}
-func (s *FileStore) SaveUsers(accountID string, users map[string]*User) error {
+func (s *FileStore) SaveUsers(_ string, _ map[string]*User) error {
return status.Errorf(status.Internal, "SaveUsers is not implemented")
}
-func (s *FileStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error {
+func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error {
return status.Errorf(status.Internal, "SaveGroups is not implemented")
}
+
+func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ LockingStrength, _ string) (string, error) {
+ return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented")
+}
+
+func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ LockingStrength, accountID string) (string, string, error) {
+ s.mux.Lock()
+ defer s.mux.Unlock()
+
+ account, err := s.getAccount(accountID)
+ if err != nil {
+ return "", "", err
+ }
+
+ return account.Domain, account.DomainCategory, nil
+}
+
+// AccountExists checks whether an account exists by the given ID.
+func (s *FileStore) AccountExists(_ context.Context, _ LockingStrength, id string) (bool, error) {
+ _, exists := s.Accounts[id]
+ return exists, nil
+}
+
+func (s *FileStore) GetAccountDNSSettings(_ context.Context, _ LockingStrength, _ string) (*DNSSettings, error) {
+ return nil, status.Errorf(status.Internal, "GetAccountDNSSettings is not implemented")
+}
+
+func (s *FileStore) GetGroupByID(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) {
+ return nil, status.Errorf(status.Internal, "GetGroupByID is not implemented")
+}
+
+func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) {
+ return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented")
+}
+
+func (s *FileStore) GetAccountPolicies(_ context.Context, _ LockingStrength, _ string) ([]*Policy, error) {
+ return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
+}
+
+func (s *FileStore) GetPolicyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*Policy, error) {
+ return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented")
+
+}
+
+func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ LockingStrength, _ string) ([]*posture.Checks, error) {
+ return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented")
+}
+
+func (s *FileStore) GetPostureChecksByID(_ context.Context, _ LockingStrength, _ string, _ string) (*posture.Checks, error) {
+ return nil, status.Errorf(status.Internal, "GetPostureChecksByID is not implemented")
+}
+
+func (s *FileStore) GetAccountRoutes(_ context.Context, _ LockingStrength, _ string) ([]*route.Route, error) {
+ return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented")
+}
+
+func (s *FileStore) GetRouteByID(_ context.Context, _ LockingStrength, _ string, _ string) (*route.Route, error) {
+ return nil, status.Errorf(status.Internal, "GetRouteByID is not implemented")
+}
+
+func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ LockingStrength, _ string) ([]*SetupKey, error) {
+ return nil, status.Errorf(status.Internal, "GetAccountSetupKeys is not implemented")
+}
+
+func (s *FileStore) GetSetupKeyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*SetupKey, error) {
+ return nil, status.Errorf(status.Internal, "GetSetupKeyByID is not implemented")
+}
+
+func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ LockingStrength, _ string) ([]*dns.NameServerGroup, error) {
+ return nil, status.Errorf(status.Internal, "GetAccountNameServerGroups is not implemented")
+}
+
+func (s *FileStore) GetNameServerGroupByID(_ context.Context, _ LockingStrength, _ string, _ string) (*dns.NameServerGroup, error) {
+ return nil, status.Errorf(status.Internal, "GetNameServerGroupByID is not implemented")
+}
diff --git a/management/server/geolocation/database.go b/management/server/geolocation/database.go
index c9b2eafff..21ae93b9d 100644
--- a/management/server/geolocation/database.go
+++ b/management/server/geolocation/database.go
@@ -1,10 +1,9 @@
package geolocation
import (
+ "context"
"encoding/csv"
- "fmt"
"io"
- "net/url"
"os"
"path"
"strconv"
@@ -20,26 +19,27 @@ const (
geoLiteCityZipURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City-CSV/download?suffix=zip"
geoLiteCitySha256TarURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz.sha256"
geoLiteCitySha256ZipURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City-CSV/download?suffix=zip.sha256"
+ geoLiteCityMMDB = "GeoLite2-City.mmdb"
+ geoLiteCityCSV = "GeoLite2-City-Locations-en.csv"
)
// loadGeolocationDatabases loads the MaxMind databases.
-func loadGeolocationDatabases(dataDir string) error {
- files := []string{MMDBFileName, GeoSqliteDBFile}
- for _, file := range files {
+func loadGeolocationDatabases(ctx context.Context, dataDir string, mmdbFile string, geonamesdbFile string) error {
+ for _, file := range []string{mmdbFile, geonamesdbFile} {
exists, _ := fileExists(path.Join(dataDir, file))
if exists {
continue
}
- log.Infof("geo location file %s not found , file will be downloaded", file)
+ log.WithContext(ctx).Infof("Geolocation database file %s not found, file will be downloaded", file)
switch file {
- case MMDBFileName:
+ case mmdbFile:
extractFunc := func(src string, dst string) error {
if err := decompressTarGzFile(src, dst); err != nil {
return err
}
- return copyFile(path.Join(dst, MMDBFileName), path.Join(dataDir, MMDBFileName))
+ return copyFile(path.Join(dst, geoLiteCityMMDB), path.Join(dataDir, mmdbFile))
}
if err := loadDatabase(
geoLiteCitySha256TarURL,
@@ -49,13 +49,13 @@ func loadGeolocationDatabases(dataDir string) error {
return err
}
- case GeoSqliteDBFile:
+ case geonamesdbFile:
extractFunc := func(src string, dst string) error {
if err := decompressZipFile(src, dst); err != nil {
return err
}
- extractedCsvFile := path.Join(dst, "GeoLite2-City-Locations-en.csv")
- return importCsvToSqlite(dataDir, extractedCsvFile)
+ extractedCsvFile := path.Join(dst, geoLiteCityCSV)
+ return importCsvToSqlite(dataDir, extractedCsvFile, geonamesdbFile)
}
if err := loadDatabase(
@@ -79,7 +79,12 @@ func loadDatabase(checksumURL string, fileURL string, extractFunc func(src strin
}
defer os.RemoveAll(temp)
- checksumFile := path.Join(temp, getDatabaseFileName(checksumURL))
+ checksumFilename, err := getFilenameFromURL(checksumURL)
+ if err != nil {
+ return err
+ }
+ checksumFile := path.Join(temp, checksumFilename)
+
err = downloadFile(checksumURL, checksumFile)
if err != nil {
return err
@@ -90,7 +95,12 @@ func loadDatabase(checksumURL string, fileURL string, extractFunc func(src strin
return err
}
- dbFile := path.Join(temp, getDatabaseFileName(fileURL))
+ dbFilename, err := getFilenameFromURL(fileURL)
+ if err != nil {
+ return err
+ }
+ dbFile := path.Join(temp, dbFilename)
+
err = downloadFile(fileURL, dbFile)
if err != nil {
return err
@@ -104,13 +114,13 @@ func loadDatabase(checksumURL string, fileURL string, extractFunc func(src strin
}
// importCsvToSqlite imports a CSV file into a SQLite database.
-func importCsvToSqlite(dataDir string, csvFile string) error {
+func importCsvToSqlite(dataDir string, csvFile string, geonamesdbFile string) error {
geonames, err := loadGeonamesCsv(csvFile)
if err != nil {
return err
}
- db, err := gorm.Open(sqlite.Open(path.Join(dataDir, GeoSqliteDBFile)), &gorm.Config{
+ db, err := gorm.Open(sqlite.Open(path.Join(dataDir, geonamesdbFile)), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
CreateBatchSize: 1000,
PrepareStmt: true,
@@ -178,18 +188,6 @@ func loadGeonamesCsv(filepath string) ([]GeoNames, error) {
return geoNames, nil
}
-// getDatabaseFileName extracts the file name from a given URL string.
-func getDatabaseFileName(urlStr string) string {
- u, err := url.Parse(urlStr)
- if err != nil {
- panic(err)
- }
-
- ext := u.Query().Get("suffix")
- fileName := fmt.Sprintf("%s.%s", path.Base(u.Path), ext)
- return fileName
-}
-
// copyFile performs a file copy operation from the source file to the destination.
func copyFile(src string, dst string) error {
srcFile, err := os.Open(src)
diff --git a/management/server/geolocation/geolocation.go b/management/server/geolocation/geolocation.go
index 794f9d0be..553a31581 100644
--- a/management/server/geolocation/geolocation.go
+++ b/management/server/geolocation/geolocation.go
@@ -1,29 +1,25 @@
package geolocation
import (
- "bytes"
"context"
"fmt"
"net"
"os"
"path"
+ "path/filepath"
+ "strings"
"sync"
- "time"
"github.com/oschwald/maxminddb-golang"
log "github.com/sirupsen/logrus"
)
-const MMDBFileName = "GeoLite2-City.mmdb"
-
type Geolocation struct {
- mmdbPath string
- mux sync.RWMutex
- sha256sum []byte
- db *maxminddb.Reader
- locationDB *SqliteStore
- stopCh chan struct{}
- reloadCheckInterval time.Duration
+ mmdbPath string
+ mux sync.RWMutex
+ db *maxminddb.Reader
+ locationDB *SqliteStore
+ stopCh chan struct{}
}
type Record struct {
@@ -53,45 +49,56 @@ type Country struct {
CountryName string
}
-func NewGeolocation(ctx context.Context, dataDir string) (*Geolocation, error) {
- if err := loadGeolocationDatabases(dataDir); err != nil {
+const (
+ mmdbPattern = "GeoLite2-City_*.mmdb"
+ geonamesdbPattern = "geonames_*.db"
+)
+
+func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (*Geolocation, error) {
+ mmdbGlobPattern := filepath.Join(dataDir, mmdbPattern)
+ mmdbFile, err := getDatabaseFilename(ctx, geoLiteCityTarGZURL, mmdbGlobPattern, autoUpdate)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get database filename: %v", err)
+ }
+
+ geonamesDbGlobPattern := filepath.Join(dataDir, geonamesdbPattern)
+ geonamesDbFile, err := getDatabaseFilename(ctx, geoLiteCityZipURL, geonamesDbGlobPattern, autoUpdate)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get database filename: %v", err)
+ }
+
+ if err := loadGeolocationDatabases(ctx, dataDir, mmdbFile, geonamesDbFile); err != nil {
return nil, fmt.Errorf("failed to load MaxMind databases: %v", err)
}
- mmdbPath := path.Join(dataDir, MMDBFileName)
+ if err := cleanupMaxMindDatabases(ctx, dataDir, mmdbFile, geonamesDbFile); err != nil {
+ return nil, fmt.Errorf("failed to remove old MaxMind databases: %v", err)
+ }
+
+ mmdbPath := path.Join(dataDir, mmdbFile)
db, err := openDB(mmdbPath)
if err != nil {
return nil, err
}
- sha256sum, err := calculateFileSHA256(mmdbPath)
- if err != nil {
- return nil, err
- }
-
- locationDB, err := NewSqliteStore(ctx, dataDir)
+ locationDB, err := NewSqliteStore(ctx, dataDir, geonamesDbFile)
if err != nil {
return nil, err
}
geo := &Geolocation{
- mmdbPath: mmdbPath,
- mux: sync.RWMutex{},
- sha256sum: sha256sum,
- db: db,
- locationDB: locationDB,
- reloadCheckInterval: 300 * time.Second, // TODO: make configurable
- stopCh: make(chan struct{}),
+ mmdbPath: mmdbPath,
+ mux: sync.RWMutex{},
+ db: db,
+ locationDB: locationDB,
+ stopCh: make(chan struct{}),
}
- go geo.reloader(ctx)
-
return geo, nil
}
func openDB(mmdbPath string) (*maxminddb.Reader, error) {
_, err := os.Stat(mmdbPath)
-
if os.IsNotExist(err) {
return nil, fmt.Errorf("%v does not exist", mmdbPath)
} else if err != nil {
@@ -166,70 +173,6 @@ func (gl *Geolocation) Stop() error {
return nil
}
-func (gl *Geolocation) reloader(ctx context.Context) {
- for {
- select {
- case <-gl.stopCh:
- return
- case <-time.After(gl.reloadCheckInterval):
- if err := gl.locationDB.reload(ctx); err != nil {
- log.WithContext(ctx).Errorf("geonames db reload failed: %s", err)
- }
-
- newSha256sum1, err := calculateFileSHA256(gl.mmdbPath)
- if err != nil {
- log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
- continue
- }
- if !bytes.Equal(gl.sha256sum, newSha256sum1) {
- // we check sum twice just to avoid possible case when we reload during update of the file
- // considering the frequency of file update (few times a week) checking sum twice should be enough
- time.Sleep(50 * time.Millisecond)
- newSha256sum2, err := calculateFileSHA256(gl.mmdbPath)
- if err != nil {
- log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
- continue
- }
- if !bytes.Equal(newSha256sum1, newSha256sum2) {
- log.WithContext(ctx).Errorf("sha256 sum changed during reloading of '%s'", gl.mmdbPath)
- continue
- }
- err = gl.reload(ctx, newSha256sum2)
- if err != nil {
- log.WithContext(ctx).Errorf("mmdb reload failed: %s", err)
- }
- } else {
- log.WithContext(ctx).Tracef("No changes in '%s', no need to reload. Next check is in %.0f seconds.",
- gl.mmdbPath, gl.reloadCheckInterval.Seconds())
- }
- }
- }
-}
-
-func (gl *Geolocation) reload(ctx context.Context, newSha256sum []byte) error {
- gl.mux.Lock()
- defer gl.mux.Unlock()
-
- log.WithContext(ctx).Infof("Reloading '%s'", gl.mmdbPath)
-
- err := gl.db.Close()
- if err != nil {
- return err
- }
-
- db, err := openDB(gl.mmdbPath)
- if err != nil {
- return err
- }
-
- gl.db = db
- gl.sha256sum = newSha256sum
-
- log.WithContext(ctx).Infof("Successfully reloaded '%s'", gl.mmdbPath)
-
- return nil
-}
-
func fileExists(filePath string) (bool, error) {
_, err := os.Stat(filePath)
if err == nil {
@@ -240,3 +183,79 @@ func fileExists(filePath string) (bool, error) {
}
return false, err
}
+
+func getExistingDatabases(pattern string) []string {
+ files, _ := filepath.Glob(pattern)
+ return files
+}
+
+func getDatabaseFilename(ctx context.Context, databaseURL string, filenamePattern string, autoUpdate bool) (string, error) {
+ var (
+ filename string
+ err error
+ )
+
+ if autoUpdate {
+ filename, err = getFilenameFromURL(databaseURL)
+ if err != nil {
+ log.WithContext(ctx).Debugf("Failed to update database from url: %s", databaseURL)
+ return "", err
+ }
+ } else {
+ files := getExistingDatabases(filenamePattern)
+ if len(files) < 1 {
+ filename, err = getFilenameFromURL(databaseURL)
+ if err != nil {
+ log.WithContext(ctx).Debugf("Failed to get database from url: %s", databaseURL)
+ return "", err
+ }
+ } else {
+ filename = filepath.Base(files[len(files)-1])
+ log.WithContext(ctx).Debugf("Using existing database, %s", filename)
+ return filename, nil
+ }
+ }
+
+ // strip suffixes that may be nested, such as .tar.gz
+ basename := strings.SplitN(filename, ".", 2)[0]
+ // get date version from basename
+ date := strings.SplitN(basename, "_", 2)[1]
+ // format db as "GeoLite2-Cities-{maxmind|geonames}_{DATE}.{mmdb|db}"
+ databaseFilename := filepath.Base(strings.Replace(filenamePattern, "*", date, 1))
+
+ return databaseFilename, nil
+}
+
+func cleanupOldDatabases(ctx context.Context, pattern string, currentFile string) error {
+ files := getExistingDatabases(pattern)
+
+ for _, db := range files {
+ if filepath.Base(db) == currentFile {
+ continue
+ }
+ log.WithContext(ctx).Debugf("Removing old database: %s", db)
+ err := os.Remove(db)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func cleanupMaxMindDatabases(ctx context.Context, dataDir string, mmdbFile string, geonamesdbFile string) error {
+ for _, file := range []string{mmdbFile, geonamesdbFile} {
+ switch file {
+ case mmdbFile:
+ pattern := filepath.Join(dataDir, mmdbPattern)
+ if err := cleanupOldDatabases(ctx, pattern, file); err != nil {
+ return err
+ }
+ case geonamesdbFile:
+ pattern := filepath.Join(dataDir, geonamesdbPattern)
+ if err := cleanupOldDatabases(ctx, pattern, file); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
diff --git a/management/server/geolocation/geolocation_test.go b/management/server/geolocation/geolocation_test.go
index 6fd46fcfe..9bdefd268 100644
--- a/management/server/geolocation/geolocation_test.go
+++ b/management/server/geolocation/geolocation_test.go
@@ -2,8 +2,8 @@ package geolocation
import (
"net"
- "os"
"path"
+ "path/filepath"
"sync"
"testing"
@@ -13,21 +13,15 @@ import (
)
// from https://github.com/maxmind/MaxMind-DB/blob/main/test-data/GeoLite2-City-Test.mmdb
-var mmdbPath = "../testdata/GeoLite2-City-Test.mmdb"
+var mmdbPath = "../testdata/GeoLite2-City_20240305.mmdb"
func TestGeoLite_Lookup(t *testing.T) {
tempDir := t.TempDir()
- filename := path.Join(tempDir, MMDBFileName)
+ filename := path.Join(tempDir, filepath.Base(mmdbPath))
err := util.CopyFileContents(mmdbPath, filename)
assert.NoError(t, err)
- defer func() {
- err := os.Remove(filename)
- if err != nil {
- t.Errorf("os.Remove: %s", err)
- }
- }()
- db, err := openDB(mmdbPath)
+ db, err := openDB(filename)
assert.NoError(t, err)
geo := &Geolocation{
diff --git a/management/server/geolocation/store.go b/management/server/geolocation/store.go
index 67d420cfd..1f94bf47e 100644
--- a/management/server/geolocation/store.go
+++ b/management/server/geolocation/store.go
@@ -1,7 +1,6 @@
package geolocation
import (
- "bytes"
"context"
"fmt"
"path/filepath"
@@ -17,10 +16,6 @@ import (
"github.com/netbirdio/netbird/management/server/status"
)
-const (
- GeoSqliteDBFile = "geonames.db"
-)
-
type GeoNames struct {
GeoNameID int `gorm:"column:geoname_id"`
LocaleCode string `gorm:"column:locale_code"`
@@ -44,31 +39,24 @@ func (*GeoNames) TableName() string {
// SqliteStore represents a location storage backed by a Sqlite DB.
type SqliteStore struct {
- db *gorm.DB
- filePath string
- mux sync.RWMutex
- closed bool
- sha256sum []byte
+ db *gorm.DB
+ filePath string
+ mux sync.RWMutex
+ closed bool
}
-func NewSqliteStore(ctx context.Context, dataDir string) (*SqliteStore, error) {
- file := filepath.Join(dataDir, GeoSqliteDBFile)
+func NewSqliteStore(ctx context.Context, dataDir string, dbPath string) (*SqliteStore, error) {
+ file := filepath.Join(dataDir, dbPath)
db, err := connectDB(ctx, file)
if err != nil {
return nil, err
}
- sha256sum, err := calculateFileSHA256(file)
- if err != nil {
- return nil, err
- }
-
return &SqliteStore{
- db: db,
- filePath: file,
- mux: sync.RWMutex{},
- sha256sum: sha256sum,
+ db: db,
+ filePath: file,
+ mux: sync.RWMutex{},
}, nil
}
@@ -115,48 +103,6 @@ func (s *SqliteStore) GetCitiesByCountry(countryISOCode string) ([]City, error)
return cities, nil
}
-// reload attempts to reload the SqliteStore's database if the database file has changed.
-func (s *SqliteStore) reload(ctx context.Context) error {
- s.mux.Lock()
- defer s.mux.Unlock()
-
- newSha256sum1, err := calculateFileSHA256(s.filePath)
- if err != nil {
- log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err)
- }
-
- if !bytes.Equal(s.sha256sum, newSha256sum1) {
- // we check sum twice just to avoid possible case when we reload during update of the file
- // considering the frequency of file update (few times a week) checking sum twice should be enough
- time.Sleep(50 * time.Millisecond)
- newSha256sum2, err := calculateFileSHA256(s.filePath)
- if err != nil {
- return fmt.Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err)
- }
- if !bytes.Equal(newSha256sum1, newSha256sum2) {
- return fmt.Errorf("sha256 sum changed during reloading of '%s'", s.filePath)
- }
-
- log.WithContext(ctx).Infof("Reloading '%s'", s.filePath)
- _ = s.close()
- s.closed = true
-
- newDb, err := connectDB(ctx, s.filePath)
- if err != nil {
- return err
- }
-
- s.closed = false
- s.db = newDb
-
- log.WithContext(ctx).Infof("Successfully reloaded '%s'", s.filePath)
- } else {
- log.WithContext(ctx).Tracef("No changes in '%s', no need to reload", s.filePath)
- }
-
- return nil
-}
-
// close closes the database connection.
// It retrieves the underlying *sql.DB object from the *gorm.DB object
// and calls the Close() method on it.
diff --git a/management/server/geolocation/utils.go b/management/server/geolocation/utils.go
index bdbd4732d..5104b0a08 100644
--- a/management/server/geolocation/utils.go
+++ b/management/server/geolocation/utils.go
@@ -10,6 +10,7 @@ import (
"errors"
"fmt"
"io"
+ "mime"
"net/http"
"os"
"path"
@@ -174,3 +175,21 @@ func downloadFile(url, filepath string) error {
_, err = io.Copy(out, bytes.NewBuffer(bodyBytes))
return err
}
+
+func getFilenameFromURL(url string) (string, error) {
+ resp, err := http.Head(url)
+ if err != nil {
+ return "", err
+ }
+
+ defer resp.Body.Close()
+
+ _, params, err := mime.ParseMediaType(resp.Header["Content-Disposition"][0])
+ if err != nil {
+ return "", err
+ }
+
+ filename := params["filename"]
+
+ return filename, nil
+}
diff --git a/management/server/group.go b/management/server/group.go
index 63281a2f1..91c06a3c0 100644
--- a/management/server/group.go
+++ b/management/server/group.go
@@ -25,91 +25,46 @@ func (e *GroupLinkError) Error() string {
return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name)
}
-// GetGroup object of the peers
+// CheckGroupPermissions validates if a user has the necessary permissions to view groups
+func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
+ settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
+ if err != nil {
+ return err
+ }
+
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
+ if err != nil {
+ return err
+ }
+
+ if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID {
+ return status.Errorf(status.PermissionDenied, "groups are blocked for users")
+ }
+
+ return nil
+}
+
+// GetGroup returns a specific group by groupID in an account
func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
- if err != nil {
+ if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
- user, err := account.FindUser(userID)
- if err != nil {
- return nil, err
- }
-
- if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
- return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users")
- }
-
- group, ok := account.Groups[groupID]
- if ok {
- return group, nil
- }
-
- return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID)
+ return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
}
// GetAllGroups returns all groups in an account
-func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
- if err != nil {
+func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
+ if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
- user, err := account.FindUser(userID)
- if err != nil {
- return nil, err
- }
-
- if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
- return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users")
- }
-
- groups := make([]*nbgroup.Group, 0, len(account.Groups))
- for _, item := range account.Groups {
- groups = append(groups, item)
- }
-
- return groups, nil
+ return am.Store.GetAccountGroups(ctx, accountID)
}
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
- if err != nil {
- return nil, err
- }
-
- matchingGroups := make([]*nbgroup.Group, 0)
- for _, group := range account.Groups {
- if group.Name == groupName {
- matchingGroups = append(matchingGroups, group)
- }
- }
-
- if len(matchingGroups) == 0 {
- return nil, status.Errorf(status.NotFound, "group with name %s not found", groupName)
- }
-
- maxPeers := -1
- var groupWithMostPeers *nbgroup.Group
- for i, group := range matchingGroups {
- if len(group.Peers) > maxPeers {
- maxPeers = len(group.Peers)
- groupWithMostPeers = matchingGroups[i]
- }
- }
-
- return groupWithMostPeers, nil
+ return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID)
}
// SaveGroup object of the peers
@@ -269,6 +224,15 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
return nil
}
+ allGroup, err := account.GetGroupAll()
+ if err != nil {
+ return err
+ }
+
+ if allGroup.ID == groupID {
+ return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
+ }
+
if err = validateDeleteGroup(account, group, userId); err != nil {
return err
}
diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go
index ff7a71cfd..4c4ef6c3c 100644
--- a/management/server/grpcserver.go
+++ b/management/server/grpcserver.go
@@ -16,13 +16,12 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
- nbContext "github.com/netbirdio/netbird/management/server/context"
- "github.com/netbirdio/netbird/management/server/posture"
-
"github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/management/proto"
+ nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
+ "github.com/netbirdio/netbird/management/server/posture"
internalStatus "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
)
@@ -32,17 +31,25 @@ type GRPCServer struct {
accountManager AccountManager
wgKey wgtypes.Key
proto.UnimplementedManagementServiceServer
- peersUpdateManager *PeersUpdateManager
- config *Config
- turnCredentialsManager TURNCredentialsManager
- jwtValidator *jwtclaims.JWTValidator
- jwtClaimsExtractor *jwtclaims.ClaimsExtractor
- appMetrics telemetry.AppMetrics
- ephemeralManager *EphemeralManager
+ peersUpdateManager *PeersUpdateManager
+ config *Config
+ secretsManager SecretsManager
+ jwtValidator *jwtclaims.JWTValidator
+ jwtClaimsExtractor *jwtclaims.ClaimsExtractor
+ appMetrics telemetry.AppMetrics
+ ephemeralManager *EphemeralManager
}
// NewServer creates a new Management server
-func NewServer(ctx context.Context, config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) {
+func NewServer(
+ ctx context.Context,
+ config *Config,
+ accountManager AccountManager,
+ peersUpdateManager *PeersUpdateManager,
+ secretsManager SecretsManager,
+ appMetrics telemetry.AppMetrics,
+ ephemeralManager *EphemeralManager,
+) (*GRPCServer, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
@@ -88,14 +95,14 @@ func NewServer(ctx context.Context, config *Config, accountManager AccountManage
return &GRPCServer{
wgKey: key,
// peerKey -> event channel
- peersUpdateManager: peersUpdateManager,
- accountManager: accountManager,
- config: config,
- turnCredentialsManager: turnCredentialsManager,
- jwtValidator: jwtValidator,
- jwtClaimsExtractor: jwtClaimsExtractor,
- appMetrics: appMetrics,
- ephemeralManager: ephemeralManager,
+ peersUpdateManager: peersUpdateManager,
+ accountManager: accountManager,
+ config: config,
+ secretsManager: secretsManager,
+ jwtValidator: jwtValidator,
+ jwtClaimsExtractor: jwtClaimsExtractor,
+ appMetrics: appMetrics,
+ ephemeralManager: ephemeralManager,
}, nil
}
@@ -132,24 +139,30 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
ctx := srv.Context()
- realIP := getRealIP(ctx)
-
syncReq := &proto.SyncRequest{}
peerKey, err := s.parseRequest(ctx, req, syncReq)
if err != nil {
return err
}
- //nolint
+ // nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
+
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
if err != nil {
- // this case should not happen and already indicates an issue but we don't want the system to fail due to being unable to log in detail
- accountID = "UNKNOWN"
+ // nolint:staticcheck
+ ctx = context.WithValue(ctx, nbContext.AccountIDKey, "UNKNOWN")
+ log.WithContext(ctx).Tracef("peer %s is not registered", peerKey.String())
+ if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
+ return status.Errorf(codes.PermissionDenied, "peer is not registered")
+ }
+ return err
}
- //nolint
+
+ // nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
+ realIP := getRealIP(ctx)
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String())
if syncReq.GetMeta() == nil {
@@ -171,9 +184,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
s.ephemeralManager.OnPeerConnected(ctx, peer)
- if s.config.TURNConfig.TimeBasedCredentials {
- s.turnCredentialsManager.SetupRefresh(ctx, peer.ID)
- }
+ s.secretsManager.SetupRefresh(ctx, peer.ID)
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart))
@@ -235,7 +246,7 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
- s.turnCredentialsManager.CancelRefresh(peer.ID)
+ s.secretsManager.CancelRefresh(peer.ID)
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
}
@@ -251,7 +262,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
}
claims := s.jwtClaimsExtractor.FromToken(token)
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
- _, _, err = s.accountManager.GetAccountFromToken(ctx, claims)
+ _, _, err = s.accountManager.GetAccountIDFromToken(ctx, claims)
if err != nil {
return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err)
}
@@ -421,9 +432,17 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
}
+ var relayToken *Token
+ if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 {
+ relayToken, err = s.secretsManager.GenerateRelayToken()
+ if err != nil {
+ log.Errorf("failed generating Relay token: %v", err)
+ }
+ }
+
// if peer has reached this point then it has logged in
loginResp := &proto.LoginResponse{
- WiretrusteeConfig: toWiretrusteeConfig(s.config, nil),
+ WiretrusteeConfig: toWiretrusteeConfig(s.config, nil, relayToken),
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()),
Checks: toProtocolChecks(ctx, postureChecks),
}
@@ -481,10 +500,11 @@ func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol {
}
}
-func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *proto.WiretrusteeConfig {
+func toWiretrusteeConfig(config *Config, turnCredentials *Token, relayToken *Token) *proto.WiretrusteeConfig {
if config == nil {
return nil
}
+
var stuns []*proto.HostConfig
for _, stun := range config.Stuns {
stuns = append(stuns, &proto.HostConfig{
@@ -492,25 +512,40 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot
Protocol: ToResponseProto(stun.Proto),
})
}
+
var turns []*proto.ProtectedHostConfig
- for _, turn := range config.TURNConfig.Turns {
- var username string
- var password string
- if turnCredentials != nil {
- username = turnCredentials.Username
- password = turnCredentials.Password
- } else {
- username = turn.Username
- password = turn.Password
+ if config.TURNConfig != nil {
+ for _, turn := range config.TURNConfig.Turns {
+ var username string
+ var password string
+ if turnCredentials != nil {
+ username = turnCredentials.Payload
+ password = turnCredentials.Signature
+ } else {
+ username = turn.Username
+ password = turn.Password
+ }
+ turns = append(turns, &proto.ProtectedHostConfig{
+ HostConfig: &proto.HostConfig{
+ Uri: turn.URI,
+ Protocol: ToResponseProto(turn.Proto),
+ },
+ User: username,
+ Password: password,
+ })
+ }
+ }
+
+ var relayCfg *proto.RelayConfig
+ if config.Relay != nil && len(config.Relay.Addresses) > 0 {
+ relayCfg = &proto.RelayConfig{
+ Urls: config.Relay.Addresses,
+ }
+
+ if relayToken != nil {
+ relayCfg.TokenPayload = relayToken.Payload
+ relayCfg.TokenSignature = relayToken.Signature
}
- turns = append(turns, &proto.ProtectedHostConfig{
- HostConfig: &proto.HostConfig{
- Uri: turn.URI,
- Protocol: ToResponseProto(turn.Proto),
- },
- User: username,
- Password: password,
- })
}
return &proto.WiretrusteeConfig{
@@ -520,6 +555,7 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot
Uri: config.Signal.URI,
Protocol: ToResponseProto(config.Signal.Proto),
},
+ Relay: relayCfg,
}
}
@@ -533,9 +569,9 @@ func toPeerConfig(peer *nbpeer.Peer, network *Network, dnsName string) *proto.Pe
}
}
-func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache) *proto.SyncResponse {
+func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache) *proto.SyncResponse {
response := &proto.SyncResponse{
- WiretrusteeConfig: toWiretrusteeConfig(config, turnCredentials),
+ WiretrusteeConfig: toWiretrusteeConfig(config, turnCredentials, relayCredentials),
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName),
NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(),
@@ -560,6 +596,10 @@ func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turn
response.NetworkMap.FirewallRules = firewallRules
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
+ routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
+ response.NetworkMap.RoutesFirewallRules = routesFirewallRules
+ response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
+
return response
}
@@ -582,15 +622,25 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em
// sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization
func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error {
- // make secret time based TURN credentials optional
- var turnCredentials *TURNCredentials
- if s.config.TURNConfig.TimeBasedCredentials {
- creds := s.turnCredentialsManager.GenerateCredentials()
- turnCredentials = &creds
- } else {
- turnCredentials = nil
+ var err error
+
+ var turnToken *Token
+ if s.config.TURNConfig != nil && s.config.TURNConfig.TimeBasedCredentials {
+ turnToken, err = s.secretsManager.GenerateTurnToken()
+ if err != nil {
+ log.Errorf("failed generating TURN token: %v", err)
+ }
}
- plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil)
+
+ var relayToken *Token
+ if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 {
+ relayToken, err = s.secretsManager.GenerateRelayToken()
+ if err != nil {
+ log.Errorf("failed generating Relay token: %v", err)
+ }
+ }
+
+ plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil {
diff --git a/management/server/http/accounts_handler.go b/management/server/http/accounts_handler.go
index ffa5b9a28..91caa1512 100644
--- a/management/server/http/accounts_handler.go
+++ b/management/server/http/accounts_handler.go
@@ -35,25 +35,26 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) *
// GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- if !(user.HasAdminPower() || user.IsServiceUser) {
- util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w)
+ settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
return
}
- resp := toAccountResponse(account)
+ resp := toAccountResponse(accountID, settings)
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
}
// UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- _, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ _, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -96,24 +97,19 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request)
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
}
- updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, user.Id, settings)
+ updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- resp := toAccountResponse(updatedAccount)
+ resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings)
util.WriteJSONObject(r.Context(), w, &resp)
}
// DeleteAccount is a HTTP DELETE handler to delete an account
func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) {
- if r.Method != http.MethodDelete {
- util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
- return
- }
-
claims := h.claimsExtractor.FromRequestContext(r)
vars := mux.Vars(r)
targetAccountID := vars["accountId"]
@@ -131,28 +127,28 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request)
util.WriteJSONObject(r.Context(), w, emptyObject{})
}
-func toAccountResponse(account *server.Account) *api.Account {
- jwtAllowGroups := account.Settings.JWTAllowGroups
+func toAccountResponse(accountID string, settings *server.Settings) *api.Account {
+ jwtAllowGroups := settings.JWTAllowGroups
if jwtAllowGroups == nil {
jwtAllowGroups = []string{}
}
- settings := api.AccountSettings{
- PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()),
- PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled,
- GroupsPropagationEnabled: &account.Settings.GroupsPropagationEnabled,
- JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled,
- JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName,
+ apiSettings := api.AccountSettings{
+ PeerLoginExpiration: int(settings.PeerLoginExpiration.Seconds()),
+ PeerLoginExpirationEnabled: settings.PeerLoginExpirationEnabled,
+ GroupsPropagationEnabled: &settings.GroupsPropagationEnabled,
+ JwtGroupsEnabled: &settings.JWTGroupsEnabled,
+ JwtGroupsClaimName: &settings.JWTGroupsClaimName,
JwtAllowGroups: &jwtAllowGroups,
- RegularUsersViewBlocked: account.Settings.RegularUsersViewBlocked,
+ RegularUsersViewBlocked: settings.RegularUsersViewBlocked,
}
- if account.Settings.Extra != nil {
- settings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &account.Settings.Extra.PeerApprovalEnabled}
+ if settings.Extra != nil {
+ apiSettings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &settings.Extra.PeerApprovalEnabled}
}
return &api.Account{
- Id: account.Id,
- Settings: settings,
+ Id: accountID,
+ Settings: apiSettings,
}
}
diff --git a/management/server/http/accounts_handler_test.go b/management/server/http/accounts_handler_test.go
index 45c7679e5..cacb3d430 100644
--- a/management/server/http/accounts_handler_test.go
+++ b/management/server/http/accounts_handler_test.go
@@ -23,8 +23,11 @@ import (
func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler {
return &AccountsHandler{
accountManager: &mock_server.MockAccountManager{
- GetAccountFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
- return account, admin, nil
+ GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
+ return account.Id, admin.Id, nil
+ },
+ GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
+ return account.Settings, nil
},
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) {
halfYearLimit := 180 * 24 * time.Hour
diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml
index 45887dc2e..fd0343e97 100644
--- a/management/server/http/api/openapi.yml
+++ b/management/server/http/api/openapi.yml
@@ -251,7 +251,7 @@ components:
- name
- ssh_enabled
- login_expiration_enabled
- PeerBase:
+ Peer:
allOf:
- $ref: '#/components/schemas/PeerMinimum'
- type: object
@@ -378,25 +378,40 @@ components:
description: User ID of the user that enrolled this peer
type: string
example: google-oauth2|277474792786460067937
+ os:
+ description: Peer's operating system and version
+ type: string
+ example: linux
+ country_code:
+ $ref: '#/components/schemas/CountryCode'
+ city_name:
+ $ref: '#/components/schemas/CityName'
+ geoname_id:
+ description: Unique identifier from the GeoNames database for a specific geographical location.
+ type: integer
+ example: 2643743
+ connected:
+ description: Peer to Management connection status
+ type: boolean
+ example: true
+ last_seen:
+ description: Last time peer connected to Netbird's management service
+ type: string
+ format: date-time
+ example: "2023-05-05T10:05:26.420578Z"
required:
- ip
- dns_label
- user_id
- Peer:
- allOf:
- - $ref: '#/components/schemas/PeerBase'
- - type: object
- properties:
- accessible_peers:
- description: List of accessible peers
- type: array
- items:
- $ref: '#/components/schemas/AccessiblePeer'
- required:
- - accessible_peers
+ - os
+ - country_code
+ - city_name
+ - geoname_id
+ - connected
+ - last_seen
PeerBatch:
allOf:
- - $ref: '#/components/schemas/PeerBase'
+ - $ref: '#/components/schemas/Peer'
- type: object
properties:
accessible_peers_count:
@@ -712,17 +727,39 @@ components:
enum: ["all", "tcp", "udp", "icmp"]
example: "tcp"
ports:
- description: Policy rule affected ports or it ranges list
+ description: Policy rule affected ports
type: array
items:
type: string
example: "80"
+ port_ranges:
+ description: Policy rule affected ports ranges list
+ type: array
+ items:
+ $ref: '#/components/schemas/RulePortRange'
required:
- name
- enabled
- bidirectional
- protocol
- action
+
+ RulePortRange:
+ description: Policy rule affected ports range
+ type: object
+ properties:
+ start:
+ description: The starting port of the range
+ type: integer
+ example: 80
+ end:
+ description: The ending port of the range
+ type: integer
+ example: 320
+ required:
+ - start
+ - end
+
PolicyRuleUpdate:
allOf:
- $ref: '#/components/schemas/PolicyRuleMinimum'
@@ -935,7 +972,7 @@ components:
type: array
items:
type: string
- example: ["192.168.1.0/24", "10.0.0.0/8", "2001:db8:1234:1a00::/56"]
+ example: ["192.168.1.0/24", "10.0.0.0/8", "2001:db8:1234:1a00::/56"]
action:
description: Action to take upon policy match
type: string
@@ -1064,12 +1101,12 @@ components:
type: string
example: 10.64.0.0/24
domains:
- description: Domain list to be dynamically resolved. Conflicts with network
+ description: Domain list to be dynamically resolved. Max of 32 domains can be added per route configuration. Conflicts with network
type: array
items:
type: string
minLength: 1
- maxLength: 255
+ maxLength: 32
example: "example.com"
metric:
description: Route metric number. Lowest number has higher priority
@@ -1091,6 +1128,12 @@ components:
description: Indicate if the route should be kept after a domain doesn't resolve that IP anymore
type: boolean
example: true
+ access_control_groups:
+ description: Access control group identifier associated with route.
+ type: array
+ items:
+ type: string
+ example: "chacbco6lnnbn6cg5s91"
required:
- id
- description
@@ -1806,6 +1849,38 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
+ /api/peers/{peerId}/accessible-peers:
+ get:
+ summary: List accessible Peers
+ description: Returns a list of peers that the specified peer can connect to within the network.
+ tags: [ Peers ]
+ security:
+ - BearerAuth: [ ]
+ - TokenAuth: [ ]
+ parameters:
+ - in: path
+ name: peerId
+ required: true
+ schema:
+ type: string
+ description: The unique identifier of a peer
+ responses:
+ '200':
+ description: A JSON Array of Accessible Peers
+ content:
+ application/json:
+ schema:
+ type: array
+ items:
+ $ref: '#/components/schemas/AccessiblePeer'
+ '400':
+ "$ref": "#/components/responses/bad_request"
+ '401':
+ "$ref": "#/components/responses/requires_authentication"
+ '403':
+ "$ref": "#/components/responses/forbidden"
+ '500':
+ "$ref": "#/components/responses/internal_error"
/api/setup-keys:
get:
summary: List all Setup Keys
@@ -2759,4 +2834,4 @@ paths:
'403':
"$ref": "#/components/responses/forbidden"
'500':
- "$ref": "#/components/responses/internal_error"
\ No newline at end of file
+ "$ref": "#/components/responses/internal_error"
diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go
index 77a6c643d..570ec03c5 100644
--- a/management/server/http/api/types.gen.go
+++ b/management/server/http/api/types.gen.go
@@ -152,18 +152,36 @@ const (
// AccessiblePeer defines model for AccessiblePeer.
type AccessiblePeer struct {
+ // CityName Commonly used English name of the city
+ CityName CityName `json:"city_name"`
+
+ // Connected Peer to Management connection status
+ Connected bool `json:"connected"`
+
+ // CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
+ CountryCode CountryCode `json:"country_code"`
+
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"`
+ // GeonameId Unique identifier from the GeoNames database for a specific geographical location.
+ GeonameId int `json:"geoname_id"`
+
// Id Peer ID
Id string `json:"id"`
// Ip Peer's IP address
Ip string `json:"ip"`
+ // LastSeen Last time peer connected to Netbird's management service
+ LastSeen time.Time `json:"last_seen"`
+
// Name Peer's hostname
Name string `json:"name"`
+ // Os Peer's operating system and version
+ Os string `json:"os"`
+
// UserId User ID of the user that enrolled this peer
UserId string `json:"user_id"`
}
@@ -490,81 +508,6 @@ type OSVersionCheck struct {
// Peer defines model for Peer.
type Peer struct {
- // AccessiblePeers List of accessible peers
- AccessiblePeers []AccessiblePeer `json:"accessible_peers"`
-
- // ApprovalRequired (Cloud only) Indicates whether peer needs approval
- ApprovalRequired bool `json:"approval_required"`
-
- // CityName Commonly used English name of the city
- CityName CityName `json:"city_name"`
-
- // Connected Peer to Management connection status
- Connected bool `json:"connected"`
-
- // ConnectionIp Peer's public connection IP address
- ConnectionIp string `json:"connection_ip"`
-
- // CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country
- CountryCode CountryCode `json:"country_code"`
-
- // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
- DnsLabel string `json:"dns_label"`
-
- // GeonameId Unique identifier from the GeoNames database for a specific geographical location.
- GeonameId int `json:"geoname_id"`
-
- // Groups Groups that the peer belongs to
- Groups []GroupMinimum `json:"groups"`
-
- // Hostname Hostname of the machine
- Hostname string `json:"hostname"`
-
- // Id Peer ID
- Id string `json:"id"`
-
- // Ip Peer's IP address
- Ip string `json:"ip"`
-
- // KernelVersion Peer's operating system kernel version
- KernelVersion string `json:"kernel_version"`
-
- // LastLogin Last time this peer performed log in (authentication). E.g., user authenticated.
- LastLogin time.Time `json:"last_login"`
-
- // LastSeen Last time peer connected to Netbird's management service
- LastSeen time.Time `json:"last_seen"`
-
- // LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not
- LoginExpirationEnabled bool `json:"login_expiration_enabled"`
-
- // LoginExpired Indicates whether peer's login expired or not
- LoginExpired bool `json:"login_expired"`
-
- // Name Peer's hostname
- Name string `json:"name"`
-
- // Os Peer's operating system and version
- Os string `json:"os"`
-
- // SerialNumber System serial number
- SerialNumber string `json:"serial_number"`
-
- // SshEnabled Indicates whether SSH server is enabled on this peer
- SshEnabled bool `json:"ssh_enabled"`
-
- // UiVersion Peer's desktop UI version
- UiVersion string `json:"ui_version"`
-
- // UserId User ID of the user that enrolled this peer
- UserId string `json:"user_id"`
-
- // Version Peer's daemon or cli version
- Version string `json:"version"`
-}
-
-// PeerBase defines model for PeerBase.
-type PeerBase struct {
// ApprovalRequired (Cloud only) Indicates whether peer needs approval
ApprovalRequired bool `json:"approval_required"`
@@ -837,7 +780,10 @@ type PolicyRule struct {
// Name Policy rule name identifier
Name string `json:"name"`
- // Ports Policy rule affected ports or it ranges list
+ // PortRanges Policy rule affected ports ranges list
+ PortRanges *[]RulePortRange `json:"port_ranges,omitempty"`
+
+ // Ports Policy rule affected ports
Ports *[]string `json:"ports,omitempty"`
// Protocol Policy rule type of the traffic
@@ -873,7 +819,10 @@ type PolicyRuleMinimum struct {
// Name Policy rule name identifier
Name string `json:"name"`
- // Ports Policy rule affected ports or it ranges list
+ // PortRanges Policy rule affected ports ranges list
+ PortRanges *[]RulePortRange `json:"port_ranges,omitempty"`
+
+ // Ports Policy rule affected ports
Ports *[]string `json:"ports,omitempty"`
// Protocol Policy rule type of the traffic
@@ -909,7 +858,10 @@ type PolicyRuleUpdate struct {
// Name Policy rule name identifier
Name string `json:"name"`
- // Ports Policy rule affected ports or it ranges list
+ // PortRanges Policy rule affected ports ranges list
+ PortRanges *[]RulePortRange `json:"port_ranges,omitempty"`
+
+ // Ports Policy rule affected ports
Ports *[]string `json:"ports,omitempty"`
// Protocol Policy rule type of the traffic
@@ -992,10 +944,13 @@ type ProcessCheck struct {
// Route defines model for Route.
type Route struct {
+ // AccessControlGroups Access control group identifier associated with route.
+ AccessControlGroups *[]string `json:"access_control_groups,omitempty"`
+
// Description Route description
Description string `json:"description"`
- // Domains Domain list to be dynamically resolved. Conflicts with network
+ // Domains Domain list to be dynamically resolved. Max of 32 domains can be added per route configuration. Conflicts with network
Domains *[]string `json:"domains,omitempty"`
// Enabled Route status
@@ -1034,10 +989,13 @@ type Route struct {
// RouteRequest defines model for RouteRequest.
type RouteRequest struct {
+ // AccessControlGroups Access control group identifier associated with route.
+ AccessControlGroups *[]string `json:"access_control_groups,omitempty"`
+
// Description Route description
Description string `json:"description"`
- // Domains Domain list to be dynamically resolved. Conflicts with network
+ // Domains Domain list to be dynamically resolved. Max of 32 domains can be added per route configuration. Conflicts with network
Domains *[]string `json:"domains,omitempty"`
// Enabled Route status
@@ -1068,6 +1026,15 @@ type RouteRequest struct {
PeerGroups *[]string `json:"peer_groups,omitempty"`
}
+// RulePortRange Policy rule affected ports range
+type RulePortRange struct {
+ // End The ending port of the range
+ End int `json:"end"`
+
+ // Start The starting port of the range
+ Start int `json:"start"`
+}
+
// SetupKey defines model for SetupKey.
type SetupKey struct {
// AutoGroups List of group IDs to auto-assign to peers registered with this key
diff --git a/management/server/http/dns_settings_handler.go b/management/server/http/dns_settings_handler.go
index 74b0e1a55..13c2101a7 100644
--- a/management/server/http/dns_settings_handler.go
+++ b/management/server/http/dns_settings_handler.go
@@ -32,14 +32,14 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg
// GetDNSSettings returns the DNS settings for the account
func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
- dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), account.Id, user.Id)
+ dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -55,7 +55,7 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque
// UpdateDNSSettings handles update to DNS settings of an account
func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -72,7 +72,7 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re
DisabledManagementGroups: req.DisabledManagementGroups,
}
- err = h.accountManager.SaveDNSSettings(r.Context(), account.Id, user.Id, updateDNSSettings)
+ err = h.accountManager.SaveDNSSettings(r.Context(), accountID, userID, updateDNSSettings)
if err != nil {
util.WriteError(r.Context(), err, w)
return
diff --git a/management/server/http/dns_settings_handler_test.go b/management/server/http/dns_settings_handler_test.go
index 897ae63dc..8baea7b15 100644
--- a/management/server/http/dns_settings_handler_test.go
+++ b/management/server/http/dns_settings_handler_test.go
@@ -52,8 +52,8 @@ func initDNSSettingsTestData() *DNSSettingsHandler {
}
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
},
- GetAccountFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
- return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil
+ GetAccountIDFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
+ return testingDNSSettingsAccount.Id, testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
diff --git a/management/server/http/events_handler.go b/management/server/http/events_handler.go
index 428b4c164..ee0c63f28 100644
--- a/management/server/http/events_handler.go
+++ b/management/server/http/events_handler.go
@@ -34,14 +34,14 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev
// GetAllEvents list of the given account
func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
- accountEvents, err := h.accountManager.GetEvents(r.Context(), account.Id, user.Id)
+ accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -51,7 +51,7 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
events[i] = toEventResponse(e)
}
- err = h.fillEventsWithUserInfo(r.Context(), events, account.Id, user.Id)
+ err = h.fillEventsWithUserInfo(r.Context(), events, accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
diff --git a/management/server/http/events_handler_test.go b/management/server/http/events_handler_test.go
index 8bdd508bf..e525cf2ee 100644
--- a/management/server/http/events_handler_test.go
+++ b/management/server/http/events_handler_test.go
@@ -20,7 +20,7 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server"
)
-func initEventsTestData(account string, user *server.User, events ...*activity.Event) *EventsHandler {
+func initEventsTestData(account string, events ...*activity.Event) *EventsHandler {
return &EventsHandler{
accountManager: &mock_server.MockAccountManager{
GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) {
@@ -29,14 +29,8 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E
}
return []*activity.Event{}, nil
},
- GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
- return &server.Account{
- Id: claims.AccountId,
- Domain: "hotmail.com",
- Users: map[string]*server.User{
- user.Id: user,
- },
- }, user, nil
+ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
+ return claims.AccountId, claims.UserId, nil
},
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
return make([]*server.UserInfo, 0), nil
@@ -199,7 +193,7 @@ func TestEvents_GetEvents(t *testing.T) {
accountID := "test_account"
adminUser := server.NewAdminUser("test_user")
events := generateEvents(accountID, adminUser.Id)
- handler := initEventsTestData(accountID, adminUser, events...)
+ handler := initEventsTestData(accountID, events...)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
diff --git a/management/server/http/geolocation_handler_test.go b/management/server/http/geolocation_handler_test.go
index b8247f78d..19c916dd2 100644
--- a/management/server/http/geolocation_handler_test.go
+++ b/management/server/http/geolocation_handler_test.go
@@ -7,12 +7,13 @@ import (
"net/http"
"net/http/httptest"
"path"
+ "path/filepath"
"testing"
"github.com/gorilla/mux"
+ "github.com/netbirdio/netbird/management/server"
"github.com/stretchr/testify/assert"
- "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@@ -24,32 +25,29 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler {
t.Helper()
var (
- mmdbPath = "../testdata/GeoLite2-City-Test.mmdb"
- geonamesDBPath = "../testdata/geonames-test.db"
+ mmdbPath = "../testdata/GeoLite2-City_20240305.mmdb"
+ geonamesdbPath = "../testdata/geonames_20240305.db"
)
tempDir := t.TempDir()
- err := util.CopyFileContents(mmdbPath, path.Join(tempDir, geolocation.MMDBFileName))
+ err := util.CopyFileContents(mmdbPath, path.Join(tempDir, filepath.Base(mmdbPath)))
assert.NoError(t, err)
- err = util.CopyFileContents(geonamesDBPath, path.Join(tempDir, geolocation.GeoSqliteDBFile))
+ err = util.CopyFileContents(geonamesdbPath, path.Join(tempDir, filepath.Base(geonamesdbPath)))
assert.NoError(t, err)
- geo, err := geolocation.NewGeolocation(context.Background(), tempDir)
+ geo, err := geolocation.NewGeolocation(context.Background(), tempDir, false)
assert.NoError(t, err)
t.Cleanup(func() { _ = geo.Stop() })
return &GeolocationsHandler{
accountManager: &mock_server.MockAccountManager{
- GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
- user := server.NewAdminUser("test_user")
- return &server.Account{
- Id: claims.AccountId,
- Users: map[string]*server.User{
- "test_user": user,
- },
- }, user, nil
+ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
+ return claims.AccountId, claims.UserId, nil
+ },
+ GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) {
+ return server.NewAdminUser(id), nil
},
},
geolocationManager: geo,
diff --git a/management/server/http/geolocations_handler.go b/management/server/http/geolocations_handler.go
index af4d3116f..418228abf 100644
--- a/management/server/http/geolocations_handler.go
+++ b/management/server/http/geolocations_handler.go
@@ -98,7 +98,12 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.
func (l *GeolocationsHandler) authenticateUser(r *http.Request) error {
claims := l.claimsExtractor.FromRequestContext(r)
- _, user, err := l.accountManager.GetAccountFromToken(r.Context(), claims)
+ _, userID, err := l.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ if err != nil {
+ return err
+ }
+
+ user, err := l.accountManager.GetUserByID(r.Context(), userID)
if err != nil {
return err
}
diff --git a/management/server/http/groups_handler.go b/management/server/http/groups_handler.go
index c622d873a..f369d1a00 100644
--- a/management/server/http/groups_handler.go
+++ b/management/server/http/groups_handler.go
@@ -5,6 +5,7 @@ import (
"net/http"
"github.com/gorilla/mux"
+ nbpeer "github.com/netbirdio/netbird/management/server/peer"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
@@ -35,14 +36,20 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr
// GetAllGroups list for the account
func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
- groups, err := h.accountManager.GetAllGroups(r.Context(), account.Id, user.Id)
+ groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -50,7 +57,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
groupsResponse := make([]*api.Group, 0, len(groups))
for _, group := range groups {
- groupsResponse = append(groupsResponse, toGroupResponse(account, group))
+ groupsResponse = append(groupsResponse, toGroupResponse(accountPeers, group))
}
util.WriteJSONObject(r.Context(), w, groupsResponse)
@@ -59,7 +66,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) {
// UpdateGroup handles update to a group identified by a given ID
func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -76,17 +83,18 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
return
}
- eg, ok := account.Groups[groupID]
- if !ok {
- util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w)
- return
- }
-
- allGroup, err := account.GetGroupAll()
+ existingGroup, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
+
+ allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
if allGroup.ID == groupID {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w)
return
@@ -114,23 +122,29 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) {
ID: groupID,
Name: req.Name,
Peers: peers,
- Issued: eg.Issued,
- IntegrationReference: eg.IntegrationReference,
+ Issued: existingGroup.Issued,
+ IntegrationReference: existingGroup.IntegrationReference,
}
- if err := h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group); err != nil {
- log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, account.Id, err)
+ if err := h.accountManager.SaveGroup(r.Context(), accountID, userID, &group); err != nil {
+ log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err)
util.WriteError(r.Context(), err, w)
return
}
- util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group))
+ accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group))
}
// CreateGroup handles group creation request
func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -160,24 +174,29 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
Issued: nbgroup.GroupIssuedAPI,
}
- err = h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group)
+ err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group))
+ accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group))
}
// DeleteGroup handles group deletion request
func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- aID := account.Id
groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 {
@@ -185,18 +204,7 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
return
}
- allGroup, err := account.GetGroupAll()
- if err != nil {
- util.WriteError(r.Context(), err, w)
- return
- }
-
- if allGroup.ID == groupID {
- util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w)
- return
- }
-
- err = h.accountManager.DeleteGroup(r.Context(), aID, user.Id, groupID)
+ err = h.accountManager.DeleteGroup(r.Context(), accountID, userID, groupID)
if err != nil {
_, ok := err.(*server.GroupLinkError)
if ok {
@@ -213,34 +221,39 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) {
// GetGroup returns a group
func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+ groupID := mux.Vars(r)["groupId"]
+ if len(groupID) == 0 {
+ util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
+ return
+ }
+
+ group, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- switch r.Method {
- case http.MethodGet:
- groupID := mux.Vars(r)["groupId"]
- if len(groupID) == 0 {
- util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
- return
- }
-
- group, err := h.accountManager.GetGroup(r.Context(), account.Id, groupID, user.Id)
- if err != nil {
- util.WriteError(r.Context(), err, w)
- return
- }
-
- util.WriteJSONObject(r.Context(), w, toGroupResponse(account, group))
- default:
- util.WriteError(r.Context(), status.Errorf(status.NotFound, "HTTP method not found"), w)
+ accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
return
}
+
+ util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, group))
+
}
-func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group {
+func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group {
+ peersMap := make(map[string]*nbpeer.Peer, len(peers))
+ for _, peer := range peers {
+ peersMap[peer.ID] = peer
+ }
+
cache := make(map[string]api.PeerMinimum)
gr := api.Group{
Id: group.ID,
@@ -251,7 +264,7 @@ func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group {
for _, pid := range group.Peers {
_, ok := cache[pid]
if !ok {
- peer, ok := account.Peers[pid]
+ peer, ok := peersMap[pid]
if !ok {
continue
}
diff --git a/management/server/http/groups_handler_test.go b/management/server/http/groups_handler_test.go
index d5ed07c9e..7f3c81f18 100644
--- a/management/server/http/groups_handler_test.go
+++ b/management/server/http/groups_handler_test.go
@@ -14,6 +14,7 @@ import (
"github.com/gorilla/mux"
"github.com/magiconair/properties/assert"
+ "golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
@@ -30,7 +31,7 @@ var TestPeers = map[string]*nbpeer.Peer{
"B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")},
}
-func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
+func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler {
return &GroupsHandler{
accountManager: &mock_server.MockAccountManager{
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error {
@@ -40,36 +41,35 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler {
return nil
},
GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) {
- if groupID != "idofthegroup" {
+ groups := map[string]*nbgroup.Group{
+ "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT},
+ "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI},
+ "id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI},
+ }
+
+ for _, group := range initGroups {
+ groups[group.ID] = group
+ }
+
+ group, ok := groups[groupID]
+ if !ok {
return nil, status.Errorf(status.NotFound, "not found")
}
- if groupID == "id-jwt-group" {
- return &nbgroup.Group{
- ID: "id-jwt-group",
- Name: "Default Group",
- Issued: nbgroup.GroupIssuedJWT,
- }, nil
- }
- return &nbgroup.Group{
- ID: "idofthegroup",
- Name: "Group",
- Issued: nbgroup.GroupIssuedAPI,
- }, nil
+
+ return group, nil
},
- GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
- return &server.Account{
- Id: claims.AccountId,
- Domain: "hotmail.com",
- Peers: TestPeers,
- Users: map[string]*server.User{
- user.Id: user,
- },
- Groups: map[string]*nbgroup.Group{
- "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT},
- "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI},
- "id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI},
- },
- }, user, nil
+ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
+ return claims.AccountId, claims.UserId, nil
+ },
+ GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*nbgroup.Group, error) {
+ if groupName == "All" {
+ return &nbgroup.Group{ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, nil
+ }
+
+ return nil, fmt.Errorf("unknown group name")
+ },
+ GetPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
+ return maps.Values(TestPeers), nil
},
DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error {
if groupID == "linked-grp" {
@@ -125,8 +125,7 @@ func TestGetGroup(t *testing.T) {
Name: "Group",
}
- adminUser := server.NewAdminUser("test_user")
- p := initGroupTestData(adminUser, group)
+ p := initGroupTestData(group)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
@@ -247,8 +246,7 @@ func TestWriteGroup(t *testing.T) {
},
}
- adminUser := server.NewAdminUser("test_user")
- p := initGroupTestData(adminUser)
+ p := initGroupTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
@@ -325,8 +323,7 @@ func TestDeleteGroup(t *testing.T) {
},
}
- adminUser := server.NewAdminUser("test_user")
- p := initGroupTestData(adminUser)
+ p := initGroupTestData()
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
diff --git a/management/server/http/handler.go b/management/server/http/handler.go
index 3fe26d0ce..3f8a8554d 100644
--- a/management/server/http/handler.go
+++ b/management/server/http/handler.go
@@ -82,7 +82,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
AuthCfg: authCfg,
}
- if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator); err != nil {
+ if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter()); err != nil {
return nil, fmt.Errorf("register integrations endpoints: %w", err)
}
@@ -100,27 +100,6 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
api.addPostureCheckEndpoint()
api.addLocationsEndpoint()
- err := api.Router.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error {
- methods, err := route.GetMethods()
- if err != nil { // we may have wildcard routes from integrations without methods, skip them for now
- methods = []string{}
- }
- for _, method := range methods {
- template, err := route.GetPathTemplate()
- if err != nil {
- return err
- }
- err = metricsMiddleware.AddHTTPRequestResponseCounter(template, method)
- if err != nil {
- return err
- }
- }
- return nil
- })
- if err != nil {
- return nil, err
- }
-
return rootRouter, nil
}
@@ -136,6 +115,7 @@ func (apiHandler *apiHandler) addPeersEndpoint() {
apiHandler.Router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
+ apiHandler.Router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS")
}
func (apiHandler *apiHandler) addUsersEndpoint() {
diff --git a/management/server/http/nameservers_handler.go b/management/server/http/nameservers_handler.go
index c6e00bb2d..e7a2bc2ae 100644
--- a/management/server/http/nameservers_handler.go
+++ b/management/server/http/nameservers_handler.go
@@ -36,14 +36,14 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg
// GetAllNameservers returns the list of nameserver groups for the account
func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
- nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), account.Id, user.Id)
+ nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -60,7 +60,7 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re
// CreateNameserverGroup handles nameserver group creation request
func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -79,7 +79,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt
return
}
- nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled)
+ nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userID, req.SearchDomainsEnabled)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -93,7 +93,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt
// UpdateNameserverGroup handles update to a nameserver group identified by a given ID
func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -130,7 +130,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
SearchDomainsEnabled: req.SearchDomainsEnabled,
}
- err = h.accountManager.SaveNameServerGroup(r.Context(), account.Id, user.Id, updatedNSGroup)
+ err = h.accountManager.SaveNameServerGroup(r.Context(), accountID, userID, updatedNSGroup)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -144,7 +144,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt
// DeleteNameserverGroup handles nameserver group deletion request
func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -156,7 +156,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt
return
}
- err = h.accountManager.DeleteNameServerGroup(r.Context(), account.Id, nsGroupID, user.Id)
+ err = h.accountManager.DeleteNameServerGroup(r.Context(), accountID, nsGroupID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -168,7 +168,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt
// GetNameserverGroup handles a nameserver group Get request identified by ID
func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -181,7 +181,7 @@ func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.R
return
}
- nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), account.Id, user.Id, nsGroupID)
+ nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), accountID, userID, nsGroupID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
diff --git a/management/server/http/nameservers_handler_test.go b/management/server/http/nameservers_handler_test.go
index 28b080571..98c2e402d 100644
--- a/management/server/http/nameservers_handler_test.go
+++ b/management/server/http/nameservers_handler_test.go
@@ -18,7 +18,6 @@ import (
"github.com/gorilla/mux"
- "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
)
@@ -29,14 +28,6 @@ const (
testNSGroupAccountID = "test_id"
)
-var testingNSAccount = &server.Account{
- Id: testNSGroupAccountID,
- Domain: "hotmail.com",
- Users: map[string]*server.User{
- "test_user": server.NewAdminUser("test_user"),
- },
-}
-
var baseExistingNSGroup = &nbdns.NameServerGroup{
ID: existingNSGroupID,
Name: "super",
@@ -90,8 +81,8 @@ func initNameserversTestData() *NameserversHandler {
}
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
},
- GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
- return testingNSAccount, testingAccount.Users["test_user"], nil
+ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
+ return claims.AccountId, claims.UserId, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
diff --git a/management/server/http/pat_handler.go b/management/server/http/pat_handler.go
index 9d8448d3d..dfa9563e3 100644
--- a/management/server/http/pat_handler.go
+++ b/management/server/http/pat_handler.go
@@ -34,20 +34,20 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH
// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
- userID := vars["userId"]
+ targetUserID := vars["userId"]
if len(userID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
- pats, err := h.accountManager.GetAllPATs(r.Context(), account.Id, user.Id, userID)
+ pats, err := h.accountManager.GetAllPATs(r.Context(), accountID, userID, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -64,7 +64,7 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) {
// GetToken is HTTP GET handler that returns a personal access token for the given user
func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -83,7 +83,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
return
}
- pat, err := h.accountManager.GetPAT(r.Context(), account.Id, user.Id, targetUserID, tokenID)
+ pat, err := h.accountManager.GetPAT(r.Context(), accountID, userID, targetUserID, tokenID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -95,7 +95,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) {
// CreateToken is HTTP POST handler that creates a personal access token for the given user
func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -115,7 +115,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
return
}
- pat, err := h.accountManager.CreatePAT(r.Context(), account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn)
+ pat, err := h.accountManager.CreatePAT(r.Context(), accountID, userID, targetUserID, req.Name, req.ExpiresIn)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -127,7 +127,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) {
// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user
func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -146,7 +146,7 @@ func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) {
return
}
- err = h.accountManager.DeletePAT(r.Context(), account.Id, user.Id, targetUserID, tokenID)
+ err = h.accountManager.DeletePAT(r.Context(), accountID, userID, targetUserID, tokenID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
diff --git a/management/server/http/pat_handler_test.go b/management/server/http/pat_handler_test.go
index b72f71468..c28228a50 100644
--- a/management/server/http/pat_handler_test.go
+++ b/management/server/http/pat_handler_test.go
@@ -77,8 +77,8 @@ func initPATTestData() *PATHandler {
}, nil
},
- GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
- return testAccount, testAccount.Users[existingUserID], nil
+ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
+ return claims.AccountId, claims.UserId, nil
},
DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
if accountID != existingAccountID {
@@ -119,7 +119,7 @@ func initPATTestData() *PATHandler {
return jwtclaims.AuthorizationClaims{
UserId: existingUserID,
Domain: testDomain,
- AccountId: testNSGroupAccountID,
+ AccountId: existingAccountID,
}
}),
),
diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go
index 913d424d1..4fbbc3106 100644
--- a/management/server/http/peers_handler.go
+++ b/management/server/http/peers_handler.go
@@ -7,8 +7,6 @@ import (
"net/http"
"github.com/gorilla/mux"
- log "github.com/sirupsen/logrus"
-
"github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api"
@@ -16,6 +14,7 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
+ log "github.com/sirupsen/logrus"
)
// PeersHandler is a handler that returns peers of the account
@@ -71,15 +70,11 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
return
}
- customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
- netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
- accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
-
_, valid := validPeers[peer.ID]
- util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid))
+ util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid))
}
-func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) {
+func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) {
req := &api.PeerRequest{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@@ -101,7 +96,7 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
}
}
- peer, err := h.accountManager.UpdatePeer(ctx, account.Id, user.Id, update)
+ peer, err := h.accountManager.UpdatePeer(ctx, account.Id, userID, update)
if err != nil {
util.WriteError(ctx, err, w)
return
@@ -117,13 +112,9 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
return
}
- customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
- netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
- accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
-
_, valid := validPeers[peer.ID]
- util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid))
+ util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, valid))
}
func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
@@ -139,7 +130,7 @@ func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string,
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -153,13 +144,20 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodDelete:
- h.deletePeer(r.Context(), account.Id, user.Id, peerID, w)
+ h.deletePeer(r.Context(), accountID, userID, peerID, w)
return
- case http.MethodPut:
- h.updatePeer(r.Context(), account, user, peerID, w, r)
- return
- case http.MethodGet:
- h.getPeer(r.Context(), account, peerID, user.Id, w)
+ case http.MethodGet, http.MethodPut:
+ account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ if r.Method == http.MethodGet {
+ h.getPeer(r.Context(), account, peerID, userID, w)
+ } else {
+ h.updatePeer(r.Context(), account, userID, peerID, w, r)
+ }
return
default:
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
@@ -168,19 +166,14 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) {
// GetAllPeers returns a list of all peers associated with a provided account
func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
- if r.Method != http.MethodGet {
- util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
- return
- }
-
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- peers, err := h.accountManager.GetPeers(r.Context(), account.Id, user.Id)
+ account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -188,8 +181,8 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.accountManager.GetDNSDomain()
- respBody := make([]*api.PeerBatch, 0, len(peers))
- for _, peer := range peers {
+ respBody := make([]*api.PeerBatch, 0, len(account.Peers))
+ for _, peer := range account.Peers {
peerToReturn, err := h.checkPeerStatus(peer)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -220,32 +213,93 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv
}
}
+// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
+func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
+ claims := h.claimsExtractor.FromRequestContext(r)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ vars := mux.Vars(r)
+ peerID := vars["peerId"]
+ if len(peerID) == 0 {
+ util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
+ return
+ }
+
+ account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ user, err := account.FindUser(userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ // If the user is regular user and does not own the peer
+ // with the given peerID return an empty list
+ if !user.HasAdminPower() && !user.IsServiceUser {
+ peer, ok := account.Peers[peerID]
+ if !ok {
+ util.WriteError(r.Context(), status.Errorf(status.NotFound, "peer not found"), w)
+ return
+ }
+
+ if peer.UserID != user.Id {
+ util.WriteJSONObject(r.Context(), w, []api.AccessiblePeer{})
+ return
+ }
+ }
+
+ dnsDomain := h.accountManager.GetDNSDomain()
+
+ validPeers, err := h.accountManager.GetValidatedPeers(account)
+ if err != nil {
+ log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
+ util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
+ return
+ }
+
+ customZone := account.GetPeersCustomZone(r.Context(), h.accountManager.GetDNSDomain())
+ netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, nil)
+
+ util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
+}
+
func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.AccessiblePeer {
accessiblePeers := make([]api.AccessiblePeer, 0, len(netMap.Peers)+len(netMap.OfflinePeers))
for _, p := range netMap.Peers {
- ap := api.AccessiblePeer{
- Id: p.ID,
- Name: p.Name,
- Ip: p.IP.String(),
- DnsLabel: fqdn(p, dnsDomain),
- UserId: p.UserID,
- }
- accessiblePeers = append(accessiblePeers, ap)
+ accessiblePeers = append(accessiblePeers, peerToAccessiblePeer(p, dnsDomain))
}
for _, p := range netMap.OfflinePeers {
- ap := api.AccessiblePeer{
- Id: p.ID,
- Name: p.Name,
- Ip: p.IP.String(),
- DnsLabel: fqdn(p, dnsDomain),
- UserId: p.UserID,
- }
- accessiblePeers = append(accessiblePeers, ap)
+ accessiblePeers = append(accessiblePeers, peerToAccessiblePeer(p, dnsDomain))
}
+
return accessiblePeers
}
+func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePeer {
+ return api.AccessiblePeer{
+ CityName: peer.Location.CityName,
+ Connected: peer.Status.Connected,
+ CountryCode: peer.Location.CountryCode,
+ DnsLabel: fqdn(peer, dnsDomain),
+ GeonameId: int(peer.Location.GeoNameID),
+ Id: peer.ID,
+ Ip: peer.IP.String(),
+ LastSeen: peer.Status.LastSeen,
+ Name: peer.Name,
+ Os: peer.Meta.OS,
+ UserId: peer.UserID,
+ }
+}
+
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
var groupsInfo []api.GroupMinimum
groupsChecked := make(map[string]struct{})
@@ -270,7 +324,7 @@ func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMi
return groupsInfo
}
-func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeer []api.AccessiblePeer, approved bool) *api.Peer {
+func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer {
osVersion := peer.Meta.OSVersion
if osVersion == "" {
osVersion = peer.Meta.Core
@@ -296,7 +350,6 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
LoginExpirationEnabled: peer.LoginExpirationEnabled,
LastLogin: peer.LastLogin,
LoginExpired: peer.Status.LoginExpired,
- AccessiblePeers: accessiblePeer,
ApprovalRequired: !approved,
CountryCode: peer.Location.CountryCode,
CityName: peer.Location.CityName,
diff --git a/management/server/http/peers_handler_test.go b/management/server/http/peers_handler_test.go
index 153c8f03a..f933eee14 100644
--- a/management/server/http/peers_handler_test.go
+++ b/management/server/http/peers_handler_test.go
@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
+ "fmt"
"io"
"net"
"net/http"
@@ -12,20 +13,29 @@ import (
"time"
"github.com/gorilla/mux"
-
- "github.com/netbirdio/netbird/management/server/http/api"
- nbpeer "github.com/netbirdio/netbird/management/server/peer"
-
- "github.com/netbirdio/netbird/management/server/jwtclaims"
-
- "github.com/magiconair/properties/assert"
-
"github.com/netbirdio/netbird/management/server"
+ nbgroup "github.com/netbirdio/netbird/management/server/group"
+ "github.com/netbirdio/netbird/management/server/http/api"
+ "github.com/netbirdio/netbird/management/server/jwtclaims"
+ nbpeer "github.com/netbirdio/netbird/management/server/peer"
+ "golang.org/x/exp/maps"
+
+ "github.com/stretchr/testify/assert"
+
"github.com/netbirdio/netbird/management/server/mock_server"
)
-const testPeerID = "test_peer"
-const noUpdateChannelTestPeerID = "no-update-channel"
+type ctxKey string
+
+const (
+ testPeerID = "test_peer"
+ noUpdateChannelTestPeerID = "no-update-channel"
+
+ adminUser = "admin_user"
+ regularUser = "regular_user"
+ serviceUser = "service_user"
+ userIDKey ctxKey = "user_id"
+)
func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
return &PeersHandler{
@@ -59,22 +69,61 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
GetDNSDomainFunc: func() string {
return "netbird.selfhosted"
},
- GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
- user := server.NewAdminUser("test_user")
- return &server.Account{
- Id: claims.AccountId,
- Domain: "hotmail.com",
- Peers: map[string]*nbpeer.Peer{
- peers[0].ID: peers[0],
- peers[1].ID: peers[1],
+ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
+ return claims.AccountId, claims.UserId, nil
+ },
+ GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
+ peersMap := make(map[string]*nbpeer.Peer)
+ for _, peer := range peers {
+ peersMap[peer.ID] = peer.Copy()
+ }
+
+ policy := &server.Policy{
+ ID: "policy",
+ AccountID: accountID,
+ Name: "policy",
+ Enabled: true,
+ Rules: []*server.PolicyRule{
+ {
+ ID: "rule",
+ Name: "rule",
+ Enabled: true,
+ Action: "accept",
+ Destinations: []string{"group1"},
+ Sources: []string{"group1"},
+ Bidirectional: true,
+ Protocol: "all",
+ Ports: []string{"80"},
+ },
},
+ }
+
+ srvUser := server.NewRegularUser(serviceUser)
+ srvUser.IsServiceUser = true
+
+ account := &server.Account{
+ Id: accountID,
+ Domain: "hotmail.com",
+ Peers: peersMap,
Users: map[string]*server.User{
- "test_user": user,
+ adminUser: server.NewAdminUser(adminUser),
+ regularUser: server.NewRegularUser(regularUser),
+ serviceUser: srvUser,
+ },
+ Groups: map[string]*nbgroup.Group{
+ "group1": {
+ ID: "group1",
+ AccountID: accountID,
+ Name: "group1",
+ Issued: "api",
+ Peers: maps.Keys(peersMap),
+ },
},
Settings: &server.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: time.Hour,
},
+ Policies: []*server.Policy{policy},
Network: &server.Network{
Identifier: "ciclqisab2ss43jdn8q0",
Net: net.IPNet{
@@ -83,7 +132,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
},
Serial: 51,
},
- }, user, nil
+ }
+
+ return account, nil
},
HasConnectedChannelFunc: func(peerID string) bool {
statuses := make(map[string]struct{})
@@ -99,8 +150,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
+ userID := r.Context().Value(userIDKey).(string)
return jwtclaims.AuthorizationClaims{
- UserId: "test_user",
+ UserId: userID,
Domain: "hotmail.com",
AccountId: "test_id",
}
@@ -197,6 +249,8 @@ func TestGetPeers(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
+ ctx := context.WithValue(context.Background(), userIDKey, "admin_user")
+ req = req.WithContext(ctx)
router := mux.NewRouter()
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
@@ -227,9 +281,15 @@ func TestGetPeers(t *testing.T) {
// hardcode this check for now as we only have two peers in this suite
assert.Equal(t, len(respBody), 2)
- assert.Equal(t, respBody[1].Connected, false)
- got = respBody[0]
+ for _, peer := range respBody {
+ if peer.Id == testPeerID {
+ got = peer
+ } else {
+ assert.Equal(t, peer.Connected, false)
+ }
+ }
+
} else {
got = &api.Peer{}
err = json.Unmarshal(content, got)
@@ -251,3 +311,119 @@ func TestGetPeers(t *testing.T) {
})
}
}
+
+func TestGetAccessiblePeers(t *testing.T) {
+ peer1 := &nbpeer.Peer{
+ ID: "peer1",
+ Key: "key1",
+ IP: net.ParseIP("100.64.0.1"),
+ Status: &nbpeer.PeerStatus{Connected: true},
+ Name: "peer1",
+ LoginExpirationEnabled: false,
+ UserID: regularUser,
+ }
+
+ peer2 := &nbpeer.Peer{
+ ID: "peer2",
+ Key: "key2",
+ IP: net.ParseIP("100.64.0.2"),
+ Status: &nbpeer.PeerStatus{Connected: true},
+ Name: "peer2",
+ LoginExpirationEnabled: false,
+ UserID: adminUser,
+ }
+
+ peer3 := &nbpeer.Peer{
+ ID: "peer3",
+ Key: "key3",
+ IP: net.ParseIP("100.64.0.3"),
+ Status: &nbpeer.PeerStatus{Connected: true},
+ Name: "peer3",
+ LoginExpirationEnabled: false,
+ UserID: regularUser,
+ }
+
+ p := initTestMetaData(peer1, peer2, peer3)
+
+ tt := []struct {
+ name string
+ peerID string
+ callerUserID string
+ expectedStatus int
+ expectedPeers []string
+ }{
+ {
+ name: "non admin user can access owned peer",
+ peerID: "peer1",
+ callerUserID: regularUser,
+ expectedStatus: http.StatusOK,
+ expectedPeers: []string{"peer2", "peer3"},
+ },
+ {
+ name: "non admin user can't access unowned peer",
+ peerID: "peer2",
+ callerUserID: regularUser,
+ expectedStatus: http.StatusOK,
+ expectedPeers: []string{},
+ },
+ {
+ name: "admin user can access owned peer",
+ peerID: "peer2",
+ callerUserID: adminUser,
+ expectedStatus: http.StatusOK,
+ expectedPeers: []string{"peer1", "peer3"},
+ },
+ {
+ name: "admin user can access unowned peer",
+ peerID: "peer3",
+ callerUserID: adminUser,
+ expectedStatus: http.StatusOK,
+ expectedPeers: []string{"peer1", "peer2"},
+ },
+ {
+ name: "service user can access unowned peer",
+ peerID: "peer3",
+ callerUserID: serviceUser,
+ expectedStatus: http.StatusOK,
+ expectedPeers: []string{"peer1", "peer2"},
+ },
+ }
+
+ for _, tc := range tt {
+ t.Run(tc.name, func(t *testing.T) {
+
+ recorder := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil)
+ ctx := context.WithValue(context.Background(), userIDKey, tc.callerUserID)
+ req = req.WithContext(ctx)
+
+ router := mux.NewRouter()
+ router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET")
+ router.ServeHTTP(recorder, req)
+
+ res := recorder.Result()
+ if res.StatusCode != tc.expectedStatus {
+ t.Fatalf("handler returned wrong status code: got %v want %v", res.StatusCode, tc.expectedStatus)
+ }
+
+ body, err := io.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("failed to read response body: %v", err)
+ }
+ defer res.Body.Close()
+
+ var accessiblePeers []api.AccessiblePeer
+ err = json.Unmarshal(body, &accessiblePeers)
+ if err != nil {
+ t.Fatalf("failed to unmarshal response: %v", err)
+ }
+
+ peerIDs := make([]string, len(accessiblePeers))
+ for i, peer := range accessiblePeers {
+ peerIDs[i] = peer.Id
+ }
+
+ assert.ElementsMatch(t, peerIDs, tc.expectedPeers)
+ })
+ }
+}
diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go
index 9622668f4..73f3803b5 100644
--- a/management/server/http/policies_handler.go
+++ b/management/server/http/policies_handler.go
@@ -6,6 +6,7 @@ import (
"strconv"
"github.com/gorilla/mux"
+ nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server"
@@ -35,21 +36,27 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) *
// GetAllPolicies list for the account
func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- accountPolicies, err := h.accountManager.ListPolicies(r.Context(), account.Id, user.Id)
+ listPolicies, err := h.accountManager.ListPolicies(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- policies := []*api.Policy{}
- for _, policy := range accountPolicies {
- resp := toPolicyResponse(account, policy)
+ allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ policies := make([]*api.Policy, 0, len(listPolicies))
+ for _, policy := range listPolicies {
+ resp := toPolicyResponse(allGroups, policy)
if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return
@@ -63,7 +70,7 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
// UpdatePolicy handles update to a policy identified by a given ID
func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -76,41 +83,29 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
return
}
- policyIdx := -1
- for i, policy := range account.Policies {
- if policy.ID == policyID {
- policyIdx = i
- break
- }
- }
- if policyIdx < 0 {
- util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
- return
- }
-
- h.savePolicy(w, r, account, user, policyID)
-}
-
-// CreatePolicy handles policy creation request
-func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
- claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ _, err = h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- h.savePolicy(w, r, account, user, "")
+ h.savePolicy(w, r, accountID, userID, policyID)
+}
+
+// CreatePolicy handles policy creation request
+func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
+ claims := h.claimsExtractor.FromRequestContext(r)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ h.savePolicy(w, r, accountID, userID, "")
}
// savePolicy handles policy creation and update
-func (h *Policies) savePolicy(
- w http.ResponseWriter,
- r *http.Request,
- account *server.Account,
- user *server.User,
- policyID string,
-) {
+func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) {
var req api.PutApiPoliciesPolicyIdJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -127,6 +122,8 @@ func (h *Policies) savePolicy(
return
}
+ isUpdate := policyID != ""
+
if policyID == "" {
policyID = xid.New().String()
}
@@ -141,8 +138,8 @@ func (h *Policies) savePolicy(
pr := server.PolicyRule{
ID: policyID, // TODO: when policy can contain multiple rules, need refactor
Name: rule.Name,
- Destinations: groupMinimumsToStrings(account, rule.Destinations),
- Sources: groupMinimumsToStrings(account, rule.Sources),
+ Destinations: rule.Destinations,
+ Sources: rule.Sources,
Bidirectional: rule.Bidirectional,
}
@@ -175,6 +172,11 @@ func (h *Policies) savePolicy(
return
}
+ if (rule.Ports != nil && len(*rule.Ports) != 0) && (rule.PortRanges != nil && len(*rule.PortRanges) != 0) {
+ util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "specify either individual ports or port ranges, not both"), w)
+ return
+ }
+
if rule.Ports != nil && len(*rule.Ports) != 0 {
for _, v := range *rule.Ports {
if port, err := strconv.Atoi(v); err != nil || port < 1 || port > 65535 {
@@ -185,10 +187,23 @@ func (h *Policies) savePolicy(
}
}
+ if rule.PortRanges != nil && len(*rule.PortRanges) != 0 {
+ for _, portRange := range *rule.PortRanges {
+ if portRange.Start < 1 || portRange.End > 65535 {
+ util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w)
+ return
+ }
+ pr.PortRanges = append(pr.PortRanges, server.RulePortRange{
+ Start: uint16(portRange.Start),
+ End: uint16(portRange.End),
+ })
+ }
+ }
+
// validate policy object
switch pr.Protocol {
case server.PolicyRuleProtocolALL, server.PolicyRuleProtocolICMP:
- if len(pr.Ports) != 0 {
+ if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w)
return
}
@@ -197,7 +212,7 @@ func (h *Policies) savePolicy(
return
}
case server.PolicyRuleProtocolTCP, server.PolicyRuleProtocolUDP:
- if !pr.Bidirectional && len(pr.Ports) == 0 {
+ if !pr.Bidirectional && (len(pr.Ports) == 0 || len(pr.PortRanges) != 0) {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w)
return
}
@@ -207,15 +222,21 @@ func (h *Policies) savePolicy(
}
if req.SourcePostureChecks != nil {
- policy.SourcePostureChecks = sourcePostureChecksToStrings(account, *req.SourcePostureChecks)
+ policy.SourcePostureChecks = *req.SourcePostureChecks
}
- if err := h.accountManager.SavePolicy(r.Context(), account.Id, user.Id, &policy); err != nil {
+ if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil {
util.WriteError(r.Context(), err, w)
return
}
- resp := toPolicyResponse(account, &policy)
+ allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ resp := toPolicyResponse(allGroups, &policy)
if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return
@@ -227,12 +248,11 @@ func (h *Policies) savePolicy(
// DeletePolicy handles policy deletion request
func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- aID := account.Id
vars := mux.Vars(r)
policyID := vars["policyId"]
@@ -241,7 +261,7 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
return
}
- if err = h.accountManager.DeletePolicy(r.Context(), aID, policyID, user.Id); err != nil {
+ if err = h.accountManager.DeletePolicy(r.Context(), accountID, policyID, userID); err != nil {
util.WriteError(r.Context(), err, w)
return
}
@@ -252,40 +272,46 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
// GetPolicy handles a group Get request identified by ID
func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- switch r.Method {
- case http.MethodGet:
- vars := mux.Vars(r)
- policyID := vars["policyId"]
- if len(policyID) == 0 {
- util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
- return
- }
-
- policy, err := h.accountManager.GetPolicy(r.Context(), account.Id, policyID, user.Id)
- if err != nil {
- util.WriteError(r.Context(), err, w)
- return
- }
-
- resp := toPolicyResponse(account, policy)
- if len(resp.Rules) == 0 {
- util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
- return
- }
-
- util.WriteJSONObject(r.Context(), w, resp)
- default:
- util.WriteError(r.Context(), status.Errorf(status.NotFound, "method not found"), w)
+ vars := mux.Vars(r)
+ policyID := vars["policyId"]
+ if len(policyID) == 0 {
+ util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
+ return
}
+
+ policy, err := h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ resp := toPolicyResponse(allGroups, policy)
+ if len(resp.Rules) == 0 {
+ util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
+ return
+ }
+
+ util.WriteJSONObject(r.Context(), w, resp)
}
-func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Policy {
+func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Policy {
+ groupsMap := make(map[string]*nbgroup.Group)
+ for _, group := range groups {
+ groupsMap[group.ID] = group
+ }
+
cache := make(map[string]api.GroupMinimum)
ap := &api.Policy{
Id: &policy.ID,
@@ -306,16 +332,29 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic
Protocol: api.PolicyRuleProtocol(r.Protocol),
Action: api.PolicyRuleAction(r.Action),
}
+
if len(r.Ports) != 0 {
portsCopy := r.Ports
rule.Ports = &portsCopy
}
+
+ if len(r.PortRanges) != 0 {
+ portRanges := make([]api.RulePortRange, 0, len(r.PortRanges))
+ for _, portRange := range r.PortRanges {
+ portRanges = append(portRanges, api.RulePortRange{
+ End: int(portRange.End),
+ Start: int(portRange.Start),
+ })
+ }
+ rule.PortRanges = &portRanges
+ }
+
for _, gid := range r.Sources {
_, ok := cache[gid]
if ok {
continue
}
- if group, ok := account.Groups[gid]; ok {
+ if group, ok := groupsMap[gid]; ok {
minimum := api.GroupMinimum{
Id: group.ID,
Name: group.Name,
@@ -325,13 +364,14 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic
cache[gid] = minimum
}
}
+
for _, gid := range r.Destinations {
cachedMinimum, ok := cache[gid]
if ok {
rule.Destinations = append(rule.Destinations, cachedMinimum)
continue
}
- if group, ok := account.Groups[gid]; ok {
+ if group, ok := groupsMap[gid]; ok {
minimum := api.GroupMinimum{
Id: group.ID,
Name: group.Name,
@@ -345,28 +385,3 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic
}
return ap
}
-
-func groupMinimumsToStrings(account *server.Account, gm []string) []string {
- result := make([]string, 0, len(gm))
- for _, g := range gm {
- if _, ok := account.Groups[g]; !ok {
- continue
- }
- result = append(result, g)
- }
- return result
-}
-
-func sourcePostureChecksToStrings(account *server.Account, postureChecksIds []string) []string {
- result := make([]string, 0, len(postureChecksIds))
- for _, id := range postureChecksIds {
- for _, postureCheck := range account.PostureChecks {
- if id == postureCheck.ID {
- result = append(result, id)
- continue
- }
- }
-
- }
- return result
-}
diff --git a/management/server/http/policies_handler_test.go b/management/server/http/policies_handler_test.go
index 06274fb07..228ebcbce 100644
--- a/management/server/http/policies_handler_test.go
+++ b/management/server/http/policies_handler_test.go
@@ -38,17 +38,23 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
}
return policy, nil
},
- SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) error {
+ SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error {
if !strings.HasPrefix(policy.ID, "id-") {
policy.ID = "id-was-set"
policy.Rules[0].ID = "id-was-set"
}
return nil
},
- GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
- user := server.NewAdminUser("test_user")
+ GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
+ return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil
+ },
+ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
+ return claims.AccountId, claims.UserId, nil
+ },
+ GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) {
+ user := server.NewAdminUser(userID)
return &server.Account{
- Id: claims.AccountId,
+ Id: accountID,
Domain: "hotmail.com",
Policies: []*server.Policy{
{ID: "id-existed"},
@@ -60,7 +66,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
Users: map[string]*server.User{
"test_user": user,
},
- }, user, nil
+ }, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
diff --git a/management/server/http/posture_checks_handler.go b/management/server/http/posture_checks_handler.go
index 059cb3b80..1d020e9bc 100644
--- a/management/server/http/posture_checks_handler.go
+++ b/management/server/http/posture_checks_handler.go
@@ -37,20 +37,20 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa
// GetAllPostureChecks list for the account
func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
- account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- accountPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), account.Id, user.Id)
+ listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- postureChecks := []*api.PostureCheck{}
- for _, postureCheck := range accountPostureChecks {
+ postureChecks := make([]*api.PostureCheck, 0, len(listPostureChecks))
+ for _, postureCheck := range listPostureChecks {
postureChecks = append(postureChecks, postureCheck.ToAPIResponse())
}
@@ -60,7 +60,7 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt
// UpdatePostureCheck handles update to a posture check identified by a given ID
func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
- account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -73,37 +73,31 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http
return
}
- postureChecksIdx := -1
- for i, postureCheck := range account.PostureChecks {
- if postureCheck.ID == postureChecksID {
- postureChecksIdx = i
- break
- }
- }
- if postureChecksIdx < 0 {
- util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w)
+ _, err = p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
return
}
- p.savePostureChecks(w, r, account, user, postureChecksID)
+ p.savePostureChecks(w, r, accountID, userID, postureChecksID)
}
// CreatePostureCheck handles posture check creation request
func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
- account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- p.savePostureChecks(w, r, account, user, "")
+ p.savePostureChecks(w, r, accountID, userID, "")
}
// GetPostureCheck handles a posture check Get request identified by ID
func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
- account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -116,7 +110,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re
return
}
- postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), account.Id, postureChecksID, user.Id)
+ postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -128,7 +122,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re
// DeletePostureCheck handles posture check deletion request
func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) {
claims := p.claimsExtractor.FromRequestContext(r)
- account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -141,7 +135,7 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http
return
}
- if err = p.accountManager.DeletePostureChecks(r.Context(), account.Id, postureChecksID, user.Id); err != nil {
+ if err = p.accountManager.DeletePostureChecks(r.Context(), accountID, postureChecksID, userID); err != nil {
util.WriteError(r.Context(), err, w)
return
}
@@ -150,13 +144,7 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http
}
// savePostureChecks handles posture checks create and update
-func (p *PostureChecksHandler) savePostureChecks(
- w http.ResponseWriter,
- r *http.Request,
- account *server.Account,
- user *server.User,
- postureChecksID string,
-) {
+func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) {
var (
err error
req api.PostureCheckUpdate
@@ -181,7 +169,7 @@ func (p *PostureChecksHandler) savePostureChecks(
return
}
- if err := p.accountManager.SavePostureChecks(r.Context(), account.Id, user.Id, postureChecks); err != nil {
+ if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil {
util.WriteError(r.Context(), err, w)
return
}
diff --git a/management/server/http/posture_checks_handler_test.go b/management/server/http/posture_checks_handler_test.go
index 974edafde..02f0f0d83 100644
--- a/management/server/http/posture_checks_handler_test.go
+++ b/management/server/http/posture_checks_handler_test.go
@@ -14,7 +14,6 @@ import (
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
- "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@@ -67,15 +66,8 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
}
return accountPostureChecks, nil
},
- GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
- user := server.NewAdminUser("test_user")
- return &server.Account{
- Id: claims.AccountId,
- Users: map[string]*server.User{
- "test_user": user,
- },
- PostureChecks: postureChecks,
- }, user, nil
+ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
+ return claims.AccountId, claims.UserId, nil
},
},
geolocationManager: &geolocation.Geolocation{},
diff --git a/management/server/http/routes_handler.go b/management/server/http/routes_handler.go
index 18c347334..ce4edee4f 100644
--- a/management/server/http/routes_handler.go
+++ b/management/server/http/routes_handler.go
@@ -43,13 +43,13 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro
// GetAllRoutes returns the list of routes for the account
func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- routes, err := h.accountManager.ListRoutes(r.Context(), account.Id, user.Id)
+ routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -70,7 +70,7 @@ func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) {
// CreateRoute handles route creation request
func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -117,15 +117,14 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) {
peerGroupIds = *req.PeerGroups
}
- // Do not allow non-Linux peers
- if peer := account.GetPeer(peerId); peer != nil {
- if peer.Meta.GoOS != "linux" {
- util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w)
- return
- }
+ var accessControlGroupIds []string
+ if req.AccessControlGroups != nil {
+ accessControlGroupIds = *req.AccessControlGroups
}
- newRoute, err := h.accountManager.CreateRoute(r.Context(), account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute)
+ newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds,
+ req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userID, req.KeepRoute)
+
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -168,7 +167,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro
// UpdateRoute handles update to a route identified by a given ID
func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -181,7 +180,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
return
}
- _, err = h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
+ _, err = h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -204,14 +203,6 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
peerID = *req.Peer
}
- // do not allow non Linux peers
- if peer := account.GetPeer(peerID); peer != nil {
- if peer.Meta.GoOS != "linux" {
- util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w)
- return
- }
- }
-
newRoute := &route.Route{
ID: route.ID(routeID),
NetID: route.NetID(req.NetworkId),
@@ -247,7 +238,11 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
newRoute.PeerGroups = *req.PeerGroups
}
- err = h.accountManager.SaveRoute(r.Context(), account.Id, user.Id, newRoute)
+ if req.AccessControlGroups != nil {
+ newRoute.AccessControlGroups = *req.AccessControlGroups
+ }
+
+ err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -265,7 +260,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) {
// DeleteRoute handles route deletion request
func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -277,7 +272,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
return
}
- err = h.accountManager.DeleteRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
+ err = h.accountManager.DeleteRoute(r.Context(), accountID, route.ID(routeID), userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -289,7 +284,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) {
// GetRoute handles a route Get request identified by ID
func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -301,7 +296,7 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) {
return
}
- foundRoute, err := h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id)
+ foundRoute, err := h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w)
return
@@ -340,6 +335,9 @@ func toRouteResponse(serverRoute *route.Route) (*api.Route, error) {
if len(serverRoute.PeerGroups) > 0 {
route.PeerGroups = &serverRoute.PeerGroups
}
+ if len(serverRoute.AccessControlGroups) > 0 {
+ route.AccessControlGroups = &serverRoute.AccessControlGroups
+ }
return route, nil
}
diff --git a/management/server/http/routes_handler_test.go b/management/server/http/routes_handler_test.go
index 40075eb9d..83bd7004d 100644
--- a/management/server/http/routes_handler_test.go
+++ b/management/server/http/routes_handler_test.go
@@ -105,32 +105,44 @@ func initRoutesTestData() *RoutesHandler {
}
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
},
- CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) {
+ CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) {
if peerID == notFoundPeerID {
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
}
if len(peerGroups) > 0 && peerGroups[0] == notFoundGroupID {
return nil, status.Errorf(status.InvalidArgument, "peer groups with ID %s not found", peerGroups[0])
}
+ if peerID != "" {
+ if peerID == nonLinuxExistingPeerID {
+ return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
+ }
+ }
+
return &route.Route{
- ID: existingRouteID,
- NetID: netID,
- Peer: peerID,
- PeerGroups: peerGroups,
- Network: prefix,
- Domains: domains,
- NetworkType: networkType,
- Description: description,
- Masquerade: masquerade,
- Enabled: enabled,
- Groups: groups,
- KeepRoute: keepRoute,
+ ID: existingRouteID,
+ NetID: netID,
+ Peer: peerID,
+ PeerGroups: peerGroups,
+ Network: prefix,
+ Domains: domains,
+ NetworkType: networkType,
+ Description: description,
+ Masquerade: masquerade,
+ Enabled: enabled,
+ Groups: groups,
+ KeepRoute: keepRoute,
+ AccessControlGroups: accessControlGroups,
}, nil
},
SaveRouteFunc: func(_ context.Context, _, _ string, r *route.Route) error {
if r.Peer == notFoundPeerID {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", r.Peer)
}
+
+ if r.Peer == nonLinuxExistingPeerID {
+ return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
+ }
+
return nil
},
DeleteRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) error {
@@ -139,8 +151,9 @@ func initRoutesTestData() *RoutesHandler {
}
return nil
},
- GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
- return testingAccount, testingAccount.Users["test_user"], nil
+ GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) {
+ //return testingAccount, testingAccount.Users["test_user"], nil
+ return testingAccount.Id, testingAccount.Users["test_user"].Id, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
@@ -256,6 +269,27 @@ func TestRoutesHandlers(t *testing.T) {
Groups: []string{existingGroupID},
},
},
+ {
+ name: "POST OK With Access Control Groups",
+ requestType: http.MethodPost,
+ requestPath: "/api/routes",
+ requestBody: bytes.NewBuffer(
+ []byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"],\"access_control_groups\":[\"%s\"]}", existingPeerID, existingGroupID, existingGroupID))),
+ expectedStatus: http.StatusOK,
+ expectedBody: true,
+ expectedRoute: &api.Route{
+ Id: existingRouteID,
+ Description: "Post",
+ NetworkId: "awesomeNet",
+ Network: toPtr("192.168.0.0/16"),
+ Peer: &existingPeerID,
+ NetworkType: route.IPv4NetworkString,
+ Masquerade: false,
+ Enabled: false,
+ Groups: []string{existingGroupID},
+ AccessControlGroups: &[]string{existingGroupID},
+ },
+ },
{
name: "POST Non Linux Peer",
requestType: http.MethodPost,
diff --git a/management/server/http/setupkeys_handler.go b/management/server/http/setupkeys_handler.go
index 8ee7dfaba..8514f0b55 100644
--- a/management/server/http/setupkeys_handler.go
+++ b/management/server/http/setupkeys_handler.go
@@ -35,7 +35,7 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg)
// CreateSetupKey is a POST requests that creates a new SetupKey
func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -76,8 +76,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
if req.Ephemeral != nil {
ephemeral = *req.Ephemeral
}
- setupKey, err := h.accountManager.CreateSetupKey(r.Context(), account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn,
- req.AutoGroups, req.UsageLimit, user.Id, ephemeral)
+ setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn,
+ req.AutoGroups, req.UsageLimit, userID, ephemeral)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -89,7 +89,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request
// GetSetupKey is a GET request to get a SetupKey by ID
func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -102,7 +102,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
return
}
- key, err := h.accountManager.GetSetupKey(r.Context(), account.Id, user.Id, keyID)
+ key, err := h.accountManager.GetSetupKey(r.Context(), accountID, userID, keyID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -114,7 +114,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) {
// UpdateSetupKey is a PUT request to update server.SetupKey
func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -150,7 +150,7 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
newKey.Name = req.Name
newKey.Id = keyID
- newKey, err = h.accountManager.SaveSetupKey(r.Context(), account.Id, newKey, user.Id)
+ newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -161,13 +161,13 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
// GetAllSetupKeys is a GET request that returns a list of SetupKey
func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), account.Id, user.Id)
+ setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
diff --git a/management/server/http/setupkeys_handler_test.go b/management/server/http/setupkeys_handler_test.go
index bfa0ec008..2d15287af 100644
--- a/management/server/http/setupkeys_handler_test.go
+++ b/management/server/http/setupkeys_handler_test.go
@@ -15,7 +15,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
- nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/mock_server"
@@ -34,21 +33,8 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
) *SetupKeysHandler {
return &SetupKeysHandler{
accountManager: &mock_server.MockAccountManager{
- GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
- return &server.Account{
- Id: testAccountID,
- Domain: "hotmail.com",
- Users: map[string]*server.User{
- user.Id: user,
- },
- SetupKeys: map[string]*server.SetupKey{
- defaultKey.Key: defaultKey,
- },
- Groups: map[string]*nbgroup.Group{
- "group-1": {ID: "group-1", Peers: []string{"A", "B"}},
- "id-all": {ID: "id-all", Name: "All"},
- },
- }, user, nil
+ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
+ return claims.AccountId, claims.UserId, nil
},
CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string,
_ int, _ string, ephemeral bool,
diff --git a/management/server/http/users_handler.go b/management/server/http/users_handler.go
index 2c2aed842..6e151a0da 100644
--- a/management/server/http/users_handler.go
+++ b/management/server/http/users_handler.go
@@ -41,22 +41,22 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
}
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
- userID := vars["userId"]
- if len(userID) == 0 {
+ targetUserID := vars["userId"]
+ if len(targetUserID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
- existingUser, ok := account.Users[userID]
- if !ok {
- util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w)
+ existingUser, err := h.accountManager.GetUserByID(r.Context(), targetUserID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
return
}
@@ -78,8 +78,8 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
return
}
- newUser, err := h.accountManager.SaveUser(r.Context(), account.Id, user.Id, &server.User{
- Id: userID,
+ newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &server.User{
+ Id: targetUserID,
Role: userRole,
AutoGroups: req.AutoGroups,
Blocked: req.IsBlocked,
@@ -102,7 +102,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
}
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -115,7 +115,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) {
return
}
- err = h.accountManager.DeleteUser(r.Context(), account.Id, user.Id, targetUserID)
+ err = h.accountManager.DeleteUser(r.Context(), accountID, userID, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -132,7 +132,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
}
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -160,7 +160,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) {
name = *req.Name
}
- newUser, err := h.accountManager.CreateUser(r.Context(), account.Id, user.Id, &server.UserInfo{
+ newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &server.UserInfo{
Email: email,
Name: name,
Role: req.Role,
@@ -184,13 +184,13 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) {
}
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- data, err := h.accountManager.GetUsersFromAccount(r.Context(), account.Id, user.Id)
+ data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -231,7 +231,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
}
claims := h.claimsExtractor.FromRequestContext(r)
- account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims)
+ accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -244,7 +244,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) {
return
}
- err = h.accountManager.InviteUser(r.Context(), account.Id, user.Id, targetUserID)
+ err = h.accountManager.InviteUser(r.Context(), accountID, userID, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
diff --git a/management/server/http/users_handler_test.go b/management/server/http/users_handler_test.go
index a78ac3a4e..f3d989da1 100644
--- a/management/server/http/users_handler_test.go
+++ b/management/server/http/users_handler_test.go
@@ -64,8 +64,11 @@ var usersTestAccount = &server.Account{
func initUsersTestData() *UsersHandler {
return &UsersHandler{
accountManager: &mock_server.MockAccountManager{
- GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
- return usersTestAccount, usersTestAccount.Users[claims.UserId], nil
+ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
+ return usersTestAccount.Id, claims.UserId, nil
+ },
+ GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) {
+ return usersTestAccount.Users[id], nil
},
GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) {
users := make([]*server.UserInfo, 0)
diff --git a/management/server/idp/zitadel.go b/management/server/idp/zitadel.go
index 729b49733..9d7626844 100644
--- a/management/server/idp/zitadel.go
+++ b/management/server/idp/zitadel.go
@@ -2,10 +2,12 @@ package idp
import (
"context"
+ "errors"
"fmt"
"io"
"net/http"
"net/url"
+ "slices"
"strings"
"sync"
"time"
@@ -97,6 +99,42 @@ type zitadelUserResponse struct {
PasswordlessRegistration zitadelPasswordlessRegistration `json:"passwordlessRegistration"`
}
+// readZitadelError parses errors returned by the zitadel APIs from a response.
+func readZitadelError(body io.ReadCloser) error {
+ bodyBytes, err := io.ReadAll(body)
+ if err != nil {
+ return fmt.Errorf("failed to read response body: %w", err)
+ }
+
+ helper := JsonParser{}
+ var target map[string]interface{}
+ err = helper.Unmarshal(bodyBytes, &target)
+ if err != nil {
+ return fmt.Errorf("error unparsable body: %s", string(bodyBytes))
+ }
+
+ // ensure keys are ordered for consistent logging behaviour.
+ errorKeys := make([]string, 0, len(target))
+ for k := range target {
+ errorKeys = append(errorKeys, k)
+ }
+ slices.Sort(errorKeys)
+
+ var errsOut []string
+ for _, k := range errorKeys {
+ if _, isEmbedded := target[k].(map[string]interface{}); isEmbedded {
+ continue
+ }
+ errsOut = append(errsOut, fmt.Sprintf("%s: %v", k, target[k]))
+ }
+
+ if len(errsOut) == 0 {
+ return errors.New("unknown error")
+ }
+
+ return errors.New(strings.Join(errsOut, " "))
+}
+
// NewZitadelManager creates a new instance of the ZitadelManager.
func NewZitadelManager(config ZitadelClientConfig, appMetrics telemetry.AppMetrics) (*ZitadelManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
@@ -176,7 +214,8 @@ func (zc *ZitadelCredentials) requestJWTToken(ctx context.Context) (*http.Respon
}
if resp.StatusCode != http.StatusOK {
- return nil, fmt.Errorf("unable to get zitadel token, statusCode %d", resp.StatusCode)
+ zErr := readZitadelError(resp.Body)
+ return nil, fmt.Errorf("unable to get zitadel token, statusCode %d, zitadel: %w", resp.StatusCode, zErr)
}
return resp, nil
@@ -489,7 +528,9 @@ func (zm *ZitadelManager) post(ctx context.Context, resource string, body string
zm.appMetrics.IDPMetrics().CountRequestStatusError()
}
- return nil, fmt.Errorf("unable to post %s, statusCode %d", reqURL, resp.StatusCode)
+ zErr := readZitadelError(resp.Body)
+
+ return nil, fmt.Errorf("unable to post %s, statusCode %d, zitadel: %w", reqURL, resp.StatusCode, zErr)
}
return io.ReadAll(resp.Body)
@@ -561,7 +602,9 @@ func (zm *ZitadelManager) get(ctx context.Context, resource string, q url.Values
zm.appMetrics.IDPMetrics().CountRequestStatusError()
}
- return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
+ zErr := readZitadelError(resp.Body)
+
+ return nil, fmt.Errorf("unable to get %s, statusCode %d, zitadel: %w", reqURL, resp.StatusCode, zErr)
}
return io.ReadAll(resp.Body)
diff --git a/management/server/idp/zitadel_test.go b/management/server/idp/zitadel_test.go
index 6bc612e78..722f94fe0 100644
--- a/management/server/idp/zitadel_test.go
+++ b/management/server/idp/zitadel_test.go
@@ -66,7 +66,6 @@ func TestNewZitadelManager(t *testing.T) {
}
func TestZitadelRequestJWTToken(t *testing.T) {
-
type requestJWTTokenTest struct {
name string
inputCode int
@@ -88,15 +87,14 @@ func TestZitadelRequestJWTToken(t *testing.T) {
requestJWTTokenTestCase2 := requestJWTTokenTest{
name: "Request Bad Status Code",
inputCode: 400,
- inputRespBody: "{}",
+ inputRespBody: "{\"error\": \"invalid_scope\", \"error_description\":\"openid missing\"}",
helper: JsonParser{},
- expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400"),
+ expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400, zitadel: error: invalid_scope error_description: openid missing"),
expectedToken: "",
}
for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} {
t.Run(testCase.name, func(t *testing.T) {
-
jwtReqClient := mockHTTPClient{
resBody: testCase.inputRespBody,
code: testCase.inputCode,
@@ -156,7 +154,7 @@ func TestZitadelParseRequestJWTResponse(t *testing.T) {
}
parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{
name: "Parse Bad json JWT Body",
- inputRespBody: "",
+ inputRespBody: "{}",
helper: JsonParser{},
expectedToken: "",
expectedExpiresIn: 0,
@@ -254,7 +252,7 @@ func TestZitadelAuthenticate(t *testing.T) {
inputCode: 400,
inputResBody: "{}",
helper: JsonParser{},
- expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400"),
+ expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400, zitadel: unknown error"),
expectedCode: 200,
expectedToken: "",
}
diff --git a/management/server/jwtclaims/jwtValidator.go b/management/server/jwtclaims/jwtValidator.go
index 39676982e..d5c1e7c9e 100644
--- a/management/server/jwtclaims/jwtValidator.go
+++ b/management/server/jwtclaims/jwtValidator.go
@@ -1,14 +1,12 @@
package jwtclaims
import (
- "bytes"
"context"
+ "crypto/ecdsa"
+ "crypto/elliptic"
"crypto/rsa"
- "crypto/x509"
"encoding/base64"
- "encoding/binary"
"encoding/json"
- "encoding/pem"
"errors"
"fmt"
"math/big"
@@ -41,11 +39,6 @@ type Options struct {
// When set, all requests with the OPTIONS method will use authentication
// Default: false
EnableAuthOnOptions bool
- // When set, the middelware verifies that tokens are signed with the specific signing algorithm
- // If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks
- // Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/
- // Default: nil
- SigningMethod jwt.SigningMethod
}
// Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation
@@ -54,6 +47,18 @@ type Jwks struct {
expiresInTime time.Time
}
+// The supported elliptic curves types
+const (
+ // p256 represents a cryptographic elliptical curve type.
+ p256 = "P-256"
+
+ // p384 represents a cryptographic elliptical curve type.
+ p384 = "P-384"
+
+ // p521 represents a cryptographic elliptical curve type.
+ p521 = "P-521"
+)
+
// JSONWebKey is a representation of a Jason Web Key
type JSONWebKey struct {
Kty string `json:"kty"`
@@ -61,6 +66,9 @@ type JSONWebKey struct {
Use string `json:"use"`
N string `json:"n"`
E string `json:"e"`
+ Crv string `json:"crv"`
+ X string `json:"x"`
+ Y string `json:"y"`
X5c []string `json:"x5c"`
}
@@ -115,15 +123,14 @@ func NewJWTValidator(ctx context.Context, issuer string, audienceList []string,
}
}
- cert, err := getPemCert(ctx, token, keys)
+ publicKey, err := getPublicKey(ctx, token, keys)
if err != nil {
+ log.WithContext(ctx).Errorf("getPublicKey error: %s", err)
return nil, err
}
- result, _ := jwt.ParseRSAPublicKeyFromPEM([]byte(cert))
- return result, nil
+ return publicKey, nil
},
- SigningMethod: jwt.SigningMethodRS256,
EnableAuthOnOptions: false,
}
@@ -159,15 +166,7 @@ func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt
// Check if there was an error in parsing...
if err != nil {
log.WithContext(ctx).Errorf("error parsing token: %v", err)
- return nil, fmt.Errorf("Error parsing token: %w", err)
- }
-
- if m.options.SigningMethod != nil && m.options.SigningMethod.Alg() != parsedToken.Header["alg"] {
- errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s",
- m.options.SigningMethod.Alg(),
- parsedToken.Header["alg"])
- log.WithContext(ctx).Debugf("error validating token algorithm: %s", errorMsg)
- return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg)
+ return nil, fmt.Errorf("error parsing token: %w", err)
}
// Check if the parsed token is valid...
@@ -205,9 +204,8 @@ func getPemKeys(ctx context.Context, keysLocation string) (*Jwks, error) {
return jwks, err
}
-func getPemCert(ctx context.Context, token *jwt.Token, jwks *Jwks) (string, error) {
+func getPublicKey(ctx context.Context, token *jwt.Token, jwks *Jwks) (interface{}, error) {
// todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time
- cert := ""
for k := range jwks.Keys {
if token.Header["kid"] != jwks.Keys[k].Kid {
@@ -215,73 +213,79 @@ func getPemCert(ctx context.Context, token *jwt.Token, jwks *Jwks) (string, erro
}
if len(jwks.Keys[k].X5c) != 0 {
- cert = "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
- return cert, nil
+ cert := "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----"
+ return jwt.ParseRSAPublicKeyFromPEM([]byte(cert))
}
- log.WithContext(ctx).Debugf("generating validation pem from JWK")
- return generatePemFromJWK(jwks.Keys[k])
+
+ if jwks.Keys[k].Kty == "RSA" {
+ log.WithContext(ctx).Debugf("generating PublicKey from RSA JWK")
+ return getPublicKeyFromRSA(jwks.Keys[k])
+ }
+ if jwks.Keys[k].Kty == "EC" {
+ log.WithContext(ctx).Debugf("generating PublicKey from ECDSA JWK")
+ return getPublicKeyFromECDSA(jwks.Keys[k])
+ }
+
+ log.WithContext(ctx).Debugf("Key Type: %s not yet supported, please raise ticket!", jwks.Keys[k].Kty)
}
- return cert, errors.New("unable to find appropriate key")
+ return nil, errors.New("unable to find appropriate key")
}
-func generatePemFromJWK(jwk JSONWebKey) (string, error) {
- decodedModulus, err := base64.RawURLEncoding.DecodeString(jwk.N)
- if err != nil {
- return "", fmt.Errorf("unable to decode JWK modulus, error: %s", err)
+func getPublicKeyFromECDSA(jwk JSONWebKey) (publicKey *ecdsa.PublicKey, err error) {
+
+ if jwk.X == "" || jwk.Y == "" || jwk.Crv == "" {
+ return nil, fmt.Errorf("ecdsa key incomplete")
}
- intModules := big.NewInt(0)
- intModules.SetBytes(decodedModulus)
-
- exponent, err := convertExponentStringToInt(jwk.E)
- if err != nil {
- return "", fmt.Errorf("unable to decode JWK exponent, error: %s", err)
+ var xCoordinate []byte
+ if xCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.X); err != nil {
+ return nil, err
}
- publicKey := &rsa.PublicKey{
- N: intModules,
- E: exponent,
+ var yCoordinate []byte
+ if yCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.Y); err != nil {
+ return nil, err
}
- derKey, err := x509.MarshalPKIXPublicKey(publicKey)
- if err != nil {
- return "", fmt.Errorf("unable to convert public key to DER, error: %s", err)
+ publicKey = &ecdsa.PublicKey{}
+
+ var curve elliptic.Curve
+ switch jwk.Crv {
+ case p256:
+ curve = elliptic.P256()
+ case p384:
+ curve = elliptic.P384()
+ case p521:
+ curve = elliptic.P521()
}
- block := &pem.Block{
- Type: "RSA PUBLIC KEY",
- Bytes: derKey,
- }
+ publicKey.Curve = curve
+ publicKey.X = big.NewInt(0).SetBytes(xCoordinate)
+ publicKey.Y = big.NewInt(0).SetBytes(yCoordinate)
- var out bytes.Buffer
- err = pem.Encode(&out, block)
- if err != nil {
- return "", fmt.Errorf("unable to encode Pem block , error: %s", err)
- }
-
- return out.String(), nil
+ return publicKey, nil
}
-func convertExponentStringToInt(stringExponent string) (int, error) {
- decodedString, err := base64.StdEncoding.DecodeString(stringExponent)
+func getPublicKeyFromRSA(jwk JSONWebKey) (*rsa.PublicKey, error) {
+
+ decodedE, err := base64.RawURLEncoding.DecodeString(jwk.E)
if err != nil {
- return 0, err
+ return nil, err
}
- exponentBytes := decodedString
- if len(decodedString) < 8 {
- exponentBytes = make([]byte, 8-len(decodedString), 8)
- exponentBytes = append(exponentBytes, decodedString...)
+ decodedN, err := base64.RawURLEncoding.DecodeString(jwk.N)
+ if err != nil {
+ return nil, err
}
- bytesReader := bytes.NewReader(exponentBytes)
- var exponent uint64
- err = binary.Read(bytesReader, binary.BigEndian, &exponent)
- if err != nil {
- return 0, err
- }
+ var n, e big.Int
+ e.SetBytes(decodedE)
+ n.SetBytes(decodedN)
- return int(exponent), nil
+ return &rsa.PublicKey{
+ E: int(e.Int64()),
+ N: &n,
+ }, nil
}
// getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header
@@ -306,3 +310,4 @@ func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int {
return 0
}
+
diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go
index fe1e36d47..ff09129bd 100644
--- a/management/server/management_proto_test.go
+++ b/management/server/management_proto_test.go
@@ -3,13 +3,17 @@ package server
import (
"context"
"fmt"
+ "io"
"net"
"os"
"path/filepath"
"runtime"
+ "sync"
+ "sync/atomic"
"testing"
"time"
+ log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
@@ -24,6 +28,12 @@ import (
"github.com/netbirdio/netbird/util"
)
+type TestingT interface {
+ require.TestingT
+ Helper()
+ Cleanup(func())
+}
+
var (
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
@@ -86,7 +96,7 @@ func Test_SyncProtocol(t *testing.T) {
defer func() {
os.Remove(filepath.Join(dir, "store.json")) //nolint
}()
- mgmtServer, _, mgmtAddr, err := startManagement(t, &Config{
+ mgmtServer, _, mgmtAddr, err := startManagementForTest(t, &Config{
Stuns: []*Host{{
Proto: "udp",
URI: "stun:stun.wiretrustee.com:3468",
@@ -402,7 +412,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
}
}
-func startManagement(t *testing.T, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) {
+func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) {
t.Helper()
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
@@ -429,10 +439,11 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, *DefaultAccoun
if err != nil {
return nil, nil, "", err
}
- turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
+
+ secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
ephemeralMgr := NewEphemeralManager(store, accountManager)
- mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr)
+ mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, ephemeralMgr)
if err != nil {
return nil, nil, "", err
}
@@ -485,7 +496,7 @@ func testSyncStatusRace(t *testing.T) {
os.Remove(filepath.Join(dir, "store.json")) //nolint
}()
- mgmtServer, am, mgmtAddr, err := startManagement(t, &Config{
+ mgmtServer, am, mgmtAddr, err := startManagementForTest(t, &Config{
Stuns: []*Host{{
Proto: "udp",
URI: "stun:stun.wiretrustee.com:3468",
@@ -545,7 +556,6 @@ func testSyncStatusRace(t *testing.T) {
ctx2, cancelFunc2 := context.WithCancel(context.Background())
- //client.
sync2, err := client.Sync(ctx2, &mgmtProto.EncryptedMessage{
WgPubKey: concurrentPeerKey2.PublicKey().String(),
Body: message2,
@@ -574,7 +584,7 @@ func testSyncStatusRace(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.Background())
- //client.
+ // client.
sync, err := client.Sync(ctx, &mgmtProto.EncryptedMessage{
WgPubKey: peerWithInvalidStatus.PublicKey().String(),
Body: message,
@@ -617,7 +627,7 @@ func testSyncStatusRace(t *testing.T) {
}
time.Sleep(10 * time.Millisecond)
- peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), peerWithInvalidStatus.PublicKey().String())
+ peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peerWithInvalidStatus.PublicKey().String())
if err != nil {
t.Fatal(err)
return
@@ -626,3 +636,221 @@ func testSyncStatusRace(t *testing.T) {
t.Fatal("Peer should be connected")
}
}
+
+func Test_LoginPerformance(t *testing.T) {
+ if os.Getenv("CI") == "true" || runtime.GOOS == "windows" {
+ t.Skip("Skipping test on CI or Windows")
+ }
+
+ t.Setenv("NETBIRD_STORE_ENGINE", "sqlite")
+
+ benchCases := []struct {
+ name string
+ peers int
+ accounts int
+ }{
+ // {"XXS", 5, 1},
+ // {"XS", 10, 1},
+ // {"S", 100, 1},
+ // {"M", 250, 1},
+ // {"L", 500, 1},
+ // {"XL", 750, 1},
+ {"XXL", 5000, 1},
+ }
+
+ log.SetOutput(io.Discard)
+ defer log.SetOutput(os.Stderr)
+
+ for _, bc := range benchCases {
+ t.Run(bc.name, func(t *testing.T) {
+ t.Helper()
+ dir := t.TempDir()
+ err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json"))
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func() {
+ os.Remove(filepath.Join(dir, "store.json")) //nolint
+ }()
+
+ mgmtServer, am, _, err := startManagementForTest(t, &Config{
+ Stuns: []*Host{{
+ Proto: "udp",
+ URI: "stun:stun.wiretrustee.com:3468",
+ }},
+ TURNConfig: &TURNConfig{
+ TimeBasedCredentials: false,
+ CredentialsTTL: util.Duration{},
+ Secret: "whatever",
+ Turns: []*Host{{
+ Proto: "udp",
+ URI: "turn:stun.wiretrustee.com:3468",
+ }},
+ },
+ Signal: &Host{
+ Proto: "http",
+ URI: "signal.wiretrustee.com:10000",
+ },
+ Datadir: dir,
+ HttpConfig: nil,
+ })
+ if err != nil {
+ t.Fatal(err)
+ return
+ }
+ defer mgmtServer.GracefulStop()
+
+ t.Logf("management setup complete, start registering peers")
+
+ var counter int32
+ var counterStart int32
+ var wgAccount sync.WaitGroup
+ var mu sync.Mutex
+ messageCalls := []func() error{}
+ for j := 0; j < bc.accounts; j++ {
+ wgAccount.Add(1)
+ var wgPeer sync.WaitGroup
+ go func(j int, counter *int32, counterStart *int32) {
+ defer wgAccount.Done()
+
+ account, err := createAccount(am, fmt.Sprintf("account-%d", j), fmt.Sprintf("user-%d", j), fmt.Sprintf("domain-%d", j))
+ if err != nil {
+ t.Logf("account creation failed: %v", err)
+ return
+ }
+
+ setupKey, err := am.CreateSetupKey(context.Background(), account.Id, fmt.Sprintf("key-%d", j), SetupKeyReusable, time.Hour, nil, 0, fmt.Sprintf("user-%d", j), false)
+ if err != nil {
+ t.Logf("error creating setup key: %v", err)
+ return
+ }
+
+ startTime := time.Now()
+ for i := 0; i < bc.peers; i++ {
+ wgPeer.Add(1)
+ key, err := wgtypes.GeneratePrivateKey()
+ if err != nil {
+ t.Logf("failed to generate key: %v", err)
+ return
+ }
+
+ meta := &mgmtProto.PeerSystemMeta{
+ Hostname: key.PublicKey().String(),
+ GoOS: runtime.GOOS,
+ OS: runtime.GOOS,
+ Core: "core",
+ Platform: "platform",
+ Kernel: "kernel",
+ WiretrusteeVersion: "",
+ }
+
+ peerLogin := PeerLogin{
+ WireGuardPubKey: key.String(),
+ SSHKey: "random",
+ Meta: extractPeerMeta(context.Background(), meta),
+ SetupKey: setupKey.Key,
+ ConnectionIP: net.IP{1, 1, 1, 1},
+ }
+
+ login := func() error {
+ _, _, _, err = am.LoginPeer(context.Background(), peerLogin)
+ if err != nil {
+ t.Logf("failed to login peer: %v", err)
+ return err
+ }
+ atomic.AddInt32(counter, 1)
+ if *counter%100 == 0 {
+ t.Logf("finished %d login calls", *counter)
+ }
+ return nil
+ }
+
+ mu.Lock()
+ messageCalls = append(messageCalls, login)
+ mu.Unlock()
+
+ go func(peerLogin PeerLogin, counterStart *int32) {
+ defer wgPeer.Done()
+ _, _, _, err = am.LoginPeer(context.Background(), peerLogin)
+ if err != nil {
+ t.Logf("failed to login peer: %v", err)
+ return
+ }
+
+ atomic.AddInt32(counterStart, 1)
+ if *counterStart%100 == 0 {
+ t.Logf("registered %d peers", *counterStart)
+ }
+ }(peerLogin, counterStart)
+
+ }
+ wgPeer.Wait()
+
+ t.Logf("Time for registration: %s", time.Since(startTime))
+ }(j, &counter, &counterStart)
+ }
+
+ wgAccount.Wait()
+
+ t.Logf("prepared %d login calls", len(messageCalls))
+ testLoginPerformance(t, messageCalls)
+
+ })
+ }
+}
+
+func testLoginPerformance(t *testing.T, loginCalls []func() error) {
+ t.Helper()
+ wgSetup := sync.WaitGroup{}
+ startChan := make(chan struct{})
+
+ wgDone := sync.WaitGroup{}
+ durations := []time.Duration{}
+ l := sync.Mutex{}
+
+ for i, function := range loginCalls {
+ wgSetup.Add(1)
+ wgDone.Add(1)
+ go func(function func() error, i int) {
+ defer wgDone.Done()
+ wgSetup.Done()
+
+ <-startChan
+ start := time.Now()
+
+ err := function()
+ if err != nil {
+ t.Logf("Error: %v", err)
+ return
+ }
+
+ duration := time.Since(start)
+ l.Lock()
+ durations = append(durations, duration)
+ l.Unlock()
+ }(function, i)
+ }
+
+ wgSetup.Wait()
+ t.Logf("starting login calls")
+ close(startChan)
+ wgDone.Wait()
+ var tMin, tMax, tSum time.Duration
+ for i, d := range durations {
+ if i == 0 {
+ tMin = d
+ tMax = d
+ tSum = d
+ continue
+ }
+ if d < tMin {
+ tMin = d
+ }
+ if d > tMax {
+ tMax = d
+ }
+ tSum += d
+ }
+ tAvg := tSum / time.Duration(len(durations))
+ t.Logf("Min: %v, Max: %v, Avg: %v", tMin, tMax, tAvg)
+}
diff --git a/management/server/management_test.go b/management/server/management_test.go
index 62e7f5a05..3956d96b1 100644
--- a/management/server/management_test.go
+++ b/management/server/management_test.go
@@ -552,8 +552,9 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
if err != nil {
log.Fatalf("failed creating a manager: %v", err)
}
- turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig)
- mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil)
+
+ secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
+ mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil)
Expect(err).NotTo(HaveOccurred())
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
go func() {
diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go
index f4452eaee..8dc8f6a4f 100644
--- a/management/server/mock_server/account_mock.go
+++ b/management/server/mock_server/account_mock.go
@@ -23,10 +23,11 @@ import (
type MockAccountManager struct {
GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error)
+ GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error)
CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType,
expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error)
GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error)
- GetAccountByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (*server.Account, error)
+ GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error)
GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
@@ -48,7 +49,7 @@ type MockAccountManager struct {
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
- SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) error
+ SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
@@ -56,7 +57,7 @@ type MockAccountManager struct {
MarkPATUsedFunc func(ctx context.Context, pat string) error
UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error
UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
- CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
+ CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups,accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error)
GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error
DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error
@@ -78,7 +79,7 @@ type MockAccountManager struct {
DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error
ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error)
CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
- GetAccountFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error)
+ GetAccountIDFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
DeleteAccountFunc func(ctx context.Context, accountID, userID string) error
GetDNSDomainFunc func() string
@@ -104,6 +105,9 @@ type MockAccountManager struct {
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
+ GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*server.Account, error)
+ GetUserByIDFunc func(ctx context.Context, id string) (*server.User, error)
+ GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*server.Settings, error)
}
func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
@@ -189,16 +193,14 @@ func (am *MockAccountManager) CreateSetupKey(
return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented")
}
-// GetAccountByUserOrAccountID mock implementation of GetAccountByUserOrAccountID from server.AccountManager interface
-func (am *MockAccountManager) GetAccountByUserOrAccountID(
- ctx context.Context, userId, accountId, domain string,
-) (*server.Account, error) {
- if am.GetAccountByUserOrAccountIdFunc != nil {
- return am.GetAccountByUserOrAccountIdFunc(ctx, userId, accountId, domain)
+// GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface
+func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) {
+ if am.GetAccountIDByUserOrAccountIdFunc != nil {
+ return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain)
}
- return nil, status.Errorf(
+ return "", status.Errorf(
codes.Unimplemented,
- "method GetAccountByUserOrAccountID is not implemented",
+ "method GetAccountIDByUserOrAccountID is not implemented",
)
}
@@ -364,7 +366,7 @@ func (am *MockAccountManager) DeleteRule(ctx context.Context, accountID, ruleID,
if am.DeleteRuleFunc != nil {
return am.DeleteRuleFunc(ctx, accountID, ruleID, userID)
}
- return status.Errorf(codes.Unimplemented, "method DeleteRule is not implemented")
+ return status.Errorf(codes.Unimplemented, "method DeletePeerRule is not implemented")
}
// GetPolicy mock implementation of GetPolicy from server.AccountManager interface
@@ -376,9 +378,9 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID
}
// SavePolicy mock implementation of SavePolicy from server.AccountManager interface
-func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) error {
+func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error {
if am.SavePolicyFunc != nil {
- return am.SavePolicyFunc(ctx, accountID, userID, policy)
+ return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate)
}
return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
}
@@ -431,9 +433,9 @@ func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID
}
// CreateRoute mock implementation of CreateRoute from server.AccountManager interface
-func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
+func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
if am.CreateRouteFunc != nil {
- return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, enabled, userID, keepRoute)
+ return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups,accessControlGroupID, enabled, userID, keepRoute)
}
return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented")
}
@@ -592,14 +594,12 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID
return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented")
}
-// GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface
-func (am *MockAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User,
- error,
-) {
- if am.GetAccountFromTokenFunc != nil {
- return am.GetAccountFromTokenFunc(ctx, claims)
+// GetAccountIDFromToken mocks GetAccountIDFromToken of the AccountManager interface
+func (am *MockAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) {
+ if am.GetAccountIDFromTokenFunc != nil {
+ return am.GetAccountIDFromTokenFunc(ctx, claims)
}
- return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented")
+ return "", "", status.Errorf(codes.Unimplemented, "method GetAccountIDFromToken is not implemented")
}
func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error {
@@ -793,3 +793,33 @@ func (am *MockAccountManager) GetAccountIDForPeerKey(ctx context.Context, peerKe
}
return "", status.Errorf(codes.Unimplemented, "method GetAccountIDForPeerKey is not implemented")
}
+
+// GetAccountByID mocks GetAccountByID of the AccountManager interface
+func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*server.Account, error) {
+ if am.GetAccountByIDFunc != nil {
+ return am.GetAccountByIDFunc(ctx, accountID, userID)
+ }
+ return nil, status.Errorf(codes.Unimplemented, "method GetAccountByID is not implemented")
+}
+
+// GetUserByID mocks GetUserByID of the AccountManager interface
+func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*server.User, error) {
+ if am.GetUserByIDFunc != nil {
+ return am.GetUserByIDFunc(ctx, id)
+ }
+ return nil, status.Errorf(codes.Unimplemented, "method GetUserByID is not implemented")
+}
+
+func (am *MockAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*server.Settings, error) {
+ if am.GetAccountSettingsFunc != nil {
+ return am.GetAccountSettingsFunc(ctx, accountID, userID)
+ }
+ return nil, status.Errorf(codes.Unimplemented, "method GetAccountSettings is not implemented")
+}
+
+func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) (*server.Account, error) {
+ if am.GetAccountFunc != nil {
+ return am.GetAccountFunc(ctx, accountID)
+ }
+ return nil, status.Errorf(codes.Unimplemented, "method GetAccount is not implemented")
+}
diff --git a/management/server/nameserver.go b/management/server/nameserver.go
index 2cd934065..751ca12bc 100644
--- a/management/server/nameserver.go
+++ b/management/server/nameserver.go
@@ -19,30 +19,16 @@ const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) {
-
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
- user, err := account.FindUser(userID)
- if err != nil {
- return nil, err
+ if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
+ return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
}
- if !(user.HasAdminPower() || user.IsServiceUser) {
- return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view nameserver groups")
- }
-
- nsGroup, found := account.NameServerGroups[nsGroupID]
- if found {
- return nsGroup.Copy(), nil
- }
-
- return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID)
+ return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID)
}
// CreateNameServerGroup creates and saves a new nameserver group
@@ -159,30 +145,16 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
// ListNameServerGroups returns a list of nameserver groups from account
func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) {
-
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
- user, err := account.FindUser(userID)
- if err != nil {
- return nil, err
- }
-
- if !(user.HasAdminPower() || user.IsServiceUser) {
+ if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
}
- nsGroups := make([]*nbdns.NameServerGroup, 0, len(account.NameServerGroups))
- for _, item := range account.NameServerGroups {
- nsGroups = append(nsGroups, item.Copy())
- }
-
- return nsGroups, nil
+ return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
}
func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error {
diff --git a/management/server/network.go b/management/server/network.go
index 91d844c3e..8fb6a8b3c 100644
--- a/management/server/network.go
+++ b/management/server/network.go
@@ -26,12 +26,13 @@ const (
)
type NetworkMap struct {
- Peers []*nbpeer.Peer
- Network *Network
- Routes []*route.Route
- DNSConfig nbdns.Config
- OfflinePeers []*nbpeer.Peer
- FirewallRules []*FirewallRule
+ Peers []*nbpeer.Peer
+ Network *Network
+ Routes []*route.Route
+ DNSConfig nbdns.Config
+ OfflinePeers []*nbpeer.Peer
+ FirewallRules []*FirewallRule
+ RoutesFirewallRules []*RouteFirewallRule
}
type Network struct {
diff --git a/management/server/peer.go b/management/server/peer.go
index 7e8dbe1a6..97e11c08a 100644
--- a/management/server/peer.go
+++ b/management/server/peer.go
@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"net"
- "slices"
"strings"
"sync"
"time"
@@ -12,6 +11,7 @@ import (
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
+ "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/proto"
@@ -185,9 +185,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
}
- peerLabelUpdated := peer.Name != update.Name
-
- if peerLabelUpdated {
+ if peer.Name != update.Name {
peer.Name = update.Name
existingLabels := account.getPeerDNSLabels()
@@ -222,16 +220,13 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
}
account.UpdatePeer(peer)
+
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return nil, err
}
- expired, _ := peer.LoginExpired(account.Settings.PeerLoginExpiration)
-
- if peerLabelUpdated || (expired && peer.LoginExpirationEnabled) {
- am.updateAccountPeers(ctx, account)
- }
+ am.updateAccountPeers(ctx, account)
return peer, nil
}
@@ -295,8 +290,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
- updateAccountPeers := isPeerInActiveGroup(account, peerID)
-
err = am.deletePeers(ctx, account, []string{peerID}, userID)
if err != nil {
return err
@@ -307,9 +300,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
- if updateAccountPeers {
- am.updateAccountPeers(ctx, account)
- }
+ am.updateAccountPeers(ctx, account)
return nil
}
@@ -383,179 +374,207 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
}
}()
- var account *Account
- // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
- account, err = am.Store.GetAccount(ctx, accountID)
- if err != nil {
- return nil, nil, nil, err
- }
-
- if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
- if am.idpManager != nil {
- userdata, err := am.lookupUserInCache(ctx, userID, account)
- if err == nil && userdata != nil {
- peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
- }
- }
- }
-
// This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice.
// Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow)
// and the peer disconnects with a timeout and tries to register again.
// We just check if this machine has been registered before and reject the second registration.
// The connecting peer should be able to recover with a retry.
- _, err = account.FindPeerByPubKey(peer.Key)
+ _, err = am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peer.Key)
if err == nil {
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
}
opEvent := &activity.Event{
Timestamp: time.Now().UTC(),
- AccountID: account.Id,
+ AccountID: accountID,
}
- var ephemeral bool
- setupKeyName := ""
- if !addedByUser {
- // validate the setup key if adding with a key
- sk, err := account.FindSetupKey(upperKey)
- if err != nil {
- return nil, nil, nil, err
- }
+ var newPeer *nbpeer.Peer
- if !sk.IsValid() {
- return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
- }
-
- account.SetupKeys[sk.Key] = sk.IncrementUsage()
- opEvent.InitiatorID = sk.Id
- opEvent.Activity = activity.PeerAddedWithSetupKey
- ephemeral = sk.Ephemeral
- setupKeyName = sk.Name
- } else {
- opEvent.InitiatorID = userID
- opEvent.Activity = activity.PeerAddedByUser
- }
-
- takenIps := account.getTakenIPs()
- existingLabels := account.getPeerDNSLabels()
-
- newLabel, err := getPeerHostLabel(peer.Meta.Hostname, existingLabels)
- if err != nil {
- return nil, nil, nil, err
- }
-
- peer.DNSLabel = newLabel
- network := account.Network
- nextIp, err := AllocatePeerIP(network.Net, takenIps)
- if err != nil {
- return nil, nil, nil, err
- }
-
- registrationTime := time.Now().UTC()
-
- newPeer := &nbpeer.Peer{
- ID: xid.New().String(),
- Key: peer.Key,
- SetupKey: upperKey,
- IP: nextIp,
- Meta: peer.Meta,
- Name: peer.Meta.Hostname,
- DNSLabel: newLabel,
- UserID: userID,
- Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
- SSHEnabled: false,
- SSHKey: peer.SSHKey,
- LastLogin: registrationTime,
- CreatedAt: registrationTime,
- LoginExpirationEnabled: addedByUser,
- Ephemeral: ephemeral,
- Location: peer.Location,
- }
-
- if am.geo != nil && newPeer.Location.ConnectionIP != nil {
- location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
- if err != nil {
- log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
+ err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
+ var groupsToAdd []string
+ var setupKeyID string
+ var setupKeyName string
+ var ephemeral bool
+ if addedByUser {
+ user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID)
+ if err != nil {
+ return fmt.Errorf("failed to get user groups: %w", err)
+ }
+ groupsToAdd = user.AutoGroups
+ opEvent.InitiatorID = userID
+ opEvent.Activity = activity.PeerAddedByUser
} else {
- newPeer.Location.CountryCode = location.Country.ISOCode
- newPeer.Location.CityName = location.City.Names.En
- newPeer.Location.GeoNameID = location.City.GeonameID
- }
- }
+ // Validate the setup key
+ sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, upperKey)
+ if err != nil {
+ return fmt.Errorf("failed to get setup key: %w", err)
+ }
- // add peer to 'All' group
- group, err := account.GetGroupAll()
- if err != nil {
- return nil, nil, nil, err
- }
- group.Peers = append(group.Peers, newPeer.ID)
+ if !sk.IsValid() {
+ return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid")
+ }
- var groupsToAdd []string
- if addedByUser {
- groupsToAdd, err = account.getUserGroups(userID)
- if err != nil {
- return nil, nil, nil, err
+ opEvent.InitiatorID = sk.Id
+ opEvent.Activity = activity.PeerAddedWithSetupKey
+ groupsToAdd = sk.AutoGroups
+ ephemeral = sk.Ephemeral
+ setupKeyID = sk.Id
+ setupKeyName = sk.Name
}
- } else {
- groupsToAdd, err = account.getSetupKeyGroups(upperKey)
- if err != nil {
- return nil, nil, nil, err
- }
- }
- if len(groupsToAdd) > 0 {
- for _, s := range groupsToAdd {
- if g, ok := account.Groups[s]; ok && g.Name != "All" {
- g.Peers = append(g.Peers, newPeer.ID)
+ if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" {
+ if am.idpManager != nil {
+ userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
+ if err == nil && userdata != nil {
+ peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0])
+ }
}
}
- }
- newPeer = am.integratedPeerValidator.PreparePeer(ctx, account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra)
-
- if addedByUser {
- user, err := account.FindUser(userID)
+ freeLabel, err := am.getFreeDNSLabel(ctx, transaction, accountID, peer.Meta.Hostname)
if err != nil {
- return nil, nil, nil, status.Errorf(status.Internal, "couldn't find user")
+ return fmt.Errorf("failed to get free DNS label: %w", err)
}
- user.updateLastLogin(newPeer.LastLogin)
- }
- account.Peers[newPeer.ID] = newPeer
- account.Network.IncSerial()
- err = am.Store.SaveAccount(ctx, account)
+ freeIP, err := am.getFreeIP(ctx, transaction, accountID)
+ if err != nil {
+ return fmt.Errorf("failed to get free IP: %w", err)
+ }
+
+ registrationTime := time.Now().UTC()
+ newPeer = &nbpeer.Peer{
+ ID: xid.New().String(),
+ AccountID: accountID,
+ Key: peer.Key,
+ SetupKey: upperKey,
+ IP: freeIP,
+ Meta: peer.Meta,
+ Name: peer.Meta.Hostname,
+ DNSLabel: freeLabel,
+ UserID: userID,
+ Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime},
+ SSHEnabled: false,
+ SSHKey: peer.SSHKey,
+ LastLogin: registrationTime,
+ CreatedAt: registrationTime,
+ LoginExpirationEnabled: addedByUser,
+ Ephemeral: ephemeral,
+ Location: peer.Location,
+ }
+ opEvent.TargetID = newPeer.ID
+ opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
+ if !addedByUser {
+ opEvent.Meta["setup_key_name"] = setupKeyName
+ }
+
+ if am.geo != nil && newPeer.Location.ConnectionIP != nil {
+ location, err := am.geo.Lookup(newPeer.Location.ConnectionIP)
+ if err != nil {
+ log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err)
+ } else {
+ newPeer.Location.CountryCode = location.Country.ISOCode
+ newPeer.Location.CityName = location.City.Names.En
+ newPeer.Location.GeoNameID = location.City.GeonameID
+ }
+ }
+
+ settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
+ if err != nil {
+ return fmt.Errorf("failed to get account settings: %w", err)
+ }
+ newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra)
+
+ err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID)
+ if err != nil {
+ return fmt.Errorf("failed adding peer to All group: %w", err)
+ }
+
+ if len(groupsToAdd) > 0 {
+ for _, g := range groupsToAdd {
+ err = transaction.AddPeerToGroup(ctx, accountID, newPeer.ID, g)
+ if err != nil {
+ return err
+ }
+ }
+ }
+
+ err = transaction.AddPeerToAccount(ctx, newPeer)
+ if err != nil {
+ return fmt.Errorf("failed to add peer to account: %w", err)
+ }
+
+ err = transaction.IncrementNetworkSerial(ctx, accountID)
+ if err != nil {
+ return fmt.Errorf("failed to increment network serial: %w", err)
+ }
+
+ if addedByUser {
+ err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.LastLogin)
+ if err != nil {
+ return fmt.Errorf("failed to update user last login: %w", err)
+ }
+ } else {
+ err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID)
+ if err != nil {
+ return fmt.Errorf("failed to increment setup key usage: %w", err)
+ }
+ }
+
+ log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
+ return nil
+ })
+
if err != nil {
- return nil, nil, nil, err
+ return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err)
}
- // Account is saved, we can release the lock
- unlock()
- unlock = nil
-
- opEvent.TargetID = newPeer.ID
- opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain())
- if !addedByUser {
- opEvent.Meta["setup_key_name"] = setupKeyName
+ if newPeer == nil {
+ return nil, nil, nil, fmt.Errorf("new peer is nil")
}
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
- if areGroupChangesAffectPeers(account, groupsToAdd) {
- am.updateAccountPeers(ctx, account)
+ unlock()
+ unlock = nil
+
+ account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
+ if err != nil {
+ return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
}
+ am.updateAccountPeers(ctx, account)
+
approvedPeersMap, err := am.GetValidatedPeers(account)
if err != nil {
return nil, nil, nil, err
}
- postureChecks := am.getPeerPostureChecks(account, peer)
+ postureChecks := am.getPeerPostureChecks(account, newPeer)
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
return newPeer, networkMap, postureChecks, nil
}
+func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) {
+ takenIps, err := store.GetTakenIPs(ctx, LockingStrengthShare, accountID)
+ if err != nil {
+ return nil, fmt.Errorf("failed to get taken IPs: %w", err)
+ }
+
+ network, err := store.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID)
+ if err != nil {
+ return nil, fmt.Errorf("failed getting network: %w", err)
+ }
+
+ nextIp, err := AllocatePeerIP(network.Net, takenIps)
+ if err != nil {
+ return nil, fmt.Errorf("failed to allocate new peer ip: %w", err)
+ }
+
+ return nextIp, nil
+}
+
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
@@ -563,16 +582,23 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
return nil, nil, nil, status.NewPeerNotRegisteredError()
}
- err = checkIfPeerOwnerIsBlocked(peer, account)
- if err != nil {
- return nil, nil, nil, err
+ if peer.UserID != "" {
+ user, err := account.FindUser(peer.UserID)
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
+ err = checkIfPeerOwnerIsBlocked(peer, user)
+ if err != nil {
+ return nil, nil, nil, err
+ }
}
if peerLoginExpired(ctx, peer, account.Settings) {
return nil, nil, nil, status.NewPeerLoginExpiredError()
}
- peer, updated := updatePeerMeta(peer, sync.Meta, account)
+ updated := peer.UpdateMetaIfNew(sync.Meta)
if updated {
err = am.Store.SavePeer(ctx, account.Id, peer)
if err != nil {
@@ -589,22 +615,16 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
return nil, nil, nil, err
}
+ var postureChecks []*posture.Checks
+
if peerNotValid {
emptyMap := &NetworkMap{
Network: account.Network.Copy(),
}
- return peer, emptyMap, nil, nil
+ return peer, emptyMap, postureChecks, nil
}
- peer, peerMetaUpdated := updatePeerMeta(peer, sync.Meta, account)
- if peerMetaUpdated {
- err = am.Store.SaveAccount(ctx, account)
- if err != nil {
- return nil, nil, nil, err
- }
- }
-
- if isStatusChanged || (peerMetaUpdated && sync.UpdateAccountPeers) {
+ if isStatusChanged {
am.updateAccountPeers(ctx, account)
}
@@ -612,7 +632,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
if err != nil {
return nil, nil, nil, err
}
- postureChecks := am.getPeerPostureChecks(account, peer)
+ postureChecks = am.getPeerPostureChecks(account, peer)
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
@@ -644,31 +664,28 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
// it means that the client has already checked if it needs login and had been through the SSO flow
// so, we can skip this check and directly proceed with the login
if login.UserID == "" {
+ log.Info("Peer needs login")
err = am.checkIFPeerNeedsLoginWithoutLock(ctx, accountID, login)
if err != nil {
return nil, nil, nil, err
}
}
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
+ unlockAccount := am.Store.AcquireReadLockByUID(ctx, accountID)
+ defer unlockAccount()
+ unlockPeer := am.Store.AcquireWriteLockByUID(ctx, login.WireGuardPubKey)
defer func() {
- if unlock != nil {
- unlock()
+ if unlockPeer != nil {
+ unlockPeer()
}
}()
- // fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
- account, err := am.Store.GetAccount(ctx, accountID)
+ peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey)
if err != nil {
return nil, nil, nil, err
}
- peer, err := account.FindPeerByPubKey(login.WireGuardPubKey)
- if err != nil {
- return nil, nil, nil, status.NewPeerNotRegisteredError()
- }
-
- err = checkIfPeerOwnerIsBlocked(peer, account)
+ settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, nil, nil, err
}
@@ -676,21 +693,39 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
// this flag prevents unnecessary calls to the persistent store.
shouldStorePeer := false
updateRemotePeers := false
- if peerLoginExpired(ctx, peer, account.Settings) {
- err = am.handleExpiredPeer(ctx, login, account, peer)
+
+ if login.UserID != "" {
+ changed, err := am.handleUserPeer(ctx, peer, settings)
if err != nil {
return nil, nil, nil, err
}
- updateRemotePeers = true
- shouldStorePeer = true
+ if changed {
+ shouldStorePeer = true
+ updateRemotePeers = true
+ }
}
- isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
+ groups, err := am.Store.GetAccountGroups(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
- peer, updated := updatePeerMeta(peer, login.Meta, account)
+ var grps []string
+ for _, group := range groups {
+ for _, id := range group.Peers {
+ if id == peer.ID {
+ grps = append(grps, group.ID)
+ break
+ }
+ }
+ }
+
+ isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, grps, settings.Extra)
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
+ updated := peer.UpdateMetaIfNew(login.Meta)
if updated {
shouldStorePeer = true
}
@@ -707,8 +742,13 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
}
}
- unlock()
- unlock = nil
+ unlockPeer()
+ unlockPeer = nil
+
+ account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
+ if err != nil {
+ return nil, nil, nil, err
+ }
if updateRemotePeers || isStatusChanged {
am.updateAccountPeers(ctx, account)
@@ -723,7 +763,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
// and before starting the engine, we do the checks without an account lock to avoid piling up requests.
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error {
- peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
+ peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, login.WireGuardPubKey)
if err != nil {
return err
}
@@ -734,7 +774,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
return nil
}
- settings, err := am.Store.GetAccountSettings(ctx, accountID)
+ settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
@@ -766,36 +806,30 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
}
-func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, login PeerLogin, account *Account, peer *nbpeer.Peer) error {
- err := checkAuth(ctx, login.UserID, peer)
+func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *User, peer *nbpeer.Peer) error {
+ err := checkAuth(ctx, user.Id, peer)
if err != nil {
return err
}
// If peer was expired before and if it reached this point, it is re-authenticated.
// UserID is present, meaning that JWT validation passed successfully in the API layer.
- updatePeerLastLogin(peer, account)
-
- // sync user last login with peer last login
- user, err := account.FindUser(login.UserID)
- if err != nil {
- return status.Errorf(status.Internal, "couldn't find user")
- }
-
- err = am.Store.SaveUserLastLogin(account.Id, user.Id, peer.LastLogin)
+ peer = peer.UpdateLastLogin()
+ err = am.Store.SavePeer(ctx, peer.AccountID, peer)
if err != nil {
return err
}
- am.StoreEvent(ctx, login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
+ err = am.Store.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.LastLogin)
+ if err != nil {
+ return err
+ }
+
+ am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
return nil
}
-func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error {
+func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, user *User) error {
if peer.AddedWithSSOLogin() {
- user, err := account.FindUser(peer.UserID)
- if err != nil {
- return status.Errorf(status.PermissionDenied, "user doesn't exist")
- }
if user.IsBlocked() {
return status.Errorf(status.PermissionDenied, "user is blocked")
}
@@ -825,9 +859,49 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings
return false
}
-func updatePeerLastLogin(peer *nbpeer.Peer, account *Account) {
- peer.UpdateLastLogin()
+// UpdatePeerSSHKey updates peer's public SSH key
+func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
+ if sshKey == "" {
+ log.WithContext(ctx).Debugf("empty SSH key provided for peer %s, skipping update", peerID)
+ return nil
+ }
+
+ account, err := am.Store.GetAccountByPeerID(ctx, peerID)
+ if err != nil {
+ return err
+ }
+
+ unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
+ defer unlock()
+
+ // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
+ account, err = am.Store.GetAccount(ctx, account.Id)
+ if err != nil {
+ return err
+ }
+
+ peer := account.GetPeer(peerID)
+ if peer == nil {
+ return status.Errorf(status.NotFound, "peer with ID %s not found", peerID)
+ }
+
+ if peer.SSHKey == sshKey {
+ log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peerID)
+ return nil
+ }
+
+ peer.SSHKey = sshKey
account.UpdatePeer(peer)
+
+ err = am.Store.SaveAccount(ctx, account)
+ if err != nil {
+ return err
+ }
+
+ // trigger network map update
+ am.updateAccountPeers(ctx, account)
+
+ return nil
}
// GetPeer for a given accountID, peerID and userID error if not found.
@@ -883,14 +957,6 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
return nil, status.Errorf(status.Internal, "user %s has no access to peer %s under account %s", userID, peerID, accountID)
}
-func updatePeerMeta(peer *nbpeer.Peer, meta nbpeer.PeerSystemMeta, account *Account) (*nbpeer.Peer, bool) {
- if peer.UpdateMetaIfNew(meta) {
- account.UpdatePeer(peer)
- return peer, true
- }
- return peer, false
-}
-
// updateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) {
@@ -929,7 +995,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
postureChecks := am.getPeerPostureChecks(account, p)
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
- update := toSyncResponse(ctx, nil, p, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
+ update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap, Checks: postureChecks})
}(peer)
}
@@ -937,14 +1003,10 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
wg.Wait()
}
-// IsPeerInActiveGroup checks if the given peer is part of a group that is used
-// in an active DNS, route, or ACL configuration.
-func isPeerInActiveGroup(account *Account, peerID string) bool {
- peerGroupIDs := make([]string, 0)
- for _, group := range account.Groups {
- if slices.Contains(group.Peers, peerID) {
- peerGroupIDs = append(peerGroupIDs, group.ID)
- }
+func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
+ labelMap := make(map[string]struct{}, len(existingLabels))
+ for _, label := range existingLabels {
+ labelMap[label] = struct{}{}
}
- return areGroupChangesAffectPeers(account, peerGroupIDs)
+ return labelMap
}
diff --git a/management/server/peer_test.go b/management/server/peer_test.go
index 102923e37..387adb91d 100644
--- a/management/server/peer_test.go
+++ b/management/server/peer_test.go
@@ -7,6 +7,7 @@ import (
"net"
"net/netip"
"os"
+ "runtime"
"testing"
"time"
@@ -19,9 +20,11 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto"
+ "github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
+ "github.com/netbirdio/netbird/management/server/telemetry"
nbroute "github.com/netbirdio/netbird/route"
)
@@ -248,7 +251,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
}
- err = manager.SavePolicy(context.Background(), account.Id, userID, &policy)
+ err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
if err != nil {
t.Errorf("expecting rule to be added, got failure %v", err)
return
@@ -296,7 +299,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
}
policy.Enabled = false
- err = manager.SavePolicy(context.Background(), account.Id, userID, &policy)
+ err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
if err != nil {
t.Errorf("expecting rule to be added, got failure %v", err)
return
@@ -643,7 +646,6 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
})
}
-
}
func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccountManager, string, string, error) {
@@ -849,9 +851,9 @@ func TestToSyncResponse(t *testing.T) {
DNSLabel: "peer1",
SSHKey: "peer1-ssh-key",
}
- turnCredentials := &TURNCredentials{
- Username: "turn-user",
- Password: "turn-pass",
+ turnRelayToken := &Token{
+ Payload: "turn-user",
+ Signature: "turn-pass",
}
networkMap := &NetworkMap{
Network: &Network{Net: *ipnet, Serial: 1000},
@@ -917,7 +919,7 @@ func TestToSyncResponse(t *testing.T) {
}
dnsCache := &DNSConfigCache{}
- response := toSyncResponse(context.Background(), config, peer, turnCredentials, networkMap, dnsName, checks, dnsCache)
+ response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache)
assert.NotNil(t, response)
// assert peer config
@@ -988,164 +990,192 @@ func TestToSyncResponse(t *testing.T) {
// assert network map Firewall
assert.Equal(t, 1, len(response.NetworkMap.FirewallRules))
assert.Equal(t, "192.168.1.2", response.NetworkMap.FirewallRules[0].PeerIP)
- assert.Equal(t, proto.FirewallRule_IN, response.NetworkMap.FirewallRules[0].Direction)
- assert.Equal(t, proto.FirewallRule_ACCEPT, response.NetworkMap.FirewallRules[0].Action)
- assert.Equal(t, proto.FirewallRule_TCP, response.NetworkMap.FirewallRules[0].Protocol)
+ assert.Equal(t, proto.RuleDirection_IN, response.NetworkMap.FirewallRules[0].Direction)
+ assert.Equal(t, proto.RuleAction_ACCEPT, response.NetworkMap.FirewallRules[0].Action)
+ assert.Equal(t, proto.RuleProtocol_TCP, response.NetworkMap.FirewallRules[0].Protocol)
assert.Equal(t, "80", response.NetworkMap.FirewallRules[0].Port)
// assert posture checks
assert.Equal(t, 1, len(response.Checks))
assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0])
}
-func TestPeerAccountPeerUpdate(t *testing.T) {
- manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
+func Test_RegisterPeerByUser(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("The SQLite store is not properly supported by Windows yet")
+ }
- err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID)
+ store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
+
+ eventStore := &activity.InMemoryEventStore{}
+
+ metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
+ assert.NoError(t, err)
+
+ am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
+ assert.NoError(t, err)
+
+ existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+ existingUserID := "edafee4e-63fb-11ec-90d6-0242ac120003"
+
+ _, err = store.GetAccount(context.Background(), existingAccountID)
require.NoError(t, err)
- err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
- ID: "group-id",
- Name: "GroupA",
- Peers: []string{peer1.ID, peer2.ID, peer3.ID},
- })
- require.NoError(t, err)
-
- // create a user with auto groups
- _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{
- Id: "regularUser1",
- AccountID: account.Id,
- Role: UserRoleAdmin,
- Issued: UserIssuedAPI,
- AutoGroups: []string{"group-id"},
- }, true)
- require.NoError(t, err)
-
- var peer4 *nbpeer.Peer
-
- updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
- t.Cleanup(func() {
- manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
- })
-
- // Updating not expired peer and peer expiration is enabled should not update account peers and not send peer update
- t.Run("updating not expired peer and peer expiration is enabled", func(t *testing.T) {
- done := make(chan struct{})
- go func() {
- peerShouldNotReceiveUpdate(t, updMsg)
- close(done)
- }()
-
- _, err := manager.UpdatePeer(context.Background(), account.Id, userID, peer2)
- require.NoError(t, err)
-
- select {
- case <-done:
- case <-time.After(time.Second):
- t.Error("timeout waiting for peerShouldNotReceiveUpdate")
- }
- })
-
- // Adding peer with an unused group in active dns, route, acl should not update account peers and not send peer update
- t.Run("adding peer with unused group", func(t *testing.T) {
- done := make(chan struct{})
- go func() {
- peerShouldNotReceiveUpdate(t, updMsg)
- close(done)
- }()
-
- key, err := wgtypes.GeneratePrivateKey()
- require.NoError(t, err)
-
- expectedPeerKey := key.PublicKey().String()
- peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{
- Key: expectedPeerKey,
- Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
- })
- require.NoError(t, err)
-
- select {
- case <-done:
- case <-time.After(time.Second):
- t.Error("timeout waiting for peerShouldNotReceiveUpdate")
- }
- })
-
- // Deleting peer with an unused group in active dns, route, acl should not update account peers and not send peer update
- t.Run("deleting peer with unused group", func(t *testing.T) {
- done := make(chan struct{})
- go func() {
- peerShouldNotReceiveUpdate(t, updMsg)
- close(done)
- }()
-
- err = manager.DeletePeer(context.Background(), account.Id, peer4.ID, userID)
- require.NoError(t, err)
-
- select {
- case <-done:
- case <-time.After(time.Second):
- t.Error("timeout waiting for peerShouldNotReceiveUpdate")
- }
- })
-
- // use the group-id in policy
- err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
- ID: "policy",
- Enabled: true,
- Rules: []*PolicyRule{
- {
- Enabled: true,
- Sources: []string{"group-id"},
- Destinations: []string{"group-id"},
- Bidirectional: true,
- Action: PolicyTrafficActionAccept,
- },
+ newPeer := &nbpeer.Peer{
+ ID: xid.New().String(),
+ AccountID: existingAccountID,
+ Key: "newPeerKey",
+ SetupKey: "",
+ IP: net.IP{123, 123, 123, 123},
+ Meta: nbpeer.PeerSystemMeta{
+ Hostname: "newPeer",
+ GoOS: "linux",
},
- })
+ Name: "newPeerName",
+ DNSLabel: "newPeer.test",
+ UserID: existingUserID,
+ Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
+ SSHEnabled: false,
+ LastLogin: time.Now(),
+ }
+
+ addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer)
require.NoError(t, err)
- // Adding peer with a used group in active dns, route or policy should update account peers and send peer update
- t.Run("adding peer with used group", func(t *testing.T) {
- done := make(chan struct{})
- go func() {
- peerShouldReceiveUpdate(t, updMsg)
- close(done)
- }()
+ peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, addedPeer.Key)
+ require.NoError(t, err)
+ assert.Equal(t, peer.AccountID, existingAccountID)
+ assert.Equal(t, peer.UserID, existingUserID)
- key, err := wgtypes.GeneratePrivateKey()
- require.NoError(t, err)
+ account, err := store.GetAccount(context.Background(), existingAccountID)
+ require.NoError(t, err)
+ assert.Contains(t, account.Peers, addedPeer.ID)
+ assert.Equal(t, peer.Meta.Hostname, newPeer.Meta.Hostname)
+ assert.Contains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, addedPeer.ID)
+ assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID)
- expectedPeerKey := key.PublicKey().String()
- peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{
- Key: expectedPeerKey,
- LoginExpirationEnabled: true,
- Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
- })
- require.NoError(t, err)
+ assert.Equal(t, uint64(1), account.Network.Serial)
- select {
- case <-done:
- case <-time.After(time.Second):
- t.Error("timeout waiting for peerShouldReceiveUpdate")
- }
- })
+ lastLogin, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
+ assert.NoError(t, err)
+ assert.NotEqual(t, lastLogin, account.Users[existingUserID].LastLogin)
+}
- //Deleting peer with a used group in active dns, route or acl should update account peers and send peer update
- t.Run("deleting peer with used group", func(t *testing.T) {
- done := make(chan struct{})
- go func() {
- peerShouldReceiveUpdate(t, updMsg)
- close(done)
- }()
+func Test_RegisterPeerBySetupKey(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("The SQLite store is not properly supported by Windows yet")
+ }
- err = manager.DeletePeer(context.Background(), account.Id, peer4.ID, userID)
- require.NoError(t, err)
+ store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
- select {
- case <-done:
- case <-time.After(time.Second):
- t.Error("timeout waiting for peerShouldReceiveUpdate")
- }
- })
+ eventStore := &activity.InMemoryEventStore{}
+
+ metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
+ assert.NoError(t, err)
+
+ am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
+ assert.NoError(t, err)
+
+ existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+ existingSetupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
+
+ _, err = store.GetAccount(context.Background(), existingAccountID)
+ require.NoError(t, err)
+
+ newPeer := &nbpeer.Peer{
+ ID: xid.New().String(),
+ AccountID: existingAccountID,
+ Key: "newPeerKey",
+ SetupKey: "existingSetupKey",
+ UserID: "",
+ IP: net.IP{123, 123, 123, 123},
+ Meta: nbpeer.PeerSystemMeta{
+ Hostname: "newPeer",
+ GoOS: "linux",
+ },
+ Name: "newPeerName",
+ DNSLabel: "newPeer.test",
+ Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
+ SSHEnabled: false,
+ }
+
+ addedPeer, _, _, err := am.AddPeer(context.Background(), existingSetupKeyID, "", newPeer)
+
+ require.NoError(t, err)
+
+ peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
+ require.NoError(t, err)
+ assert.Equal(t, peer.AccountID, existingAccountID)
+ assert.Equal(t, peer.SetupKey, existingSetupKeyID)
+
+ account, err := store.GetAccount(context.Background(), existingAccountID)
+ require.NoError(t, err)
+ assert.Contains(t, account.Peers, addedPeer.ID)
+ assert.Contains(t, account.Groups["cfefqs706sqkneg59g2g"].Peers, addedPeer.ID)
+ assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID)
+
+ assert.Equal(t, uint64(1), account.Network.Serial)
+
+ lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
+ assert.NoError(t, err)
+ assert.NotEqual(t, lastUsed, account.SetupKeys[existingSetupKeyID].LastUsed)
+ assert.Equal(t, 1, account.SetupKeys[existingSetupKeyID].UsedTimes)
}
+
+func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("The SQLite store is not properly supported by Windows yet")
+ }
+
+ store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
+
+ eventStore := &activity.InMemoryEventStore{}
+
+ metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
+ assert.NoError(t, err)
+
+ am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
+ assert.NoError(t, err)
+
+ existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+ faultyKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBC"
+
+ _, err = store.GetAccount(context.Background(), existingAccountID)
+ require.NoError(t, err)
+
+ newPeer := &nbpeer.Peer{
+ ID: xid.New().String(),
+ AccountID: existingAccountID,
+ Key: "newPeerKey",
+ SetupKey: "existingSetupKey",
+ UserID: "",
+ IP: net.IP{123, 123, 123, 123},
+ Meta: nbpeer.PeerSystemMeta{
+ Hostname: "newPeer",
+ GoOS: "linux",
+ },
+ Name: "newPeerName",
+ DNSLabel: "newPeer.test",
+ Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()},
+ SSHEnabled: false,
+ }
+
+ _, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer)
+ require.Error(t, err)
+
+ _, err = store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key)
+ require.Error(t, err)
+
+ account, err := store.GetAccount(context.Background(), existingAccountID)
+ require.NoError(t, err)
+ assert.NotContains(t, account.Peers, newPeer.ID)
+ assert.NotContains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, newPeer.ID)
+ assert.NotContains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, newPeer.ID)
+
+ assert.Equal(t, uint64(0), account.Network.Serial)
+
+ lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z")
+ assert.NoError(t, err)
+ assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed)
+ assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes)
+}
diff --git a/management/server/policy.go b/management/server/policy.go
index e3a6df3f2..75647de44 100644
--- a/management/server/policy.go
+++ b/management/server/policy.go
@@ -3,6 +3,7 @@ package server
import (
"context"
_ "embed"
+ "slices"
"strconv"
"strings"
@@ -75,6 +76,12 @@ type PolicyUpdateOperation struct {
Values []string
}
+// RulePortRange represents a range of ports for a firewall rule.
+type RulePortRange struct {
+ Start uint16
+ End uint16
+}
+
// PolicyRule is the metadata of the policy
type PolicyRule struct {
// ID of the policy rule
@@ -109,6 +116,9 @@ type PolicyRule struct {
// Ports or it ranges list
Ports []string `gorm:"serializer:json"`
+
+ // PortRanges a list of port ranges.
+ PortRanges []RulePortRange `gorm:"serializer:json"`
}
// Copy returns a copy of a policy rule
@@ -124,10 +134,12 @@ func (pm *PolicyRule) Copy() *PolicyRule {
Bidirectional: pm.Bidirectional,
Protocol: pm.Protocol,
Ports: make([]string, len(pm.Ports)),
+ PortRanges: make([]RulePortRange, len(pm.PortRanges)),
}
copy(rule.Destinations, pm.Destinations)
copy(rule.Sources, pm.Sources)
copy(rule.Ports, pm.Ports)
+ copy(rule.PortRanges, pm.PortRanges)
return rule
}
@@ -191,18 +203,6 @@ func (p *Policy) UpgradeAndFix() {
}
}
-// ruleGroups returns a list of all groups referenced in the policy's rules,
-// including sources and destinations.
-func (p *Policy) ruleGroups() []string {
- groups := make([]string, 0)
- for _, rule := range p.Rules {
- groups = append(groups, rule.Sources...)
- groups = append(groups, rule.Destinations...)
- }
-
- return groups
-}
-
// FirewallRule is a rule of the firewall.
type FirewallRule struct {
// PeerIP of the peer
@@ -326,34 +326,20 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
// GetPolicy from the store
func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
- user, err := account.FindUser(userID)
- if err != nil {
- return nil, err
- }
-
- if !(user.HasAdminPower() || user.IsServiceUser) {
+ if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
}
- for _, policy := range account.Policies {
- if policy.ID == policyID {
- return policy, nil
- }
- }
-
- return nil, status.Errorf(status.NotFound, "policy with ID %s not found", policyID)
+ return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID)
}
// SavePolicy in the store
-func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error {
+func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
@@ -362,7 +348,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
return err
}
- exists, updateAccountPeers := am.savePolicy(account, policy)
+ if err = am.savePolicy(account, policy, isUpdate); err != nil {
+ return err
+ }
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
@@ -370,14 +358,12 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
}
action := activity.PolicyAdded
- if exists {
+ if isUpdate {
action = activity.PolicyUpdated
}
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
- if updateAccountPeers {
- am.updateAccountPeers(ctx, account)
- }
+ am.updateAccountPeers(ctx, account)
return nil
}
@@ -404,33 +390,23 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
- if anyGroupHasPeers(account, policy.ruleGroups()) {
- am.updateAccountPeers(ctx, account)
- }
+ am.updateAccountPeers(ctx, account)
return nil
}
// ListPolicies from the store
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
- user, err := account.FindUser(userID)
- if err != nil {
- return nil, err
+ if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
+ return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
}
- if !(user.HasAdminPower() || user.IsServiceUser) {
- return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view policies")
- }
-
- return account.Policies, nil
+ return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
}
func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {
@@ -450,54 +426,47 @@ func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string)
return policy, nil
}
-func (am *DefaultAccountManager) savePolicy(account *Account, policy *Policy) (exists, updateAccountPeers bool) {
- for i, p := range account.Policies {
- if p.ID == policy.ID {
- account.Policies[i] = policy
+// savePolicy saves or updates a policy in the given account.
+// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
+func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error {
+ for index, rule := range policyToSave.Rules {
+ rule.Sources = filterValidGroupIDs(account, rule.Sources)
+ rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
+ policyToSave.Rules[index] = rule
+ }
- exists = true
- updateAccountPeers = anyGroupHasPeers(account, p.ruleGroups()) || anyGroupHasPeers(account, policy.ruleGroups())
+ if policyToSave.SourcePostureChecks != nil {
+ policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks)
+ }
- break
+ if isUpdate {
+ policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
+ if policyIdx < 0 {
+ return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
}
+
+ // Update the existing policy
+ account.Policies[policyIdx] = policyToSave
+ return nil
}
- if !exists {
- account.Policies = append(account.Policies, policy)
- updateAccountPeers = anyGroupHasPeers(account, policy.ruleGroups())
- }
- return
+
+ // Add the new policy to the account
+ account.Policies = append(account.Policies, policyToSave)
+
+ return nil
}
-func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
- result := make([]*proto.FirewallRule, len(update))
- for i := range update {
- direction := proto.FirewallRule_IN
- if update[i].Direction == firewallRuleDirectionOUT {
- direction = proto.FirewallRule_OUT
- }
- action := proto.FirewallRule_ACCEPT
- if update[i].Action == string(PolicyTrafficActionDrop) {
- action = proto.FirewallRule_DROP
- }
-
- protocol := proto.FirewallRule_UNKNOWN
- switch PolicyRuleProtocolType(update[i].Protocol) {
- case PolicyRuleProtocolALL:
- protocol = proto.FirewallRule_ALL
- case PolicyRuleProtocolTCP:
- protocol = proto.FirewallRule_TCP
- case PolicyRuleProtocolUDP:
- protocol = proto.FirewallRule_UDP
- case PolicyRuleProtocolICMP:
- protocol = proto.FirewallRule_ICMP
- }
+func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
+ result := make([]*proto.FirewallRule, len(rules))
+ for i := range rules {
+ rule := rules[i]
result[i] = &proto.FirewallRule{
- PeerIP: update[i].PeerIP,
- Direction: direction,
- Action: action,
- Protocol: protocol,
- Port: update[i].Port,
+ PeerIP: rule.PeerIP,
+ Direction: getProtoDirection(rule.Direction),
+ Action: getProtoAction(rule.Action),
+ Protocol: getProtoProtocol(rule.Protocol),
+ Port: rule.Port,
}
}
return result
@@ -580,3 +549,29 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
}
return nil
}
+
+// filterValidPostureChecks filters and returns the posture check IDs from the given list
+// that are valid within the provided account.
+func filterValidPostureChecks(account *Account, postureChecksIds []string) []string {
+ result := make([]string, 0, len(postureChecksIds))
+ for _, id := range postureChecksIds {
+ for _, postureCheck := range account.PostureChecks {
+ if id == postureCheck.ID {
+ result = append(result, id)
+ continue
+ }
+ }
+ }
+ return result
+}
+
+// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map.
+func filterValidGroupIDs(account *Account, groupIDs []string) []string {
+ result := make([]string, 0, len(groupIDs))
+ for _, groupID := range groupIDs {
+ if _, exists := account.Groups[groupID]; exists {
+ result = append(result, groupID)
+ }
+ }
+ return result
+}
diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go
index c0a9c6a65..ca4946703 100644
--- a/management/server/posture_checks.go
+++ b/management/server/posture_checks.go
@@ -15,30 +15,16 @@ const (
)
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
- user, err := account.FindUser(userID)
- if err != nil {
- return nil, err
- }
-
- if !user.HasAdminPower() {
+ if !user.HasAdminPower() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
}
- for _, postureChecks := range account.PostureChecks {
- if postureChecks.ID == postureChecksID {
- return postureChecks, nil
- }
- }
-
- return nil, status.Errorf(status.NotFound, "posture checks with ID %s not found", postureChecksID)
+ return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
}
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
@@ -123,24 +109,16 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
}
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
- user, err := account.FindUser(userID)
- if err != nil {
- return nil, err
- }
-
- if !user.HasAdminPower() {
+ if !user.HasAdminPower() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
}
- return account.PostureChecks, nil
+ return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
}
func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) {
diff --git a/management/server/route.go b/management/server/route.go
index 3e3842ab1..39ee6170c 100644
--- a/management/server/route.go
+++ b/management/server/route.go
@@ -4,9 +4,15 @@ import (
"context"
"fmt"
"net/netip"
+ "slices"
+ "strconv"
+ "strings"
"unicode/utf8"
"github.com/rs/xid"
+ log "github.com/sirupsen/logrus"
+
+ nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto"
@@ -15,31 +21,42 @@ import (
"github.com/netbirdio/netbird/route"
)
+// RouteFirewallRule a firewall rule applicable for a routed network.
+type RouteFirewallRule struct {
+ // SourceRanges IP ranges of the routing peers.
+ SourceRanges []string
+
+ // Action of the traffic when the rule is applicable
+ Action string
+
+ // Destination a network prefix for the routed traffic
+ Destination string
+
+ // Protocol of the traffic
+ Protocol string
+
+ // Port of the traffic
+ Port uint16
+
+ // PortRange represents the range of ports for a firewall rule
+ PortRange RulePortRange
+
+ // isDynamic indicates whether the rule is for DNS routing
+ IsDynamic bool
+}
+
// GetRoute gets a route object from account and route IDs
func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
- user, err := account.FindUser(userID)
- if err != nil {
- return nil, err
- }
-
- if !(user.HasAdminPower() || user.IsServiceUser) {
+ if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
}
- wantedRoute, found := account.Routes[routeID]
- if found {
- return wantedRoute, nil
- }
-
- return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
+ return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID)
}
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
@@ -125,7 +142,7 @@ func getRouteDescriptor(prefix netip.Prefix, domains domain.List) string {
}
// CreateRoute creates and saves a new route
-func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
+func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
@@ -134,6 +151,13 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return nil, err
}
+ // Do not allow non-Linux peers
+ if peer := account.GetPeer(peerID); peer != nil {
+ if peer.Meta.GoOS != "linux" {
+ return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
+ }
+ }
+
if len(domains) > 0 && prefix.IsValid() {
return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
}
@@ -163,6 +187,13 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
}
}
+ if len(accessControlGroupIDs) > 0 {
+ err = validateGroups(accessControlGroupIDs, account.Groups)
+ if err != nil {
+ return nil, err
+ }
+ }
+
err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains)
if err != nil {
return nil, err
@@ -193,6 +224,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
newRoute.Enabled = enabled
newRoute.Groups = groups
newRoute.KeepRoute = keepRoute
+ newRoute.AccessControlGroups = accessControlGroupIDs
if account.Routes == nil {
account.Routes = make(map[route.ID]*route.Route)
@@ -205,9 +237,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return nil, err
}
- if isRouteChangeAffectPeers(account, &newRoute) {
- am.updateAccountPeers(ctx, account)
- }
+ am.updateAccountPeers(ctx, account)
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
@@ -236,6 +266,13 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return err
}
+ // Do not allow non-Linux peers
+ if peer := account.GetPeer(routeToSave.Peer); peer != nil {
+ if peer.Meta.GoOS != "linux" {
+ return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
+ }
+ }
+
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
}
@@ -259,6 +296,13 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
}
}
+ if len(routeToSave.AccessControlGroups) > 0 {
+ err = validateGroups(routeToSave.AccessControlGroups, account.Groups)
+ if err != nil {
+ return err
+ }
+ }
+
err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
if err != nil {
return err
@@ -269,7 +313,6 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return err
}
- oldRoute := account.Routes[routeToSave.ID]
account.Routes[routeToSave.ID] = routeToSave
account.Network.IncSerial()
@@ -277,9 +320,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return err
}
- if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) {
- am.updateAccountPeers(ctx, account)
- }
+ am.updateAccountPeers(ctx, account)
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
@@ -296,8 +337,8 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
return err
}
- route := account.Routes[routeID]
- if route == nil {
+ routy := account.Routes[routeID]
+ if routy == nil {
return status.Errorf(status.NotFound, "route with ID %s doesn't exist", routeID)
}
delete(account.Routes, routeID)
@@ -307,40 +348,25 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
return err
}
- if isRouteChangeAffectPeers(account, route) {
- am.updateAccountPeers(ctx, account)
- }
+ am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
- am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
+ am.updateAccountPeers(ctx, account)
return nil
}
// ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
- user, err := account.FindUser(userID)
- if err != nil {
- return nil, err
- }
-
- if !(user.HasAdminPower() || user.IsServiceUser) {
+ if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
}
- routes := make([]*route.Route, 0, len(account.Routes))
- for _, item := range account.Routes {
- routes = append(routes, item)
- }
-
- return routes, nil
+ return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
}
func toProtocolRoute(route *route.Route) *proto.Route {
@@ -371,8 +397,247 @@ func getPlaceholderIP() netip.Prefix {
return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32)
}
-// isRouteChangeAffectPeers checks if a given route affects peers by determining
-// if it has a routing peer, distribution, or peer groups that include peers
-func isRouteChangeAffectPeers(account *Account, route *route.Route) bool {
- return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
+// getPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account.
+func (a *Account) getPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule {
+ routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes))
+
+ enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID)
+ for _, route := range enabledRoutes {
+ // If no access control groups are specified, accept all traffic.
+ if len(route.AccessControlGroups) == 0 {
+ defaultPermit := getDefaultPermit(route)
+ routesFirewallRules = append(routesFirewallRules, defaultPermit...)
+ continue
+ }
+
+ policies := getAllRoutePoliciesFromGroups(a, route.AccessControlGroups)
+ for _, policy := range policies {
+ if !policy.Enabled {
+ continue
+ }
+
+ for _, rule := range policy.Rules {
+ if !rule.Enabled {
+ continue
+ }
+
+ distributionGroupPeers, _ := a.getAllPeersFromGroups(ctx, route.Groups, peerID, nil, validatedPeersMap)
+ rules := generateRouteFirewallRules(ctx, route, rule, distributionGroupPeers, firewallRuleDirectionIN)
+ routesFirewallRules = append(routesFirewallRules, rules...)
+ }
+ }
+ }
+
+ return routesFirewallRules
+}
+
+func getDefaultPermit(route *route.Route) []*RouteFirewallRule {
+ var rules []*RouteFirewallRule
+
+ sources := []string{"0.0.0.0/0"}
+ if route.Network.Addr().Is6() {
+ sources = []string{"::/0"}
+ }
+ rule := RouteFirewallRule{
+ SourceRanges: sources,
+ Action: string(PolicyTrafficActionAccept),
+ Destination: route.Network.String(),
+ Protocol: string(PolicyRuleProtocolALL),
+ IsDynamic: route.IsDynamic(),
+ }
+
+ rules = append(rules, &rule)
+
+ // dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally
+ if route.IsDynamic() {
+ ruleV6 := rule
+ ruleV6.SourceRanges = []string{"::/0"}
+ rules = append(rules, &ruleV6)
+ }
+
+ return rules
+}
+
+// getAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups
+// and returns a list of policies that have rules with destinations matching the specified groups.
+func getAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy {
+ routePolicies := make([]*Policy, 0)
+ for _, groupID := range accessControlGroups {
+ group, ok := account.Groups[groupID]
+ if !ok {
+ continue
+ }
+
+ for _, policy := range account.Policies {
+ for _, rule := range policy.Rules {
+ exist := slices.ContainsFunc(rule.Destinations, func(groupID string) bool {
+ return groupID == group.ID
+ })
+ if exist {
+ routePolicies = append(routePolicies, policy)
+ continue
+ }
+ }
+ }
+ }
+
+ return routePolicies
+}
+
+// generateRouteFirewallRules generates a list of firewall rules for a given route.
+func generateRouteFirewallRules(ctx context.Context, route *route.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule {
+ rulesExists := make(map[string]struct{})
+ rules := make([]*RouteFirewallRule, 0)
+
+ sourceRanges := make([]string, 0, len(groupPeers))
+ for _, peer := range groupPeers {
+ if peer == nil {
+ continue
+ }
+ sourceRanges = append(sourceRanges, fmt.Sprintf(AllowedIPsFormat, peer.IP))
+ }
+
+ baseRule := RouteFirewallRule{
+ SourceRanges: sourceRanges,
+ Action: string(rule.Action),
+ Destination: route.Network.String(),
+ Protocol: string(rule.Protocol),
+ IsDynamic: route.IsDynamic(),
+ }
+
+ // generate rule for port range
+ if len(rule.Ports) == 0 {
+ rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...)
+ } else {
+ rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...)
+
+ }
+
+ // TODO: generate IPv6 rules for dynamic routes
+
+ return rules
+}
+
+// generateRuleIDBase generates the base rule ID for checking duplicates.
+func generateRuleIDBase(rule *PolicyRule, baseRule RouteFirewallRule) string {
+ return rule.ID + strings.Join(baseRule.SourceRanges, ",") + strconv.Itoa(firewallRuleDirectionIN) + baseRule.Protocol + baseRule.Action
+}
+
+// generateRulesForPeer generates rules for a given peer based on ports and port ranges.
+func generateRulesWithPortRanges(baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule {
+ rules := make([]*RouteFirewallRule, 0)
+
+ ruleIDBase := generateRuleIDBase(rule, baseRule)
+ if len(rule.Ports) == 0 {
+ if len(rule.PortRanges) == 0 {
+ if _, ok := rulesExists[ruleIDBase]; !ok {
+ rulesExists[ruleIDBase] = struct{}{}
+ rules = append(rules, &baseRule)
+ }
+ } else {
+ for _, portRange := range rule.PortRanges {
+ ruleID := fmt.Sprintf("%s%d-%d", ruleIDBase, portRange.Start, portRange.End)
+ if _, ok := rulesExists[ruleID]; !ok {
+ rulesExists[ruleID] = struct{}{}
+ pr := baseRule
+ pr.PortRange = portRange
+ rules = append(rules, &pr)
+ }
+ }
+ }
+ return rules
+ }
+
+ return rules
+}
+
+// generateRulesWithPorts generates rules when specific ports are provided.
+func generateRulesWithPorts(ctx context.Context, baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule {
+ rules := make([]*RouteFirewallRule, 0)
+ ruleIDBase := generateRuleIDBase(rule, baseRule)
+
+ for _, port := range rule.Ports {
+ ruleID := ruleIDBase + port
+ if _, ok := rulesExists[ruleID]; ok {
+ continue
+ }
+ rulesExists[ruleID] = struct{}{}
+
+ pr := baseRule
+ p, err := strconv.ParseUint(port, 10, 16)
+ if err != nil {
+ log.WithContext(ctx).Errorf("failed to parse port %s for rule: %s", port, rule.ID)
+ continue
+ }
+
+ pr.Port = uint16(p)
+ rules = append(rules, &pr)
+ }
+
+ return rules
+}
+
+func toProtocolRoutesFirewallRules(rules []*RouteFirewallRule) []*proto.RouteFirewallRule {
+ result := make([]*proto.RouteFirewallRule, len(rules))
+ for i := range rules {
+ rule := rules[i]
+ result[i] = &proto.RouteFirewallRule{
+ SourceRanges: rule.SourceRanges,
+ Action: getProtoAction(rule.Action),
+ Destination: rule.Destination,
+ Protocol: getProtoProtocol(rule.Protocol),
+ PortInfo: getProtoPortInfo(rule),
+ IsDynamic: rule.IsDynamic,
+ }
+ }
+
+ return result
+}
+
+// getProtoDirection converts the direction to proto.RuleDirection.
+func getProtoDirection(direction int) proto.RuleDirection {
+ if direction == firewallRuleDirectionOUT {
+ return proto.RuleDirection_OUT
+ }
+ return proto.RuleDirection_IN
+}
+
+// getProtoAction converts the action to proto.RuleAction.
+func getProtoAction(action string) proto.RuleAction {
+ if action == string(PolicyTrafficActionDrop) {
+ return proto.RuleAction_DROP
+ }
+ return proto.RuleAction_ACCEPT
+}
+
+// getProtoProtocol converts the protocol to proto.RuleProtocol.
+func getProtoProtocol(protocol string) proto.RuleProtocol {
+ switch PolicyRuleProtocolType(protocol) {
+ case PolicyRuleProtocolALL:
+ return proto.RuleProtocol_ALL
+ case PolicyRuleProtocolTCP:
+ return proto.RuleProtocol_TCP
+ case PolicyRuleProtocolUDP:
+ return proto.RuleProtocol_UDP
+ case PolicyRuleProtocolICMP:
+ return proto.RuleProtocol_ICMP
+ default:
+ return proto.RuleProtocol_UNKNOWN
+ }
+}
+
+// getProtoPortInfo converts the port info to proto.PortInfo.
+func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo {
+ var portInfo proto.PortInfo
+ if rule.Port != 0 {
+ portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
+ } else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
+ portInfo.PortSelection = &proto.PortInfo_Range_{
+ Range: &proto.PortInfo_Range{
+ Start: uint32(portRange.Start),
+ End: uint32(portRange.End),
+ },
+ }
+ }
+ return &portInfo
}
diff --git a/management/server/route_test.go b/management/server/route_test.go
index 8dcdef884..b556816be 100644
--- a/management/server/route_test.go
+++ b/management/server/route_test.go
@@ -2,9 +2,10 @@ package server
import (
"context"
+ "fmt"
+ "net"
"net/netip"
"testing"
- "time"
"github.com/rs/xid"
"github.com/stretchr/testify/assert"
@@ -45,18 +46,19 @@ var existingDomains = domain.List{"example.com"}
func TestCreateRoute(t *testing.T) {
type input struct {
- network netip.Prefix
- domains domain.List
- keepRoute bool
- networkType route.NetworkType
- netID route.NetID
- peerKey string
- peerGroupIDs []string
- description string
- masquerade bool
- metric int
- enabled bool
- groups []string
+ network netip.Prefix
+ domains domain.List
+ keepRoute bool
+ networkType route.NetworkType
+ netID route.NetID
+ peerKey string
+ peerGroupIDs []string
+ description string
+ masquerade bool
+ metric int
+ enabled bool
+ groups []string
+ accessControlGroups []string
}
testCases := []struct {
@@ -70,100 +72,107 @@ func TestCreateRoute(t *testing.T) {
{
name: "Happy Path Network",
inputArgs: input{
- network: netip.MustParsePrefix("192.168.0.0/16"),
- networkType: route.IPv4Network,
- netID: "happy",
- peerKey: peer1ID,
- description: "super",
- masquerade: false,
- metric: 9999,
- enabled: true,
- groups: []string{routeGroup1},
+ network: netip.MustParsePrefix("192.168.0.0/16"),
+ networkType: route.IPv4Network,
+ netID: "happy",
+ peerKey: peer1ID,
+ description: "super",
+ masquerade: false,
+ metric: 9999,
+ enabled: true,
+ groups: []string{routeGroup1},
+ accessControlGroups: []string{routeGroup1},
},
errFunc: require.NoError,
shouldCreate: true,
expectedRoute: &route.Route{
- Network: netip.MustParsePrefix("192.168.0.0/16"),
- NetworkType: route.IPv4Network,
- NetID: "happy",
- Peer: peer1ID,
- Description: "super",
- Masquerade: false,
- Metric: 9999,
- Enabled: true,
- Groups: []string{routeGroup1},
+ Network: netip.MustParsePrefix("192.168.0.0/16"),
+ NetworkType: route.IPv4Network,
+ NetID: "happy",
+ Peer: peer1ID,
+ Description: "super",
+ Masquerade: false,
+ Metric: 9999,
+ Enabled: true,
+ Groups: []string{routeGroup1},
+ AccessControlGroups: []string{routeGroup1},
},
},
{
name: "Happy Path Domains",
inputArgs: input{
- domains: domain.List{"domain1", "domain2"},
- keepRoute: true,
- networkType: route.DomainNetwork,
- netID: "happy",
- peerKey: peer1ID,
- description: "super",
- masquerade: false,
- metric: 9999,
- enabled: true,
- groups: []string{routeGroup1},
+ domains: domain.List{"domain1", "domain2"},
+ keepRoute: true,
+ networkType: route.DomainNetwork,
+ netID: "happy",
+ peerKey: peer1ID,
+ description: "super",
+ masquerade: false,
+ metric: 9999,
+ enabled: true,
+ groups: []string{routeGroup1},
+ accessControlGroups: []string{routeGroup1},
},
errFunc: require.NoError,
shouldCreate: true,
expectedRoute: &route.Route{
- Network: netip.MustParsePrefix("192.0.2.0/32"),
- Domains: domain.List{"domain1", "domain2"},
- NetworkType: route.DomainNetwork,
- NetID: "happy",
- Peer: peer1ID,
- Description: "super",
- Masquerade: false,
- Metric: 9999,
- Enabled: true,
- Groups: []string{routeGroup1},
- KeepRoute: true,
+ Network: netip.MustParsePrefix("192.0.2.0/32"),
+ Domains: domain.List{"domain1", "domain2"},
+ NetworkType: route.DomainNetwork,
+ NetID: "happy",
+ Peer: peer1ID,
+ Description: "super",
+ Masquerade: false,
+ Metric: 9999,
+ Enabled: true,
+ Groups: []string{routeGroup1},
+ KeepRoute: true,
+ AccessControlGroups: []string{routeGroup1},
},
},
{
name: "Happy Path Peer Groups",
inputArgs: input{
- network: netip.MustParsePrefix("192.168.0.0/16"),
- networkType: route.IPv4Network,
- netID: "happy",
- peerGroupIDs: []string{routeGroupHA1, routeGroupHA2},
- description: "super",
- masquerade: false,
- metric: 9999,
- enabled: true,
- groups: []string{routeGroup1, routeGroup2},
+ network: netip.MustParsePrefix("192.168.0.0/16"),
+ networkType: route.IPv4Network,
+ netID: "happy",
+ peerGroupIDs: []string{routeGroupHA1, routeGroupHA2},
+ description: "super",
+ masquerade: false,
+ metric: 9999,
+ enabled: true,
+ groups: []string{routeGroup1, routeGroup2},
+ accessControlGroups: []string{routeGroup1, routeGroup2},
},
errFunc: require.NoError,
shouldCreate: true,
expectedRoute: &route.Route{
- Network: netip.MustParsePrefix("192.168.0.0/16"),
- NetworkType: route.IPv4Network,
- NetID: "happy",
- PeerGroups: []string{routeGroupHA1, routeGroupHA2},
- Description: "super",
- Masquerade: false,
- Metric: 9999,
- Enabled: true,
- Groups: []string{routeGroup1, routeGroup2},
+ Network: netip.MustParsePrefix("192.168.0.0/16"),
+ NetworkType: route.IPv4Network,
+ NetID: "happy",
+ PeerGroups: []string{routeGroupHA1, routeGroupHA2},
+ Description: "super",
+ Masquerade: false,
+ Metric: 9999,
+ Enabled: true,
+ Groups: []string{routeGroup1, routeGroup2},
+ AccessControlGroups: []string{routeGroup1, routeGroup2},
},
},
{
name: "Both network and domains provided should fail",
inputArgs: input{
- network: netip.MustParsePrefix("192.168.0.0/16"),
- domains: domain.List{"domain1", "domain2"},
- netID: "happy",
- peerKey: peer1ID,
- peerGroupIDs: []string{routeGroupHA1},
- description: "super",
- masquerade: false,
- metric: 9999,
- enabled: true,
- groups: []string{routeGroup1},
+ network: netip.MustParsePrefix("192.168.0.0/16"),
+ domains: domain.List{"domain1", "domain2"},
+ netID: "happy",
+ peerKey: peer1ID,
+ peerGroupIDs: []string{routeGroupHA1},
+ description: "super",
+ masquerade: false,
+ metric: 9999,
+ enabled: true,
+ groups: []string{routeGroup1},
+ accessControlGroups: []string{routeGroup2},
},
errFunc: require.Error,
shouldCreate: false,
@@ -171,16 +180,17 @@ func TestCreateRoute(t *testing.T) {
{
name: "Both peer and peer_groups Provided Should Fail",
inputArgs: input{
- network: netip.MustParsePrefix("192.168.0.0/16"),
- networkType: route.IPv4Network,
- netID: "happy",
- peerKey: peer1ID,
- peerGroupIDs: []string{routeGroupHA1},
- description: "super",
- masquerade: false,
- metric: 9999,
- enabled: true,
- groups: []string{routeGroup1},
+ network: netip.MustParsePrefix("192.168.0.0/16"),
+ networkType: route.IPv4Network,
+ netID: "happy",
+ peerKey: peer1ID,
+ peerGroupIDs: []string{routeGroupHA1},
+ description: "super",
+ masquerade: false,
+ metric: 9999,
+ enabled: true,
+ groups: []string{routeGroup1},
+ accessControlGroups: []string{routeGroup2},
},
errFunc: require.Error,
shouldCreate: false,
@@ -424,13 +434,13 @@ func TestCreateRoute(t *testing.T) {
if testCase.createInitRoute {
groupAll, errInit := account.GetGroupAll()
require.NoError(t, errInit)
- _, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID, false)
+ _, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false)
require.NoError(t, errInit)
- _, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID, false)
+ _, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false)
require.NoError(t, errInit)
}
- outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute)
+ outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute)
testCase.errFunc(t, err)
@@ -1038,15 +1048,16 @@ func TestDeleteRoute(t *testing.T) {
func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
baseRoute := &route.Route{
- Network: netip.MustParsePrefix("192.168.0.0/16"),
- NetID: "superNet",
- NetworkType: route.IPv4Network,
- PeerGroups: []string{routeGroupHA1, routeGroupHA2},
- Description: "ha route",
- Masquerade: false,
- Metric: 9999,
- Enabled: true,
- Groups: []string{routeGroup1, routeGroup2},
+ Network: netip.MustParsePrefix("192.168.0.0/16"),
+ NetID: "superNet",
+ NetworkType: route.IPv4Network,
+ PeerGroups: []string{routeGroupHA1, routeGroupHA2},
+ Description: "ha route",
+ Masquerade: false,
+ Metric: 9999,
+ Enabled: true,
+ Groups: []string{routeGroup1, routeGroup2},
+ AccessControlGroups: []string{routeGroup1},
}
am, err := createRouterManager(t)
@@ -1063,7 +1074,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
require.NoError(t, err)
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
- newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.Enabled, userID, baseRoute.KeepRoute)
+ newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute)
require.NoError(t, err)
require.Equal(t, newRoute.Enabled, true)
@@ -1128,16 +1139,17 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
// no routes for peer in different groups
// no routes when route is deleted
baseRoute := &route.Route{
- ID: "testingRoute",
- Network: netip.MustParsePrefix("192.168.0.0/16"),
- NetID: "superNet",
- NetworkType: route.IPv4Network,
- Peer: peer1ID,
- Description: "super",
- Masquerade: false,
- Metric: 9999,
- Enabled: true,
- Groups: []string{routeGroup1},
+ ID: "testingRoute",
+ Network: netip.MustParsePrefix("192.168.0.0/16"),
+ NetID: "superNet",
+ NetworkType: route.IPv4Network,
+ Peer: peer1ID,
+ Description: "super",
+ Masquerade: false,
+ Metric: 9999,
+ Enabled: true,
+ Groups: []string{routeGroup1},
+ AccessControlGroups: []string{routeGroup1},
}
am, err := createRouterManager(t)
@@ -1154,7 +1166,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
require.NoError(t, err)
require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes")
- createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, false, userID, baseRoute.KeepRoute)
+ createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute)
require.NoError(t, err)
noDisabledRoutes, err := am.GetNetworkMap(context.Background(), peer1ID)
@@ -1206,7 +1218,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
newPolicy.Rules[0].Sources = []string{newGroup.ID}
newPolicy.Rules[0].Destinations = []string{newGroup.ID}
- err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy)
+ err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false)
require.NoError(t, err)
err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID)
@@ -1273,11 +1285,12 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
}
peer1 := &nbpeer.Peer{
- IP: peer1IP,
- ID: peer1ID,
- Key: peer1Key,
- Name: "test-host1@netbird.io",
- UserID: userID,
+ IP: peer1IP,
+ ID: peer1ID,
+ Key: peer1Key,
+ Name: "test-host1@netbird.io",
+ DNSLabel: "test-host1",
+ UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host1@netbird.io",
GoOS: "linux",
@@ -1299,11 +1312,12 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
}
peer2 := &nbpeer.Peer{
- IP: peer2IP,
- ID: peer2ID,
- Key: peer2Key,
- Name: "test-host2@netbird.io",
- UserID: userID,
+ IP: peer2IP,
+ ID: peer2ID,
+ Key: peer2Key,
+ Name: "test-host2@netbird.io",
+ DNSLabel: "test-host2",
+ UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host2@netbird.io",
GoOS: "linux",
@@ -1325,11 +1339,12 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
}
peer3 := &nbpeer.Peer{
- IP: peer3IP,
- ID: peer3ID,
- Key: peer3Key,
- Name: "test-host3@netbird.io",
- UserID: userID,
+ IP: peer3IP,
+ ID: peer3ID,
+ Key: peer3Key,
+ Name: "test-host3@netbird.io",
+ DNSLabel: "test-host3",
+ UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host3@netbird.io",
GoOS: "darwin",
@@ -1351,11 +1366,12 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
}
peer4 := &nbpeer.Peer{
- IP: peer4IP,
- ID: peer4ID,
- Key: peer4Key,
- Name: "test-host4@netbird.io",
- UserID: userID,
+ IP: peer4IP,
+ ID: peer4ID,
+ Key: peer4Key,
+ Name: "test-host4@netbird.io",
+ DNSLabel: "test-host4",
+ UserID: userID,
Meta: nbpeer.PeerSystemMeta{
Hostname: "test-host4@netbird.io",
GoOS: "linux",
@@ -1377,13 +1393,14 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
}
peer5 := &nbpeer.Peer{
- IP: peer5IP,
- ID: peer5ID,
- Key: peer5Key,
- Name: "test-host4@netbird.io",
- UserID: userID,
+ IP: peer5IP,
+ ID: peer5ID,
+ Key: peer5Key,
+ Name: "test-host5@netbird.io",
+ DNSLabel: "test-host5",
+ UserID: userID,
Meta: nbpeer.PeerSystemMeta{
- Hostname: "test-host4@netbird.io",
+ Hostname: "test-host5@netbird.io",
GoOS: "linux",
Kernel: "Linux",
Core: "21.04",
@@ -1464,90 +1481,299 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er
return am.Store.GetAccount(context.Background(), account.Id)
}
-func TestRouteAccountPeerUpdate(t *testing.T) {
- manager, err := createRouterManager(t)
- require.NoError(t, err, "failed to create account manager")
+func TestAccount_getPeersRoutesFirewall(t *testing.T) {
+ var (
+ peerBIp = "100.65.80.39"
+ peerCIp = "100.65.254.139"
+ peerHIp = "100.65.29.55"
+ )
- account, err := initTestRouteAccount(t, manager)
- require.NoError(t, err, "failed to init testing account")
-
- updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID)
- t.Cleanup(func() {
- manager.peersUpdateManager.CloseChannel(context.Background(), peer1ID)
- })
-
- baseRoute := route.Route{
- ID: "testingRoute",
- Network: netip.MustParsePrefix("192.168.0.0/16"),
- NetID: "superNet",
- NetworkType: route.IPv4Network,
- Peer: peer1ID,
- Description: "super",
- Masquerade: false,
- Metric: 9999,
- Enabled: true,
- Groups: []string{routeGroup1},
+ account := &Account{
+ Peers: map[string]*nbpeer.Peer{
+ "peerA": {
+ ID: "peerA",
+ IP: net.ParseIP("100.65.14.88"),
+ Status: &nbpeer.PeerStatus{},
+ Meta: nbpeer.PeerSystemMeta{
+ GoOS: "linux",
+ },
+ },
+ "peerB": {
+ ID: "peerB",
+ IP: net.ParseIP(peerBIp),
+ Status: &nbpeer.PeerStatus{},
+ Meta: nbpeer.PeerSystemMeta{},
+ },
+ "peerC": {
+ ID: "peerC",
+ IP: net.ParseIP(peerCIp),
+ Status: &nbpeer.PeerStatus{},
+ },
+ "peerD": {
+ ID: "peerD",
+ IP: net.ParseIP("100.65.62.5"),
+ Status: &nbpeer.PeerStatus{},
+ Meta: nbpeer.PeerSystemMeta{
+ GoOS: "linux",
+ },
+ },
+ "peerE": {
+ ID: "peerE",
+ IP: net.ParseIP("100.65.32.206"),
+ Key: peer1Key,
+ Status: &nbpeer.PeerStatus{},
+ Meta: nbpeer.PeerSystemMeta{
+ GoOS: "linux",
+ },
+ },
+ "peerF": {
+ ID: "peerF",
+ IP: net.ParseIP("100.65.250.202"),
+ Status: &nbpeer.PeerStatus{},
+ },
+ "peerG": {
+ ID: "peerG",
+ IP: net.ParseIP("100.65.13.186"),
+ Status: &nbpeer.PeerStatus{},
+ },
+ "peerH": {
+ ID: "peerH",
+ IP: net.ParseIP(peerHIp),
+ Status: &nbpeer.PeerStatus{},
+ },
+ },
+ Groups: map[string]*nbgroup.Group{
+ "routingPeer1": {
+ ID: "routingPeer1",
+ Name: "RoutingPeer1",
+ Peers: []string{
+ "peerA",
+ },
+ },
+ "routingPeer2": {
+ ID: "routingPeer2",
+ Name: "RoutingPeer2",
+ Peers: []string{
+ "peerD",
+ },
+ },
+ "route1": {
+ ID: "route1",
+ Name: "Route1",
+ Peers: []string{},
+ },
+ "route2": {
+ ID: "route2",
+ Name: "Route2",
+ Peers: []string{},
+ },
+ "finance": {
+ ID: "finance",
+ Name: "Finance",
+ Peers: []string{
+ "peerF",
+ "peerG",
+ },
+ },
+ "dev": {
+ ID: "dev",
+ Name: "Dev",
+ Peers: []string{
+ "peerC",
+ "peerH",
+ "peerB",
+ },
+ },
+ "contractors": {
+ ID: "contractors",
+ Name: "Contractors",
+ Peers: []string{},
+ },
+ },
+ Routes: map[route.ID]*route.Route{
+ "route1": {
+ ID: "route1",
+ Network: netip.MustParsePrefix("192.168.0.0/16"),
+ NetID: "route1",
+ NetworkType: route.IPv4Network,
+ PeerGroups: []string{"routingPeer1", "routingPeer2"},
+ Description: "Route1 ha route",
+ Masquerade: false,
+ Metric: 9999,
+ Enabled: true,
+ Groups: []string{"dev"},
+ AccessControlGroups: []string{"route1"},
+ },
+ "route2": {
+ ID: "route2",
+ Network: existingNetwork,
+ NetID: "route2",
+ NetworkType: route.IPv4Network,
+ Peer: "peerE",
+ Description: "Allow",
+ Masquerade: false,
+ Metric: 9999,
+ Enabled: true,
+ Groups: []string{"finance"},
+ AccessControlGroups: []string{"route2"},
+ },
+ "route3": {
+ ID: "route3",
+ Network: netip.MustParsePrefix("192.0.2.0/32"),
+ Domains: domain.List{"example.com"},
+ NetID: "route3",
+ NetworkType: route.DomainNetwork,
+ Peer: "peerE",
+ Description: "Allow all traffic to routed DNS network",
+ Masquerade: false,
+ Metric: 9999,
+ Enabled: true,
+ Groups: []string{"contractors"},
+ AccessControlGroups: []string{},
+ },
+ },
+ Policies: []*Policy{
+ {
+ ID: "RuleRoute1",
+ Name: "Route1",
+ Enabled: true,
+ Rules: []*PolicyRule{
+ {
+ ID: "RuleRoute1",
+ Name: "ruleRoute1",
+ Bidirectional: true,
+ Enabled: true,
+ Protocol: PolicyRuleProtocolALL,
+ Action: PolicyTrafficActionAccept,
+ Ports: []string{"80", "320"},
+ Sources: []string{
+ "dev",
+ },
+ Destinations: []string{
+ "route1",
+ },
+ },
+ },
+ },
+ {
+ ID: "RuleRoute2",
+ Name: "Route2",
+ Enabled: true,
+ Rules: []*PolicyRule{
+ {
+ ID: "RuleRoute2",
+ Name: "ruleRoute2",
+ Bidirectional: true,
+ Enabled: true,
+ Protocol: PolicyRuleProtocolTCP,
+ Action: PolicyTrafficActionAccept,
+ PortRanges: []RulePortRange{
+ {
+ Start: 80,
+ End: 350,
+ }, {
+ Start: 80,
+ End: 350,
+ },
+ },
+ Sources: []string{
+ "finance",
+ },
+ Destinations: []string{
+ "route2",
+ },
+ },
+ },
+ },
+ },
}
- // Creating route should not update account peers and send peer update
- t.Run("creating route", func(t *testing.T) {
- done := make(chan struct{})
- go func() {
- peerShouldReceiveUpdate(t, updMsg)
- close(done)
- }()
+ validatedPeers := make(map[string]struct{})
+ for p := range account.Peers {
+ validatedPeers[p] = struct{}{}
+ }
- newRoute, err := manager.CreateRoute(
- context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer,
- baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric,
- baseRoute.Groups, true, userID, baseRoute.KeepRoute,
- )
- require.NoError(t, err)
- baseRoute = *newRoute
+ t.Run("check applied policies for the route", func(t *testing.T) {
+ route1 := account.Routes["route1"]
+ policies := getAllRoutePoliciesFromGroups(account, route1.AccessControlGroups)
+ assert.Len(t, policies, 1)
- select {
- case <-done:
- case <-time.After(time.Second):
- t.Error("timeout waiting for peerShouldReceiveUpdate")
- }
+ route2 := account.Routes["route2"]
+ policies = getAllRoutePoliciesFromGroups(account, route2.AccessControlGroups)
+ assert.Len(t, policies, 1)
+
+ route3 := account.Routes["route3"]
+ policies = getAllRoutePoliciesFromGroups(account, route3.AccessControlGroups)
+ assert.Len(t, policies, 0)
})
- // Updating the route should update account peers and send peer update
- t.Run("updating route", func(t *testing.T) {
- baseRoute.Groups = []string{routeGroup1, routeGroup2}
+ t.Run("check peer routes firewall rules", func(t *testing.T) {
+ routesFirewallRules := account.getPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers)
+ assert.Len(t, routesFirewallRules, 2)
- done := make(chan struct{})
- go func() {
- peerShouldReceiveUpdate(t, updMsg)
- close(done)
- }()
-
- err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute)
- require.NoError(t, err)
-
- select {
- case <-done:
- case <-time.After(time.Second):
- t.Error("timeout waiting for peerShouldReceiveUpdate")
+ expectedRoutesFirewallRules := []*RouteFirewallRule{
+ {
+ SourceRanges: []string{
+ fmt.Sprintf(AllowedIPsFormat, peerCIp),
+ fmt.Sprintf(AllowedIPsFormat, peerHIp),
+ fmt.Sprintf(AllowedIPsFormat, peerBIp),
+ },
+ Action: "accept",
+ Destination: "192.168.0.0/16",
+ Protocol: "all",
+ Port: 80,
+ },
+ {
+ SourceRanges: []string{
+ fmt.Sprintf(AllowedIPsFormat, peerCIp),
+ fmt.Sprintf(AllowedIPsFormat, peerHIp),
+ fmt.Sprintf(AllowedIPsFormat, peerBIp),
+ },
+ Action: "accept",
+ Destination: "192.168.0.0/16",
+ Protocol: "all",
+ Port: 320,
+ },
}
- })
+ assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules)
- // Deleting the route should update account peers and send peer update
- t.Run("deleting route", func(t *testing.T) {
- done := make(chan struct{})
- go func() {
- peerShouldReceiveUpdate(t, updMsg)
- close(done)
- }()
+ //peerD is also the routing peer for route1, should contain same routes firewall rules as peerA
+ routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers)
+ assert.Len(t, routesFirewallRules, 2)
+ assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules)
- err := manager.DeleteRoute(context.Background(), account.Id, baseRoute.ID, userID)
- require.NoError(t, err)
+ // peerE is a single routing peer for route 2 and route 3
+ routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers)
+ assert.Len(t, routesFirewallRules, 3)
- select {
- case <-done:
- case <-time.After(time.Second):
- t.Error("timeout waiting for peerShouldReceiveUpdate")
+ expectedRoutesFirewallRules = []*RouteFirewallRule{
+ {
+ SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"},
+ Action: "accept",
+ Destination: existingNetwork.String(),
+ Protocol: "tcp",
+ PortRange: RulePortRange{Start: 80, End: 350},
+ },
+ {
+ SourceRanges: []string{"0.0.0.0/0"},
+ Action: "accept",
+ Destination: "192.0.2.0/32",
+ Protocol: "all",
+ IsDynamic: true,
+ },
+ {
+ SourceRanges: []string{"::/0"},
+ Action: "accept",
+ Destination: "192.0.2.0/32",
+ Protocol: "all",
+ IsDynamic: true,
+ },
}
+ assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules)
+
+ // peerC is part of route1 distribution groups but should not receive the routes firewall rules
+ routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers)
+ assert.Len(t, routesFirewallRules, 0)
})
}
diff --git a/management/server/scheduler_test.go b/management/server/scheduler_test.go
index 7c287a554..fa279d4db 100644
--- a/management/server/scheduler_test.go
+++ b/management/server/scheduler_test.go
@@ -63,6 +63,7 @@ func TestScheduler_Cancel(t *testing.T) {
scheduler.Schedule(context.Background(), scheduletime, jobID2, func() (nextRunIn time.Duration, reschedule bool) {
return scheduletime, true
})
+ defer scheduler.Cancel(context.Background(), []string{jobID2})
time.Sleep(sleepTime)
assert.Len(t, scheduler.jobs, 2)
diff --git a/management/server/setupkey.go b/management/server/setupkey.go
index 8ae15726b..e84f8fcd6 100644
--- a/management/server/setupkey.go
+++ b/management/server/setupkey.go
@@ -328,26 +328,24 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
// ListSetupKeys returns a list of all setup keys of the account
func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
- account, err := am.Store.GetAccount(ctx, accountID)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
- user, err := account.FindUser(userID)
+ if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
+ return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys")
+ }
+
+ setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
- if !user.HasAdminPower() && !user.IsServiceUser {
- return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies")
- }
-
- keys := make([]*SetupKey, 0, len(account.SetupKeys))
- for _, key := range account.SetupKeys {
+ keys := make([]*SetupKey, 0, len(setupKeys))
+ for _, key := range setupKeys {
var k *SetupKey
- if !(user.HasAdminPower() || user.IsServiceUser) {
+ if !user.IsAdminOrServiceUser() {
k = key.HiddenCopy(999)
} else {
k = key.Copy()
@@ -360,44 +358,30 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
- user, err := account.FindUser(userID)
+ if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
+ return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys")
+ }
+
+ setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
if err != nil {
return nil, err
}
- if !user.HasAdminPower() && !user.IsServiceUser {
- return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies")
- }
-
- var foundKey *SetupKey
- for _, key := range account.SetupKeys {
- if key.Id == keyID {
- foundKey = key.Copy()
- break
- }
- }
- if foundKey == nil {
- return nil, status.Errorf(status.NotFound, "setup key not found")
- }
-
// the UpdatedAt field was introduced later, so there might be that some keys have a Zero value (e.g, null in the store file)
- if foundKey.UpdatedAt.IsZero() {
- foundKey.UpdatedAt = foundKey.CreatedAt
+ if setupKey.UpdatedAt.IsZero() {
+ setupKey.UpdatedAt = setupKey.CreatedAt
}
- if !(user.HasAdminPower() || user.IsServiceUser) {
- foundKey = foundKey.HiddenCopy(999)
+ if !user.IsAdminOrServiceUser() {
+ setupKey = setupKey.HiddenCopy(999)
}
- return foundKey, nil
+ return setupKey, nil
}
func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error {
diff --git a/management/server/sql_store.go b/management/server/sql_store.go
index c44ab7f09..85c68ef44 100644
--- a/management/server/sql_store.go
+++ b/management/server/sql_store.go
@@ -5,6 +5,7 @@ import (
"encoding/json"
"errors"
"fmt"
+ "net"
"os"
"path/filepath"
"runtime"
@@ -33,7 +34,9 @@ import (
const (
storeSqliteFileName = "store.db"
idQueryCondition = "id = ?"
+ keyQueryCondition = "key = ?"
accountAndIDQueryCondition = "account_id = ? and id = ?"
+ accountIDCondition = "account_id = ?"
peerNotFoundFMT = "peer %s not found"
)
@@ -134,6 +137,12 @@ func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (u
func (s *SqlStore) SaveAccount(ctx context.Context, account *Account) error {
start := time.Now()
+ defer func() {
+ elapsed := time.Since(start)
+ if elapsed > 1*time.Second {
+ log.WithContext(ctx).Tracef("SaveAccount for account %s exceeded 1s, took: %v", account.Id, elapsed)
+ }
+ }()
// todo: remove this check after the issue is resolved
s.checkAccountDomainBeforeSave(ctx, account.Id, account.Domain)
@@ -391,31 +400,40 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
}
func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) {
- var account Account
-
- result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?",
- strings.ToLower(domain), true, PrivateCategory)
- if result.Error != nil {
- if errors.Is(result.Error, gorm.ErrRecordNotFound) {
- return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
- }
- log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
- return nil, status.Errorf(status.Internal, "issue getting account from store")
+ accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain)
+ if err != nil {
+ return nil, err
}
// TODO: rework to not call GetAccount
- return s.GetAccount(ctx, account.Id)
+ return s.GetAccount(ctx, accountID)
+}
+
+func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
+ var accountID string
+ result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id").
+ Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
+ strings.ToLower(domain), true, PrivateCategory,
+ ).First(&accountID)
+ if result.Error != nil {
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
+ }
+ log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
+ return "", status.Errorf(status.Internal, "issue getting account from store")
+ }
+
+ return accountID, nil
}
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
var key SetupKey
- result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey))
+ result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, strings.ToUpper(setupKey))
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
- log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error)
- return nil, status.Errorf(status.Internal, "issue getting setup key from store")
+ return nil, status.NewSetupKeyNotFoundError()
}
if key.AccountID == "" {
@@ -468,6 +486,34 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
return &user, nil
}
+func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
+ var user User
+ result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
+ Preload(clause.Associations).First(&user, idQueryCondition, userID)
+ if result.Error != nil {
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return nil, status.NewUserNotFoundError(userID)
+ }
+ return nil, status.NewGetUserFromStoreError()
+ }
+
+ return &user, nil
+}
+
+func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
+ var groups []*nbgroup.Group
+ result := s.db.Find(&groups, accountIDCondition, accountID)
+ if result.Error != nil {
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
+ }
+ log.WithContext(ctx).Errorf("error when getting groups from the store: %s", result.Error)
+ return nil, status.Errorf(status.Internal, "issue getting groups from store")
+ }
+
+ return groups, nil
+}
+
func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) {
var accounts []Account
result := s.db.Find(&accounts)
@@ -485,6 +531,13 @@ func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) {
}
func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, error) {
+ start := time.Now()
+ defer func() {
+ elapsed := time.Since(start)
+ if elapsed > 1*time.Second {
+ log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed)
+ }
+ }()
var account Account
result := s.db.Model(&account).
@@ -494,7 +547,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
- return nil, status.Errorf(status.NotFound, "account not found")
+ return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.Errorf(status.Internal, "issue getting account from store")
}
@@ -554,7 +607,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) {
var user User
- result := s.db.Select("account_id").First(&user, idQueryCondition, userID)
+ result := s.db.WithContext(ctx).Select("account_id").First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
@@ -571,12 +624,11 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
var peer nbpeer.Peer
- result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID)
+ result := s.db.WithContext(ctx).Select("account_id").First(&peer, idQueryCondition, peerID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
- log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
}
@@ -590,12 +642,11 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
var peer nbpeer.Peer
- result := s.db.Select("account_id").First(&peer, "key = ?", peerKey)
+ result := s.db.WithContext(ctx).Select("account_id").First(&peer, keyQueryCondition, peerKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
- log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting account from store")
}
@@ -609,12 +660,11 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) {
var peer nbpeer.Peer
var accountID string
- result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID)
+ result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
- log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return "", status.Errorf(status.Internal, "issue getting account from store")
}
@@ -622,9 +672,8 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
}
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
- var user User
var accountID string
- result := s.db.Model(&user).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
+ result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
@@ -636,61 +685,117 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
}
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
- var key SetupKey
var accountID string
- result := s.db.Model(&key).Select("account_id").Where("key = ?", strings.ToUpper(setupKey)).First(&accountID)
+ result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, strings.ToUpper(setupKey)).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
- log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error)
- return "", status.Errorf(status.Internal, "issue getting setup key from store")
+ return "", status.NewSetupKeyNotFoundError()
+ }
+
+ if accountID == "" {
+ return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
return accountID, nil
}
-func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error) {
+func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
+ var ipJSONStrings []string
+
+ // Fetch the IP addresses as JSON strings
+ result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
+ Where("account_id = ?", accountID).
+ Pluck("ip", &ipJSONStrings)
+ if result.Error != nil {
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return nil, status.Errorf(status.NotFound, "no peers found for the account")
+ }
+ return nil, status.Errorf(status.Internal, "issue getting IPs from store")
+ }
+
+ // Convert the JSON strings to net.IP objects
+ ips := make([]net.IP, len(ipJSONStrings))
+ for i, ipJSON := range ipJSONStrings {
+ var ip net.IP
+ if err := json.Unmarshal([]byte(ipJSON), &ip); err != nil {
+ return nil, status.Errorf(status.Internal, "issue parsing IP JSON from store")
+ }
+ ips[i] = ip
+ }
+
+ return ips, nil
+}
+
+func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
+ var labels []string
+
+ result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
+ Where("account_id = ?", accountID).
+ Pluck("dns_label", &labels)
+
+ if result.Error != nil {
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return nil, status.Errorf(status.NotFound, "no peers found for the account")
+ }
+ log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error)
+ return nil, status.Errorf(status.Internal, "issue getting dns labels from store")
+ }
+
+ return labels, nil
+}
+
+func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
+ var accountNetwork AccountNetwork
+
+ if err := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
+ if errors.Is(err, gorm.ErrRecordNotFound) {
+ return nil, status.NewAccountNotFoundError(accountID)
+ }
+ return nil, status.Errorf(status.Internal, "issue getting network from store")
+ }
+ return accountNetwork.Network, nil
+}
+
+func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
var peer nbpeer.Peer
- result := s.db.First(&peer, "key = ?", peerKey)
+ result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer not found")
}
- log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting peer from store")
}
return &peer, nil
}
-func (s *SqlStore) GetAccountSettings(ctx context.Context, accountID string) (*Settings, error) {
+func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) {
var accountSettings AccountSettings
- if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
+ if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "settings not found")
}
- log.WithContext(ctx).Errorf("error when getting settings from the store: %s", err)
return nil, status.Errorf(status.Internal, "issue getting settings from store")
}
return accountSettings.Settings, nil
}
// SaveUserLastLogin stores the last login time for a user in DB.
-func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
+func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
var user User
- result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID)
+ result := s.db.WithContext(ctx).First(&user, accountAndIDQueryCondition, accountID, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
- return status.Errorf(status.NotFound, "user %s not found", userID)
+ return status.NewUserNotFoundError(userID)
}
- return status.Errorf(status.Internal, "issue getting user from store")
+ return status.NewGetUserFromStoreError()
}
-
user.LastLogin = lastLogin
- return s.db.Save(user).Error
+ return s.db.Save(&user).Error
}
func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
@@ -809,3 +914,276 @@ func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore,
return store, nil
}
+
+func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
+ var setupKey SetupKey
+ result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
+ First(&setupKey, keyQueryCondition, strings.ToUpper(key))
+ if result.Error != nil {
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return nil, status.Errorf(status.NotFound, "setup key not found")
+ }
+ return nil, status.NewSetupKeyNotFoundError()
+ }
+ return &setupKey, nil
+}
+
+func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
+ result := s.db.WithContext(ctx).Model(&SetupKey{}).
+ Where(idQueryCondition, setupKeyID).
+ Updates(map[string]interface{}{
+ "used_times": gorm.Expr("used_times + 1"),
+ "last_used": time.Now(),
+ })
+
+ if result.Error != nil {
+ return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error)
+ }
+
+ if result.RowsAffected == 0 {
+ return status.Errorf(status.NotFound, "setup key not found")
+ }
+
+ return nil
+}
+
+func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
+ var group nbgroup.Group
+
+ result := s.db.WithContext(ctx).Where("account_id = ? AND name = ?", accountID, "All").First(&group)
+ if result.Error != nil {
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return status.Errorf(status.NotFound, "group 'All' not found for account")
+ }
+ return status.Errorf(status.Internal, "issue finding group 'All'")
+ }
+
+ for _, existingPeerID := range group.Peers {
+ if existingPeerID == peerID {
+ return nil
+ }
+ }
+
+ group.Peers = append(group.Peers, peerID)
+
+ if err := s.db.Save(&group).Error; err != nil {
+ return status.Errorf(status.Internal, "issue updating group 'All'")
+ }
+
+ return nil
+}
+
+func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
+ var group nbgroup.Group
+
+ result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
+ if result.Error != nil {
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return status.Errorf(status.NotFound, "group not found for account")
+ }
+ return status.Errorf(status.Internal, "issue finding group")
+ }
+
+ for _, existingPeerID := range group.Peers {
+ if existingPeerID == peerId {
+ return nil
+ }
+ }
+
+ group.Peers = append(group.Peers, peerId)
+
+ if err := s.db.Save(&group).Error; err != nil {
+ return status.Errorf(status.Internal, "issue updating group")
+ }
+
+ return nil
+}
+
+func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
+ if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
+ return status.Errorf(status.Internal, "issue adding peer to account")
+ }
+
+ return nil
+}
+
+func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
+ result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
+ if result.Error != nil {
+ return status.Errorf(status.Internal, "issue incrementing network serial count")
+ }
+ return nil
+}
+
+func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
+ tx := s.db.WithContext(ctx).Begin()
+ if tx.Error != nil {
+ return tx.Error
+ }
+ repo := s.withTx(tx)
+ err := operation(repo)
+ if err != nil {
+ tx.Rollback()
+ return err
+ }
+ return tx.Commit().Error
+}
+
+func (s *SqlStore) withTx(tx *gorm.DB) Store {
+ return &SqlStore{
+ db: tx,
+ }
+}
+
+func (s *SqlStore) GetDB() *gorm.DB {
+ return s.db
+}
+
+func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) {
+ var accountDNSSettings AccountDNSSettings
+
+ result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
+ First(&accountDNSSettings, idQueryCondition, accountID)
+ if result.Error != nil {
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return nil, status.Errorf(status.NotFound, "dns settings not found")
+ }
+ return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error)
+ }
+ return &accountDNSSettings.DNSSettings, nil
+}
+
+// AccountExists checks whether an account exists by the given ID.
+func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
+ var accountID string
+
+ result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
+ Select("id").First(&accountID, idQueryCondition, id)
+ if result.Error != nil {
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return false, nil
+ }
+ return false, result.Error
+ }
+
+ return accountID != "", nil
+}
+
+// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
+func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
+ var account Account
+
+ result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category").
+ Where(idQueryCondition, accountID).First(&account)
+ if result.Error != nil {
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return "", "", status.Errorf(status.NotFound, "account not found")
+ }
+ return "", "", status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error)
+ }
+
+ return account.Domain, account.DomainCategory, nil
+}
+
+// GetGroupByID retrieves a group by ID and account ID.
+func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) {
+ return getRecordByID[nbgroup.Group](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, groupID, accountID)
+}
+
+// GetGroupByName retrieves a group by name and account ID.
+func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
+ var group nbgroup.Group
+
+ result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
+ Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID)
+ if err := result.Error; err != nil {
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return nil, status.Errorf(status.NotFound, "group not found")
+ }
+ return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error)
+ }
+ return &group, nil
+}
+
+// GetAccountPolicies retrieves policies for an account.
+func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
+ return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
+}
+
+// GetPolicyByID retrieves a policy by its ID and account ID.
+func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) {
+ return getRecordByID[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, policyID, accountID)
+}
+
+// GetAccountPostureChecks retrieves posture checks for an account.
+func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
+ return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID)
+}
+
+// GetPostureChecksByID retrieves posture checks by their ID and account ID.
+func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) {
+ return getRecordByID[posture.Checks](s.db.WithContext(ctx), lockStrength, postureCheckID, accountID)
+}
+
+// GetAccountRoutes retrieves network routes for an account.
+func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
+ return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID)
+}
+
+// GetRouteByID retrieves a route by its ID and account ID.
+func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) {
+ return getRecordByID[route.Route](s.db.WithContext(ctx), lockStrength, routeID, accountID)
+}
+
+// GetAccountSetupKeys retrieves setup keys for an account.
+func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
+ return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID)
+}
+
+// GetSetupKeyByID retrieves a setup key by its ID and account ID.
+func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) {
+ return getRecordByID[SetupKey](s.db.WithContext(ctx), lockStrength, setupKeyID, accountID)
+}
+
+// GetAccountNameServerGroups retrieves name server groups for an account.
+func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
+ return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID)
+}
+
+// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
+func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) {
+ return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID)
+}
+
+// getRecords retrieves records from the database based on the account ID.
+func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
+ var record []T
+
+ result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID)
+ if err := result.Error; err != nil {
+ parts := strings.Split(fmt.Sprintf("%T", record), ".")
+ recordType := parts[len(parts)-1]
+
+ return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err)
+ }
+
+ return record, nil
+}
+
+// getRecordByID retrieves a record by its ID and account ID from the database.
+func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) (*T, error) {
+ var record T
+
+ result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).
+ First(&record, accountAndIDQueryCondition, accountID, recordID)
+ if err := result.Error; err != nil {
+ parts := strings.Split(fmt.Sprintf("%T", record), ".")
+ recordType := parts[len(parts)-1]
+
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return nil, status.Errorf(status.NotFound, "%s not found", recordType)
+ }
+ return nil, status.Errorf(status.Internal, "failed to get %s from store: %v", recordType, err)
+ }
+ return &record, nil
+}
diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go
index ce4ee531a..64ef36831 100644
--- a/management/server/sql_store_test.go
+++ b/management/server/sql_store_test.go
@@ -1003,3 +1003,163 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) {
require.NoError(t, err)
require.Equal(t, id, user.PATs[id].ID)
}
+
+func TestSqlite_GetTakenIPs(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("The SQLite store is not properly supported by Windows yet")
+ }
+
+ store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
+ defer store.Close(context.Background())
+
+ existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+
+ _, err := store.GetAccount(context.Background(), existingAccountID)
+ require.NoError(t, err)
+
+ takenIPs, err := store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
+ require.NoError(t, err)
+ assert.Equal(t, []net.IP{}, takenIPs)
+
+ peer1 := &nbpeer.Peer{
+ ID: "peer1",
+ AccountID: existingAccountID,
+ IP: net.IP{1, 1, 1, 1},
+ }
+ err = store.AddPeerToAccount(context.Background(), peer1)
+ require.NoError(t, err)
+
+ takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
+ require.NoError(t, err)
+ ip1 := net.IP{1, 1, 1, 1}.To16()
+ assert.Equal(t, []net.IP{ip1}, takenIPs)
+
+ peer2 := &nbpeer.Peer{
+ ID: "peer2",
+ AccountID: existingAccountID,
+ IP: net.IP{2, 2, 2, 2},
+ }
+ err = store.AddPeerToAccount(context.Background(), peer2)
+ require.NoError(t, err)
+
+ takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID)
+ require.NoError(t, err)
+ ip2 := net.IP{2, 2, 2, 2}.To16()
+ assert.Equal(t, []net.IP{ip1, ip2}, takenIPs)
+
+}
+
+func TestSqlite_GetPeerLabelsInAccount(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("The SQLite store is not properly supported by Windows yet")
+ }
+
+ store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
+ defer store.Close(context.Background())
+
+ existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+
+ _, err := store.GetAccount(context.Background(), existingAccountID)
+ require.NoError(t, err)
+
+ labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
+ require.NoError(t, err)
+ assert.Equal(t, []string{}, labels)
+
+ peer1 := &nbpeer.Peer{
+ ID: "peer1",
+ AccountID: existingAccountID,
+ DNSLabel: "peer1.domain.test",
+ }
+ err = store.AddPeerToAccount(context.Background(), peer1)
+ require.NoError(t, err)
+
+ labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
+ require.NoError(t, err)
+ assert.Equal(t, []string{"peer1.domain.test"}, labels)
+
+ peer2 := &nbpeer.Peer{
+ ID: "peer2",
+ AccountID: existingAccountID,
+ DNSLabel: "peer2.domain.test",
+ }
+ err = store.AddPeerToAccount(context.Background(), peer2)
+ require.NoError(t, err)
+
+ labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID)
+ require.NoError(t, err)
+ assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels)
+}
+
+func TestSqlite_GetAccountNetwork(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("The SQLite store is not properly supported by Windows yet")
+ }
+
+ store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
+ defer store.Close(context.Background())
+
+ existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+
+ _, err := store.GetAccount(context.Background(), existingAccountID)
+ require.NoError(t, err)
+
+ network, err := store.GetAccountNetwork(context.Background(), LockingStrengthShare, existingAccountID)
+ require.NoError(t, err)
+ ip := net.IP{100, 64, 0, 0}.To16()
+ assert.Equal(t, ip, network.Net.IP)
+ assert.Equal(t, net.IPMask{255, 255, 0, 0}, network.Net.Mask)
+ assert.Equal(t, "", network.Dns)
+ assert.Equal(t, "af1c8024-ha40-4ce2-9418-34653101fc3c", network.Identifier)
+ assert.Equal(t, uint64(0), network.Serial)
+}
+
+func TestSqlite_GetSetupKeyBySecret(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("The SQLite store is not properly supported by Windows yet")
+ }
+ store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
+ defer store.Close(context.Background())
+
+ existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+
+ _, err := store.GetAccount(context.Background(), existingAccountID)
+ require.NoError(t, err)
+
+ setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
+ require.NoError(t, err)
+ assert.Equal(t, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", setupKey.Key)
+ assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", setupKey.AccountID)
+ assert.Equal(t, "Default key", setupKey.Name)
+}
+
+func TestSqlite_incrementSetupKeyUsage(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("The SQLite store is not properly supported by Windows yet")
+ }
+ store := newSqliteStoreFromFile(t, "testdata/extended-store.json")
+ defer store.Close(context.Background())
+
+ existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+
+ _, err := store.GetAccount(context.Background(), existingAccountID)
+ require.NoError(t, err)
+
+ setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
+ require.NoError(t, err)
+ assert.Equal(t, 0, setupKey.UsedTimes)
+
+ err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
+ require.NoError(t, err)
+
+ setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
+ require.NoError(t, err)
+ assert.Equal(t, 1, setupKey.UsedTimes)
+
+ err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id)
+ require.NoError(t, err)
+
+ setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB")
+ require.NoError(t, err)
+ assert.Equal(t, 2, setupKey.UsedTimes)
+}
diff --git a/management/server/status/error.go b/management/server/status/error.go
index 58b9a84a0..d7fde35b9 100644
--- a/management/server/status/error.go
+++ b/management/server/status/error.go
@@ -100,3 +100,13 @@ func NewPeerNotRegisteredError() error {
func NewPeerLoginExpiredError() error {
return Errorf(PermissionDenied, "peer login has expired, please log in once more")
}
+
+// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key
+func NewSetupKeyNotFoundError() error {
+ return Errorf(NotFound, "setup key not found")
+}
+
+// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store
+func NewGetUserFromStoreError() error {
+ return Errorf(Internal, "issue getting user from store")
+}
diff --git a/management/server/store.go b/management/server/store.go
index 864871c8e..f34a73c2d 100644
--- a/management/server/store.go
+++ b/management/server/store.go
@@ -12,6 +12,7 @@ import (
"strings"
"time"
+ "github.com/netbirdio/netbird/dns"
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
@@ -27,45 +28,94 @@ import (
"github.com/netbirdio/netbird/route"
)
+type LockingStrength string
+
+const (
+ LockingStrengthUpdate LockingStrength = "UPDATE" // Strongest lock, preventing any changes by other transactions until your transaction completes.
+ LockingStrengthShare LockingStrength = "SHARE" // Allows reading but prevents changes by other transactions.
+ LockingStrengthNoKeyUpdate LockingStrength = "NO KEY UPDATE" // Similar to UPDATE but allows changes to related rows.
+ LockingStrengthKeyShare LockingStrength = "KEY SHARE" // Protects against changes to primary/unique keys but allows other updates.
+)
+
type Store interface {
GetAllAccounts(ctx context.Context) []*Account
GetAccount(ctx context.Context, accountID string) (*Account, error)
- DeleteAccount(ctx context.Context, account *Account) error
+ AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error)
+ GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error)
GetAccountByUser(ctx context.Context, userID string) (*Account, error)
GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error)
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
- GetAccountIDByUserID(peerKey string) (string, error)
+ GetAccountIDByUserID(userID string) (string, error)
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
- GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
- GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
- GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
+ GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
+ GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
+ GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
SaveAccount(ctx context.Context, account *Account) error
+ DeleteAccount(ctx context.Context, account *Account) error
+
+ GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
+ GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
SaveUsers(accountID string, users map[string]*User) error
- SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
+ SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
+ GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error
+
+ GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
+ GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
+ GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
+ SaveGroups(accountID string, groups map[string]*nbgroup.Group) error
+
+ GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
+ GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
+
+ GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
+ GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
+ GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
+
+ GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
+ AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
+ AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
+ AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
+ GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
+ SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
+ SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
+ SavePeerLocation(accountID string, peer *nbpeer.Peer) error
+
+ GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
+ IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
+ GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error)
+ GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error)
+
+ GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
+ GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
+
+ GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
+ GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
+
+ GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
+ IncrementNetworkSerial(ctx context.Context, accountId string) error
+ GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
+
GetInstallationID() string
SaveInstallationID(ctx context.Context, ID string) error
+
// AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock
AcquireWriteLockByUID(ctx context.Context, uniqueID string) func()
// AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock
AcquireReadLockByUID(ctx context.Context, uniqueID string) func()
// AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock
AcquireGlobalLock(ctx context.Context) func()
- SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
- SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
- SavePeerLocation(accountID string, peer *nbpeer.Peer) error
- SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error
+
// Close should close the store persisting all unsaved data.
Close(ctx context.Context) error
// GetStoreEngine should return StoreEngine of the current store implementation.
// This is also a method of metrics.DataSource interface.
GetStoreEngine() StoreEngine
- GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error)
- GetAccountSettings(ctx context.Context, accountID string) (*Settings, error)
+ ExecuteInTransaction(ctx context.Context, f func(store Store) error) error
}
type StoreEngine string
diff --git a/management/server/telemetry/http_api_metrics.go b/management/server/telemetry/http_api_metrics.go
index a80453dca..357f019c7 100644
--- a/management/server/telemetry/http_api_metrics.go
+++ b/management/server/telemetry/http_api_metrics.go
@@ -8,6 +8,7 @@ import (
"time"
"github.com/google/uuid"
+ "github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
@@ -54,112 +55,89 @@ func (rw *WrappedResponseWriter) WriteHeader(code int) {
// HTTPMiddleware handler used to collect metrics of every request/response coming to the API.
// Also adds request tracing (logging).
type HTTPMiddleware struct {
- meter metric.Meter
- ctx context.Context
+ ctx context.Context
// all HTTP requests by endpoint & method
- httpRequestCounters map[string]metric.Int64Counter
+ httpRequestCounter metric.Int64Counter
// all HTTP responses by endpoint & method & status code
- httpResponseCounters map[string]metric.Int64Counter
+ httpResponseCounter metric.Int64Counter
// all HTTP requests
totalHTTPRequestsCounter metric.Int64Counter
// all HTTP responses
totalHTTPResponseCounter metric.Int64Counter
// all HTTP responses by status code
- totalHTTPResponseCodeCounters map[int]metric.Int64Counter
+ totalHTTPResponseCodeCounter metric.Int64Counter
// all HTTP requests durations by endpoint and method
- httpRequestDurations map[string]metric.Int64Histogram
+ httpRequestDuration metric.Int64Histogram
// all HTTP requests durations
totalHTTPRequestDuration metric.Int64Histogram
}
// NewMetricsMiddleware creates a new HTTPMiddleware
func NewMetricsMiddleware(ctx context.Context, meter metric.Meter) (*HTTPMiddleware, error) {
- totalHTTPRequestsCounter, err := meter.Int64Counter(fmt.Sprintf("%s_total", httpRequestCounterPrefix), metric.WithUnit("1"))
+ httpRequestCounter, err := meter.Int64Counter(httpRequestCounterPrefix, metric.WithUnit("1"))
if err != nil {
return nil, err
}
- totalHTTPResponseCounter, err := meter.Int64Counter(fmt.Sprintf("%s_total", httpResponseCounterPrefix), metric.WithUnit("1"))
+ httpResponseCounter, err := meter.Int64Counter(httpResponseCounterPrefix, metric.WithUnit("1"))
if err != nil {
return nil, err
}
- totalHTTPRequestDuration, err := meter.Int64Histogram(fmt.Sprintf("%s_total", httpRequestDurationPrefix), metric.WithUnit("milliseconds"))
+ totalHTTPRequestsCounter, err := meter.Int64Counter(fmt.Sprintf("%s.total", httpRequestCounterPrefix), metric.WithUnit("1"))
+ if err != nil {
+ return nil, err
+ }
+
+ totalHTTPResponseCounter, err := meter.Int64Counter(fmt.Sprintf("%s.total", httpResponseCounterPrefix), metric.WithUnit("1"))
+ if err != nil {
+ return nil, err
+ }
+
+ totalHTTPResponseCodeCounter, err := meter.Int64Counter(fmt.Sprintf("%s.code.total", httpResponseCounterPrefix), metric.WithUnit("1"))
+ if err != nil {
+ return nil, err
+ }
+
+ httpRequestDuration, err := meter.Int64Histogram(httpRequestDurationPrefix, metric.WithUnit("milliseconds"))
+ if err != nil {
+ return nil, err
+ }
+
+ totalHTTPRequestDuration, err := meter.Int64Histogram(fmt.Sprintf("%s.total", httpRequestDurationPrefix), metric.WithUnit("milliseconds"))
if err != nil {
return nil, err
}
return &HTTPMiddleware{
- ctx: ctx,
- httpRequestCounters: map[string]metric.Int64Counter{},
- httpResponseCounters: map[string]metric.Int64Counter{},
- httpRequestDurations: map[string]metric.Int64Histogram{},
- totalHTTPResponseCodeCounters: map[int]metric.Int64Counter{},
- meter: meter,
- totalHTTPRequestsCounter: totalHTTPRequestsCounter,
- totalHTTPResponseCounter: totalHTTPResponseCounter,
- totalHTTPRequestDuration: totalHTTPRequestDuration,
+ ctx: ctx,
+ httpRequestCounter: httpRequestCounter,
+ httpResponseCounter: httpResponseCounter,
+ httpRequestDuration: httpRequestDuration,
+ totalHTTPResponseCodeCounter: totalHTTPResponseCodeCounter,
+ totalHTTPRequestsCounter: totalHTTPRequestsCounter,
+ totalHTTPResponseCounter: totalHTTPResponseCounter,
+ totalHTTPRequestDuration: totalHTTPRequestDuration,
},
nil
}
-// AddHTTPRequestResponseCounter adds a new meter for an HTTP defaultEndpoint and Method (GET, POST, etc)
-// Creates one request counter and multiple response counters (one per http response status code).
-func (m *HTTPMiddleware) AddHTTPRequestResponseCounter(endpoint string, method string) error {
- meterKey := getRequestCounterKey(endpoint, method)
- httpReqCounter, err := m.meter.Int64Counter(meterKey, metric.WithUnit("1"))
- if err != nil {
- return err
- }
- m.httpRequestCounters[meterKey] = httpReqCounter
-
- durationKey := getRequestDurationKey(endpoint, method)
- requestDuration, err := m.meter.Int64Histogram(durationKey, metric.WithUnit("milliseconds"))
- if err != nil {
- return err
- }
- m.httpRequestDurations[durationKey] = requestDuration
-
- respCodes := []int{200, 204, 400, 401, 403, 404, 500, 502, 503}
- for _, code := range respCodes {
- meterKey = getResponseCounterKey(endpoint, method, code)
- httpRespCounter, err := m.meter.Int64Counter(meterKey, metric.WithUnit("1"))
- if err != nil {
- return err
- }
- m.httpResponseCounters[meterKey] = httpRespCounter
-
- meterKey = fmt.Sprintf("%s_%d_total", httpResponseCounterPrefix, code)
- totalHTTPResponseCodeCounter, err := m.meter.Int64Counter(meterKey, metric.WithUnit("1"))
- if err != nil {
- return err
- }
- m.totalHTTPResponseCodeCounters[code] = totalHTTPResponseCodeCounter
- }
-
- return nil
-}
-
func replaceEndpointChars(endpoint string) string {
- endpoint = strings.ReplaceAll(endpoint, "/", "_")
endpoint = strings.ReplaceAll(endpoint, "{", "")
endpoint = strings.ReplaceAll(endpoint, "}", "")
return endpoint
}
-func getRequestCounterKey(endpoint, method string) string {
- endpoint = replaceEndpointChars(endpoint)
- return fmt.Sprintf("%s%s_%s", httpRequestCounterPrefix, endpoint, method)
-}
-
-func getRequestDurationKey(endpoint, method string) string {
- endpoint = replaceEndpointChars(endpoint)
- return fmt.Sprintf("%s%s_%s", httpRequestDurationPrefix, endpoint, method)
-}
-
-func getResponseCounterKey(endpoint, method string, status int) string {
- endpoint = replaceEndpointChars(endpoint)
- return fmt.Sprintf("%s%s_%s_%d", httpResponseCounterPrefix, endpoint, method, status)
+func getEndpointMetricAttr(r *http.Request) string {
+ var endpoint string
+ route := mux.CurrentRoute(r)
+ if route != nil {
+ pathTmpl, err := route.GetPathTemplate()
+ if err == nil {
+ endpoint = replaceEndpointChars(pathTmpl)
+ }
+ }
+ return endpoint
}
// Handler logs every request and response and adds the, to metrics.
@@ -176,11 +154,10 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
log.WithContext(ctx).Tracef("HTTP request %v: %v %v", reqID, r.Method, r.URL)
- metricKey := getRequestCounterKey(r.URL.Path, r.Method)
+ endpointAttr := attribute.String("endpoint", getEndpointMetricAttr(r))
+ methodAttr := attribute.String("method", r.Method)
- if c, ok := m.httpRequestCounters[metricKey]; ok {
- c.Add(m.ctx, 1)
- }
+ m.httpRequestCounter.Add(m.ctx, 1, metric.WithAttributes(endpointAttr, methodAttr))
m.totalHTTPRequestsCounter.Add(m.ctx, 1)
w := WrapResponseWriter(rw)
@@ -193,21 +170,14 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
log.WithContext(ctx).Tracef("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status())
}
- metricKey = getResponseCounterKey(r.URL.Path, r.Method, w.Status())
- if c, ok := m.httpResponseCounters[metricKey]; ok {
- c.Add(m.ctx, 1)
- }
+ statusCodeAttr := attribute.Int("code", w.Status())
+ m.httpResponseCounter.Add(m.ctx, 1, metric.WithAttributes(endpointAttr, methodAttr, statusCodeAttr))
m.totalHTTPResponseCounter.Add(m.ctx, 1)
- if c, ok := m.totalHTTPResponseCodeCounters[w.Status()]; ok {
- c.Add(m.ctx, 1)
- }
+ m.totalHTTPResponseCodeCounter.Add(m.ctx, 1, metric.WithAttributes(statusCodeAttr))
- durationKey := getRequestDurationKey(r.URL.Path, r.Method)
reqTook := time.Since(reqStart)
- if c, ok := m.httpRequestDurations[durationKey]; ok {
- c.Record(m.ctx, reqTook.Milliseconds())
- }
+ m.httpRequestDuration.Record(m.ctx, reqTook.Milliseconds(), metric.WithAttributes(endpointAttr, methodAttr))
log.WithContext(ctx).Debugf("request %s %s took %d ms and finished with status %d", r.Method, r.URL.Path, reqTook.Milliseconds(), w.Status())
if w.Status() == 200 && (r.Method == http.MethodPut || r.Method == http.MethodPost || r.Method == http.MethodDelete) {
diff --git a/management/server/testdata/GeoLite2-City-Test.mmdb b/management/server/testdata/GeoLite2-City_20240305.mmdb
similarity index 100%
rename from management/server/testdata/GeoLite2-City-Test.mmdb
rename to management/server/testdata/GeoLite2-City_20240305.mmdb
diff --git a/management/server/testdata/extended-store.json b/management/server/testdata/extended-store.json
new file mode 100644
index 000000000..7f96e57a8
--- /dev/null
+++ b/management/server/testdata/extended-store.json
@@ -0,0 +1,120 @@
+{
+ "Accounts": {
+ "bf1c8084-ba50-4ce7-9439-34653001fc3b": {
+ "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b",
+ "CreatedBy": "",
+ "Domain": "test.com",
+ "DomainCategory": "private",
+ "IsDomainPrimaryAccount": true,
+ "SetupKeys": {
+ "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": {
+ "Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
+ "AccountID": "",
+ "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB",
+ "Name": "Default key",
+ "Type": "reusable",
+ "CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
+ "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
+ "UpdatedAt": "0001-01-01T00:00:00Z",
+ "Revoked": false,
+ "UsedTimes": 0,
+ "LastUsed": "0001-01-01T00:00:00Z",
+ "AutoGroups": ["cfefqs706sqkneg59g2g"],
+ "UsageLimit": 0,
+ "Ephemeral": false
+ },
+ "A2C8E62B-38F5-4553-B31E-DD66C696CEBC": {
+ "Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC",
+ "AccountID": "",
+ "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC",
+ "Name": "Faulty key with non existing group",
+ "Type": "reusable",
+ "CreatedAt": "2021-08-19T20:46:20.005936822+02:00",
+ "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00",
+ "UpdatedAt": "0001-01-01T00:00:00Z",
+ "Revoked": false,
+ "UsedTimes": 0,
+ "LastUsed": "0001-01-01T00:00:00Z",
+ "AutoGroups": ["abcd"],
+ "UsageLimit": 0,
+ "Ephemeral": false
+ }
+ },
+ "Network": {
+ "id": "af1c8024-ha40-4ce2-9418-34653101fc3c",
+ "Net": {
+ "IP": "100.64.0.0",
+ "Mask": "//8AAA=="
+ },
+ "Dns": "",
+ "Serial": 0
+ },
+ "Peers": {},
+ "Users": {
+ "edafee4e-63fb-11ec-90d6-0242ac120003": {
+ "Id": "edafee4e-63fb-11ec-90d6-0242ac120003",
+ "AccountID": "",
+ "Role": "admin",
+ "IsServiceUser": false,
+ "ServiceUserName": "",
+ "AutoGroups": ["cfefqs706sqkneg59g3g"],
+ "PATs": {},
+ "Blocked": false,
+ "LastLogin": "0001-01-01T00:00:00Z"
+ },
+ "f4f6d672-63fb-11ec-90d6-0242ac120003": {
+ "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003",
+ "AccountID": "",
+ "Role": "user",
+ "IsServiceUser": false,
+ "ServiceUserName": "",
+ "AutoGroups": null,
+ "PATs": {
+ "9dj38s35-63fb-11ec-90d6-0242ac120003": {
+ "ID": "9dj38s35-63fb-11ec-90d6-0242ac120003",
+ "UserID": "",
+ "Name": "",
+ "HashedToken": "SoMeHaShEdToKeN",
+ "ExpirationDate": "2023-02-27T00:00:00Z",
+ "CreatedBy": "user",
+ "CreatedAt": "2023-01-01T00:00:00Z",
+ "LastUsed": "2023-02-01T00:00:00Z"
+ }
+ },
+ "Blocked": false,
+ "LastLogin": "0001-01-01T00:00:00Z"
+ }
+ },
+ "Groups": {
+ "cfefqs706sqkneg59g4g": {
+ "ID": "cfefqs706sqkneg59g4g",
+ "Name": "All",
+ "Peers": []
+ },
+ "cfefqs706sqkneg59g3g": {
+ "ID": "cfefqs706sqkneg59g3g",
+ "Name": "AwesomeGroup1",
+ "Peers": []
+ },
+ "cfefqs706sqkneg59g2g": {
+ "ID": "cfefqs706sqkneg59g2g",
+ "Name": "AwesomeGroup2",
+ "Peers": []
+ }
+ },
+ "Rules": null,
+ "Policies": [],
+ "Routes": null,
+ "NameServerGroups": null,
+ "DNSSettings": null,
+ "Settings": {
+ "PeerLoginExpirationEnabled": false,
+ "PeerLoginExpiration": 86400000000000,
+ "GroupsPropagationEnabled": false,
+ "JWTGroupsEnabled": false,
+ "JWTGroupsClaimName": ""
+ }
+ }
+ },
+ "InstallationID": ""
+}
diff --git a/management/server/testdata/geonames-test.db b/management/server/testdata/geonames_20240305.db
similarity index 100%
rename from management/server/testdata/geonames-test.db
rename to management/server/testdata/geonames_20240305.db
diff --git a/management/server/token_mgr.go b/management/server/token_mgr.go
new file mode 100644
index 000000000..ef8276b59
--- /dev/null
+++ b/management/server/token_mgr.go
@@ -0,0 +1,232 @@
+package server
+
+import (
+ "context"
+ "crypto/sha1"
+ "crypto/sha256"
+ "encoding/base64"
+ "fmt"
+ "sync"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/management/proto"
+ auth "github.com/netbirdio/netbird/relay/auth/hmac"
+ authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
+)
+
+const defaultDuration = 12 * time.Hour
+
+// SecretsManager used to manage TURN and relay secrets
+type SecretsManager interface {
+ GenerateTurnToken() (*Token, error)
+ GenerateRelayToken() (*Token, error)
+ SetupRefresh(ctx context.Context, peerKey string)
+ CancelRefresh(peerKey string)
+}
+
+// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
+type TimeBasedAuthSecretsManager struct {
+ mux sync.Mutex
+ turnCfg *TURNConfig
+ relayCfg *Relay
+ turnHmacToken *auth.TimedHMAC
+ relayHmacToken *authv2.Generator
+ updateManager *PeersUpdateManager
+ turnCancelMap map[string]chan struct{}
+ relayCancelMap map[string]chan struct{}
+}
+
+type Token auth.Token
+
+func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *TURNConfig, relayCfg *Relay) *TimeBasedAuthSecretsManager {
+ mgr := &TimeBasedAuthSecretsManager{
+ updateManager: updateManager,
+ turnCfg: turnCfg,
+ relayCfg: relayCfg,
+ turnCancelMap: make(map[string]chan struct{}),
+ relayCancelMap: make(map[string]chan struct{}),
+ }
+
+ if turnCfg != nil {
+ duration := turnCfg.CredentialsTTL.Duration
+ if turnCfg.CredentialsTTL.Duration <= 0 {
+ log.Warnf("TURN credentials TTL is not set or invalid, using default value %s", defaultDuration)
+ duration = defaultDuration
+ }
+ mgr.turnHmacToken = auth.NewTimedHMAC(turnCfg.Secret, duration)
+ }
+
+ if relayCfg != nil {
+ duration := relayCfg.CredentialsTTL.Duration
+ if relayCfg.CredentialsTTL.Duration <= 0 {
+ log.Warnf("Relay credentials TTL is not set or invalid, using default value %s", defaultDuration)
+ duration = defaultDuration
+ }
+
+ hashedSecret := sha256.Sum256([]byte(relayCfg.Secret))
+ var err error
+ if mgr.relayHmacToken, err = authv2.NewGenerator(authv2.AuthAlgoHMACSHA256, hashedSecret[:], duration); err != nil {
+ log.Errorf("failed to create relay token generator: %s", err)
+ }
+ }
+
+ return mgr
+}
+
+// GenerateTurnToken generates new time-based secret credentials for TURN
+func (m *TimeBasedAuthSecretsManager) GenerateTurnToken() (*Token, error) {
+ if m.turnHmacToken == nil {
+ return nil, fmt.Errorf("TURN configuration is not set")
+ }
+ turnToken, err := m.turnHmacToken.GenerateToken(sha1.New)
+ if err != nil {
+ return nil, fmt.Errorf("generate TURN token: %s", err)
+ }
+ return (*Token)(turnToken), nil
+}
+
+// GenerateRelayToken generates new time-based secret credentials for relay
+func (m *TimeBasedAuthSecretsManager) GenerateRelayToken() (*Token, error) {
+ if m.relayHmacToken == nil {
+ return nil, fmt.Errorf("relay configuration is not set")
+ }
+ relayToken, err := m.relayHmacToken.GenerateToken()
+ if err != nil {
+ return nil, fmt.Errorf("generate relay token: %s", err)
+ }
+
+ return &Token{
+ Payload: string(relayToken.Payload),
+ Signature: base64.StdEncoding.EncodeToString(relayToken.Signature),
+ }, nil
+}
+
+func (m *TimeBasedAuthSecretsManager) cancelTURN(peerID string) {
+ if channel, ok := m.turnCancelMap[peerID]; ok {
+ close(channel)
+ delete(m.turnCancelMap, peerID)
+ }
+}
+
+func (m *TimeBasedAuthSecretsManager) cancelRelay(peerID string) {
+ if channel, ok := m.relayCancelMap[peerID]; ok {
+ close(channel)
+ delete(m.relayCancelMap, peerID)
+ }
+}
+
+// CancelRefresh cancels scheduled peer credentials refresh
+func (m *TimeBasedAuthSecretsManager) CancelRefresh(peerID string) {
+ m.mux.Lock()
+ defer m.mux.Unlock()
+ m.cancelTURN(peerID)
+ m.cancelRelay(peerID)
+}
+
+// SetupRefresh starts peer credentials refresh
+func (m *TimeBasedAuthSecretsManager) SetupRefresh(ctx context.Context, peerID string) {
+ m.mux.Lock()
+ defer m.mux.Unlock()
+
+ m.cancelTURN(peerID)
+ m.cancelRelay(peerID)
+
+ if m.turnCfg != nil && m.turnCfg.TimeBasedCredentials {
+ turnCancel := make(chan struct{}, 1)
+ m.turnCancelMap[peerID] = turnCancel
+ go m.refreshTURNTokens(ctx, peerID, turnCancel)
+ log.WithContext(ctx).Debugf("starting TURN refresh for %s", peerID)
+ }
+
+ if m.relayCfg != nil {
+ relayCancel := make(chan struct{}, 1)
+ m.relayCancelMap[peerID] = relayCancel
+ go m.refreshRelayTokens(ctx, peerID, relayCancel)
+ log.WithContext(ctx).Debugf("starting relay refresh for %s", peerID)
+ }
+}
+
+func (m *TimeBasedAuthSecretsManager) refreshTURNTokens(ctx context.Context, peerID string, cancel chan struct{}) {
+ ticker := time.NewTicker(m.turnCfg.CredentialsTTL.Duration / 4 * 3)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-cancel:
+ log.WithContext(ctx).Debugf("stopping TURN refresh for %s", peerID)
+ return
+ case <-ticker.C:
+ m.pushNewTURNTokens(ctx, peerID)
+ }
+ }
+}
+
+func (m *TimeBasedAuthSecretsManager) refreshRelayTokens(ctx context.Context, peerID string, cancel chan struct{}) {
+ ticker := time.NewTicker(m.relayCfg.CredentialsTTL.Duration / 4 * 3)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-cancel:
+ log.WithContext(ctx).Debugf("stopping relay refresh for %s", peerID)
+ return
+ case <-ticker.C:
+ m.pushNewRelayTokens(ctx, peerID)
+ }
+ }
+}
+
+func (m *TimeBasedAuthSecretsManager) pushNewTURNTokens(ctx context.Context, peerID string) {
+ turnToken, err := m.turnHmacToken.GenerateToken(sha1.New)
+ if err != nil {
+ log.Errorf("failed to generate token for peer '%s': %s", peerID, err)
+ return
+ }
+
+ var turns []*proto.ProtectedHostConfig
+ for _, host := range m.turnCfg.Turns {
+ turn := &proto.ProtectedHostConfig{
+ HostConfig: &proto.HostConfig{
+ Uri: host.URI,
+ Protocol: ToResponseProto(host.Proto),
+ },
+ User: turnToken.Payload,
+ Password: turnToken.Signature,
+ }
+ turns = append(turns, turn)
+ }
+
+ update := &proto.SyncResponse{
+ WiretrusteeConfig: &proto.WiretrusteeConfig{
+ Turns: turns,
+ // omit Relay to avoid updates there
+ },
+ }
+
+ log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
+ m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
+}
+
+func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, peerID string) {
+ relayToken, err := m.relayHmacToken.GenerateToken()
+ if err != nil {
+ log.Errorf("failed to generate relay token for peer '%s': %s", peerID, err)
+ return
+ }
+
+ update := &proto.SyncResponse{
+ WiretrusteeConfig: &proto.WiretrusteeConfig{
+ Relay: &proto.RelayConfig{
+ Urls: m.relayCfg.Addresses,
+ TokenPayload: string(relayToken.Payload),
+ TokenSignature: base64.StdEncoding.EncodeToString(relayToken.Signature),
+ },
+ // omit Turns to avoid updates there
+ },
+ }
+
+ log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
+ m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
+}
diff --git a/management/server/token_mgr_test.go b/management/server/token_mgr_test.go
new file mode 100644
index 000000000..3e63346c2
--- /dev/null
+++ b/management/server/token_mgr_test.go
@@ -0,0 +1,219 @@
+package server
+
+import (
+ "context"
+ "crypto/hmac"
+ "crypto/sha1"
+ "crypto/sha256"
+ "encoding/base64"
+ "hash"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+
+ "github.com/netbirdio/netbird/management/proto"
+ "github.com/netbirdio/netbird/util"
+)
+
+var TurnTestHost = &Host{
+ Proto: UDP,
+ URI: "turn:turn.wiretrustee.com:77777",
+ Username: "username",
+ Password: "",
+}
+
+func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
+ ttl := util.Duration{Duration: time.Hour}
+ secret := "some_secret"
+ peersManager := NewPeersUpdateManager(nil)
+
+ rc := &Relay{
+ Addresses: []string{"localhost:0"},
+ CredentialsTTL: ttl,
+ Secret: secret,
+ }
+
+ tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{
+ CredentialsTTL: ttl,
+ Secret: secret,
+ Turns: []*Host{TurnTestHost},
+ TimeBasedCredentials: true,
+ }, rc)
+
+ turnCredentials, err := tested.GenerateTurnToken()
+ require.NoError(t, err)
+
+ if turnCredentials.Payload == "" {
+ t.Errorf("expected generated TURN username not to be empty, got empty")
+ }
+ if turnCredentials.Signature == "" {
+ t.Errorf("expected generated TURN password not to be empty, got empty")
+ }
+
+ validateMAC(t, sha1.New, turnCredentials.Payload, turnCredentials.Signature, []byte(secret))
+
+ relayCredentials, err := tested.GenerateRelayToken()
+ require.NoError(t, err)
+
+ if relayCredentials.Payload == "" {
+ t.Errorf("expected generated relay payload not to be empty, got empty")
+ }
+ if relayCredentials.Signature == "" {
+ t.Errorf("expected generated relay signature not to be empty, got empty")
+ }
+
+ hashedSecret := sha256.Sum256([]byte(secret))
+ validateMAC(t, sha256.New, relayCredentials.Payload, relayCredentials.Signature, hashedSecret[:])
+}
+
+func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
+ ttl := util.Duration{Duration: 2 * time.Second}
+ secret := "some_secret"
+ peersManager := NewPeersUpdateManager(nil)
+ peer := "some_peer"
+ updateChannel := peersManager.CreateChannel(context.Background(), peer)
+
+ rc := &Relay{
+ Addresses: []string{"localhost:0"},
+ CredentialsTTL: ttl,
+ Secret: secret,
+ }
+ tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{
+ CredentialsTTL: ttl,
+ Secret: secret,
+ Turns: []*Host{TurnTestHost},
+ TimeBasedCredentials: true,
+ }, rc)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ tested.SetupRefresh(ctx, peer)
+
+ if _, ok := tested.turnCancelMap[peer]; !ok {
+ t.Errorf("expecting peer to be present in the turn cancel map, got not present")
+ }
+
+ if _, ok := tested.relayCancelMap[peer]; !ok {
+ t.Errorf("expecting peer to be present in the relay cancel map, got not present")
+ }
+
+ var updates []*UpdateMessage
+
+loop:
+ for timeout := time.After(5 * time.Second); ; {
+ select {
+ case update := <-updateChannel:
+ updates = append(updates, update)
+ case <-timeout:
+ break loop
+ }
+
+ if len(updates) >= 2 {
+ break loop
+ }
+ }
+
+ if len(updates) < 2 {
+ t.Errorf("expecting at least 2 peer credentials updates, got %v", len(updates))
+ }
+
+ var turnUpdates, relayUpdates int
+ var firstTurnUpdate, secondTurnUpdate *proto.ProtectedHostConfig
+ var firstRelayUpdate, secondRelayUpdate *proto.RelayConfig
+
+ for _, update := range updates {
+ if turns := update.Update.GetWiretrusteeConfig().GetTurns(); len(turns) > 0 {
+ turnUpdates++
+ if turnUpdates == 1 {
+ firstTurnUpdate = turns[0]
+ } else {
+ secondTurnUpdate = turns[0]
+ }
+ }
+ if relay := update.Update.GetWiretrusteeConfig().GetRelay(); relay != nil {
+ relayUpdates++
+ if relayUpdates == 1 {
+ firstRelayUpdate = relay
+ } else {
+ secondRelayUpdate = relay
+ }
+ }
+ }
+
+ if turnUpdates < 1 {
+ t.Errorf("expecting at least 1 TURN credential update, got %v", turnUpdates)
+ }
+ if relayUpdates < 1 {
+ t.Errorf("expecting at least 1 relay credential update, got %v", relayUpdates)
+ }
+
+ if firstTurnUpdate != nil && secondTurnUpdate != nil {
+ if firstTurnUpdate.Password == secondTurnUpdate.Password {
+ t.Errorf("expecting first TURN credential update password %v to be different from second, got equal", firstTurnUpdate.Password)
+ }
+ }
+
+ if firstRelayUpdate != nil && secondRelayUpdate != nil {
+ if firstRelayUpdate.TokenSignature == secondRelayUpdate.TokenSignature {
+ t.Errorf("expecting first relay credential update signature %v to be different from second, got equal", firstRelayUpdate.TokenSignature)
+ }
+ }
+}
+
+func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
+ ttl := util.Duration{Duration: time.Hour}
+ secret := "some_secret"
+ peersManager := NewPeersUpdateManager(nil)
+ peer := "some_peer"
+
+ rc := &Relay{
+ Addresses: []string{"localhost:0"},
+ CredentialsTTL: ttl,
+ Secret: secret,
+ }
+ tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{
+ CredentialsTTL: ttl,
+ Secret: secret,
+ Turns: []*Host{TurnTestHost},
+ TimeBasedCredentials: true,
+ }, rc)
+
+ tested.SetupRefresh(context.Background(), peer)
+ if _, ok := tested.turnCancelMap[peer]; !ok {
+ t.Errorf("expecting peer to be present in turn cancel map, got not present")
+ }
+ if _, ok := tested.relayCancelMap[peer]; !ok {
+ t.Errorf("expecting peer to be present in relay cancel map, got not present")
+ }
+
+ tested.CancelRefresh(peer)
+ if _, ok := tested.turnCancelMap[peer]; ok {
+ t.Errorf("expecting peer to be not present in turn cancel map, got present")
+ }
+ if _, ok := tested.relayCancelMap[peer]; ok {
+ t.Errorf("expecting peer to be not present in relay cancel map, got present")
+ }
+}
+
+func validateMAC(t *testing.T, algo func() hash.Hash, username string, actualMAC string, key []byte) {
+ t.Helper()
+ mac := hmac.New(algo, key)
+
+ _, err := mac.Write([]byte(username))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ expectedMAC := mac.Sum(nil)
+ decodedMAC, err := base64.StdEncoding.DecodeString(actualMAC)
+ if err != nil {
+ t.Fatal(err)
+ }
+ equal := hmac.Equal(decodedMAC, expectedMAC)
+
+ if !equal {
+ t.Errorf("expected password MAC to be %s. got %s", expectedMAC, decodedMAC)
+ }
+}
diff --git a/management/server/turncredentials.go b/management/server/turncredentials.go
deleted file mode 100644
index 79f42e882..000000000
--- a/management/server/turncredentials.go
+++ /dev/null
@@ -1,126 +0,0 @@
-package server
-
-import (
- "context"
- "crypto/hmac"
- "crypto/sha1"
- "encoding/base64"
- "fmt"
- "sync"
- "time"
-
- log "github.com/sirupsen/logrus"
-
- "github.com/netbirdio/netbird/management/proto"
-)
-
-// TURNCredentialsManager used to manage TURN credentials
-type TURNCredentialsManager interface {
- GenerateCredentials() TURNCredentials
- SetupRefresh(ctx context.Context, peerKey string)
- CancelRefresh(peerKey string)
-}
-
-// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
-type TimeBasedAuthSecretsManager struct {
- mux sync.Mutex
- config *TURNConfig
- updateManager *PeersUpdateManager
- cancelMap map[string]chan struct{}
-}
-
-type TURNCredentials struct {
- Username string
- Password string
-}
-
-func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, config *TURNConfig) *TimeBasedAuthSecretsManager {
- return &TimeBasedAuthSecretsManager{
- mux: sync.Mutex{},
- config: config,
- updateManager: updateManager,
- cancelMap: make(map[string]chan struct{}),
- }
-}
-
-// GenerateCredentials generates new time-based secret credentials - basically username is a unix timestamp and password is a HMAC hash of a timestamp with a preshared TURN secret
-func (m *TimeBasedAuthSecretsManager) GenerateCredentials() TURNCredentials {
- mac := hmac.New(sha1.New, []byte(m.config.Secret))
-
- timeAuth := time.Now().Add(m.config.CredentialsTTL.Duration).Unix()
-
- username := fmt.Sprint(timeAuth)
-
- _, err := mac.Write([]byte(username))
- if err != nil {
- log.Errorln("Generating turn password failed with error: ", err)
- }
-
- bytePassword := mac.Sum(nil)
- password := base64.StdEncoding.EncodeToString(bytePassword)
-
- return TURNCredentials{
- Username: username,
- Password: password,
- }
-
-}
-
-func (m *TimeBasedAuthSecretsManager) cancel(peerID string) {
- if channel, ok := m.cancelMap[peerID]; ok {
- close(channel)
- delete(m.cancelMap, peerID)
- }
-}
-
-// CancelRefresh cancels scheduled peer credentials refresh
-func (m *TimeBasedAuthSecretsManager) CancelRefresh(peerID string) {
- m.mux.Lock()
- defer m.mux.Unlock()
- m.cancel(peerID)
-}
-
-// SetupRefresh starts peer credentials refresh. Since credentials are expiring (TTL) it is necessary to always generate them and send to the peer.
-// A goroutine is created and put into TimeBasedAuthSecretsManager.cancelMap. This routine should be cancelled if peer is gone.
-func (m *TimeBasedAuthSecretsManager) SetupRefresh(ctx context.Context, peerID string) {
- m.mux.Lock()
- defer m.mux.Unlock()
- m.cancel(peerID)
- cancel := make(chan struct{}, 1)
- m.cancelMap[peerID] = cancel
- log.WithContext(ctx).Debugf("starting turn refresh for %s", peerID)
-
- go func() {
- // we don't want to regenerate credentials right on expiration, so we do it slightly before (at 3/4 of TTL)
- ticker := time.NewTicker(m.config.CredentialsTTL.Duration / 4 * 3)
-
- for {
- select {
- case <-cancel:
- log.WithContext(ctx).Debugf("stopping turn refresh for %s", peerID)
- return
- case <-ticker.C:
- c := m.GenerateCredentials()
- var turns []*proto.ProtectedHostConfig
- for _, host := range m.config.Turns {
- turns = append(turns, &proto.ProtectedHostConfig{
- HostConfig: &proto.HostConfig{
- Uri: host.URI,
- Protocol: ToResponseProto(host.Proto),
- },
- User: c.Username,
- Password: c.Password,
- })
- }
-
- update := &proto.SyncResponse{
- WiretrusteeConfig: &proto.WiretrusteeConfig{
- Turns: turns,
- },
- }
- log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
- m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update})
- }
- }
- }()
-}
diff --git a/management/server/turncredentials_test.go b/management/server/turncredentials_test.go
deleted file mode 100644
index 667dccbb5..000000000
--- a/management/server/turncredentials_test.go
+++ /dev/null
@@ -1,136 +0,0 @@
-package server
-
-import (
- "context"
- "crypto/hmac"
- "crypto/sha1"
- "encoding/base64"
- "testing"
- "time"
-
- "github.com/netbirdio/netbird/util"
-)
-
-var TurnTestHost = &Host{
- Proto: UDP,
- URI: "turn:turn.wiretrustee.com:77777",
- Username: "username",
- Password: "",
-}
-
-func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
- ttl := util.Duration{Duration: time.Hour}
- secret := "some_secret"
- peersManager := NewPeersUpdateManager(nil)
-
- tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{
- CredentialsTTL: ttl,
- Secret: secret,
- Turns: []*Host{TurnTestHost},
- })
-
- credentials := tested.GenerateCredentials()
-
- if credentials.Username == "" {
- t.Errorf("expected generated TURN username not to be empty, got empty")
- }
- if credentials.Password == "" {
- t.Errorf("expected generated TURN password not to be empty, got empty")
- }
-
- validateMAC(t, credentials.Username, credentials.Password, []byte(secret))
-
-}
-
-func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
- ttl := util.Duration{Duration: 2 * time.Second}
- secret := "some_secret"
- peersManager := NewPeersUpdateManager(nil)
- peer := "some_peer"
- updateChannel := peersManager.CreateChannel(context.Background(), peer)
-
- tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{
- CredentialsTTL: ttl,
- Secret: secret,
- Turns: []*Host{TurnTestHost},
- })
-
- tested.SetupRefresh(context.Background(), peer)
-
- if _, ok := tested.cancelMap[peer]; !ok {
- t.Errorf("expecting peer to be present in a cancel map, got not present")
- }
-
- var updates []*UpdateMessage
-
-loop:
- for timeout := time.After(5 * time.Second); ; {
-
- select {
- case update := <-updateChannel:
- updates = append(updates, update)
- case <-timeout:
- break loop
- }
-
- if len(updates) >= 2 {
- break loop
- }
- }
-
- if len(updates) < 2 {
- t.Errorf("expecting 2 peer credentials updates, got %v", len(updates))
- }
-
- firstUpdate := updates[0].Update.GetWiretrusteeConfig().Turns[0]
- secondUpdate := updates[1].Update.GetWiretrusteeConfig().Turns[0]
-
- if firstUpdate.Password == secondUpdate.Password {
- t.Errorf("expecting first credential update password %v to be diffeerent from second, got equal", firstUpdate.Password)
- }
-
-}
-
-func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
- ttl := util.Duration{Duration: time.Hour}
- secret := "some_secret"
- peersManager := NewPeersUpdateManager(nil)
- peer := "some_peer"
-
- tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{
- CredentialsTTL: ttl,
- Secret: secret,
- Turns: []*Host{TurnTestHost},
- })
-
- tested.SetupRefresh(context.Background(), peer)
- if _, ok := tested.cancelMap[peer]; !ok {
- t.Errorf("expecting peer to be present in a cancel map, got not present")
- }
-
- tested.CancelRefresh(peer)
- if _, ok := tested.cancelMap[peer]; ok {
- t.Errorf("expecting peer to be not present in a cancel map, got present")
- }
-}
-
-func validateMAC(t *testing.T, username string, actualMAC string, key []byte) {
- t.Helper()
- mac := hmac.New(sha1.New, key)
-
- _, err := mac.Write([]byte(username))
- if err != nil {
- t.Fatal(err)
- }
-
- expectedMAC := mac.Sum(nil)
- decodedMAC, err := base64.StdEncoding.DecodeString(actualMAC)
- if err != nil {
- t.Fatal(err)
- }
- equal := hmac.Equal(decodedMAC, expectedMAC)
-
- if !equal {
- t.Errorf("expected password MAC to be %s. got %s", expectedMAC, decodedMAC)
- }
-}
diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go
index 68e694893..36168cffe 100644
--- a/management/server/updatechannel.go
+++ b/management/server/updatechannel.go
@@ -84,7 +84,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
log.WithContext(ctx).Debugf("update was sent to channel for peer %s", peerID)
default:
dropped = true
- log.WithContext(ctx).Warnf("channel for peer %s is %d full", peerID, len(channel))
+ log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel))
}
} else {
log.WithContext(ctx).Debugf("peer %s has no channel", peerID)
diff --git a/management/server/user.go b/management/server/user.go
index 6c7fdfe3c..7acb0b487 100644
--- a/management/server/user.go
+++ b/management/server/user.go
@@ -90,15 +90,16 @@ func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool {
return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero()
}
-func (u *User) updateLastLogin(login time.Time) {
- u.LastLogin = login
-}
-
// HasAdminPower returns true if the user has admin or owner roles, false otherwise
func (u *User) HasAdminPower() bool {
return u.Role == UserRoleAdmin || u.Role == UserRoleOwner
}
+// IsAdminOrServiceUser checks if the user has admin power or is a service user.
+func (u *User) IsAdminOrServiceUser() bool {
+ return u.HasAdminPower() || u.IsServiceUser
+}
+
// ToUserInfo converts a User object to a UserInfo object.
func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) {
autoGroups := u.AutoGroups
@@ -362,39 +363,35 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
return newUser.ToUserInfo(idpUser, account.Settings)
}
+func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*User, error) {
+ return am.Store.GetUserByUserID(ctx, LockingStrengthShare, id)
+}
+
// GetUser looks up a user by provided authorization claims.
// It will also create an account if didn't exist for this user before.
func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) {
- account, _, err := am.GetAccountFromToken(ctx, claims)
+ accountID, userID, err := am.GetAccountIDFromToken(ctx, claims)
if err != nil {
return nil, fmt.Errorf("failed to get account with token claims %v", err)
}
- unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
- defer unlock()
-
- account, err = am.Store.GetAccount(ctx, account.Id)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
- return nil, fmt.Errorf("failed to get an account from store %v", err)
+ return nil, err
}
- user, ok := account.Users[claims.UserId]
- if !ok {
- return nil, status.Errorf(status.NotFound, "user not found")
- }
-
- // this code should be outside of the am.GetAccountFromToken(claims) because this method is called also by the gRPC
+ // this code should be outside of the am.GetAccountIDFromToken(claims) because this method is called also by the gRPC
// server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event.
newLogin := user.LastDashboardLoginChanged(claims.LastLogin)
- err = am.Store.SaveUserLastLogin(account.Id, claims.UserId, claims.LastLogin)
+ err = am.Store.SaveUserLastLogin(ctx, accountID, userID, claims.LastLogin)
if err != nil {
log.WithContext(ctx).Errorf("failed saving user last login: %v", err)
}
if newLogin {
meta := map[string]any{"timestamp": claims.LastLogin}
- am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.DashboardLogin, meta)
+ am.StoreEvent(ctx, claims.UserId, claims.UserId, accountID, activity.DashboardLogin, meta)
}
return user, nil
@@ -654,63 +651,48 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
// GetPAT returns a specific PAT from a user
func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
+ initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
if err != nil {
- return nil, status.Errorf(status.NotFound, "account not found: %s", err)
+ return nil, err
}
- targetUser, ok := account.Users[targetUserID]
- if !ok {
- return nil, status.Errorf(status.NotFound, "user not found")
+ targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
+ if err != nil {
+ return nil, err
}
- executingUser, ok := account.Users[initiatorUserID]
- if !ok {
- return nil, status.Errorf(status.NotFound, "user not found")
+ if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID {
+ return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
}
- if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) {
- return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this userser")
+ for _, pat := range targetUser.PATsG {
+ if pat.ID == tokenID {
+ return pat.Copy(), nil
+ }
}
- pat := targetUser.PATs[tokenID]
- if pat == nil {
- return nil, status.Errorf(status.NotFound, "PAT not found")
- }
-
- return pat, nil
+ return nil, status.Errorf(status.NotFound, "PAT not found")
}
// GetAllPATs returns all PATs for a user
func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
+ initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
if err != nil {
- return nil, status.Errorf(status.NotFound, "account not found: %s", err)
+ return nil, err
}
- targetUser, ok := account.Users[targetUserID]
- if !ok {
- return nil, status.Errorf(status.NotFound, "user not found")
+ targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
+ if err != nil {
+ return nil, err
}
- executingUser, ok := account.Users[initiatorUserID]
- if !ok {
- return nil, status.Errorf(status.NotFound, "user not found")
- }
-
- if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) {
+ if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
}
- var pats []*PersonalAccessToken
- for _, pat := range targetUser.PATs {
- pats = append(pats, pat)
+ pats := make([]*PersonalAccessToken, 0, len(targetUser.PATsG))
+ for _, pat := range targetUser.PATsG {
+ pats = append(pats, pat.Copy())
}
return pats, nil
diff --git a/management/server/user_test.go b/management/server/user_test.go
index f5a26ef89..7740b4059 100644
--- a/management/server/user_test.go
+++ b/management/server/user_test.go
@@ -202,7 +202,8 @@ func TestUser_GetPAT(t *testing.T) {
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{
- Id: mockUserID,
+ Id: mockUserID,
+ AccountID: mockAccountID,
PATs: map[string]*PersonalAccessToken{
mockTokenID1: {
ID: mockTokenID1,
@@ -234,7 +235,8 @@ func TestUser_GetAllPATs(t *testing.T) {
defer store.Close(context.Background())
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
account.Users[mockUserID] = &User{
- Id: mockUserID,
+ Id: mockUserID,
+ AccountID: mockAccountID,
PATs: map[string]*PersonalAccessToken{
mockTokenID1: {
ID: mockTokenID1,
@@ -799,7 +801,10 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
assert.NoError(t, err)
}
- acc, err := am.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
+ accID, err := am.GetAccountIDByUserOrAccountID(context.Background(), "", account.Id, "")
+ assert.NoError(t, err)
+
+ acc, err := am.Store.GetAccount(context.Background(), accID)
assert.NoError(t, err)
for _, id := range tc.expectedDeleted {
diff --git a/relay/Dockerfile b/relay/Dockerfile
new file mode 100644
index 000000000..f750027c3
--- /dev/null
+++ b/relay/Dockerfile
@@ -0,0 +1,4 @@
+FROM gcr.io/distroless/base:debug
+ENTRYPOINT [ "/go/bin/netbird-relay" ]
+ENV NB_LOG_FILE=console
+COPY netbird-relay /go/bin/netbird-relay
diff --git a/relay/auth/allow/allow_all.go b/relay/auth/allow/allow_all.go
new file mode 100644
index 000000000..2d30c59c9
--- /dev/null
+++ b/relay/auth/allow/allow_all.go
@@ -0,0 +1,14 @@
+package allow
+
+// Auth is a Validator that allows all connections.
+// Used this for testing purposes only.
+type Auth struct {
+}
+
+func (a *Auth) Validate(any) error {
+ return nil
+}
+
+func (a *Auth) ValidateHelloMsgType(any) error {
+ return nil
+}
diff --git a/relay/auth/doc.go b/relay/auth/doc.go
new file mode 100644
index 000000000..b3e8dbb08
--- /dev/null
+++ b/relay/auth/doc.go
@@ -0,0 +1,26 @@
+/*
+Package auth manages the authentication process with the relay server.
+
+Key Components:
+
+Validator: The Validator interface defines the Validate method. Any type that provides this method can be used as a
+Validator.
+
+Methods:
+
+Validate(func() hash.Hash, any): This method is defined in the Validator interface and is used to validate the authentication.
+
+Usage:
+
+To create a new AllowAllAuth validator, simply instantiate it:
+
+ validator := &allow.Auth{}
+
+To validate the authentication, use the Validate method:
+
+ err := validator.Validate(sha256.New, any)
+
+This package provides a simple and effective way to manage authentication with the relay server, ensuring that the
+peers are authenticated properly.
+*/
+package auth
diff --git a/relay/auth/hmac/doc.go b/relay/auth/hmac/doc.go
new file mode 100644
index 000000000..a1b135aa6
--- /dev/null
+++ b/relay/auth/hmac/doc.go
@@ -0,0 +1,8 @@
+/*
+This package uses a similar HMAC method for authentication with the TURN server. The Management server provides the
+tokens for the peers. The peers manage these tokens in the token store. The token store is a simple thread safe store
+that keeps the tokens in memory. These tokens are used to authenticate the peers with the Relay server in the hello
+message.
+*/
+
+package hmac
diff --git a/relay/auth/hmac/store.go b/relay/auth/hmac/store.go
new file mode 100644
index 000000000..169b8d6b0
--- /dev/null
+++ b/relay/auth/hmac/store.go
@@ -0,0 +1,44 @@
+package hmac
+
+import (
+ "encoding/base64"
+ "fmt"
+ "sync"
+
+ v2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
+)
+
+// TokenStore is a simple in-memory store for token
+// With this can update the token in thread safe way
+type TokenStore struct {
+ mu sync.Mutex
+ token []byte
+}
+
+func (a *TokenStore) UpdateToken(token *Token) error {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ if token == nil {
+ return nil
+ }
+
+ sig, err := base64.StdEncoding.DecodeString(token.Signature)
+ if err != nil {
+ return fmt.Errorf("decode signature: %w", err)
+ }
+
+ tok := v2.Token{
+ AuthAlgo: v2.AuthAlgoHMACSHA256,
+ Signature: sig,
+ Payload: []byte(token.Payload),
+ }
+
+ a.token = tok.Marshal()
+ return nil
+}
+
+func (a *TokenStore) TokenBinary() []byte {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ return a.token
+}
diff --git a/relay/auth/hmac/token.go b/relay/auth/hmac/token.go
new file mode 100644
index 000000000..581b1d6fd
--- /dev/null
+++ b/relay/auth/hmac/token.go
@@ -0,0 +1,94 @@
+package hmac
+
+import (
+ "bytes"
+ "crypto/hmac"
+ "encoding/base64"
+ "encoding/gob"
+ "fmt"
+ "hash"
+ "strconv"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+)
+
+type Token struct {
+ Payload string
+ Signature string
+}
+
+func unmarshalToken(payload []byte) (Token, error) {
+ var creds Token
+ buffer := bytes.NewBuffer(payload)
+ decoder := gob.NewDecoder(buffer)
+ err := decoder.Decode(&creds)
+ return creds, err
+}
+
+// TimedHMAC generates a token with TTL and uses a pre-shared secret known to the relay server
+type TimedHMAC struct {
+ secret string
+ timeToLive time.Duration
+}
+
+// NewTimedHMAC creates a new TimedHMAC instance
+func NewTimedHMAC(secret string, timeToLive time.Duration) *TimedHMAC {
+ return &TimedHMAC{
+ secret: secret,
+ timeToLive: timeToLive,
+ }
+}
+
+// GenerateToken generates new time-based secret token - basically Payload is a unix timestamp and Signature is a HMAC
+// hash of a timestamp with a preshared TURN secret
+func (m *TimedHMAC) GenerateToken(algo func() hash.Hash) (*Token, error) {
+ timeAuth := time.Now().Add(m.timeToLive).Unix()
+ timeStamp := strconv.FormatInt(timeAuth, 10)
+
+ checksum, err := m.generate(algo, timeStamp)
+ if err != nil {
+ return nil, err
+ }
+
+ return &Token{
+ Payload: timeStamp,
+ Signature: base64.StdEncoding.EncodeToString(checksum),
+ }, nil
+}
+
+// Validate checks if the token is valid
+func (m *TimedHMAC) Validate(algo func() hash.Hash, token Token) error {
+ expectedMAC, err := m.generate(algo, token.Payload)
+ if err != nil {
+ return err
+ }
+
+ expectedSignature := base64.StdEncoding.EncodeToString(expectedMAC)
+
+ if !hmac.Equal([]byte(expectedSignature), []byte(token.Signature)) {
+ return fmt.Errorf("signature mismatch")
+ }
+
+ timeAuthInt, err := strconv.ParseInt(token.Payload, 10, 64)
+ if err != nil {
+ return fmt.Errorf("invalid payload: %w", err)
+ }
+
+ if time.Now().Unix() > timeAuthInt {
+ return fmt.Errorf("expired token")
+ }
+
+ return nil
+}
+
+func (m *TimedHMAC) generate(algo func() hash.Hash, payload string) ([]byte, error) {
+ mac := hmac.New(algo, []byte(m.secret))
+ _, err := mac.Write([]byte(payload))
+ if err != nil {
+ log.Debugf("failed to generate token: %s", err)
+ return nil, fmt.Errorf("failed to generate token: %w", err)
+ }
+
+ return mac.Sum(nil), nil
+}
diff --git a/relay/auth/hmac/token_test.go b/relay/auth/hmac/token_test.go
new file mode 100644
index 000000000..e629eab97
--- /dev/null
+++ b/relay/auth/hmac/token_test.go
@@ -0,0 +1,105 @@
+package hmac
+
+import (
+ "crypto/sha1"
+ "crypto/sha256"
+ "encoding/base64"
+ "strconv"
+ "testing"
+ "time"
+)
+
+func TestGenerateCredentials(t *testing.T) {
+ secret := "secret"
+ timeToLive := 1 * time.Hour
+ v := NewTimedHMAC(secret, timeToLive)
+
+ creds, err := v.GenerateToken(sha1.New)
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ if creds.Payload == "" {
+ t.Fatalf("expected non-empty payload")
+ }
+
+ _, err = strconv.ParseInt(creds.Payload, 10, 64)
+ if err != nil {
+ t.Fatalf("expected payload to be a valid unix timestamp, got %v", err)
+ }
+
+ _, err = base64.StdEncoding.DecodeString(creds.Signature)
+ if err != nil {
+ t.Fatalf("expected signature to be base64 encoded, got %v", err)
+ }
+}
+
+func TestValidateCredentials(t *testing.T) {
+ secret := "supersecret"
+ timeToLive := 1 * time.Hour
+ manager := NewTimedHMAC(secret, timeToLive)
+
+ // Test valid token
+ creds, err := manager.GenerateToken(sha1.New)
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ if err := manager.Validate(sha1.New, *creds); err != nil {
+ t.Fatalf("expected valid token: %s", err)
+ }
+}
+
+func TestInvalidSignature(t *testing.T) {
+ secret := "supersecret"
+ timeToLive := 1 * time.Hour
+ manager := NewTimedHMAC(secret, timeToLive)
+
+ creds, err := manager.GenerateToken(sha256.New)
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ invalidCreds := &Token{
+ Payload: creds.Payload,
+ Signature: "invalidsignature",
+ }
+
+ if err = manager.Validate(sha1.New, *invalidCreds); err == nil {
+ t.Fatalf("expected invalid token due to signature mismatch")
+ }
+}
+
+func TestExpired(t *testing.T) {
+ secret := "supersecret"
+ v := NewTimedHMAC(secret, -1*time.Hour)
+ expiredCreds, err := v.GenerateToken(sha256.New)
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ if err = v.Validate(sha1.New, *expiredCreds); err == nil {
+ t.Fatalf("expected invalid token due to expiration")
+ }
+}
+
+func TestInvalidPayload(t *testing.T) {
+ secret := "supersecret"
+ timeToLive := 1 * time.Hour
+ v := NewTimedHMAC(secret, timeToLive)
+
+ creds, err := v.GenerateToken(sha256.New)
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ // Test invalid payload
+ invalidPayloadCreds := &Token{
+ Payload: "invalidtimestamp",
+ Signature: creds.Signature,
+ }
+
+ if err = v.Validate(sha1.New, *invalidPayloadCreds); err == nil {
+ t.Fatalf("expected invalid token due to invalid payload")
+ }
+}
diff --git a/relay/auth/hmac/v2/algo.go b/relay/auth/hmac/v2/algo.go
new file mode 100644
index 000000000..c379c2bd7
--- /dev/null
+++ b/relay/auth/hmac/v2/algo.go
@@ -0,0 +1,40 @@
+package v2
+
+import (
+ "crypto/sha256"
+ "hash"
+)
+
+const (
+ AuthAlgoUnknown AuthAlgo = iota
+ AuthAlgoHMACSHA256
+)
+
+type AuthAlgo uint8
+
+func (a AuthAlgo) String() string {
+ switch a {
+ case AuthAlgoHMACSHA256:
+ return "HMAC-SHA256"
+ default:
+ return "Unknown"
+ }
+}
+
+func (a AuthAlgo) New() func() hash.Hash {
+ switch a {
+ case AuthAlgoHMACSHA256:
+ return sha256.New
+ default:
+ return nil
+ }
+}
+
+func (a AuthAlgo) Size() int {
+ switch a {
+ case AuthAlgoHMACSHA256:
+ return sha256.Size
+ default:
+ return 0
+ }
+}
diff --git a/relay/auth/hmac/v2/generator.go b/relay/auth/hmac/v2/generator.go
new file mode 100644
index 000000000..827532730
--- /dev/null
+++ b/relay/auth/hmac/v2/generator.go
@@ -0,0 +1,45 @@
+package v2
+
+import (
+ "crypto/hmac"
+ "fmt"
+ "hash"
+ "strconv"
+ "time"
+)
+
+type Generator struct {
+ algo func() hash.Hash
+ algoType AuthAlgo
+ secret []byte
+ timeToLive time.Duration
+}
+
+func NewGenerator(algo AuthAlgo, secret []byte, timeToLive time.Duration) (*Generator, error) {
+ algoFunc := algo.New()
+ if algoFunc == nil {
+ return nil, fmt.Errorf("unsupported auth algorithm: %s", algo)
+ }
+ return &Generator{
+ algo: algoFunc,
+ algoType: algo,
+ secret: secret,
+ timeToLive: timeToLive,
+ }, nil
+}
+
+func (g *Generator) GenerateToken() (*Token, error) {
+ expirationTime := time.Now().Add(g.timeToLive).Unix()
+
+ payload := []byte(strconv.FormatInt(expirationTime, 10))
+
+ h := hmac.New(g.algo, g.secret)
+ h.Write(payload)
+ signature := h.Sum(nil)
+
+ return &Token{
+ AuthAlgo: g.algoType,
+ Signature: signature,
+ Payload: payload,
+ }, nil
+}
diff --git a/relay/auth/hmac/v2/hmac_test.go b/relay/auth/hmac/v2/hmac_test.go
new file mode 100644
index 000000000..40336363f
--- /dev/null
+++ b/relay/auth/hmac/v2/hmac_test.go
@@ -0,0 +1,110 @@
+package v2
+
+import (
+ "strconv"
+ "testing"
+ "time"
+)
+
+func TestGenerateCredentials(t *testing.T) {
+ secret := "supersecret"
+ timeToLive := 1 * time.Hour
+ g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
+ if err != nil {
+ t.Fatalf("failed to create generator: %v", err)
+ }
+
+ token, err := g.GenerateToken()
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ if len(token.Payload) == 0 {
+ t.Fatalf("expected non-empty payload")
+ }
+
+ _, err = strconv.ParseInt(string(token.Payload), 10, 64)
+ if err != nil {
+ t.Fatalf("expected payload to be a valid unix timestamp, got %v", err)
+ }
+}
+
+func TestValidateCredentials(t *testing.T) {
+ secret := "supersecret"
+ timeToLive := 1 * time.Hour
+ g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
+ if err != nil {
+ t.Fatalf("failed to create generator: %v", err)
+ }
+
+ token, err := g.GenerateToken()
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ v := NewValidator([]byte(secret))
+ if err := v.Validate(token.Marshal()); err != nil {
+ t.Fatalf("expected valid token: %s", err)
+ }
+}
+
+func TestInvalidSignature(t *testing.T) {
+ secret := "supersecret"
+ timeToLive := 1 * time.Hour
+ g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
+ if err != nil {
+ t.Fatalf("failed to create generator: %v", err)
+ }
+
+ token, err := g.GenerateToken()
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ token.Signature = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
+
+ v := NewValidator([]byte(secret))
+ if err := v.Validate(token.Marshal()); err == nil {
+ t.Fatalf("expected valid token: %s", err)
+ }
+}
+
+func TestExpired(t *testing.T) {
+ secret := "supersecret"
+ timeToLive := -1 * time.Hour
+ g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
+ if err != nil {
+ t.Fatalf("failed to create generator: %v", err)
+ }
+
+ token, err := g.GenerateToken()
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ v := NewValidator([]byte(secret))
+ if err := v.Validate(token.Marshal()); err == nil {
+ t.Fatalf("expected valid token: %s", err)
+ }
+}
+
+func TestInvalidPayload(t *testing.T) {
+ secret := "supersecret"
+ timeToLive := 1 * time.Hour
+ g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive)
+ if err != nil {
+ t.Fatalf("failed to create generator: %v", err)
+ }
+
+ token, err := g.GenerateToken()
+ if err != nil {
+ t.Fatalf("expected no error, got %v", err)
+ }
+
+ token.Payload = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
+
+ v := NewValidator([]byte(secret))
+ if err := v.Validate(token.Marshal()); err == nil {
+ t.Fatalf("expected invalid token due to invalid payload")
+ }
+}
diff --git a/relay/auth/hmac/v2/token.go b/relay/auth/hmac/v2/token.go
new file mode 100644
index 000000000..553ac01b9
--- /dev/null
+++ b/relay/auth/hmac/v2/token.go
@@ -0,0 +1,39 @@
+package v2
+
+import "errors"
+
+type Token struct {
+ AuthAlgo AuthAlgo
+ Signature []byte
+ Payload []byte
+}
+
+func (t *Token) Marshal() []byte {
+ size := 1 + len(t.Signature) + len(t.Payload)
+
+ buf := make([]byte, size)
+
+ buf[0] = byte(t.AuthAlgo)
+ copy(buf[1:], t.Signature)
+ copy(buf[1+len(t.Signature):], t.Payload)
+
+ return buf
+}
+
+func UnmarshalToken(data []byte) (*Token, error) {
+ if len(data) == 0 {
+ return nil, errors.New("invalid token data")
+ }
+
+ algo := AuthAlgo(data[0])
+ sigSize := algo.Size()
+ if len(data) < 1+sigSize {
+ return nil, errors.New("invalid token data: insufficient length")
+ }
+
+ return &Token{
+ AuthAlgo: algo,
+ Signature: data[1 : 1+sigSize],
+ Payload: data[1+sigSize:],
+ }, nil
+}
diff --git a/relay/auth/hmac/v2/validator.go b/relay/auth/hmac/v2/validator.go
new file mode 100644
index 000000000..7f448dd5f
--- /dev/null
+++ b/relay/auth/hmac/v2/validator.go
@@ -0,0 +1,59 @@
+package v2
+
+import (
+ "crypto/hmac"
+ "errors"
+ "fmt"
+ "strconv"
+ "time"
+)
+
+const minLengthUnixTimestamp = 10
+
+type Validator struct {
+ secret []byte
+}
+
+func NewValidator(secret []byte) *Validator {
+ return &Validator{secret: secret}
+}
+
+func (v *Validator) Validate(data any) error {
+ d, ok := data.([]byte)
+ if !ok {
+ return fmt.Errorf("invalid data type")
+ }
+
+ token, err := UnmarshalToken(d)
+ if err != nil {
+ return fmt.Errorf("unmarshal token: %w", err)
+ }
+
+ if len(token.Payload) < minLengthUnixTimestamp {
+ return errors.New("invalid payload: insufficient length")
+ }
+
+ hashFunc := token.AuthAlgo.New()
+ if hashFunc == nil {
+ return fmt.Errorf("unsupported auth algorithm: %s", token.AuthAlgo)
+ }
+
+ h := hmac.New(hashFunc, v.secret)
+ h.Write(token.Payload)
+ expectedMAC := h.Sum(nil)
+
+ if !hmac.Equal(token.Signature, expectedMAC) {
+ return errors.New("invalid signature")
+ }
+
+ timestamp, err := strconv.ParseInt(string(token.Payload), 10, 64)
+ if err != nil {
+ return fmt.Errorf("invalid payload: %w", err)
+ }
+
+ if time.Now().Unix() > timestamp {
+ return fmt.Errorf("expired token")
+ }
+
+ return nil
+}
diff --git a/relay/auth/hmac/validator.go b/relay/auth/hmac/validator.go
new file mode 100644
index 000000000..b0b7542be
--- /dev/null
+++ b/relay/auth/hmac/validator.go
@@ -0,0 +1,33 @@
+package hmac
+
+import (
+ "crypto/sha256"
+ "fmt"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+)
+
+type TimedHMACValidator struct {
+ *TimedHMAC
+}
+
+func NewTimedHMACValidator(secret string, duration time.Duration) *TimedHMACValidator {
+ ta := NewTimedHMAC(secret, duration)
+ return &TimedHMACValidator{
+ ta,
+ }
+}
+
+func (a *TimedHMACValidator) Validate(credentials any) error {
+ b, ok := credentials.([]byte)
+ if !ok {
+ return fmt.Errorf("invalid credentials type")
+ }
+ c, err := unmarshalToken(b)
+ if err != nil {
+ log.Debugf("failed to unmarshal token: %s", err)
+ return err
+ }
+ return a.TimedHMAC.Validate(sha256.New, c)
+}
diff --git a/relay/auth/validator.go b/relay/auth/validator.go
new file mode 100644
index 000000000..854efd5bb
--- /dev/null
+++ b/relay/auth/validator.go
@@ -0,0 +1,35 @@
+package auth
+
+import (
+ "time"
+
+ auth "github.com/netbirdio/netbird/relay/auth/hmac"
+ authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2"
+)
+
+// Validator is an interface that defines the Validate method.
+type Validator interface {
+ Validate(any) error
+ // Deprecated: Use Validate instead.
+ ValidateHelloMsgType(any) error
+}
+
+type TimedHMACValidator struct {
+ authenticatorV2 *authv2.Validator
+ authenticator *auth.TimedHMACValidator
+}
+
+func NewTimedHMACValidator(secret []byte, duration time.Duration) *TimedHMACValidator {
+ return &TimedHMACValidator{
+ authenticatorV2: authv2.NewValidator(secret),
+ authenticator: auth.NewTimedHMACValidator(string(secret), duration),
+ }
+}
+
+func (a *TimedHMACValidator) Validate(credentials any) error {
+ return a.authenticatorV2.Validate(credentials)
+}
+
+func (a *TimedHMACValidator) ValidateHelloMsgType(credentials any) error {
+ return a.authenticator.Validate(credentials)
+}
diff --git a/relay/client/addr.go b/relay/client/addr.go
new file mode 100644
index 000000000..af4f459f8
--- /dev/null
+++ b/relay/client/addr.go
@@ -0,0 +1,13 @@
+package client
+
+type RelayAddr struct {
+ addr string
+}
+
+func (a RelayAddr) Network() string {
+ return "relay"
+}
+
+func (a RelayAddr) String() string {
+ return a.addr
+}
diff --git a/relay/client/client.go b/relay/client/client.go
new file mode 100644
index 000000000..90bc3ac41
--- /dev/null
+++ b/relay/client/client.go
@@ -0,0 +1,574 @@
+package client
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net"
+ "sync"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+
+ auth "github.com/netbirdio/netbird/relay/auth/hmac"
+ "github.com/netbirdio/netbird/relay/client/dialer/ws"
+ "github.com/netbirdio/netbird/relay/healthcheck"
+ "github.com/netbirdio/netbird/relay/messages"
+)
+
+const (
+ bufferSize = 8820
+ serverResponseTimeout = 8 * time.Second
+)
+
+var (
+ ErrConnAlreadyExists = fmt.Errorf("connection already exists")
+)
+
+type internalStopFlag struct {
+ sync.Mutex
+ stop bool
+}
+
+func newInternalStopFlag() *internalStopFlag {
+ return &internalStopFlag{}
+}
+
+func (isf *internalStopFlag) set() {
+ isf.Lock()
+ defer isf.Unlock()
+ isf.stop = true
+}
+
+func (isf *internalStopFlag) isSet() bool {
+ isf.Lock()
+ defer isf.Unlock()
+ return isf.stop
+}
+
+// Msg carry the payload from the server to the client. With this struct, the net.Conn can free the buffer.
+type Msg struct {
+ Payload []byte
+
+ bufPool *sync.Pool
+ bufPtr *[]byte
+}
+
+func (m *Msg) Free() {
+ m.bufPool.Put(m.bufPtr)
+}
+
+// connContainer is a container for the connection to the peer. It is responsible for managing the messages from the
+// server and forwarding them to the upper layer content reader.
+type connContainer struct {
+ log *log.Entry
+ conn *Conn
+ messages chan Msg
+ msgChanLock sync.Mutex
+ closed bool // flag to check if channel is closed
+ ctx context.Context
+ cancel context.CancelFunc
+}
+
+func newConnContainer(log *log.Entry, conn *Conn, messages chan Msg) *connContainer {
+ ctx, cancel := context.WithCancel(context.Background())
+ return &connContainer{
+ log: log,
+ conn: conn,
+ messages: messages,
+ ctx: ctx,
+ cancel: cancel,
+ }
+}
+
+func (cc *connContainer) writeMsg(msg Msg) {
+ cc.msgChanLock.Lock()
+ defer cc.msgChanLock.Unlock()
+
+ if cc.closed {
+ msg.Free()
+ return
+ }
+
+ select {
+ case cc.messages <- msg:
+ case <-cc.ctx.Done():
+ msg.Free()
+ default:
+ msg.Free()
+ cc.log.Infof("message queue is full")
+ // todo consider to close the connection
+ }
+}
+
+func (cc *connContainer) close() {
+ cc.cancel()
+
+ cc.msgChanLock.Lock()
+ defer cc.msgChanLock.Unlock()
+
+ if cc.closed {
+ return
+ }
+
+ cc.closed = true
+ close(cc.messages)
+
+ for msg := range cc.messages {
+ msg.Free()
+ }
+}
+
+// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and
+// managing connections to other peers. All exported functions are safe to call concurrently. After close the connection,
+// the client can be reused by calling Connect again. When the client is closed, all connections are closed too.
+// While the Connect is in progress, the OpenConn function will block until the connection is established with relay server.
+type Client struct {
+ log *log.Entry
+ parentCtx context.Context
+ connectionURL string
+ authTokenStore *auth.TokenStore
+ hashedID []byte
+
+ bufPool *sync.Pool
+
+ relayConn net.Conn
+ conns map[string]*connContainer
+ serviceIsRunning bool
+ mu sync.Mutex // protect serviceIsRunning and conns
+ readLoopMutex sync.Mutex
+ wgReadLoop sync.WaitGroup
+ instanceURL *RelayAddr
+ muInstanceURL sync.Mutex
+
+ onDisconnectListener func()
+ listenerMutex sync.Mutex
+}
+
+// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect
+func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client {
+ hashedID, hashedStringId := messages.HashID(peerID)
+ c := &Client{
+ log: log.WithFields(log.Fields{"relay": serverURL}),
+ parentCtx: ctx,
+ connectionURL: serverURL,
+ authTokenStore: authTokenStore,
+ hashedID: hashedID,
+ bufPool: &sync.Pool{
+ New: func() any {
+ buf := make([]byte, bufferSize)
+ return &buf
+ },
+ },
+ conns: make(map[string]*connContainer),
+ }
+ c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedStringId)
+ return c
+}
+
+// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs.
+func (c *Client) Connect() error {
+ c.log.Infof("connecting to relay server")
+ c.readLoopMutex.Lock()
+ defer c.readLoopMutex.Unlock()
+
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if c.serviceIsRunning {
+ return nil
+ }
+
+ err := c.connect()
+ if err != nil {
+ return err
+ }
+
+ c.serviceIsRunning = true
+
+ c.wgReadLoop.Add(1)
+ go c.readLoop(c.relayConn)
+
+ c.log.Infof("relay connection established")
+ return nil
+}
+
+// OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress
+// to the relay server, the function will block until the connection is established or timed out. Otherwise,
+// it will return immediately.
+// todo: what should happen if call with the same peerID with multiple times?
+func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if !c.serviceIsRunning {
+ return nil, fmt.Errorf("relay connection is not established")
+ }
+
+ hashedID, hashedStringID := messages.HashID(dstPeerID)
+ _, ok := c.conns[hashedStringID]
+ if ok {
+ return nil, ErrConnAlreadyExists
+ }
+
+ c.log.Infof("open connection to peer: %s", hashedStringID)
+ msgChannel := make(chan Msg, 100)
+ conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL)
+
+ c.conns[hashedStringID] = newConnContainer(c.log, conn, msgChannel)
+ return conn, nil
+}
+
+// ServerInstanceURL returns the address of the relay server. It could change after the close and reopen the connection.
+func (c *Client) ServerInstanceURL() (string, error) {
+ c.muInstanceURL.Lock()
+ defer c.muInstanceURL.Unlock()
+ if c.instanceURL == nil {
+ return "", fmt.Errorf("relay connection is not established")
+ }
+ return c.instanceURL.String(), nil
+}
+
+// SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed.
+func (c *Client) SetOnDisconnectListener(fn func()) {
+ c.listenerMutex.Lock()
+ defer c.listenerMutex.Unlock()
+ c.onDisconnectListener = fn
+}
+
+// HasConns returns true if there are connections.
+func (c *Client) HasConns() bool {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ return len(c.conns) > 0
+}
+
+// Close closes the connection to the relay server and all connections to other peers.
+func (c *Client) Close() error {
+ return c.close(true)
+}
+
+func (c *Client) connect() error {
+ conn, err := ws.Dial(c.connectionURL)
+ if err != nil {
+ return err
+ }
+ c.relayConn = conn
+
+ err = c.handShake()
+ if err != nil {
+ cErr := conn.Close()
+ if cErr != nil {
+ c.log.Errorf("failed to close connection: %s", cErr)
+ }
+ return err
+ }
+
+ return nil
+}
+
+func (c *Client) handShake() error {
+ msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary())
+ if err != nil {
+ c.log.Errorf("failed to marshal auth message: %s", err)
+ return err
+ }
+
+ _, err = c.relayConn.Write(msg)
+ if err != nil {
+ c.log.Errorf("failed to send auth message: %s", err)
+ return err
+ }
+ buf := make([]byte, messages.MaxHandshakeRespSize)
+ n, err := c.readWithTimeout(buf)
+ if err != nil {
+ c.log.Errorf("failed to read auth response: %s", err)
+ return err
+ }
+
+ _, err = messages.ValidateVersion(buf[:n])
+ if err != nil {
+ return fmt.Errorf("validate version: %w", err)
+ }
+
+ msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
+ if err != nil {
+ c.log.Errorf("failed to determine message type: %s", err)
+ return err
+ }
+
+ if msgType != messages.MsgTypeAuthResponse {
+ c.log.Errorf("unexpected message type: %s", msgType)
+ return fmt.Errorf("unexpected message type")
+ }
+
+ addr, err := messages.UnmarshalAuthResponse(buf[messages.SizeOfProtoHeader:n])
+ if err != nil {
+ return err
+ }
+
+ c.muInstanceURL.Lock()
+ c.instanceURL = &RelayAddr{addr: addr}
+ c.muInstanceURL.Unlock()
+ return nil
+}
+
+func (c *Client) readLoop(relayConn net.Conn) {
+ internallyStoppedFlag := newInternalStopFlag()
+ hc := healthcheck.NewReceiver(c.log)
+ go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag)
+
+ var (
+ errExit error
+ n int
+ )
+ for {
+ bufPtr := c.bufPool.Get().(*[]byte)
+ buf := *bufPtr
+ n, errExit = relayConn.Read(buf)
+ if errExit != nil {
+ c.log.Infof("start to Relay read loop exit")
+ c.mu.Lock()
+ if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
+ c.log.Debugf("failed to read message from relay server: %s", errExit)
+ }
+ c.mu.Unlock()
+ break
+ }
+
+ _, err := messages.ValidateVersion(buf[:n])
+ if err != nil {
+ c.log.Errorf("failed to validate protocol version: %s", err)
+ c.bufPool.Put(bufPtr)
+ continue
+ }
+
+ msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n])
+ if err != nil {
+ c.log.Errorf("failed to determine message type: %s", err)
+ c.bufPool.Put(bufPtr)
+ continue
+ }
+
+ if !c.handleMsg(msgType, buf[messages.SizeOfProtoHeader:n], bufPtr, hc, internallyStoppedFlag) {
+ break
+ }
+ }
+
+ hc.Stop()
+
+ c.muInstanceURL.Lock()
+ c.instanceURL = nil
+ c.muInstanceURL.Unlock()
+
+ c.notifyDisconnected()
+ c.wgReadLoop.Done()
+ _ = c.close(false)
+}
+
+func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte, hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) (continueLoop bool) {
+ switch msgType {
+ case messages.MsgTypeHealthCheck:
+ c.handleHealthCheck(hc, internallyStoppedFlag)
+ c.bufPool.Put(bufPtr)
+ case messages.MsgTypeTransport:
+ return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag)
+ case messages.MsgTypeClose:
+ c.log.Debugf("relay connection close by server")
+ c.bufPool.Put(bufPtr)
+ return false
+ }
+
+ return true
+}
+
+func (c *Client) handleHealthCheck(hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) {
+ msg := messages.MarshalHealthcheck()
+ _, wErr := c.relayConn.Write(msg)
+ if wErr != nil {
+ if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
+ c.log.Errorf("failed to send heartbeat: %s", wErr)
+ }
+ }
+ hc.Heartbeat()
+}
+
+func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppedFlag *internalStopFlag) bool {
+ peerID, payload, err := messages.UnmarshalTransportMsg(buf)
+ if err != nil {
+ if c.serviceIsRunning && !internallyStoppedFlag.isSet() {
+ c.log.Errorf("failed to parse transport message: %v", err)
+ }
+
+ c.bufPool.Put(bufPtr)
+ return true
+ }
+
+ stringID := messages.HashIDToString(peerID)
+
+ c.mu.Lock()
+ if !c.serviceIsRunning {
+ c.mu.Unlock()
+ c.bufPool.Put(bufPtr)
+ return false
+ }
+ container, ok := c.conns[stringID]
+ c.mu.Unlock()
+ if !ok {
+ c.log.Errorf("peer not found: %s", stringID)
+ c.bufPool.Put(bufPtr)
+ return true
+ }
+ msg := Msg{
+ bufPool: c.bufPool,
+ bufPtr: bufPtr,
+ Payload: payload,
+ }
+ container.writeMsg(msg)
+ return true
+}
+
+func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload []byte) (int, error) {
+ c.mu.Lock()
+ conn, ok := c.conns[id]
+ c.mu.Unlock()
+ if !ok {
+ return 0, io.EOF
+ }
+
+ if conn.conn != connReference {
+ return 0, io.EOF
+ }
+
+ // todo: use buffer pool instead of create new transport msg.
+ msg, err := messages.MarshalTransportMsg(dstID, payload)
+ if err != nil {
+ c.log.Errorf("failed to marshal transport message: %s", err)
+ return 0, err
+ }
+
+ // the write always return with 0 length because the underling does not support the size feedback.
+ _, err = c.relayConn.Write(msg)
+ if err != nil {
+ c.log.Errorf("failed to write transport message: %s", err)
+ }
+ return len(payload), err
+}
+
+func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) {
+ for {
+ select {
+ case _, ok := <-hc.OnTimeout:
+ if !ok {
+ return
+ }
+ c.log.Errorf("health check timeout")
+ internalStopFlag.set()
+ if err := conn.Close(); err != nil {
+ // ignore the err handling because the readLoop will handle it
+ c.log.Warnf("failed to close connection: %s", err)
+ }
+ return
+ case <-c.parentCtx.Done():
+ err := c.close(true)
+ if err != nil {
+ c.log.Errorf("failed to teardown connection: %s", err)
+ }
+ return
+ }
+ }
+}
+
+func (c *Client) closeAllConns() {
+ for _, container := range c.conns {
+ container.close()
+ }
+ c.conns = make(map[string]*connContainer)
+}
+
+func (c *Client) closeConn(connReference *Conn, id string) error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ container, ok := c.conns[id]
+ if !ok {
+ return fmt.Errorf("connection already closed")
+ }
+
+ if container.conn != connReference {
+ return fmt.Errorf("conn reference mismatch")
+ }
+ c.log.Infof("free up connection to peer: %s", id)
+ delete(c.conns, id)
+ container.close()
+
+ return nil
+}
+
+func (c *Client) close(gracefullyExit bool) error {
+ c.readLoopMutex.Lock()
+ defer c.readLoopMutex.Unlock()
+
+ c.mu.Lock()
+ var err error
+ if !c.serviceIsRunning {
+ c.mu.Unlock()
+ c.log.Warn("relay connection was already marked as not running")
+ return nil
+ }
+
+ c.serviceIsRunning = false
+ c.log.Infof("closing all peer connections")
+ c.closeAllConns()
+ if gracefullyExit {
+ c.writeCloseMsg()
+ }
+ err = c.relayConn.Close()
+ c.mu.Unlock()
+
+ c.log.Infof("waiting for read loop to close")
+ c.wgReadLoop.Wait()
+ c.log.Infof("relay connection closed")
+ return err
+}
+
+func (c *Client) notifyDisconnected() {
+ c.listenerMutex.Lock()
+ defer c.listenerMutex.Unlock()
+
+ if c.onDisconnectListener == nil {
+ return
+ }
+ go c.onDisconnectListener()
+}
+
+func (c *Client) writeCloseMsg() {
+ msg := messages.MarshalCloseMsg()
+ _, err := c.relayConn.Write(msg)
+ if err != nil {
+ c.log.Errorf("failed to send close message: %s", err)
+ }
+}
+
+func (c *Client) readWithTimeout(buf []byte) (int, error) {
+ ctx, cancel := context.WithTimeout(c.parentCtx, serverResponseTimeout)
+ defer cancel()
+
+ readDone := make(chan struct{})
+ var (
+ n int
+ err error
+ )
+
+ go func() {
+ n, err = c.relayConn.Read(buf)
+ close(readDone)
+ }()
+
+ select {
+ case <-ctx.Done():
+ return 0, fmt.Errorf("read operation timed out")
+ case <-readDone:
+ return n, err
+ }
+}
diff --git a/relay/client/client_test.go b/relay/client/client_test.go
new file mode 100644
index 000000000..ef28203e9
--- /dev/null
+++ b/relay/client/client_test.go
@@ -0,0 +1,712 @@
+package client
+
+import (
+ "context"
+ "net"
+ "os"
+ "testing"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+ "go.opentelemetry.io/otel"
+
+ "github.com/netbirdio/netbird/relay/auth/allow"
+ "github.com/netbirdio/netbird/relay/auth/hmac"
+ "github.com/netbirdio/netbird/util"
+
+ "github.com/netbirdio/netbird/relay/server"
+)
+
+var (
+ av = &allow.Auth{}
+ hmacTokenStore = &hmac.TokenStore{}
+ serverListenAddr = "127.0.0.1:1234"
+ serverURL = "rel://127.0.0.1:1234"
+)
+
+func TestMain(m *testing.M) {
+ _ = util.InitLog("error", "console")
+ code := m.Run()
+ os.Exit(code)
+}
+
+func TestClient(t *testing.T) {
+ ctx := context.Background()
+
+ srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ listenCfg := server.ListenerConfig{Address: serverListenAddr}
+ err := srv.Listen(listenCfg)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ defer func() {
+ err := srv.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ // wait for server to start
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+ t.Log("alice connecting to server")
+ clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
+ err = clientAlice.Connect()
+ if err != nil {
+ t.Fatalf("failed to connect to server: %s", err)
+ }
+ defer clientAlice.Close()
+
+ t.Log("placeholder connecting to server")
+ clientPlaceHolder := NewClient(ctx, serverURL, hmacTokenStore, "clientPlaceHolder")
+ err = clientPlaceHolder.Connect()
+ if err != nil {
+ t.Fatalf("failed to connect to server: %s", err)
+ }
+ defer clientPlaceHolder.Close()
+
+ t.Log("Bob connecting to server")
+ clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob")
+ err = clientBob.Connect()
+ if err != nil {
+ t.Fatalf("failed to connect to server: %s", err)
+ }
+ defer clientBob.Close()
+
+ t.Log("Alice open connection to Bob")
+ connAliceToBob, err := clientAlice.OpenConn("bob")
+ if err != nil {
+ t.Fatalf("failed to bind channel: %s", err)
+ }
+
+ t.Log("Bob open connection to Alice")
+ connBobToAlice, err := clientBob.OpenConn("alice")
+ if err != nil {
+ t.Fatalf("failed to bind channel: %s", err)
+ }
+
+ payload := "hello bob, I am alice"
+ _, err = connAliceToBob.Write([]byte(payload))
+ if err != nil {
+ t.Fatalf("failed to write to channel: %s", err)
+ }
+ log.Debugf("alice sent message to bob")
+
+ buf := make([]byte, 65535)
+ n, err := connBobToAlice.Read(buf)
+ if err != nil {
+ t.Fatalf("failed to read from channel: %s", err)
+ }
+ log.Debugf("on new message from alice to bob")
+
+ if payload != string(buf[:n]) {
+ t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
+ }
+}
+
+func TestRegistration(t *testing.T) {
+ ctx := context.Background()
+ srvCfg := server.ListenerConfig{Address: serverListenAddr}
+ srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ err := srv.Listen(srvCfg)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ // wait for server to start
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
+ err = clientAlice.Connect()
+ if err != nil {
+ _ = srv.Shutdown(ctx)
+ t.Fatalf("failed to connect to server: %s", err)
+ }
+ err = clientAlice.Close()
+ if err != nil {
+ t.Errorf("failed to close conn: %s", err)
+ }
+ err = srv.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+}
+
+func TestRegistrationTimeout(t *testing.T) {
+ ctx := context.Background()
+ fakeUDPListener, err := net.ListenUDP("udp", &net.UDPAddr{
+ Port: 1234,
+ IP: net.ParseIP("0.0.0.0"),
+ })
+ if err != nil {
+ t.Fatalf("failed to bind UDP server: %s", err)
+ }
+ defer func(fakeUDPListener *net.UDPConn) {
+ _ = fakeUDPListener.Close()
+ }(fakeUDPListener)
+
+ fakeTCPListener, err := net.ListenTCP("tcp", &net.TCPAddr{
+ Port: 1234,
+ IP: net.ParseIP("0.0.0.0"),
+ })
+ if err != nil {
+ t.Fatalf("failed to bind TCP server: %s", err)
+ }
+ defer func(fakeTCPListener *net.TCPListener) {
+ _ = fakeTCPListener.Close()
+ }(fakeTCPListener)
+
+ clientAlice := NewClient(ctx, "127.0.0.1:1234", hmacTokenStore, "alice")
+ err = clientAlice.Connect()
+ if err == nil {
+ t.Errorf("failed to connect to server: %s", err)
+ }
+ log.Debugf("%s", err)
+ err = clientAlice.Close()
+ if err != nil {
+ t.Errorf("failed to close conn: %s", err)
+ }
+}
+
+func TestEcho(t *testing.T) {
+ ctx := context.Background()
+ idAlice := "alice"
+ idBob := "bob"
+ srvCfg := server.ListenerConfig{Address: serverListenAddr}
+ srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ err := srv.Listen(srvCfg)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ defer func() {
+ err := srv.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ // wait for servers to start
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
+ err = clientAlice.Connect()
+ if err != nil {
+ t.Fatalf("failed to connect to server: %s", err)
+ }
+ defer func() {
+ err := clientAlice.Close()
+ if err != nil {
+ t.Errorf("failed to close Alice client: %s", err)
+ }
+ }()
+
+ clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob)
+ err = clientBob.Connect()
+ if err != nil {
+ t.Fatalf("failed to connect to server: %s", err)
+ }
+ defer func() {
+ err := clientBob.Close()
+ if err != nil {
+ t.Errorf("failed to close Bob client: %s", err)
+ }
+ }()
+
+ connAliceToBob, err := clientAlice.OpenConn(idBob)
+ if err != nil {
+ t.Fatalf("failed to bind channel: %s", err)
+ }
+
+ connBobToAlice, err := clientBob.OpenConn(idAlice)
+ if err != nil {
+ t.Fatalf("failed to bind channel: %s", err)
+ }
+
+ payload := "hello bob, I am alice"
+ _, err = connAliceToBob.Write([]byte(payload))
+ if err != nil {
+ t.Fatalf("failed to write to channel: %s", err)
+ }
+
+ buf := make([]byte, 65535)
+ n, err := connBobToAlice.Read(buf)
+ if err != nil {
+ t.Fatalf("failed to read from channel: %s", err)
+ }
+
+ _, err = connBobToAlice.Write(buf[:n])
+ if err != nil {
+ t.Fatalf("failed to write to channel: %s", err)
+ }
+
+ n, err = connAliceToBob.Read(buf)
+ if err != nil {
+ t.Fatalf("failed to read from channel: %s", err)
+ }
+
+ if payload != string(buf[:n]) {
+ t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
+ }
+}
+
+func TestBindToUnavailabePeer(t *testing.T) {
+ ctx := context.Background()
+
+ srvCfg := server.ListenerConfig{Address: serverListenAddr}
+ srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ err := srv.Listen(srvCfg)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ defer func() {
+ log.Infof("closing server")
+ err := srv.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ // wait for servers to start
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
+ err = clientAlice.Connect()
+ if err != nil {
+ t.Errorf("failed to connect to server: %s", err)
+ }
+ _, err = clientAlice.OpenConn("bob")
+ if err != nil {
+ t.Errorf("failed to bind channel: %s", err)
+ }
+
+ log.Infof("closing client")
+ err = clientAlice.Close()
+ if err != nil {
+ t.Errorf("failed to close client: %s", err)
+ }
+}
+
+func TestBindReconnect(t *testing.T) {
+ ctx := context.Background()
+
+ srvCfg := server.ListenerConfig{Address: serverListenAddr}
+ srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ err := srv.Listen(srvCfg)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ defer func() {
+ log.Infof("closing server")
+ err := srv.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ // wait for servers to start
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
+ err = clientAlice.Connect()
+ if err != nil {
+ t.Errorf("failed to connect to server: %s", err)
+ }
+
+ _, err = clientAlice.OpenConn("bob")
+ if err != nil {
+ t.Errorf("failed to bind channel: %s", err)
+ }
+
+ clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob")
+ err = clientBob.Connect()
+ if err != nil {
+ t.Errorf("failed to connect to server: %s", err)
+ }
+
+ chBob, err := clientBob.OpenConn("alice")
+ if err != nil {
+ t.Errorf("failed to bind channel: %s", err)
+ }
+
+ log.Infof("closing client Alice")
+ err = clientAlice.Close()
+ if err != nil {
+ t.Errorf("failed to close client: %s", err)
+ }
+
+ clientAlice = NewClient(ctx, serverURL, hmacTokenStore, "alice")
+ err = clientAlice.Connect()
+ if err != nil {
+ t.Errorf("failed to connect to server: %s", err)
+ }
+
+ chAlice, err := clientAlice.OpenConn("bob")
+ if err != nil {
+ t.Errorf("failed to bind channel: %s", err)
+ }
+
+ testString := "hello alice, I am bob"
+ _, err = chBob.Write([]byte(testString))
+ if err != nil {
+ t.Errorf("failed to write to channel: %s", err)
+ }
+
+ buf := make([]byte, 65535)
+ n, err := chAlice.Read(buf)
+ if err != nil {
+ t.Errorf("failed to read from channel: %s", err)
+ }
+
+ if testString != string(buf[:n]) {
+ t.Errorf("expected %s, got %s", testString, string(buf[:n]))
+ }
+
+ log.Infof("closing client")
+ err = clientAlice.Close()
+ if err != nil {
+ t.Errorf("failed to close client: %s", err)
+ }
+}
+
+func TestCloseConn(t *testing.T) {
+ ctx := context.Background()
+
+ srvCfg := server.ListenerConfig{Address: serverListenAddr}
+ srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ err := srv.Listen(srvCfg)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ defer func() {
+ log.Infof("closing server")
+ err := srv.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ // wait for servers to start
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
+ err = clientAlice.Connect()
+ if err != nil {
+ t.Errorf("failed to connect to server: %s", err)
+ }
+
+ conn, err := clientAlice.OpenConn("bob")
+ if err != nil {
+ t.Errorf("failed to bind channel: %s", err)
+ }
+
+ log.Infof("closing connection")
+ err = conn.Close()
+ if err != nil {
+ t.Errorf("failed to close connection: %s", err)
+ }
+
+ _, err = conn.Read(make([]byte, 1))
+ if err == nil {
+ t.Errorf("unexpected reading from closed connection")
+ }
+
+ _, err = conn.Write([]byte("hello"))
+ if err == nil {
+ t.Errorf("unexpected writing from closed connection")
+ }
+}
+
+func TestCloseRelayConn(t *testing.T) {
+ ctx := context.Background()
+
+ srvCfg := server.ListenerConfig{Address: serverListenAddr}
+ srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ err := srv.Listen(srvCfg)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ defer func() {
+ err := srv.Shutdown(ctx)
+ if err != nil {
+ log.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ // wait for servers to start
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice")
+ err = clientAlice.Connect()
+ if err != nil {
+ t.Fatalf("failed to connect to server: %s", err)
+ }
+
+ conn, err := clientAlice.OpenConn("bob")
+ if err != nil {
+ t.Errorf("failed to bind channel: %s", err)
+ }
+
+ _ = clientAlice.relayConn.Close()
+
+ _, err = conn.Read(make([]byte, 1))
+ if err == nil {
+ t.Errorf("unexpected reading from closed connection")
+ }
+
+ _, err = clientAlice.OpenConn("bob")
+ if err == nil {
+ t.Errorf("unexpected opening connection to closed server")
+ }
+}
+
+func TestCloseByServer(t *testing.T) {
+ ctx := context.Background()
+
+ srvCfg := server.ListenerConfig{Address: serverListenAddr}
+ srv1, err := server.NewServer(otel.Meter(""), serverURL, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+
+ go func() {
+ err := srv1.Listen(srvCfg)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ // wait for servers to start
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ idAlice := "alice"
+ log.Debugf("connect by alice")
+ relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
+ err = relayClient.Connect()
+ if err != nil {
+ log.Fatalf("failed to connect to server: %s", err)
+ }
+
+ disconnected := make(chan struct{})
+ relayClient.SetOnDisconnectListener(func() {
+ log.Infof("client disconnected")
+ close(disconnected)
+ })
+
+ err = srv1.Shutdown(ctx)
+ if err != nil {
+ t.Fatalf("failed to close server: %s", err)
+ }
+
+ select {
+ case <-disconnected:
+ case <-time.After(3 * time.Second):
+ log.Fatalf("timeout waiting for client to disconnect")
+ }
+
+ _, err = relayClient.OpenConn("bob")
+ if err == nil {
+ t.Errorf("unexpected opening connection to closed server")
+ }
+}
+
+func TestCloseByClient(t *testing.T) {
+ ctx := context.Background()
+
+ srvCfg := server.ListenerConfig{Address: serverListenAddr}
+ srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ err := srv.Listen(srvCfg)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ // wait for servers to start
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ idAlice := "alice"
+ log.Debugf("connect by alice")
+ relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
+ err = relayClient.Connect()
+ if err != nil {
+ log.Fatalf("failed to connect to server: %s", err)
+ }
+
+ err = relayClient.Close()
+ if err != nil {
+ t.Errorf("failed to close client: %s", err)
+ }
+
+ _, err = relayClient.OpenConn("bob")
+ if err == nil {
+ t.Errorf("unexpected opening connection to closed server")
+ }
+
+ err = srv.Shutdown(ctx)
+ if err != nil {
+ t.Fatalf("failed to close server: %s", err)
+ }
+}
+
+func TestCloseNotDrainedChannel(t *testing.T) {
+ ctx := context.Background()
+ idAlice := "alice"
+ idBob := "bob"
+ srvCfg := server.ListenerConfig{Address: serverListenAddr}
+ srv, err := server.NewServer(otel.Meter(""), serverURL, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ err := srv.Listen(srvCfg)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ defer func() {
+ err := srv.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ // wait for servers to start
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice)
+ err = clientAlice.Connect()
+ if err != nil {
+ t.Fatalf("failed to connect to server: %s", err)
+ }
+ defer func() {
+ err := clientAlice.Close()
+ if err != nil {
+ t.Errorf("failed to close Alice client: %s", err)
+ }
+ }()
+
+ clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob)
+ err = clientBob.Connect()
+ if err != nil {
+ t.Fatalf("failed to connect to server: %s", err)
+ }
+ defer func() {
+ err := clientBob.Close()
+ if err != nil {
+ t.Errorf("failed to close Bob client: %s", err)
+ }
+ }()
+
+ connAliceToBob, err := clientAlice.OpenConn(idBob)
+ if err != nil {
+ t.Fatalf("failed to bind channel: %s", err)
+ }
+
+ connBobToAlice, err := clientBob.OpenConn(idAlice)
+ if err != nil {
+ t.Fatalf("failed to bind channel: %s", err)
+ }
+
+ payload := "hello bob, I am alice"
+ // the internal channel buffer size is 2. So we should overflow it
+ for i := 0; i < 5; i++ {
+ _, err = connAliceToBob.Write([]byte(payload))
+ if err != nil {
+ t.Fatalf("failed to write to channel: %s", err)
+ }
+
+ }
+
+ // wait for delivery
+ time.Sleep(1 * time.Second)
+ err = connBobToAlice.Close()
+ if err != nil {
+ t.Errorf("failed to close channel: %s", err)
+ }
+}
+
+func waitForServerToStart(errChan chan error) error {
+ select {
+ case err := <-errChan:
+ if err != nil {
+ return err
+ }
+ case <-time.After(300 * time.Millisecond):
+ return nil
+ }
+ return nil
+}
diff --git a/relay/client/conn.go b/relay/client/conn.go
new file mode 100644
index 000000000..b4ff903e8
--- /dev/null
+++ b/relay/client/conn.go
@@ -0,0 +1,76 @@
+package client
+
+import (
+ "io"
+ "net"
+ "time"
+)
+
+// Conn represent a connection to a relayed remote peer.
+type Conn struct {
+ client *Client
+ dstID []byte
+ dstStringID string
+ messageChan chan Msg
+ instanceURL *RelayAddr
+}
+
+// NewConn creates a new connection to a relayed remote peer.
+// client: the client instance, it used to send messages to the destination peer
+// dstID: the destination peer ID
+// dstStringID: the destination peer ID in string format
+// messageChan: the channel where the messages will be received
+// instanceURL: the relay instance URL, it used to get the proper server instance address for the remote peer
+func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan Msg, instanceURL *RelayAddr) *Conn {
+ c := &Conn{
+ client: client,
+ dstID: dstID,
+ dstStringID: dstStringID,
+ messageChan: messageChan,
+ instanceURL: instanceURL,
+ }
+
+ return c
+}
+
+func (c *Conn) Write(p []byte) (n int, err error) {
+ return c.client.writeTo(c, c.dstStringID, c.dstID, p)
+}
+
+func (c *Conn) Read(b []byte) (n int, err error) {
+ msg, ok := <-c.messageChan
+ if !ok {
+ return 0, io.EOF
+ }
+
+ n = copy(b, msg.Payload)
+ msg.Free()
+ return n, nil
+}
+
+func (c *Conn) Close() error {
+ return c.client.closeConn(c, c.dstStringID)
+}
+
+func (c *Conn) LocalAddr() net.Addr {
+ return c.client.relayConn.LocalAddr()
+}
+
+func (c *Conn) RemoteAddr() net.Addr {
+ return c.instanceURL
+}
+
+func (c *Conn) SetDeadline(t time.Time) error {
+ //TODO implement me
+ panic("SetDeadline is not implemented")
+}
+
+func (c *Conn) SetReadDeadline(t time.Time) error {
+ //TODO implement me
+ panic("SetReadDeadline is not implemented")
+}
+
+func (c *Conn) SetWriteDeadline(t time.Time) error {
+ //TODO implement me
+ panic("SetReadDeadline is not implemented")
+}
diff --git a/relay/client/dialer/ws/addr.go b/relay/client/dialer/ws/addr.go
new file mode 100644
index 000000000..43f5dd6af
--- /dev/null
+++ b/relay/client/dialer/ws/addr.go
@@ -0,0 +1,13 @@
+package ws
+
+type WebsocketAddr struct {
+ addr string
+}
+
+func (a WebsocketAddr) Network() string {
+ return "websocket"
+}
+
+func (a WebsocketAddr) String() string {
+ return a.addr
+}
diff --git a/relay/client/dialer/ws/conn.go b/relay/client/dialer/ws/conn.go
new file mode 100644
index 000000000..e7f771b8d
--- /dev/null
+++ b/relay/client/dialer/ws/conn.go
@@ -0,0 +1,66 @@
+package ws
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "time"
+
+ "nhooyr.io/websocket"
+)
+
+type Conn struct {
+ ctx context.Context
+ *websocket.Conn
+ remoteAddr WebsocketAddr
+}
+
+func NewConn(wsConn *websocket.Conn, serverAddress string) net.Conn {
+ return &Conn{
+ ctx: context.Background(),
+ Conn: wsConn,
+ remoteAddr: WebsocketAddr{serverAddress},
+ }
+}
+
+func (c *Conn) Read(b []byte) (n int, err error) {
+ t, ioReader, err := c.Conn.Reader(c.ctx)
+ if err != nil {
+ return 0, err
+ }
+
+ if t != websocket.MessageBinary {
+ return 0, fmt.Errorf("unexpected message type")
+ }
+
+ return ioReader.Read(b)
+}
+
+func (c *Conn) Write(b []byte) (n int, err error) {
+ err = c.Conn.Write(c.ctx, websocket.MessageBinary, b)
+ return 0, err
+}
+
+func (c *Conn) RemoteAddr() net.Addr {
+ return c.remoteAddr
+}
+
+func (c *Conn) LocalAddr() net.Addr {
+ return WebsocketAddr{addr: "unknown"}
+}
+
+func (c *Conn) SetReadDeadline(t time.Time) error {
+ return fmt.Errorf("SetReadDeadline is not implemented")
+}
+
+func (c *Conn) SetWriteDeadline(t time.Time) error {
+ return fmt.Errorf("SetWriteDeadline is not implemented")
+}
+
+func (c *Conn) SetDeadline(t time.Time) error {
+ return fmt.Errorf("SetDeadline is not implemented")
+}
+
+func (c *Conn) Close() error {
+ return c.Conn.CloseNow()
+}
diff --git a/relay/client/dialer/ws/ws.go b/relay/client/dialer/ws/ws.go
new file mode 100644
index 000000000..d9388aafd
--- /dev/null
+++ b/relay/client/dialer/ws/ws.go
@@ -0,0 +1,67 @@
+package ws
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "net/http"
+ "net/url"
+ "strings"
+
+ log "github.com/sirupsen/logrus"
+ "nhooyr.io/websocket"
+
+ "github.com/netbirdio/netbird/relay/server/listener/ws"
+ nbnet "github.com/netbirdio/netbird/util/net"
+)
+
+func Dial(address string) (net.Conn, error) {
+ wsURL, err := prepareURL(address)
+ if err != nil {
+ return nil, err
+ }
+
+ opts := &websocket.DialOptions{
+ HTTPClient: httpClientNbDialer(),
+ }
+
+ parsedURL, err := url.Parse(wsURL)
+ if err != nil {
+ return nil, err
+ }
+ parsedURL.Path = ws.URLPath
+
+ wsConn, resp, err := websocket.Dial(context.Background(), parsedURL.String(), opts)
+ if err != nil {
+ log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err)
+ return nil, err
+ }
+ if resp.Body != nil {
+ _ = resp.Body.Close()
+ }
+
+ conn := NewConn(wsConn, address)
+ return conn, nil
+}
+
+func prepareURL(address string) (string, error) {
+ if !strings.HasPrefix(address, "rel:") && !strings.HasPrefix(address, "rels:") {
+ return "", fmt.Errorf("unsupported scheme: %s", address)
+ }
+
+ return strings.Replace(address, "rel", "ws", 1), nil
+}
+
+func httpClientNbDialer() *http.Client {
+ customDialer := nbnet.NewDialer()
+
+ customTransport := &http.Transport{
+ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ return customDialer.DialContext(ctx, network, addr)
+ },
+ }
+
+ return &http.Client{
+ Transport: customTransport,
+ }
+}
diff --git a/relay/client/doc.go b/relay/client/doc.go
new file mode 100644
index 000000000..1339251d9
--- /dev/null
+++ b/relay/client/doc.go
@@ -0,0 +1,12 @@
+/*
+Package client contains the implementation of the Relay client.
+
+The Relay client is responsible for establishing a connection with the Relay server and sending and receiving messages,
+Keep persistent connection with the Relay server and handle the connection issues.
+It uses the WebSocket protocol for communication and optionally supports TLS (Transport Layer Security).
+
+If a peer wants to communicate with a peer on a different relay server, the manager will establish a new connection to
+the relay server. The connection with these relay servers will be closed if there is no active connection. The peers
+negotiate the common relay instance via signaling service.
+*/
+package client
diff --git a/relay/client/guard.go b/relay/client/guard.go
new file mode 100644
index 000000000..f826cf1b6
--- /dev/null
+++ b/relay/client/guard.go
@@ -0,0 +1,48 @@
+package client
+
+import (
+ "context"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+)
+
+var (
+ reconnectingTimeout = 5 * time.Second
+)
+
+// Guard manage the reconnection tries to the Relay server in case of disconnection event.
+type Guard struct {
+ ctx context.Context
+ relayClient *Client
+}
+
+// NewGuard creates a new guard for the relay client.
+func NewGuard(context context.Context, relayClient *Client) *Guard {
+ g := &Guard{
+ ctx: context,
+ relayClient: relayClient,
+ }
+ return g
+}
+
+// OnDisconnected is called when the relay client is disconnected from the relay server. It will trigger the reconnection
+// todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent
+func (g *Guard) OnDisconnected() {
+ ticker := time.NewTicker(reconnectingTimeout)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ticker.C:
+ err := g.relayClient.Connect()
+ if err != nil {
+ log.Errorf("failed to reconnect to relay server: %s", err)
+ continue
+ }
+ return
+ case <-g.ctx.Done():
+ return
+ }
+ }
+}
diff --git a/relay/client/manager.go b/relay/client/manager.go
new file mode 100644
index 000000000..4554c7c0f
--- /dev/null
+++ b/relay/client/manager.go
@@ -0,0 +1,326 @@
+package client
+
+import (
+ "container/list"
+ "context"
+ "fmt"
+ "net"
+ "reflect"
+ "sync"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+
+ relayAuth "github.com/netbirdio/netbird/relay/auth/hmac"
+)
+
+var (
+ relayCleanupInterval = 60 * time.Second
+
+ ErrRelayClientNotConnected = fmt.Errorf("relay client not connected")
+)
+
+// RelayTrack hold the relay clients for the foreign relay servers.
+// With the mutex can ensure we can open new connection in case the relay connection has been established with
+// the relay server.
+type RelayTrack struct {
+ sync.RWMutex
+ relayClient *Client
+ err error
+}
+
+func NewRelayTrack() *RelayTrack {
+ return &RelayTrack{}
+}
+
+type OnServerCloseListener func()
+
+// ManagerService is the interface for the relay manager.
+type ManagerService interface {
+ Serve() error
+ OpenConn(serverAddress, peerKey string) (net.Conn, error)
+ AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error
+ RelayInstanceAddress() (string, error)
+ ServerURLs() []string
+ HasRelayAddress() bool
+ UpdateToken(token *relayAuth.Token) error
+}
+
+// Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL
+// and automatically reconnect to them in case disconnection.
+// The manager also manage temporary relay connection. If a client wants to communicate with a client on a
+// different relay server, the manager will establish a new connection to the relay server. The connection with these
+// relay servers will be closed if there is no active connection. Periodically the manager will check if there is any
+// unused relay connection and close it.
+type Manager struct {
+ ctx context.Context
+ serverURLs []string
+ peerID string
+ tokenStore *relayAuth.TokenStore
+
+ relayClient *Client
+ reconnectGuard *Guard
+
+ relayClients map[string]*RelayTrack
+ relayClientsMutex sync.RWMutex
+
+ onDisconnectedListeners map[string]*list.List
+ listenerLock sync.Mutex
+}
+
+// NewManager creates a new manager instance.
+// The serverURL address can be empty. In this case, the manager will not serve.
+func NewManager(ctx context.Context, serverURLs []string, peerID string) *Manager {
+ return &Manager{
+ ctx: ctx,
+ serverURLs: serverURLs,
+ peerID: peerID,
+ tokenStore: &relayAuth.TokenStore{},
+ relayClients: make(map[string]*RelayTrack),
+ onDisconnectedListeners: make(map[string]*list.List),
+ }
+}
+
+// Serve starts the manager. It will establish a connection to the relay server and start the relay cleanup loop for
+// the unused relay connections. The manager will automatically reconnect to the relay server in case of disconnection.
+func (m *Manager) Serve() error {
+ if m.relayClient != nil {
+ return fmt.Errorf("manager already serving")
+ }
+ log.Debugf("starting relay client manager with %v relay servers", m.serverURLs)
+
+ sp := ServerPicker{
+ TokenStore: m.tokenStore,
+ PeerID: m.peerID,
+ }
+
+ client, err := sp.PickServer(m.ctx, m.serverURLs)
+ if err != nil {
+ return err
+ }
+ m.relayClient = client
+
+ m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
+ m.relayClient.SetOnDisconnectListener(func() {
+ m.onServerDisconnected(client.connectionURL)
+ })
+ m.startCleanupLoop()
+ return nil
+}
+
+// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
+// established via the relay server. If the peer is on a different relay server, the manager will establish a new
+// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection.
+func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
+ if m.relayClient == nil {
+ return nil, ErrRelayClientNotConnected
+ }
+
+ foreign, err := m.isForeignServer(serverAddress)
+ if err != nil {
+ return nil, err
+ }
+
+ var (
+ netConn net.Conn
+ )
+ if !foreign {
+ log.Debugf("open peer connection via permanent server: %s", peerKey)
+ netConn, err = m.relayClient.OpenConn(peerKey)
+ } else {
+ log.Debugf("open peer connection via foreign server: %s", serverAddress)
+ netConn, err = m.openConnVia(serverAddress, peerKey)
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ return netConn, err
+}
+
+// AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection
+// closed.
+func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error {
+ foreign, err := m.isForeignServer(serverAddress)
+ if err != nil {
+ return err
+ }
+
+ var listenerAddr string
+ if foreign {
+ listenerAddr = serverAddress
+ } else {
+ listenerAddr = m.relayClient.connectionURL
+ }
+ m.addListener(listenerAddr, onClosedListener)
+ return nil
+}
+
+// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is
+// lost. This address will be sent to the target peer to choose the common relay server for the communication.
+func (m *Manager) RelayInstanceAddress() (string, error) {
+ if m.relayClient == nil {
+ return "", ErrRelayClientNotConnected
+ }
+ return m.relayClient.ServerInstanceURL()
+}
+
+// ServerURLs returns the addresses of the relay servers.
+func (m *Manager) ServerURLs() []string {
+ return m.serverURLs
+}
+
+// HasRelayAddress returns true if the manager is serving. With this method can check if the peer can communicate with
+// Relay service.
+func (m *Manager) HasRelayAddress() bool {
+ return len(m.serverURLs) > 0
+}
+
+// UpdateToken updates the token in the token store.
+func (m *Manager) UpdateToken(token *relayAuth.Token) error {
+ return m.tokenStore.UpdateToken(token)
+}
+
+func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
+ // check if already has a connection to the desired relay server
+ m.relayClientsMutex.RLock()
+ rt, ok := m.relayClients[serverAddress]
+ if ok {
+ rt.RLock()
+ m.relayClientsMutex.RUnlock()
+ defer rt.RUnlock()
+ if rt.err != nil {
+ return nil, rt.err
+ }
+ return rt.relayClient.OpenConn(peerKey)
+ }
+ m.relayClientsMutex.RUnlock()
+
+ // if not, establish a new connection but check it again (because changed the lock type) before starting the
+ // connection
+ m.relayClientsMutex.Lock()
+ rt, ok = m.relayClients[serverAddress]
+ if ok {
+ rt.RLock()
+ m.relayClientsMutex.Unlock()
+ defer rt.RUnlock()
+ if rt.err != nil {
+ return nil, rt.err
+ }
+ return rt.relayClient.OpenConn(peerKey)
+ }
+
+ // create a new relay client and store it in the relayClients map
+ rt = NewRelayTrack()
+ rt.Lock()
+ m.relayClients[serverAddress] = rt
+ m.relayClientsMutex.Unlock()
+
+ relayClient := NewClient(m.ctx, serverAddress, m.tokenStore, m.peerID)
+ err := relayClient.Connect()
+ if err != nil {
+ rt.err = err
+ rt.Unlock()
+ m.relayClientsMutex.Lock()
+ delete(m.relayClients, serverAddress)
+ m.relayClientsMutex.Unlock()
+ return nil, err
+ }
+ // if connection closed then delete the relay client from the list
+ relayClient.SetOnDisconnectListener(func() {
+ m.onServerDisconnected(serverAddress)
+ })
+ rt.relayClient = relayClient
+ rt.Unlock()
+
+ conn, err := relayClient.OpenConn(peerKey)
+ if err != nil {
+ return nil, err
+ }
+ return conn, nil
+}
+
+func (m *Manager) onServerDisconnected(serverAddress string) {
+ if serverAddress == m.relayClient.connectionURL {
+ go m.reconnectGuard.OnDisconnected()
+ }
+
+ m.notifyOnDisconnectListeners(serverAddress)
+}
+
+func (m *Manager) isForeignServer(address string) (bool, error) {
+ rAddr, err := m.relayClient.ServerInstanceURL()
+ if err != nil {
+ return false, fmt.Errorf("relay client not connected")
+ }
+ return rAddr != address, nil
+}
+
+func (m *Manager) startCleanupLoop() {
+ if m.ctx.Err() != nil {
+ return
+ }
+
+ ticker := time.NewTicker(relayCleanupInterval)
+ go func() {
+ defer ticker.Stop()
+ for {
+ select {
+ case <-m.ctx.Done():
+ return
+ case <-ticker.C:
+ m.cleanUpUnusedRelays()
+ }
+ }
+ }()
+}
+
+func (m *Manager) cleanUpUnusedRelays() {
+ m.relayClientsMutex.Lock()
+ defer m.relayClientsMutex.Unlock()
+
+ for addr, rt := range m.relayClients {
+ rt.Lock()
+ if rt.relayClient.HasConns() {
+ rt.Unlock()
+ continue
+ }
+ rt.relayClient.SetOnDisconnectListener(nil)
+ go func() {
+ _ = rt.relayClient.Close()
+ }()
+ log.Debugf("clean up unused relay server connection: %s", addr)
+ delete(m.relayClients, addr)
+ rt.Unlock()
+ }
+}
+
+func (m *Manager) addListener(serverAddress string, onClosedListener OnServerCloseListener) {
+ m.listenerLock.Lock()
+ defer m.listenerLock.Unlock()
+ l, ok := m.onDisconnectedListeners[serverAddress]
+ if !ok {
+ l = list.New()
+ }
+ for e := l.Front(); e != nil; e = e.Next() {
+ if reflect.ValueOf(e.Value).Pointer() == reflect.ValueOf(onClosedListener).Pointer() {
+ return
+ }
+ }
+ l.PushBack(onClosedListener)
+ m.onDisconnectedListeners[serverAddress] = l
+}
+
+func (m *Manager) notifyOnDisconnectListeners(serverAddress string) {
+ m.listenerLock.Lock()
+ defer m.listenerLock.Unlock()
+
+ l, ok := m.onDisconnectedListeners[serverAddress]
+ if !ok {
+ return
+ }
+ for e := l.Front(); e != nil; e = e.Next() {
+ go e.Value.(OnServerCloseListener)()
+ }
+ delete(m.onDisconnectedListeners, serverAddress)
+}
diff --git a/relay/client/manager_test.go b/relay/client/manager_test.go
new file mode 100644
index 000000000..e9cc2c581
--- /dev/null
+++ b/relay/client/manager_test.go
@@ -0,0 +1,432 @@
+package client
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+ "go.opentelemetry.io/otel"
+
+ "github.com/netbirdio/netbird/relay/server"
+)
+
+func TestEmptyURL(t *testing.T) {
+ mgr := NewManager(context.Background(), nil, "alice")
+ err := mgr.Serve()
+ if err == nil {
+ t.Errorf("expected error, got nil")
+ }
+}
+
+func TestForeignConn(t *testing.T) {
+ ctx := context.Background()
+
+ srvCfg1 := server.ListenerConfig{
+ Address: "localhost:1234",
+ }
+ srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ err := srv1.Listen(srvCfg1)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ defer func() {
+ err := srv1.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ srvCfg2 := server.ListenerConfig{
+ Address: "localhost:2234",
+ }
+ srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan2 := make(chan error, 1)
+ go func() {
+ err := srv2.Listen(srvCfg2)
+ if err != nil {
+ errChan2 <- err
+ }
+ }()
+
+ defer func() {
+ err := srv2.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ if err := waitForServerToStart(errChan2); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ idAlice := "alice"
+ log.Debugf("connect by alice")
+ mCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice)
+ err = clientAlice.Serve()
+ if err != nil {
+ t.Fatalf("failed to serve manager: %s", err)
+ }
+
+ idBob := "bob"
+ log.Debugf("connect by bob")
+ clientBob := NewManager(mCtx, toURL(srvCfg2), idBob)
+ err = clientBob.Serve()
+ if err != nil {
+ t.Fatalf("failed to serve manager: %s", err)
+ }
+ bobsSrvAddr, err := clientBob.RelayInstanceAddress()
+ if err != nil {
+ t.Fatalf("failed to get relay address: %s", err)
+ }
+ connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr, idBob)
+ if err != nil {
+ t.Fatalf("failed to bind channel: %s", err)
+ }
+ connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr, idAlice)
+ if err != nil {
+ t.Fatalf("failed to bind channel: %s", err)
+ }
+
+ payload := "hello bob, I am alice"
+ _, err = connAliceToBob.Write([]byte(payload))
+ if err != nil {
+ t.Fatalf("failed to write to channel: %s", err)
+ }
+
+ buf := make([]byte, 65535)
+ n, err := connBobToAlice.Read(buf)
+ if err != nil {
+ t.Fatalf("failed to read from channel: %s", err)
+ }
+
+ _, err = connBobToAlice.Write(buf[:n])
+ if err != nil {
+ t.Fatalf("failed to write to channel: %s", err)
+ }
+
+ n, err = connAliceToBob.Read(buf)
+ if err != nil {
+ t.Fatalf("failed to read from channel: %s", err)
+ }
+
+ if payload != string(buf[:n]) {
+ t.Fatalf("expected %s, got %s", payload, string(buf[:n]))
+ }
+}
+
+func TestForeginConnClose(t *testing.T) {
+ ctx := context.Background()
+
+ srvCfg1 := server.ListenerConfig{
+ Address: "localhost:1234",
+ }
+ srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ err := srv1.Listen(srvCfg1)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ defer func() {
+ err := srv1.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ srvCfg2 := server.ListenerConfig{
+ Address: "localhost:2234",
+ }
+ srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan2 := make(chan error, 1)
+ go func() {
+ err := srv2.Listen(srvCfg2)
+ if err != nil {
+ errChan2 <- err
+ }
+ }()
+
+ defer func() {
+ err := srv2.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ if err := waitForServerToStart(errChan2); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ idAlice := "alice"
+ log.Debugf("connect by alice")
+ mCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ mgr := NewManager(mCtx, toURL(srvCfg1), idAlice)
+ err = mgr.Serve()
+ if err != nil {
+ t.Fatalf("failed to serve manager: %s", err)
+ }
+ conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer")
+ if err != nil {
+ t.Fatalf("failed to bind channel: %s", err)
+ }
+
+ err = conn.Close()
+ if err != nil {
+ t.Fatalf("failed to close connection: %s", err)
+ }
+}
+
+func TestForeginAutoClose(t *testing.T) {
+ ctx := context.Background()
+ relayCleanupInterval = 1 * time.Second
+ srvCfg1 := server.ListenerConfig{
+ Address: "localhost:1234",
+ }
+ srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ t.Log("binding server 1.")
+ err := srv1.Listen(srvCfg1)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ defer func() {
+ t.Logf("closing server 1.")
+ err := srv1.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ t.Logf("server 1. closed")
+ }()
+
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ srvCfg2 := server.ListenerConfig{
+ Address: "localhost:2234",
+ }
+ srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan2 := make(chan error, 1)
+ go func() {
+ t.Log("binding server 2.")
+ err := srv2.Listen(srvCfg2)
+ if err != nil {
+ errChan2 <- err
+ }
+ }()
+ defer func() {
+ t.Logf("closing server 2.")
+ err := srv2.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ t.Logf("server 2 closed.")
+ }()
+
+ if err := waitForServerToStart(errChan2); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ idAlice := "alice"
+ t.Log("connect to server 1.")
+ mCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ mgr := NewManager(mCtx, toURL(srvCfg1), idAlice)
+ err = mgr.Serve()
+ if err != nil {
+ t.Fatalf("failed to serve manager: %s", err)
+ }
+
+ t.Log("open connection to another peer")
+ conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer")
+ if err != nil {
+ t.Fatalf("failed to bind channel: %s", err)
+ }
+
+ t.Log("close conn")
+ err = conn.Close()
+ if err != nil {
+ t.Fatalf("failed to close connection: %s", err)
+ }
+
+ t.Logf("waiting for relay cleanup: %s", relayCleanupInterval+1*time.Second)
+ time.Sleep(relayCleanupInterval + 1*time.Second)
+ if len(mgr.relayClients) != 0 {
+ t.Errorf("expected 0, got %d", len(mgr.relayClients))
+ }
+
+ t.Logf("closing manager")
+}
+
+func TestAutoReconnect(t *testing.T) {
+ ctx := context.Background()
+ reconnectingTimeout = 2 * time.Second
+
+ srvCfg := server.ListenerConfig{
+ Address: "localhost:1234",
+ }
+ srv, err := server.NewServer(otel.Meter(""), srvCfg.Address, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ err := srv.Listen(srvCfg)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ defer func() {
+ err := srv.Shutdown(ctx)
+ if err != nil {
+ log.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ mCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ clientAlice := NewManager(mCtx, toURL(srvCfg), "alice")
+ err = clientAlice.Serve()
+ if err != nil {
+ t.Fatalf("failed to serve manager: %s", err)
+ }
+ ra, err := clientAlice.RelayInstanceAddress()
+ if err != nil {
+ t.Errorf("failed to get relay address: %s", err)
+ }
+ conn, err := clientAlice.OpenConn(ra, "bob")
+ if err != nil {
+ t.Errorf("failed to bind channel: %s", err)
+ }
+
+ t.Log("closing client relay connection")
+ // todo figure out moc server
+ _ = clientAlice.relayClient.relayConn.Close()
+ t.Log("start test reading")
+ _, err = conn.Read(make([]byte, 1))
+ if err == nil {
+ t.Errorf("unexpected reading from closed connection")
+ }
+
+ log.Infof("waiting for reconnection")
+ time.Sleep(reconnectingTimeout + 1*time.Second)
+
+ log.Infof("reopent the connection")
+ _, err = clientAlice.OpenConn(ra, "bob")
+ if err != nil {
+ t.Errorf("failed to open channel: %s", err)
+ }
+}
+
+func TestNotifierDoubleAdd(t *testing.T) {
+ ctx := context.Background()
+
+ srvCfg1 := server.ListenerConfig{
+ Address: "localhost:1234",
+ }
+ srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ err := srv1.Listen(srvCfg1)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ defer func() {
+ err := srv1.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ idAlice := "alice"
+ log.Debugf("connect by alice")
+ mCtx, cancel := context.WithCancel(ctx)
+ defer cancel()
+ clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice)
+ err = clientAlice.Serve()
+ if err != nil {
+ t.Fatalf("failed to serve manager: %s", err)
+ }
+
+ conn1, err := clientAlice.OpenConn(clientAlice.ServerURLs()[0], "idBob")
+ if err != nil {
+ t.Fatalf("failed to bind channel: %s", err)
+ }
+
+ fnCloseListener := OnServerCloseListener(func() {
+ log.Infof("close listener")
+ })
+
+ err = clientAlice.AddCloseListener(clientAlice.ServerURLs()[0], fnCloseListener)
+ if err != nil {
+ t.Fatalf("failed to add close listener: %s", err)
+ }
+
+ err = clientAlice.AddCloseListener(clientAlice.ServerURLs()[0], fnCloseListener)
+ if err != nil {
+ t.Fatalf("failed to add close listener: %s", err)
+ }
+
+ err = conn1.Close()
+ if err != nil {
+ t.Errorf("failed to close connection: %s", err)
+ }
+
+}
+
+func toURL(address server.ListenerConfig) []string {
+ return []string{"rel://" + address.Address}
+}
diff --git a/relay/client/picker.go b/relay/client/picker.go
new file mode 100644
index 000000000..13b0547aa
--- /dev/null
+++ b/relay/client/picker.go
@@ -0,0 +1,98 @@
+package client
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+
+ auth "github.com/netbirdio/netbird/relay/auth/hmac"
+)
+
+const (
+ connectionTimeout = 30 * time.Second
+ maxConcurrentServers = 7
+)
+
+type connResult struct {
+ RelayClient *Client
+ Url string
+ Err error
+}
+
+type ServerPicker struct {
+ TokenStore *auth.TokenStore
+ PeerID string
+}
+
+func (sp *ServerPicker) PickServer(parentCtx context.Context, urls []string) (*Client, error) {
+ ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout)
+ defer cancel()
+
+ totalServers := len(urls)
+
+ connResultChan := make(chan connResult, totalServers)
+ successChan := make(chan connResult, 1)
+ concurrentLimiter := make(chan struct{}, maxConcurrentServers)
+
+ for _, url := range urls {
+ // todo check if we have a successful connection so we do not need to connect to other servers
+ concurrentLimiter <- struct{}{}
+ go func(url string) {
+ defer func() {
+ <-concurrentLimiter
+ }()
+ sp.startConnection(parentCtx, connResultChan, url)
+ }(url)
+ }
+
+ go sp.processConnResults(connResultChan, successChan)
+
+ select {
+ case cr, ok := <-successChan:
+ if !ok {
+ return nil, errors.New("failed to connect to any relay server: all attempts failed")
+ }
+ log.Infof("chosen home Relay server: %s", cr.Url)
+ return cr.RelayClient, nil
+ case <-ctx.Done():
+ return nil, fmt.Errorf("failed to connect to any relay server: %w", ctx.Err())
+ }
+}
+
+func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) {
+ log.Infof("try to connecting to relay server: %s", url)
+ relayClient := NewClient(ctx, url, sp.TokenStore, sp.PeerID)
+ err := relayClient.Connect()
+ resultChan <- connResult{
+ RelayClient: relayClient,
+ Url: url,
+ Err: err,
+ }
+}
+
+func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult) {
+ var hasSuccess bool
+ for numOfResults := 0; numOfResults < cap(resultChan); numOfResults++ {
+ cr := <-resultChan
+ if cr.Err != nil {
+ log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
+ continue
+ }
+ log.Infof("connected to Relay server: %s", cr.Url)
+
+ if hasSuccess {
+ log.Infof("closing unnecessary Relay connection to: %s", cr.Url)
+ if err := cr.RelayClient.Close(); err != nil {
+ log.Errorf("failed to close connection to %s: %v", cr.Url, err)
+ }
+ continue
+ }
+
+ hasSuccess = true
+ successChan <- cr
+ }
+ close(successChan)
+}
diff --git a/relay/client/picker_test.go b/relay/client/picker_test.go
new file mode 100644
index 000000000..eb14581e0
--- /dev/null
+++ b/relay/client/picker_test.go
@@ -0,0 +1,31 @@
+package client
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "time"
+)
+
+func TestServerPicker_UnavailableServers(t *testing.T) {
+ sp := ServerPicker{
+ TokenStore: nil,
+ PeerID: "test",
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ go func() {
+ _, err := sp.PickServer(ctx, []string{"rel://dummy1", "rel://dummy2"})
+ if err == nil {
+ t.Error(err)
+ }
+ cancel()
+ }()
+
+ <-ctx.Done()
+ if errors.Is(ctx.Err(), context.DeadlineExceeded) {
+ t.Errorf("PickServer() took too long to complete")
+ }
+}
diff --git a/relay/cmd/env.go b/relay/cmd/env.go
new file mode 100644
index 000000000..3c15ebe1f
--- /dev/null
+++ b/relay/cmd/env.go
@@ -0,0 +1,35 @@
+package cmd
+
+import (
+ "os"
+ "strings"
+
+ log "github.com/sirupsen/logrus"
+ "github.com/spf13/cobra"
+ "github.com/spf13/pflag"
+)
+
+// setFlagsFromEnvVars reads and updates flag values from environment variables with prefix NB_
+func setFlagsFromEnvVars(cmd *cobra.Command) {
+ flags := cmd.PersistentFlags()
+ flags.VisitAll(func(f *pflag.Flag) {
+ newEnvVar := flagNameToEnvVar(f.Name, "NB_")
+ value, present := os.LookupEnv(newEnvVar)
+ if !present {
+ return
+ }
+
+ err := flags.Set(f.Name, value)
+ if err != nil {
+ log.Infof("unable to configure flag %s using variable %s, err: %v", f.Name, newEnvVar, err)
+ }
+ })
+}
+
+// flagNameToEnvVar converts flag name to environment var name adding a prefix,
+// replacing dashes and making all uppercase (e.g. setup-keys is converted to NB_SETUP_KEYS according to the input prefix)
+func flagNameToEnvVar(cmdFlag string, prefix string) string {
+ parsed := strings.ReplaceAll(cmdFlag, "-", "_")
+ upper := strings.ToUpper(parsed)
+ return prefix + upper
+}
diff --git a/relay/cmd/root.go b/relay/cmd/root.go
new file mode 100644
index 000000000..d603ff73b
--- /dev/null
+++ b/relay/cmd/root.go
@@ -0,0 +1,215 @@
+package cmd
+
+import (
+ "context"
+ "crypto/sha256"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "net/http"
+ "os"
+ "os/signal"
+ "syscall"
+ "time"
+
+ "github.com/hashicorp/go-multierror"
+ log "github.com/sirupsen/logrus"
+ "github.com/spf13/cobra"
+
+ "github.com/netbirdio/netbird/encryption"
+ "github.com/netbirdio/netbird/relay/auth"
+ "github.com/netbirdio/netbird/relay/server"
+ "github.com/netbirdio/netbird/signal/metrics"
+ "github.com/netbirdio/netbird/util"
+)
+
+type Config struct {
+ ListenAddress string
+ // in HA every peer connect to a common domain, the instance domain has been distributed during the p2p connection
+ // it is a domain:port or ip:port
+ ExposedAddress string
+ MetricsPort int
+ LetsencryptEmail string
+ LetsencryptDataDir string
+ LetsencryptDomains []string
+ // in case of using Route 53 for DNS challenge the credentials should be provided in the environment variables or
+ // in the AWS credentials file
+ LetsencryptAWSRoute53 bool
+ TlsCertFile string
+ TlsKeyFile string
+ AuthSecret string
+ LogLevel string
+ LogFile string
+}
+
+func (c Config) Validate() error {
+ if c.ExposedAddress == "" {
+ return fmt.Errorf("exposed address is required")
+ }
+ if c.AuthSecret == "" {
+ return fmt.Errorf("auth secret is required")
+ }
+ return nil
+}
+
+func (c Config) HasCertConfig() bool {
+ return c.TlsCertFile != "" && c.TlsKeyFile != ""
+}
+
+func (c Config) HasLetsEncrypt() bool {
+ return c.LetsencryptDataDir != "" && c.LetsencryptDomains != nil && len(c.LetsencryptDomains) > 0
+}
+
+var (
+ cobraConfig *Config
+ rootCmd = &cobra.Command{
+ Use: "relay",
+ Short: "Relay service",
+ Long: "Relay service for Netbird agents",
+ SilenceUsage: true,
+ SilenceErrors: true,
+ RunE: execute,
+ }
+)
+
+func init() {
+ _ = util.InitLog("trace", "console")
+ cobraConfig = &Config{}
+ rootCmd.PersistentFlags().StringVarP(&cobraConfig.ListenAddress, "listen-address", "l", ":443", "listen address")
+ rootCmd.PersistentFlags().StringVarP(&cobraConfig.ExposedAddress, "exposed-address", "e", "", "instance domain address (or ip) and port, it will be distributes between peers")
+ rootCmd.PersistentFlags().IntVar(&cobraConfig.MetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
+ rootCmd.PersistentFlags().StringVarP(&cobraConfig.LetsencryptDataDir, "letsencrypt-data-dir", "d", "", "a directory to store Let's Encrypt data. Required if Let's Encrypt is enabled.")
+ rootCmd.PersistentFlags().StringSliceVarP(&cobraConfig.LetsencryptDomains, "letsencrypt-domains", "a", nil, "list of domains to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
+ rootCmd.PersistentFlags().StringVar(&cobraConfig.LetsencryptEmail, "letsencrypt-email", "", "email address to use for Let's Encrypt certificate registration")
+ rootCmd.PersistentFlags().BoolVar(&cobraConfig.LetsencryptAWSRoute53, "letsencrypt-aws-route53", false, "use AWS Route 53 for Let's Encrypt DNS challenge")
+ rootCmd.PersistentFlags().StringVarP(&cobraConfig.TlsCertFile, "tls-cert-file", "c", "", "")
+ rootCmd.PersistentFlags().StringVarP(&cobraConfig.TlsKeyFile, "tls-key-file", "k", "", "")
+ rootCmd.PersistentFlags().StringVarP(&cobraConfig.AuthSecret, "auth-secret", "s", "", "auth secret")
+ rootCmd.PersistentFlags().StringVar(&cobraConfig.LogLevel, "log-level", "info", "log level")
+ rootCmd.PersistentFlags().StringVar(&cobraConfig.LogFile, "log-file", "console", "log file")
+
+ setFlagsFromEnvVars(rootCmd)
+}
+
+func Execute() error {
+ return rootCmd.Execute()
+}
+
+func waitForExitSignal() {
+ osSigs := make(chan os.Signal, 1)
+ signal.Notify(osSigs, syscall.SIGINT, syscall.SIGTERM)
+ <-osSigs
+}
+
+func execute(cmd *cobra.Command, args []string) error {
+ err := cobraConfig.Validate()
+ if err != nil {
+ log.Debugf("invalid config: %s", err)
+ return fmt.Errorf("invalid config: %s", err)
+ }
+
+ err = util.InitLog(cobraConfig.LogLevel, cobraConfig.LogFile)
+ if err != nil {
+ log.Debugf("failed to initialize log: %s", err)
+ return fmt.Errorf("failed to initialize log: %s", err)
+ }
+
+ metricsServer, err := metrics.NewServer(cobraConfig.MetricsPort, "")
+ if err != nil {
+ log.Debugf("setup metrics: %v", err)
+ return fmt.Errorf("setup metrics: %v", err)
+ }
+
+ go func() {
+ log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint)
+ if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
+ log.Fatalf("Failed to start metrics server: %v", err)
+ }
+ }()
+
+ srvListenerCfg := server.ListenerConfig{
+ Address: cobraConfig.ListenAddress,
+ }
+
+ tlsConfig, tlsSupport, err := handleTLSConfig(cobraConfig)
+ if err != nil {
+ log.Debugf("failed to setup TLS config: %s", err)
+ return fmt.Errorf("failed to setup TLS config: %s", err)
+ }
+ srvListenerCfg.TLSConfig = tlsConfig
+
+ hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret))
+ authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour)
+
+ srv, err := server.NewServer(metricsServer.Meter, cobraConfig.ExposedAddress, tlsSupport, authenticator)
+ if err != nil {
+ log.Debugf("failed to create relay server: %v", err)
+ return fmt.Errorf("failed to create relay server: %v", err)
+ }
+ log.Infof("server will be available on: %s", srv.InstanceURL())
+ go func() {
+ if err := srv.Listen(srvListenerCfg); err != nil {
+ log.Fatalf("failed to bind server: %s", err)
+ }
+ }()
+
+ // it will block until exit signal
+ waitForExitSignal()
+
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ var shutDownErrors error
+ if err := srv.Shutdown(ctx); err != nil {
+ shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close server: %s", err))
+ }
+
+ log.Infof("shutting down metrics server")
+ if err := metricsServer.Shutdown(ctx); err != nil {
+ shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close metrics server: %v", err))
+ }
+ return shutDownErrors
+}
+
+func handleTLSConfig(cfg *Config) (*tls.Config, bool, error) {
+ if cfg.LetsencryptAWSRoute53 {
+ log.Debugf("using Let's Encrypt DNS resolver with Route 53 support")
+ r53 := encryption.Route53TLS{
+ DataDir: cfg.LetsencryptDataDir,
+ Email: cfg.LetsencryptEmail,
+ Domains: cfg.LetsencryptDomains,
+ }
+ tlsCfg, err := r53.GetCertificate()
+ if err != nil {
+ return nil, false, fmt.Errorf("%s", err)
+ }
+ return tlsCfg, true, nil
+ }
+
+ if cfg.HasLetsEncrypt() {
+ log.Infof("setting up TLS with Let's Encrypt.")
+ tlsCfg, err := setupTLSCertManager(cfg.LetsencryptDataDir, cfg.LetsencryptDomains...)
+ if err != nil {
+ return nil, false, fmt.Errorf("%s", err)
+ }
+ return tlsCfg, true, nil
+ }
+
+ if cfg.HasCertConfig() {
+ log.Debugf("using file based TLS config")
+ tlsCfg, err := encryption.LoadTLSConfig(cfg.TlsCertFile, cfg.TlsKeyFile)
+ if err != nil {
+ return nil, false, fmt.Errorf("%s", err)
+ }
+ return tlsCfg, true, nil
+ }
+ return nil, false, nil
+}
+
+func setupTLSCertManager(letsencryptDataDir string, letsencryptDomains ...string) (*tls.Config, error) {
+ certManager, err := encryption.CreateCertManager(letsencryptDataDir, letsencryptDomains...)
+ if err != nil {
+ return nil, fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err)
+ }
+ return certManager.TLSConfig(), nil
+}
diff --git a/relay/doc.go b/relay/doc.go
new file mode 100644
index 000000000..56e010e3e
--- /dev/null
+++ b/relay/doc.go
@@ -0,0 +1,14 @@
+//Package main
+/*
+The `relay` package contains the implementation of the Relay server and client. The Relay server can be used to relay
+messages between peers on a single network channel. In this implementation the transport layer is the WebSocket
+protocol.
+
+Between the server and client communication has been design a custom protocol and message format. These messages are
+transported over the WebSocket connection. Optionally the server can use TLS to secure the communication.
+
+The service can support multiple Relay server instances. For this purpose the peers must know the server instance URL.
+This URL will be sent to the target peer to choose the common Relay server for the communication via Signal service.
+
+*/
+package main
diff --git a/relay/healthcheck/doc.go b/relay/healthcheck/doc.go
new file mode 100644
index 000000000..da9689c6b
--- /dev/null
+++ b/relay/healthcheck/doc.go
@@ -0,0 +1,17 @@
+/*
+The `healthcheck` package is responsible for managing the health checks between the client and the relay server. It
+ensures that the connection between the client and the server are alive and functioning properly.
+
+The `Sender` struct is responsible for sending health check signals to the receiver. The receiver listens for these
+signals and sends a new signal back to the sender to acknowledge that the signal has been received. If the sender does
+not receive an acknowledgment signal within a certain time frame, it will send a timeout signal via timeout channel
+and stop working.
+
+The `Receiver` struct is responsible for receiving the health check signals from the sender. If the receiver does not
+receive a signal within a certain time frame, it will send a timeout signal via the OnTimeout channel and stop working.
+
+In the Relay usage the signal is sent to the peer in message type Healthcheck. In case of timeout the connection is
+closed and the peer is removed from the relay.
+*/
+
+package healthcheck
diff --git a/relay/healthcheck/receiver.go b/relay/healthcheck/receiver.go
new file mode 100644
index 000000000..b3503d5db
--- /dev/null
+++ b/relay/healthcheck/receiver.go
@@ -0,0 +1,94 @@
+package healthcheck
+
+import (
+ "context"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+)
+
+var (
+ heartbeatTimeout = healthCheckInterval + 10*time.Second
+)
+
+// Receiver is a healthcheck receiver
+// It will listen for heartbeat and check if the heartbeat is not received in a certain time
+// If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work
+// The heartbeat timeout is a bit longer than the sender's healthcheck interval
+type Receiver struct {
+ OnTimeout chan struct{}
+ log *log.Entry
+ ctx context.Context
+ ctxCancel context.CancelFunc
+ heartbeat chan struct{}
+ alive bool
+ attemptThreshold int
+}
+
+// NewReceiver creates a new healthcheck receiver and start the timer in the background
+func NewReceiver(log *log.Entry) *Receiver {
+ ctx, ctxCancel := context.WithCancel(context.Background())
+
+ r := &Receiver{
+ OnTimeout: make(chan struct{}, 1),
+ log: log,
+ ctx: ctx,
+ ctxCancel: ctxCancel,
+ heartbeat: make(chan struct{}, 1),
+ attemptThreshold: getAttemptThresholdFromEnv(),
+ }
+
+ go r.waitForHealthcheck()
+ return r
+}
+
+// Heartbeat acknowledge the heartbeat has been received
+func (r *Receiver) Heartbeat() {
+ select {
+ case r.heartbeat <- struct{}{}:
+ default:
+ }
+}
+
+// Stop check the timeout and do not send new notifications
+func (r *Receiver) Stop() {
+ r.ctxCancel()
+}
+
+func (r *Receiver) waitForHealthcheck() {
+ ticker := time.NewTicker(heartbeatTimeout)
+ defer ticker.Stop()
+ defer r.ctxCancel()
+ defer close(r.OnTimeout)
+
+ failureCounter := 0
+ for {
+ select {
+ case <-r.heartbeat:
+ r.alive = true
+ failureCounter = 0
+ case <-ticker.C:
+ if r.alive {
+ r.alive = false
+ continue
+ }
+
+ failureCounter++
+ if failureCounter < r.attemptThreshold {
+ r.log.Warnf("healthcheck failed, attempt %d", failureCounter)
+ continue
+ }
+ r.notifyTimeout()
+ return
+ case <-r.ctx.Done():
+ return
+ }
+ }
+}
+
+func (r *Receiver) notifyTimeout() {
+ select {
+ case r.OnTimeout <- struct{}{}:
+ default:
+ }
+}
diff --git a/relay/healthcheck/receiver_test.go b/relay/healthcheck/receiver_test.go
new file mode 100644
index 000000000..3b3e32fe6
--- /dev/null
+++ b/relay/healthcheck/receiver_test.go
@@ -0,0 +1,97 @@
+package healthcheck
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "testing"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+)
+
+func TestNewReceiver(t *testing.T) {
+ heartbeatTimeout = 5 * time.Second
+ r := NewReceiver(log.WithContext(context.Background()))
+
+ select {
+ case <-r.OnTimeout:
+ t.Error("unexpected timeout")
+ case <-time.After(1 * time.Second):
+
+ }
+}
+
+func TestNewReceiverNotReceive(t *testing.T) {
+ heartbeatTimeout = 1 * time.Second
+ r := NewReceiver(log.WithContext(context.Background()))
+
+ select {
+ case <-r.OnTimeout:
+ case <-time.After(2 * time.Second):
+ t.Error("timeout not received")
+ }
+}
+
+func TestNewReceiverAck(t *testing.T) {
+ heartbeatTimeout = 2 * time.Second
+ r := NewReceiver(log.WithContext(context.Background()))
+
+ r.Heartbeat()
+
+ select {
+ case <-r.OnTimeout:
+ t.Error("unexpected timeout")
+ case <-time.After(3 * time.Second):
+ }
+}
+
+func TestReceiverHealthCheckAttemptThreshold(t *testing.T) {
+ testsCases := []struct {
+ name string
+ threshold int
+ resetCounterOnce bool
+ }{
+ {"Default attempt threshold", defaultAttemptThreshold, false},
+ {"Custom attempt threshold", 3, false},
+ {"Should reset threshold once", 2, true},
+ }
+
+ for _, tc := range testsCases {
+ t.Run(tc.name, func(t *testing.T) {
+ originalInterval := healthCheckInterval
+ originalTimeout := heartbeatTimeout
+ healthCheckInterval = 1 * time.Second
+ heartbeatTimeout = healthCheckInterval + 500*time.Millisecond
+ defer func() {
+ healthCheckInterval = originalInterval
+ heartbeatTimeout = originalTimeout
+ }()
+ //nolint:tenv
+ os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
+ defer os.Unsetenv(defaultAttemptThresholdEnv)
+
+ receiver := NewReceiver(log.WithField("test_name", tc.name))
+
+ testTimeout := heartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval
+
+ if tc.resetCounterOnce {
+ receiver.Heartbeat()
+ t.Logf("reset counter once")
+ }
+
+ select {
+ case <-receiver.OnTimeout:
+ if tc.resetCounterOnce {
+ t.Fatalf("should not have timed out before %s", testTimeout)
+ }
+ case <-time.After(testTimeout):
+ if tc.resetCounterOnce {
+ return
+ }
+ t.Fatalf("should have timed out before %s", testTimeout)
+ }
+
+ })
+ }
+}
diff --git a/relay/healthcheck/sender.go b/relay/healthcheck/sender.go
new file mode 100644
index 000000000..57b3015ec
--- /dev/null
+++ b/relay/healthcheck/sender.go
@@ -0,0 +1,110 @@
+package healthcheck
+
+import (
+ "context"
+ "os"
+ "strconv"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+)
+
+const (
+ defaultAttemptThreshold = 1
+ defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD"
+)
+
+var (
+ healthCheckInterval = 25 * time.Second
+ healthCheckTimeout = 20 * time.Second
+)
+
+// Sender is a healthcheck sender
+// It will send healthcheck signal to the receiver
+// If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work
+// It will also stop if the context is canceled
+type Sender struct {
+ log *log.Entry
+ // HealthCheck is a channel to send health check signal to the peer
+ HealthCheck chan struct{}
+ // Timeout is a channel to the health check signal is not received in a certain time
+ Timeout chan struct{}
+
+ ack chan struct{}
+ alive bool
+ attemptThreshold int
+}
+
+// NewSender creates a new healthcheck sender
+func NewSender(log *log.Entry) *Sender {
+ hc := &Sender{
+ log: log,
+ HealthCheck: make(chan struct{}, 1),
+ Timeout: make(chan struct{}, 1),
+ ack: make(chan struct{}, 1),
+ attemptThreshold: getAttemptThresholdFromEnv(),
+ }
+
+ return hc
+}
+
+// OnHCResponse sends an acknowledgment signal to the sender
+func (hc *Sender) OnHCResponse() {
+ select {
+ case hc.ack <- struct{}{}:
+ default:
+ }
+}
+
+func (hc *Sender) StartHealthCheck(ctx context.Context) {
+ ticker := time.NewTicker(healthCheckInterval)
+ defer ticker.Stop()
+
+ timeoutTicker := time.NewTicker(hc.getTimeoutTime())
+ defer timeoutTicker.Stop()
+
+ defer close(hc.HealthCheck)
+ defer close(hc.Timeout)
+
+ failureCounter := 0
+ for {
+ select {
+ case <-ticker.C:
+ hc.HealthCheck <- struct{}{}
+ case <-timeoutTicker.C:
+ if hc.alive {
+ hc.alive = false
+ continue
+ }
+
+ failureCounter++
+ if failureCounter < hc.attemptThreshold {
+ hc.log.Warnf("Health check failed attempt %d.", failureCounter)
+ continue
+ }
+ hc.Timeout <- struct{}{}
+ return
+ case <-hc.ack:
+ failureCounter = 0
+ hc.alive = true
+ case <-ctx.Done():
+ return
+ }
+ }
+}
+
+func (hc *Sender) getTimeoutTime() time.Duration {
+ return healthCheckInterval + healthCheckTimeout
+}
+
+func getAttemptThresholdFromEnv() int {
+ if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" {
+ threshold, err := strconv.ParseInt(attemptThreshold, 10, 64)
+ if err != nil {
+ log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold)
+ return defaultAttemptThreshold
+ }
+ return int(threshold)
+ }
+ return defaultAttemptThreshold
+}
diff --git a/relay/healthcheck/sender_test.go b/relay/healthcheck/sender_test.go
new file mode 100644
index 000000000..f21167025
--- /dev/null
+++ b/relay/healthcheck/sender_test.go
@@ -0,0 +1,205 @@
+package healthcheck
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "testing"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+)
+
+func TestMain(m *testing.M) {
+ // override the health check interval to speed up the test
+ healthCheckInterval = 2 * time.Second
+ healthCheckTimeout = 100 * time.Millisecond
+ code := m.Run()
+ os.Exit(code)
+}
+
+func TestNewHealthPeriod(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ hc := NewSender(log.WithContext(ctx))
+ go hc.StartHealthCheck(ctx)
+
+ iterations := 0
+ for i := 0; i < 3; i++ {
+ select {
+ case <-hc.HealthCheck:
+ iterations++
+ hc.OnHCResponse()
+ case <-hc.Timeout:
+ t.Fatalf("health check is timed out")
+ case <-time.After(healthCheckInterval + 100*time.Millisecond):
+ t.Fatalf("health check not received")
+ }
+ }
+}
+
+func TestNewHealthFailed(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ hc := NewSender(log.WithContext(ctx))
+ go hc.StartHealthCheck(ctx)
+
+ select {
+ case <-hc.Timeout:
+ case <-time.After(healthCheckInterval + healthCheckTimeout + 100*time.Millisecond):
+ t.Fatalf("health check is not timed out")
+ }
+}
+
+func TestNewHealthcheckStop(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ hc := NewSender(log.WithContext(ctx))
+ go hc.StartHealthCheck(ctx)
+
+ time.Sleep(100 * time.Millisecond)
+ cancel()
+
+ select {
+ case _, ok := <-hc.HealthCheck:
+ if ok {
+ t.Fatalf("health check on received")
+ }
+ case _, ok := <-hc.Timeout:
+ if ok {
+ t.Fatalf("health check on received")
+ }
+ case <-ctx.Done():
+ // expected
+ case <-time.After(10 * time.Second):
+ t.Fatalf("is not exited")
+ }
+}
+
+func TestTimeoutReset(t *testing.T) {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ hc := NewSender(log.WithContext(ctx))
+ go hc.StartHealthCheck(ctx)
+
+ iterations := 0
+ for i := 0; i < 3; i++ {
+ select {
+ case <-hc.HealthCheck:
+ iterations++
+ hc.OnHCResponse()
+ case <-hc.Timeout:
+ t.Fatalf("health check is timed out")
+ case <-time.After(healthCheckInterval + 100*time.Millisecond):
+ t.Fatalf("health check not received")
+ }
+ }
+
+ select {
+ case <-hc.HealthCheck:
+ case <-hc.Timeout:
+ // expected
+ case <-ctx.Done():
+ t.Fatalf("context is done")
+ case <-time.After(10 * time.Second):
+ t.Fatalf("is not exited")
+ }
+}
+
+func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
+ testsCases := []struct {
+ name string
+ threshold int
+ resetCounterOnce bool
+ }{
+ {"Default attempt threshold", defaultAttemptThreshold, false},
+ {"Custom attempt threshold", 3, false},
+ {"Should reset threshold once", 2, true},
+ }
+
+ for _, tc := range testsCases {
+ t.Run(tc.name, func(t *testing.T) {
+ originalInterval := healthCheckInterval
+ originalTimeout := healthCheckTimeout
+ healthCheckInterval = 1 * time.Second
+ healthCheckTimeout = 500 * time.Millisecond
+ defer func() {
+ healthCheckInterval = originalInterval
+ healthCheckTimeout = originalTimeout
+ }()
+
+ //nolint:tenv
+ os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
+ defer os.Unsetenv(defaultAttemptThresholdEnv)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ sender := NewSender(log.WithField("test_name", tc.name))
+ go sender.StartHealthCheck(ctx)
+
+ go func() {
+ responded := false
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case _, ok := <-sender.HealthCheck:
+ if !ok {
+ return
+ }
+ if tc.resetCounterOnce && !responded {
+ responded = true
+ sender.OnHCResponse()
+ }
+ }
+ }
+ }()
+
+ testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + healthCheckInterval
+
+ select {
+ case <-sender.Timeout:
+ if tc.resetCounterOnce {
+ t.Fatalf("should not have timed out before %s", testTimeout)
+ }
+ case <-time.After(testTimeout):
+ if tc.resetCounterOnce {
+ return
+ }
+ t.Fatalf("should have timed out before %s", testTimeout)
+ }
+
+ })
+ }
+
+}
+
+//nolint:tenv
+func TestGetAttemptThresholdFromEnv(t *testing.T) {
+ tests := []struct {
+ name string
+ envValue string
+ expected int
+ }{
+ {"Default attempt threshold when env is not set", "", defaultAttemptThreshold},
+ {"Custom attempt threshold when env is set to a valid integer", "3", 3},
+ {"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if tt.envValue == "" {
+ os.Unsetenv(defaultAttemptThresholdEnv)
+ } else {
+ os.Setenv(defaultAttemptThresholdEnv, tt.envValue)
+ }
+
+ result := getAttemptThresholdFromEnv()
+ if result != tt.expected {
+ t.Fatalf("Expected %d, got %d", tt.expected, result)
+ }
+
+ os.Unsetenv(defaultAttemptThresholdEnv)
+ })
+ }
+}
diff --git a/relay/main.go b/relay/main.go
new file mode 100644
index 000000000..e28f73603
--- /dev/null
+++ b/relay/main.go
@@ -0,0 +1,13 @@
+package main
+
+import (
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/relay/cmd"
+)
+
+func main() {
+ if err := cmd.Execute(); err != nil {
+ log.Fatalf("failed to execute command: %v", err)
+ }
+}
diff --git a/relay/messages/address/address.go b/relay/messages/address/address.go
new file mode 100644
index 000000000..707e73e55
--- /dev/null
+++ b/relay/messages/address/address.go
@@ -0,0 +1,21 @@
+// Deprecated: This package is deprecated and will be removed in a future release.
+package address
+
+import (
+ "bytes"
+ "encoding/gob"
+ "fmt"
+)
+
+type Address struct {
+ URL string
+}
+
+func (addr *Address) Marshal() ([]byte, error) {
+ var buf bytes.Buffer
+ enc := gob.NewEncoder(&buf)
+ if err := enc.Encode(addr); err != nil {
+ return nil, fmt.Errorf("encode Address: %w", err)
+ }
+ return buf.Bytes(), nil
+}
diff --git a/relay/messages/auth/auth.go b/relay/messages/auth/auth.go
new file mode 100644
index 000000000..9c2511f2f
--- /dev/null
+++ b/relay/messages/auth/auth.go
@@ -0,0 +1,43 @@
+// Deprecated: This package is deprecated and will be removed in a future release.
+package auth
+
+import (
+ "bytes"
+ "encoding/gob"
+ "fmt"
+)
+
+type Algorithm int
+
+const (
+ AlgoUnknown Algorithm = iota
+ AlgoHMACSHA256
+ AlgoHMACSHA512
+)
+
+func (a Algorithm) String() string {
+ switch a {
+ case AlgoHMACSHA256:
+ return "HMAC-SHA256"
+ case AlgoHMACSHA512:
+ return "HMAC-SHA512"
+ default:
+ return "Unknown"
+ }
+}
+
+type Msg struct {
+ AuthAlgorithm Algorithm
+ AdditionalData []byte
+}
+
+func UnmarshalMsg(data []byte) (*Msg, error) {
+ var msg *Msg
+
+ buf := bytes.NewBuffer(data)
+ dec := gob.NewDecoder(buf)
+ if err := dec.Decode(&msg); err != nil {
+ return nil, fmt.Errorf("decode Msg: %w", err)
+ }
+ return msg, nil
+}
diff --git a/relay/messages/doc.go b/relay/messages/doc.go
new file mode 100644
index 000000000..4c719df3a
--- /dev/null
+++ b/relay/messages/doc.go
@@ -0,0 +1,5 @@
+/*
+Package messages provides the message types that are used to communicate between the relay and the client.
+This package is used to determine the type of message that is being sent and received between the relay and the client.
+*/
+package messages
diff --git a/relay/messages/id.go b/relay/messages/id.go
new file mode 100644
index 000000000..e2162cd3b
--- /dev/null
+++ b/relay/messages/id.go
@@ -0,0 +1,31 @@
+package messages
+
+import (
+ "crypto/sha256"
+ "encoding/base64"
+ "fmt"
+)
+
+const (
+ prefixLength = 4
+ IDSize = prefixLength + sha256.Size
+)
+
+var (
+ prefix = []byte("sha-") // 4 bytes
+)
+
+// HashID generates a sha256 hash from the peerID and returns the hash and the human-readable string
+func HashID(peerID string) ([]byte, string) {
+ idHash := sha256.Sum256([]byte(peerID))
+ idHashString := string(prefix) + base64.StdEncoding.EncodeToString(idHash[:])
+ var prefixedHash []byte
+ prefixedHash = append(prefixedHash, prefix...)
+ prefixedHash = append(prefixedHash, idHash[:]...)
+ return prefixedHash, idHashString
+}
+
+// HashIDToString converts a hash to a human-readable string
+func HashIDToString(idHash []byte) string {
+ return fmt.Sprintf("%s%s", idHash[:prefixLength], base64.StdEncoding.EncodeToString(idHash[prefixLength:]))
+}
diff --git a/relay/messages/id_test.go b/relay/messages/id_test.go
new file mode 100644
index 000000000..271a8f90d
--- /dev/null
+++ b/relay/messages/id_test.go
@@ -0,0 +1,13 @@
+package messages
+
+import (
+ "testing"
+)
+
+func TestHashID(t *testing.T) {
+ hashedID, hashedStringId := HashID("alice")
+ enc := HashIDToString(hashedID)
+ if enc != hashedStringId {
+ t.Errorf("expected %s, got %s", hashedStringId, enc)
+ }
+}
diff --git a/relay/messages/message.go b/relay/messages/message.go
new file mode 100644
index 000000000..39ca0aa90
--- /dev/null
+++ b/relay/messages/message.go
@@ -0,0 +1,317 @@
+package messages
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+)
+
+const (
+ MaxHandshakeSize = 212
+ MaxHandshakeRespSize = 8192
+
+ CurrentProtocolVersion = 1
+
+ MsgTypeUnknown MsgType = 0
+ // Deprecated: Use MsgTypeAuth instead.
+ MsgTypeHello MsgType = 1
+ // Deprecated: Use MsgTypeAuthResponse instead.
+ MsgTypeHelloResponse MsgType = 2
+ MsgTypeTransport MsgType = 3
+ MsgTypeClose MsgType = 4
+ MsgTypeHealthCheck MsgType = 5
+ MsgTypeAuth = 6
+ MsgTypeAuthResponse = 7
+
+ SizeOfVersionByte = 1
+ SizeOfMsgType = 1
+
+ SizeOfProtoHeader = SizeOfVersionByte + SizeOfMsgType
+
+ sizeOfMagicByte = 4
+
+ headerSizeTransport = IDSize
+
+ headerSizeHello = sizeOfMagicByte + IDSize
+ headerSizeHelloResp = 0
+
+ headerSizeAuth = sizeOfMagicByte + IDSize
+ headerSizeAuthResp = 0
+)
+
+var (
+ ErrInvalidMessageLength = errors.New("invalid message length")
+ ErrUnsupportedVersion = errors.New("unsupported version")
+
+ magicHeader = []byte{0x21, 0x12, 0xA4, 0x42}
+
+ healthCheckMsg = []byte{byte(CurrentProtocolVersion), byte(MsgTypeHealthCheck)}
+)
+
+type MsgType byte
+
+func (m MsgType) String() string {
+ switch m {
+ case MsgTypeHello:
+ return "hello"
+ case MsgTypeHelloResponse:
+ return "hello response"
+ case MsgTypeAuth:
+ return "auth"
+ case MsgTypeAuthResponse:
+ return "auth response"
+ case MsgTypeTransport:
+ return "transport"
+ case MsgTypeClose:
+ return "close"
+ case MsgTypeHealthCheck:
+ return "health check"
+ default:
+ return "unknown"
+ }
+}
+
+// ValidateVersion checks if the given version is supported by the protocol
+func ValidateVersion(msg []byte) (int, error) {
+ if len(msg) < SizeOfVersionByte {
+ return 0, ErrInvalidMessageLength
+ }
+ version := int(msg[0])
+ if version != CurrentProtocolVersion {
+ return 0, fmt.Errorf("%d: %w", version, ErrUnsupportedVersion)
+ }
+ return version, nil
+}
+
+// DetermineClientMessageType determines the message type from the first the message
+func DetermineClientMessageType(msg []byte) (MsgType, error) {
+ if len(msg) < SizeOfMsgType {
+ return 0, ErrInvalidMessageLength
+ }
+
+ msgType := MsgType(msg[0])
+ switch msgType {
+ case
+ MsgTypeHello,
+ MsgTypeAuth,
+ MsgTypeTransport,
+ MsgTypeClose,
+ MsgTypeHealthCheck:
+ return msgType, nil
+ default:
+ return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
+ }
+}
+
+// DetermineServerMessageType determines the message type from the first the message
+func DetermineServerMessageType(msg []byte) (MsgType, error) {
+ if len(msg) < SizeOfMsgType {
+ return 0, ErrInvalidMessageLength
+ }
+
+ msgType := MsgType(msg[0])
+ switch msgType {
+ case
+ MsgTypeHelloResponse,
+ MsgTypeAuthResponse,
+ MsgTypeTransport,
+ MsgTypeClose,
+ MsgTypeHealthCheck:
+ return msgType, nil
+ default:
+ return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType)
+ }
+}
+
+// Deprecated: Use MarshalAuthMsg instead.
+// MarshalHelloMsg initial hello message
+// The Hello message is the first message sent by a client after establishing a connection with the Relay server. This
+// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
+// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
+// close the network connection without any response.
+func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) {
+ if len(peerID) != IDSize {
+ return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
+ }
+
+ msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeHello+len(additions))
+
+ msg[0] = byte(CurrentProtocolVersion)
+ msg[1] = byte(MsgTypeHello)
+
+ copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
+
+ msg = append(msg, peerID...)
+ msg = append(msg, additions...)
+
+ return msg, nil
+}
+
+// Deprecated: Use UnmarshalAuthMsg instead.
+// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to
+// authenticate the client with the server.
+func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) {
+ if len(msg) < headerSizeHello {
+ return nil, nil, ErrInvalidMessageLength
+ }
+ if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
+ return nil, nil, errors.New("invalid magic header")
+ }
+
+ return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil
+}
+
+// Deprecated: Use MarshalAuthResponse instead.
+// MarshalHelloResponse creates a response message to the hello message.
+// In case of success connection the server response with a Hello Response message. This message contains the server's
+// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
+// servers.
+func MarshalHelloResponse(additionalData []byte) ([]byte, error) {
+ msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeHelloResp+len(additionalData))
+
+ msg[0] = byte(CurrentProtocolVersion)
+ msg[1] = byte(MsgTypeHelloResponse)
+
+ msg = append(msg, additionalData...)
+
+ return msg, nil
+}
+
+// Deprecated: Use UnmarshalAuthResponse instead.
+// UnmarshalHelloResponse extracts the additional data from the hello response message.
+func UnmarshalHelloResponse(msg []byte) ([]byte, error) {
+ if len(msg) < headerSizeHelloResp {
+ return nil, ErrInvalidMessageLength
+ }
+ return msg, nil
+}
+
+// MarshalAuthMsg initial authentication message
+// The Auth message is the first message sent by a client after establishing a connection with the Relay server. This
+// message is used to authenticate the client with the server. The authentication is done using an HMAC method.
+// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will
+// close the network connection without any response.
+func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) {
+ if len(peerID) != IDSize {
+ return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
+ }
+
+ msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeAuth+len(authPayload))
+
+ msg[0] = byte(CurrentProtocolVersion)
+ msg[1] = byte(MsgTypeAuth)
+
+ copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader)
+
+ msg = append(msg, peerID...)
+ msg = append(msg, authPayload...)
+
+ return msg, nil
+}
+
+// UnmarshalAuthMsg extracts peerID and the auth payload from the message
+func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) {
+ if len(msg) < headerSizeAuth {
+ return nil, nil, ErrInvalidMessageLength
+ }
+ if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) {
+ return nil, nil, errors.New("invalid magic header")
+ }
+
+ return msg[sizeOfMagicByte:headerSizeAuth], msg[headerSizeAuth:], nil
+}
+
+// MarshalAuthResponse creates a response message to the auth.
+// In case of success connection the server response with a AuthResponse message. This message contains the server's
+// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay
+// servers.
+func MarshalAuthResponse(address string) ([]byte, error) {
+ ab := []byte(address)
+ msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeAuthResp+len(ab))
+
+ msg[0] = byte(CurrentProtocolVersion)
+ msg[1] = byte(MsgTypeAuthResponse)
+
+ msg = append(msg, ab...)
+
+ if len(msg) > MaxHandshakeRespSize {
+ return nil, fmt.Errorf("invalid message length: %d", len(msg))
+ }
+
+ return msg, nil
+}
+
+// UnmarshalAuthResponse it is a confirmation message to auth success
+func UnmarshalAuthResponse(msg []byte) (string, error) {
+ if len(msg) < headerSizeAuthResp+1 {
+ return "", ErrInvalidMessageLength
+ }
+ return string(msg), nil
+}
+
+// MarshalCloseMsg creates a close message.
+// The close message is used to close the connection gracefully between the client and the server. The server and the
+// client can send this message. After receiving this message, the server or client will close the connection.
+func MarshalCloseMsg() []byte {
+ msg := make([]byte, SizeOfProtoHeader)
+
+ msg[0] = byte(CurrentProtocolVersion)
+ msg[1] = byte(MsgTypeClose)
+
+ return msg
+}
+
+// MarshalTransportMsg creates a transport message.
+// The transport message is used to exchange data between peers. The message contains the data to be exchanged and the
+// destination peer hashed ID.
+func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) {
+ if len(peerID) != IDSize {
+ return nil, fmt.Errorf("invalid peerID length: %d", len(peerID))
+ }
+
+ msg := make([]byte, SizeOfProtoHeader+headerSizeTransport, SizeOfProtoHeader+headerSizeTransport+len(payload))
+
+ msg[0] = byte(CurrentProtocolVersion)
+ msg[1] = byte(MsgTypeTransport)
+
+ copy(msg[SizeOfProtoHeader:], peerID)
+
+ msg = append(msg, payload...)
+
+ return msg, nil
+}
+
+// UnmarshalTransportMsg extracts the peerID and the payload from the transport message.
+func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) {
+ if len(buf) < headerSizeTransport {
+ return nil, nil, ErrInvalidMessageLength
+ }
+
+ return buf[:headerSizeTransport], buf[headerSizeTransport:], nil
+}
+
+// UnmarshalTransportID extracts the peerID from the transport message.
+func UnmarshalTransportID(buf []byte) ([]byte, error) {
+ if len(buf) < headerSizeTransport {
+ return nil, ErrInvalidMessageLength
+ }
+ return buf[:headerSizeTransport], nil
+}
+
+// UpdateTransportMsg updates the peerID in the transport message.
+// With this function the server can reuse the given byte slice to update the peerID in the transport message. So do
+// need to allocate a new byte slice.
+func UpdateTransportMsg(msg []byte, peerID []byte) error {
+ if len(msg) < len(peerID) {
+ return ErrInvalidMessageLength
+ }
+ copy(msg, peerID)
+ return nil
+}
+
+// MarshalHealthcheck creates a health check message.
+// Health check message is sent by the server periodically. The client will respond with a health check response
+// message. If the client does not respond to the health check message, the server will close the connection.
+func MarshalHealthcheck() []byte {
+ return healthCheckMsg
+}
diff --git a/relay/messages/message_test.go b/relay/messages/message_test.go
new file mode 100644
index 000000000..6e917da71
--- /dev/null
+++ b/relay/messages/message_test.go
@@ -0,0 +1,59 @@
+package messages
+
+import (
+ "testing"
+)
+
+func TestMarshalHelloMsg(t *testing.T) {
+ peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
+ bHello, err := MarshalHelloMsg(peerID, nil)
+ if err != nil {
+ t.Fatalf("error: %v", err)
+ }
+
+ receivedPeerID, _, err := UnmarshalHelloMsg(bHello[SizeOfProtoHeader:])
+ if err != nil {
+ t.Fatalf("error: %v", err)
+ }
+ if string(receivedPeerID) != string(peerID) {
+ t.Errorf("expected %s, got %s", peerID, receivedPeerID)
+ }
+}
+
+func TestMarshalAuthMsg(t *testing.T) {
+ peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
+ bHello, err := MarshalAuthMsg(peerID, []byte{})
+ if err != nil {
+ t.Fatalf("error: %v", err)
+ }
+
+ receivedPeerID, _, err := UnmarshalAuthMsg(bHello[SizeOfProtoHeader:])
+ if err != nil {
+ t.Fatalf("error: %v", err)
+ }
+ if string(receivedPeerID) != string(peerID) {
+ t.Errorf("expected %s, got %s", peerID, receivedPeerID)
+ }
+}
+
+func TestMarshalTransportMsg(t *testing.T) {
+ peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+")
+ payload := []byte("payload")
+ msg, err := MarshalTransportMsg(peerID, payload)
+ if err != nil {
+ t.Fatalf("error: %v", err)
+ }
+
+ id, respPayload, err := UnmarshalTransportMsg(msg[SizeOfProtoHeader:])
+ if err != nil {
+ t.Fatalf("error: %v", err)
+ }
+
+ if string(id) != string(peerID) {
+ t.Errorf("expected %s, got %s", peerID, id)
+ }
+
+ if string(respPayload) != string(payload) {
+ t.Errorf("expected %s, got %s", payload, respPayload)
+ }
+}
diff --git a/relay/metrics/realy.go b/relay/metrics/realy.go
new file mode 100644
index 000000000..13799713a
--- /dev/null
+++ b/relay/metrics/realy.go
@@ -0,0 +1,136 @@
+package metrics
+
+import (
+ "context"
+ "sync"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+ "go.opentelemetry.io/otel/metric"
+)
+
+const (
+ idleTimeout = 30 * time.Second
+)
+
+type Metrics struct {
+ metric.Meter
+
+ TransferBytesSent metric.Int64Counter
+ TransferBytesRecv metric.Int64Counter
+
+ peers metric.Int64UpDownCounter
+ peerActivityChan chan string
+ peerLastActive map[string]time.Time
+ mutexActivity sync.Mutex
+ ctx context.Context
+}
+
+func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) {
+ bytesSent, err := meter.Int64Counter("relay_transfer_sent_bytes_total")
+ if err != nil {
+ return nil, err
+ }
+
+ bytesRecv, err := meter.Int64Counter("relay_transfer_received_bytes_total")
+ if err != nil {
+ return nil, err
+ }
+
+ peers, err := meter.Int64UpDownCounter("relay_peers")
+ if err != nil {
+ return nil, err
+ }
+
+ peersActive, err := meter.Int64ObservableGauge("relay_peers_active")
+ if err != nil {
+ return nil, err
+ }
+
+ peersIdle, err := meter.Int64ObservableGauge("relay_peers_idle")
+ if err != nil {
+ return nil, err
+ }
+
+ m := &Metrics{
+ Meter: meter,
+ TransferBytesSent: bytesSent,
+ TransferBytesRecv: bytesRecv,
+ peers: peers,
+
+ ctx: ctx,
+ peerActivityChan: make(chan string, 10),
+ peerLastActive: make(map[string]time.Time),
+ }
+
+ _, err = meter.RegisterCallback(
+ func(ctx context.Context, o metric.Observer) error {
+ active, idle := m.calculateActiveIdleConnections()
+ o.ObserveInt64(peersActive, active)
+ o.ObserveInt64(peersIdle, idle)
+ return nil
+ },
+ peersActive, peersIdle,
+ )
+ if err != nil {
+ return nil, err
+ }
+
+ go m.readPeerActivity()
+ return m, nil
+}
+
+// PeerConnected increments the number of connected peers and increments number of idle connections
+func (m *Metrics) PeerConnected(id string) {
+ m.peers.Add(m.ctx, 1)
+ m.mutexActivity.Lock()
+ defer m.mutexActivity.Unlock()
+
+ m.peerLastActive[id] = time.Time{}
+}
+
+// PeerDisconnected decrements the number of connected peers and decrements number of idle or active connections
+func (m *Metrics) PeerDisconnected(id string) {
+ m.peers.Add(m.ctx, -1)
+ m.mutexActivity.Lock()
+ defer m.mutexActivity.Unlock()
+
+ delete(m.peerLastActive, id)
+}
+
+// PeerActivity increases the active connections
+func (m *Metrics) PeerActivity(peerID string) {
+ select {
+ case m.peerActivityChan <- peerID:
+ default:
+ log.Tracef("peer activity channel is full, dropping activity metrics for peer %s", peerID)
+ }
+}
+
+func (m *Metrics) calculateActiveIdleConnections() (int64, int64) {
+ active, idle := int64(0), int64(0)
+ m.mutexActivity.Lock()
+ defer m.mutexActivity.Unlock()
+
+ for _, lastActive := range m.peerLastActive {
+ if time.Since(lastActive) > idleTimeout {
+ idle++
+ } else {
+ active++
+ }
+ }
+ return active, idle
+}
+
+func (m *Metrics) readPeerActivity() {
+ for {
+ select {
+ case peerID := <-m.peerActivityChan:
+ m.mutexActivity.Lock()
+ m.peerLastActive[peerID] = time.Now()
+ m.mutexActivity.Unlock()
+ case <-m.ctx.Done():
+ return
+ }
+ }
+}
diff --git a/relay/server/listener/listener.go b/relay/server/listener/listener.go
new file mode 100644
index 000000000..535c8bcd9
--- /dev/null
+++ b/relay/server/listener/listener.go
@@ -0,0 +1,11 @@
+package listener
+
+import (
+ "context"
+ "net"
+)
+
+type Listener interface {
+ Listen(func(conn net.Conn)) error
+ Shutdown(ctx context.Context) error
+}
diff --git a/relay/server/listener/ws/conn.go b/relay/server/listener/ws/conn.go
new file mode 100644
index 000000000..c248963b9
--- /dev/null
+++ b/relay/server/listener/ws/conn.go
@@ -0,0 +1,114 @@
+package ws
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "sync"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+ "nhooyr.io/websocket"
+)
+
+const (
+ writeTimeout = 10 * time.Second
+)
+
+type Conn struct {
+ *websocket.Conn
+ lAddr *net.TCPAddr
+ rAddr *net.TCPAddr
+
+ closed bool
+ closedMu sync.Mutex
+ ctx context.Context
+}
+
+func NewConn(wsConn *websocket.Conn, lAddr, rAddr *net.TCPAddr) *Conn {
+ return &Conn{
+ Conn: wsConn,
+ lAddr: lAddr,
+ rAddr: rAddr,
+ ctx: context.Background(),
+ }
+}
+
+func (c *Conn) Read(b []byte) (n int, err error) {
+ t, r, err := c.Reader(c.ctx)
+ if err != nil {
+ return 0, c.ioErrHandling(err)
+ }
+
+ if t != websocket.MessageBinary {
+ log.Errorf("unexpected message type: %d", t)
+ return 0, fmt.Errorf("unexpected message type")
+ }
+
+ n, err = r.Read(b)
+ if err != nil {
+ return 0, c.ioErrHandling(err)
+ }
+ return n, err
+}
+
+// Write writes a binary message with the given payload.
+// It does not block until fill the internal buffer.
+// If the buffer filled up, wait until the buffer is drained or timeout.
+func (c *Conn) Write(b []byte) (int, error) {
+ ctx, ctxCancel := context.WithTimeout(c.ctx, writeTimeout)
+ defer ctxCancel()
+
+ err := c.Conn.Write(ctx, websocket.MessageBinary, b)
+ return len(b), err
+}
+
+func (c *Conn) LocalAddr() net.Addr {
+ return c.lAddr
+}
+
+func (c *Conn) RemoteAddr() net.Addr {
+ return c.rAddr
+}
+
+func (c *Conn) SetReadDeadline(t time.Time) error {
+ return fmt.Errorf("SetReadDeadline is not implemented")
+}
+
+func (c *Conn) SetWriteDeadline(t time.Time) error {
+ return fmt.Errorf("SetWriteDeadline is not implemented")
+}
+
+func (c *Conn) SetDeadline(t time.Time) error {
+ return fmt.Errorf("SetDeadline is not implemented")
+}
+
+func (c *Conn) Close() error {
+ c.closedMu.Lock()
+ c.closed = true
+ c.closedMu.Unlock()
+ return c.Conn.CloseNow()
+}
+
+func (c *Conn) isClosed() bool {
+ c.closedMu.Lock()
+ defer c.closedMu.Unlock()
+ return c.closed
+}
+
+func (c *Conn) ioErrHandling(err error) error {
+ if c.isClosed() {
+ return io.EOF
+ }
+
+ var wErr *websocket.CloseError
+ if !errors.As(err, &wErr) {
+ return err
+ }
+ if wErr.Code == websocket.StatusNormalClosure {
+ return io.EOF
+ }
+ return err
+}
diff --git a/relay/server/listener/ws/listener.go b/relay/server/listener/ws/listener.go
new file mode 100644
index 000000000..10bfbe44d
--- /dev/null
+++ b/relay/server/listener/ws/listener.go
@@ -0,0 +1,92 @@
+package ws
+
+import (
+ "context"
+ "crypto/tls"
+ "errors"
+ "fmt"
+ "net"
+ "net/http"
+
+ log "github.com/sirupsen/logrus"
+ "nhooyr.io/websocket"
+)
+
+// URLPath is the path for the websocket connection.
+const URLPath = "/relay"
+
+type Listener struct {
+ // Address is the address to listen on.
+ Address string
+ // TLSConfig is the TLS configuration for the server.
+ TLSConfig *tls.Config
+
+ server *http.Server
+ acceptFn func(conn net.Conn)
+}
+
+func (l *Listener) Listen(acceptFn func(conn net.Conn)) error {
+ l.acceptFn = acceptFn
+ mux := http.NewServeMux()
+ mux.HandleFunc(URLPath, l.onAccept)
+
+ l.server = &http.Server{
+ Addr: l.Address,
+ Handler: mux,
+ TLSConfig: l.TLSConfig,
+ }
+
+ log.Infof("WS server listening address: %s", l.Address)
+ var err error
+ if l.TLSConfig != nil {
+ err = l.server.ListenAndServeTLS("", "")
+ } else {
+ err = l.server.ListenAndServe()
+ }
+ if errors.Is(err, http.ErrServerClosed) {
+ return nil
+ }
+ return err
+}
+
+func (l *Listener) Shutdown(ctx context.Context) error {
+ if l.server == nil {
+ return nil
+ }
+
+ log.Infof("stop WS listener")
+ if err := l.server.Shutdown(ctx); err != nil {
+ return fmt.Errorf("server shutdown failed: %v", err)
+ }
+ log.Infof("WS listener stopped")
+ return nil
+}
+
+func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) {
+ wsConn, err := websocket.Accept(w, r, nil)
+ if err != nil {
+ log.Errorf("failed to accept ws connection from %s: %s", r.RemoteAddr, err)
+ return
+ }
+
+ rAddr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr)
+ if err != nil {
+ err = wsConn.Close(websocket.StatusInternalError, "internal error")
+ if err != nil {
+ log.Errorf("failed to close ws connection: %s", err)
+ }
+ return
+ }
+
+ lAddr, err := net.ResolveTCPAddr("tcp", l.server.Addr)
+ if err != nil {
+ err = wsConn.Close(websocket.StatusInternalError, "internal error")
+ if err != nil {
+ log.Errorf("failed to close ws connection: %s", err)
+ }
+ return
+ }
+
+ conn := NewConn(wsConn, lAddr, rAddr)
+ l.acceptFn(conn)
+}
diff --git a/relay/server/peer.go b/relay/server/peer.go
new file mode 100644
index 000000000..a9c542f84
--- /dev/null
+++ b/relay/server/peer.go
@@ -0,0 +1,212 @@
+package server
+
+import (
+ "context"
+ "io"
+ "net"
+ "sync"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/relay/healthcheck"
+ "github.com/netbirdio/netbird/relay/messages"
+ "github.com/netbirdio/netbird/relay/metrics"
+)
+
+const (
+ bufferSize = 8820
+)
+
+// Peer represents a peer connection
+type Peer struct {
+ metrics *metrics.Metrics
+ log *log.Entry
+ idS string
+ idB []byte
+ conn net.Conn
+ connMu sync.RWMutex
+ store *Store
+}
+
+// NewPeer creates a new Peer instance and prepare custom logging
+func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) *Peer {
+ stringID := messages.HashIDToString(id)
+ return &Peer{
+ metrics: metrics,
+ log: log.WithField("peer_id", stringID),
+ idS: stringID,
+ idB: id,
+ conn: conn,
+ store: store,
+ }
+}
+
+// Work reads data from the connection
+// It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle
+// the message accordingly.
+func (p *Peer) Work() {
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ hc := healthcheck.NewSender(p.log)
+ go hc.StartHealthCheck(ctx)
+ go p.handleHealthcheckEvents(ctx, hc)
+
+ buf := make([]byte, bufferSize)
+ for {
+ n, err := p.conn.Read(buf)
+ if err != nil {
+ if err != io.EOF {
+ p.log.Errorf("failed to read message: %s", err)
+ }
+ return
+ }
+
+ if n == 0 {
+ p.log.Errorf("received empty message")
+ return
+ }
+
+ msg := buf[:n]
+
+ _, err = messages.ValidateVersion(msg)
+ if err != nil {
+ p.log.Warnf("failed to validate protocol version: %s", err)
+ return
+ }
+
+ msgType, err := messages.DetermineClientMessageType(msg[messages.SizeOfVersionByte:])
+ if err != nil {
+ p.log.Errorf("failed to determine message type: %s", err)
+ return
+ }
+
+ p.handleMsgType(ctx, msgType, hc, n, msg)
+ }
+}
+
+func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *healthcheck.Sender, n int, msg []byte) {
+ switch msgType {
+ case messages.MsgTypeHealthCheck:
+ hc.OnHCResponse()
+ case messages.MsgTypeTransport:
+ p.metrics.TransferBytesRecv.Add(ctx, int64(n))
+ p.metrics.PeerActivity(p.String())
+ p.handleTransportMsg(msg)
+ case messages.MsgTypeClose:
+ p.log.Infof("peer exited gracefully")
+ if err := p.conn.Close(); err != nil {
+ log.Errorf("failed to close connection to peer: %s", err)
+ }
+ default:
+ p.log.Warnf("received unexpected message type: %s", msgType)
+ }
+}
+
+// Write writes data to the connection
+func (p *Peer) Write(b []byte) (int, error) {
+ p.connMu.RLock()
+ defer p.connMu.RUnlock()
+ return p.conn.Write(b)
+}
+
+// CloseGracefully closes the connection with the peer gracefully. Send a close message to the client and close the
+// connection.
+func (p *Peer) CloseGracefully(ctx context.Context) {
+ p.connMu.Lock()
+ defer p.connMu.Unlock()
+ err := p.writeWithTimeout(ctx, messages.MarshalCloseMsg())
+ if err != nil {
+ p.log.Errorf("failed to send close message to peer: %s", p.String())
+ }
+
+ err = p.conn.Close()
+ if err != nil {
+ p.log.Errorf("failed to close connection to peer: %s", err)
+ }
+}
+
+func (p *Peer) Close() {
+ p.connMu.Lock()
+ defer p.connMu.Unlock()
+
+ if err := p.conn.Close(); err != nil {
+ p.log.Errorf("failed to close connection to peer: %s", err)
+ }
+}
+
+// String returns the peer ID
+func (p *Peer) String() string {
+ return p.idS
+}
+
+func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error {
+ ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
+ defer cancel()
+
+ writeDone := make(chan struct{})
+ var err error
+ go func() {
+ _, err = p.conn.Write(buf)
+ close(writeDone)
+ }()
+
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-writeDone:
+ return err
+ }
+}
+
+func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Sender) {
+ for {
+ select {
+ case <-hc.HealthCheck:
+ _, err := p.Write(messages.MarshalHealthcheck())
+ if err != nil {
+ p.log.Errorf("failed to send healthcheck message: %s", err)
+ return
+ }
+ case <-hc.Timeout:
+ p.log.Errorf("peer healthcheck timeout")
+ err := p.conn.Close()
+ if err != nil {
+ p.log.Errorf("failed to close connection to peer: %s", err)
+ }
+ p.log.Info("peer connection closed due healthcheck timeout")
+ return
+ case <-ctx.Done():
+ return
+ }
+ }
+}
+
+func (p *Peer) handleTransportMsg(msg []byte) {
+ peerID, err := messages.UnmarshalTransportID(msg[messages.SizeOfProtoHeader:])
+ if err != nil {
+ p.log.Errorf("failed to unmarshal transport message: %s", err)
+ return
+ }
+
+ stringPeerID := messages.HashIDToString(peerID)
+ dp, ok := p.store.Peer(stringPeerID)
+ if !ok {
+ p.log.Debugf("peer not found: %s", stringPeerID)
+ return
+ }
+
+ err = messages.UpdateTransportMsg(msg[messages.SizeOfProtoHeader:], p.idB)
+ if err != nil {
+ p.log.Errorf("failed to update transport message: %s", err)
+ return
+ }
+
+ n, err := dp.Write(msg)
+ if err != nil {
+ p.log.Errorf("failed to write transport message to: %s", dp.String())
+ return
+ }
+ p.metrics.TransferBytesSent.Add(context.Background(), int64(n))
+}
diff --git a/relay/server/relay.go b/relay/server/relay.go
new file mode 100644
index 000000000..76c01a697
--- /dev/null
+++ b/relay/server/relay.go
@@ -0,0 +1,249 @@
+package server
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "net/url"
+ "strings"
+ "sync"
+
+ log "github.com/sirupsen/logrus"
+ "go.opentelemetry.io/otel/metric"
+
+ "github.com/netbirdio/netbird/relay/auth"
+ "github.com/netbirdio/netbird/relay/messages"
+ //nolint:staticcheck
+ "github.com/netbirdio/netbird/relay/messages/address"
+ //nolint:staticcheck
+ authmsg "github.com/netbirdio/netbird/relay/messages/auth"
+ "github.com/netbirdio/netbird/relay/metrics"
+)
+
+// Relay represents the relay server
+type Relay struct {
+ metrics *metrics.Metrics
+ metricsCancel context.CancelFunc
+ validator auth.Validator
+
+ store *Store
+ instanceURL string
+
+ closed bool
+ closeMu sync.RWMutex
+}
+
+// NewRelay creates a new Relay instance
+//
+// Parameters:
+// meter: An instance of metric.Meter from the go.opentelemetry.io/otel/metric package. It is used to create and manage
+// metrics for the relay server.
+// exposedAddress: A string representing the address that the relay server is exposed on. The client will use this
+// address as the relay server's instance URL.
+// tlsSupport: A boolean indicating whether the relay server supports TLS (Transport Layer Security) or not. The
+// instance URL depends on this value.
+// validator: An instance of auth.Validator from the auth package. It is used to validate the authentication of the
+// peers.
+//
+// Returns:
+// A pointer to a Relay instance and an error. If the Relay instance is successfully created, the error is nil.
+// Otherwise, the error contains the details of what went wrong.
+func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, validator auth.Validator) (*Relay, error) {
+ ctx, metricsCancel := context.WithCancel(context.Background())
+ m, err := metrics.NewMetrics(ctx, meter)
+ if err != nil {
+ metricsCancel()
+ return nil, fmt.Errorf("creating app metrics: %v", err)
+ }
+
+ r := &Relay{
+ metrics: m,
+ metricsCancel: metricsCancel,
+ validator: validator,
+ store: NewStore(),
+ }
+
+ r.instanceURL, err = getInstanceURL(exposedAddress, tlsSupport)
+ if err != nil {
+ metricsCancel()
+ return nil, fmt.Errorf("get instance URL: %v", err)
+ }
+
+ return r, nil
+}
+
+// getInstanceURL checks if user supplied a URL scheme otherwise adds to the
+// provided address according to TLS definition and parses the address before returning it
+func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) {
+ addr := exposedAddress
+ split := strings.Split(exposedAddress, "://")
+ switch {
+ case len(split) == 1 && tlsSupported:
+ addr = "rels://" + exposedAddress
+ case len(split) == 1 && !tlsSupported:
+ addr = "rel://" + exposedAddress
+ case len(split) > 2:
+ return "", fmt.Errorf("invalid exposed address: %s", exposedAddress)
+ }
+
+ parsedURL, err := url.ParseRequestURI(addr)
+ if err != nil {
+ return "", fmt.Errorf("invalid exposed address: %v", err)
+ }
+
+ if parsedURL.Scheme != "rel" && parsedURL.Scheme != "rels" {
+ return "", fmt.Errorf("invalid scheme: %s", parsedURL.Scheme)
+ }
+
+ return parsedURL.String(), nil
+}
+
+// Accept start to handle a new peer connection
+func (r *Relay) Accept(conn net.Conn) {
+ r.closeMu.RLock()
+ defer r.closeMu.RUnlock()
+ if r.closed {
+ return
+ }
+
+ peerID, err := r.handshake(conn)
+ if err != nil {
+ log.Errorf("failed to handshake: %s", err)
+ cErr := conn.Close()
+ if cErr != nil {
+ log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr)
+ }
+ return
+ }
+
+ peer := NewPeer(r.metrics, peerID, conn, r.store)
+ peer.log.Infof("peer connected from: %s", conn.RemoteAddr())
+ r.store.AddPeer(peer)
+ r.metrics.PeerConnected(peer.String())
+ go func() {
+ peer.Work()
+ r.store.DeletePeer(peer)
+ peer.log.Debugf("relay connection closed")
+ r.metrics.PeerDisconnected(peer.String())
+ }()
+}
+
+// Shutdown closes the relay server
+// It closes the connection with all peers in gracefully and stops accepting new connections.
+func (r *Relay) Shutdown(ctx context.Context) {
+ log.Infof("close connection with all peers")
+ r.closeMu.Lock()
+ wg := sync.WaitGroup{}
+ peers := r.store.Peers()
+ for _, peer := range peers {
+ wg.Add(1)
+ go func(p *Peer) {
+ p.CloseGracefully(ctx)
+ wg.Done()
+ }(peer)
+ }
+ wg.Wait()
+ r.metricsCancel()
+ r.closeMu.Unlock()
+}
+
+// InstanceURL returns the instance URL of the relay server
+func (r *Relay) InstanceURL() string {
+ return r.instanceURL
+}
+
+func (r *Relay) handshake(conn net.Conn) ([]byte, error) {
+ buf := make([]byte, messages.MaxHandshakeSize)
+ n, err := conn.Read(buf)
+ if err != nil {
+ return nil, fmt.Errorf("read from %s: %w", conn.RemoteAddr(), err)
+ }
+
+ _, err = messages.ValidateVersion(buf[:n])
+ if err != nil {
+ return nil, fmt.Errorf("validate version from %s: %w", conn.RemoteAddr(), err)
+ }
+
+ msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n])
+ if err != nil {
+ return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err)
+ }
+
+ var (
+ responseMsg []byte
+ peerID []byte
+ )
+ switch msgType {
+ //nolint:staticcheck
+ case messages.MsgTypeHello:
+ peerID, responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
+ case messages.MsgTypeAuth:
+ peerID, responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr())
+ default:
+ return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr())
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ _, err = conn.Write(responseMsg)
+ if err != nil {
+ return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err)
+ }
+
+ return peerID, nil
+}
+
+func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, []byte, error) {
+ //nolint:staticcheck
+ rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf)
+ if err != nil {
+ return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
+ }
+
+ peerID := messages.HashIDToString(rawPeerID)
+ log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, remoteAddr)
+
+ authMsg, err := authmsg.UnmarshalMsg(authData)
+ if err != nil {
+ return nil, nil, fmt.Errorf("unmarshal auth message: %w", err)
+ }
+
+ //nolint:staticcheck
+ if err := r.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil {
+ return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err)
+ }
+
+ addr := &address.Address{URL: r.instanceURL}
+ addrData, err := addr.Marshal()
+ if err != nil {
+ return nil, nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err)
+ }
+
+ //nolint:staticcheck
+ responseMsg, err := messages.MarshalHelloResponse(addrData)
+ if err != nil {
+ return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err)
+ }
+ return rawPeerID, responseMsg, nil
+}
+
+func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, []byte, error) {
+ rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf)
+ if err != nil {
+ return nil, nil, fmt.Errorf("unmarshal hello message: %w", err)
+ }
+
+ peerID := messages.HashIDToString(rawPeerID)
+
+ if err := r.validator.Validate(authPayload); err != nil {
+ return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err)
+ }
+
+ responseMsg, err := messages.MarshalAuthResponse(r.instanceURL)
+ if err != nil {
+ return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err)
+ }
+
+ return rawPeerID, responseMsg, nil
+}
diff --git a/relay/server/relay_test.go b/relay/server/relay_test.go
new file mode 100644
index 000000000..062039ab9
--- /dev/null
+++ b/relay/server/relay_test.go
@@ -0,0 +1,36 @@
+package server
+
+import "testing"
+
+func TestGetInstanceURL(t *testing.T) {
+ tests := []struct {
+ name string
+ exposedAddress string
+ tlsSupported bool
+ expectedURL string
+ expectError bool
+ }{
+ {"Valid address with TLS", "example.com", true, "rels://example.com", false},
+ {"Valid address without TLS", "example.com", false, "rel://example.com", false},
+ {"Valid address with scheme", "rel://example.com", false, "rel://example.com", false},
+ {"Valid address with non TLS scheme and TLS true", "rel://example.com", true, "rel://example.com", false},
+ {"Valid address with TLS scheme", "rels://example.com", true, "rels://example.com", false},
+ {"Valid address with TLS scheme and TLS false", "rels://example.com", false, "rels://example.com", false},
+ {"Valid address with TLS scheme and custom port", "rels://example.com:9300", true, "rels://example.com:9300", false},
+ {"Invalid address with multiple schemes", "rel://rels://example.com", false, "", true},
+ {"Invalid address with unsupported scheme", "http://example.com", false, "", true},
+ {"Invalid address format", "://example.com", false, "", true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ url, err := getInstanceURL(tt.exposedAddress, tt.tlsSupported)
+ if (err != nil) != tt.expectError {
+ t.Errorf("expected error: %v, got: %v", tt.expectError, err)
+ }
+ if url != tt.expectedURL {
+ t.Errorf("expected URL: %s, got: %s", tt.expectedURL, url)
+ }
+ })
+ }
+}
diff --git a/relay/server/server.go b/relay/server/server.go
new file mode 100644
index 000000000..0036e2390
--- /dev/null
+++ b/relay/server/server.go
@@ -0,0 +1,76 @@
+package server
+
+import (
+ "context"
+ "crypto/tls"
+
+ log "github.com/sirupsen/logrus"
+ "go.opentelemetry.io/otel/metric"
+
+ "github.com/netbirdio/netbird/relay/auth"
+ "github.com/netbirdio/netbird/relay/server/listener"
+ "github.com/netbirdio/netbird/relay/server/listener/ws"
+)
+
+// ListenerConfig is the configuration for the listener.
+// Address: the address to bind the listener to. It could be an address behind a reverse proxy.
+// TLSConfig: the TLS configuration for the listener.
+type ListenerConfig struct {
+ Address string
+ TLSConfig *tls.Config
+}
+
+// Server is the main entry point for the relay server.
+// It is the gate between the WebSocket listener and the Relay server logic.
+// In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method.
+type Server struct {
+ relay *Relay
+ wSListener listener.Listener
+}
+
+// NewServer creates a new relay server instance.
+// meter: the OpenTelemetry meter
+// exposedAddress: this address will be used as the instance URL. It should be a domain:port format.
+// tlsSupport: if true, the server will support TLS
+// authValidator: the auth validator to use for the server
+func NewServer(meter metric.Meter, exposedAddress string, tlsSupport bool, authValidator auth.Validator) (*Server, error) {
+ relay, err := NewRelay(meter, exposedAddress, tlsSupport, authValidator)
+ if err != nil {
+ return nil, err
+ }
+ return &Server{
+ relay: relay,
+ }, nil
+}
+
+// Listen starts the relay server.
+func (r *Server) Listen(cfg ListenerConfig) error {
+ r.wSListener = &ws.Listener{
+ Address: cfg.Address,
+ TLSConfig: cfg.TLSConfig,
+ }
+
+ wslErr := r.wSListener.Listen(r.relay.Accept)
+ if wslErr != nil {
+ log.Errorf("failed to bind ws server: %s", wslErr)
+ }
+
+ return wslErr
+}
+
+// Shutdown stops the relay server. If there are active connections, they will be closed gracefully. In case of a context,
+// the connections will be forcefully closed.
+func (r *Server) Shutdown(ctx context.Context) (err error) {
+ // stop service new connections
+ if r.wSListener != nil {
+ err = r.wSListener.Shutdown(ctx)
+ }
+
+ r.relay.Shutdown(ctx)
+ return
+}
+
+// InstanceURL returns the instance URL of the relay server.
+func (r *Server) InstanceURL() string {
+ return r.relay.instanceURL
+}
diff --git a/relay/server/store.go b/relay/server/store.go
new file mode 100644
index 000000000..4288e62c5
--- /dev/null
+++ b/relay/server/store.go
@@ -0,0 +1,68 @@
+package server
+
+import (
+ "sync"
+)
+
+// Store is a thread-safe store of peers
+// It is used to store the peers that are connected to the relay server
+type Store struct {
+ peers map[string]*Peer // consider to use [32]byte as key. The Peer(id string) would be faster
+ peersLock sync.RWMutex
+}
+
+// NewStore creates a new Store instance
+func NewStore() *Store {
+ return &Store{
+ peers: make(map[string]*Peer),
+ }
+}
+
+// AddPeer adds a peer to the store
+func (s *Store) AddPeer(peer *Peer) {
+ s.peersLock.Lock()
+ defer s.peersLock.Unlock()
+ odlPeer, ok := s.peers[peer.String()]
+ if ok {
+ odlPeer.Close()
+ }
+
+ s.peers[peer.String()] = peer
+}
+
+// DeletePeer deletes a peer from the store
+func (s *Store) DeletePeer(peer *Peer) {
+ s.peersLock.Lock()
+ defer s.peersLock.Unlock()
+
+ dp, ok := s.peers[peer.String()]
+ if !ok {
+ return
+ }
+ if dp != peer {
+ return
+ }
+
+ delete(s.peers, peer.String())
+}
+
+// Peer returns a peer by its ID
+func (s *Store) Peer(id string) (*Peer, bool) {
+ s.peersLock.RLock()
+ defer s.peersLock.RUnlock()
+
+ p, ok := s.peers[id]
+ return p, ok
+}
+
+// Peers returns all the peers in the store
+func (s *Store) Peers() []*Peer {
+ s.peersLock.RLock()
+ defer s.peersLock.RUnlock()
+
+ peers := make([]*Peer, 0, len(s.peers))
+ for _, p := range s.peers {
+ peers = append(peers, p)
+ }
+ return peers
+}
diff --git a/relay/server/store_test.go b/relay/server/store_test.go
new file mode 100644
index 000000000..41c7baa92
--- /dev/null
+++ b/relay/server/store_test.go
@@ -0,0 +1,85 @@
+package server
+
+import (
+ "context"
+ "net"
+ "testing"
+ "time"
+
+ "go.opentelemetry.io/otel"
+
+ "github.com/netbirdio/netbird/relay/metrics"
+)
+
+type mockConn struct {
+}
+
+func (m mockConn) Read(b []byte) (n int, err error) {
+ //TODO implement me
+ panic("implement me")
+}
+
+func (m mockConn) Write(b []byte) (n int, err error) {
+ //TODO implement me
+ panic("implement me")
+}
+
+func (m mockConn) Close() error {
+ return nil
+}
+
+func (m mockConn) LocalAddr() net.Addr {
+ //TODO implement me
+ panic("implement me")
+}
+
+func (m mockConn) RemoteAddr() net.Addr {
+ //TODO implement me
+ panic("implement me")
+}
+
+func (m mockConn) SetDeadline(t time.Time) error {
+ //TODO implement me
+ panic("implement me")
+}
+
+func (m mockConn) SetReadDeadline(t time.Time) error {
+ //TODO implement me
+ panic("implement me")
+}
+
+func (m mockConn) SetWriteDeadline(t time.Time) error {
+ //TODO implement me
+ panic("implement me")
+}
+
+func TestStore_DeletePeer(t *testing.T) {
+ s := NewStore()
+
+ m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
+
+ p := NewPeer(m, []byte("peer_one"), nil, nil)
+ s.AddPeer(p)
+ s.DeletePeer(p)
+ if _, ok := s.Peer(p.String()); ok {
+ t.Errorf("peer was not deleted")
+ }
+}
+
+func TestStore_DeleteDeprecatedPeer(t *testing.T) {
+ s := NewStore()
+
+ m, _ := metrics.NewMetrics(context.Background(), otel.Meter(""))
+
+ conn := &mockConn{}
+ p1 := NewPeer(m, []byte("peer_id"), conn, nil)
+ p2 := NewPeer(m, []byte("peer_id"), conn, nil)
+
+ s.AddPeer(p1)
+ s.AddPeer(p2)
+ s.DeletePeer(p1)
+
+ if _, ok := s.Peer(p2.String()); !ok {
+ t.Errorf("second peer was deleted")
+ }
+}
diff --git a/relay/test/benchmark_test.go b/relay/test/benchmark_test.go
new file mode 100644
index 000000000..ec2aa488c
--- /dev/null
+++ b/relay/test/benchmark_test.go
@@ -0,0 +1,386 @@
+package test
+
+import (
+ "context"
+ "crypto/rand"
+ "fmt"
+ "net"
+ "os"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/pion/logging"
+ "github.com/pion/turn/v3"
+ "go.opentelemetry.io/otel"
+
+ "github.com/netbirdio/netbird/relay/auth/allow"
+ "github.com/netbirdio/netbird/relay/auth/hmac"
+ "github.com/netbirdio/netbird/relay/client"
+ "github.com/netbirdio/netbird/relay/server"
+ "github.com/netbirdio/netbird/util"
+)
+
+var (
+ av = &allow.Auth{}
+ hmacTokenStore = &hmac.TokenStore{}
+ pairs = []int{1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100}
+ dataSize = 1024 * 1024 * 10
+)
+
+func TestMain(m *testing.M) {
+ _ = util.InitLog("error", "console")
+ code := m.Run()
+ os.Exit(code)
+}
+
+func TestRelayDataTransfer(t *testing.T) {
+ t.SkipNow() // skip this test on CI because it is a benchmark test
+ testData, err := seedRandomData(dataSize)
+ if err != nil {
+ t.Fatalf("failed to seed random data: %s", err)
+ }
+
+ for _, peerPairs := range pairs {
+ t.Run(fmt.Sprintf("peerPairs-%d", peerPairs), func(t *testing.T) {
+ transfer(t, testData, peerPairs)
+ })
+ }
+}
+
+// TestTurnDataTransfer run turn server:
+// docker run --rm --name coturn -d --network=host coturn/coturn --user test:test
+func TestTurnDataTransfer(t *testing.T) {
+ t.SkipNow() // skip this test on CI because it is a benchmark test
+ testData, err := seedRandomData(dataSize)
+ if err != nil {
+ t.Fatalf("failed to seed random data: %s", err)
+ }
+
+ for _, peerPairs := range pairs {
+ t.Run(fmt.Sprintf("peerPairs-%d", peerPairs), func(t *testing.T) {
+ runTurnTest(t, testData, peerPairs)
+ })
+ }
+}
+
+func transfer(t *testing.T, testData []byte, peerPairs int) {
+ t.Helper()
+ ctx := context.Background()
+ port := 35000 + peerPairs
+ serverAddress := fmt.Sprintf("127.0.0.1:%d", port)
+ serverConnURL := fmt.Sprintf("rel://%s", serverAddress)
+
+ srv, err := server.NewServer(otel.Meter(""), serverConnURL, false, av)
+ if err != nil {
+ t.Fatalf("failed to create server: %s", err)
+ }
+ errChan := make(chan error, 1)
+ go func() {
+ listenCfg := server.ListenerConfig{Address: serverAddress}
+ err := srv.Listen(listenCfg)
+ if err != nil {
+ errChan <- err
+ }
+ }()
+
+ defer func() {
+ err := srv.Shutdown(ctx)
+ if err != nil {
+ t.Errorf("failed to close server: %s", err)
+ }
+ }()
+
+ // wait for server to start
+ if err := waitForServerToStart(errChan); err != nil {
+ t.Fatalf("failed to start server: %s", err)
+ }
+
+ clientsSender := make([]*client.Client, peerPairs)
+ for i := 0; i < cap(clientsSender); i++ {
+ c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i))
+ err := c.Connect()
+ if err != nil {
+ t.Fatalf("failed to connect to server: %s", err)
+ }
+ clientsSender[i] = c
+ }
+
+ clientsReceiver := make([]*client.Client, peerPairs)
+ for i := 0; i < cap(clientsReceiver); i++ {
+ c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i))
+ err := c.Connect()
+ if err != nil {
+ t.Fatalf("failed to connect to server: %s", err)
+ }
+ clientsReceiver[i] = c
+ }
+
+ connsSender := make([]net.Conn, 0, peerPairs)
+ connsReceiver := make([]net.Conn, 0, peerPairs)
+ for i := 0; i < len(clientsSender); i++ {
+ conn, err := clientsSender[i].OpenConn("receiver-" + fmt.Sprint(i))
+ if err != nil {
+ t.Fatalf("failed to bind channel: %s", err)
+ }
+ connsSender = append(connsSender, conn)
+
+ conn, err = clientsReceiver[i].OpenConn("sender-" + fmt.Sprint(i))
+ if err != nil {
+ t.Fatalf("failed to bind channel: %s", err)
+ }
+ connsReceiver = append(connsReceiver, conn)
+ }
+
+ var transferDuration []time.Duration
+ wg := sync.WaitGroup{}
+ var writeErr error
+ var readErr error
+ for i := 0; i < len(connsSender); i++ {
+ wg.Add(2)
+ start := time.Now()
+ go func(i int) {
+ defer wg.Done()
+ pieceSize := 1024
+ testDataLen := len(testData)
+
+ for j := 0; j < testDataLen; j += pieceSize {
+ end := j + pieceSize
+ if end > testDataLen {
+ end = testDataLen
+ }
+ _, writeErr = connsSender[i].Write(testData[j:end])
+ if writeErr != nil {
+ return
+ }
+ }
+
+ }(i)
+
+ go func(i int, start time.Time) {
+ defer wg.Done()
+ buf := make([]byte, 8192)
+ rcv := 0
+ var n int
+ for receivedSize := 0; receivedSize < len(testData); {
+
+ n, readErr = connsReceiver[i].Read(buf)
+ if readErr != nil {
+ return
+ }
+
+ receivedSize += n
+ rcv += n
+ }
+ transferDuration = append(transferDuration, time.Since(start))
+ }(i, start)
+ }
+
+ wg.Wait()
+
+ if writeErr != nil {
+ t.Fatalf("failed to write to channel: %s", err)
+ }
+
+ if readErr != nil {
+ t.Fatalf("failed to read from channel: %s", err)
+ }
+
+ // calculate the megabytes per second from the average transferDuration against the dataSize
+ var totalDuration time.Duration
+ for _, d := range transferDuration {
+ totalDuration += d
+ }
+ avgDuration := totalDuration / time.Duration(len(transferDuration))
+ mbps := float64(len(testData)) / avgDuration.Seconds() / 1024 / 1024
+ t.Logf("average transfer duration: %s", avgDuration)
+ t.Logf("average transfer speed: %.2f MB/s", mbps)
+
+ for i := 0; i < len(connsSender); i++ {
+ err := connsSender[i].Close()
+ if err != nil {
+ t.Errorf("failed to close connection: %s", err)
+ }
+
+ err = connsReceiver[i].Close()
+ if err != nil {
+ t.Errorf("failed to close connection: %s", err)
+ }
+ }
+}
+
+func runTurnTest(t *testing.T, testData []byte, maxPairs int) {
+ t.Helper()
+ var transferDuration []time.Duration
+ var wg sync.WaitGroup
+
+ for i := 0; i < maxPairs; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ d := runTurnDataTransfer(t, testData)
+ transferDuration = append(transferDuration, d)
+ }()
+
+ }
+ wg.Wait()
+
+ var totalDuration time.Duration
+ for _, d := range transferDuration {
+ totalDuration += d
+ }
+ avgDuration := totalDuration / time.Duration(len(transferDuration))
+ mbps := float64(len(testData)) / avgDuration.Seconds() / 1024 / 1024
+ t.Logf("average transfer duration: %s", avgDuration)
+ t.Logf("average transfer speed: %.2f MB/s", mbps)
+}
+
+func runTurnDataTransfer(t *testing.T, testData []byte) time.Duration {
+ t.Helper()
+ testDataLen := len(testData)
+ relayAddress := "192.168.0.10:3478"
+ conn, err := net.Dial("tcp", relayAddress)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func(conn net.Conn) {
+ _ = conn.Close()
+ }(conn)
+
+ turnClient, err := getTurnClient(t, relayAddress, conn)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer turnClient.Close()
+
+ relayConn, err := turnClient.Allocate()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func(relayConn net.PacketConn) {
+ _ = relayConn.Close()
+ }(relayConn)
+
+ receiverConn, err := net.Dial("udp", relayConn.LocalAddr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func(receiverConn net.Conn) {
+ _ = receiverConn.Close()
+ }(receiverConn)
+
+ var (
+ tb int
+ start time.Time
+ timerInit bool
+ readDone = make(chan struct{})
+ ack = make([]byte, 1)
+ )
+ go func() {
+ defer func() {
+ readDone <- struct{}{}
+ }()
+ buff := make([]byte, 8192)
+ for {
+ n, e := receiverConn.Read(buff)
+ if e != nil {
+ return
+ }
+ if !timerInit {
+ start = time.Now()
+ timerInit = true
+ }
+ tb += n
+ _, _ = receiverConn.Write(ack)
+
+ if tb >= testDataLen {
+ return
+ }
+ }
+ }()
+
+ pieceSize := 1024
+ ackBuff := make([]byte, 1)
+ pipelineSize := 10
+ for j := 0; j < testDataLen; j += pieceSize {
+ end := j + pieceSize
+ if end > testDataLen {
+ end = testDataLen
+ }
+ _, err := relayConn.WriteTo(testData[j:end], receiverConn.LocalAddr())
+ if err != nil {
+ t.Fatalf("failed to write to channel: %s", err)
+ }
+ if pipelineSize == 0 {
+ _, _, _ = relayConn.ReadFrom(ackBuff)
+ } else {
+ pipelineSize--
+ }
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
+ defer cancel()
+ select {
+ case <-readDone:
+ if tb != testDataLen {
+ t.Fatalf("failed to read all data: %d/%d", tb, testDataLen)
+ }
+ case <-ctx.Done():
+ t.Fatal("timeout")
+ }
+ return time.Since(start)
+}
+
+func getTurnClient(t *testing.T, address string, conn net.Conn) (*turn.Client, error) {
+ t.Helper()
+ // Dial TURN Server
+ addrStr := fmt.Sprintf("%s:%d", address, 443)
+
+ fac := logging.NewDefaultLoggerFactory()
+ //fac.DefaultLogLevel = logging.LogLevelTrace
+
+ // Start a new TURN Client and wrap our net.Conn in a STUNConn
+ // This allows us to simulate datagram based communication over a net.Conn
+ cfg := &turn.ClientConfig{
+ TURNServerAddr: address,
+ Conn: turn.NewSTUNConn(conn),
+ Username: "test",
+ Password: "test",
+ LoggerFactory: fac,
+ }
+
+ client, err := turn.NewClient(cfg)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create TURN client for server %s: %s", addrStr, err)
+ }
+
+ // Start listening on the conn provided.
+ err = client.Listen()
+ if err != nil {
+ client.Close()
+ return nil, fmt.Errorf("failed to listen on TURN client for server %s: %s", addrStr, err)
+ }
+
+ return client, nil
+}
+
+func seedRandomData(size int) ([]byte, error) {
+ token := make([]byte, size)
+ _, err := rand.Read(token)
+ if err != nil {
+ return nil, err
+ }
+ return token, nil
+}
+
+func waitForServerToStart(errChan chan error) error {
+ select {
+ case err := <-errChan:
+ if err != nil {
+ return err
+ }
+ case <-time.After(300 * time.Millisecond):
+ return nil
+ }
+ return nil
+}
diff --git a/relay/testec2/main.go b/relay/testec2/main.go
new file mode 100644
index 000000000..0c8099a5e
--- /dev/null
+++ b/relay/testec2/main.go
@@ -0,0 +1,258 @@
+//go:build linux || darwin
+
+package main
+
+import (
+ "crypto/rand"
+ "flag"
+ "fmt"
+ "net"
+ "os"
+ "sync"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/util"
+)
+
+const (
+ errMsgFailedReadTCP = "failed to read from tcp: %s"
+)
+
+var (
+ dataSize = 1024 * 1024 * 50 // 50MB
+ pairs = []int{1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100}
+ signalListenAddress = ":8081"
+
+ relaySrvAddress string
+ turnSrvAddress string
+ signalURL string
+ udpListener string // used for TURN test
+)
+
+type testResult struct {
+ numOfPairs int
+ duration time.Duration
+ speed float64
+}
+
+func (tr testResult) Speed() string {
+ speed := tr.speed
+ var unit string
+
+ switch {
+ case speed < 1024:
+ unit = "B/s"
+ case speed < 1048576:
+ speed /= 1024
+ unit = "KB/s"
+ case speed < 1073741824:
+ speed /= 1048576
+ unit = "MB/s"
+ default:
+ speed /= 1073741824
+ unit = "GB/s"
+ }
+
+ return fmt.Sprintf("%.2f %s", speed, unit)
+}
+
+func seedRandomData(size int) ([]byte, error) {
+ token := make([]byte, size)
+ _, err := rand.Read(token)
+ if err != nil {
+ return nil, err
+ }
+ return token, nil
+}
+
+func avg(transferDuration []time.Duration) (time.Duration, float64) {
+ var totalDuration time.Duration
+ for _, d := range transferDuration {
+ totalDuration += d
+ }
+ avgDuration := totalDuration / time.Duration(len(transferDuration))
+ bps := float64(dataSize) / avgDuration.Seconds()
+ return avgDuration, bps
+}
+
+func RelayReceiverMain() []testResult {
+ testResults := make([]testResult, 0, len(pairs))
+ for _, p := range pairs {
+ tr := testResult{numOfPairs: p}
+ td := relayReceive(relaySrvAddress, p)
+ tr.duration, tr.speed = avg(td)
+
+ testResults = append(testResults, tr)
+ }
+
+ return testResults
+}
+
+func RelaySenderMain() {
+ log.Infof("starting sender")
+ log.Infof("starting seed phase")
+
+ testData, err := seedRandomData(dataSize)
+ if err != nil {
+ log.Fatalf("failed to seed random data: %s", err)
+ }
+
+ log.Infof("data size: %d", len(testData))
+
+ for n, p := range pairs {
+ log.Infof("running test with %d pairs", p)
+ relayTransfer(relaySrvAddress, testData, p)
+
+ // grant time to prepare new receivers
+ if n < len(pairs)-1 {
+ time.Sleep(3 * time.Second)
+ }
+ }
+
+}
+
+// TRUNSenderMain is the sender
+// - allocate turn clients
+// - send relayed addresses to signal server in batch
+// - wait for signal server to send back addresses in a map
+// - send test data to each address in parallel
+func TRUNSenderMain() {
+ log.Infof("starting TURN sender test")
+
+ log.Infof("starting seed random data: %d", dataSize)
+ testData, err := seedRandomData(dataSize)
+ if err != nil {
+ log.Fatalf("failed to seed random data: %s", err)
+ }
+
+ ss := SignalClient{signalURL}
+
+ for _, p := range pairs {
+ log.Infof("running test with %d pairs", p)
+ turnSender := &TurnSender{}
+
+ createTurnConns(p, turnSender)
+
+ log.Infof("send addresses via signal server: %d", len(turnSender.addresses))
+ clientAddresses, err := ss.SendAddress(turnSender.addresses)
+ if err != nil {
+ log.Fatalf("failed to send address: %s", err)
+ }
+ log.Infof("received addresses: %v", clientAddresses.Address)
+
+ createSenderDevices(turnSender, clientAddresses)
+
+ log.Infof("waiting for tcpListeners to be ready")
+ time.Sleep(2 * time.Second)
+
+ tcpConns := make([]net.Conn, 0, len(turnSender.devices))
+ for i := range turnSender.devices {
+ addr := fmt.Sprintf("10.0.%d.2:9999", i)
+ log.Infof("dialing: %s", addr)
+ tcpConn, err := net.Dial("tcp", addr)
+ if err != nil {
+ log.Fatalf("failed to dial tcp: %s", err)
+ }
+ tcpConns = append(tcpConns, tcpConn)
+ }
+
+ log.Infof("start test data transfer for %d pairs", p)
+ testDataLen := len(testData)
+ wg := sync.WaitGroup{}
+ wg.Add(len(tcpConns))
+ for i, tcpConn := range tcpConns {
+ log.Infof("sending test data to device: %d", i)
+ go runTurnWriting(tcpConn, testData, testDataLen, &wg)
+ }
+ wg.Wait()
+
+ for _, d := range turnSender.devices {
+ _ = d.Close()
+ }
+
+ log.Infof("test finished with %d pairs", p)
+ }
+}
+
+func TURNReaderMain() []testResult {
+ log.Infof("starting TURN receiver test")
+ si := NewSignalService()
+ go func() {
+ log.Infof("starting signal server")
+ err := si.Listen(signalListenAddress)
+ if err != nil {
+ log.Errorf("failed to listen: %s", err)
+ }
+ }()
+
+ testResults := make([]testResult, 0, len(pairs))
+ for range pairs {
+ addresses := <-si.AddressesChan
+ instanceNumber := len(addresses)
+ log.Infof("received addresses: %d", instanceNumber)
+
+ turnReceiver := &TurnReceiver{}
+ err := createDevices(addresses, turnReceiver)
+ if err != nil {
+ log.Fatalf("%s", err)
+ }
+
+ // send client addresses back via signal server
+ si.ClientAddressChan <- turnReceiver.clientAddresses
+
+ durations := make(chan time.Duration, instanceNumber)
+ for _, device := range turnReceiver.devices {
+ go runTurnReading(device, durations)
+ }
+
+ durationsList := make([]time.Duration, 0, instanceNumber)
+ for d := range durations {
+ durationsList = append(durationsList, d)
+ if len(durationsList) == instanceNumber {
+ close(durations)
+ }
+ }
+
+ avgDuration, avgSpeed := avg(durationsList)
+ ts := testResult{
+ numOfPairs: len(durationsList),
+ duration: avgDuration,
+ speed: avgSpeed,
+ }
+ testResults = append(testResults, ts)
+
+ for _, d := range turnReceiver.devices {
+ _ = d.Close()
+ }
+ }
+ return testResults
+}
+
+func main() {
+ var mode string
+
+ _ = util.InitLog("debug", "console")
+ flag.StringVar(&mode, "mode", "sender", "sender or receiver mode")
+ flag.Parse()
+
+ relaySrvAddress = os.Getenv("TEST_RELAY_SERVER") // rel://ip:port
+ turnSrvAddress = os.Getenv("TEST_TURN_SERVER") // ip:3478
+ signalURL = os.Getenv("TEST_SIGNAL_URL") // http://receiver_ip:8081
+ udpListener = os.Getenv("TEST_UDP_LISTENER") // IP:0
+
+ if mode == "receiver" {
+ relayResult := RelayReceiverMain()
+ turnResults := TURNReaderMain()
+ for i := 0; i < len(turnResults); i++ {
+ log.Infof("pairs: %d,\tRelay speed:\t%s,\trelay duration:\t%s", relayResult[i].numOfPairs, relayResult[i].Speed(), relayResult[i].duration)
+ log.Infof("pairs: %d,\tTURN speed:\t%s,\tturn duration:\t%s", turnResults[i].numOfPairs, turnResults[i].Speed(), turnResults[i].duration)
+ }
+ } else {
+ RelaySenderMain()
+ // grant time for receiver to start
+ time.Sleep(3 * time.Second)
+ TRUNSenderMain()
+ }
+}
diff --git a/relay/testec2/relay.go b/relay/testec2/relay.go
new file mode 100644
index 000000000..93d084387
--- /dev/null
+++ b/relay/testec2/relay.go
@@ -0,0 +1,176 @@
+//go:build linux || darwin
+
+package main
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "sync"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+
+ "github.com/netbirdio/netbird/relay/auth/hmac"
+ "github.com/netbirdio/netbird/relay/client"
+)
+
+var (
+ hmacTokenStore = &hmac.TokenStore{}
+)
+
+func relayTransfer(serverConnURL string, testData []byte, peerPairs int) {
+ connsSender := prepareConnsSender(serverConnURL, peerPairs)
+ defer func() {
+ for i := 0; i < len(connsSender); i++ {
+ err := connsSender[i].Close()
+ if err != nil {
+ log.Errorf("failed to close connection: %s", err)
+ }
+ }
+ }()
+
+ wg := sync.WaitGroup{}
+ wg.Add(len(connsSender))
+ for _, conn := range connsSender {
+ go func(conn net.Conn) {
+ defer wg.Done()
+ runWriter(conn, testData)
+ }(conn)
+ }
+ wg.Wait()
+}
+
+func runWriter(conn net.Conn, testData []byte) {
+ si := NewStartInidication(time.Now(), len(testData))
+ _, err := conn.Write(si)
+ if err != nil {
+ log.Errorf("failed to write to channel: %s", err)
+ return
+ }
+ log.Infof("sent start indication")
+
+ pieceSize := 1024
+ testDataLen := len(testData)
+
+ for j := 0; j < testDataLen; j += pieceSize {
+ end := j + pieceSize
+ if end > testDataLen {
+ end = testDataLen
+ }
+ _, writeErr := conn.Write(testData[j:end])
+ if writeErr != nil {
+ log.Errorf("failed to write to channel: %s", writeErr)
+ return
+ }
+ }
+}
+
+func prepareConnsSender(serverConnURL string, peerPairs int) []net.Conn {
+ ctx := context.Background()
+ clientsSender := make([]*client.Client, peerPairs)
+ for i := 0; i < cap(clientsSender); i++ {
+ c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i))
+ if err := c.Connect(); err != nil {
+ log.Fatalf("failed to connect to server: %s", err)
+ }
+ clientsSender[i] = c
+ }
+
+ connsSender := make([]net.Conn, 0, peerPairs)
+ for i := 0; i < len(clientsSender); i++ {
+ conn, err := clientsSender[i].OpenConn("receiver-" + fmt.Sprint(i))
+ if err != nil {
+ log.Fatalf("failed to bind channel: %s", err)
+ }
+ connsSender = append(connsSender, conn)
+ }
+ return connsSender
+}
+
+func relayReceive(serverConnURL string, peerPairs int) []time.Duration {
+ connsReceiver := prepareConnsReceiver(serverConnURL, peerPairs)
+ defer func() {
+ for i := 0; i < len(connsReceiver); i++ {
+ if err := connsReceiver[i].Close(); err != nil {
+ log.Errorf("failed to close connection: %s", err)
+ }
+ }
+ }()
+
+ durations := make(chan time.Duration, len(connsReceiver))
+ wg := sync.WaitGroup{}
+ for _, conn := range connsReceiver {
+ wg.Add(1)
+ go func(conn net.Conn) {
+ defer wg.Done()
+ duration := runReader(conn)
+ durations <- duration
+ }(conn)
+ }
+ wg.Wait()
+
+ durationsList := make([]time.Duration, 0, len(connsReceiver))
+ for d := range durations {
+ durationsList = append(durationsList, d)
+ if len(durationsList) == len(connsReceiver) {
+ close(durations)
+ }
+ }
+
+ return durationsList
+}
+
+func runReader(conn net.Conn) time.Duration {
+ buf := make([]byte, 8192)
+
+ n, readErr := conn.Read(buf)
+ if readErr != nil {
+ log.Errorf("failed to read from channel: %s", readErr)
+ return 0
+ }
+
+ si := DecodeStartIndication(buf[:n])
+ log.Infof("received start indication: %v", si)
+
+ receivedSize, err := conn.Read(buf)
+ if err != nil {
+ log.Fatalf("failed to read from relay: %s", err)
+ }
+ now := time.Now()
+
+ rcv := 0
+ for receivedSize < si.TransferSize {
+ n, readErr = conn.Read(buf)
+ if readErr != nil {
+ log.Errorf("failed to read from channel: %s", readErr)
+ return 0
+ }
+
+ receivedSize += n
+ rcv += n
+ }
+ return time.Since(now)
+}
+
+func prepareConnsReceiver(serverConnURL string, peerPairs int) []net.Conn {
+ clientsReceiver := make([]*client.Client, peerPairs)
+ for i := 0; i < cap(clientsReceiver); i++ {
+ c := client.NewClient(context.Background(), serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i))
+ err := c.Connect()
+ if err != nil {
+ log.Fatalf("failed to connect to server: %s", err)
+ }
+ clientsReceiver[i] = c
+ }
+
+ connsReceiver := make([]net.Conn, 0, peerPairs)
+ for i := 0; i < len(clientsReceiver); i++ {
+ conn, err := clientsReceiver[i].OpenConn("sender-" + fmt.Sprint(i))
+ if err != nil {
+ log.Fatalf("failed to bind channel: %s", err)
+ }
+ connsReceiver = append(connsReceiver, conn)
+ }
+ return connsReceiver
+}
diff --git a/relay/testec2/signal.go b/relay/testec2/signal.go
new file mode 100644
index 000000000..fe93a2fe2
--- /dev/null
+++ b/relay/testec2/signal.go
@@ -0,0 +1,91 @@
+//go:build linux || darwin
+
+package main
+
+import (
+ "bytes"
+ "encoding/json"
+ "net/http"
+
+ log "github.com/sirupsen/logrus"
+)
+
+type PeerAddr struct {
+ Address []string
+}
+
+type ClientPeerAddr struct {
+ Address map[string]string
+}
+
+type Signal struct {
+ AddressesChan chan []string
+ ClientAddressChan chan map[string]string
+}
+
+func NewSignalService() *Signal {
+ return &Signal{
+ AddressesChan: make(chan []string),
+ ClientAddressChan: make(chan map[string]string),
+ }
+}
+
+func (rs *Signal) Listen(listenAddr string) error {
+ http.HandleFunc("/", rs.onNewAddresses)
+ return http.ListenAndServe(listenAddr, nil)
+}
+
+func (rs *Signal) onNewAddresses(w http.ResponseWriter, r *http.Request) {
+ var msg PeerAddr
+ err := json.NewDecoder(r.Body).Decode(&msg)
+ if err != nil {
+ log.Errorf("Error decoding message: %v", err)
+ }
+
+ log.Infof("received addresses: %d", len(msg.Address))
+ rs.AddressesChan <- msg.Address
+ clientAddresses := <-rs.ClientAddressChan
+
+ respMsg := ClientPeerAddr{
+ Address: clientAddresses,
+ }
+ data, err := json.Marshal(respMsg)
+ if err != nil {
+ log.Errorf("Error marshalling message: %v", err)
+ return
+ }
+
+ _, err = w.Write(data)
+ if err != nil {
+ log.Errorf("Error writing response: %v", err)
+ }
+}
+
+type SignalClient struct {
+ SignalURL string
+}
+
+func (ss SignalClient) SendAddress(addresses []string) (*ClientPeerAddr, error) {
+ msg := PeerAddr{
+ Address: addresses,
+ }
+ data, err := json.Marshal(msg)
+ if err != nil {
+ return nil, err
+ }
+
+ response, err := http.Post(ss.SignalURL, "application/json", bytes.NewBuffer(data))
+ if err != nil {
+ return nil, err
+ }
+
+ defer response.Body.Close()
+
+ log.Debugf("wait for signal response")
+ var respPeerAddress ClientPeerAddr
+ err = json.NewDecoder(response.Body).Decode(&respPeerAddress)
+ if err != nil {
+ return nil, err
+ }
+ return &respPeerAddress, nil
+}
diff --git a/relay/testec2/start_msg.go b/relay/testec2/start_msg.go
new file mode 100644
index 000000000..19b65380b
--- /dev/null
+++ b/relay/testec2/start_msg.go
@@ -0,0 +1,39 @@
+//go:build linux || darwin
+
+package main
+
+import (
+ "bytes"
+ "encoding/gob"
+ "time"
+
+ log "github.com/sirupsen/logrus"
+)
+
+type StartIndication struct {
+ Started time.Time
+ TransferSize int
+}
+
+func NewStartInidication(started time.Time, transferSize int) []byte {
+ si := StartIndication{
+ Started: started,
+ TransferSize: transferSize,
+ }
+
+ var data bytes.Buffer
+ err := gob.NewEncoder(&data).Encode(si)
+ if err != nil {
+ log.Fatal("encode error:", err)
+ }
+ return data.Bytes()
+}
+
+func DecodeStartIndication(data []byte) StartIndication {
+ var si StartIndication
+ err := gob.NewDecoder(bytes.NewReader(data)).Decode(&si)
+ if err != nil {
+ log.Fatal("decode error:", err)
+ }
+ return si
+}
diff --git a/relay/testec2/tun/proxy.go b/relay/testec2/tun/proxy.go
new file mode 100644
index 000000000..7d84bece7
--- /dev/null
+++ b/relay/testec2/tun/proxy.go
@@ -0,0 +1,72 @@
+//go:build linux || darwin
+
+package tun
+
+import (
+ "net"
+ "sync/atomic"
+
+ log "github.com/sirupsen/logrus"
+)
+
+type Proxy struct {
+ Device *Device
+ PConn net.PacketConn
+ DstAddr net.Addr
+ shutdownFlag atomic.Bool
+}
+
+func (p *Proxy) Start() {
+ go p.readFromDevice()
+ go p.readFromConn()
+}
+
+func (p *Proxy) Close() {
+ p.shutdownFlag.Store(true)
+}
+
+func (p *Proxy) readFromDevice() {
+ buf := make([]byte, 1500)
+ for {
+ n, err := p.Device.Read(buf)
+ if err != nil {
+ if p.shutdownFlag.Load() {
+ return
+ }
+ log.Errorf("failed to read from device: %s", err)
+ return
+ }
+
+ _, err = p.PConn.WriteTo(buf[:n], p.DstAddr)
+ if err != nil {
+ if p.shutdownFlag.Load() {
+ return
+ }
+ log.Errorf("failed to write to conn: %s", err)
+ return
+ }
+ }
+}
+
+func (p *Proxy) readFromConn() {
+ buf := make([]byte, 1500)
+ for {
+ n, _, err := p.PConn.ReadFrom(buf)
+ if err != nil {
+ if p.shutdownFlag.Load() {
+ return
+ }
+ log.Errorf("failed to read from conn: %s", err)
+ return
+ }
+
+ _, err = p.Device.Write(buf[:n])
+ if err != nil {
+ if p.shutdownFlag.Load() {
+ return
+ }
+ log.Errorf("failed to write to device: %s", err)
+ return
+ }
+ }
+}
diff --git a/relay/testec2/tun/tun.go b/relay/testec2/tun/tun.go
new file mode 100644
index 000000000..5580785ce
--- /dev/null
+++ b/relay/testec2/tun/tun.go
@@ -0,0 +1,110 @@
+//go:build linux || darwin
+
+package tun
+
+import (
+ "net"
+
+ log "github.com/sirupsen/logrus"
+ "github.com/songgao/water"
+ "github.com/vishvananda/netlink"
+)
+
+type Device struct {
+ Name string
+ IP string
+ PConn net.PacketConn
+ DstAddr net.Addr
+
+ iFace *water.Interface
+ proxy *Proxy
+}
+
+func (d *Device) Up() error {
+ cfg := water.Config{
+ DeviceType: water.TUN,
+ PlatformSpecificParams: water.PlatformSpecificParams{
+ Name: d.Name,
+ },
+ }
+ iFace, err := water.New(cfg)
+ if err != nil {
+ return err
+ }
+ d.iFace = iFace
+
+ err = d.assignIP()
+ if err != nil {
+ return err
+ }
+
+ err = d.bringUp()
+ if err != nil {
+ return err
+ }
+
+ d.proxy = &Proxy{
+ Device: d,
+ PConn: d.PConn,
+ DstAddr: d.DstAddr,
+ }
+ d.proxy.Start()
+ return nil
+}
+
+func (d *Device) Close() error {
+ if d.proxy != nil {
+ d.proxy.Close()
+ }
+ if d.iFace != nil {
+ return d.iFace.Close()
+ }
+ return nil
+}
+
+func (d *Device) Read(b []byte) (int, error) {
+ return d.iFace.Read(b)
+}
+
+func (d *Device) Write(b []byte) (int, error) {
+ return d.iFace.Write(b)
+}
+
+func (d *Device) assignIP() error {
+ iface, err := netlink.LinkByName(d.Name)
+ if err != nil {
+ log.Errorf("failed to get TUN device: %v", err)
+ return err
+ }
+
+ ip := net.IPNet{
+ IP: net.ParseIP(d.IP),
+ Mask: net.CIDRMask(24, 32),
+ }
+
+ addr := &netlink.Addr{
+ IPNet: &ip,
+ }
+ err = netlink.AddrAdd(iface, addr)
+ if err != nil {
+ log.Errorf("failed to add IP address: %v", err)
+ return err
+ }
+ return nil
+}
+
+func (d *Device) bringUp() error {
+ iface, err := netlink.LinkByName(d.Name)
+ if err != nil {
+ log.Errorf("failed to get device: %v", err)
+ return err
+ }
+
+ // Bring the interface up
+ err = netlink.LinkSetUp(iface)
+ if err != nil {
+ log.Errorf("failed to set device up: %v", err)
+ return err
+ }
+ return nil
+}
diff --git a/relay/testec2/turn.go b/relay/testec2/turn.go
new file mode 100644
index 000000000..8beb40423
--- /dev/null
+++ b/relay/testec2/turn.go
@@ -0,0 +1,181 @@
+//go:build linux || darwin
+
+package main
+
+import (
+ "fmt"
+ "net"
+ "sync"
+ "time"
+
+ "github.com/netbirdio/netbird/relay/testec2/tun"
+
+ log "github.com/sirupsen/logrus"
+)
+
+type TurnReceiver struct {
+ conns []*net.UDPConn
+ clientAddresses map[string]string
+ devices []*tun.Device
+}
+
+type TurnSender struct {
+ turnConns map[string]*TurnConn
+ addresses []string
+ devices []*tun.Device
+}
+
+func runTurnWriting(tcpConn net.Conn, testData []byte, testDataLen int, wg *sync.WaitGroup) {
+ defer wg.Done()
+ defer tcpConn.Close()
+
+ log.Infof("start to sending test data: %s", tcpConn.RemoteAddr())
+
+ si := NewStartInidication(time.Now(), testDataLen)
+ _, err := tcpConn.Write(si)
+ if err != nil {
+ log.Errorf("failed to write to tcp: %s", err)
+ return
+ }
+
+ pieceSize := 1024
+ for j := 0; j < testDataLen; j += pieceSize {
+ end := j + pieceSize
+ if end > testDataLen {
+ end = testDataLen
+ }
+ _, writeErr := tcpConn.Write(testData[j:end])
+ if writeErr != nil {
+ log.Errorf("failed to write to tcp conn: %s", writeErr)
+ return
+ }
+ }
+
+ // grant time to flush out packages
+ time.Sleep(3 * time.Second)
+}
+
+func createSenderDevices(sender *TurnSender, clientAddresses *ClientPeerAddr) {
+ var i int
+ devices := make([]*tun.Device, 0, len(clientAddresses.Address))
+ for k, v := range clientAddresses.Address {
+ tc, ok := sender.turnConns[k]
+ if !ok {
+ log.Fatalf("failed to find turn conn: %s", k)
+ }
+
+ addr, err := net.ResolveUDPAddr("udp", v)
+ if err != nil {
+ log.Fatalf("failed to resolve udp address: %s", err)
+ }
+ device := &tun.Device{
+ Name: fmt.Sprintf("mtun-sender-%d", i),
+ IP: fmt.Sprintf("10.0.%d.1", i),
+ PConn: tc.relayConn,
+ DstAddr: addr,
+ }
+
+ err = device.Up()
+ if err != nil {
+ log.Fatalf("failed to bring up device: %s", err)
+ }
+
+ devices = append(devices, device)
+ i++
+ }
+ sender.devices = devices
+}
+
+func createTurnConns(p int, sender *TurnSender) {
+ turnConns := make(map[string]*TurnConn)
+ addresses := make([]string, 0, len(pairs))
+ for i := 0; i < p; i++ {
+ tc := AllocateTurnClient(turnSrvAddress)
+ log.Infof("allocated turn client: %s", tc.Address().String())
+ turnConns[tc.Address().String()] = tc
+ addresses = append(addresses, tc.Address().String())
+ }
+
+ sender.turnConns = turnConns
+ sender.addresses = addresses
+}
+
+func runTurnReading(d *tun.Device, durations chan time.Duration) {
+ tcpListener, err := net.Listen("tcp", d.IP+":9999")
+ if err != nil {
+ log.Fatalf("failed to listen on tcp: %s", err)
+ }
+ log := log.WithField("device", tcpListener.Addr())
+
+ tcpConn, err := tcpListener.Accept()
+ if err != nil {
+ _ = tcpListener.Close()
+ log.Fatalf("failed to accept connection: %s", err)
+ }
+ log.Infof("remote peer connected")
+
+ buf := make([]byte, 103)
+ n, err := tcpConn.Read(buf)
+ if err != nil {
+ _ = tcpListener.Close()
+ log.Fatalf(errMsgFailedReadTCP, err)
+ }
+
+ si := DecodeStartIndication(buf[:n])
+ log.Infof("received start indication: %v, %d", si, n)
+
+ buf = make([]byte, 8192)
+ i, err := tcpConn.Read(buf)
+ if err != nil {
+ _ = tcpListener.Close()
+ log.Fatalf(errMsgFailedReadTCP, err)
+ }
+ now := time.Now()
+ for i < si.TransferSize {
+ n, err := tcpConn.Read(buf)
+ if err != nil {
+ _ = tcpListener.Close()
+ log.Fatalf(errMsgFailedReadTCP, err)
+ }
+ i += n
+ }
+ durations <- time.Since(now)
+}
+
+func createDevices(addresses []string, receiver *TurnReceiver) error {
+ receiver.conns = make([]*net.UDPConn, 0, len(addresses))
+ receiver.clientAddresses = make(map[string]string, len(addresses))
+ receiver.devices = make([]*tun.Device, 0, len(addresses))
+ for i, addr := range addresses {
+ localAddr, err := net.ResolveUDPAddr("udp", udpListener)
+ if err != nil {
+ return fmt.Errorf("failed to resolve UDP address: %s", err)
+ }
+
+ conn, err := net.ListenUDP("udp", localAddr)
+ if err != nil {
+ return fmt.Errorf("failed to create UDP connection: %s", err)
+ }
+
+ receiver.conns = append(receiver.conns, conn)
+ receiver.clientAddresses[addr] = conn.LocalAddr().String()
+
+ dstAddr, err := net.ResolveUDPAddr("udp", addr)
+ if err != nil {
+ return fmt.Errorf("failed to resolve address: %s", err)
+ }
+
+ device := &tun.Device{
+ Name: fmt.Sprintf("mtun-%d", i),
+ IP: fmt.Sprintf("10.0.%d.2", i),
+ PConn: conn,
+ DstAddr: dstAddr,
+ }
+
+ if err = device.Up(); err != nil {
+ return fmt.Errorf("failed to bring up device: %s, %s", device.Name, err)
+ }
+ receiver.devices = append(receiver.devices, device)
+ }
+ return nil
+}
diff --git a/relay/testec2/turn_allocator.go b/relay/testec2/turn_allocator.go
new file mode 100644
index 000000000..fd86208df
--- /dev/null
+++ b/relay/testec2/turn_allocator.go
@@ -0,0 +1,83 @@
+//go:build linux || darwin
+
+package main
+
+import (
+ "fmt"
+ "net"
+
+ "github.com/pion/logging"
+ "github.com/pion/turn/v3"
+ log "github.com/sirupsen/logrus"
+)
+
+type TurnConn struct {
+ conn net.Conn
+ turnClient *turn.Client
+ relayConn net.PacketConn
+}
+
+func (tc *TurnConn) Address() net.Addr {
+ return tc.relayConn.LocalAddr()
+}
+
+func (tc *TurnConn) Close() {
+ _ = tc.relayConn.Close()
+ tc.turnClient.Close()
+ _ = tc.conn.Close()
+}
+
+func AllocateTurnClient(serverAddr string) *TurnConn {
+ conn, err := net.Dial("tcp", serverAddr)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ turnClient, err := getTurnClient(serverAddr, conn)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ relayConn, err := turnClient.Allocate()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ return &TurnConn{
+ conn: conn,
+ turnClient: turnClient,
+ relayConn: relayConn,
+ }
+}
+
+func getTurnClient(address string, conn net.Conn) (*turn.Client, error) {
+ // Dial TURN Server
+ addrStr := fmt.Sprintf("%s:%d", address, 443)
+
+ fac := logging.NewDefaultLoggerFactory()
+ //fac.DefaultLogLevel = logging.LogLevelTrace
+
+ // Start a new TURN Client and wrap our net.Conn in a STUNConn
+ // This allows us to simulate datagram based communication over a net.Conn
+ cfg := &turn.ClientConfig{
+ TURNServerAddr: address,
+ Conn: turn.NewSTUNConn(conn),
+ Username: "test",
+ Password: "test",
+ LoggerFactory: fac,
+ }
+
+ client, err := turn.NewClient(cfg)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create TURN client for server %s: %s", addrStr, err)
+ }
+
+ // Start listening on the conn provided.
+ err = client.Listen()
+ if err != nil {
+ client.Close()
+ return nil, fmt.Errorf("failed to listen on TURN client for server %s: %s", addrStr, err)
+ }
+
+ return client, nil
+}
diff --git a/release_files/install.sh b/release_files/install.sh
index 198d74428..b7a6c08f9 100755
--- a/release_files/install.sh
+++ b/release_files/install.sh
@@ -21,6 +21,8 @@ SUDO=""
if command -v sudo > /dev/null && [ "$(id -u)" -ne 0 ]; then
SUDO="sudo"
+elif command -v doas > /dev/null && [ "$(id -u)" -ne 0 ]; then
+ SUDO="doas"
fi
if [ -z ${NETBIRD_RELEASE+x} ]; then
@@ -31,14 +33,16 @@ get_release() {
local RELEASE=$1
if [ "$RELEASE" = "latest" ]; then
local TAG="latest"
+ local URL="https://pkgs.netbird.io/releases/latest"
else
local TAG="tags/${RELEASE}"
+ local URL="https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}"
fi
if [ -n "$GITHUB_TOKEN" ]; then
- curl -H "Authorization: token ${GITHUB_TOKEN}" -s "https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" \
+ curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}" \
| grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/'
else
- curl -s "https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" \
+ curl -s "${URL}" \
| grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/'
fi
}
@@ -68,7 +72,7 @@ download_release_binary() {
if [ -n "$GITHUB_TOKEN" ]; then
cd /tmp && curl -H "Authorization: token ${GITHUB_TOKEN}" -LO "$DOWNLOAD_URL"
else
- cd /tmp && curl -LO "$DOWNLOAD_URL"
+ cd /tmp && curl -LO "$DOWNLOAD_URL" || curl -LO --dns-servers 8.8.8.8 "$DOWNLOAD_URL"
fi
@@ -151,6 +155,22 @@ add_aur_repo() {
${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm
}
+prepare_tun_module() {
+ # Create the necessary file structure for /dev/net/tun
+ if [ ! -c /dev/net/tun ]; then
+ if [ ! -d /dev/net ]; then
+ mkdir -m 755 /dev/net
+ fi
+ mknod /dev/net/tun c 10 200
+ chmod 0755 /dev/net/tun
+ fi
+
+ # Load the tun module if not already loaded
+ if ! lsmod | grep -q "^tun\s"; then
+ insmod /lib/modules/tun.ko
+ fi
+}
+
install_native_binaries() {
# Checks for supported architecture
case "$ARCH" in
@@ -226,6 +246,13 @@ install_netbird() {
${SUDO} dnf -y install netbird-ui
fi
;;
+ rpm-ostree)
+ add_rpm_repo
+ ${SUDO} rpm-ostree -y install netbird
+ if ! $SKIP_UI_APP; then
+ ${SUDO} rpm-ostree -y install netbird-ui
+ fi
+ ;;
pacman)
${SUDO} pacman -Syy
add_aur_repo
@@ -268,16 +295,22 @@ install_netbird() {
;;
esac
+ if [ "$OS_NAME" = "synology" ]; then
+ prepare_tun_module
+ fi
+
# Add package manager to config
${SUDO} mkdir -p "$CONFIG_FOLDER"
echo "package_manager=$PACKAGE_MANAGER" | ${SUDO} tee "$CONFIG_FILE" > /dev/null
# Load and start netbird service
- if ! ${SUDO} netbird service install 2>&1; then
- echo "NetBird service has already been loaded"
- fi
- if ! ${SUDO} netbird service start 2>&1; then
- echo "NetBird service has already been started"
+ if [ "$PACKAGE_MANAGER" != "rpm-ostree" ]; then
+ if ! ${SUDO} netbird service install 2>&1; then
+ echo "NetBird service has already been loaded"
+ fi
+ if ! ${SUDO} netbird service start 2>&1; then
+ echo "NetBird service has already been started"
+ fi
fi
@@ -287,7 +320,7 @@ install_netbird() {
}
version_greater_equal() {
- printf '%s\n%s\n' "$2" "$1" | sort -V -C
+ printf '%s\n%s\n' "$2" "$1" | sort -V -c
}
is_bin_package_manager() {
@@ -383,6 +416,9 @@ if type uname >/dev/null 2>&1; then
elif [ -x "$(command -v dnf)" ]; then
PACKAGE_MANAGER="dnf"
echo "The installation will be performed using dnf package manager"
+ elif [ -x "$(command -v rpm-ostree)" ]; then
+ PACKAGE_MANAGER="rpm-ostree"
+ echo "The installation will be performed using rpm-ostree package manager"
elif [ -x "$(command -v yum)" ]; then
PACKAGE_MANAGER="yum"
echo "The installation will be performed using yum package manager"
diff --git a/route/route.go b/route/route.go
index eb6c36bd8..e23801e6e 100644
--- a/route/route.go
+++ b/route/route.go
@@ -100,6 +100,7 @@ type Route struct {
Metric int
Enabled bool
Groups []string `gorm:"serializer:json"`
+ AccessControlGroups []string `gorm:"serializer:json"`
}
// EventMeta returns activity event meta related to the route
@@ -123,6 +124,7 @@ func (r *Route) Copy() *Route {
Masquerade: r.Masquerade,
Enabled: r.Enabled,
Groups: slices.Clone(r.Groups),
+ AccessControlGroups: slices.Clone(r.AccessControlGroups),
}
return route
}
@@ -147,7 +149,8 @@ func (r *Route) IsEqual(other *Route) bool {
other.Masquerade == r.Masquerade &&
other.Enabled == r.Enabled &&
slices.Equal(r.Groups, other.Groups) &&
- slices.Equal(r.PeerGroups, other.PeerGroups)
+ slices.Equal(r.PeerGroups, other.PeerGroups)&&
+ slices.Equal(r.AccessControlGroups, other.AccessControlGroups)
}
// IsDynamic returns if the route is dynamic, i.e. has domains
diff --git a/signal/client/client.go b/signal/client/client.go
index 9d99b3677..ced3fb7d0 100644
--- a/signal/client/client.go
+++ b/signal/client/client.go
@@ -51,11 +51,10 @@ func UnMarshalCredential(msg *proto.Message) (*Credential, error) {
}
// MarshalCredential marshal a Credential instance and returns a Message object
-func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey wgtypes.Key, credential *Credential, t proto.Body_Type,
- rosenpassPubKey []byte, rosenpassAddr string) (*proto.Message, error) {
+func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey string, credential *Credential, t proto.Body_Type, rosenpassPubKey []byte, rosenpassAddr string, relaySrvAddress string) (*proto.Message, error) {
return &proto.Message{
Key: myKey.PublicKey().String(),
- RemoteKey: remoteKey.String(),
+ RemoteKey: remoteKey,
Body: &proto.Body{
Type: t,
Payload: fmt.Sprintf("%s:%s", credential.UFrag, credential.Pwd),
@@ -65,6 +64,7 @@ func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey wgtypes.Key, cre
RosenpassPubKey: rosenpassPubKey,
RosenpassServerAddr: rosenpassAddr,
},
+ RelayServerAddress: relaySrvAddress,
},
}, nil
}
diff --git a/signal/client/client_test.go b/signal/client/client_test.go
index 2525493b4..f7d4ebc50 100644
--- a/signal/client/client_test.go
+++ b/signal/client/client_test.go
@@ -199,7 +199,7 @@ func startSignal() (*grpc.Server, net.Listener) {
panic(err)
}
s := grpc.NewServer()
- srv, err := server.NewServer(otel.Meter(""))
+ srv, err := server.NewServer(context.Background(), otel.Meter(""))
if err != nil {
panic(err)
}
diff --git a/signal/cmd/run.go b/signal/cmd/run.go
index 61f7a32a7..1bb2f1d0c 100644
--- a/signal/cmd/run.go
+++ b/signal/cmd/run.go
@@ -29,12 +29,9 @@ import (
"google.golang.org/grpc/keepalive"
)
-const (
- metricsPort = 9090
-)
-
var (
signalPort int
+ metricsPort int
signalLetsencryptDomain string
signalSSLDir string
defaultSignalSSLDir string
@@ -105,7 +102,7 @@ var (
}
}()
- srv, err := server.NewServer(metricsServer.Meter)
+ srv, err := server.NewServer(cmd.Context(), metricsServer.Meter)
if err != nil {
return fmt.Errorf("creating signal server: %v", err)
}
@@ -288,6 +285,7 @@ func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) {
func init() {
runCmd.PersistentFlags().IntVar(&signalPort, "port", 80, "Server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise")
+ runCmd.Flags().IntVar(&metricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics")
runCmd.Flags().StringVar(&signalSSLDir, "ssl-dir", defaultSignalSSLDir, "server ssl directory location. *Required only for Let's Encrypt certificates.")
runCmd.Flags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
runCmd.Flags().StringVar(&signalCertFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
diff --git a/signal/peer/peer.go b/signal/peer/peer.go
index 3149526b2..ed2360d67 100644
--- a/signal/peer/peer.go
+++ b/signal/peer/peer.go
@@ -18,16 +18,20 @@ type Peer struct {
StreamID int64
- //a gRpc connection stream to the Peer
+ // a gRpc connection stream to the Peer
Stream proto.SignalExchange_ConnectStreamServer
+
+ // registration time
+ RegisteredAt time.Time
}
// NewPeer creates a new instance of a connected Peer
func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer) *Peer {
return &Peer{
- Id: id,
- Stream: stream,
- StreamID: time.Now().UnixNano(),
+ Id: id,
+ Stream: stream,
+ StreamID: time.Now().UnixNano(),
+ RegisteredAt: time.Now(),
}
}
@@ -78,8 +82,11 @@ func (registry *Registry) Register(peer *Peer) {
log.Warnf("peer [%s] is already registered [new streamID %d, previous StreamID %d]. Will override stream.",
peer.Id, peer.StreamID, pp.StreamID)
registry.Peers.Store(peer.Id, peer)
+ return
}
+
log.Debugf("peer registered [%s]", peer.Id)
+ registry.metrics.ActivePeers.Add(context.Background(), 1)
// record time as milliseconds
registry.metrics.RegistrationDelay.Record(context.Background(), float64(time.Since(start).Nanoseconds())/1e6)
@@ -101,8 +108,8 @@ func (registry *Registry) Deregister(peer *Peer) {
peer.Id, pp.StreamID, peer.StreamID)
return
}
+ registry.metrics.ActivePeers.Add(context.Background(), -1)
+ log.Debugf("peer deregistered [%s]", peer.Id)
+ registry.metrics.Deregistrations.Add(context.Background(), 1)
}
- log.Debugf("peer deregistered [%s]", peer.Id)
-
- registry.metrics.Deregistrations.Add(context.Background(), 1)
}
diff --git a/signal/proto/signalexchange.pb.go b/signal/proto/signalexchange.pb.go
index 782c45da1..30f704c6f 100644
--- a/signal/proto/signalexchange.pb.go
+++ b/signal/proto/signalexchange.pb.go
@@ -1,15 +1,15 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
-// protoc v3.12.4
+// protoc v3.21.12
// source: signalexchange.proto
package proto
import (
- _ "github.com/golang/protobuf/protoc-gen-go/descriptor"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
+ _ "google.golang.org/protobuf/types/descriptorpb"
reflect "reflect"
sync "sync"
)
@@ -225,6 +225,8 @@ type Body struct {
FeaturesSupported []uint32 `protobuf:"varint,6,rep,packed,name=featuresSupported,proto3" json:"featuresSupported,omitempty"`
// RosenpassConfig is a Rosenpass config of the remote peer our peer tries to connect to
RosenpassConfig *RosenpassConfig `protobuf:"bytes,7,opt,name=rosenpassConfig,proto3" json:"rosenpassConfig,omitempty"`
+ // relayServerAddress is an IP:port of the relay server
+ RelayServerAddress string `protobuf:"bytes,8,opt,name=relayServerAddress,proto3" json:"relayServerAddress,omitempty"`
}
func (x *Body) Reset() {
@@ -308,6 +310,13 @@ func (x *Body) GetRosenpassConfig() *RosenpassConfig {
return nil
}
+func (x *Body) GetRelayServerAddress() string {
+ if x != nil {
+ return x.RelayServerAddress
+ }
+ return ""
+}
+
// Mode indicates a connection mode
type Mode struct {
state protoimpl.MessageState
@@ -431,7 +440,7 @@ var file_signalexchange_proto_rawDesc = []byte{
0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x04, 0x62,
0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e,
0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x52,
- 0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xf6, 0x02, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d,
+ 0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xa6, 0x03, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d,
0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73,
0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f,
0x64, 0x79, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a,
@@ -451,7 +460,10 @@ var file_signalexchange_proto_rawDesc = []byte{
0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63,
0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43,
0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73,
- 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x36, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09,
+ 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53,
+ 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x08, 0x20, 0x01,
+ 0x28, 0x09, 0x52, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41,
+ 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0x36, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09,
0x0a, 0x05, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53,
0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41,
0x54, 0x45, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10, 0x04, 0x22, 0x2e,
diff --git a/signal/proto/signalexchange.proto b/signal/proto/signalexchange.proto
index a8c4c309c..4431edd7c 100644
--- a/signal/proto/signalexchange.proto
+++ b/signal/proto/signalexchange.proto
@@ -60,6 +60,9 @@ message Body {
// RosenpassConfig is a Rosenpass config of the remote peer our peer tries to connect to
RosenpassConfig rosenpassConfig = 7;
+
+ // relayServerAddress is url of the relay server
+ string relayServerAddress = 8;
}
// Mode indicates a connection mode
diff --git a/signal/server/signal.go b/signal/server/signal.go
index 219bdcc41..63cc43bd7 100644
--- a/signal/server/signal.go
+++ b/signal/server/signal.go
@@ -13,6 +13,8 @@ import (
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
+ "github.com/netbirdio/signal-dispatcher/dispatcher"
+
"github.com/netbirdio/netbird/signal/metrics"
"github.com/netbirdio/netbird/signal/peer"
"github.com/netbirdio/netbird/signal/proto"
@@ -40,20 +42,26 @@ const (
type Server struct {
registry *peer.Registry
proto.UnimplementedSignalExchangeServer
-
- metrics *metrics.AppMetrics
+ dispatcher *dispatcher.Dispatcher
+ metrics *metrics.AppMetrics
}
// NewServer creates a new Signal server
-func NewServer(meter metric.Meter) (*Server, error) {
+func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) {
appMetrics, err := metrics.NewAppMetrics(meter)
if err != nil {
return nil, fmt.Errorf("creating app metrics: %v", err)
}
+ dispatcher, err := dispatcher.NewDispatcher(ctx, meter)
+ if err != nil {
+ return nil, fmt.Errorf("creating dispatcher: %v", err)
+ }
+
s := &Server{
- registry: peer.NewRegistry(appMetrics),
- metrics: appMetrics,
+ dispatcher: dispatcher,
+ registry: peer.NewRegistry(appMetrics),
+ metrics: appMetrics,
}
return s, nil
@@ -61,57 +69,26 @@ func NewServer(meter metric.Meter) (*Server, error) {
// Send forwards a message to the signal peer
func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
- if !s.registry.IsPeerRegistered(msg.Key) {
- s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotRegistered)))
+ log.Debugf("received a new message to send from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey)
- return nil, fmt.Errorf("peer %s is not registered", msg.Key)
+ if _, found := s.registry.Get(msg.RemoteKey); found {
+ s.forwardMessageToPeer(ctx, msg)
+ return &proto.EncryptedMessage{}, nil
}
- getRegistrationStart := time.Now()
-
- if dstPeer, found := s.registry.Get(msg.RemoteKey); found {
- s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage), attribute.String(labelRegistrationStatus, labelRegistrationFound)))
- start := time.Now()
- //forward the message to the target peer
- if err := dstPeer.Stream.Send(msg); err != nil {
- log.Errorf("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err)
- //todo respond to the sender?
-
- s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError)))
- } else {
- s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage)))
- s.metrics.MessagesForwarded.Add(context.Background(), 1)
- }
- } else {
- s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage), attribute.String(labelRegistrationStatus, labelRegistrationNotFound)))
- log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey)
- //todo respond to the sender?
-
- s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected)))
- }
- return &proto.EncryptedMessage{}, nil
+ return s.dispatcher.SendMessage(context.Background(), msg)
}
// ConnectStream connects to the exchange stream
func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) error {
- p, err := s.connectPeer(stream)
+ p, err := s.RegisterPeer(stream)
if err != nil {
return err
}
- startRegister := time.Now()
+ defer s.DeregisterPeer(p)
- s.metrics.ActivePeers.Add(stream.Context(), 1)
-
- defer func() {
- log.Infof("peer disconnected [%s] [streamID %d] ", p.Id, p.StreamID)
- s.registry.Deregister(p)
-
- s.metrics.PeerConnectionDuration.Record(stream.Context(), int64(time.Since(startRegister).Seconds()))
- s.metrics.ActivePeers.Add(context.Background(), -1)
- }()
-
- //needed to confirm that the peer has been registered so that the client can proceed
+ // needed to confirm that the peer has been registered so that the client can proceed
header := metadata.Pairs(proto.HeaderRegistered, "1")
err = stream.SendHeader(header)
if err != nil {
@@ -119,11 +96,10 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer)
return err
}
- log.Infof("peer connected [%s] [streamID %d] ", p.Id, p.StreamID)
+ log.Debugf("peer connected [%s] [streamID %d] ", p.Id, p.StreamID)
for {
-
- //read incoming messages
+ // read incoming messages
msg, err := stream.Recv()
if err == io.EOF {
break
@@ -131,44 +107,26 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer)
return err
}
- log.Debugf("received a new message from peer [%s] to peer [%s]", p.Id, msg.RemoteKey)
+ log.Debugf("Received a response from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey)
- getRegistrationStart := time.Now()
-
- // lookup the target peer where the message is going to
- if dstPeer, found := s.registry.Get(msg.RemoteKey); found {
- s.metrics.GetRegistrationDelay.Record(stream.Context(), float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound)))
- start := time.Now()
- //forward the message to the target peer
- if err := dstPeer.Stream.Send(msg); err != nil {
- log.Errorf("error while forwarding message from peer [%s] to peer [%s] %v", p.Id, msg.RemoteKey, err)
- //todo respond to the sender?
- s.metrics.MessageForwardFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelType, labelTypeError)))
- } else {
- // in milliseconds
- s.metrics.MessageForwardLatency.Record(stream.Context(), float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream)))
- s.metrics.MessagesForwarded.Add(stream.Context(), 1)
- }
- } else {
- s.metrics.GetRegistrationDelay.Record(stream.Context(), float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationNotFound)))
- s.metrics.MessageForwardFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected)))
- log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", p.Id, msg.RemoteKey)
- //todo respond to the sender?
+ _, err = s.dispatcher.SendMessage(stream.Context(), msg)
+ if err != nil {
+ log.Debugf("error while sending message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err)
}
}
+
<-stream.Context().Done()
return stream.Context().Err()
}
-// Handles initial Peer connection.
-// Each connection must provide an Id header.
-// At this moment the connecting Peer will be registered in the peer.Registry
-func (s Server) connectPeer(stream proto.SignalExchange_ConnectStreamServer) (*peer.Peer, error) {
+func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) (*peer.Peer, error) {
+ log.Debugf("registering new peer")
if meta, hasMeta := metadata.FromIncomingContext(stream.Context()); hasMeta {
if id, found := meta[proto.HeaderId]; found {
p := peer.NewPeer(id[0], stream)
s.registry.Register(p)
+ s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer)
return p, nil
} else {
@@ -180,3 +138,37 @@ func (s Server) connectPeer(stream proto.SignalExchange_ConnectStreamServer) (*p
return nil, status.Errorf(codes.FailedPrecondition, "missing connection stream meta")
}
}
+
+func (s *Server) DeregisterPeer(p *peer.Peer) {
+ log.Debugf("peer disconnected [%s] [streamID %d] ", p.Id, p.StreamID)
+ s.registry.Deregister(p)
+
+ s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds()))
+}
+
+func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) {
+ log.Debugf("forwarding a new message from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey)
+
+ getRegistrationStart := time.Now()
+
+ // lookup the target peer where the message is going to
+ if dstPeer, found := s.registry.Get(msg.RemoteKey); found {
+ s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound)))
+ start := time.Now()
+ // forward the message to the target peer
+ if err := dstPeer.Stream.Send(msg); err != nil {
+ log.Warnf("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err)
+ // todo respond to the sender?
+ s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError)))
+ } else {
+ // in milliseconds
+ s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream)))
+ s.metrics.MessagesForwarded.Add(ctx, 1)
+ }
+ } else {
+ s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationNotFound)))
+ s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected)))
+ log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey)
+ // todo respond to the sender?
+ }
+}
diff --git a/util/file.go b/util/file.go
index 2a6182556..8355488c9 100644
--- a/util/file.go
+++ b/util/file.go
@@ -10,51 +10,30 @@ import (
log "github.com/sirupsen/logrus"
)
-// WriteJson writes JSON config object to a file creating parent directories if required
-// The output JSON is pretty-formatted
-func WriteJson(file string, obj interface{}) error {
-
+// WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory
+func WriteJsonWithRestrictedPermission(file string, obj interface{}) error {
configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil {
return err
}
- // make it pretty
- bs, err := json.MarshalIndent(obj, "", " ")
+ err = EnforcePermission(file)
if err != nil {
return err
}
- tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
+ return writeJson(file, obj, configDir, configFileName)
+}
+
+// WriteJson writes JSON config object to a file creating parent directories if required
+// The output JSON is pretty-formatted
+func WriteJson(file string, obj interface{}) error {
+ configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil {
return err
}
- tempFileName := tempFile.Name()
- // closing file ops as windows doesn't allow to move it
- err = tempFile.Close()
- if err != nil {
- return err
- }
-
- defer func() {
- _, err = os.Stat(tempFileName)
- if err == nil {
- os.Remove(tempFileName)
- }
- }()
-
- err = os.WriteFile(tempFileName, bs, 0600)
- if err != nil {
- return err
- }
-
- err = os.Rename(tempFileName, file)
- if err != nil {
- return err
- }
-
- return nil
+ return writeJson(file, obj, configDir, configFileName)
}
// DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file
@@ -96,6 +75,46 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error {
return nil
}
+func writeJson(file string, obj interface{}, configDir string, configFileName string) error {
+
+ // make it pretty
+ bs, err := json.MarshalIndent(obj, "", " ")
+ if err != nil {
+ return err
+ }
+
+ tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
+ if err != nil {
+ return err
+ }
+
+ tempFileName := tempFile.Name()
+ // closing file ops as windows doesn't allow to move it
+ err = tempFile.Close()
+ if err != nil {
+ return err
+ }
+
+ defer func() {
+ _, err = os.Stat(tempFileName)
+ if err == nil {
+ os.Remove(tempFileName)
+ }
+ }()
+
+ err = os.WriteFile(tempFileName, bs, 0600)
+ if err != nil {
+ return err
+ }
+
+ err = os.Rename(tempFileName, file)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
func openOrCreateFile(file string) (*os.File, error) {
s, err := os.Stat(file)
if err == nil {
@@ -172,5 +191,9 @@ func prepareConfigFileDir(file string) (string, string, error) {
}
err := os.MkdirAll(configDir, 0750)
+ if err != nil {
+ return "", "", err
+ }
+
return configDir, configFileName, err
}
diff --git a/util/log.go b/util/log.go
index 4bce75e4a..7a9235ee6 100644
--- a/util/log.go
+++ b/util/log.go
@@ -5,6 +5,7 @@ import (
"os"
"path/filepath"
"slices"
+ "strconv"
log "github.com/sirupsen/logrus"
"gopkg.in/natefinch/lumberjack.v2"
@@ -12,6 +13,8 @@ import (
"github.com/netbirdio/netbird/formatter"
)
+const defaultLogSize = 5
+
// InitLog parses and sets log-level input
func InitLog(logLevel string, logPath string) error {
level, err := log.ParseLevel(logLevel)
@@ -19,13 +22,14 @@ func InitLog(logLevel string, logPath string) error {
log.Errorf("Failed parsing log-level %s: %s", logLevel, err)
return err
}
- customOutputs := []string{"console", "syslog"};
+ customOutputs := []string{"console", "syslog"}
if logPath != "" && !slices.Contains(customOutputs, logPath) {
+ maxLogSize := getLogMaxSize()
lumberjackLogger := &lumberjack.Logger{
// Log file absolute path, os agnostic
Filename: filepath.ToSlash(logPath),
- MaxSize: 5, // MB
+ MaxSize: maxLogSize, // MB
MaxBackups: 10,
MaxAge: 30, // days
Compress: true,
@@ -46,3 +50,18 @@ func InitLog(logLevel string, logPath string) error {
log.SetLevel(level)
return nil
}
+
+func getLogMaxSize() int {
+ if sizeVar, ok := os.LookupEnv("NB_LOG_MAX_SIZE_MB"); ok {
+ size, err := strconv.ParseInt(sizeVar, 10, 64)
+ if err != nil {
+ log.Errorf("Failed parsing log-size %s: %s. Should be just an integer", sizeVar, err)
+ return defaultLogSize
+ }
+
+ log.Infof("Setting log file max size to %d MB", size)
+
+ return int(size)
+ }
+ return defaultLogSize
+}
diff --git a/util/net/dialer_nonios.go b/util/net/dialer_nonios.go
index 7a5de7587..4032a75c0 100644
--- a/util/net/dialer_nonios.go
+++ b/util/net/dialer_nonios.go
@@ -49,6 +49,8 @@ func RemoveDialerHooks() {
// DialContext wraps the net.Dialer's DialContext method to use the custom connection
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
+ log.Debugf("Dialing %s %s", network, address)
+
if CustomRoutingDisabled() {
return d.Dialer.DialContext(ctx, network, address)
}
diff --git a/util/net/net.go b/util/net/net.go
index 8d1fcebd0..61b47dbe7 100644
--- a/util/net/net.go
+++ b/util/net/net.go
@@ -4,7 +4,7 @@ import (
"net"
"os"
- "github.com/netbirdio/netbird/iface/netstack"
+ "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/google/uuid"
)
diff --git a/util/permission.go b/util/permission.go
new file mode 100644
index 000000000..666998cff
--- /dev/null
+++ b/util/permission.go
@@ -0,0 +1,7 @@
+//go:build !windows
+
+package util
+
+func EnforcePermission(dirPath string) error {
+ return nil
+}
diff --git a/util/permission_windows.go b/util/permission_windows.go
new file mode 100644
index 000000000..548fef824
--- /dev/null
+++ b/util/permission_windows.go
@@ -0,0 +1,86 @@
+package util
+
+import (
+ "path/filepath"
+
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/sys/windows"
+)
+
+const (
+ securityFlags = windows.OWNER_SECURITY_INFORMATION |
+ windows.GROUP_SECURITY_INFORMATION |
+ windows.DACL_SECURITY_INFORMATION |
+ windows.PROTECTED_DACL_SECURITY_INFORMATION
+)
+
+func EnforcePermission(file string) error {
+ dirPath := filepath.Dir(file)
+
+ user, group, err := sids()
+ if err != nil {
+ return err
+ }
+
+ adminGroupSid, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid)
+ if err != nil {
+ return err
+ }
+
+ explicitAccess := []windows.EXPLICIT_ACCESS{
+ {
+ AccessPermissions: windows.GENERIC_ALL,
+ AccessMode: windows.SET_ACCESS,
+ Inheritance: windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT,
+ Trustee: windows.TRUSTEE{
+ MultipleTrusteeOperation: windows.NO_MULTIPLE_TRUSTEE,
+ TrusteeForm: windows.TRUSTEE_IS_SID,
+ TrusteeType: windows.TRUSTEE_IS_USER,
+ TrusteeValue: windows.TrusteeValueFromSID(user),
+ },
+ },
+ {
+ AccessPermissions: windows.GENERIC_ALL,
+ AccessMode: windows.SET_ACCESS,
+ Inheritance: windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT,
+ Trustee: windows.TRUSTEE{
+ MultipleTrusteeOperation: windows.NO_MULTIPLE_TRUSTEE,
+ TrusteeForm: windows.TRUSTEE_IS_SID,
+ TrusteeType: windows.TRUSTEE_IS_WELL_KNOWN_GROUP,
+ TrusteeValue: windows.TrusteeValueFromSID(adminGroupSid),
+ },
+ },
+ }
+
+ dacl, err := windows.ACLFromEntries(explicitAccess, nil)
+ if err != nil {
+ return err
+ }
+
+ return windows.SetNamedSecurityInfo(dirPath, windows.SE_FILE_OBJECT, securityFlags, user, group, dacl, nil)
+}
+
+func sids() (*windows.SID, *windows.SID, error) {
+ var token windows.Token
+ err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &token)
+ if err != nil {
+ return nil, nil, err
+ }
+ defer func() {
+ if err := token.Close(); err != nil {
+ log.Errorf("failed to close process token: %v", err)
+ }
+ }()
+
+ tu, err := token.GetTokenUser()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ pg, err := token.GetTokenPrimaryGroup()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return tu.User.Sid, pg.PrimaryGroup, nil
+}