diff --git a/.github/ISSUE_TEMPLATE/bug-issue-report.md b/.github/ISSUE_TEMPLATE/bug-issue-report.md index 789c61974..87f757f42 100644 --- a/.github/ISSUE_TEMPLATE/bug-issue-report.md +++ b/.github/ISSUE_TEMPLATE/bug-issue-report.md @@ -35,7 +35,7 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan If applicable, add the `netbird status -dA' command output. -**Do you face any client issues on desktop?** +**Do you face any (non-mobile) client issues?** Please provide the file created by `netbird debug for 1m -AS`. We advise reviewing the anonymized files for any remaining PII. diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml index 2b4c43cb4..2aaef7564 100644 --- a/.github/workflows/golang-test-darwin.yml +++ b/.github/workflows/golang-test-darwin.yml @@ -18,14 +18,14 @@ jobs: runs-on: macos-latest steps: - name: Install Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: - go-version: "1.21.x" + go-version: "1.23.x" - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Cache Go modules - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/go/pkg/mod key: macos-go-${{ hashFiles('**/go.sum') }} diff --git a/.github/workflows/golang-test-freebsd.yml b/.github/workflows/golang-test-freebsd.yml index 4f13ee30e..a2d743715 100644 --- a/.github/workflows/golang-test-freebsd.yml +++ b/.github/workflows/golang-test-freebsd.yml @@ -38,7 +38,7 @@ jobs: time go test -timeout 1m -failfast ./dns/... time go test -timeout 1m -failfast ./encryption/... time go test -timeout 1m -failfast ./formatter/... - time go test -timeout 1m -failfast ./iface/... + time go test -timeout 1m -failfast ./client/iface/... time go test -timeout 1m -failfast ./route/... time go test -timeout 1m -failfast ./sharedsock/... time go test -timeout 1m -failfast ./signal/... diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 120b213e9..524f35f6f 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -19,13 +19,13 @@ jobs: runs-on: ubuntu-latest steps: - name: Install Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: - go-version: "1.21.x" + go-version: "1.23.x" - name: Cache Go modules - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} @@ -33,7 +33,7 @@ jobs: ${{ runner.os }}-go- - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install dependencies run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev @@ -49,18 +49,18 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./... + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./... test_client_on_docker: runs-on: ubuntu-20.04 steps: - name: Install Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: - go-version: "1.21.x" + go-version: "1.23.x" - name: Cache Go modules - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} @@ -68,7 +68,7 @@ jobs: ${{ runner.os }}-go- - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install dependencies run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev @@ -80,7 +80,7 @@ jobs: run: git --no-pager diff --exit-code - name: Generate Iface Test bin - run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./iface/ + run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./client/iface/ - name: Generate Shared Sock Test bin run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock @@ -124,4 +124,4 @@ jobs: run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1 - name: Run Peer tests in docker - run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1 \ No newline at end of file + run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1 diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index 2d63acbcd..d378bec3f 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -17,13 +17,13 @@ jobs: runs-on: windows-latest steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 id: go with: - go-version: "1.21.x" + go-version: "1.23.x" - name: Download wintun uses: carlosperate/download-file-action@v2 diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 78b9f504f..2d743f790 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -15,11 +15,11 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: codespell uses: codespell-project/actions-codespell@v2 with: - ignore_words_list: erro,clienta,hastable, + ignore_words_list: erro,clienta,hastable,iif skip: go.mod,go.sum only_warn: 1 golangci: @@ -32,15 +32,15 @@ jobs: timeout-minutes: 15 steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Check for duplicate constants if: matrix.os == 'ubuntu-latest' run: | ! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep . - name: Install Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: - go-version: "1.21.x" + go-version: "1.23.x" cache: false - name: Install dependencies if: matrix.os == 'ubuntu-latest' @@ -49,4 +49,4 @@ jobs: uses: golangci/golangci-lint-action@v3 with: version: latest - args: --timeout=12m \ No newline at end of file + args: --timeout=12m diff --git a/.github/workflows/install-script-test.yml b/.github/workflows/install-script-test.yml index dfb8a279b..22d002a48 100644 --- a/.github/workflows/install-script-test.yml +++ b/.github/workflows/install-script-test.yml @@ -13,6 +13,7 @@ concurrency: jobs: test-install-script: strategy: + fail-fast: false max-parallel: 2 matrix: os: [ubuntu-latest, macos-latest] @@ -21,7 +22,7 @@ jobs: runs-on: ${{ matrix.os }} steps: - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: run install script env: diff --git a/.github/workflows/mobile-build-validation.yml b/.github/workflows/mobile-build-validation.yml index e5a5ff485..dcf461a34 100644 --- a/.github/workflows/mobile-build-validation.yml +++ b/.github/workflows/mobile-build-validation.yml @@ -15,23 +15,23 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: - go-version: "1.21.x" + go-version: "1.23.x" - name: Setup Android SDK uses: android-actions/setup-android@v3 with: cmdline-tools-version: 8512546 - name: Setup Java - uses: actions/setup-java@v3 + uses: actions/setup-java@v4 with: java-version: "11" distribution: "adopt" - name: NDK Cache id: ndk-cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: /usr/local/lib/android/sdk/ndk key: ndk-cache-23.1.7779620 @@ -50,11 +50,11 @@ jobs: runs-on: macos-latest steps: - name: Checkout repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: - go-version: "1.21.x" + go-version: "1.23.x" - name: install gomobile run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed - name: gomobile init @@ -62,4 +62,4 @@ jobs: - name: build iOS netbird lib run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o ./NetBirdSDK.xcframework ./client/ios/NetBirdSDK env: - CGO_ENABLED: 0 \ No newline at end of file + CGO_ENABLED: 0 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 30f24e92e..7af6d3e4d 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -3,15 +3,14 @@ name: Release on: push: tags: - - 'v*' + - "v*" branches: - main pull_request: - env: - SIGN_PIPE_VER: "v0.0.12" - GORELEASER_VER: "v1.14.1" + SIGN_PIPE_VER: "v0.0.14" + GORELEASER_VER: "v2.3.2" PRODUCT_NAME: "NetBird" COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)" @@ -34,20 +33,17 @@ jobs: - if: ${{ !startsWith(github.ref, 'refs/tags/v') }} run: echo "flags=--snapshot" >> $GITHUB_ENV - - - name: Checkout - uses: actions/checkout@v3 + - name: Checkout + uses: actions/checkout@v4 with: fetch-depth: 0 # It is required for GoReleaser to work properly - - - name: Set up Go - uses: actions/setup-go@v4 + - name: Set up Go + uses: actions/setup-go@v5 with: - go-version: "1.21" + go-version: "1.23" cache: false - - - name: Cache Go modules - uses: actions/cache@v3 + - name: Cache Go modules + uses: actions/cache@v4 with: path: | ~/go/pkg/mod @@ -55,24 +51,19 @@ jobs: key: ${{ runner.os }}-go-releaser-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go-releaser- - - - name: Install modules + - name: Install modules run: go mod tidy - - - name: check git status + - name: check git status run: git --no-pager diff --exit-code - - - name: Set up QEMU + - name: Set up QEMU uses: docker/setup-qemu-action@v2 - - - name: Set up Docker Buildx + - name: Set up Docker Buildx uses: docker/setup-buildx-action@v2 - - - name: Login to Docker hub + - name: Login to Docker hub if: github.event_name != 'pull_request' uses: docker/login-action@v1 with: - username: netbirdio + username: ${{ secrets.DOCKER_USER }} password: ${{ secrets.DOCKER_TOKEN }} - name: Install OS build dependencies run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu @@ -85,36 +76,32 @@ jobs: uses: goreleaser/goreleaser-action@v4 with: version: ${{ env.GORELEASER_VER }} - args: release --rm-dist ${{ env.flags }} + args: release --clean ${{ env.flags }} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} - - - name: upload non tags for debug purposes - uses: actions/upload-artifact@v3 + - name: upload non tags for debug purposes + uses: actions/upload-artifact@v4 with: name: release path: dist/ retention-days: 3 - - - name: upload linux packages - uses: actions/upload-artifact@v3 + - name: upload linux packages + uses: actions/upload-artifact@v4 with: name: linux-packages path: dist/netbird_linux** retention-days: 3 - - - name: upload windows packages - uses: actions/upload-artifact@v3 + - name: upload windows packages + uses: actions/upload-artifact@v4 with: name: windows-packages path: dist/netbird_windows** retention-days: 3 - - - name: upload macos packages - uses: actions/upload-artifact@v3 + - name: upload macos packages + uses: actions/upload-artifact@v4 with: name: macos-packages path: dist/netbird_darwin** @@ -133,19 +120,19 @@ jobs: - if: ${{ !startsWith(github.ref, 'refs/tags/v') }} run: echo "flags=--snapshot" >> $GITHUB_ENV - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 with: fetch-depth: 0 # It is required for GoReleaser to work properly - name: Set up Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: - go-version: "1.21" + go-version: "1.23" cache: false - name: Cache Go modules - uses: actions/cache@v3 + uses: actions/cache@v4 with: - path: | + path: | ~/go/pkg/mod ~/.cache/go-build key: ${{ runner.os }}-ui-go-releaser-${{ hashFiles('**/go.sum') }} @@ -169,14 +156,14 @@ jobs: uses: goreleaser/goreleaser-action@v4 with: version: ${{ env.GORELEASER_VER }} - args: release --config .goreleaser_ui.yaml --rm-dist ${{ env.flags }} + args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} - name: upload non tags for debug purposes - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: release-ui path: dist/ @@ -187,20 +174,17 @@ jobs: steps: - if: ${{ !startsWith(github.ref, 'refs/tags/v') }} run: echo "flags=--snapshot" >> $GITHUB_ENV - - - name: Checkout - uses: actions/checkout@v3 + - name: Checkout + uses: actions/checkout@v4 with: fetch-depth: 0 # It is required for GoReleaser to work properly - - - name: Set up Go - uses: actions/setup-go@v4 + - name: Set up Go + uses: actions/setup-go@v5 with: - go-version: "1.21" + go-version: "1.23" cache: false - - - name: Cache Go modules - uses: actions/cache@v3 + - name: Cache Go modules + uses: actions/cache@v4 with: path: | ~/go/pkg/mod @@ -208,52 +192,34 @@ jobs: key: ${{ runner.os }}-ui-go-releaser-darwin-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-ui-go-releaser-darwin- - - - name: Install modules + - name: Install modules run: go mod tidy - - - name: check git status + - name: check git status run: git --no-pager diff --exit-code - - - name: Run GoReleaser + - name: Run GoReleaser id: goreleaser uses: goreleaser/goreleaser-action@v4 with: version: ${{ env.GORELEASER_VER }} - args: release --config .goreleaser_ui_darwin.yaml --rm-dist ${{ env.flags }} + args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - - name: upload non tags for debug purposes - uses: actions/upload-artifact@v3 + - name: upload non tags for debug purposes + uses: actions/upload-artifact@v4 with: name: release-ui-darwin path: dist/ retention-days: 3 - trigger_windows_signer: + trigger_signer: runs-on: ubuntu-latest - needs: [release,release_ui] + needs: [release, release_ui, release_ui_darwin] if: startsWith(github.ref, 'refs/tags/') steps: - - name: Trigger Windows binaries sign pipeline + - name: Trigger binaries sign pipelines uses: benc-uk/workflow-dispatch@v1 with: - workflow: Sign windows bin and installer - repo: netbirdio/sign-pipelines - ref: ${{ env.SIGN_PIPE_VER }} - token: ${{ secrets.SIGN_GITHUB_TOKEN }} - inputs: '{ "tag": "${{ github.ref }}" }' - - trigger_darwin_signer: - runs-on: ubuntu-latest - needs: [release,release_ui_darwin] - if: startsWith(github.ref, 'refs/tags/') - steps: - - name: Trigger Darwin App binaries sign pipeline - uses: benc-uk/workflow-dispatch@v1 - with: - workflow: Sign darwin ui app with dispatch + workflow: Sign bin and installer repo: netbirdio/sign-pipelines ref: ${{ env.SIGN_PIPE_VER }} token: ${{ secrets.SIGN_GITHUB_TOKEN }} diff --git a/.github/workflows/test-infrastructure-files.yml b/.github/workflows/test-infrastructure-files.yml index 52b8ee3e2..da3ec746a 100644 --- a/.github/workflows/test-infrastructure-files.yml +++ b/.github/workflows/test-infrastructure-files.yml @@ -18,7 +18,31 @@ concurrency: jobs: test-docker-compose: runs-on: ubuntu-latest + strategy: + matrix: + store: [ 'sqlite', 'postgres' ] + services: + postgres: + image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }} + env: + POSTGRES_USER: netbird + POSTGRES_PASSWORD: postgres + POSTGRES_DB: netbird + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + ports: + - 5432:5432 steps: + - name: Set Database Connection String + run: | + if [ "${{ matrix.store }}" == "postgres" ]; then + echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN=host=$(hostname -I | awk '{print $1}') user=netbird password=postgres dbname=netbird port=5432" >> $GITHUB_ENV + else + echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN==" >> $GITHUB_ENV + fi + - name: Install jq run: sudo apt-get install -y jq @@ -26,12 +50,12 @@ jobs: run: sudo apt-get install -y curl - name: Install Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: - go-version: "1.21.x" + go-version: "1.23.x" - name: Cache Go modules - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} @@ -39,7 +63,7 @@ jobs: ${{ runner.os }}-go- - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: cp setup.env run: cp infrastructure_files/tests/setup.env infrastructure_files/ @@ -58,7 +82,8 @@ jobs: CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified" - CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite" + CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }} + NETBIRD_STORE_ENGINE_POSTGRES_DSN: ${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }} CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false - name: check values @@ -85,7 +110,8 @@ jobs: CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret CI_NETBIRD_SIGNAL_PORT: 12345 - CI_NETBIRD_STORE_CONFIG_ENGINE: "sqlite" + CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }} + NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$' CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4" @@ -123,6 +149,14 @@ jobs: grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES" grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000" grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP + grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN" + # check relay values + grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml + grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml + grep '33445:33445' docker-compose.yml + grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$' + grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445" + grep -A 7 Relay management.json | egrep '"Secret": ".+"' - name: Install modules run: go mod tidy @@ -148,6 +182,15 @@ jobs: run: | docker build -t netbirdio/signal:latest . + - name: Build relay binary + working-directory: relay + run: CGO_ENABLED=0 go build -o netbird-relay main.go + + - name: Build relay docker image + working-directory: relay + run: | + docker build -t netbirdio/relay:latest . + - name: run docker compose up working-directory: infrastructure_files/artifacts run: | @@ -159,15 +202,15 @@ jobs: - name: test running containers run: | count=$(docker compose ps --format json | jq '. | select(.Name | contains("artifacts")) | .State' | grep -c running) - test $count -eq 4 + test $count -eq 5 || docker compose logs working-directory: infrastructure_files/artifacts - name: test geolocation databases working-directory: infrastructure_files/artifacts run: | sleep 30 - docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City.mmdb - docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames.db + docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City_[0-9]*.mmdb + docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames_[0-9]*.db test-getting-started-script: runs-on: ubuntu-latest @@ -176,7 +219,7 @@ jobs: run: sudo apt-get install -y jq - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: run script with Zitadel PostgreSQL run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh @@ -191,7 +234,7 @@ jobs: run: test -f management.json - name: test turnserver.conf file gen postgres - run: | + run: | set -x test -f turnserver.conf grep external-ip turnserver.conf @@ -202,6 +245,9 @@ jobs: - name: test dashboard.env file gen postgres run: test -f dashboard.env + - name: test relay.env file gen postgres + run: test -f relay.env + - name: test zdb.env file gen postgres run: test -f zdb.env @@ -226,7 +272,7 @@ jobs: run: test -f management.json - name: test turnserver.conf file gen CockroachDB - run: | + run: | set -x test -f turnserver.conf grep external-ip turnserver.conf @@ -237,20 +283,5 @@ jobs: - name: test dashboard.env file gen CockroachDB run: test -f dashboard.env - test-download-geolite2-script: - runs-on: ubuntu-latest - steps: - - name: Install jq - run: sudo apt-get update && sudo apt-get install -y unzip sqlite3 - - - name: Checkout code - uses: actions/checkout@v3 - - - name: test script - run: bash -x infrastructure_files/download-geolite2.sh - - - name: test mmdb file exists - run: test -f GeoLite2-City.mmdb - - - name: test geonames file exists - run: test -f geonames.db + - name: test relay.env file gen CockroachDB + run: test -f relay.env diff --git a/.gitignore b/.gitignore index cdce46975..d0b4f82dd 100644 --- a/.gitignore +++ b/.gitignore @@ -29,4 +29,3 @@ infrastructure_files/setup.env infrastructure_files/setup-*.env .vscode .DS_Store -GeoLite2-City* \ No newline at end of file diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 7a219110a..cf2ce4f4f 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -1,3 +1,5 @@ +version: 2 + project_name: netbird builds: - id: netbird @@ -22,7 +24,7 @@ builds: goarch: 386 ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" tags: - load_wgnt_from_rsrc @@ -42,19 +44,19 @@ builds: - softfloat ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" tags: - load_wgnt_from_rsrc - id: netbird-mgmt dir: management env: - - CGO_ENABLED=1 - - >- - {{- if eq .Runtime.Goos "linux" }} - {{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }} - {{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }} - {{- end }} + - CGO_ENABLED=1 + - >- + {{- if eq .Runtime.Goos "linux" }} + {{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }} + {{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }} + {{- end }} binary: netbird-mgmt goos: - linux @@ -64,7 +66,7 @@ builds: - arm ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" - id: netbird-signal dir: signal @@ -78,7 +80,21 @@ builds: - arm ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" + + - id: netbird-relay + dir: relay + env: [CGO_ENABLED=0] + binary: netbird-relay + goos: + - linux + goarch: + - amd64 + - arm64 + - arm + ldflags: + - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser + mod_timestamp: "{{ .CommitTimestamp }}" archives: - builds: @@ -86,7 +102,6 @@ archives: - netbird-static nfpms: - - maintainer: Netbird description: Netbird client. homepage: https://netbird.io/ @@ -161,6 +176,52 @@ dockers: - "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.version={{.Version}}" - "--label=maintainer=dev@netbird.io" + - image_templates: + - netbirdio/relay:{{ .Version }}-amd64 + ids: + - netbird-relay + goarch: amd64 + use: buildx + dockerfile: relay/Dockerfile + build_flag_templates: + - "--platform=linux/amd64" + - "--label=org.opencontainers.image.created={{.Date}}" + - "--label=org.opencontainers.image.title={{.ProjectName}}" + - "--label=org.opencontainers.image.version={{.Version}}" + - "--label=org.opencontainers.image.revision={{.FullCommit}}" + - "--label=org.opencontainers.image.version={{.Version}}" + - "--label=maintainer=dev@netbird.io" + - image_templates: + - netbirdio/relay:{{ .Version }}-arm64v8 + ids: + - netbird-relay + goarch: arm64 + use: buildx + dockerfile: relay/Dockerfile + build_flag_templates: + - "--platform=linux/arm64" + - "--label=org.opencontainers.image.created={{.Date}}" + - "--label=org.opencontainers.image.title={{.ProjectName}}" + - "--label=org.opencontainers.image.version={{.Version}}" + - "--label=org.opencontainers.image.revision={{.FullCommit}}" + - "--label=org.opencontainers.image.version={{.Version}}" + - "--label=maintainer=dev@netbird.io" + - image_templates: + - netbirdio/relay:{{ .Version }}-arm + ids: + - netbird-relay + goarch: arm + goarm: 6 + use: buildx + dockerfile: relay/Dockerfile + build_flag_templates: + - "--platform=linux/arm" + - "--label=org.opencontainers.image.created={{.Date}}" + - "--label=org.opencontainers.image.title={{.ProjectName}}" + - "--label=org.opencontainers.image.version={{.Version}}" + - "--label=org.opencontainers.image.revision={{.FullCommit}}" + - "--label=org.opencontainers.image.version={{.Version}}" + - "--label=maintainer=dev@netbird.io" - image_templates: - netbirdio/signal:{{ .Version }}-amd64 ids: @@ -313,6 +374,18 @@ docker_manifests: - netbirdio/netbird:{{ .Version }}-arm - netbirdio/netbird:{{ .Version }}-amd64 + - name_template: netbirdio/relay:{{ .Version }} + image_templates: + - netbirdio/relay:{{ .Version }}-arm64v8 + - netbirdio/relay:{{ .Version }}-arm + - netbirdio/relay:{{ .Version }}-amd64 + + - name_template: netbirdio/relay:latest + image_templates: + - netbirdio/relay:{{ .Version }}-arm64v8 + - netbirdio/relay:{{ .Version }}-arm + - netbirdio/relay:{{ .Version }}-amd64 + - name_template: netbirdio/signal:{{ .Version }} image_templates: - netbirdio/signal:{{ .Version }}-arm64v8 @@ -344,10 +417,9 @@ docker_manifests: - netbirdio/management:{{ .Version }}-debug-amd64 brews: - - - ids: + - ids: - default - tap: + repository: owner: netbirdio name: homebrew-tap token: "{{ .Env.HOMEBREW_TAP_GITHUB_TOKEN }}" @@ -364,7 +436,7 @@ brews: uploads: - name: debian ids: - - netbird-deb + - netbird-deb mode: archive target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package= username: dev@wiretrustee.com @@ -386,4 +458,4 @@ checksum: release: extra_files: - glob: ./infrastructure_files/getting-started-with-zitadel.sh - - glob: ./release_files/install.sh \ No newline at end of file + - glob: ./release_files/install.sh diff --git a/.goreleaser_ui.yaml b/.goreleaser_ui.yaml index fd92b5328..06577f4e3 100644 --- a/.goreleaser_ui.yaml +++ b/.goreleaser_ui.yaml @@ -1,3 +1,5 @@ +version: 2 + project_name: netbird-ui builds: - id: netbird-ui @@ -11,7 +13,7 @@ builds: - amd64 ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" - id: netbird-ui-windows dir: client/ui @@ -26,7 +28,7 @@ builds: ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -H windowsgui - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" archives: - id: linux-arch @@ -39,7 +41,6 @@ archives: - netbird-ui-windows nfpms: - - maintainer: Netbird description: Netbird client UI. homepage: https://netbird.io/ @@ -77,7 +78,7 @@ nfpms: uploads: - name: debian ids: - - netbird-ui-deb + - netbird-ui-deb mode: archive target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package= username: dev@wiretrustee.com diff --git a/.goreleaser_ui_darwin.yaml b/.goreleaser_ui_darwin.yaml index 2c3afa91b..bccb7f471 100644 --- a/.goreleaser_ui_darwin.yaml +++ b/.goreleaser_ui_darwin.yaml @@ -1,3 +1,5 @@ +version: 2 + project_name: netbird-ui builds: - id: netbird-ui-darwin @@ -17,7 +19,7 @@ builds: - softfloat ldflags: - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - mod_timestamp: '{{ .CommitTimestamp }}' + mod_timestamp: "{{ .CommitTimestamp }}" tags: - load_wgnt_from_rsrc @@ -28,4 +30,4 @@ archives: checksum: name_template: "{{ .ProjectName }}_darwin_checksums.txt" changelog: - skip: true \ No newline at end of file + disable: true diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 492aa5c2e..c82cfc763 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -96,7 +96,7 @@ They can be executed from the repository root before every push or PR: **Goreleaser** ```shell -goreleaser --snapshot --rm-dist +goreleaser build --snapshot --clean ``` **golangci-lint** ```shell diff --git a/README.md b/README.md index 370445412..aa3ec41e5 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@
- +

@@ -30,7 +30,7 @@
See Documentation
- Join our Slack channel + Join our Slack channel
diff --git a/client/Dockerfile b/client/Dockerfile index a3220bf33..b9f7c1355 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -1,4 +1,4 @@ -FROM alpine:3.19 +FROM alpine:3.20 RUN apk add --no-cache ca-certificates iptables ip6tables ENV NB_FOREGROUND_MODE=true ENTRYPOINT [ "/usr/local/bin/netbird","up"] diff --git a/client/android/client.go b/client/android/client.go index d937e132e..229bcd974 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -8,6 +8,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" @@ -15,7 +16,6 @@ import ( "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/formatter" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/util/net" ) @@ -26,7 +26,7 @@ type ConnectionListener interface { // TunAdapter export internal TunAdapter for mobile type TunAdapter interface { - iface.TunAdapter + device.TunAdapter } // IFaceDiscover export internal IFaceDiscover for mobile @@ -51,7 +51,7 @@ func init() { // Client struct manage the life circle of background service type Client struct { cfgFile string - tunAdapter iface.TunAdapter + tunAdapter device.TunAdapter iFaceDiscover IFaceDiscover recorder *peer.Status ctxCancel context.CancelFunc diff --git a/client/cmd/down.go b/client/cmd/down.go index 4d9f1eba4..3a324cc19 100644 --- a/client/cmd/down.go +++ b/client/cmd/down.go @@ -42,6 +42,8 @@ var downCmd = &cobra.Command{ log.Errorf("call service down method: %v", err) return err } + + cmd.Println("Disconnected") return nil }, } diff --git a/client/cmd/login_test.go b/client/cmd/login_test.go index 6bb7eff4f..fa20435ea 100644 --- a/client/cmd/login_test.go +++ b/client/cmd/login_test.go @@ -5,8 +5,8 @@ import ( "strings" "testing" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/util" ) diff --git a/client/cmd/root_test.go b/client/cmd/root_test.go index abb7d41b2..4cbbe8783 100644 --- a/client/cmd/root_test.go +++ b/client/cmd/root_test.go @@ -4,6 +4,10 @@ import ( "fmt" "io" "testing" + + "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/client/iface" ) func TestInitCommands(t *testing.T) { @@ -34,3 +38,44 @@ func TestInitCommands(t *testing.T) { }) } } + +func TestSetFlagsFromEnvVars(t *testing.T) { + var cmd = &cobra.Command{ + Use: "netbird", + Long: "test", + SilenceUsage: true, + Run: func(cmd *cobra.Command, args []string) { + SetFlagsFromEnvVars(cmd) + }, + } + + cmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil, + `comma separated list of external IPs to map to the Wireguard interface`) + cmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name") + cmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "Enable Rosenpass feature Rosenpass.") + cmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port") + + t.Setenv("NB_EXTERNAL_IP_MAP", "abc,dec") + t.Setenv("NB_INTERFACE_NAME", "test-name") + t.Setenv("NB_ENABLE_ROSENPASS", "true") + t.Setenv("NB_WIREGUARD_PORT", "10000") + err := cmd.Execute() + if err != nil { + t.Fatalf("expected no error while running netbird command, got %v", err) + } + if len(natExternalIPs) != 2 { + t.Errorf("expected 2 external ips, got %d", len(natExternalIPs)) + } + if natExternalIPs[0] != "abc" || natExternalIPs[1] != "dec" { + t.Errorf("expected abc,dec, got %s,%s", natExternalIPs[0], natExternalIPs[1]) + } + if interfaceName != "test-name" { + t.Errorf("expected test-name, got %s", interfaceName) + } + if !rosenpassEnabled { + t.Errorf("expected rosenpassEnabled to be true, got false") + } + if wireguardPort != 10000 { + t.Errorf("expected wireguardPort to be 10000, got %d", wireguardPort) + } +} diff --git a/client/cmd/service.go b/client/cmd/service.go index 5c60744f9..855eb30fa 100644 --- a/client/cmd/service.go +++ b/client/cmd/service.go @@ -2,18 +2,21 @@ package cmd import ( "context" + "github.com/kardianos/service" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "google.golang.org/grpc" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/server" ) type program struct { - ctx context.Context - cancel context.CancelFunc - serv *grpc.Server + ctx context.Context + cancel context.CancelFunc + serv *grpc.Server + serverInstance *server.Server } func newProgram(ctx context.Context, cancel context.CancelFunc) *program { diff --git a/client/cmd/service_controller.go b/client/cmd/service_controller.go index d416afaac..86546e31c 100644 --- a/client/cmd/service_controller.go +++ b/client/cmd/service_controller.go @@ -61,6 +61,8 @@ func (p *program) Start(svc service.Service) error { } proto.RegisterDaemonServiceServer(p.serv, serverInstance) + p.serverInstance = serverInstance + log.Printf("started daemon server: %v", split[1]) if err := p.serv.Serve(listen); err != nil { log.Errorf("failed to serve daemon requests: %v", err) @@ -70,6 +72,14 @@ func (p *program) Start(svc service.Service) error { } func (p *program) Stop(srv service.Service) error { + if p.serverInstance != nil { + in := new(proto.DownRequest) + _, err := p.serverInstance.Down(p.ctx, in) + if err != nil { + log.Errorf("failed to stop daemon: %v", err) + } + } + p.cancel() if p.serv != nil { diff --git a/client/cmd/status.go b/client/cmd/status.go index d9b7a9c91..ed3daa2b5 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -31,9 +31,9 @@ type peerStateDetailOutput struct { Status string `json:"status" yaml:"status"` LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"` ConnType string `json:"connectionType" yaml:"connectionType"` - Direct bool `json:"direct" yaml:"direct"` IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"` IceCandidateEndpoint iceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"` + RelayAddress string `json:"relayAddress" yaml:"relayAddress"` LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"` TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"` TransferSent int64 `json:"transferSent" yaml:"transferSent"` @@ -335,16 +335,18 @@ func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput { func mapPeers(peers []*proto.PeerState) peersStateOutput { var peersStateDetail []peerStateDetailOutput - localICE := "" - remoteICE := "" - localICEEndpoint := "" - remoteICEEndpoint := "" - connType := "" peersConnected := 0 - lastHandshake := time.Time{} - transferReceived := int64(0) - transferSent := int64(0) for _, pbPeerState := range peers { + localICE := "" + remoteICE := "" + localICEEndpoint := "" + remoteICEEndpoint := "" + relayServerAddress := "" + connType := "" + lastHandshake := time.Time{} + transferReceived := int64(0) + transferSent := int64(0) + isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String() if skipDetailByFilters(pbPeerState, isPeerConnected) { continue @@ -360,6 +362,7 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput { if pbPeerState.Relayed { connType = "Relayed" } + relayServerAddress = pbPeerState.GetRelayAddress() lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local() transferReceived = pbPeerState.GetBytesRx() transferSent = pbPeerState.GetBytesTx() @@ -372,7 +375,6 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput { Status: pbPeerState.GetConnStatus(), LastStatusUpdate: timeLocal, ConnType: connType, - Direct: pbPeerState.GetDirect(), IceCandidateType: iceCandidateType{ Local: localICE, Remote: remoteICE, @@ -381,6 +383,7 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput { Local: localICEEndpoint, Remote: remoteICEEndpoint, }, + RelayAddress: relayServerAddress, FQDN: pbPeerState.GetFqdn(), LastWireguardHandshake: lastHandshake, TransferReceived: transferReceived, @@ -641,9 +644,9 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo " Status: %s\n"+ " -- detail --\n"+ " Connection type: %s\n"+ - " Direct: %t\n"+ " ICE candidate (Local/Remote): %s/%s\n"+ " ICE candidate endpoints (Local/Remote): %s/%s\n"+ + " Relay server address: %s\n"+ " Last connection update: %s\n"+ " Last WireGuard handshake: %s\n"+ " Transfer status (received/sent) %s/%s\n"+ @@ -655,11 +658,11 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo peerState.PubKey, peerState.Status, peerState.ConnType, - peerState.Direct, localICE, remoteICE, localICEEndpoint, remoteICEEndpoint, + peerState.RelayAddress, timeAgo(peerState.LastStatusUpdate), timeAgo(peerState.LastWireguardHandshake), toIEC(peerState.TransferReceived), @@ -802,6 +805,9 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) { if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil { peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port) } + + peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress) + for i, route := range peer.Routes { peer.Routes[i] = a.AnonymizeIPString(route) } diff --git a/client/cmd/status_test.go b/client/cmd/status_test.go index 46620a956..ca43df8a5 100644 --- a/client/cmd/status_test.go +++ b/client/cmd/status_test.go @@ -37,7 +37,6 @@ var resp = &proto.StatusResponse{ ConnStatus: "Connected", ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)), Relayed: false, - Direct: true, LocalIceCandidateType: "", RemoteIceCandidateType: "", LocalIceCandidateEndpoint: "", @@ -57,7 +56,6 @@ var resp = &proto.StatusResponse{ ConnStatus: "Connected", ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)), Relayed: true, - Direct: false, LocalIceCandidateType: "relay", RemoteIceCandidateType: "prflx", LocalIceCandidateEndpoint: "10.0.0.1:10001", @@ -137,7 +135,6 @@ var overview = statusOutputOverview{ Status: "Connected", LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC), ConnType: "P2P", - Direct: true, IceCandidateType: iceCandidateType{ Local: "", Remote: "", @@ -161,7 +158,6 @@ var overview = statusOutputOverview{ Status: "Connected", LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC), ConnType: "Relayed", - Direct: false, IceCandidateType: iceCandidateType{ Local: "relay", Remote: "prflx", @@ -283,7 +279,6 @@ func TestParsingToJSON(t *testing.T) { "status": "Connected", "lastStatusUpdate": "2001-01-01T01:01:01Z", "connectionType": "P2P", - "direct": true, "iceCandidateType": { "local": "", "remote": "" @@ -292,6 +287,7 @@ func TestParsingToJSON(t *testing.T) { "local": "", "remote": "" }, + "relayAddress": "", "lastWireguardHandshake": "2001-01-01T01:01:02Z", "transferReceived": 200, "transferSent": 100, @@ -308,7 +304,6 @@ func TestParsingToJSON(t *testing.T) { "status": "Connected", "lastStatusUpdate": "2002-02-02T02:02:02Z", "connectionType": "Relayed", - "direct": false, "iceCandidateType": { "local": "relay", "remote": "prflx" @@ -317,6 +312,7 @@ func TestParsingToJSON(t *testing.T) { "local": "10.0.0.1:10001", "remote": "10.0.10.1:10002" }, + "relayAddress": "", "lastWireguardHandshake": "2002-02-02T02:02:03Z", "transferReceived": 2000, "transferSent": 1000, @@ -408,13 +404,13 @@ func TestParsingToYAML(t *testing.T) { status: Connected lastStatusUpdate: 2001-01-01T01:01:01Z connectionType: P2P - direct: true iceCandidateType: local: "" remote: "" iceCandidateEndpoint: local: "" remote: "" + relayAddress: "" lastWireguardHandshake: 2001-01-01T01:01:02Z transferReceived: 200 transferSent: 100 @@ -428,13 +424,13 @@ func TestParsingToYAML(t *testing.T) { status: Connected lastStatusUpdate: 2002-02-02T02:02:02Z connectionType: Relayed - direct: false iceCandidateType: local: relay remote: prflx iceCandidateEndpoint: local: 10.0.0.1:10001 remote: 10.0.10.1:10002 + relayAddress: "" lastWireguardHandshake: 2002-02-02T02:02:03Z transferReceived: 2000 transferSent: 1000 @@ -505,9 +501,9 @@ func TestParsingToDetail(t *testing.T) { Status: Connected -- detail -- Connection type: P2P - Direct: true ICE candidate (Local/Remote): -/- ICE candidate endpoints (Local/Remote): -/- + Relay server address: Last connection update: %s Last WireGuard handshake: %s Transfer status (received/sent) 200 B/100 B @@ -521,9 +517,9 @@ func TestParsingToDetail(t *testing.T) { Status: Connected -- detail -- Connection type: Relayed - Direct: false ICE candidate (Local/Remote): relay/prflx ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002 + Relay server address: Last connection update: %s Last WireGuard handshake: %s Transfer status (received/sent) 2.0 KiB/1000 B diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 984aa6df7..f0dc8bf21 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -57,7 +57,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) { t.Fatal(err) } s := grpc.NewServer() - srv, err := sig.NewServer(otel.Meter("")) + srv, err := sig.NewServer(context.Background(), otel.Meter("")) require.NoError(t, err) sigProto.RegisterSignalExchangeServer(s, srv) @@ -98,8 +98,9 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste if err != nil { t.Fatal(err) } - turnManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil) + + secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) + mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil) if err != nil { t.Fatal(err) } diff --git a/client/cmd/up.go b/client/cmd/up.go index 2ed6e41d2..05ecce9e0 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -15,11 +15,11 @@ import ( gstatus "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/durationpb" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/system" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/util" ) @@ -168,7 +168,10 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error { ctx, cancel = context.WithCancel(ctx) SetupCloseHandler(ctx, cancel) - connectClient := internal.NewConnectClient(ctx, config, peer.NewRecorder(config.ManagementURL.String())) + r := peer.NewRecorder(config.ManagementURL.String()) + r.GetFullStatus() + + connectClient := internal.NewConnectClient(ctx, config, r) return connectClient.Run() } diff --git a/client/errors/errors.go b/client/errors/errors.go index cef999ac8..8faadbda5 100644 --- a/client/errors/errors.go +++ b/client/errors/errors.go @@ -8,8 +8,8 @@ import ( ) func formatError(es []error) string { - if len(es) == 0 { - return fmt.Sprintf("0 error occurred:\n\t* %s", es[0]) + if len(es) == 1 { + return fmt.Sprintf("1 error occurred:\n\t* %s", es[0]) } points := make([]string, len(es)) diff --git a/client/firewall/iface.go b/client/firewall/iface.go index 882daef75..f349f9210 100644 --- a/client/firewall/iface.go +++ b/client/firewall/iface.go @@ -1,11 +1,13 @@ package firewall -import "github.com/netbirdio/netbird/iface" +import ( + "github.com/netbirdio/netbird/client/iface/device" +) // IFaceMapper defines subset methods of interface required for manager type IFaceMapper interface { Name() string - Address() iface.WGAddress + Address() device.WGAddress IsUserspaceBind() bool - SetFilter(iface.PacketFilter) error + SetFilter(device.PacketFilter) error } diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index b77cc8f43..c6a96a876 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -19,24 +19,22 @@ const ( // rules chains contains the effective ACL rules chainNameInputRules = "NETBIRD-ACL-INPUT" chainNameOutputRules = "NETBIRD-ACL-OUTPUT" - - postRoutingMark = "0x000007e4" ) type aclManager struct { - iptablesClient *iptables.IPTables - wgIface iFaceMapper - routeingFwChainName string + iptablesClient *iptables.IPTables + wgIface iFaceMapper + routingFwChainName string entries map[string][][]string ipsetStore *ipsetStore } -func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routeingFwChainName string) (*aclManager, error) { +func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) { m := &aclManager{ - iptablesClient: iptablesClient, - wgIface: wgIface, - routeingFwChainName: routeingFwChainName, + iptablesClient: iptablesClient, + wgIface: wgIface, + routingFwChainName: routingFwChainName, entries: make(map[string][][]string), ipsetStore: newIpsetStore(), @@ -61,7 +59,7 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, route return m, nil } -func (m *aclManager) AddFiltering( +func (m *aclManager) AddPeerFiltering( ip net.IP, protocol firewall.Protocol, sPort *firewall.Port, @@ -127,7 +125,7 @@ func (m *aclManager) AddFiltering( return nil, fmt.Errorf("rule already exists") } - if err := m.iptablesClient.Insert("filter", chain, 1, specs...); err != nil { + if err := m.iptablesClient.Append("filter", chain, specs...); err != nil { return nil, err } @@ -139,28 +137,16 @@ func (m *aclManager) AddFiltering( chain: chain, } - if !shouldAddToPrerouting(protocol, dPort, direction) { - return []firewall.Rule{rule}, nil - } - - rulePrerouting, err := m.addPreroutingFilter(ipsetName, string(protocol), dPortVal, ip) - if err != nil { - return []firewall.Rule{rule}, err - } - return []firewall.Rule{rule, rulePrerouting}, nil + return []firewall.Rule{rule}, nil } -// DeleteRule from the firewall by rule definition -func (m *aclManager) DeleteRule(rule firewall.Rule) error { +// DeletePeerRule from the firewall by rule definition +func (m *aclManager) DeletePeerRule(rule firewall.Rule) error { r, ok := rule.(*Rule) if !ok { return fmt.Errorf("invalid rule type") } - if r.chain == "PREROUTING" { - goto DELETERULE - } - if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok { // delete IP from ruleset IPs list and ipset if _, ok := ipsetList.ips[r.ip]; ok { @@ -185,14 +171,7 @@ func (m *aclManager) DeleteRule(rule firewall.Rule) error { } } -DELETERULE: - var table string - if r.chain == "PREROUTING" { - table = "mangle" - } else { - table = "filter" - } - err := m.iptablesClient.Delete(table, r.chain, r.specs...) + err := m.iptablesClient.Delete(tableName, r.chain, r.specs...) if err != nil { log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err) } @@ -203,44 +182,6 @@ func (m *aclManager) Reset() error { return m.cleanChains() } -func (m *aclManager) addPreroutingFilter(ipsetName string, protocol string, port string, ip net.IP) (*Rule, error) { - var src []string - if ipsetName != "" { - src = []string{"-m", "set", "--set", ipsetName, "src"} - } else { - src = []string{"-s", ip.String()} - } - specs := []string{ - "-d", m.wgIface.Address().IP.String(), - "-p", protocol, - "--dport", port, - "-j", "MARK", "--set-mark", postRoutingMark, - } - - specs = append(src, specs...) - - ok, err := m.iptablesClient.Exists("mangle", "PREROUTING", specs...) - if err != nil { - return nil, fmt.Errorf("failed to check rule: %w", err) - } - if ok { - return nil, fmt.Errorf("rule already exists") - } - - if err := m.iptablesClient.Insert("mangle", "PREROUTING", 1, specs...); err != nil { - return nil, err - } - - rule := &Rule{ - ruleID: uuid.New().String(), - specs: specs, - ipsetName: ipsetName, - ip: ip.String(), - chain: "PREROUTING", - } - return rule, nil -} - // todo write less destructive cleanup mechanism func (m *aclManager) cleanChains() error { ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules) @@ -291,25 +232,6 @@ func (m *aclManager) cleanChains() error { } } - ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING") - if err != nil { - log.Debugf("failed to list chains: %s", err) - return err - } - if ok { - for _, rule := range m.entries["PREROUTING"] { - err := m.iptablesClient.DeleteIfExists("mangle", "PREROUTING", rule...) - if err != nil { - log.Errorf("failed to delete rule: %v, %s", rule, err) - } - } - err = m.iptablesClient.ClearChain("mangle", "PREROUTING") - if err != nil { - log.Debugf("failed to clear %s chain: %s", "PREROUTING", err) - return err - } - } - for _, ipsetName := range m.ipsetStore.ipsetNames() { if err := ipset.Flush(ipsetName); err != nil { log.Errorf("flush ipset %q during reset: %v", ipsetName, err) @@ -338,17 +260,9 @@ func (m *aclManager) createDefaultChains() error { for chainName, rules := range m.entries { for _, rule := range rules { - if chainName == "FORWARD" { - // position 2 because we add it after router's, jump rule - if err := m.iptablesClient.InsertUnique(tableName, "FORWARD", 2, rule...); err != nil { - log.Debugf("failed to create input chain jump rule: %s", err) - return err - } - } else { - if err := m.iptablesClient.AppendUnique(tableName, chainName, rule...); err != nil { - log.Debugf("failed to create input chain jump rule: %s", err) - return err - } + if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil { + log.Debugf("failed to create input chain jump rule: %s", err) + return err } } } @@ -356,40 +270,29 @@ func (m *aclManager) createDefaultChains() error { return nil } +// seedInitialEntries adds default rules to the entries map, rules are inserted on pos 1, hence the order is reversed. +// We want to make sure our traffic is not dropped by existing rules. + +// The existing FORWARD rules/policies decide outbound traffic towards our interface. +// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. + +// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule. func (m *aclManager) seedInitialEntries() { - m.appendToEntries("INPUT", - []string{"-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) - m.appendToEntries("INPUT", - []string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) - - m.appendToEntries("INPUT", - []string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameInputRules}) + established := getConntrackEstablished() m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) - - m.appendToEntries("OUTPUT", - []string{"-o", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) - - m.appendToEntries("OUTPUT", - []string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) - - m.appendToEntries("OUTPUT", - []string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameOutputRules}) + m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules}) + m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...)) m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"}) + m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", chainNameOutputRules}) + m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) + m.appendToEntries("OUTPUT", append([]string{"-o", m.wgIface.Name()}, established...)) m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) - m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules}) - m.appendToEntries("FORWARD", - []string{"-o", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"}) - m.appendToEntries("FORWARD", - []string{"-i", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"}) - m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", m.routeingFwChainName}) - m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routeingFwChainName}) - - m.appendToEntries("PREROUTING", - []string{"-t", "mangle", "-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().IP.String(), "-m", "mark", "--mark", postRoutingMark}) + m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName}) + m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...)) } func (m *aclManager) appendToEntries(chainName string, spec []string) { @@ -456,18 +359,3 @@ func transformIPsetName(ipsetName string, sPort, dPort string) string { return ipsetName } } - -func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool { - if proto == "all" { - return false - } - - if direction != firewall.RuleDirectionIN { - return false - } - - if dPort == nil { - return false - } - return true -} diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 2d231ec45..6fefd58e6 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -4,13 +4,14 @@ import ( "context" "fmt" "net" + "net/netip" "sync" "github.com/coreos/go-iptables/iptables" log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" ) // Manager of iptables firewall @@ -21,7 +22,7 @@ type Manager struct { ipv4Client *iptables.IPTables aclMgr *aclManager - router *routerManager + router *router } // iFaceMapper defines subset methods of interface required for manager @@ -43,12 +44,12 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { ipv4Client: iptablesClient, } - m.router, err = newRouterManager(context, iptablesClient) + m.router, err = newRouter(context, iptablesClient, wgIface) if err != nil { log.Debugf("failed to initialize route related chains: %s", err) return nil, err } - m.aclMgr, err = newAclManager(iptablesClient, wgIface, m.router.RouteingFwChainName()) + m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD) if err != nil { log.Debugf("failed to initialize ACL manager: %s", err) return nil, err @@ -57,10 +58,10 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { return m, nil } -// AddFiltering rule to the firewall +// AddPeerFiltering adds a rule to the firewall // // Comment will be ignored because some system this feature is not supported -func (m *Manager) AddFiltering( +func (m *Manager) AddPeerFiltering( ip net.IP, protocol firewall.Protocol, sPort *firewall.Port, @@ -73,33 +74,62 @@ func (m *Manager) AddFiltering( m.mutex.Lock() defer m.mutex.Unlock() - return m.aclMgr.AddFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName) + return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName) } -// DeleteRule from the firewall by rule definition -func (m *Manager) DeleteRule(rule firewall.Rule) error { +func (m *Manager) AddRouteFiltering( + sources [] netip.Prefix, + destination netip.Prefix, + proto firewall.Protocol, + sPort *firewall.Port, + dPort *firewall.Port, + action firewall.Action, +) (firewall.Rule, error) { m.mutex.Lock() defer m.mutex.Unlock() - return m.aclMgr.DeleteRule(rule) + if !destination.Addr().Is4() { + return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String()) + } + + return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action) +} + +// DeletePeerRule from the firewall by rule definition +func (m *Manager) DeletePeerRule(rule firewall.Rule) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.aclMgr.DeletePeerRule(rule) +} + +func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.DeleteRouteRule(rule) } func (m *Manager) IsServerRouteSupported() bool { return true } -func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error { +func (m *Manager) AddNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.InsertRoutingRules(pair) + return m.router.AddNatRule(pair) } -func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error { +func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.RemoveRoutingRules(pair) + return m.router.RemoveNatRule(pair) +} + +func (m *Manager) SetLegacyManagement(isLegacy bool) error { + return firewall.SetLegacyManagement(m.router, isLegacy) } // Reset firewall to the default state @@ -125,7 +155,7 @@ func (m *Manager) AllowNetbird() error { return nil } - _, err := m.AddFiltering( + _, err := m.AddPeerFiltering( net.ParseIP("0.0.0.0"), "all", nil, @@ -138,7 +168,7 @@ func (m *Manager) AllowNetbird() error { if err != nil { return fmt.Errorf("failed to allow netbird interface traffic: %w", err) } - _, err = m.AddFiltering( + _, err = m.AddPeerFiltering( net.ParseIP("0.0.0.0"), "all", nil, @@ -153,3 +183,7 @@ func (m *Manager) AllowNetbird() error { // Flush doesn't need to be implemented for this manager func (m *Manager) Flush() error { return nil } + +func getConntrackEstablished() []string { + return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} +} diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index ceb116c62..498d8f58b 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -11,9 +11,24 @@ import ( "github.com/stretchr/testify/require" fw "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" ) +var ifaceMock = &iFaceMock{ + NameFunc: func() string { + return "lo" + }, + AddressFunc: func() iface.WGAddress { + return iface.WGAddress{ + IP: net.ParseIP("10.20.0.1"), + Network: &net.IPNet{ + IP: net.ParseIP("10.20.0.0"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + } + }, +} + // iFaceMapper defines subset methods of interface required for manager type iFaceMock struct { NameFunc func() string @@ -40,23 +55,8 @@ func TestIptablesManager(t *testing.T) { ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err) - mock := &iFaceMock{ - NameFunc: func() string { - return "lo" - }, - AddressFunc: func() iface.WGAddress { - return iface.WGAddress{ - IP: net.ParseIP("10.20.0.1"), - Network: &net.IPNet{ - IP: net.ParseIP("10.20.0.0"), - Mask: net.IPv4Mask(255, 255, 255, 0), - }, - } - }, - } - // just check on the local interface - manager, err := Create(context.Background(), mock) + manager, err := Create(context.Background(), ifaceMock) require.NoError(t, err) time.Sleep(time.Second) @@ -72,7 +72,7 @@ func TestIptablesManager(t *testing.T) { t.Run("add first rule", func(t *testing.T) { ip := net.ParseIP("10.20.0.2") port := &fw.Port{Values: []int{8080}} - rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") + rule1, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") require.NoError(t, err, "failed to add rule") for _, r := range rule1 { @@ -87,7 +87,7 @@ func TestIptablesManager(t *testing.T) { port := &fw.Port{ Values: []int{8043: 8046}, } - rule2, err = manager.AddFiltering( + rule2, err = manager.AddPeerFiltering( ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range") require.NoError(t, err, "failed to add rule") @@ -99,7 +99,7 @@ func TestIptablesManager(t *testing.T) { t.Run("delete first rule", func(t *testing.T) { for _, r := range rule1 { - err := manager.DeleteRule(r) + err := manager.DeletePeerRule(r) require.NoError(t, err, "failed to delete rule") checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...) @@ -108,7 +108,7 @@ func TestIptablesManager(t *testing.T) { t.Run("delete second rule", func(t *testing.T) { for _, r := range rule2 { - err := manager.DeleteRule(r) + err := manager.DeletePeerRule(r) require.NoError(t, err, "failed to delete rule") } @@ -119,7 +119,7 @@ func TestIptablesManager(t *testing.T) { // add second rule ip := net.ParseIP("10.20.0.3") port := &fw.Port{Values: []int{5353}} - _, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic") + _, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic") require.NoError(t, err, "failed to add rule") err = manager.Reset() @@ -170,7 +170,7 @@ func TestIptablesManagerIPSet(t *testing.T) { t.Run("add first rule with set", func(t *testing.T) { ip := net.ParseIP("10.20.0.2") port := &fw.Port{Values: []int{8080}} - rule1, err = manager.AddFiltering( + rule1, err = manager.AddPeerFiltering( ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "default", "accept HTTP traffic", ) @@ -189,7 +189,7 @@ func TestIptablesManagerIPSet(t *testing.T) { port := &fw.Port{ Values: []int{443}, } - rule2, err = manager.AddFiltering( + rule2, err = manager.AddPeerFiltering( ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "default", "accept HTTPS traffic from ports range", ) @@ -202,7 +202,7 @@ func TestIptablesManagerIPSet(t *testing.T) { t.Run("delete first rule", func(t *testing.T) { for _, r := range rule1 { - err := manager.DeleteRule(r) + err := manager.DeletePeerRule(r) require.NoError(t, err, "failed to delete rule") require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index") @@ -211,7 +211,7 @@ func TestIptablesManagerIPSet(t *testing.T) { t.Run("delete second rule", func(t *testing.T) { for _, r := range rule2 { - err := manager.DeleteRule(r) + err := manager.DeletePeerRule(r) require.NoError(t, err, "failed to delete rule") require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty") @@ -269,9 +269,9 @@ func TestIptablesCreatePerformance(t *testing.T) { for i := 0; i < testMax; i++ { port := &fw.Port{Values: []int{1000 + i}} if i%2 == 0 { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") } else { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") } require.NoError(t, err, "failed to add rule") diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index e8f09a106..737b20785 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -5,368 +5,478 @@ package iptables import ( "context" "fmt" + "net/netip" + "strconv" "strings" "github.com/coreos/go-iptables/iptables" + "github.com/hashicorp/go-multierror" + "github.com/nadoo/ipset" log "github.com/sirupsen/logrus" + nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/internal/acl/id" + "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" ) const ( - Ipv4Forwarding = "netbird-rt-forwarding" - ipv4Nat = "netbird-rt-nat" + ipv4Nat = "netbird-rt-nat" ) // constants needed to manage and create iptable rules const ( tableFilter = "filter" tableNat = "nat" - chainFORWARD = "FORWARD" chainPOSTROUTING = "POSTROUTING" chainRTNAT = "NETBIRD-RT-NAT" chainRTFWD = "NETBIRD-RT-FWD" routingFinalForwardJump = "ACCEPT" routingFinalNatJump = "MASQUERADE" + + matchSet = "--match-set" ) -type routerManager struct { - ctx context.Context - stop context.CancelFunc - iptablesClient *iptables.IPTables - rules map[string][]string +type routeFilteringRuleParams struct { + Sources []netip.Prefix + Destination netip.Prefix + Proto firewall.Protocol + SPort *firewall.Port + DPort *firewall.Port + Direction firewall.RuleDirection + Action firewall.Action + SetName string } -func newRouterManager(parentCtx context.Context, iptablesClient *iptables.IPTables) (*routerManager, error) { +type router struct { + ctx context.Context + stop context.CancelFunc + iptablesClient *iptables.IPTables + rules map[string][]string + ipsetCounter *refcounter.Counter[string, []netip.Prefix, struct{}] + wgIface iFaceMapper + legacyManagement bool +} + +func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { ctx, cancel := context.WithCancel(parentCtx) - m := &routerManager{ + r := &router{ ctx: ctx, stop: cancel, iptablesClient: iptablesClient, rules: make(map[string][]string), + wgIface: wgIface, } - err := m.cleanUpDefaultForwardRules() + r.ipsetCounter = refcounter.New( + r.createIpSet, + func(name string, _ struct{}) error { + return r.deleteIpSet(name) + }, + ) + + if err := ipset.Init(); err != nil { + return nil, fmt.Errorf("init ipset: %w", err) + } + + err := r.cleanUpDefaultForwardRules() if err != nil { - log.Errorf("failed to cleanup routing rules: %s", err) + log.Errorf("cleanup routing rules: %s", err) return nil, err } - err = m.createContainers() + err = r.createContainers() if err != nil { - log.Errorf("failed to create containers for route: %s", err) + log.Errorf("create containers for route: %s", err) } - return m, err + return r, err } -// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain -func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error { - err := i.insertRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, pair) - if err != nil { - return err +func (r *router) AddRouteFiltering( + sources []netip.Prefix, + destination netip.Prefix, + proto firewall.Protocol, + sPort *firewall.Port, + dPort *firewall.Port, + action firewall.Action, +) (firewall.Rule, error) { + ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action) + if _, ok := r.rules[string(ruleKey)]; ok { + return ruleKey, nil } - err = i.insertRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, firewall.GetInPair(pair)) - if err != nil { - return err + var setName string + if len(sources) > 1 { + setName = firewall.GenerateSetName(sources) + if _, err := r.ipsetCounter.Increment(setName, sources); err != nil { + return nil, fmt.Errorf("create or get ipset: %w", err) + } + } + + params := routeFilteringRuleParams{ + Sources: sources, + Destination: destination, + Proto: proto, + SPort: sPort, + DPort: dPort, + Action: action, + SetName: setName, + } + + rule := genRouteFilteringRuleSpec(params) + if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil { + return nil, fmt.Errorf("add route rule: %v", err) + } + + r.rules[string(ruleKey)] = rule + + return ruleKey, nil +} + +func (r *router) DeleteRouteRule(rule firewall.Rule) error { + ruleKey := rule.GetRuleID() + + if rule, exists := r.rules[ruleKey]; exists { + setName := r.findSetNameInRule(rule) + + if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil { + return fmt.Errorf("delete route rule: %v", err) + } + delete(r.rules, ruleKey) + + if setName != "" { + if _, err := r.ipsetCounter.Decrement(setName); err != nil { + return fmt.Errorf("failed to remove ipset: %w", err) + } + } + } else { + log.Debugf("route rule %s not found", ruleKey) + } + + return nil +} + +func (r *router) findSetNameInRule(rule []string) string { + for i, arg := range rule { + if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet { + return rule[i+3] + } + } + return "" +} + +func (r *router) createIpSet(setName string, sources []netip.Prefix) (struct{}, error) { + if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil { + return struct{}{}, fmt.Errorf("create set %s: %w", setName, err) + } + + for _, prefix := range sources { + if err := ipset.AddPrefix(setName, prefix); err != nil { + return struct{}{}, fmt.Errorf("add element to set %s: %w", setName, err) + } + } + + return struct{}{}, nil +} + +func (r *router) deleteIpSet(setName string) error { + if err := ipset.Destroy(setName); err != nil { + return fmt.Errorf("destroy set %s: %w", setName, err) + } + return nil +} + +// AddNatRule inserts an iptables rule pair into the nat chain +func (r *router) AddNatRule(pair firewall.RouterPair) error { + if r.legacyManagement { + log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination) + if err := r.addLegacyRouteRule(pair); err != nil { + return fmt.Errorf("add legacy routing rule: %w", err) + } } if !pair.Masquerade { return nil } - err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair) - if err != nil { - return err + if err := r.addNatRule(pair); err != nil { + return fmt.Errorf("add nat rule: %w", err) } - err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair)) - if err != nil { - return err + if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil { + return fmt.Errorf("add inverse nat rule: %w", err) } return nil } -// insertRoutingRule inserts an iptables rule -func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error { - var err error +// RemoveNatRule removes an iptables rule pair from forwarding and nat chains +func (r *router) RemoveNatRule(pair firewall.RouterPair) error { + if err := r.removeNatRule(pair); err != nil { + return fmt.Errorf("remove nat rule: %w", err) + } - ruleKey := firewall.GenKey(keyFormat, pair.ID) - rule := genRuleSpec(jump, pair.Source, pair.Destination) - existingRule, found := i.rules[ruleKey] - if found { - err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...) - if err != nil { - return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err) + if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { + return fmt.Errorf("remove inverse nat rule: %w", err) + } + + if err := r.removeLegacyRouteRule(pair); err != nil { + return fmt.Errorf("remove legacy routing rule: %w", err) + } + + return nil +} + +// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls +func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { + ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) + + if err := r.removeLegacyRouteRule(pair); err != nil { + return err + } + + rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump} + if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil { + return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) + } + + r.rules[ruleKey] = rule + + return nil +} + +func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error { + ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) + + if rule, exists := r.rules[ruleKey]; exists { + if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil { + return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) } - delete(i.rules, ruleKey) - } - - err = i.iptablesClient.Insert(table, chain, 1, rule...) - if err != nil { - return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err) - } - - i.rules[ruleKey] = rule - - return nil -} - -// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains -func (i *routerManager) RemoveRoutingRules(pair firewall.RouterPair) error { - err := i.removeRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, pair) - if err != nil { - return err - } - - err = i.removeRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, firewall.GetInPair(pair)) - if err != nil { - return err - } - - if !pair.Masquerade { - return nil - } - - err = i.removeRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, pair) - if err != nil { - return err - } - - err = i.removeRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, firewall.GetInPair(pair)) - if err != nil { - return err + delete(r.rules, ruleKey) + } else { + log.Debugf("legacy forwarding rule %s not found", ruleKey) } return nil } -func (i *routerManager) removeRoutingRule(keyFormat, table, chain string, pair firewall.RouterPair) error { - var err error +// GetLegacyManagement returns the current legacy management mode +func (r *router) GetLegacyManagement() bool { + return r.legacyManagement +} - ruleKey := firewall.GenKey(keyFormat, pair.ID) - existingRule, found := i.rules[ruleKey] - if found { - err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...) - if err != nil { - return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err) +// SetLegacyManagement sets the route manager to use legacy management mode +func (r *router) SetLegacyManagement(isLegacy bool) { + r.legacyManagement = isLegacy +} + +// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls +func (r *router) RemoveAllLegacyRouteRules() error { + var merr *multierror.Error + for k, rule := range r.rules { + if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) { + continue + } + if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) } } - delete(i.rules, ruleKey) - - return nil + return nberrors.FormatErrorOrNil(merr) } -func (i *routerManager) RouteingFwChainName() string { - return chainRTFWD -} - -func (i *routerManager) Reset() error { - err := i.cleanUpDefaultForwardRules() - if err != nil { - return err +func (r *router) Reset() error { + var merr *multierror.Error + if err := r.cleanUpDefaultForwardRules(); err != nil { + merr = multierror.Append(merr, err) } - i.rules = make(map[string][]string) - return nil + r.rules = make(map[string][]string) + + if err := r.ipsetCounter.Flush(); err != nil { + merr = multierror.Append(merr, err) + } + + return nberrors.FormatErrorOrNil(merr) } -func (i *routerManager) cleanUpDefaultForwardRules() error { - err := i.cleanJumpRules() +func (r *router) cleanUpDefaultForwardRules() error { + err := r.cleanJumpRules() if err != nil { return err } log.Debug("flushing routing related tables") - ok, err := i.iptablesClient.ChainExists(tableFilter, chainRTFWD) - if err != nil { - log.Errorf("failed check chain %s,error: %v", chainRTFWD, err) - return err - } else if ok { - err = i.iptablesClient.ClearAndDeleteChain(tableFilter, chainRTFWD) + for _, chain := range []string{chainRTFWD, chainRTNAT} { + table := tableFilter + if chain == chainRTNAT { + table = tableNat + } + + ok, err := r.iptablesClient.ChainExists(table, chain) if err != nil { - log.Errorf("failed cleaning chain %s,error: %v", chainRTFWD, err) + log.Errorf("failed check chain %s, error: %v", chain, err) return err + } else if ok { + err = r.iptablesClient.ClearAndDeleteChain(table, chain) + if err != nil { + log.Errorf("failed cleaning chain %s, error: %v", chain, err) + return err + } } } - ok, err = i.iptablesClient.ChainExists(tableNat, chainRTNAT) - if err != nil { - log.Errorf("failed check chain %s,error: %v", chainRTNAT, err) - return err - } else if ok { - err = i.iptablesClient.ClearAndDeleteChain(tableNat, chainRTNAT) - if err != nil { - log.Errorf("failed cleaning chain %s,error: %v", chainRTNAT, err) - return err - } - } - return nil -} - -func (i *routerManager) createContainers() error { - if i.rules[Ipv4Forwarding] != nil { - return nil - } - - errMSGFormat := "failed creating chain %s,error: %v" - err := i.createChain(tableFilter, chainRTFWD) - if err != nil { - return fmt.Errorf(errMSGFormat, chainRTFWD, err) - } - - err = i.createChain(tableNat, chainRTNAT) - if err != nil { - return fmt.Errorf(errMSGFormat, chainRTNAT, err) - } - - err = i.addJumpRules() - if err != nil { - return fmt.Errorf("error while creating jump rules: %v", err) - } - return nil } -// addJumpRules create jump rules to send packets to NetBird chains -func (i *routerManager) addJumpRules() error { - rule := []string{"-j", chainRTFWD} - err := i.iptablesClient.Insert(tableFilter, chainFORWARD, 1, rule...) +func (r *router) createContainers() error { + for _, chain := range []string{chainRTFWD, chainRTNAT} { + if err := r.createAndSetupChain(chain); err != nil { + return fmt.Errorf("create chain %s: %v", chain, err) + } + } + + if err := r.insertEstablishedRule(chainRTFWD); err != nil { + return fmt.Errorf("insert established rule: %v", err) + } + + return r.addJumpRules() +} + +func (r *router) createAndSetupChain(chain string) error { + table := r.getTableForChain(chain) + + if err := r.iptablesClient.NewChain(table, chain); err != nil { + return fmt.Errorf("failed creating chain %s, error: %v", chain, err) + } + + return nil +} + +func (r *router) getTableForChain(chain string) string { + if chain == chainRTNAT { + return tableNat + } + return tableFilter +} + +func (r *router) insertEstablishedRule(chain string) error { + establishedRule := getConntrackEstablished() + + err := r.iptablesClient.Insert(tableFilter, chain, 1, establishedRule...) + if err != nil { + return fmt.Errorf("failed to insert established rule: %v", err) + } + + ruleKey := "established-" + chain + r.rules[ruleKey] = establishedRule + + return nil +} + +func (r *router) addJumpRules() error { + rule := []string{"-j", chainRTNAT} + err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...) if err != nil { return err } - i.rules[Ipv4Forwarding] = rule - - rule = []string{"-j", chainRTNAT} - err = i.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...) - if err != nil { - return err - } - i.rules[ipv4Nat] = rule + r.rules[ipv4Nat] = rule return nil } -// cleanJumpRules cleans jump rules that was sending packets to NetBird chains -func (i *routerManager) cleanJumpRules() error { - var err error - errMSGFormat := "failed cleaning rule from chain %s,err: %v" - rule, found := i.rules[Ipv4Forwarding] +func (r *router) cleanJumpRules() error { + rule, found := r.rules[ipv4Nat] if found { - err = i.iptablesClient.DeleteIfExists(tableFilter, chainFORWARD, rule...) + err := r.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...) if err != nil { - return fmt.Errorf(errMSGFormat, chainFORWARD, err) - } - } - rule, found = i.rules[ipv4Nat] - if found { - err = i.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...) - if err != nil { - return fmt.Errorf(errMSGFormat, chainPOSTROUTING, err) + return fmt.Errorf("failed cleaning rule from chain %s, err: %v", chainPOSTROUTING, err) } } - rules, err := i.iptablesClient.List("nat", "POSTROUTING") - if err != nil { - return fmt.Errorf("failed to list rules: %s", err) - } - - for _, ruleString := range rules { - if !strings.Contains(ruleString, "NETBIRD") { - continue - } - rule := strings.Fields(ruleString) - err := i.iptablesClient.DeleteIfExists("nat", "POSTROUTING", rule[2:]...) - if err != nil { - return fmt.Errorf("failed to delete postrouting jump rule: %s", err) - } - } - - rules, err = i.iptablesClient.List(tableFilter, "FORWARD") - if err != nil { - return fmt.Errorf("failed to list rules in FORWARD chain: %s", err) - } - - for _, ruleString := range rules { - if !strings.Contains(ruleString, "NETBIRD") { - continue - } - rule := strings.Fields(ruleString) - err := i.iptablesClient.DeleteIfExists(tableFilter, "FORWARD", rule[2:]...) - if err != nil { - return fmt.Errorf("failed to delete FORWARD jump rule: %s", err) - } - } return nil } -func (i *routerManager) createChain(table, newChain string) error { - chains, err := i.iptablesClient.ListChains(table) - if err != nil { - return fmt.Errorf("couldn't get %s table chains, error: %v", table, err) - } +func (r *router) addNatRule(pair firewall.RouterPair) error { + ruleKey := firewall.GenKey(firewall.NatFormat, pair) - shouldCreateChain := true - for _, chain := range chains { - if chain == newChain { - shouldCreateChain = false - } - } - - if shouldCreateChain { - err = i.iptablesClient.NewChain(table, newChain) - if err != nil { - return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err) - } - - // Add the loopback return rule to the NAT chain - loopbackRule := []string{"-o", "lo", "-j", "RETURN"} - err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...) - if err != nil { - return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err) - } - - err = i.iptablesClient.Append(table, newChain, "-j", "RETURN") - if err != nil { - return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err) - } - - } - return nil -} - -// addNATRule appends an iptables rule pair to the nat chain -func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error { - ruleKey := firewall.GenKey(keyFormat, pair.ID) - rule := genRuleSpec(jump, pair.Source, pair.Destination) - existingRule, found := i.rules[ruleKey] - if found { - err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...) - if err != nil { + if rule, exists := r.rules[ruleKey]; exists { + if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil { return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err) } - delete(i.rules, ruleKey) + delete(r.rules, ruleKey) } - // inserting after loopback ignore rule - err := i.iptablesClient.Insert(table, chain, 2, rule...) - if err != nil { + rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, r.wgIface.Name(), pair.Inverse) + if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule...); err != nil { return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err) } - i.rules[ruleKey] = rule + r.rules[ruleKey] = rule return nil } -// genRuleSpec generates rule specification -func genRuleSpec(jump, source, destination string) []string { - return []string{"-s", source, "-d", destination, "-j", jump} +func (r *router) removeNatRule(pair firewall.RouterPair) error { + ruleKey := firewall.GenKey(firewall.NatFormat, pair) + + if rule, exists := r.rules[ruleKey]; exists { + if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil { + return fmt.Errorf("error while removing existing nat rule for %s: %v", pair.Destination, err) + } + + delete(r.rules, ruleKey) + } else { + log.Debugf("nat rule %s not found", ruleKey) + } + + return nil } -func getIptablesRuleType(table string) string { - ruleType := "forwarding" - if table == tableNat { - ruleType = "nat" +func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string { + intdir := "-i" + if inverse { + intdir = "-o" } - return ruleType + return []string{intdir, intf, "-s", source.String(), "-d", destination.String(), "-j", jump} +} + +func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { + var rule []string + + if params.SetName != "" { + rule = append(rule, "-m", "set", matchSet, params.SetName, "src") + } else if len(params.Sources) > 0 { + source := params.Sources[0] + rule = append(rule, "-s", source.String()) + } + + rule = append(rule, "-d", params.Destination.String()) + + if params.Proto != firewall.ProtocolALL { + rule = append(rule, "-p", strings.ToLower(string(params.Proto))) + rule = append(rule, applyPort("--sport", params.SPort)...) + rule = append(rule, applyPort("--dport", params.DPort)...) + } + + rule = append(rule, "-j", actionToStr(params.Action)) + + return rule +} + +func applyPort(flag string, port *firewall.Port) []string { + if port == nil { + return nil + } + + if port.IsRange && len(port.Values) == 2 { + return []string{flag, fmt.Sprintf("%d:%d", port.Values[0], port.Values[1])} + } + + if len(port.Values) > 1 { + portList := make([]string, len(port.Values)) + for i, p := range port.Values { + portList[i] = strconv.Itoa(p) + } + return []string{"-m", "multiport", flag, strings.Join(portList, ",")} + } + + return []string{flag, strconv.Itoa(port.Values[0])} } diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index 79b970c36..6cede09e2 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -4,11 +4,13 @@ package iptables import ( "context" + "net/netip" "os/exec" "testing" "github.com/coreos/go-iptables/iptables" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -28,7 +30,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "failed to init iptables client") - manager, err := newRouterManager(context.TODO(), iptablesClient) + manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) require.NoError(t, err, "should return a valid iptables manager") defer func() { @@ -37,26 +39,22 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { require.Len(t, manager.rules, 2, "should have created rules map") - exists, err := manager.iptablesClient.Exists(tableFilter, chainFORWARD, manager.rules[Ipv4Forwarding]...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainFORWARD) - require.True(t, exists, "forwarding rule should exist") - - exists, err = manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...) + exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) require.True(t, exists, "postrouting rule should exist") pair := firewall.RouterPair{ ID: "abc", - Source: "100.100.100.1/32", - Destination: "100.100.100.0/24", + Source: netip.MustParsePrefix("100.100.100.1/32"), + Destination: netip.MustParsePrefix("100.100.100.0/24"), Masquerade: true, } - forward4Rule := genRuleSpec(routingFinalForwardJump, pair.Source, pair.Destination) + forward4Rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump} err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...) require.NoError(t, err, "inserting rule should not return error") - nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination) + nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, ifaceMock.Name(), false) err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...) require.NoError(t, err, "inserting rule should not return error") @@ -65,7 +63,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { require.NoError(t, err, "shouldn't return error") } -func TestIptablesManager_InsertRoutingRules(t *testing.T) { +func TestIptablesManager_AddNatRule(t *testing.T) { if !isIptablesSupported() { t.SkipNow() @@ -76,7 +74,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "failed to init iptables client") - manager, err := newRouterManager(context.TODO(), iptablesClient) + manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) require.NoError(t, err, "shouldn't return error") defer func() { @@ -86,35 +84,13 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) { } }() - err = manager.InsertRoutingRules(testCase.InputPair) + err = manager.AddNatRule(testCase.InputPair) require.NoError(t, err, "forwarding pair should be inserted") - forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID) - forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination) + natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) + natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false) - exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD) - require.True(t, exists, "forwarding rule should exist") - - foundRule, found := manager.rules[forwardRuleKey] - require.True(t, found, "forwarding rule should exist in the manager map") - require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match") - - inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID) - inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination) - - exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD) - require.True(t, exists, "income forwarding rule should exist") - - foundRule, found = manager.rules[inForwardRuleKey] - require.True(t, found, "income forwarding rule should exist in the manager map") - require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match") - - natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID) - natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination) - - exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...) + exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) if testCase.InputPair.Masquerade { require.True(t, exists, "nat rule should be created") @@ -127,8 +103,8 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) { require.False(t, foundNat, "nat rule should not exist in the map") } - inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID) - inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination) + inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) + inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true) exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) @@ -146,7 +122,7 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) { } } -func TestIptablesManager_RemoveRoutingRules(t *testing.T) { +func TestIptablesManager_RemoveNatRule(t *testing.T) { if !isIptablesSupported() { t.SkipNow() @@ -156,7 +132,7 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) { iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) - manager, err := newRouterManager(context.TODO(), iptablesClient) + manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) require.NoError(t, err, "shouldn't return error") defer func() { _ = manager.Reset() @@ -164,26 +140,14 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) { require.NoError(t, err, "shouldn't return error") - forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID) - forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination) - - err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...) - require.NoError(t, err, "inserting rule should not return error") - - inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID) - inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination) - - err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...) - require.NoError(t, err, "inserting rule should not return error") - - natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID) - natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination) + natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) + natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false) err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...) require.NoError(t, err, "inserting rule should not return error") - inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID) - inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination) + inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) + inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true) err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...) require.NoError(t, err, "inserting rule should not return error") @@ -191,28 +155,14 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) { err = manager.Reset() require.NoError(t, err, "shouldn't return error") - err = manager.RemoveRoutingRules(testCase.InputPair) + err = manager.RemoveNatRule(testCase.InputPair) require.NoError(t, err, "shouldn't return error") - exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD) - require.False(t, exists, "forwarding rule should not exist") - - _, found := manager.rules[forwardRuleKey] - require.False(t, found, "forwarding rule should exist in the manager map") - - exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...) - require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD) - require.False(t, exists, "income forwarding rule should not exist") - - _, found = manager.rules[inForwardRuleKey] - require.False(t, found, "income forwarding rule should exist in the manager map") - - exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...) + exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) require.False(t, exists, "nat rule should not exist") - _, found = manager.rules[natRuleKey] + _, found := manager.rules[natRuleKey] require.False(t, found, "nat rule should exist in the manager map") exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...) @@ -221,7 +171,175 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) { _, found = manager.rules[inNatRuleKey] require.False(t, found, "income nat rule should exist in the manager map") - + }) + } +} + +func TestRouter_AddRouteFiltering(t *testing.T) { + if !isIptablesSupported() { + t.Skip("iptables not supported on this system") + } + + iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) + require.NoError(t, err, "Failed to create iptables client") + + r, err := newRouter(context.Background(), iptablesClient, ifaceMock) + require.NoError(t, err, "Failed to create router manager") + + defer func() { + err := r.Reset() + require.NoError(t, err, "Failed to reset router") + }() + + tests := []struct { + name string + sources []netip.Prefix + destination netip.Prefix + proto firewall.Protocol + sPort *firewall.Port + dPort *firewall.Port + direction firewall.RuleDirection + action firewall.Action + expectSet bool + }{ + { + name: "Basic TCP rule with single source", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}, + destination: netip.MustParsePrefix("10.0.0.0/24"), + proto: firewall.ProtocolTCP, + sPort: nil, + dPort: &firewall.Port{Values: []int{80}}, + direction: firewall.RuleDirectionIN, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "UDP rule with multiple sources", + sources: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/16"), + netip.MustParsePrefix("192.168.0.0/16"), + }, + destination: netip.MustParsePrefix("10.0.0.0/8"), + proto: firewall.ProtocolUDP, + sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true}, + dPort: nil, + direction: firewall.RuleDirectionOUT, + action: firewall.ActionDrop, + expectSet: true, + }, + { + name: "All protocols rule", + sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + destination: netip.MustParsePrefix("0.0.0.0/0"), + proto: firewall.ProtocolALL, + sPort: nil, + dPort: nil, + direction: firewall.RuleDirectionIN, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "ICMP rule", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + destination: netip.MustParsePrefix("10.0.0.0/8"), + proto: firewall.ProtocolICMP, + sPort: nil, + dPort: nil, + direction: firewall.RuleDirectionIN, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "TCP rule with multiple source ports", + sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")}, + destination: netip.MustParsePrefix("192.168.0.0/16"), + proto: firewall.ProtocolTCP, + sPort: &firewall.Port{Values: []int{80, 443, 8080}}, + dPort: nil, + direction: firewall.RuleDirectionOUT, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "UDP rule with single IP and port range", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")}, + destination: netip.MustParsePrefix("10.0.0.0/24"), + proto: firewall.ProtocolUDP, + sPort: nil, + dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true}, + direction: firewall.RuleDirectionIN, + action: firewall.ActionDrop, + expectSet: false, + }, + { + name: "TCP rule with source and destination ports", + sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}, + destination: netip.MustParsePrefix("172.16.0.0/16"), + proto: firewall.ProtocolTCP, + sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true}, + dPort: &firewall.Port{Values: []int{22}}, + direction: firewall.RuleDirectionOUT, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "Drop all incoming traffic", + sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + destination: netip.MustParsePrefix("192.168.0.0/24"), + proto: firewall.ProtocolALL, + sPort: nil, + dPort: nil, + direction: firewall.RuleDirectionIN, + action: firewall.ActionDrop, + expectSet: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) + require.NoError(t, err, "AddRouteFiltering failed") + + // Check if the rule is in the internal map + rule, ok := r.rules[ruleKey.GetRuleID()] + assert.True(t, ok, "Rule not found in internal map") + + // Log the internal rule + t.Logf("Internal rule: %v", rule) + + // Check if the rule exists in iptables + exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, rule...) + assert.NoError(t, err, "Failed to check rule existence") + assert.True(t, exists, "Rule not found in iptables") + + // Verify rule content + params := routeFilteringRuleParams{ + Sources: tt.sources, + Destination: tt.destination, + Proto: tt.proto, + SPort: tt.sPort, + DPort: tt.dPort, + Action: tt.action, + SetName: "", + } + + expectedRule := genRouteFilteringRuleSpec(params) + + if tt.expectSet { + setName := firewall.GenerateSetName(tt.sources) + params.SetName = setName + expectedRule = genRouteFilteringRuleSpec(params) + + // Check if the set was created + _, exists := r.ipsetCounter.Get(setName) + assert.True(t, exists, "IPSet not created") + } + + assert.Equal(t, expectedRule, rule, "Rule content mismatch") + + // Clean up + err = r.DeleteRouteRule(ruleKey) + require.NoError(t, err, "Failed to delete rule") }) } } diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 6e4edb63e..a6185d370 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -1,15 +1,21 @@ package manager import ( + "crypto/sha256" + "encoding/hex" "fmt" "net" + "net/netip" + "sort" + "strings" + + log "github.com/sirupsen/logrus" ) const ( - NatFormat = "netbird-nat-%s" - ForwardingFormat = "netbird-fwd-%s" - InNatFormat = "netbird-nat-in-%s" - InForwardingFormat = "netbird-fwd-in-%s" + ForwardingFormatPrefix = "netbird-fwd-" + ForwardingFormat = "netbird-fwd-%s-%t" + NatFormat = "netbird-nat-%s-%t" ) // Rule abstraction should be implemented by each firewall manager @@ -49,11 +55,11 @@ type Manager interface { // AllowNetbird allows netbird interface traffic AllowNetbird() error - // AddFiltering rule to the firewall + // AddPeerFiltering adds a rule to the firewall // // If comment argument is empty firewall manager should set // rule ID as comment for the rule - AddFiltering( + AddPeerFiltering( ip net.IP, proto Protocol, sPort *Port, @@ -64,17 +70,25 @@ type Manager interface { comment string, ) ([]Rule, error) - // DeleteRule from the firewall by rule definition - DeleteRule(rule Rule) error + // DeletePeerRule from the firewall by rule definition + DeletePeerRule(rule Rule) error // IsServerRouteSupported returns true if the firewall supports server side routing operations IsServerRouteSupported() bool - // InsertRoutingRules inserts a routing firewall rule - InsertRoutingRules(pair RouterPair) error + AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error) - // RemoveRoutingRules removes a routing firewall rule - RemoveRoutingRules(pair RouterPair) error + // DeleteRouteRule deletes a routing rule + DeleteRouteRule(rule Rule) error + + // AddNatRule inserts a routing NAT rule + AddNatRule(pair RouterPair) error + + // RemoveNatRule removes a routing NAT rule + RemoveNatRule(pair RouterPair) error + + // SetLegacyManagement sets the legacy management mode + SetLegacyManagement(legacy bool) error // Reset firewall to the default state Reset() error @@ -83,6 +97,89 @@ type Manager interface { Flush() error } -func GenKey(format string, input string) string { - return fmt.Sprintf(format, input) +func GenKey(format string, pair RouterPair) string { + return fmt.Sprintf(format, pair.ID, pair.Inverse) +} + +// LegacyManager defines the interface for legacy management operations +type LegacyManager interface { + RemoveAllLegacyRouteRules() error + GetLegacyManagement() bool + SetLegacyManagement(bool) +} + +// SetLegacyManagement sets the route manager to use legacy management +func SetLegacyManagement(router LegacyManager, isLegacy bool) error { + oldLegacy := router.GetLegacyManagement() + + if oldLegacy != isLegacy { + router.SetLegacyManagement(isLegacy) + log.Debugf("Set legacy management to %v", isLegacy) + } + + // client reconnected to a newer mgmt, we need to clean up the legacy rules + if !isLegacy && oldLegacy { + if err := router.RemoveAllLegacyRouteRules(); err != nil { + return fmt.Errorf("remove legacy routing rules: %v", err) + } + + log.Debugf("Legacy routing rules removed") + } + + return nil +} + +// GenerateSetName generates a unique name for an ipset based on the given sources. +func GenerateSetName(sources []netip.Prefix) string { + // sort for consistent naming + sortPrefixes(sources) + + var sourcesStr strings.Builder + for _, src := range sources { + sourcesStr.WriteString(src.String()) + } + + hash := sha256.Sum256([]byte(sourcesStr.String())) + shortHash := hex.EncodeToString(hash[:])[:8] + + return fmt.Sprintf("nb-%s", shortHash) +} + +// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix +func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix { + if len(prefixes) == 0 { + return prefixes + } + + merged := []netip.Prefix{prefixes[0]} + for _, prefix := range prefixes[1:] { + last := merged[len(merged)-1] + if last.Contains(prefix.Addr()) { + // If the current prefix is contained within the last merged prefix, skip it + continue + } + if prefix.Contains(last.Addr()) { + // If the current prefix contains the last merged prefix, replace it + merged[len(merged)-1] = prefix + } else { + // Otherwise, add the current prefix to the merged list + merged = append(merged, prefix) + } + } + + return merged +} + +// sortPrefixes sorts the given slice of netip.Prefix in place. +// It sorts first by IP address, then by prefix length (most specific to least specific). +func sortPrefixes(prefixes []netip.Prefix) { + sort.Slice(prefixes, func(i, j int) bool { + addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr()) + if addrCmp != 0 { + return addrCmp < 0 + } + + // If IP addresses are the same, compare prefix lengths (longer prefixes first) + return prefixes[i].Bits() > prefixes[j].Bits() + }) } diff --git a/client/firewall/manager/firewall_test.go b/client/firewall/manager/firewall_test.go new file mode 100644 index 000000000..3f47d6679 --- /dev/null +++ b/client/firewall/manager/firewall_test.go @@ -0,0 +1,192 @@ +package manager_test + +import ( + "net/netip" + "reflect" + "regexp" + "testing" + + "github.com/netbirdio/netbird/client/firewall/manager" +) + +func TestGenerateSetName(t *testing.T) { + t.Run("Different orders result in same hash", func(t *testing.T) { + prefixes1 := []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("10.0.0.0/8"), + } + prefixes2 := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("192.168.1.0/24"), + } + + result1 := manager.GenerateSetName(prefixes1) + result2 := manager.GenerateSetName(prefixes2) + + if result1 != result2 { + t.Errorf("Different orders produced different hashes: %s != %s", result1, result2) + } + }) + + t.Run("Result format is correct", func(t *testing.T) { + prefixes := []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("10.0.0.0/8"), + } + + result := manager.GenerateSetName(prefixes) + + matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result) + if err != nil { + t.Fatalf("Error matching regex: %v", err) + } + if !matched { + t.Errorf("Result format is incorrect: %s", result) + } + }) + + t.Run("Empty input produces consistent result", func(t *testing.T) { + result1 := manager.GenerateSetName([]netip.Prefix{}) + result2 := manager.GenerateSetName([]netip.Prefix{}) + + if result1 != result2 { + t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2) + } + }) + + t.Run("IPv4 and IPv6 mixing", func(t *testing.T) { + prefixes1 := []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("2001:db8::/32"), + } + prefixes2 := []netip.Prefix{ + netip.MustParsePrefix("2001:db8::/32"), + netip.MustParsePrefix("192.168.1.0/24"), + } + + result1 := manager.GenerateSetName(prefixes1) + result2 := manager.GenerateSetName(prefixes2) + + if result1 != result2 { + t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2) + } + }) +} + +func TestMergeIPRanges(t *testing.T) { + tests := []struct { + name string + input []netip.Prefix + expected []netip.Prefix + }{ + { + name: "Empty input", + input: []netip.Prefix{}, + expected: []netip.Prefix{}, + }, + { + name: "Single range", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + }, + }, + { + name: "Two non-overlapping ranges", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("10.0.0.0/8"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("10.0.0.0/8"), + }, + }, + { + name: "One range containing another", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + netip.MustParsePrefix("192.168.1.0/24"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + }, + }, + { + name: "One range containing another (different order)", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.0.0/16"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + }, + }, + { + name: "Overlapping ranges", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.1.128/25"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + }, + }, + { + name: "Overlapping ranges (different order)", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.128/25"), + netip.MustParsePrefix("192.168.1.0/24"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + }, + }, + { + name: "Multiple overlapping ranges", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.2.0/24"), + netip.MustParsePrefix("192.168.1.128/25"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + }, + }, + { + name: "Partially overlapping ranges", + input: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/23"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.2.0/25"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/23"), + netip.MustParsePrefix("192.168.2.0/25"), + }, + }, + { + name: "IPv6 ranges", + input: []netip.Prefix{ + netip.MustParsePrefix("2001:db8::/32"), + netip.MustParsePrefix("2001:db8:1::/48"), + netip.MustParsePrefix("2001:db8:2::/48"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("2001:db8::/32"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := manager.MergeIPRanges(tt.input) + if !reflect.DeepEqual(result, tt.expected) { + t.Errorf("MergeIPRanges() = %v, want %v", result, tt.expected) + } + }) + } +} diff --git a/client/firewall/manager/routerpair.go b/client/firewall/manager/routerpair.go index b63a9f104..8c94b7dd4 100644 --- a/client/firewall/manager/routerpair.go +++ b/client/firewall/manager/routerpair.go @@ -1,18 +1,26 @@ package manager +import ( + "net/netip" + + "github.com/netbirdio/netbird/route" +) + type RouterPair struct { - ID string - Source string - Destination string + ID route.ID + Source netip.Prefix + Destination netip.Prefix Masquerade bool + Inverse bool } -func GetInPair(pair RouterPair) RouterPair { +func GetInversePair(pair RouterPair) RouterPair { return RouterPair{ ID: pair.ID, // invert Source/Destination Source: pair.Destination, Destination: pair.Source, Masquerade: pair.Masquerade, + Inverse: true, } } diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index 1fa41b63a..eaf7fb6a0 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -16,7 +16,7 @@ import ( "golang.org/x/sys/unix" firewall "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" ) const ( @@ -33,9 +33,10 @@ const ( allowNetbirdInputRuleID = "allow Netbird incoming traffic" ) +const flushError = "flush: %w" + var ( - anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} - postroutingMark = []byte{0xe4, 0x7, 0x0, 0x00} + anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} ) type AclManager struct { @@ -48,7 +49,6 @@ type AclManager struct { chainInputRules *nftables.Chain chainOutputRules *nftables.Chain chainFwFilter *nftables.Chain - chainPrerouting *nftables.Chain ipsetStore *ipsetStore rules map[string]*Rule @@ -64,7 +64,7 @@ type iFaceMapper interface { func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) { // sConn is used for creating sets and adding/removing elements from them // it's differ then rConn (which does create new conn for each flush operation) - // and is permanent. Using same connection for booth type of operations + // and is permanent. Using same connection for both type of operations // overloads netlink with high amount of rules ( > 10000) sConn, err := nftables.New(nftables.AsLasting()) if err != nil { @@ -90,11 +90,11 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainNa return m, nil } -// AddFiltering rule to the firewall +// AddPeerFiltering rule to the firewall // // If comment argument is empty firewall manager should set // rule ID as comment for the rule -func (m *AclManager) AddFiltering( +func (m *AclManager) AddPeerFiltering( ip net.IP, proto firewall.Protocol, sPort *firewall.Port, @@ -120,20 +120,11 @@ func (m *AclManager) AddFiltering( } newRules = append(newRules, ioRule) - if !shouldAddToPrerouting(proto, dPort, direction) { - return newRules, nil - } - - preroutingRule, err := m.addPreroutingFiltering(ipset, proto, dPort, ip) - if err != nil { - return newRules, err - } - newRules = append(newRules, preroutingRule) return newRules, nil } -// DeleteRule from the firewall by rule definition -func (m *AclManager) DeleteRule(rule firewall.Rule) error { +// DeletePeerRule from the firewall by rule definition +func (m *AclManager) DeletePeerRule(rule firewall.Rule) error { r, ok := rule.(*Rule) if !ok { return fmt.Errorf("invalid rule type") @@ -199,8 +190,7 @@ func (m *AclManager) DeleteRule(rule firewall.Rule) error { return nil } -// createDefaultAllowRules In case if the USP firewall manager can use the native firewall manager we must to create allow rules for -// input and output chains +// createDefaultAllowRules creates default allow rules for the input and output chains func (m *AclManager) createDefaultAllowRules() error { expIn := []expr.Any{ &expr.Payload{ @@ -214,13 +204,13 @@ func (m *AclManager) createDefaultAllowRules() error { SourceRegister: 1, DestRegister: 1, Len: 4, - Mask: []byte{0x00, 0x00, 0x00, 0x00}, - Xor: zeroXor, + Mask: []byte{0, 0, 0, 0}, + Xor: []byte{0, 0, 0, 0}, }, // net address &expr.Cmp{ Register: 1, - Data: []byte{0x00, 0x00, 0x00, 0x00}, + Data: []byte{0, 0, 0, 0}, }, &expr.Verdict{ Kind: expr.VerdictAccept, @@ -246,13 +236,13 @@ func (m *AclManager) createDefaultAllowRules() error { SourceRegister: 1, DestRegister: 1, Len: 4, - Mask: []byte{0x00, 0x00, 0x00, 0x00}, - Xor: zeroXor, + Mask: []byte{0, 0, 0, 0}, + Xor: []byte{0, 0, 0, 0}, }, // net address &expr.Cmp{ Register: 1, - Data: []byte{0x00, 0x00, 0x00, 0x00}, + Data: []byte{0, 0, 0, 0}, }, &expr.Verdict{ Kind: expr.VerdictAccept, @@ -266,10 +256,8 @@ func (m *AclManager) createDefaultAllowRules() error { Exprs: expOut, }) - err := m.rConn.Flush() - if err != nil { - log.Debugf("failed to create default allow rules: %s", err) - return err + if err := m.rConn.Flush(); err != nil { + return fmt.Errorf(flushError, err) } return nil } @@ -290,15 +278,11 @@ func (m *AclManager) Flush() error { log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err) } - if err := m.refreshRuleHandles(m.chainPrerouting); err != nil { - log.Errorf("failed to refresh rule handles IPv4 prerouting chain: %v", err) - } - return nil } func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) { - ruleId := generateRuleId(ip, sPort, dPort, direction, action, ipset) + ruleId := generatePeerRuleId(ip, sPort, dPort, direction, action, ipset) if r, ok := m.rules[ruleId]; ok { return &Rule{ r.nftRule, @@ -308,18 +292,7 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f }, nil } - ifaceKey := expr.MetaKeyIIFNAME - if direction == firewall.RuleDirectionOUT { - ifaceKey = expr.MetaKeyOIFNAME - } - expressions := []expr.Any{ - &expr.Meta{Key: ifaceKey, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - } + var expressions []expr.Any if proto != firewall.ProtocolALL { expressions = append(expressions, &expr.Payload{ @@ -329,21 +302,15 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f Len: uint32(1), }) - var protoData []byte - switch proto { - case firewall.ProtocolTCP: - protoData = []byte{unix.IPPROTO_TCP} - case firewall.ProtocolUDP: - protoData = []byte{unix.IPPROTO_UDP} - case firewall.ProtocolICMP: - protoData = []byte{unix.IPPROTO_ICMP} - default: - return nil, fmt.Errorf("unsupported protocol: %s", proto) + protoData, err := protoToInt(proto) + if err != nil { + return nil, fmt.Errorf("convert protocol to number: %v", err) } + expressions = append(expressions, &expr.Cmp{ Register: 1, Op: expr.CmpOpEq, - Data: protoData, + Data: []byte{protoData}, }) } @@ -432,10 +399,9 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f } else { chain = m.chainOutputRules } - nftRule := m.rConn.InsertRule(&nftables.Rule{ + nftRule := m.rConn.AddRule(&nftables.Rule{ Table: m.workTable, Chain: chain, - Position: 0, Exprs: expressions, UserData: userData, }) @@ -453,139 +419,13 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f return rule, nil } -func (m *AclManager) addPreroutingFiltering(ipset *nftables.Set, proto firewall.Protocol, port *firewall.Port, ip net.IP) (*Rule, error) { - var protoData []byte - switch proto { - case firewall.ProtocolTCP: - protoData = []byte{unix.IPPROTO_TCP} - case firewall.ProtocolUDP: - protoData = []byte{unix.IPPROTO_UDP} - case firewall.ProtocolICMP: - protoData = []byte{unix.IPPROTO_ICMP} - default: - return nil, fmt.Errorf("unsupported protocol: %s", proto) - } - - ruleId := generateRuleIdForMangle(ipset, ip, proto, port) - if r, ok := m.rules[ruleId]; ok { - return &Rule{ - r.nftRule, - r.nftSet, - r.ruleID, - ip, - }, nil - } - - var ipExpression expr.Any - // add individual IP for match if no ipset defined - rawIP := ip.To4() - if ipset == nil { - ipExpression = &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: rawIP, - } - } else { - ipExpression = &expr.Lookup{ - SourceRegister: 1, - SetName: ipset.Name, - SetID: ipset.ID, - } - } - - expressions := []expr.Any{ - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - ipExpression, - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: m.wgIface.Address().IP.To4(), - }, - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: uint32(9), - Len: uint32(1), - }, - &expr.Cmp{ - Register: 1, - Op: expr.CmpOpEq, - Data: protoData, - }, - } - - if port != nil { - expressions = append(expressions, - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseTransportHeader, - Offset: 2, - Len: 2, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: encodePort(*port), - }, - ) - } - - expressions = append(expressions, - &expr.Immediate{ - Register: 1, - Data: postroutingMark, - }, - &expr.Meta{ - Key: expr.MetaKeyMARK, - SourceRegister: true, - Register: 1, - }, - ) - - nftRule := m.rConn.InsertRule(&nftables.Rule{ - Table: m.workTable, - Chain: m.chainPrerouting, - Position: 0, - Exprs: expressions, - UserData: []byte(ruleId), - }) - - if err := m.rConn.Flush(); err != nil { - return nil, fmt.Errorf("flush insert rule: %v", err) - } - - rule := &Rule{ - nftRule: nftRule, - nftSet: ipset, - ruleID: ruleId, - ip: ip, - } - - m.rules[ruleId] = rule - if ipset != nil { - m.ipsetStore.AddReferenceToIpset(ipset.Name) - } - return rule, nil -} - func (m *AclManager) createDefaultChains() (err error) { // chainNameInputRules chain := m.createChain(chainNameInputRules) err = m.rConn.Flush() if err != nil { log.Debugf("failed to create chain (%s): %s", chain.Name, err) - return err + return fmt.Errorf(flushError, err) } m.chainInputRules = chain @@ -601,9 +441,6 @@ func (m *AclManager) createDefaultChains() (err error) { // netbird-acl-input-filter // type filter hook input priority filter; policy accept; chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput) - //netbird-acl-input-filter iifname "wt0" ip saddr 100.72.0.0/16 ip daddr != 100.72.0.0/16 accept - m.addRouteAllowRule(chain, expr.MetaKeyIIFNAME) - m.addFwdAllow(chain, expr.MetaKeyIIFNAME) m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules m.addDropExpressions(chain, expr.MetaKeyIIFNAME) err = m.rConn.Flush() @@ -615,7 +452,6 @@ func (m *AclManager) createDefaultChains() (err error) { // netbird-acl-output-filter // type filter hook output priority filter; policy accept; chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput) - m.addRouteAllowRule(chain, expr.MetaKeyOIFNAME) m.addFwdAllow(chain, expr.MetaKeyOIFNAME) m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules m.addDropExpressions(chain, expr.MetaKeyOIFNAME) @@ -627,24 +463,15 @@ func (m *AclManager) createDefaultChains() (err error) { // netbird-acl-forward-filter m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward) - m.addJumpRulesToRtForward() // to - m.addMarkAccept() - m.addJumpRuleToInputChain() // to netbird-acl-input-rules + m.addJumpRulesToRtForward() // to netbird-rt-fwd m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME) + err = m.rConn.Flush() if err != nil { log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err) - return err + return fmt.Errorf(flushError, err) } - // netbird-acl-output-filter - // type filter hook output priority filter; policy accept; - m.chainPrerouting = m.createPreroutingMangle() - err = m.rConn.Flush() - if err != nil { - log.Debugf("failed to create chain (%s): %s", m.chainPrerouting.Name, err) - return err - } return nil } @@ -667,59 +494,6 @@ func (m *AclManager) addJumpRulesToRtForward() { Chain: m.chainFwFilter, Exprs: expressions, }) - - expressions = []expr.Any{ - &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Verdict{ - Kind: expr.VerdictJump, - Chain: m.routeingFwChainName, - }, - } - - _ = m.rConn.AddRule(&nftables.Rule{ - Table: m.workTable, - Chain: m.chainFwFilter, - Exprs: expressions, - }) -} - -func (m *AclManager) addMarkAccept() { - // oifname "wt0" meta mark 0x000007e4 accept - // iifname "wt0" meta mark 0x000007e4 accept - ifaces := []expr.MetaKey{expr.MetaKeyIIFNAME, expr.MetaKeyOIFNAME} - for _, iface := range ifaces { - expressions := []expr.Any{ - &expr.Meta{Key: iface, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Meta{ - Key: expr.MetaKeyMARK, - Register: 1, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: postroutingMark, - }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - } - - _ = m.rConn.AddRule(&nftables.Rule{ - Table: m.workTable, - Chain: m.chainFwFilter, - Exprs: expressions, - }) - } } func (m *AclManager) createChain(name string) *nftables.Chain { @@ -729,6 +503,9 @@ func (m *AclManager) createChain(name string) *nftables.Chain { } chain = m.rConn.AddChain(chain) + + insertReturnTrafficRule(m.rConn, m.workTable, chain) + return chain } @@ -746,74 +523,6 @@ func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.Cha return m.rConn.AddChain(chain) } -func (m *AclManager) createPreroutingMangle() *nftables.Chain { - polAccept := nftables.ChainPolicyAccept - chain := &nftables.Chain{ - Name: "netbird-acl-prerouting-filter", - Table: m.workTable, - Hooknum: nftables.ChainHookPrerouting, - Priority: nftables.ChainPriorityMangle, - Type: nftables.ChainTypeFilter, - Policy: &polAccept, - } - - chain = m.rConn.AddChain(chain) - - ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) - expressions := []expr.Any{ - &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: expr.CmpOpNeq, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: m.wgIface.Address().IP.To4(), - }, - &expr.Immediate{ - Register: 1, - Data: postroutingMark, - }, - &expr.Meta{ - Key: expr.MetaKeyMARK, - SourceRegister: true, - Register: 1, - }, - } - _ = m.rConn.AddRule(&nftables.Rule{ - Table: m.workTable, - Chain: chain, - Exprs: expressions, - }) - return chain -} - func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any { expressions := []expr.Any{ &expr.Meta{Key: ifaceKey, Register: 1}, @@ -832,101 +541,9 @@ func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.Met return nil } -func (m *AclManager) addJumpRuleToInputChain() { - expressions := []expr.Any{ - &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Verdict{ - Kind: expr.VerdictJump, - Chain: m.chainInputRules.Name, - }, - } - - _ = m.rConn.AddRule(&nftables.Rule{ - Table: m.workTable, - Chain: m.chainFwFilter, - Exprs: expressions, - }) -} - -func (m *AclManager) addRouteAllowRule(chain *nftables.Chain, netIfName expr.MetaKey) { - ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) - var srcOp, dstOp expr.CmpOp - if netIfName == expr.MetaKeyIIFNAME { - srcOp = expr.CmpOpEq - dstOp = expr.CmpOpNeq - } else { - srcOp = expr.CmpOpNeq - dstOp = expr.CmpOpEq - } - expressions := []expr.Any{ - &expr.Meta{Key: netIfName, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: srcOp, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: dstOp, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - } - _ = m.rConn.AddRule(&nftables.Rule{ - Table: chain.Table, - Chain: chain, - Exprs: expressions, - }) -} - func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) { ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) - var srcOp, dstOp expr.CmpOp - if iifname == expr.MetaKeyIIFNAME { - srcOp = expr.CmpOpNeq - dstOp = expr.CmpOpEq - } else { - srcOp = expr.CmpOpEq - dstOp = expr.CmpOpNeq - } + dstOp := expr.CmpOpNeq expressions := []expr.Any{ &expr.Meta{Key: iifname, Register: 1}, &expr.Cmp{ @@ -934,24 +551,6 @@ func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) { Register: 1, Data: ifname(m.wgIface.Name()), }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: srcOp, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, &expr.Payload{ DestRegister: 2, Base: expr.PayloadBaseNetworkHeader, @@ -982,7 +581,6 @@ func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) { } func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) { - ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) expressions := []expr.Any{ &expr.Meta{Key: ifaceKey, Register: 1}, &expr.Cmp{ @@ -990,47 +588,12 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr Register: 1, Data: ifname(m.wgIface.Name()), }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, &expr.Verdict{ Kind: expr.VerdictJump, Chain: to, }, } + _ = m.rConn.AddRule(&nftables.Rule{ Table: chain.Table, Chain: chain, @@ -1132,7 +695,7 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error { return nil } -func generateRuleId( +func generatePeerRuleId( ip net.IP, sPort *firewall.Port, dPort *firewall.Port, @@ -1155,33 +718,6 @@ func generateRuleId( } return "set:" + ipset.Name + rulesetID } -func generateRuleIdForMangle(ipset *nftables.Set, ip net.IP, proto firewall.Protocol, port *firewall.Port) string { - // case of icmp port is empty - var p string - if port != nil { - p = port.String() - } - if ipset != nil { - return fmt.Sprintf("p:set:%s:%s:%v", ipset.Name, proto, p) - } else { - return fmt.Sprintf("p:ip:%s:%s:%v", ip.String(), proto, p) - } -} - -func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool { - if proto == "all" { - return false - } - - if direction != firewall.RuleDirectionIN { - return false - } - - if dPort == nil && proto != firewall.ProtocolICMP { - return false - } - return true -} func encodePort(port firewall.Port) []byte { bs := make([]byte, 2) @@ -1191,6 +727,19 @@ func encodePort(port firewall.Port) []byte { func ifname(n string) []byte { b := make([]byte, 16) - copy(b, []byte(n+"\x00")) + copy(b, n+"\x00") return b } + +func protoToInt(protocol firewall.Protocol) (uint8, error) { + switch protocol { + case firewall.ProtocolTCP: + return unix.IPPROTO_TCP, nil + case firewall.ProtocolUDP: + return unix.IPPROTO_UDP, nil + case firewall.ProtocolICMP: + return unix.IPPROTO_ICMP, nil + } + + return 0, fmt.Errorf("unsupported protocol: %s", protocol) +} diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index a376c98c3..d2258ae08 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -5,9 +5,11 @@ import ( "context" "fmt" "net" + "net/netip" "sync" "github.com/google/nftables" + "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" log "github.com/sirupsen/logrus" @@ -15,8 +17,11 @@ import ( ) const ( - // tableName is the name of the table that is used for filtering by the Netbird client - tableName = "netbird" + // tableNameNetbird is the name of the table that is used for filtering by the Netbird client + tableNameNetbird = "netbird" + + tableNameFilter = "filter" + chainNameInput = "INPUT" ) // Manager of iptables firewall @@ -41,12 +46,12 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { return nil, err } - m.router, err = newRouter(context, workTable) + m.router, err = newRouter(context, workTable, wgIface) if err != nil { return nil, err } - m.aclManager, err = newAclManager(workTable, wgIface, m.router.RouteingFwChainName()) + m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw) if err != nil { return nil, err } @@ -54,11 +59,11 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { return m, nil } -// AddFiltering rule to the firewall +// AddPeerFiltering rule to the firewall // // If comment argument is empty firewall manager should set // rule ID as comment for the rule -func (m *Manager) AddFiltering( +func (m *Manager) AddPeerFiltering( ip net.IP, proto firewall.Protocol, sPort *firewall.Port, @@ -76,33 +81,52 @@ func (m *Manager) AddFiltering( return nil, fmt.Errorf("unsupported IP version: %s", ip.String()) } - return m.aclManager.AddFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment) + return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment) } -// DeleteRule from the firewall by rule definition -func (m *Manager) DeleteRule(rule firewall.Rule) error { +func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) { m.mutex.Lock() defer m.mutex.Unlock() - return m.aclManager.DeleteRule(rule) + if !destination.Addr().Is4() { + return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String()) + } + + return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action) +} + +// DeletePeerRule from the firewall by rule definition +func (m *Manager) DeletePeerRule(rule firewall.Rule) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.aclManager.DeletePeerRule(rule) +} + +// DeleteRouteRule deletes a routing rule +func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.DeleteRouteRule(rule) } func (m *Manager) IsServerRouteSupported() bool { return true } -func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error { +func (m *Manager) AddNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.AddRoutingRules(pair) + return m.router.AddNatRule(pair) } -func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error { +func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { m.mutex.Lock() defer m.mutex.Unlock() - return m.router.RemoveRoutingRules(pair) + return m.router.RemoveNatRule(pair) } // AllowNetbird allows netbird interface traffic @@ -126,7 +150,7 @@ func (m *Manager) AllowNetbird() error { var chain *nftables.Chain for _, c := range chains { - if c.Table.Name == "filter" && c.Name == "INPUT" { + if c.Table.Name == tableNameFilter && c.Name == chainNameForward { chain = c break } @@ -157,6 +181,27 @@ func (m *Manager) AllowNetbird() error { return nil } +// SetLegacyManagement sets the route manager to use legacy management +func (m *Manager) SetLegacyManagement(isLegacy bool) error { + oldLegacy := m.router.legacyManagement + + if oldLegacy != isLegacy { + m.router.legacyManagement = isLegacy + log.Debugf("Set legacy management to %v", isLegacy) + } + + // client reconnected to a newer mgmt, we need to cleanup the legacy rules + if !isLegacy && oldLegacy { + if err := m.router.RemoveAllLegacyRouteRules(); err != nil { + return fmt.Errorf("remove legacy routing rules: %v", err) + } + + log.Debugf("Legacy routing rules removed") + } + + return nil +} + // Reset firewall to the default state func (m *Manager) Reset() error { m.mutex.Lock() @@ -185,14 +230,16 @@ func (m *Manager) Reset() error { } } - m.router.ResetForwardRules() + if err := m.router.Reset(); err != nil { + return fmt.Errorf("reset forward rules: %v", err) + } tables, err := m.rConn.ListTables() if err != nil { return fmt.Errorf("list of tables: %w", err) } for _, t := range tables { - if t.Name == tableName { + if t.Name == tableNameNetbird { m.rConn.DelTable(t) } } @@ -218,12 +265,12 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) { } for _, t := range tables { - if t.Name == tableName { + if t.Name == tableNameNetbird { m.rConn.DelTable(t) } } - table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}) + table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}) err = m.rConn.Flush() return table, err } @@ -239,9 +286,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) { Register: 1, Data: ifname(m.wgIface.Name()), }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, + &expr.Verdict{}, }, UserData: []byte(allowNetbirdInputRuleID), } @@ -251,7 +296,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) { func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule { ifName := ifname(m.wgIface.Name()) for _, rule := range existedRules { - if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" { + if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput { if len(rule.Exprs) < 4 { if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME { continue @@ -265,3 +310,33 @@ func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftable } return nil } + +func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) { + rule := &nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Ct{ + Key: expr.CtKeySTATE, + Register: 1, + }, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), + Xor: binaryutil.NativeEndian.PutUint32(0), + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: []byte{0, 0, 0, 0}, + }, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + } + + conn.InsertRule(rule) +} diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 1f226e315..904050a51 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -9,14 +9,30 @@ import ( "time" "github.com/google/nftables" + "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" "github.com/stretchr/testify/require" "golang.org/x/sys/unix" fw "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" ) +var ifaceMock = &iFaceMock{ + NameFunc: func() string { + return "lo" + }, + AddressFunc: func() iface.WGAddress { + return iface.WGAddress{ + IP: net.ParseIP("100.96.0.1"), + Network: &net.IPNet{ + IP: net.ParseIP("100.96.0.0"), + Mask: net.IPv4Mask(255, 255, 255, 0), + }, + } + }, +} + // iFaceMapper defines subset methods of interface required for manager type iFaceMock struct { NameFunc func() string @@ -40,23 +56,9 @@ func (i *iFaceMock) Address() iface.WGAddress { func (i *iFaceMock) IsUserspaceBind() bool { return false } func TestNftablesManager(t *testing.T) { - mock := &iFaceMock{ - NameFunc: func() string { - return "lo" - }, - AddressFunc: func() iface.WGAddress { - return iface.WGAddress{ - IP: net.ParseIP("100.96.0.1"), - Network: &net.IPNet{ - IP: net.ParseIP("100.96.0.0"), - Mask: net.IPv4Mask(255, 255, 255, 0), - }, - } - }, - } // just check on the local interface - manager, err := Create(context.Background(), mock) + manager, err := Create(context.Background(), ifaceMock) require.NoError(t, err) time.Sleep(time.Second * 3) @@ -70,7 +72,7 @@ func TestNftablesManager(t *testing.T) { testClient := &nftables.Conn{} - rule, err := manager.AddFiltering( + rule, err := manager.AddPeerFiltering( ip, fw.ProtocolTCP, nil, @@ -88,17 +90,34 @@ func TestNftablesManager(t *testing.T) { rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules) require.NoError(t, err, "failed to get rules") - require.Len(t, rules, 1, "expected 1 rules") + require.Len(t, rules, 2, "expected 2 rules") + + expectedExprs1 := []expr.Any{ + &expr.Ct{ + Key: expr.CtKeySTATE, + Register: 1, + }, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), + Xor: binaryutil.NativeEndian.PutUint32(0), + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: []byte{0, 0, 0, 0}, + }, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + } + require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions") ipToAdd, _ := netip.AddrFromSlice(ip) add := ipToAdd.Unmap() - expectedExprs := []expr.Any{ - &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname("lo"), - }, + expectedExprs2 := []expr.Any{ &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, @@ -134,10 +153,10 @@ func TestNftablesManager(t *testing.T) { }, &expr.Verdict{Kind: expr.VerdictDrop}, } - require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions") + require.ElementsMatch(t, rules[1].Exprs, expectedExprs2, "expected the same expressions") for _, r := range rule { - err = manager.DeleteRule(r) + err = manager.DeletePeerRule(r) require.NoError(t, err, "failed to delete rule") } @@ -146,7 +165,8 @@ func TestNftablesManager(t *testing.T) { rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules) require.NoError(t, err, "failed to get rules") - require.Len(t, rules, 0, "expected 0 rules after deletion") + // established rule remains + require.Len(t, rules, 1, "expected 1 rules after deletion") err = manager.Reset() require.NoError(t, err, "failed to reset") @@ -187,9 +207,9 @@ func TestNFtablesCreatePerformance(t *testing.T) { for i := 0; i < testMax; i++ { port := &fw.Port{Values: []int{1000 + i}} if i%2 == 0 { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") } else { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") } require.NoError(t, err, "failed to add rule") diff --git a/client/firewall/nftables/route_linux.go b/client/firewall/nftables/route_linux.go deleted file mode 100644 index 71d5ac88e..000000000 --- a/client/firewall/nftables/route_linux.go +++ /dev/null @@ -1,431 +0,0 @@ -package nftables - -import ( - "bytes" - "context" - "errors" - "fmt" - "net" - "net/netip" - - "github.com/google/nftables" - "github.com/google/nftables/binaryutil" - "github.com/google/nftables/expr" - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/client/firewall/manager" -) - -const ( - chainNameRouteingFw = "netbird-rt-fwd" - chainNameRoutingNat = "netbird-rt-nat" - - userDataAcceptForwardRuleSrc = "frwacceptsrc" - userDataAcceptForwardRuleDst = "frwacceptdst" - - loopbackInterface = "lo\x00" -) - -// some presets for building nftable rules -var ( - zeroXor = binaryutil.NativeEndian.PutUint32(0) - - exprCounterAccept = []expr.Any{ - &expr.Counter{}, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - } - - errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found") -) - -type router struct { - ctx context.Context - stop context.CancelFunc - conn *nftables.Conn - workTable *nftables.Table - filterTable *nftables.Table - chains map[string]*nftables.Chain - // rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules - rules map[string]*nftables.Rule - isDefaultFwdRulesEnabled bool -} - -func newRouter(parentCtx context.Context, workTable *nftables.Table) (*router, error) { - ctx, cancel := context.WithCancel(parentCtx) - - r := &router{ - ctx: ctx, - stop: cancel, - conn: &nftables.Conn{}, - workTable: workTable, - chains: make(map[string]*nftables.Chain), - rules: make(map[string]*nftables.Rule), - } - - var err error - r.filterTable, err = r.loadFilterTable() - if err != nil { - if errors.Is(err, errFilterTableNotFound) { - log.Warnf("table 'filter' not found for forward rules") - } else { - return nil, err - } - } - - err = r.cleanUpDefaultForwardRules() - if err != nil { - log.Errorf("failed to clean up rules from FORWARD chain: %s", err) - } - - err = r.createContainers() - if err != nil { - log.Errorf("failed to create containers for route: %s", err) - } - return r, err -} - -func (r *router) RouteingFwChainName() string { - return chainNameRouteingFw -} - -// ResetForwardRules cleans existing nftables default forward rules from the system -func (r *router) ResetForwardRules() { - err := r.cleanUpDefaultForwardRules() - if err != nil { - log.Errorf("failed to reset forward rules: %s", err) - } -} - -func (r *router) loadFilterTable() (*nftables.Table, error) { - tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) - if err != nil { - return nil, fmt.Errorf("nftables: unable to list tables: %v", err) - } - - for _, table := range tables { - if table.Name == "filter" { - return table, nil - } - } - - return nil, errFilterTableNotFound -} - -func (r *router) createContainers() error { - - r.chains[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{ - Name: chainNameRouteingFw, - Table: r.workTable, - }) - - r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{ - Name: chainNameRoutingNat, - Table: r.workTable, - Hooknum: nftables.ChainHookPostrouting, - Priority: nftables.ChainPriorityNATSource - 1, - Type: nftables.ChainTypeNAT, - }) - - // Add RETURN rule for loopback interface - loRule := &nftables.Rule{ - Table: r.workTable, - Chain: r.chains[chainNameRoutingNat], - Exprs: []expr.Any{ - &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: []byte(loopbackInterface), - }, - &expr.Verdict{Kind: expr.VerdictReturn}, - }, - } - r.conn.InsertRule(loRule) - - err := r.refreshRulesMap() - if err != nil { - log.Errorf("failed to clean up rules from FORWARD chain: %s", err) - } - - err = r.conn.Flush() - if err != nil { - return fmt.Errorf("nftables: unable to initialize table: %v", err) - } - return nil -} - -// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain -func (r *router) AddRoutingRules(pair manager.RouterPair) error { - err := r.refreshRulesMap() - if err != nil { - return err - } - - err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false) - if err != nil { - return err - } - err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false) - if err != nil { - return err - } - - if pair.Masquerade { - err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true) - if err != nil { - return err - } - err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true) - if err != nil { - return err - } - } - - if r.filterTable != nil && !r.isDefaultFwdRulesEnabled { - log.Debugf("add default accept forward rule") - r.acceptForwardRule(pair.Source) - } - - err = r.conn.Flush() - if err != nil { - return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.Destination, err) - } - return nil -} - -// addRoutingRule inserts a nftable rule to the conn client flush queue -func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error { - sourceExp := generateCIDRMatcherExpressions(true, pair.Source) - destExp := generateCIDRMatcherExpressions(false, pair.Destination) - - var expression []expr.Any - if isNat { - expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) // nolint:gocritic - } else { - expression = append(sourceExp, append(destExp, exprCounterAccept...)...) // nolint:gocritic - } - - ruleKey := manager.GenKey(format, pair.ID) - - _, exists := r.rules[ruleKey] - if exists { - err := r.removeRoutingRule(format, pair) - if err != nil { - return err - } - } - - r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ - Table: r.workTable, - Chain: r.chains[chainName], - Exprs: expression, - UserData: []byte(ruleKey), - }) - return nil -} - -func (r *router) acceptForwardRule(sourceNetwork string) { - src := generateCIDRMatcherExpressions(true, sourceNetwork) - dst := generateCIDRMatcherExpressions(false, "0.0.0.0/0") - - var exprs []expr.Any - exprs = append(src, append(dst, &expr.Verdict{ // nolint:gocritic - Kind: expr.VerdictAccept, - })...) - - rule := &nftables.Rule{ - Table: r.filterTable, - Chain: &nftables.Chain{ - Name: "FORWARD", - Table: r.filterTable, - Type: nftables.ChainTypeFilter, - Hooknum: nftables.ChainHookForward, - Priority: nftables.ChainPriorityFilter, - }, - Exprs: exprs, - UserData: []byte(userDataAcceptForwardRuleSrc), - } - - r.conn.AddRule(rule) - - src = generateCIDRMatcherExpressions(true, "0.0.0.0/0") - dst = generateCIDRMatcherExpressions(false, sourceNetwork) - - exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic - Kind: expr.VerdictAccept, - })...) - - rule = &nftables.Rule{ - Table: r.filterTable, - Chain: &nftables.Chain{ - Name: "FORWARD", - Table: r.filterTable, - Type: nftables.ChainTypeFilter, - Hooknum: nftables.ChainHookForward, - Priority: nftables.ChainPriorityFilter, - }, - Exprs: exprs, - UserData: []byte(userDataAcceptForwardRuleDst), - } - r.conn.AddRule(rule) - r.isDefaultFwdRulesEnabled = true -} - -// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains -func (r *router) RemoveRoutingRules(pair manager.RouterPair) error { - err := r.refreshRulesMap() - if err != nil { - return err - } - - err = r.removeRoutingRule(manager.ForwardingFormat, pair) - if err != nil { - return err - } - - err = r.removeRoutingRule(manager.InForwardingFormat, manager.GetInPair(pair)) - if err != nil { - return err - } - - err = r.removeRoutingRule(manager.NatFormat, pair) - if err != nil { - return err - } - - err = r.removeRoutingRule(manager.InNatFormat, manager.GetInPair(pair)) - if err != nil { - return err - } - - if len(r.rules) == 0 { - err := r.cleanUpDefaultForwardRules() - if err != nil { - log.Errorf("failed to clean up rules from FORWARD chain: %s", err) - } - } - - err = r.conn.Flush() - if err != nil { - return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err) - } - log.Debugf("nftables: removed rules for %s", pair.Destination) - return nil -} - -// removeRoutingRule add a nftable rule to the removal queue and delete from rules map -func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error { - ruleKey := manager.GenKey(format, pair.ID) - - rule, found := r.rules[ruleKey] - if found { - ruleType := "forwarding" - if rule.Chain.Type == nftables.ChainTypeNAT { - ruleType = "nat" - } - - err := r.conn.DelRule(rule) - if err != nil { - return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.Destination, err) - } - - log.Debugf("nftables: removing %s rule for %s", ruleType, pair.Destination) - - delete(r.rules, ruleKey) - } - return nil -} - -// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid -// duplicates and to get missing attributes that we don't have when adding new rules -func (r *router) refreshRulesMap() error { - for _, chain := range r.chains { - rules, err := r.conn.GetRules(chain.Table, chain) - if err != nil { - return fmt.Errorf("nftables: unable to list rules: %v", err) - } - for _, rule := range rules { - if len(rule.UserData) > 0 { - r.rules[string(rule.UserData)] = rule - } - } - } - return nil -} - -func (r *router) cleanUpDefaultForwardRules() error { - if r.filterTable == nil { - r.isDefaultFwdRulesEnabled = false - return nil - } - - chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) - if err != nil { - return err - } - - var rules []*nftables.Rule - for _, chain := range chains { - if chain.Table.Name != r.filterTable.Name { - continue - } - if chain.Name != "FORWARD" { - continue - } - - rules, err = r.conn.GetRules(r.filterTable, chain) - if err != nil { - return err - } - } - - for _, rule := range rules { - if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleSrc)) || bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleDst)) { - err := r.conn.DelRule(rule) - if err != nil { - return err - } - } - } - r.isDefaultFwdRulesEnabled = false - return r.conn.Flush() -} - -// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR -func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any { - ip, network, _ := net.ParseCIDR(cidr) - ipToAdd, _ := netip.AddrFromSlice(ip) - add := ipToAdd.Unmap() - - var offSet uint32 - if source { - offSet = 12 // src offset - } else { - offSet = 16 // dst offset - } - - return []expr.Any{ - // fetch src add - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: offSet, - Len: 4, - }, - // net mask - &expr.Bitwise{ - DestRegister: 1, - SourceRegister: 1, - Len: 4, - Mask: network.Mask, - Xor: zeroXor, - }, - // net address - &expr.Cmp{ - Register: 1, - Data: add.AsSlice(), - }, - } -} diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go new file mode 100644 index 000000000..aa61e1858 --- /dev/null +++ b/client/firewall/nftables/router_linux.go @@ -0,0 +1,798 @@ +package nftables + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "fmt" + "net" + "net/netip" + "strings" + + "github.com/google/nftables" + "github.com/google/nftables/binaryutil" + "github.com/google/nftables/expr" + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" + firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/internal/acl/id" + "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" +) + +const ( + chainNameRoutingFw = "netbird-rt-fwd" + chainNameRoutingNat = "netbird-rt-nat" + chainNameForward = "FORWARD" + + userDataAcceptForwardRuleIif = "frwacceptiif" + userDataAcceptForwardRuleOif = "frwacceptoif" +) + +const refreshRulesMapError = "refresh rules map: %w" + +var ( + errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found") +) + +type router struct { + ctx context.Context + stop context.CancelFunc + conn *nftables.Conn + workTable *nftables.Table + filterTable *nftables.Table + chains map[string]*nftables.Chain + // rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules + rules map[string]*nftables.Rule + ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set] + + wgIface iFaceMapper + legacyManagement bool +} + +func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFaceMapper) (*router, error) { + ctx, cancel := context.WithCancel(parentCtx) + + r := &router{ + ctx: ctx, + stop: cancel, + conn: &nftables.Conn{}, + workTable: workTable, + chains: make(map[string]*nftables.Chain), + rules: make(map[string]*nftables.Rule), + wgIface: wgIface, + } + + r.ipsetCounter = refcounter.New( + r.createIpSet, + r.deleteIpSet, + ) + + var err error + r.filterTable, err = r.loadFilterTable() + if err != nil { + if errors.Is(err, errFilterTableNotFound) { + log.Warnf("table 'filter' not found for forward rules") + } else { + return nil, err + } + } + + err = r.cleanUpDefaultForwardRules() + if err != nil { + log.Errorf("failed to clean up rules from FORWARD chain: %s", err) + } + + err = r.createContainers() + if err != nil { + log.Errorf("failed to create containers for route: %s", err) + } + return r, err +} + +// Reset cleans existing nftables default forward rules from the system +func (r *router) Reset() error { + // clear without deleting the ipsets, the nf table will be deleted by the caller + r.ipsetCounter.Clear() + + return r.cleanUpDefaultForwardRules() +} + +func (r *router) cleanUpDefaultForwardRules() error { + if r.filterTable == nil { + return nil + } + + chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) + if err != nil { + return fmt.Errorf("list chains: %v", err) + } + + for _, chain := range chains { + if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward { + continue + } + + rules, err := r.conn.GetRules(r.filterTable, chain) + if err != nil { + return fmt.Errorf("get rules: %v", err) + } + + for _, rule := range rules { + if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || + bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) { + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("delete rule: %v", err) + } + } + } + } + + return r.conn.Flush() +} + +func (r *router) loadFilterTable() (*nftables.Table, error) { + tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) + if err != nil { + return nil, fmt.Errorf("nftables: unable to list tables: %v", err) + } + + for _, table := range tables { + if table.Name == "filter" { + return table, nil + } + } + + return nil, errFilterTableNotFound +} + +func (r *router) createContainers() error { + + r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{ + Name: chainNameRoutingFw, + Table: r.workTable, + }) + + insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw]) + + r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{ + Name: chainNameRoutingNat, + Table: r.workTable, + Hooknum: nftables.ChainHookPostrouting, + Priority: nftables.ChainPriorityNATSource - 1, + Type: nftables.ChainTypeNAT, + }) + + r.acceptForwardRules() + + err := r.refreshRulesMap() + if err != nil { + log.Errorf("failed to clean up rules from FORWARD chain: %s", err) + } + + err = r.conn.Flush() + if err != nil { + return fmt.Errorf("nftables: unable to initialize table: %v", err) + } + return nil +} + +// AddRouteFiltering appends a nftables rule to the routing chain +func (r *router) AddRouteFiltering( + sources []netip.Prefix, + destination netip.Prefix, + proto firewall.Protocol, + sPort *firewall.Port, + dPort *firewall.Port, + action firewall.Action, +) (firewall.Rule, error) { + ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action) + if _, ok := r.rules[string(ruleKey)]; ok { + return ruleKey, nil + } + + chain := r.chains[chainNameRoutingFw] + var exprs []expr.Any + + switch { + case len(sources) == 1 && sources[0].Bits() == 0: + // If it's 0.0.0.0/0, we don't need to add any source matching + case len(sources) == 1: + // If there's only one source, we can use it directly + exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...) + default: + // If there are multiple sources, create or get an ipset + var err error + exprs, err = r.getIpSetExprs(sources, exprs) + if err != nil { + return nil, fmt.Errorf("get ipset expressions: %w", err) + } + } + + // Handle destination + exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...) + + // Handle protocol + if proto != firewall.ProtocolALL { + protoNum, err := protoToInt(proto) + if err != nil { + return nil, fmt.Errorf("convert protocol to number: %w", err) + } + exprs = append(exprs, &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1}) + exprs = append(exprs, &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{protoNum}, + }) + + exprs = append(exprs, applyPort(sPort, true)...) + exprs = append(exprs, applyPort(dPort, false)...) + } + + exprs = append(exprs, &expr.Counter{}) + + var verdict expr.VerdictKind + if action == firewall.ActionAccept { + verdict = expr.VerdictAccept + } else { + verdict = expr.VerdictDrop + } + exprs = append(exprs, &expr.Verdict{Kind: verdict}) + + rule := &nftables.Rule{ + Table: r.workTable, + Chain: chain, + Exprs: exprs, + UserData: []byte(ruleKey), + } + + r.rules[string(ruleKey)] = r.conn.AddRule(rule) + + return ruleKey, r.conn.Flush() +} + +func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) { + setName := firewall.GenerateSetName(sources) + ref, err := r.ipsetCounter.Increment(setName, sources) + if err != nil { + return nil, fmt.Errorf("create or get ipset for sources: %w", err) + } + + exprs = append(exprs, + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: 12, + Len: 4, + }, + &expr.Lookup{ + SourceRegister: 1, + SetName: ref.Out.Name, + SetID: ref.Out.ID, + }, + ) + return exprs, nil +} + +func (r *router) DeleteRouteRule(rule firewall.Rule) error { + if err := r.refreshRulesMap(); err != nil { + return fmt.Errorf(refreshRulesMapError, err) + } + + ruleKey := rule.GetRuleID() + nftRule, exists := r.rules[ruleKey] + if !exists { + log.Debugf("route rule %s not found", ruleKey) + return nil + } + + setName := r.findSetNameInRule(nftRule) + + if err := r.deleteNftRule(nftRule, ruleKey); err != nil { + return fmt.Errorf("delete: %w", err) + } + + if setName != "" { + if _, err := r.ipsetCounter.Decrement(setName); err != nil { + return fmt.Errorf("decrement ipset reference: %w", err) + } + } + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf(flushError, err) + } + + return nil +} + +func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) { + // overlapping prefixes will result in an error, so we need to merge them + sources = firewall.MergeIPRanges(sources) + + set := &nftables.Set{ + Name: setName, + Table: r.workTable, + // required for prefixes + Interval: true, + KeyType: nftables.TypeIPAddr, + } + + var elements []nftables.SetElement + for _, prefix := range sources { + // TODO: Implement IPv6 support + if prefix.Addr().Is6() { + log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) + continue + } + + // nftables needs half-open intervals [firstIP, lastIP) for prefixes + // e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc + firstIP := prefix.Addr() + lastIP := calculateLastIP(prefix).Next() + + elements = append(elements, + // the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247 + // nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true}, + nftables.SetElement{Key: firstIP.AsSlice()}, + nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true}, + ) + } + + if err := r.conn.AddSet(set, elements); err != nil { + return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err) + } + + if err := r.conn.Flush(); err != nil { + return nil, fmt.Errorf("flush error: %w", err) + } + + log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2) + + return set, nil +} + +// calculateLastIP determines the last IP in a given prefix. +func calculateLastIP(prefix netip.Prefix) netip.Addr { + hostMask := ^uint32(0) >> prefix.Masked().Bits() + lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask + + return netip.AddrFrom4(uint32ToBytes(lastIP)) +} + +// Utility function to convert netip.Addr to uint32. +func uint32FromNetipAddr(addr netip.Addr) uint32 { + b := addr.As4() + return binary.BigEndian.Uint32(b[:]) +} + +// Utility function to convert uint32 to a netip-compatible byte slice. +func uint32ToBytes(ip uint32) [4]byte { + var b [4]byte + binary.BigEndian.PutUint32(b[:], ip) + return b +} + +func (r *router) deleteIpSet(setName string, set *nftables.Set) error { + r.conn.DelSet(set) + if err := r.conn.Flush(); err != nil { + return fmt.Errorf(flushError, err) + } + + log.Debugf("Deleted unused ipset %s", setName) + return nil +} + +func (r *router) findSetNameInRule(rule *nftables.Rule) string { + for _, e := range rule.Exprs { + if lookup, ok := e.(*expr.Lookup); ok { + return lookup.SetName + } + } + return "" +} + +func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error { + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("delete rule %s: %w", ruleKey, err) + } + delete(r.rules, ruleKey) + + log.Debugf("removed route rule %s", ruleKey) + + return nil +} + +// AddNatRule appends a nftables rule pair to the nat chain +func (r *router) AddNatRule(pair firewall.RouterPair) error { + if err := r.refreshRulesMap(); err != nil { + return fmt.Errorf(refreshRulesMapError, err) + } + + if r.legacyManagement { + log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination) + if err := r.addLegacyRouteRule(pair); err != nil { + return fmt.Errorf("add legacy routing rule: %w", err) + } + } + + if pair.Masquerade { + if err := r.addNatRule(pair); err != nil { + return fmt.Errorf("add nat rule: %w", err) + } + + if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil { + return fmt.Errorf("add inverse nat rule: %w", err) + } + } + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err) + } + + return nil +} + +// addNatRule inserts a nftables rule to the conn client flush queue +func (r *router) addNatRule(pair firewall.RouterPair) error { + sourceExp := generateCIDRMatcherExpressions(true, pair.Source) + destExp := generateCIDRMatcherExpressions(false, pair.Destination) + + dir := expr.MetaKeyIIFNAME + if pair.Inverse { + dir = expr.MetaKeyOIFNAME + } + + intf := ifname(r.wgIface.Name()) + exprs := []expr.Any{ + &expr.Meta{ + Key: dir, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: intf, + }, + } + + exprs = append(exprs, sourceExp...) + exprs = append(exprs, destExp...) + exprs = append(exprs, + &expr.Counter{}, &expr.Masq{}, + ) + + ruleKey := firewall.GenKey(firewall.NatFormat, pair) + + if _, exists := r.rules[ruleKey]; exists { + if err := r.removeNatRule(pair); err != nil { + return fmt.Errorf("remove routing rule: %w", err) + } + } + + r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameRoutingNat], + Exprs: exprs, + UserData: []byte(ruleKey), + }) + return nil +} + +// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls +func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { + sourceExp := generateCIDRMatcherExpressions(true, pair.Source) + destExp := generateCIDRMatcherExpressions(false, pair.Destination) + + exprs := []expr.Any{ + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + } + + expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic + + ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) + + if _, exists := r.rules[ruleKey]; exists { + if err := r.removeLegacyRouteRule(pair); err != nil { + return fmt.Errorf("remove legacy routing rule: %w", err) + } + } + + r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameRoutingFw], + Exprs: expression, + UserData: []byte(ruleKey), + }) + return nil +} + +// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls +func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error { + ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) + + if rule, exists := r.rules[ruleKey]; exists { + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) + } + + log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination) + + delete(r.rules, ruleKey) + } else { + log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey) + } + + return nil +} + +// GetLegacyManagement returns the route manager's legacy management mode +func (r *router) GetLegacyManagement() bool { + return r.legacyManagement +} + +// SetLegacyManagement sets the route manager to use legacy management mode +func (r *router) SetLegacyManagement(isLegacy bool) { + r.legacyManagement = isLegacy +} + +// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls +func (r *router) RemoveAllLegacyRouteRules() error { + if err := r.refreshRulesMap(); err != nil { + return fmt.Errorf(refreshRulesMapError, err) + } + + var merr *multierror.Error + for k, rule := range r.rules { + if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) { + continue + } + if err := r.conn.DelRule(rule); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) + } + } + return nberrors.FormatErrorOrNil(merr) +} + +// acceptForwardRules adds iif/oif rules in the filter table/forward chain to make sure +// that our traffic is not dropped by existing rules there. +// The existing FORWARD rules/policies decide outbound traffic towards our interface. +// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. +func (r *router) acceptForwardRules() { + if r.filterTable == nil { + log.Debugf("table 'filter' not found for forward rules, skipping accept rules") + return + } + + intf := ifname(r.wgIface.Name()) + + // Rule for incoming interface (iif) with counter + iifRule := &nftables.Rule{ + Table: r.filterTable, + Chain: &nftables.Chain{ + Name: "FORWARD", + Table: r.filterTable, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityFilter, + }, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: intf, + }, + &expr.Counter{}, + &expr.Verdict{Kind: expr.VerdictAccept}, + }, + UserData: []byte(userDataAcceptForwardRuleIif), + } + r.conn.InsertRule(iifRule) + + // Rule for outgoing interface (oif) with counter + oifRule := &nftables.Rule{ + Table: r.filterTable, + Chain: &nftables.Chain{ + Name: "FORWARD", + Table: r.filterTable, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityFilter, + }, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: intf, + }, + &expr.Ct{ + Key: expr.CtKeySTATE, + Register: 2, + }, + &expr.Bitwise{ + SourceRegister: 2, + DestRegister: 2, + Len: 4, + Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED), + Xor: binaryutil.NativeEndian.PutUint32(0), + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 2, + Data: []byte{0, 0, 0, 0}, + }, + &expr.Counter{}, + &expr.Verdict{Kind: expr.VerdictAccept}, + }, + UserData: []byte(userDataAcceptForwardRuleOif), + } + + r.conn.InsertRule(oifRule) +} + +// RemoveNatRule removes a nftables rule pair from nat chains +func (r *router) RemoveNatRule(pair firewall.RouterPair) error { + if err := r.refreshRulesMap(); err != nil { + return fmt.Errorf(refreshRulesMapError, err) + } + + if err := r.removeNatRule(pair); err != nil { + return fmt.Errorf("remove nat rule: %w", err) + } + + if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { + return fmt.Errorf("remove inverse nat rule: %w", err) + } + + if err := r.removeLegacyRouteRule(pair); err != nil { + return fmt.Errorf("remove legacy routing rule: %w", err) + } + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err) + } + + log.Debugf("nftables: removed rules for %s", pair.Destination) + return nil +} + +// removeNatRule adds a nftables rule to the removal queue and deletes it from the rules map +func (r *router) removeNatRule(pair firewall.RouterPair) error { + ruleKey := firewall.GenKey(firewall.NatFormat, pair) + + if rule, exists := r.rules[ruleKey]; exists { + err := r.conn.DelRule(rule) + if err != nil { + return fmt.Errorf("remove nat rule %s -> %s: %v", pair.Source, pair.Destination, err) + } + + log.Debugf("nftables: removed nat rule %s -> %s", pair.Source, pair.Destination) + + delete(r.rules, ruleKey) + } else { + log.Debugf("nftables: nat rule %s not found", ruleKey) + } + + return nil +} + +// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid +// duplicates and to get missing attributes that we don't have when adding new rules +func (r *router) refreshRulesMap() error { + for _, chain := range r.chains { + rules, err := r.conn.GetRules(chain.Table, chain) + if err != nil { + return fmt.Errorf("nftables: unable to list rules: %v", err) + } + for _, rule := range rules { + if len(rule.UserData) > 0 { + r.rules[string(rule.UserData)] = rule + } + } + } + return nil +} + +// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR +func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any { + var offset uint32 + if source { + offset = 12 // src offset + } else { + offset = 16 // dst offset + } + + ones := prefix.Bits() + // 0.0.0.0/0 doesn't need extra expressions + if ones == 0 { + return nil + } + + mask := net.CIDRMask(ones, 32) + + return []expr.Any{ + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: offset, + Len: 4, + }, + // netmask + &expr.Bitwise{ + DestRegister: 1, + SourceRegister: 1, + Len: 4, + Mask: mask, + Xor: []byte{0, 0, 0, 0}, + }, + // net address + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: prefix.Masked().Addr().AsSlice(), + }, + } +} + +func applyPort(port *firewall.Port, isSource bool) []expr.Any { + if port == nil { + return nil + } + + var exprs []expr.Any + + offset := uint32(2) // Default offset for destination port + if isSource { + offset = 0 // Offset for source port + } + + exprs = append(exprs, &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseTransportHeader, + Offset: offset, + Len: 2, + }) + + if port.IsRange && len(port.Values) == 2 { + // Handle port range + exprs = append(exprs, + &expr.Cmp{ + Op: expr.CmpOpGte, + Register: 1, + Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[0])), + }, + &expr.Cmp{ + Op: expr.CmpOpLte, + Register: 1, + Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[1])), + }, + ) + } else { + // Handle single port or multiple ports + for i, p := range port.Values { + if i > 0 { + // Add a bitwise OR operation between port checks + exprs = append(exprs, &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: []byte{0x00, 0x00, 0xff, 0xff}, + Xor: []byte{0x00, 0x00, 0x00, 0x00}, + }) + } + exprs = append(exprs, &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: binaryutil.BigEndian.PutUint16(uint16(p)), + }) + } + } + + return exprs +} diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index 913fbd5d2..bbf92f3be 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -4,11 +4,15 @@ package nftables import ( "context" + "encoding/binary" + "net/netip" + "os/exec" "testing" "github.com/coreos/go-iptables/iptables" "github.com/google/nftables" "github.com/google/nftables/expr" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -24,56 +28,50 @@ const ( NFTABLES ) -func TestNftablesManager_InsertRoutingRules(t *testing.T) { +func TestNftablesManager_AddNatRule(t *testing.T) { if check() != NFTABLES { t.Skip("nftables not supported on this OS") } table, err := createWorkTable() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Failed to create work table") defer deleteWorkTable() for _, testCase := range test.InsertRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { - manager, err := newRouter(context.TODO(), table) + manager, err := newRouter(context.TODO(), table, ifaceMock) require.NoError(t, err, "failed to create router") nftablesTestingClient := &nftables.Conn{} - defer manager.ResetForwardRules() + defer func(manager *router) { + require.NoError(t, manager.Reset(), "failed to reset rules") + }(manager) require.NoError(t, err, "shouldn't return error") - err = manager.AddRoutingRules(testCase.InputPair) - defer func() { - _ = manager.RemoveRoutingRules(testCase.InputPair) - }() - require.NoError(t, err, "forwarding pair should be inserted") + err = manager.AddNatRule(testCase.InputPair) + require.NoError(t, err, "pair should be inserted") - sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) - destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) - testingExpression := append(sourceExp, destExp...) //nolint:gocritic - fwdRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID) - - found := 0 - for _, chain := range manager.chains { - rules, err := nftablesTestingClient.GetRules(chain.Table, chain) - require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) - for _, rule := range rules { - if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey { - require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match") - found = 1 - } - } - } - - require.Equal(t, 1, found, "should find at least 1 rule to test") + defer func(manager *router, pair firewall.RouterPair) { + require.NoError(t, manager.RemoveNatRule(pair), "failed to remove rule") + }(manager, testCase.InputPair) if testCase.InputPair.Masquerade { - natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID) + sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) + destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) + testingExpression := append(sourceExp, destExp...) //nolint:gocritic + testingExpression = append(testingExpression, + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(ifaceMock.Name()), + }, + ) + + natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) found := 0 for _, chain := range manager.chains { rules, err := nftablesTestingClient.GetRules(chain.Table, chain) @@ -88,27 +86,20 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) { require.Equal(t, 1, found, "should find at least 1 rule to test") } - sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source) - destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination) - testingExpression = append(sourceExp, destExp...) //nolint:gocritic - inFwdRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID) - - found = 0 - for _, chain := range manager.chains { - rules, err := nftablesTestingClient.GetRules(chain.Table, chain) - require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) - for _, rule := range rules { - if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey { - require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match") - found = 1 - } - } - } - - require.Equal(t, 1, found, "should find at least 1 rule to test") - if testCase.InputPair.Masquerade { - inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID) + sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) + destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) + testingExpression := append(sourceExp, destExp...) //nolint:gocritic + testingExpression = append(testingExpression, + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(ifaceMock.Name()), + }, + ) + + inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) found := 0 for _, chain := range manager.chains { rules, err := nftablesTestingClient.GetRules(chain.Table, chain) @@ -122,45 +113,37 @@ func TestNftablesManager_InsertRoutingRules(t *testing.T) { } require.Equal(t, 1, found, "should find at least 1 rule to test") } + }) } } -func TestNftablesManager_RemoveRoutingRules(t *testing.T) { +func TestNftablesManager_RemoveNatRule(t *testing.T) { if check() != NFTABLES { t.Skip("nftables not supported on this OS") } table, err := createWorkTable() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Failed to create work table") defer deleteWorkTable() for _, testCase := range test.RemoveRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { - manager, err := newRouter(context.TODO(), table) + manager, err := newRouter(context.TODO(), table, ifaceMock) require.NoError(t, err, "failed to create router") nftablesTestingClient := &nftables.Conn{} - defer manager.ResetForwardRules() + defer func(manager *router) { + require.NoError(t, manager.Reset(), "failed to reset rules") + }(manager) sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) - forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic - forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID) - insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{ - Table: manager.workTable, - Chain: manager.chains[chainNameRouteingFw], - Exprs: forwardExp, - UserData: []byte(forwardRuleKey), - }) - natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic - natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID) + natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{ Table: manager.workTable, @@ -169,20 +152,11 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) { UserData: []byte(natRuleKey), }) - sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source) - destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination) - - forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic - inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID) - insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{ - Table: manager.workTable, - Chain: manager.chains[chainNameRouteingFw], - Exprs: forwardExp, - UserData: []byte(inForwardRuleKey), - }) + sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInversePair(testCase.InputPair).Source) + destExp = generateCIDRMatcherExpressions(false, firewall.GetInversePair(testCase.InputPair).Destination) natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic - inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID) + inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{ Table: manager.workTable, @@ -194,9 +168,10 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) { err = nftablesTestingClient.Flush() require.NoError(t, err, "shouldn't return error") - manager.ResetForwardRules() + err = manager.Reset() + require.NoError(t, err, "shouldn't return error") - err = manager.RemoveRoutingRules(testCase.InputPair) + err = manager.RemoveNatRule(testCase.InputPair) require.NoError(t, err, "shouldn't return error") for _, chain := range manager.chains { @@ -204,9 +179,7 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) { require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) for _, rule := range rules { if len(rule.UserData) > 0 { - require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist") require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist") - require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist") require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist") } } @@ -215,6 +188,468 @@ func TestNftablesManager_RemoveRoutingRules(t *testing.T) { } } +func TestRouter_AddRouteFiltering(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + workTable, err := createWorkTable() + require.NoError(t, err, "Failed to create work table") + + defer deleteWorkTable() + + r, err := newRouter(context.Background(), workTable, ifaceMock) + require.NoError(t, err, "Failed to create router") + + defer func(r *router) { + require.NoError(t, r.Reset(), "Failed to reset rules") + }(r) + + tests := []struct { + name string + sources []netip.Prefix + destination netip.Prefix + proto firewall.Protocol + sPort *firewall.Port + dPort *firewall.Port + direction firewall.RuleDirection + action firewall.Action + expectSet bool + }{ + { + name: "Basic TCP rule with single source", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}, + destination: netip.MustParsePrefix("10.0.0.0/24"), + proto: firewall.ProtocolTCP, + sPort: nil, + dPort: &firewall.Port{Values: []int{80}}, + direction: firewall.RuleDirectionIN, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "UDP rule with multiple sources", + sources: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/16"), + netip.MustParsePrefix("192.168.0.0/16"), + }, + destination: netip.MustParsePrefix("10.0.0.0/8"), + proto: firewall.ProtocolUDP, + sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true}, + dPort: nil, + direction: firewall.RuleDirectionOUT, + action: firewall.ActionDrop, + expectSet: true, + }, + { + name: "All protocols rule", + sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + destination: netip.MustParsePrefix("0.0.0.0/0"), + proto: firewall.ProtocolALL, + sPort: nil, + dPort: nil, + direction: firewall.RuleDirectionIN, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "ICMP rule", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")}, + destination: netip.MustParsePrefix("10.0.0.0/8"), + proto: firewall.ProtocolICMP, + sPort: nil, + dPort: nil, + direction: firewall.RuleDirectionIN, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "TCP rule with multiple source ports", + sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")}, + destination: netip.MustParsePrefix("192.168.0.0/16"), + proto: firewall.ProtocolTCP, + sPort: &firewall.Port{Values: []int{80, 443, 8080}}, + dPort: nil, + direction: firewall.RuleDirectionOUT, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "UDP rule with single IP and port range", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")}, + destination: netip.MustParsePrefix("10.0.0.0/24"), + proto: firewall.ProtocolUDP, + sPort: nil, + dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true}, + direction: firewall.RuleDirectionIN, + action: firewall.ActionDrop, + expectSet: false, + }, + { + name: "TCP rule with source and destination ports", + sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}, + destination: netip.MustParsePrefix("172.16.0.0/16"), + proto: firewall.ProtocolTCP, + sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true}, + dPort: &firewall.Port{Values: []int{22}}, + direction: firewall.RuleDirectionOUT, + action: firewall.ActionAccept, + expectSet: false, + }, + { + name: "Drop all incoming traffic", + sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + destination: netip.MustParsePrefix("192.168.0.0/24"), + proto: firewall.ProtocolALL, + sPort: nil, + dPort: nil, + direction: firewall.RuleDirectionIN, + action: firewall.ActionDrop, + expectSet: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) + require.NoError(t, err, "AddRouteFiltering failed") + + // Check if the rule is in the internal map + rule, ok := r.rules[ruleKey.GetRuleID()] + assert.True(t, ok, "Rule not found in internal map") + + t.Log("Internal rule expressions:") + for i, expr := range rule.Exprs { + t.Logf(" [%d] %T: %+v", i, expr, expr) + } + + // Verify internal rule content + verifyRule(t, rule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet) + + // Check if the rule exists in nftables and verify its content + rules, err := r.conn.GetRules(r.workTable, r.chains[chainNameRoutingFw]) + require.NoError(t, err, "Failed to get rules from nftables") + + var nftRule *nftables.Rule + for _, rule := range rules { + if string(rule.UserData) == ruleKey.GetRuleID() { + nftRule = rule + break + } + } + + require.NotNil(t, nftRule, "Rule not found in nftables") + t.Log("Actual nftables rule expressions:") + for i, expr := range nftRule.Exprs { + t.Logf(" [%d] %T: %+v", i, expr, expr) + } + + // Verify actual nftables rule content + verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet) + + // Clean up + err = r.DeleteRouteRule(ruleKey) + require.NoError(t, err, "Failed to delete rule") + }) + } +} + +func TestNftablesCreateIpSet(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + workTable, err := createWorkTable() + require.NoError(t, err, "Failed to create work table") + + defer deleteWorkTable() + + r, err := newRouter(context.Background(), workTable, ifaceMock) + require.NoError(t, err, "Failed to create router") + + defer func() { + require.NoError(t, r.Reset(), "Failed to reset router") + }() + + tests := []struct { + name string + sources []netip.Prefix + expected []netip.Prefix + }{ + { + name: "Single IP", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")}, + }, + { + name: "Multiple IPs", + sources: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.1/32"), + netip.MustParsePrefix("10.0.0.1/32"), + netip.MustParsePrefix("172.16.0.1/32"), + }, + }, + { + name: "Single Subnet", + sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}, + }, + { + name: "Multiple Subnets with Various Prefix Lengths", + sources: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("172.16.0.0/16"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("203.0.113.0/26"), + }, + }, + { + name: "Mix of Single IPs and Subnets in Different Positions", + sources: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.1/32"), + netip.MustParsePrefix("10.0.0.0/16"), + netip.MustParsePrefix("172.16.0.1/32"), + netip.MustParsePrefix("203.0.113.0/24"), + }, + }, + { + name: "Overlapping IPs/Subnets", + sources: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("10.0.0.0/16"), + netip.MustParsePrefix("10.0.0.1/32"), + netip.MustParsePrefix("192.168.0.0/16"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.1.1/32"), + }, + expected: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("192.168.0.0/16"), + }, + }, + } + + // Add this helper function inside TestNftablesCreateIpSet + printNftSets := func() { + cmd := exec.Command("nft", "list", "sets") + output, err := cmd.CombinedOutput() + if err != nil { + t.Logf("Failed to run 'nft list sets': %v", err) + } else { + t.Logf("Current nft sets:\n%s", output) + } + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setName := firewall.GenerateSetName(tt.sources) + set, err := r.createIpSet(setName, tt.sources) + if err != nil { + t.Logf("Failed to create IP set: %v", err) + printNftSets() + require.NoError(t, err, "Failed to create IP set") + } + require.NotNil(t, set, "Created set is nil") + + // Verify set properties + assert.Equal(t, setName, set.Name, "Set name mismatch") + assert.Equal(t, r.workTable, set.Table, "Set table mismatch") + assert.True(t, set.Interval, "Set interval property should be true") + assert.Equal(t, nftables.TypeIPAddr, set.KeyType, "Set key type mismatch") + + // Fetch the created set from nftables + fetchedSet, err := r.conn.GetSetByName(r.workTable, setName) + require.NoError(t, err, "Failed to fetch created set") + require.NotNil(t, fetchedSet, "Fetched set is nil") + + // Verify set elements + elements, err := r.conn.GetSetElements(fetchedSet) + require.NoError(t, err, "Failed to get set elements") + + // Count the number of unique prefixes (excluding interval end markers) + uniquePrefixes := make(map[string]bool) + for _, elem := range elements { + if !elem.IntervalEnd { + ip := netip.AddrFrom4(*(*[4]byte)(elem.Key)) + uniquePrefixes[ip.String()] = true + } + } + + // Check against expected merged prefixes + expectedCount := len(tt.expected) + if expectedCount == 0 { + expectedCount = len(tt.sources) + } + assert.Equal(t, expectedCount, len(uniquePrefixes), "Number of unique prefixes in set doesn't match expected") + + // Verify each expected prefix is in the set + for _, expected := range tt.expected { + found := false + for _, elem := range elements { + if !elem.IntervalEnd { + ip := netip.AddrFrom4(*(*[4]byte)(elem.Key)) + if expected.Contains(ip) { + found = true + break + } + } + } + assert.True(t, found, "Expected prefix %s not found in set", expected) + } + + r.conn.DelSet(set) + if err := r.conn.Flush(); err != nil { + t.Logf("Failed to delete set: %v", err) + printNftSets() + } + require.NoError(t, err, "Failed to delete set") + }) + } +} + +func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) { + t.Helper() + + assert.NotNil(t, rule, "Rule should not be nil") + + // Verify sources and destination + if expectSet { + assert.True(t, containsSetLookup(rule.Exprs), "Rule should contain set lookup for multiple sources") + } else if len(sources) == 1 && sources[0].Bits() != 0 { + if direction == firewall.RuleDirectionIN { + assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], true), "Rule should contain source CIDR matcher for %s", sources[0]) + } else { + assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], false), "Rule should contain destination CIDR matcher for %s", sources[0]) + } + } + + if direction == firewall.RuleDirectionIN { + assert.True(t, containsCIDRMatcher(rule.Exprs, destination, false), "Rule should contain destination CIDR matcher for %s", destination) + } else { + assert.True(t, containsCIDRMatcher(rule.Exprs, destination, true), "Rule should contain source CIDR matcher for %s", destination) + } + + // Verify protocol + if proto != firewall.ProtocolALL { + assert.True(t, containsProtocol(rule.Exprs, proto), "Rule should contain protocol matcher for %s", proto) + } + + // Verify ports + if sPort != nil { + assert.True(t, containsPort(rule.Exprs, sPort, true), "Rule should contain source port matcher for %v", sPort) + } + if dPort != nil { + assert.True(t, containsPort(rule.Exprs, dPort, false), "Rule should contain destination port matcher for %v", dPort) + } + + // Verify action + assert.True(t, containsAction(rule.Exprs, action), "Rule should contain correct action: %s", action) +} + +func containsSetLookup(exprs []expr.Any) bool { + for _, e := range exprs { + if _, ok := e.(*expr.Lookup); ok { + return true + } + } + return false +} + +func containsCIDRMatcher(exprs []expr.Any, prefix netip.Prefix, isSource bool) bool { + var offset uint32 + if isSource { + offset = 12 // src offset + } else { + offset = 16 // dst offset + } + + var payloadFound, bitwiseFound, cmpFound bool + for _, e := range exprs { + switch ex := e.(type) { + case *expr.Payload: + if ex.Base == expr.PayloadBaseNetworkHeader && ex.Offset == offset && ex.Len == 4 { + payloadFound = true + } + case *expr.Bitwise: + if ex.Len == 4 && len(ex.Mask) == 4 && len(ex.Xor) == 4 { + bitwiseFound = true + } + case *expr.Cmp: + if ex.Op == expr.CmpOpEq && len(ex.Data) == 4 { + cmpFound = true + } + } + } + return (payloadFound && bitwiseFound && cmpFound) || prefix.Bits() == 0 +} + +func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool { + var offset uint32 = 2 // Default offset for destination port + if isSource { + offset = 0 // Offset for source port + } + + var payloadFound, portMatchFound bool + for _, e := range exprs { + switch ex := e.(type) { + case *expr.Payload: + if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 { + payloadFound = true + } + case *expr.Cmp: + if port.IsRange { + if ex.Op == expr.CmpOpGte || ex.Op == expr.CmpOpLte { + portMatchFound = true + } + } else { + if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 { + portValue := binary.BigEndian.Uint16(ex.Data) + for _, p := range port.Values { + if uint16(p) == portValue { + portMatchFound = true + break + } + } + } + } + } + if payloadFound && portMatchFound { + return true + } + } + return false +} + +func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool { + var metaFound, cmpFound bool + expectedProto, _ := protoToInt(proto) + for _, e := range exprs { + switch ex := e.(type) { + case *expr.Meta: + if ex.Key == expr.MetaKeyL4PROTO { + metaFound = true + } + case *expr.Cmp: + if ex.Op == expr.CmpOpEq && len(ex.Data) == 1 && ex.Data[0] == expectedProto { + cmpFound = true + } + } + } + return metaFound && cmpFound +} + +func containsAction(exprs []expr.Any, action firewall.Action) bool { + for _, e := range exprs { + if verdict, ok := e.(*expr.Verdict); ok { + switch action { + case firewall.ActionAccept: + return verdict.Kind == expr.VerdictAccept + case firewall.ActionDrop: + return verdict.Kind == expr.VerdictDrop + } + } + } + return false +} + // check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found. func check() int { nf := nftables.Conn{} @@ -250,12 +685,12 @@ func createWorkTable() (*nftables.Table, error) { } for _, t := range tables { - if t.Name == tableName { + if t.Name == tableNameNetbird { sConn.DelTable(t) } } - table := sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}) + table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}) err = sConn.Flush() return table, err @@ -273,7 +708,7 @@ func deleteWorkTable() { } for _, t := range tables { - if t.Name == tableName { + if t.Name == tableNameNetbird { sConn.DelTable(t) } } diff --git a/client/firewall/test/cases_linux.go b/client/firewall/test/cases_linux.go index 432d113dd..267e93efd 100644 --- a/client/firewall/test/cases_linux.go +++ b/client/firewall/test/cases_linux.go @@ -1,8 +1,10 @@ -//go:build !android - package test -import firewall "github.com/netbirdio/netbird/client/firewall/manager" +import ( + "net/netip" + + firewall "github.com/netbirdio/netbird/client/firewall/manager" +) var ( InsertRuleTestCases = []struct { @@ -13,8 +15,8 @@ var ( Name: "Insert Forwarding IPV4 Rule", InputPair: firewall.RouterPair{ ID: "zxa", - Source: "100.100.100.1/32", - Destination: "100.100.200.0/24", + Source: netip.MustParsePrefix("100.100.100.1/32"), + Destination: netip.MustParsePrefix("100.100.200.0/24"), Masquerade: false, }, }, @@ -22,8 +24,8 @@ var ( Name: "Insert Forwarding And Nat IPV4 Rules", InputPair: firewall.RouterPair{ ID: "zxa", - Source: "100.100.100.1/32", - Destination: "100.100.200.0/24", + Source: netip.MustParsePrefix("100.100.100.1/32"), + Destination: netip.MustParsePrefix("100.100.200.0/24"), Masquerade: true, }, }, @@ -38,8 +40,8 @@ var ( Name: "Remove Forwarding And Nat IPV4 Rules", InputPair: firewall.RouterPair{ ID: "zxa", - Source: "100.100.100.1/32", - Destination: "100.100.200.0/24", + Source: netip.MustParsePrefix("100.100.100.1/32"), + Destination: netip.MustParsePrefix("100.100.200.0/24"), Masquerade: true, }, }, diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 75792e9c0..0e3ee9799 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -3,6 +3,7 @@ package uspfilter import ( "fmt" "net" + "net/netip" "sync" "github.com/google/gopacket" @@ -11,7 +12,8 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" ) const layerTypeAll = 0 @@ -22,7 +24,7 @@ var ( // IFaceMapper defines subset methods of interface required for manager type IFaceMapper interface { - SetFilter(iface.PacketFilter) error + SetFilter(device.PacketFilter) error Address() iface.WGAddress } @@ -103,26 +105,26 @@ func (m *Manager) IsServerRouteSupported() bool { } } -func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error { +func (m *Manager) AddNatRule(pair firewall.RouterPair) error { if m.nativeFirewall == nil { return errRouteNotSupported } - return m.nativeFirewall.InsertRoutingRules(pair) + return m.nativeFirewall.AddNatRule(pair) } -// RemoveRoutingRules removes a routing firewall rule -func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error { +// RemoveNatRule removes a routing firewall rule +func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { if m.nativeFirewall == nil { return errRouteNotSupported } - return m.nativeFirewall.RemoveRoutingRules(pair) + return m.nativeFirewall.RemoveNatRule(pair) } -// AddFiltering rule to the firewall +// AddPeerFiltering rule to the firewall // // If comment argument is empty firewall manager should set // rule ID as comment for the rule -func (m *Manager) AddFiltering( +func (m *Manager) AddPeerFiltering( ip net.IP, proto firewall.Protocol, sPort *firewall.Port, @@ -188,8 +190,22 @@ func (m *Manager) AddFiltering( return []firewall.Rule{&r}, nil } -// DeleteRule from the firewall by rule definition -func (m *Manager) DeleteRule(rule firewall.Rule) error { +func (m *Manager) AddRouteFiltering(sources [] netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action ) (firewall.Rule, error) { + if m.nativeFirewall == nil { + return nil, errRouteNotSupported + } + return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action) +} + +func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { + if m.nativeFirewall == nil { + return errRouteNotSupported + } + return m.nativeFirewall.DeleteRouteRule(rule) +} + +// DeletePeerRule from the firewall by rule definition +func (m *Manager) DeletePeerRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -215,6 +231,11 @@ func (m *Manager) DeleteRule(rule firewall.Rule) error { return nil } +// SetLegacyManagement doesn't need to be implemented for this manager +func (m *Manager) SetLegacyManagement(_ bool) error { + return nil +} + // Flush doesn't need to be implemented for this manager func (m *Manager) Flush() error { return nil } @@ -395,7 +416,7 @@ func (m *Manager) RemovePacketHook(hookID string) error { for _, r := range arr { if r.id == hookID { rule := r - return m.DeleteRule(&rule) + return m.DeletePeerRule(&rule) } } } @@ -403,7 +424,7 @@ func (m *Manager) RemovePacketHook(hookID string) error { for _, r := range arr { if r.id == hookID { rule := r - return m.DeleteRule(&rule) + return m.DeletePeerRule(&rule) } } } diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index 514a90539..c188deea4 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -11,15 +11,16 @@ import ( "github.com/stretchr/testify/require" fw "github.com/netbirdio/netbird/client/firewall/manager" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" ) type IFaceMock struct { - SetFilterFunc func(iface.PacketFilter) error + SetFilterFunc func(device.PacketFilter) error AddressFunc func() iface.WGAddress } -func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error { +func (i *IFaceMock) SetFilter(iface device.PacketFilter) error { if i.SetFilterFunc == nil { return fmt.Errorf("not implemented") } @@ -35,7 +36,7 @@ func (i *IFaceMock) Address() iface.WGAddress { func TestManagerCreate(t *testing.T) { ifaceMock := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(device.PacketFilter) error { return nil }, } m, err := Create(ifaceMock) @@ -49,10 +50,10 @@ func TestManagerCreate(t *testing.T) { } } -func TestManagerAddFiltering(t *testing.T) { +func TestManagerAddPeerFiltering(t *testing.T) { isSetFilterCalled := false ifaceMock := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { + SetFilterFunc: func(device.PacketFilter) error { isSetFilterCalled = true return nil }, @@ -71,7 +72,7 @@ func TestManagerAddFiltering(t *testing.T) { action := fw.ActionDrop comment := "Test rule" - rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment) + rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return @@ -90,7 +91,7 @@ func TestManagerAddFiltering(t *testing.T) { func TestManagerDeleteRule(t *testing.T) { ifaceMock := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(device.PacketFilter) error { return nil }, } m, err := Create(ifaceMock) @@ -106,7 +107,7 @@ func TestManagerDeleteRule(t *testing.T) { action := fw.ActionDrop comment := "Test rule" - rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment) + rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return @@ -119,14 +120,14 @@ func TestManagerDeleteRule(t *testing.T) { action = fw.ActionDrop comment = "Test rule 2" - rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment) + rule2, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return } for _, r := range rule { - err = m.DeleteRule(r) + err = m.DeletePeerRule(r) if err != nil { t.Errorf("failed to delete rule: %v", err) return @@ -140,7 +141,7 @@ func TestManagerDeleteRule(t *testing.T) { } for _, r := range rule2 { - err = m.DeleteRule(r) + err = m.DeletePeerRule(r) if err != nil { t.Errorf("failed to delete rule: %v", err) return @@ -236,7 +237,7 @@ func TestAddUDPPacketHook(t *testing.T) { func TestManagerReset(t *testing.T) { ifaceMock := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(device.PacketFilter) error { return nil }, } m, err := Create(ifaceMock) @@ -252,7 +253,7 @@ func TestManagerReset(t *testing.T) { action := fw.ActionDrop comment := "Test rule" - _, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment) + _, err = m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return @@ -271,7 +272,7 @@ func TestManagerReset(t *testing.T) { func TestNotMatchByIP(t *testing.T) { ifaceMock := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(device.PacketFilter) error { return nil }, } m, err := Create(ifaceMock) @@ -290,7 +291,7 @@ func TestNotMatchByIP(t *testing.T) { action := fw.ActionAccept comment := "Test rule" - _, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment) + _, err = m.AddPeerFiltering(ip, proto, nil, nil, direction, action, "", comment) if err != nil { t.Errorf("failed to add filtering: %v", err) return @@ -339,7 +340,7 @@ func TestNotMatchByIP(t *testing.T) { func TestRemovePacketHook(t *testing.T) { // creating mock iface iface := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(device.PacketFilter) error { return nil }, } // creating manager instance @@ -388,7 +389,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface ifaceMock := &IFaceMock{ - SetFilterFunc: func(iface.PacketFilter) error { return nil }, + SetFilterFunc: func(device.PacketFilter) error { return nil }, } manager, err := Create(ifaceMock) require.NoError(t, err) @@ -406,9 +407,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) { for i := 0; i < testMax; i++ { port := &fw.Port{Values: []int{1000 + i}} if i%2 == 0 { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") } else { - _, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") + _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") } require.NoError(t, err, "failed to add rule") diff --git a/iface/bind/bind.go b/client/iface/bind/bind.go similarity index 100% rename from iface/bind/bind.go rename to client/iface/bind/bind.go diff --git a/iface/bind/udp_mux.go b/client/iface/bind/udp_mux.go similarity index 100% rename from iface/bind/udp_mux.go rename to client/iface/bind/udp_mux.go diff --git a/iface/bind/udp_mux_universal.go b/client/iface/bind/udp_mux_universal.go similarity index 100% rename from iface/bind/udp_mux_universal.go rename to client/iface/bind/udp_mux_universal.go diff --git a/iface/bind/udp_muxed_conn.go b/client/iface/bind/udp_muxed_conn.go similarity index 100% rename from iface/bind/udp_muxed_conn.go rename to client/iface/bind/udp_muxed_conn.go diff --git a/client/iface/configurer/err.go b/client/iface/configurer/err.go new file mode 100644 index 000000000..a64bba2dd --- /dev/null +++ b/client/iface/configurer/err.go @@ -0,0 +1,5 @@ +package configurer + +import "errors" + +var ErrPeerNotFound = errors.New("peer not found") diff --git a/iface/wg_configurer_kernel_unix.go b/client/iface/configurer/kernel_unix.go similarity index 81% rename from iface/wg_configurer_kernel_unix.go rename to client/iface/configurer/kernel_unix.go index 48ea70b7b..7c1c41669 100644 --- a/iface/wg_configurer_kernel_unix.go +++ b/client/iface/configurer/kernel_unix.go @@ -1,6 +1,6 @@ //go:build (linux && !android) || freebsd -package iface +package configurer import ( "fmt" @@ -12,18 +12,17 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -type wgKernelConfigurer struct { +type KernelConfigurer struct { deviceName string } -func newWGConfigurer(deviceName string) wgConfigurer { - wgc := &wgKernelConfigurer{ +func NewKernelConfigurer(deviceName string) *KernelConfigurer { + return &KernelConfigurer{ deviceName: deviceName, } - return wgc } -func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) error { +func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error { log.Debugf("adding Wireguard private key") key, err := wgtypes.ParseKey(privateKey) if err != nil { @@ -44,7 +43,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err return nil } -func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { +func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { // parse allowed ips _, ipNet, err := net.ParseCIDR(allowedIps) if err != nil { @@ -56,8 +55,9 @@ func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepA return err } peer := wgtypes.PeerConfig{ - PublicKey: peerKeyParsed, - ReplaceAllowedIPs: true, + PublicKey: peerKeyParsed, + ReplaceAllowedIPs: false, + // don't replace allowed ips, wg will handle duplicated peer IP AllowedIPs: []net.IPNet{*ipNet}, PersistentKeepaliveInterval: &keepAlive, Endpoint: endpoint, @@ -74,7 +74,7 @@ func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepA return nil } -func (c *wgKernelConfigurer) removePeer(peerKey string) error { +func (c *KernelConfigurer) RemovePeer(peerKey string) error { peerKeyParsed, err := wgtypes.ParseKey(peerKey) if err != nil { return err @@ -95,7 +95,7 @@ func (c *wgKernelConfigurer) removePeer(peerKey string) error { return nil } -func (c *wgKernelConfigurer) addAllowedIP(peerKey string, allowedIP string) error { +func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error { _, ipNet, err := net.ParseCIDR(allowedIP) if err != nil { return err @@ -122,7 +122,7 @@ func (c *wgKernelConfigurer) addAllowedIP(peerKey string, allowedIP string) erro return nil } -func (c *wgKernelConfigurer) removeAllowedIP(peerKey string, allowedIP string) error { +func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP string) error { _, ipNet, err := net.ParseCIDR(allowedIP) if err != nil { return fmt.Errorf("parse allowed IP: %w", err) @@ -164,7 +164,7 @@ func (c *wgKernelConfigurer) removeAllowedIP(peerKey string, allowedIP string) e return nil } -func (c *wgKernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) { +func (c *KernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) { wg, err := wgctrl.New() if err != nil { return wgtypes.Peer{}, fmt.Errorf("wgctl: %w", err) @@ -188,7 +188,7 @@ func (c *wgKernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer return wgtypes.Peer{}, ErrPeerNotFound } -func (c *wgKernelConfigurer) configure(config wgtypes.Config) error { +func (c *KernelConfigurer) configure(config wgtypes.Config) error { wg, err := wgctrl.New() if err != nil { return err @@ -204,10 +204,10 @@ func (c *wgKernelConfigurer) configure(config wgtypes.Config) error { return wg.ConfigureDevice(c.deviceName, config) } -func (c *wgKernelConfigurer) close() { +func (c *KernelConfigurer) Close() { } -func (c *wgKernelConfigurer) getStats(peerKey string) (WGStats, error) { +func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) { peer, err := c.getPeer(c.deviceName, peerKey) if err != nil { return WGStats{}, fmt.Errorf("get wireguard stats: %w", err) diff --git a/iface/name.go b/client/iface/configurer/name.go similarity index 87% rename from iface/name.go rename to client/iface/configurer/name.go index 706cb65ad..e2133d0ea 100644 --- a/iface/name.go +++ b/client/iface/configurer/name.go @@ -1,6 +1,6 @@ //go:build linux || windows || freebsd -package iface +package configurer // WgInterfaceDefault is a default interface name of Wiretrustee const WgInterfaceDefault = "wt0" diff --git a/iface/name_darwin.go b/client/iface/configurer/name_darwin.go similarity index 86% rename from iface/name_darwin.go rename to client/iface/configurer/name_darwin.go index a4016ce15..034ce388d 100644 --- a/iface/name_darwin.go +++ b/client/iface/configurer/name_darwin.go @@ -1,6 +1,6 @@ //go:build darwin -package iface +package configurer // WgInterfaceDefault is a default interface name of Wiretrustee const WgInterfaceDefault = "utun100" diff --git a/iface/uapi.go b/client/iface/configurer/uapi.go similarity index 96% rename from iface/uapi.go rename to client/iface/configurer/uapi.go index d7ff52e7b..4801841de 100644 --- a/iface/uapi.go +++ b/client/iface/configurer/uapi.go @@ -1,6 +1,6 @@ //go:build !windows -package iface +package configurer import ( "net" diff --git a/iface/uapi_windows.go b/client/iface/configurer/uapi_windows.go similarity index 88% rename from iface/uapi_windows.go rename to client/iface/configurer/uapi_windows.go index e1f466364..46fa90c2e 100644 --- a/iface/uapi_windows.go +++ b/client/iface/configurer/uapi_windows.go @@ -1,4 +1,4 @@ -package iface +package configurer import ( "net" diff --git a/iface/wg_configurer_usp.go b/client/iface/configurer/usp.go similarity index 91% rename from iface/wg_configurer_usp.go rename to client/iface/configurer/usp.go index 04a29a60b..21d65ab2a 100644 --- a/iface/wg_configurer_usp.go +++ b/client/iface/configurer/usp.go @@ -1,4 +1,4 @@ -package iface +package configurer import ( "encoding/hex" @@ -19,15 +19,15 @@ import ( var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found") -type wgUSPConfigurer struct { +type WGUSPConfigurer struct { device *device.Device deviceName string uapiListener net.Listener } -func newWGUSPConfigurer(device *device.Device, deviceName string) wgConfigurer { - wgCfg := &wgUSPConfigurer{ +func NewUSPConfigurer(device *device.Device, deviceName string) *WGUSPConfigurer { + wgCfg := &WGUSPConfigurer{ device: device, deviceName: deviceName, } @@ -35,7 +35,7 @@ func newWGUSPConfigurer(device *device.Device, deviceName string) wgConfigurer { return wgCfg } -func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error { +func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error { log.Debugf("adding Wireguard private key") key, err := wgtypes.ParseKey(privateKey) if err != nil { @@ -52,7 +52,7 @@ func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error return c.device.IpcSet(toWgUserspaceString(config)) } -func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { +func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { // parse allowed ips _, ipNet, err := net.ParseCIDR(allowedIps) if err != nil { @@ -64,8 +64,9 @@ func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAliv return err } peer := wgtypes.PeerConfig{ - PublicKey: peerKeyParsed, - ReplaceAllowedIPs: true, + PublicKey: peerKeyParsed, + ReplaceAllowedIPs: false, + // don't replace allowed ips, wg will handle duplicated peer IP AllowedIPs: []net.IPNet{*ipNet}, PersistentKeepaliveInterval: &keepAlive, PresharedKey: preSharedKey, @@ -79,7 +80,7 @@ func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAliv return c.device.IpcSet(toWgUserspaceString(config)) } -func (c *wgUSPConfigurer) removePeer(peerKey string) error { +func (c *WGUSPConfigurer) RemovePeer(peerKey string) error { peerKeyParsed, err := wgtypes.ParseKey(peerKey) if err != nil { return err @@ -96,7 +97,7 @@ func (c *wgUSPConfigurer) removePeer(peerKey string) error { return c.device.IpcSet(toWgUserspaceString(config)) } -func (c *wgUSPConfigurer) addAllowedIP(peerKey string, allowedIP string) error { +func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error { _, ipNet, err := net.ParseCIDR(allowedIP) if err != nil { return err @@ -120,7 +121,7 @@ func (c *wgUSPConfigurer) addAllowedIP(peerKey string, allowedIP string) error { return c.device.IpcSet(toWgUserspaceString(config)) } -func (c *wgUSPConfigurer) removeAllowedIP(peerKey string, ip string) error { +func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error { ipc, err := c.device.IpcGet() if err != nil { return err @@ -184,7 +185,7 @@ func (c *wgUSPConfigurer) removeAllowedIP(peerKey string, ip string) error { } // startUAPI starts the UAPI listener for managing the WireGuard interface via external tool -func (t *wgUSPConfigurer) startUAPI() { +func (t *WGUSPConfigurer) startUAPI() { var err error t.uapiListener, err = openUAPI(t.deviceName) if err != nil { @@ -206,7 +207,7 @@ func (t *wgUSPConfigurer) startUAPI() { }(t.uapiListener) } -func (t *wgUSPConfigurer) close() { +func (t *WGUSPConfigurer) Close() { if t.uapiListener != nil { err := t.uapiListener.Close() if err != nil { @@ -222,7 +223,7 @@ func (t *wgUSPConfigurer) close() { } } -func (t *wgUSPConfigurer) getStats(peerKey string) (WGStats, error) { +func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) { ipc, err := t.device.IpcGet() if err != nil { return WGStats{}, fmt.Errorf("ipc get: %w", err) diff --git a/iface/wg_configurer_usp_test.go b/client/iface/configurer/usp_test.go similarity index 99% rename from iface/wg_configurer_usp_test.go rename to client/iface/configurer/usp_test.go index ac0fc6130..775339f24 100644 --- a/iface/wg_configurer_usp_test.go +++ b/client/iface/configurer/usp_test.go @@ -1,4 +1,4 @@ -package iface +package configurer import ( "encoding/hex" diff --git a/client/iface/configurer/wgstats.go b/client/iface/configurer/wgstats.go new file mode 100644 index 000000000..56d0d7310 --- /dev/null +++ b/client/iface/configurer/wgstats.go @@ -0,0 +1,9 @@ +package configurer + +import "time" + +type WGStats struct { + LastHandshake time.Time + TxBytes int64 + RxBytes int64 +} diff --git a/client/iface/device.go b/client/iface/device.go new file mode 100644 index 000000000..0d4e69145 --- /dev/null +++ b/client/iface/device.go @@ -0,0 +1,18 @@ +//go:build !android + +package iface + +import ( + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" +) + +type WGTunDevice interface { + Create() (device.WGConfigurer, error) + Up() (*bind.UniversalUDPMuxDefault, error) + UpdateAddr(address WGAddress) error + WgAddress() WGAddress + DeviceName() string + Close() error + FilteredDevice() *device.FilteredDevice +} diff --git a/iface/tun_adapter.go b/client/iface/device/adapter.go similarity index 94% rename from iface/tun_adapter.go rename to client/iface/device/adapter.go index adec93ed1..6ebc05390 100644 --- a/iface/tun_adapter.go +++ b/client/iface/device/adapter.go @@ -1,4 +1,4 @@ -package iface +package device // TunAdapter is an interface for create tun device from external service type TunAdapter interface { diff --git a/iface/address.go b/client/iface/device/address.go similarity index 69% rename from iface/address.go rename to client/iface/device/address.go index 5ff4fbc06..15de301da 100644 --- a/iface/address.go +++ b/client/iface/device/address.go @@ -1,18 +1,18 @@ -package iface +package device import ( "fmt" "net" ) -// WGAddress Wireguard parsed address +// WGAddress WireGuard parsed address type WGAddress struct { IP net.IP Network *net.IPNet } -// parseWGAddress parse a string ("1.2.3.4/24") address to WG Address -func parseWGAddress(address string) (WGAddress, error) { +// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address +func ParseWGAddress(address string) (WGAddress, error) { ip, network, err := net.ParseCIDR(address) if err != nil { return WGAddress{}, err diff --git a/iface/tun_args.go b/client/iface/device/args.go similarity index 88% rename from iface/tun_args.go rename to client/iface/device/args.go index 0eac2c4c0..d7b86b335 100644 --- a/iface/tun_args.go +++ b/client/iface/device/args.go @@ -1,4 +1,4 @@ -package iface +package device type MobileIFaceArguments struct { TunAdapter TunAdapter // only for Android diff --git a/iface/tun_android.go b/client/iface/device/device_android.go similarity index 61% rename from iface/tun_android.go rename to client/iface/device/device_android.go index 504993094..29e3f409d 100644 --- a/iface/tun_android.go +++ b/client/iface/device/device_android.go @@ -1,7 +1,6 @@ //go:build android -// +build android -package iface +package device import ( "strings" @@ -12,11 +11,12 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" ) -// ignore the wgTunDevice interface on Android because the creation of the tun device is different on this platform -type wgTunDevice struct { +// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform +type WGTunDevice struct { address WGAddress port int key string @@ -24,15 +24,15 @@ type wgTunDevice struct { iceBind *bind.ICEBind tunAdapter TunAdapter - name string - device *device.Device - wrapper *DeviceWrapper - udpMux *bind.UniversalUDPMuxDefault - configurer wgConfigurer + name string + device *device.Device + filteredDevice *FilteredDevice + udpMux *bind.UniversalUDPMuxDefault + configurer WGConfigurer } -func newTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter, filterFn bind.FilterFn) wgTunDevice { - return wgTunDevice{ +func NewTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter, filterFn bind.FilterFn) *WGTunDevice { + return &WGTunDevice{ address: address, port: port, key: key, @@ -42,7 +42,7 @@ func newTunDevice(address WGAddress, port int, key string, mtu int, transportNet } } -func (t *wgTunDevice) Create(routes []string, dns string, searchDomains []string) (wgConfigurer, error) { +func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) { log.Info("create tun interface") routesString := routesToString(routes) @@ -61,24 +61,24 @@ func (t *wgTunDevice) Create(routes []string, dns string, searchDomains []string return nil, err } t.name = name - t.wrapper = newDeviceWrapper(tunDevice) + t.filteredDevice = newDeviceFilter(tunDevice) log.Debugf("attaching to interface %v", name) - t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] ")) + t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] ")) // without this property mobile devices can discover remote endpoints if the configured one was wrong. // this helps with support for the older NetBird clients that had a hardcoded direct mode // t.device.DisableSomeRoamingForBrokenMobileSemantics() - t.configurer = newWGUSPConfigurer(t.device, t.name) - err = t.configurer.configureInterface(t.key, t.port) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { t.device.Close() - t.configurer.close() + t.configurer.Close() return nil, err } return t.configurer, nil } -func (t *wgTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err @@ -93,14 +93,14 @@ func (t *wgTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *wgTunDevice) UpdateAddr(addr WGAddress) error { +func (t *WGTunDevice) UpdateAddr(addr WGAddress) error { // todo implement return nil } -func (t *wgTunDevice) Close() error { +func (t *WGTunDevice) Close() error { if t.configurer != nil { - t.configurer.close() + t.configurer.Close() } if t.device != nil { @@ -115,20 +115,20 @@ func (t *wgTunDevice) Close() error { return nil } -func (t *wgTunDevice) Device() *device.Device { +func (t *WGTunDevice) Device() *device.Device { return t.device } -func (t *wgTunDevice) DeviceName() string { +func (t *WGTunDevice) DeviceName() string { return t.name } -func (t *wgTunDevice) WgAddress() WGAddress { +func (t *WGTunDevice) WgAddress() WGAddress { return t.address } -func (t *wgTunDevice) Wrapper() *DeviceWrapper { - return t.wrapper +func (t *WGTunDevice) FilteredDevice() *FilteredDevice { + return t.filteredDevice } func routesToString(routes []string) string { diff --git a/iface/tun_darwin.go b/client/iface/device/device_darwin.go similarity index 64% rename from iface/tun_darwin.go rename to client/iface/device/device_darwin.go index 364e5dfad..03e85a7f1 100644 --- a/iface/tun_darwin.go +++ b/client/iface/device/device_darwin.go @@ -1,8 +1,9 @@ //go:build !ios -package iface +package device import ( + "fmt" "os/exec" "github.com/pion/transport/v3" @@ -10,10 +11,11 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" ) -type tunDevice struct { +type TunDevice struct { name string address WGAddress port int @@ -21,14 +23,14 @@ type tunDevice struct { mtu int iceBind *bind.ICEBind - device *device.Device - wrapper *DeviceWrapper - udpMux *bind.UniversalUDPMuxDefault - configurer wgConfigurer + device *device.Device + filteredDevice *FilteredDevice + udpMux *bind.UniversalUDPMuxDefault + configurer WGConfigurer } -func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice { - return &tunDevice{ +func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice { + return &TunDevice{ name: name, address: address, port: port, @@ -38,16 +40,16 @@ func newTunDevice(name string, address WGAddress, port int, key string, mtu int, } } -func (t *tunDevice) Create() (wgConfigurer, error) { +func (t *TunDevice) Create() (WGConfigurer, error) { tunDevice, err := tun.CreateTUN(t.name, t.mtu) if err != nil { - return nil, err + return nil, fmt.Errorf("error creating tun device: %s", err) } - t.wrapper = newDeviceWrapper(tunDevice) + t.filteredDevice = newDeviceFilter(tunDevice) // We need to create a wireguard-go device and listen to configuration requests t.device = device.NewDevice( - t.wrapper, + t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "), ) @@ -55,20 +57,20 @@ func (t *tunDevice) Create() (wgConfigurer, error) { err = t.assignAddr() if err != nil { t.device.Close() - return nil, err + return nil, fmt.Errorf("error assigning ip: %s", err) } - t.configurer = newWGUSPConfigurer(t.device, t.name) - err = t.configurer.configureInterface(t.key, t.port) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { t.device.Close() - t.configurer.close() - return nil, err + t.configurer.Close() + return nil, fmt.Errorf("error configuring interface: %s", err) } return t.configurer, nil } -func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err @@ -83,14 +85,14 @@ func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *tunDevice) UpdateAddr(address WGAddress) error { +func (t *TunDevice) UpdateAddr(address WGAddress) error { t.address = address return t.assignAddr() } -func (t *tunDevice) Close() error { +func (t *TunDevice) Close() error { if t.configurer != nil { - t.configurer.close() + t.configurer.Close() } if t.device != nil { @@ -104,20 +106,20 @@ func (t *tunDevice) Close() error { return nil } -func (t *tunDevice) WgAddress() WGAddress { +func (t *TunDevice) WgAddress() WGAddress { return t.address } -func (t *tunDevice) DeviceName() string { +func (t *TunDevice) DeviceName() string { return t.name } -func (t *tunDevice) Wrapper() *DeviceWrapper { - return t.wrapper +func (t *TunDevice) FilteredDevice() *FilteredDevice { + return t.filteredDevice } // assignAddr Adds IP address to the tunnel interface and network route based on the range provided -func (t *tunDevice) assignAddr() error { +func (t *TunDevice) assignAddr() error { cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String()) if out, err := cmd.CombinedOutput(); err != nil { log.Errorf("adding address command '%v' failed with output: %s", cmd.String(), out) diff --git a/iface/device_wrapper.go b/client/iface/device/device_filter.go similarity index 81% rename from iface/device_wrapper.go rename to client/iface/device/device_filter.go index 2fa219395..f87f10429 100644 --- a/iface/device_wrapper.go +++ b/client/iface/device/device_filter.go @@ -1,4 +1,4 @@ -package iface +package device import ( "net" @@ -28,22 +28,23 @@ type PacketFilter interface { SetNetwork(*net.IPNet) } -// DeviceWrapper to override Read or Write of packets -type DeviceWrapper struct { +// FilteredDevice to override Read or Write of packets +type FilteredDevice struct { tun.Device + filter PacketFilter mutex sync.RWMutex } -// newDeviceWrapper constructor function -func newDeviceWrapper(device tun.Device) *DeviceWrapper { - return &DeviceWrapper{ +// newDeviceFilter constructor function +func newDeviceFilter(device tun.Device) *FilteredDevice { + return &FilteredDevice{ Device: device, } } // Read wraps read method with filtering feature -func (d *DeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { +func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { if n, err = d.Device.Read(bufs, sizes, offset); err != nil { return 0, err } @@ -68,7 +69,7 @@ func (d *DeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err } // Write wraps write method with filtering feature -func (d *DeviceWrapper) Write(bufs [][]byte, offset int) (int, error) { +func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) { d.mutex.RLock() filter := d.filter d.mutex.RUnlock() @@ -92,7 +93,7 @@ func (d *DeviceWrapper) Write(bufs [][]byte, offset int) (int, error) { } // SetFilter sets packet filter to device -func (d *DeviceWrapper) SetFilter(filter PacketFilter) { +func (d *FilteredDevice) SetFilter(filter PacketFilter) { d.mutex.Lock() d.filter = filter d.mutex.Unlock() diff --git a/iface/device_wrapper_test.go b/client/iface/device/device_filter_test.go similarity index 95% rename from iface/device_wrapper_test.go rename to client/iface/device/device_filter_test.go index 2d3725ea4..d3278b918 100644 --- a/iface/device_wrapper_test.go +++ b/client/iface/device/device_filter_test.go @@ -1,4 +1,4 @@ -package iface +package device import ( "net" @@ -7,7 +7,8 @@ import ( "github.com/golang/mock/gomock" "github.com/google/gopacket" "github.com/google/gopacket/layers" - mocks "github.com/netbirdio/netbird/iface/mocks" + + mocks "github.com/netbirdio/netbird/client/iface/mocks" ) func TestDeviceWrapperRead(t *testing.T) { @@ -51,7 +52,7 @@ func TestDeviceWrapperRead(t *testing.T) { return 1, nil }) - wrapped := newDeviceWrapper(tun) + wrapped := newDeviceFilter(tun) bufs := [][]byte{{}} sizes := []int{0} @@ -99,7 +100,7 @@ func TestDeviceWrapperRead(t *testing.T) { tun := mocks.NewMockDevice(ctrl) tun.EXPECT().Write(mockBufs, 0).Return(1, nil) - wrapped := newDeviceWrapper(tun) + wrapped := newDeviceFilter(tun) bufs := [][]byte{buffer.Bytes()} @@ -147,7 +148,7 @@ func TestDeviceWrapperRead(t *testing.T) { filter := mocks.NewMockPacketFilter(ctrl) filter.EXPECT().DropIncoming(gomock.Any()).Return(true) - wrapped := newDeviceWrapper(tun) + wrapped := newDeviceFilter(tun) wrapped.filter = filter bufs := [][]byte{buffer.Bytes()} @@ -202,7 +203,7 @@ func TestDeviceWrapperRead(t *testing.T) { filter := mocks.NewMockPacketFilter(ctrl) filter.EXPECT().DropOutgoing(gomock.Any()).Return(true) - wrapped := newDeviceWrapper(tun) + wrapped := newDeviceFilter(tun) wrapped.filter = filter bufs := [][]byte{{}} diff --git a/iface/tun_ios.go b/client/iface/device/device_ios.go similarity index 63% rename from iface/tun_ios.go rename to client/iface/device/device_ios.go index 6d53cc333..226e8a2e0 100644 --- a/iface/tun_ios.go +++ b/client/iface/device/device_ios.go @@ -1,7 +1,7 @@ //go:build ios // +build ios -package iface +package device import ( "os" @@ -12,10 +12,11 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" ) -type tunDevice struct { +type TunDevice struct { name string address WGAddress port int @@ -23,14 +24,14 @@ type tunDevice struct { iceBind *bind.ICEBind tunFd int - device *device.Device - wrapper *DeviceWrapper - udpMux *bind.UniversalUDPMuxDefault - configurer wgConfigurer + device *device.Device + filteredDevice *FilteredDevice + udpMux *bind.UniversalUDPMuxDefault + configurer WGConfigurer } -func newTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int, filterFn bind.FilterFn) *tunDevice { - return &tunDevice{ +func NewTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int, filterFn bind.FilterFn) *TunDevice { + return &TunDevice{ name: name, address: address, port: port, @@ -40,7 +41,7 @@ func newTunDevice(name string, address WGAddress, port int, key string, transpor } } -func (t *tunDevice) Create() (wgConfigurer, error) { +func (t *TunDevice) Create() (WGConfigurer, error) { log.Infof("create tun interface") dupTunFd, err := unix.Dup(t.tunFd) @@ -62,24 +63,24 @@ func (t *tunDevice) Create() (wgConfigurer, error) { return nil, err } - t.wrapper = newDeviceWrapper(tunDevice) + t.filteredDevice = newDeviceFilter(tunDevice) log.Debug("Attaching to interface") - t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] ")) + t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] ")) // without this property mobile devices can discover remote endpoints if the configured one was wrong. // this helps with support for the older NetBird clients that had a hardcoded direct mode // t.device.DisableSomeRoamingForBrokenMobileSemantics() - t.configurer = newWGUSPConfigurer(t.device, t.name) - err = t.configurer.configureInterface(t.key, t.port) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { t.device.Close() - t.configurer.close() + t.configurer.Close() return nil, err } return t.configurer, nil } -func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err @@ -94,17 +95,17 @@ func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *tunDevice) Device() *device.Device { +func (t *TunDevice) Device() *device.Device { return t.device } -func (t *tunDevice) DeviceName() string { +func (t *TunDevice) DeviceName() string { return t.name } -func (t *tunDevice) Close() error { +func (t *TunDevice) Close() error { if t.configurer != nil { - t.configurer.close() + t.configurer.Close() } if t.device != nil { @@ -119,15 +120,15 @@ func (t *tunDevice) Close() error { return nil } -func (t *tunDevice) WgAddress() WGAddress { +func (t *TunDevice) WgAddress() WGAddress { return t.address } -func (t *tunDevice) UpdateAddr(addr WGAddress) error { +func (t *TunDevice) UpdateAddr(addr WGAddress) error { // todo implement return nil } -func (t *tunDevice) Wrapper() *DeviceWrapper { - return t.wrapper +func (t *TunDevice) FilteredDevice() *FilteredDevice { + return t.filteredDevice } diff --git a/iface/tun_kernel_unix.go b/client/iface/device/device_kernel_unix.go similarity index 73% rename from iface/tun_kernel_unix.go rename to client/iface/device/device_kernel_unix.go index 019dd786b..f355d2cf7 100644 --- a/iface/tun_kernel_unix.go +++ b/client/iface/device/device_kernel_unix.go @@ -1,6 +1,6 @@ //go:build (linux && !android) || freebsd -package iface +package device import ( "context" @@ -10,11 +10,12 @@ import ( "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/sharedsock" ) -type tunKernelDevice struct { +type TunKernelDevice struct { name string address WGAddress wgPort int @@ -31,11 +32,11 @@ type tunKernelDevice struct { filterFn bind.FilterFn } -func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) wgTunDevice { +func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice { checkUser() ctx, cancel := context.WithCancel(context.Background()) - return &tunKernelDevice{ + return &TunKernelDevice{ ctx: ctx, ctxCancel: cancel, name: name, @@ -47,7 +48,7 @@ func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu in } } -func (t *tunKernelDevice) Create() (wgConfigurer, error) { +func (t *TunKernelDevice) Create() (WGConfigurer, error) { link := newWGLink(t.name) if err := link.recreate(); err != nil { @@ -67,16 +68,16 @@ func (t *tunKernelDevice) Create() (wgConfigurer, error) { return nil, fmt.Errorf("set mtu: %w", err) } - configurer := newWGConfigurer(t.name) + configurer := configurer.NewKernelConfigurer(t.name) - if err := configurer.configureInterface(t.key, t.wgPort); err != nil { - return nil, err + if err := configurer.ConfigureInterface(t.key, t.wgPort); err != nil { + return nil, fmt.Errorf("error configuring interface: %s", err) } return configurer, nil } -func (t *tunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { if t.udpMux != nil { return t.udpMux, nil } @@ -111,12 +112,12 @@ func (t *tunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return t.udpMux, nil } -func (t *tunKernelDevice) UpdateAddr(address WGAddress) error { +func (t *TunKernelDevice) UpdateAddr(address WGAddress) error { t.address = address return t.assignAddr() } -func (t *tunKernelDevice) Close() error { +func (t *TunKernelDevice) Close() error { if t.link == nil { return nil } @@ -144,19 +145,19 @@ func (t *tunKernelDevice) Close() error { return closErr } -func (t *tunKernelDevice) WgAddress() WGAddress { +func (t *TunKernelDevice) WgAddress() WGAddress { return t.address } -func (t *tunKernelDevice) DeviceName() string { +func (t *TunKernelDevice) DeviceName() string { return t.name } -func (t *tunKernelDevice) Wrapper() *DeviceWrapper { +func (t *TunKernelDevice) FilteredDevice() *FilteredDevice { return nil } // assignAddr Adds IP address to the tunnel interface -func (t *tunKernelDevice) assignAddr() error { +func (t *TunKernelDevice) assignAddr() error { return t.link.assignAddr(t.address) } diff --git a/iface/tun_netstack.go b/client/iface/device/device_netstack.go similarity index 51% rename from iface/tun_netstack.go rename to client/iface/device/device_netstack.go index df2f75c45..440a1ca19 100644 --- a/iface/tun_netstack.go +++ b/client/iface/device/device_netstack.go @@ -1,7 +1,7 @@ //go:build !android // +build !android -package iface +package device import ( "fmt" @@ -10,11 +10,12 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/device" - "github.com/netbirdio/netbird/iface/bind" - "github.com/netbirdio/netbird/iface/netstack" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/netstack" ) -type tunNetstackDevice struct { +type TunNetstackDevice struct { name string address WGAddress port int @@ -23,15 +24,15 @@ type tunNetstackDevice struct { listenAddress string iceBind *bind.ICEBind - device *device.Device - wrapper *DeviceWrapper - nsTun *netstack.NetStackTun - udpMux *bind.UniversalUDPMuxDefault - configurer wgConfigurer + device *device.Device + filteredDevice *FilteredDevice + nsTun *netstack.NetStackTun + udpMux *bind.UniversalUDPMuxDefault + configurer WGConfigurer } -func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) wgTunDevice { - return &tunNetstackDevice{ +func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) *TunNetstackDevice { + return &TunNetstackDevice{ name: name, address: address, port: wgPort, @@ -42,33 +43,33 @@ func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string } } -func (t *tunNetstackDevice) Create() (wgConfigurer, error) { +func (t *TunNetstackDevice) Create() (WGConfigurer, error) { log.Info("create netstack tun interface") t.nsTun = netstack.NewNetStackTun(t.listenAddress, t.address.IP.String(), t.mtu) tunIface, err := t.nsTun.Create() if err != nil { - return nil, err + return nil, fmt.Errorf("error creating tun device: %s", err) } - t.wrapper = newDeviceWrapper(tunIface) + t.filteredDevice = newDeviceFilter(tunIface) t.device = device.NewDevice( - t.wrapper, + t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "), ) - t.configurer = newWGUSPConfigurer(t.device, t.name) - err = t.configurer.configureInterface(t.key, t.port) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { _ = tunIface.Close() - return nil, err + return nil, fmt.Errorf("error configuring interface: %s", err) } log.Debugf("device has been created: %s", t.name) return t.configurer, nil } -func (t *tunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) { if t.device == nil { return nil, fmt.Errorf("device is not ready yet") } @@ -87,13 +88,13 @@ func (t *tunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *tunNetstackDevice) UpdateAddr(WGAddress) error { +func (t *TunNetstackDevice) UpdateAddr(WGAddress) error { return nil } -func (t *tunNetstackDevice) Close() error { +func (t *TunNetstackDevice) Close() error { if t.configurer != nil { - t.configurer.close() + t.configurer.Close() } if t.device != nil { @@ -106,14 +107,14 @@ func (t *tunNetstackDevice) Close() error { return nil } -func (t *tunNetstackDevice) WgAddress() WGAddress { +func (t *TunNetstackDevice) WgAddress() WGAddress { return t.address } -func (t *tunNetstackDevice) DeviceName() string { +func (t *TunNetstackDevice) DeviceName() string { return t.name } -func (t *tunNetstackDevice) Wrapper() *DeviceWrapper { - return t.wrapper +func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice { + return t.filteredDevice } diff --git a/iface/tun_usp_unix.go b/client/iface/device/device_usp_unix.go similarity index 55% rename from iface/tun_usp_unix.go rename to client/iface/device/device_usp_unix.go index 814c9ca89..4175f6556 100644 --- a/iface/tun_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -1,6 +1,6 @@ //go:build (linux && !android) || freebsd -package iface +package device import ( "fmt" @@ -12,10 +12,11 @@ import ( "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" ) -type tunUSPDevice struct { +type USPDevice struct { name string address WGAddress port int @@ -23,39 +24,38 @@ type tunUSPDevice struct { mtu int iceBind *bind.ICEBind - device *device.Device - wrapper *DeviceWrapper - udpMux *bind.UniversalUDPMuxDefault - configurer wgConfigurer + device *device.Device + filteredDevice *FilteredDevice + udpMux *bind.UniversalUDPMuxDefault + configurer WGConfigurer } -func newTunUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice { +func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *USPDevice { log.Infof("using userspace bind mode") checkUser() - return &tunUSPDevice{ + return &USPDevice{ name: name, address: address, port: port, key: key, mtu: mtu, - iceBind: bind.NewICEBind(transportNet, filterFn), - } + iceBind: bind.NewICEBind(transportNet, filterFn)} } -func (t *tunUSPDevice) Create() (wgConfigurer, error) { +func (t *USPDevice) Create() (WGConfigurer, error) { log.Info("create tun interface") tunIface, err := tun.CreateTUN(t.name, t.mtu) if err != nil { - log.Debugf("failed to create tun unterface (%s, %d): %s", t.name, t.mtu, err) - return nil, err + log.Debugf("failed to create tun interface (%s, %d): %s", t.name, t.mtu, err) + return nil, fmt.Errorf("error creating tun device: %s", err) } - t.wrapper = newDeviceWrapper(tunIface) + t.filteredDevice = newDeviceFilter(tunIface) // We need to create a wireguard-go device and listen to configuration requests t.device = device.NewDevice( - t.wrapper, + t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "), ) @@ -63,20 +63,20 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) { err = t.assignAddr() if err != nil { t.device.Close() - return nil, err + return nil, fmt.Errorf("error assigning ip: %s", err) } - t.configurer = newWGUSPConfigurer(t.device, t.name) - err = t.configurer.configureInterface(t.key, t.port) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { t.device.Close() - t.configurer.close() - return nil, err + t.configurer.Close() + return nil, fmt.Errorf("error configuring interface: %s", err) } return t.configurer, nil } -func (t *tunUSPDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) { if t.device == nil { return nil, fmt.Errorf("device is not ready yet") } @@ -96,14 +96,14 @@ func (t *tunUSPDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *tunUSPDevice) UpdateAddr(address WGAddress) error { +func (t *USPDevice) UpdateAddr(address WGAddress) error { t.address = address return t.assignAddr() } -func (t *tunUSPDevice) Close() error { +func (t *USPDevice) Close() error { if t.configurer != nil { - t.configurer.close() + t.configurer.Close() } if t.device != nil { @@ -116,20 +116,20 @@ func (t *tunUSPDevice) Close() error { return nil } -func (t *tunUSPDevice) WgAddress() WGAddress { +func (t *USPDevice) WgAddress() WGAddress { return t.address } -func (t *tunUSPDevice) DeviceName() string { +func (t *USPDevice) DeviceName() string { return t.name } -func (t *tunUSPDevice) Wrapper() *DeviceWrapper { - return t.wrapper +func (t *USPDevice) FilteredDevice() *FilteredDevice { + return t.filteredDevice } // assignAddr Adds IP address to the tunnel interface -func (t *tunUSPDevice) assignAddr() error { +func (t *USPDevice) assignAddr() error { link := newWGLink(t.name) return link.assignAddr(t.address) diff --git a/iface/tun_windows.go b/client/iface/device/device_windows.go similarity index 60% rename from iface/tun_windows.go rename to client/iface/device/device_windows.go index 0d658059f..f3e216ccd 100644 --- a/iface/tun_windows.go +++ b/client/iface/device/device_windows.go @@ -1,4 +1,4 @@ -package iface +package device import ( "fmt" @@ -11,10 +11,13 @@ import ( "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" ) -type tunDevice struct { +const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}" + +type TunDevice struct { name string address WGAddress port int @@ -24,13 +27,13 @@ type tunDevice struct { device *device.Device nativeTunDevice *tun.NativeTun - wrapper *DeviceWrapper + filteredDevice *FilteredDevice udpMux *bind.UniversalUDPMuxDefault - configurer wgConfigurer + configurer WGConfigurer } -func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice { - return &tunDevice{ +func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice { + return &TunDevice{ name: name, address: address, port: port, @@ -40,18 +43,31 @@ func newTunDevice(name string, address WGAddress, port int, key string, mtu int, } } -func (t *tunDevice) Create() (wgConfigurer, error) { - log.Info("create tun interface") - tunDevice, err := tun.CreateTUN(t.name, t.mtu) +func getGUID() (windows.GUID, error) { + guidString := defaultWindowsGUIDSTring + if CustomWindowsGUIDString != "" { + guidString = CustomWindowsGUIDString + } + return windows.GUIDFromString(guidString) +} + +func (t *TunDevice) Create() (WGConfigurer, error) { + guid, err := getGUID() if err != nil { + log.Errorf("failed to get GUID: %s", err) return nil, err } + log.Info("create tun interface") + tunDevice, err := tun.CreateTUNWithRequestedGUID(t.name, &guid, t.mtu) + if err != nil { + return nil, fmt.Errorf("error creating tun device: %s", err) + } t.nativeTunDevice = tunDevice.(*tun.NativeTun) - t.wrapper = newDeviceWrapper(tunDevice) + t.filteredDevice = newDeviceFilter(tunDevice) // We need to create a wireguard-go device and listen to configuration requests t.device = device.NewDevice( - t.wrapper, + t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[netbird] "), ) @@ -74,20 +90,20 @@ func (t *tunDevice) Create() (wgConfigurer, error) { err = t.assignAddr() if err != nil { t.device.Close() - return nil, err + return nil, fmt.Errorf("error assigning ip: %s", err) } - t.configurer = newWGUSPConfigurer(t.device, t.name) - err = t.configurer.configureInterface(t.key, t.port) + t.configurer = configurer.NewUSPConfigurer(t.device, t.name) + err = t.configurer.ConfigureInterface(t.key, t.port) if err != nil { t.device.Close() - t.configurer.close() - return nil, err + t.configurer.Close() + return nil, fmt.Errorf("error configuring interface: %s", err) } return t.configurer, nil } -func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { +func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { err := t.device.Up() if err != nil { return nil, err @@ -102,14 +118,14 @@ func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { return udpMux, nil } -func (t *tunDevice) UpdateAddr(address WGAddress) error { +func (t *TunDevice) UpdateAddr(address WGAddress) error { t.address = address return t.assignAddr() } -func (t *tunDevice) Close() error { +func (t *TunDevice) Close() error { if t.configurer != nil { - t.configurer.close() + t.configurer.Close() } if t.device != nil { @@ -123,19 +139,19 @@ func (t *tunDevice) Close() error { } return nil } -func (t *tunDevice) WgAddress() WGAddress { +func (t *TunDevice) WgAddress() WGAddress { return t.address } -func (t *tunDevice) DeviceName() string { +func (t *TunDevice) DeviceName() string { return t.name } -func (t *tunDevice) Wrapper() *DeviceWrapper { - return t.wrapper +func (t *TunDevice) FilteredDevice() *FilteredDevice { + return t.filteredDevice } -func (t *tunDevice) getInterfaceGUIDString() (string, error) { +func (t *TunDevice) GetInterfaceGUIDString() (string, error) { if t.nativeTunDevice == nil { return "", fmt.Errorf("interface has not been initialized yet") } @@ -149,7 +165,7 @@ func (t *tunDevice) getInterfaceGUIDString() (string, error) { } // assignAddr Adds IP address to the tunnel interface and network route based on the range provided -func (t *tunDevice) assignAddr() error { +func (t *TunDevice) assignAddr() error { luid := winipcfg.LUID(t.nativeTunDevice.LUID()) log.Debugf("adding address %s to interface: %s", t.address.IP, t.name) return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())}) diff --git a/client/iface/device/interface.go b/client/iface/device/interface.go new file mode 100644 index 000000000..0196b0085 --- /dev/null +++ b/client/iface/device/interface.go @@ -0,0 +1,20 @@ +package device + +import ( + "net" + "time" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/iface/configurer" +) + +type WGConfigurer interface { + ConfigureInterface(privateKey string, port int) error + UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + RemovePeer(peerKey string) error + AddAllowedIP(peerKey string, allowedIP string) error + RemoveAllowedIP(peerKey string, allowedIP string) error + Close() + GetStats(peerKey string) (configurer.WGStats, error) +} diff --git a/iface/module.go b/client/iface/device/kernel_module.go similarity index 92% rename from iface/module.go rename to client/iface/device/kernel_module.go index ca70cf3c7..1bdd6f7c6 100644 --- a/iface/module.go +++ b/client/iface/device/kernel_module.go @@ -1,6 +1,6 @@ //go:build (!linux && !freebsd) || android -package iface +package device // WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only) func WireGuardModuleIsLoaded() bool { diff --git a/iface/module_freebsd.go b/client/iface/device/kernel_module_freebsd.go similarity index 84% rename from iface/module_freebsd.go rename to client/iface/device/kernel_module_freebsd.go index 00ad882c2..dd6c8b408 100644 --- a/iface/module_freebsd.go +++ b/client/iface/device/kernel_module_freebsd.go @@ -1,4 +1,4 @@ -package iface +package device // WireGuardModuleIsLoaded check if kernel support wireguard func WireGuardModuleIsLoaded() bool { @@ -10,8 +10,8 @@ func WireGuardModuleIsLoaded() bool { return false } -// tunModuleIsLoaded check if tun module exist, if is not attempt to load it -func tunModuleIsLoaded() bool { +// ModuleTunIsLoaded check if tun module exist, if is not attempt to load it +func ModuleTunIsLoaded() bool { // Assume tun supported by freebsd kernel by default // TODO: implement check for module loaded in kernel or build-it return true diff --git a/iface/module_linux.go b/client/iface/device/kernel_module_linux.go similarity index 98% rename from iface/module_linux.go rename to client/iface/device/kernel_module_linux.go index 11c0482d5..0d195779d 100644 --- a/iface/module_linux.go +++ b/client/iface/device/kernel_module_linux.go @@ -1,7 +1,7 @@ //go:build linux && !android // Package iface provides wireguard network interface creation and management -package iface +package device import ( "bufio" @@ -66,8 +66,8 @@ func getModuleRoot() string { return filepath.Join(moduleLibDir, string(uname.Release[:i])) } -// tunModuleIsLoaded check if tun module exist, if is not attempt to load it -func tunModuleIsLoaded() bool { +// ModuleTunIsLoaded check if tun module exist, if is not attempt to load it +func ModuleTunIsLoaded() bool { _, err := os.Stat("/dev/net/tun") if err == nil { return true diff --git a/iface/module_linux_test.go b/client/iface/device/kernel_module_linux_test.go similarity index 98% rename from iface/module_linux_test.go rename to client/iface/device/kernel_module_linux_test.go index 97e9b1f78..de9656e47 100644 --- a/iface/module_linux_test.go +++ b/client/iface/device/kernel_module_linux_test.go @@ -1,4 +1,6 @@ -package iface +//go:build linux && !android + +package device import ( "bufio" @@ -132,7 +134,7 @@ func resetGlobals() { } func createFiles(t *testing.T) (string, []module) { - t.Helper() + t.Helper() writeFile := func(path, text string) { if err := os.WriteFile(path, []byte(text), 0644); err != nil { t.Fatal(err) @@ -168,7 +170,7 @@ func createFiles(t *testing.T) (string, []module) { } func getRandomLoadedModule(t *testing.T) (string, error) { - t.Helper() + t.Helper() f, err := os.Open("/proc/modules") if err != nil { return "", err diff --git a/iface/tun_link_freebsd.go b/client/iface/device/wg_link_freebsd.go similarity index 95% rename from iface/tun_link_freebsd.go rename to client/iface/device/wg_link_freebsd.go index be7921fdb..104010f47 100644 --- a/iface/tun_link_freebsd.go +++ b/client/iface/device/wg_link_freebsd.go @@ -1,10 +1,11 @@ -package iface +package device import ( "fmt" - "github.com/netbirdio/netbird/iface/freebsd" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/freebsd" ) type wgLink struct { diff --git a/iface/tun_link_linux.go b/client/iface/device/wg_link_linux.go similarity index 99% rename from iface/tun_link_linux.go rename to client/iface/device/wg_link_linux.go index 3ce644e84..a15cffe48 100644 --- a/iface/tun_link_linux.go +++ b/client/iface/device/wg_link_linux.go @@ -1,6 +1,6 @@ //go:build linux && !android -package iface +package device import ( "fmt" diff --git a/iface/wg_log.go b/client/iface/device/wg_log.go similarity index 93% rename from iface/wg_log.go rename to client/iface/device/wg_log.go index b44f6fc0b..db2f3111f 100644 --- a/iface/wg_log.go +++ b/client/iface/device/wg_log.go @@ -1,4 +1,4 @@ -package iface +package device import ( "os" diff --git a/client/iface/device/windows_guid.go b/client/iface/device/windows_guid.go new file mode 100644 index 000000000..1c7d40d13 --- /dev/null +++ b/client/iface/device/windows_guid.go @@ -0,0 +1,4 @@ +package device + +// CustomWindowsGUIDString is a custom GUID string for the interface +var CustomWindowsGUIDString string diff --git a/client/iface/device_android.go b/client/iface/device_android.go new file mode 100644 index 000000000..3d15080ff --- /dev/null +++ b/client/iface/device_android.go @@ -0,0 +1,16 @@ +package iface + +import ( + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" +) + +type WGTunDevice interface { + Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error) + Up() (*bind.UniversalUDPMuxDefault, error) + UpdateAddr(address WGAddress) error + WgAddress() WGAddress + DeviceName() string + Close() error + FilteredDevice() *device.FilteredDevice +} diff --git a/iface/freebsd/errors.go b/client/iface/freebsd/errors.go similarity index 100% rename from iface/freebsd/errors.go rename to client/iface/freebsd/errors.go diff --git a/iface/freebsd/iface.go b/client/iface/freebsd/iface.go similarity index 100% rename from iface/freebsd/iface.go rename to client/iface/freebsd/iface.go diff --git a/iface/freebsd/iface_internal_test.go b/client/iface/freebsd/iface_internal_test.go similarity index 100% rename from iface/freebsd/iface_internal_test.go rename to client/iface/freebsd/iface_internal_test.go diff --git a/iface/freebsd/link.go b/client/iface/freebsd/link.go similarity index 100% rename from iface/freebsd/link.go rename to client/iface/freebsd/link.go diff --git a/iface/iface.go b/client/iface/iface.go similarity index 59% rename from iface/iface.go rename to client/iface/iface.go index 928077a3d..accf5ce0a 100644 --- a/iface/iface.go +++ b/client/iface/iface.go @@ -9,28 +9,27 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" ) const ( - DefaultMTU = 1280 - DefaultWgPort = 51820 + DefaultMTU = 1280 + DefaultWgPort = 51820 + WgInterfaceDefault = configurer.WgInterfaceDefault ) -// WGIface represents a interface instance +type WGAddress = device.WGAddress + +// WGIface represents an interface instance type WGIface struct { - tun wgTunDevice + tun WGTunDevice userspaceBind bool mu sync.Mutex - configurer wgConfigurer - filter PacketFilter -} - -type WGStats struct { - LastHandshake time.Time - TxBytes int64 - RxBytes int64 + configurer device.WGConfigurer + filter device.PacketFilter } // IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind @@ -44,7 +43,7 @@ func (w *WGIface) Name() string { } // Address returns the interface address -func (w *WGIface) Address() WGAddress { +func (w *WGIface) Address() device.WGAddress { return w.tun.WgAddress() } @@ -75,7 +74,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error { w.mu.Lock() defer w.mu.Unlock() - addr, err := parseWGAddress(newAddr) + addr, err := device.ParseWGAddress(newAddr) if err != nil { return err } @@ -90,7 +89,7 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.D defer w.mu.Unlock() log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint) - return w.configurer.updatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) + return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) } // RemovePeer removes a Wireguard Peer from the interface iface @@ -99,7 +98,7 @@ func (w *WGIface) RemovePeer(peerKey string) error { defer w.mu.Unlock() log.Debugf("Removing peer %s from interface %s ", peerKey, w.tun.DeviceName()) - return w.configurer.removePeer(peerKey) + return w.configurer.RemovePeer(peerKey) } // AddAllowedIP adds a prefix to the allowed IPs list of peer @@ -108,7 +107,7 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error { defer w.mu.Unlock() log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) - return w.configurer.addAllowedIP(peerKey, allowedIP) + return w.configurer.AddAllowedIP(peerKey, allowedIP) } // RemoveAllowedIP removes a prefix from the allowed IPs list of peer @@ -117,34 +116,50 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error { defer w.mu.Unlock() log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) - return w.configurer.removeAllowedIP(peerKey, allowedIP) + return w.configurer.RemoveAllowedIP(peerKey, allowedIP) } // Close closes the tunnel interface func (w *WGIface) Close() error { w.mu.Lock() defer w.mu.Unlock() - return w.tun.Close() + + err := w.tun.Close() + if err != nil { + return fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err) + } + + err = w.waitUntilRemoved() + if err != nil { + log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err) + err = w.Destroy() + if err != nil { + return fmt.Errorf("failed to remove WireGuard interface %s: %w", w.Name(), err) + } + log.Infof("interface %s successfully removed", w.Name()) + } + + return nil } // SetFilter sets packet filters for the userspace implementation -func (w *WGIface) SetFilter(filter PacketFilter) error { +func (w *WGIface) SetFilter(filter device.PacketFilter) error { w.mu.Lock() defer w.mu.Unlock() - if w.tun.Wrapper() == nil { + if w.tun.FilteredDevice() == nil { return fmt.Errorf("userspace packet filtering not handled on this device") } w.filter = filter w.filter.SetNetwork(w.tun.WgAddress().Network) - w.tun.Wrapper().SetFilter(filter) + w.tun.FilteredDevice().SetFilter(filter) return nil } // GetFilter returns packet filter used by interface if it uses userspace device implementation -func (w *WGIface) GetFilter() PacketFilter { +func (w *WGIface) GetFilter() device.PacketFilter { w.mu.Lock() defer w.mu.Unlock() @@ -152,14 +167,41 @@ func (w *WGIface) GetFilter() PacketFilter { } // GetDevice to interact with raw device (with filtering) -func (w *WGIface) GetDevice() *DeviceWrapper { +func (w *WGIface) GetDevice() *device.FilteredDevice { w.mu.Lock() defer w.mu.Unlock() - return w.tun.Wrapper() + return w.tun.FilteredDevice() } // GetStats returns the last handshake time, rx and tx bytes for the given peer -func (w *WGIface) GetStats(peerKey string) (WGStats, error) { - return w.configurer.getStats(peerKey) +func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) { + return w.configurer.GetStats(peerKey) +} + +func (w *WGIface) waitUntilRemoved() error { + maxWaitTime := 5 * time.Second + timeout := time.NewTimer(maxWaitTime) + defer timeout.Stop() + + for { + iface, err := net.InterfaceByName(w.Name()) + if err != nil { + if _, ok := err.(*net.OpError); ok { + log.Infof("interface %s has been removed", w.Name()) + return nil + } + log.Debugf("failed to get interface by name %s: %v", w.Name(), err) + } else if iface == nil { + log.Infof("interface %s has been removed", w.Name()) + return nil + } + + select { + case <-timeout.C: + return fmt.Errorf("timeout when waiting for interface %s to be removed", w.Name()) + default: + time.Sleep(100 * time.Millisecond) + } + } } diff --git a/iface/iface_android.go b/client/iface/iface_android.go similarity index 67% rename from iface/iface_android.go rename to client/iface/iface_android.go index 99f6885a5..5ed476e70 100644 --- a/iface/iface_android.go +++ b/client/iface/iface_android.go @@ -5,18 +5,19 @@ import ( "github.com/pion/transport/v3" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := parseWGAddress(address) +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(address) if err != nil { return nil, err } wgIFace := &WGIface{ - tun: newTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn), + tun: device.NewTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn), userspaceBind: true, } return wgIFace, nil diff --git a/iface/iface_create.go b/client/iface/iface_create.go similarity index 89% rename from iface/iface_create.go rename to client/iface/iface_create.go index cfc555f2e..f389019ed 100644 --- a/iface/iface_create.go +++ b/client/iface/iface_create.go @@ -1,4 +1,4 @@ -//go:build !android +//go:build (!android && !darwin) || ios package iface diff --git a/client/iface/iface_darwin.go b/client/iface/iface_darwin.go new file mode 100644 index 000000000..b46ea0f80 --- /dev/null +++ b/client/iface/iface_darwin.go @@ -0,0 +1,67 @@ +//go:build !ios + +package iface + +import ( + "fmt" + "time" + + "github.com/cenkalti/backoff/v4" + "github.com/pion/transport/v3" + + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" +) + +// NewWGIFace Creates a new WireGuard interface instance +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(address) + if err != nil { + return nil, err + } + + wgIFace := &WGIface{ + userspaceBind: true, + } + + if netstack.IsEnabled() { + wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) + return wgIFace, nil + } + + wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn) + + return wgIFace, nil +} + +// CreateOnAndroid this function make sense on mobile only +func (w *WGIface) CreateOnAndroid([]string, string, []string) error { + return fmt.Errorf("this function has not implemented on this platform") +} + +// Create creates a new Wireguard interface, sets a given IP and brings it up. +// Will reuse an existing one. +// this function is different on Android +func (w *WGIface) Create() error { + w.mu.Lock() + defer w.mu.Unlock() + + backOff := &backoff.ExponentialBackOff{ + InitialInterval: 20 * time.Millisecond, + MaxElapsedTime: 500 * time.Millisecond, + Stop: backoff.Stop, + Clock: backoff.SystemClock, + } + + operation := func() error { + cfgr, err := w.tun.Create() + if err != nil { + return err + } + w.configurer = cfgr + return nil + } + + return backoff.Retry(operation, backOff) +} diff --git a/client/iface/iface_destroy_bsd.go b/client/iface/iface_destroy_bsd.go new file mode 100644 index 000000000..c16010a1c --- /dev/null +++ b/client/iface/iface_destroy_bsd.go @@ -0,0 +1,17 @@ +//go:build darwin || dragonfly || freebsd || netbsd || openbsd + +package iface + +import ( + "fmt" + "os/exec" +) + +func (w *WGIface) Destroy() error { + out, err := exec.Command("ifconfig", w.Name(), "destroy").CombinedOutput() + if err != nil { + return fmt.Errorf("failed to remove interface %s: %w - %s", w.Name(), err, out) + } + + return nil +} diff --git a/client/iface/iface_destroy_linux.go b/client/iface/iface_destroy_linux.go new file mode 100644 index 000000000..e9d54bed1 --- /dev/null +++ b/client/iface/iface_destroy_linux.go @@ -0,0 +1,22 @@ +//go:build linux && !android + +package iface + +import ( + "fmt" + + "github.com/vishvananda/netlink" +) + +func (w *WGIface) Destroy() error { + link, err := netlink.LinkByName(w.Name()) + if err != nil { + return fmt.Errorf("failed to get link by name %s: %w", w.Name(), err) + } + + if err := netlink.LinkDel(link); err != nil { + return fmt.Errorf("failed to delete link %s: %w", w.Name(), err) + } + + return nil +} diff --git a/client/iface/iface_destroy_mobile.go b/client/iface/iface_destroy_mobile.go new file mode 100644 index 000000000..89f87a598 --- /dev/null +++ b/client/iface/iface_destroy_mobile.go @@ -0,0 +1,9 @@ +//go:build android || (ios && !darwin) + +package iface + +import "errors" + +func (w *WGIface) Destroy() error { + return errors.New("not supported on mobile") +} diff --git a/client/iface/iface_destroy_windows.go b/client/iface/iface_destroy_windows.go new file mode 100644 index 000000000..0bfa4e211 --- /dev/null +++ b/client/iface/iface_destroy_windows.go @@ -0,0 +1,32 @@ +//go:build windows + +package iface + +import ( + "fmt" + "os/exec" + + log "github.com/sirupsen/logrus" +) + +func (w *WGIface) Destroy() error { + netshCmd := GetSystem32Command("netsh") + out, err := exec.Command(netshCmd, "interface", "set", "interface", w.Name(), "admin=disable").CombinedOutput() + if err != nil { + return fmt.Errorf("failed to remove interface %s: %w - %s", w.Name(), err, out) + } + return nil +} + +// GetSystem32Command checks if a command can be found in the system path and returns it. In case it can't find it +// in the path it will return the full path of a command assuming C:\windows\system32 as the base path. +func GetSystem32Command(command string) string { + _, err := exec.LookPath(command) + if err == nil { + return command + } + + log.Tracef("Command %s not found in PATH, using C:\\windows\\system32\\%s.exe path", command, command) + + return "C:\\windows\\system32\\" + command + ".exe" +} diff --git a/iface/iface_ios.go b/client/iface/iface_ios.go similarity index 59% rename from iface/iface_ios.go rename to client/iface/iface_ios.go index 6babe5964..fc0214748 100644 --- a/iface/iface_ios.go +++ b/client/iface/iface_ios.go @@ -7,17 +7,18 @@ import ( "github.com/pion/transport/v3" - "github.com/netbirdio/netbird/iface/bind" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := parseWGAddress(address) +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(address) if err != nil { return nil, err } wgIFace := &WGIface{ - tun: newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd, filterFn), + tun: device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd, filterFn), userspaceBind: true, } return wgIFace, nil diff --git a/client/iface/iface_moc.go b/client/iface/iface_moc.go new file mode 100644 index 000000000..703da9ce0 --- /dev/null +++ b/client/iface/iface_moc.go @@ -0,0 +1,105 @@ +package iface + +import ( + "net" + "time" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" +) + +type MockWGIface struct { + CreateFunc func() error + CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error + IsUserspaceBindFunc func() bool + NameFunc func() string + AddressFunc func() device.WGAddress + ToInterfaceFunc func() *net.Interface + UpFunc func() (*bind.UniversalUDPMuxDefault, error) + UpdateAddrFunc func(newAddr string) error + UpdatePeerFunc func(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + RemovePeerFunc func(peerKey string) error + AddAllowedIPFunc func(peerKey string, allowedIP string) error + RemoveAllowedIPFunc func(peerKey string, allowedIP string) error + CloseFunc func() error + SetFilterFunc func(filter device.PacketFilter) error + GetFilterFunc func() device.PacketFilter + GetDeviceFunc func() *device.FilteredDevice + GetStatsFunc func(peerKey string) (configurer.WGStats, error) + GetInterfaceGUIDStringFunc func() (string, error) +} + +func (m *MockWGIface) GetInterfaceGUIDString() (string, error) { + return m.GetInterfaceGUIDStringFunc() +} + +func (m *MockWGIface) Create() error { + return m.CreateFunc() +} + +func (m *MockWGIface) CreateOnAndroid(routeRange []string, ip string, domains []string) error { + return m.CreateOnAndroidFunc(routeRange, ip, domains) +} + +func (m *MockWGIface) IsUserspaceBind() bool { + return m.IsUserspaceBindFunc() +} + +func (m *MockWGIface) Name() string { + return m.NameFunc() +} + +func (m *MockWGIface) Address() device.WGAddress { + return m.AddressFunc() +} + +func (m *MockWGIface) ToInterface() *net.Interface { + return m.ToInterfaceFunc() +} + +func (m *MockWGIface) Up() (*bind.UniversalUDPMuxDefault, error) { + return m.UpFunc() +} + +func (m *MockWGIface) UpdateAddr(newAddr string) error { + return m.UpdateAddrFunc(newAddr) +} + +func (m *MockWGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { + return m.UpdatePeerFunc(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) +} + +func (m *MockWGIface) RemovePeer(peerKey string) error { + return m.RemovePeerFunc(peerKey) +} + +func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP string) error { + return m.AddAllowedIPFunc(peerKey, allowedIP) +} + +func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP string) error { + return m.RemoveAllowedIPFunc(peerKey, allowedIP) +} + +func (m *MockWGIface) Close() error { + return m.CloseFunc() +} + +func (m *MockWGIface) SetFilter(filter device.PacketFilter) error { + return m.SetFilterFunc(filter) +} + +func (m *MockWGIface) GetFilter() device.PacketFilter { + return m.GetFilterFunc() +} + +func (m *MockWGIface) GetDevice() *device.FilteredDevice { + return m.GetDeviceFunc() +} + +func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) { + return m.GetStatsFunc(peerKey) +} diff --git a/iface/iface_test.go b/client/iface/iface_test.go similarity index 84% rename from iface/iface_test.go rename to client/iface/iface_test.go index 43c44b770..87a68addb 100644 --- a/iface/iface_test.go +++ b/client/iface/iface_test.go @@ -4,14 +4,18 @@ import ( "fmt" "net" "net/netip" + "strings" "testing" "time" + "github.com/google/uuid" "github.com/pion/transport/v3/stdnet" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/iface/device" ) // keep darwin compatibility @@ -174,6 +178,72 @@ func Test_Close(t *testing.T) { } } +func TestRecreation(t *testing.T) { + for i := 0; i < 100; i++ { + t.Run(fmt.Sprintf("down-%d", i), func(t *testing.T) { + ifaceName := fmt.Sprintf("utun%d", WgIntNumber+2) + wgIP := "10.99.99.2/32" + wgPort := 33100 + newNet, err := stdnet.NewNet() + if err != nil { + t.Fatal(err) + } + + iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil) + if err != nil { + t.Fatal(err) + } + + for { + _, err = net.InterfaceByName(ifaceName) + if err != nil { + t.Logf("interface %s not found: err: %s", ifaceName, err) + break + } + t.Logf("interface %s found", ifaceName) + } + + err = iface.Create() + if err != nil { + t.Fatal(err) + } + wg, err := wgctrl.New() + if err != nil { + t.Fatal(err) + } + defer func() { + err = wg.Close() + if err != nil { + t.Error(err) + } + }() + + _, err = iface.Up() + if err != nil { + t.Fatal(err) + } + + for { + _, err = net.InterfaceByName(ifaceName) + if err == nil { + t.Logf("interface %s found", ifaceName) + + break + } + t.Logf("interface %s not found: err: %s", ifaceName, err) + + } + + start := time.Now() + err = iface.Close() + t.Logf("down time: %s", time.Since(start)) + if err != nil { + t.Fatal(err) + } + }) + } +} + func Test_ConfigureInterface(t *testing.T) { ifaceName := fmt.Sprintf("utun%d", WgIntNumber+3) wgIP := "10.99.99.5/30" @@ -345,6 +415,9 @@ func Test_ConnectPeers(t *testing.T) { t.Fatal(err) } + guid := fmt.Sprintf("{%s}", uuid.New().String()) + device.CustomWindowsGUIDString = strings.ToLower(guid) + iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, peer1wgPort, peer1Key.String(), DefaultMTU, newNet, nil, nil) if err != nil { t.Fatal(err) @@ -364,6 +437,9 @@ func Test_ConnectPeers(t *testing.T) { t.Fatal(err) } + guid = fmt.Sprintf("{%s}", uuid.New().String()) + device.CustomWindowsGUIDString = strings.ToLower(guid) + newNet, err = stdnet.NewNet() if err != nil { t.Fatal(err) diff --git a/iface/iface_unix.go b/client/iface/iface_unix.go similarity index 53% rename from iface/iface_unix.go rename to client/iface/iface_unix.go index 9608df1ad..09dbb2c1f 100644 --- a/iface/iface_unix.go +++ b/client/iface/iface_unix.go @@ -8,13 +8,14 @@ import ( "github.com/pion/transport/v3" - "github.com/netbirdio/netbird/iface/bind" - "github.com/netbirdio/netbird/iface/netstack" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := parseWGAddress(address) +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(address) if err != nil { return nil, err } @@ -23,21 +24,21 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, // move the kernel/usp/netstack preference evaluation to upper layer if netstack.IsEnabled() { - wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) + wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) wgIFace.userspaceBind = true return wgIFace, nil } - if WireGuardModuleIsLoaded() { - wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet) + if device.WireGuardModuleIsLoaded() { + wgIFace.tun = device.NewKernelDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet) wgIFace.userspaceBind = false return wgIFace, nil } - if !tunModuleIsLoaded() { + if !device.ModuleTunIsLoaded() { return nil, fmt.Errorf("couldn't check or load tun module") } - wgIFace.tun = newTunUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil) + wgIFace.tun = device.NewUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil) wgIFace.userspaceBind = true return wgIFace, nil } diff --git a/iface/iface_windows.go b/client/iface/iface_windows.go similarity index 52% rename from iface/iface_windows.go rename to client/iface/iface_windows.go index c5edd27a9..6845ef3dd 100644 --- a/iface/iface_windows.go +++ b/client/iface/iface_windows.go @@ -5,13 +5,14 @@ import ( "github.com/pion/transport/v3" - "github.com/netbirdio/netbird/iface/bind" - "github.com/netbirdio/netbird/iface/netstack" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" ) // NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := parseWGAddress(address) +func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { + wgAddress, err := device.ParseWGAddress(address) if err != nil { return nil, err } @@ -21,11 +22,11 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, } if netstack.IsEnabled() { - wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) + wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) return wgIFace, nil } - wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn) + wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn) return wgIFace, nil } @@ -36,5 +37,5 @@ func (w *WGIface) CreateOnAndroid([]string, string, []string) error { // GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only func (w *WGIface) GetInterfaceGUIDString() (string, error) { - return w.tun.(*tunDevice).getInterfaceGUIDString() + return w.tun.(*device.TunDevice).GetInterfaceGUIDString() } diff --git a/client/iface/iwginterface.go b/client/iface/iwginterface.go new file mode 100644 index 000000000..cb6d7ccd9 --- /dev/null +++ b/client/iface/iwginterface.go @@ -0,0 +1,34 @@ +//go:build !windows + +package iface + +import ( + "net" + "time" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" +) + +type IWGIface interface { + Create() error + CreateOnAndroid(routeRange []string, ip string, domains []string) error + IsUserspaceBind() bool + Name() string + Address() device.WGAddress + ToInterface() *net.Interface + Up() (*bind.UniversalUDPMuxDefault, error) + UpdateAddr(newAddr string) error + UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + RemovePeer(peerKey string) error + AddAllowedIP(peerKey string, allowedIP string) error + RemoveAllowedIP(peerKey string, allowedIP string) error + Close() error + SetFilter(filter device.PacketFilter) error + GetFilter() device.PacketFilter + GetDevice() *device.FilteredDevice + GetStats(peerKey string) (configurer.WGStats, error) +} diff --git a/client/iface/iwginterface_windows.go b/client/iface/iwginterface_windows.go new file mode 100644 index 000000000..6baeb66ae --- /dev/null +++ b/client/iface/iwginterface_windows.go @@ -0,0 +1,33 @@ +package iface + +import ( + "net" + "time" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" +) + +type IWGIface interface { + Create() error + CreateOnAndroid(routeRange []string, ip string, domains []string) error + IsUserspaceBind() bool + Name() string + Address() device.WGAddress + ToInterface() *net.Interface + Up() (*bind.UniversalUDPMuxDefault, error) + UpdateAddr(newAddr string) error + UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + RemovePeer(peerKey string) error + AddAllowedIP(peerKey string, allowedIP string) error + RemoveAllowedIP(peerKey string, allowedIP string) error + Close() error + SetFilter(filter device.PacketFilter) error + GetFilter() device.PacketFilter + GetDevice() *device.FilteredDevice + GetStats(peerKey string) (configurer.WGStats, error) + GetInterfaceGUIDString() (string, error) +} diff --git a/iface/mocks/README.md b/client/iface/mocks/README.md similarity index 100% rename from iface/mocks/README.md rename to client/iface/mocks/README.md diff --git a/iface/mocks/filter.go b/client/iface/mocks/filter.go similarity index 97% rename from iface/mocks/filter.go rename to client/iface/mocks/filter.go index 2d80d69f1..6348e0e77 100644 --- a/iface/mocks/filter.go +++ b/client/iface/mocks/filter.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/netbirdio/netbird/iface (interfaces: PacketFilter) +// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter) // Package mocks is a generated GoMock package. package mocks diff --git a/iface/mocks/iface/mocks/filter.go b/client/iface/mocks/iface/mocks/filter.go similarity index 97% rename from iface/mocks/iface/mocks/filter.go rename to client/iface/mocks/iface/mocks/filter.go index 059a2b9a0..17e123abb 100644 --- a/iface/mocks/iface/mocks/filter.go +++ b/client/iface/mocks/iface/mocks/filter.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/netbirdio/netbird/iface (interfaces: PacketFilter) +// Source: github.com/netbirdio/netbird/client/iface (interfaces: PacketFilter) // Package mocks is a generated GoMock package. package mocks diff --git a/iface/mocks/tun.go b/client/iface/mocks/tun.go similarity index 100% rename from iface/mocks/tun.go rename to client/iface/mocks/tun.go diff --git a/iface/netstack/dialer.go b/client/iface/netstack/dialer.go similarity index 100% rename from iface/netstack/dialer.go rename to client/iface/netstack/dialer.go diff --git a/iface/netstack/env.go b/client/iface/netstack/env.go similarity index 100% rename from iface/netstack/env.go rename to client/iface/netstack/env.go diff --git a/iface/netstack/proxy.go b/client/iface/netstack/proxy.go similarity index 100% rename from iface/netstack/proxy.go rename to client/iface/netstack/proxy.go diff --git a/iface/netstack/tun.go b/client/iface/netstack/tun.go similarity index 100% rename from iface/netstack/tun.go rename to client/iface/netstack/tun.go diff --git a/client/internal/acl/id/id.go b/client/internal/acl/id/id.go new file mode 100644 index 000000000..e27fce439 --- /dev/null +++ b/client/internal/acl/id/id.go @@ -0,0 +1,25 @@ +package id + +import ( + "fmt" + "net/netip" + + "github.com/netbirdio/netbird/client/firewall/manager" +) + +type RuleID string + +func (r RuleID) GetRuleID() string { + return string(r) +} + +func GenerateRouteRuleKey( + sources []netip.Prefix, + destination netip.Prefix, + proto manager.Protocol, + sPort *manager.Port, + dPort *manager.Port, + action manager.Action, +) RuleID { + return RuleID(fmt.Sprintf("%s-%s-%s-%s-%s-%d", sources, destination, proto, sPort, dPort, action)) +} diff --git a/client/internal/acl/manager.go b/client/internal/acl/manager.go index fd2c2c875..ce2a12af1 100644 --- a/client/internal/acl/manager.go +++ b/client/internal/acl/manager.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "fmt" "net" + "net/netip" "strconv" "sync" "time" @@ -12,6 +13,7 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/ssh" mgmProto "github.com/netbirdio/netbird/management/proto" ) @@ -23,16 +25,18 @@ type Manager interface { // DefaultManager uses firewall manager to handle type DefaultManager struct { - firewall firewall.Manager - ipsetCounter int - rulesPairs map[string][]firewall.Rule - mutex sync.Mutex + firewall firewall.Manager + ipsetCounter int + peerRulesPairs map[id.RuleID][]firewall.Rule + routeRules map[id.RuleID]struct{} + mutex sync.Mutex } func NewDefaultManager(fm firewall.Manager) *DefaultManager { return &DefaultManager{ - firewall: fm, - rulesPairs: make(map[string][]firewall.Rule), + firewall: fm, + peerRulesPairs: make(map[id.RuleID][]firewall.Rule), + routeRules: make(map[id.RuleID]struct{}), } } @@ -46,7 +50,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { start := time.Now() defer func() { total := 0 - for _, pairs := range d.rulesPairs { + for _, pairs := range d.peerRulesPairs { total += len(pairs) } log.Infof( @@ -59,21 +63,34 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { return } - defer func() { - if err := d.firewall.Flush(); err != nil { - log.Error("failed to flush firewall rules: ", err) - } - }() + d.applyPeerACLs(networkMap) + // If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag, + // then the mgmt server is older than the client, and we need to allow all traffic for routes + isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty + if err := d.firewall.SetLegacyManagement(isLegacy); err != nil { + log.Errorf("failed to set legacy management flag: %v", err) + } + + if err := d.applyRouteACLs(networkMap.RoutesFirewallRules); err != nil { + log.Errorf("Failed to apply route ACLs: %v", err) + } + + if err := d.firewall.Flush(); err != nil { + log.Error("failed to flush firewall rules: ", err) + } +} + +func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { rules, squashedProtocols := d.squashAcceptRules(networkMap) - enableSSH := (networkMap.PeerConfig != nil && + enableSSH := networkMap.PeerConfig != nil && networkMap.PeerConfig.SshConfig != nil && - networkMap.PeerConfig.SshConfig.SshEnabled) - if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok { + networkMap.PeerConfig.SshConfig.SshEnabled + if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok { enableSSH = enableSSH && !ok } - if _, ok := squashedProtocols[mgmProto.FirewallRule_TCP]; ok { + if _, ok := squashedProtocols[mgmProto.RuleProtocol_TCP]; ok { enableSSH = enableSSH && !ok } @@ -83,9 +100,9 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { if enableSSH { rules = append(rules, &mgmProto.FirewallRule{ PeerIP: "0.0.0.0", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_TCP, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, Port: strconv.Itoa(ssh.DefaultSSHPort), }) } @@ -97,20 +114,20 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { rules = append(rules, &mgmProto.FirewallRule{ PeerIP: "0.0.0.0", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, &mgmProto.FirewallRule{ PeerIP: "0.0.0.0", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, ) } - newRulePairs := make(map[string][]firewall.Rule) + newRulePairs := make(map[id.RuleID][]firewall.Rule) ipsetByRuleSelectors := make(map[string]string) for _, r := range rules { @@ -130,29 +147,97 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { break } if len(rules) > 0 { - d.rulesPairs[pairID] = rulePair + d.peerRulesPairs[pairID] = rulePair newRulePairs[pairID] = rulePair } } - for pairID, rules := range d.rulesPairs { + for pairID, rules := range d.peerRulesPairs { if _, ok := newRulePairs[pairID]; !ok { for _, rule := range rules { - if err := d.firewall.DeleteRule(rule); err != nil { - log.Errorf("failed to delete firewall rule: %v", err) + if err := d.firewall.DeletePeerRule(rule); err != nil { + log.Errorf("failed to delete peer firewall rule: %v", err) continue } } - delete(d.rulesPairs, pairID) + delete(d.peerRulesPairs, pairID) } } - d.rulesPairs = newRulePairs + d.peerRulesPairs = newRulePairs +} + +func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error { + var newRouteRules = make(map[id.RuleID]struct{}) + for _, rule := range rules { + id, err := d.applyRouteACL(rule) + if err != nil { + return fmt.Errorf("apply route ACL: %w", err) + } + newRouteRules[id] = struct{}{} + } + + for id := range d.routeRules { + if _, ok := newRouteRules[id]; !ok { + if err := d.firewall.DeleteRouteRule(id); err != nil { + log.Errorf("failed to delete route firewall rule: %v", err) + continue + } + delete(d.routeRules, id) + } + } + d.routeRules = newRouteRules + return nil +} + +func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) { + if len(rule.SourceRanges) == 0 { + return "", fmt.Errorf("source ranges is empty") + } + + var sources []netip.Prefix + for _, sourceRange := range rule.SourceRanges { + source, err := netip.ParsePrefix(sourceRange) + if err != nil { + return "", fmt.Errorf("parse source range: %w", err) + } + sources = append(sources, source) + } + + var destination netip.Prefix + if rule.IsDynamic { + destination = getDefault(sources[0]) + } else { + var err error + destination, err = netip.ParsePrefix(rule.Destination) + if err != nil { + return "", fmt.Errorf("parse destination: %w", err) + } + } + + protocol, err := convertToFirewallProtocol(rule.Protocol) + if err != nil { + return "", fmt.Errorf("invalid protocol: %w", err) + } + + action, err := convertFirewallAction(rule.Action) + if err != nil { + return "", fmt.Errorf("invalid action: %w", err) + } + + dPorts := convertPortInfo(rule.PortInfo) + + addedRule, err := d.firewall.AddRouteFiltering(sources, destination, protocol, nil, dPorts, action) + if err != nil { + return "", fmt.Errorf("add route rule: %w", err) + } + + return id.RuleID(addedRule.GetRuleID()), nil } func (d *DefaultManager) protoRuleToFirewallRule( r *mgmProto.FirewallRule, ipsetName string, -) (string, []firewall.Rule, error) { +) (id.RuleID, []firewall.Rule, error) { ip := net.ParseIP(r.PeerIP) if ip == nil { return "", nil, fmt.Errorf("invalid IP address, skipping firewall rule") @@ -179,16 +264,16 @@ func (d *DefaultManager) protoRuleToFirewallRule( } } - ruleID := d.getRuleID(ip, protocol, int(r.Direction), port, action, "") - if rulesPair, ok := d.rulesPairs[ruleID]; ok { + ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action, "") + if rulesPair, ok := d.peerRulesPairs[ruleID]; ok { return ruleID, rulesPair, nil } var rules []firewall.Rule switch r.Direction { - case mgmProto.FirewallRule_IN: + case mgmProto.RuleDirection_IN: rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "") - case mgmProto.FirewallRule_OUT: + case mgmProto.RuleDirection_OUT: rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "") default: return "", nil, fmt.Errorf("invalid direction, skipping firewall rule") @@ -210,7 +295,7 @@ func (d *DefaultManager) addInRules( comment string, ) ([]firewall.Rule, error) { var rules []firewall.Rule - rule, err := d.firewall.AddFiltering( + rule, err := d.firewall.AddPeerFiltering( ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment) if err != nil { return nil, fmt.Errorf("failed to add firewall rule: %v", err) @@ -221,7 +306,7 @@ func (d *DefaultManager) addInRules( return rules, nil } - rule, err = d.firewall.AddFiltering( + rule, err = d.firewall.AddPeerFiltering( ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment) if err != nil { return nil, fmt.Errorf("failed to add firewall rule: %v", err) @@ -239,7 +324,7 @@ func (d *DefaultManager) addOutRules( comment string, ) ([]firewall.Rule, error) { var rules []firewall.Rule - rule, err := d.firewall.AddFiltering( + rule, err := d.firewall.AddPeerFiltering( ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment) if err != nil { return nil, fmt.Errorf("failed to add firewall rule: %v", err) @@ -250,7 +335,7 @@ func (d *DefaultManager) addOutRules( return rules, nil } - rule, err = d.firewall.AddFiltering( + rule, err = d.firewall.AddPeerFiltering( ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment) if err != nil { return nil, fmt.Errorf("failed to add firewall rule: %v", err) @@ -259,21 +344,21 @@ func (d *DefaultManager) addOutRules( return append(rules, rule...), nil } -// getRuleID() returns unique ID for the rule based on its parameters. -func (d *DefaultManager) getRuleID( +// getPeerRuleID() returns unique ID for the rule based on its parameters. +func (d *DefaultManager) getPeerRuleID( ip net.IP, proto firewall.Protocol, direction int, port *firewall.Port, action firewall.Action, comment string, -) string { +) id.RuleID { idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment if port != nil { idStr += port.String() } - return hex.EncodeToString(md5.New().Sum([]byte(idStr))) + return id.RuleID(hex.EncodeToString(md5.New().Sum([]byte(idStr)))) } // squashAcceptRules does complex logic to convert many rules which allows connection by traffic type @@ -283,7 +368,7 @@ func (d *DefaultManager) getRuleID( // but other has port definitions or has drop policy. func (d *DefaultManager) squashAcceptRules( networkMap *mgmProto.NetworkMap, -) ([]*mgmProto.FirewallRule, map[mgmProto.FirewallRuleProtocol]struct{}) { +) ([]*mgmProto.FirewallRule, map[mgmProto.RuleProtocol]struct{}) { totalIPs := 0 for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) { for range p.AllowedIps { @@ -291,14 +376,14 @@ func (d *DefaultManager) squashAcceptRules( } } - type protoMatch map[mgmProto.FirewallRuleProtocol]map[string]int + type protoMatch map[mgmProto.RuleProtocol]map[string]int in := protoMatch{} out := protoMatch{} // trace which type of protocols was squashed squashedRules := []*mgmProto.FirewallRule{} - squashedProtocols := map[mgmProto.FirewallRuleProtocol]struct{}{} + squashedProtocols := map[mgmProto.RuleProtocol]struct{}{} // this function we use to do calculation, can we squash the rules by protocol or not. // We summ amount of Peers IP for given protocol we found in original rules list. @@ -308,7 +393,7 @@ func (d *DefaultManager) squashAcceptRules( // // We zeroed this to notify squash function that this protocol can't be squashed. addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols protoMatch) { - drop := r.Action == mgmProto.FirewallRule_DROP || r.Port != "" + drop := r.Action == mgmProto.RuleAction_DROP || r.Port != "" if drop { protocols[r.Protocol] = map[string]int{} return @@ -336,7 +421,7 @@ func (d *DefaultManager) squashAcceptRules( for i, r := range networkMap.FirewallRules { // calculate squash for different directions - if r.Direction == mgmProto.FirewallRule_IN { + if r.Direction == mgmProto.RuleDirection_IN { addRuleToCalculationMap(i, r, in) } else { addRuleToCalculationMap(i, r, out) @@ -345,14 +430,14 @@ func (d *DefaultManager) squashAcceptRules( // order of squashing by protocol is important // only for their first element ALL, it must be done first - protocolOrders := []mgmProto.FirewallRuleProtocol{ - mgmProto.FirewallRule_ALL, - mgmProto.FirewallRule_ICMP, - mgmProto.FirewallRule_TCP, - mgmProto.FirewallRule_UDP, + protocolOrders := []mgmProto.RuleProtocol{ + mgmProto.RuleProtocol_ALL, + mgmProto.RuleProtocol_ICMP, + mgmProto.RuleProtocol_TCP, + mgmProto.RuleProtocol_UDP, } - squash := func(matches protoMatch, direction mgmProto.FirewallRuleDirection) { + squash := func(matches protoMatch, direction mgmProto.RuleDirection) { for _, protocol := range protocolOrders { if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 { // don't squash if : @@ -365,12 +450,12 @@ func (d *DefaultManager) squashAcceptRules( squashedRules = append(squashedRules, &mgmProto.FirewallRule{ PeerIP: "0.0.0.0", Direction: direction, - Action: mgmProto.FirewallRule_ACCEPT, + Action: mgmProto.RuleAction_ACCEPT, Protocol: protocol, }) squashedProtocols[protocol] = struct{}{} - if protocol == mgmProto.FirewallRule_ALL { + if protocol == mgmProto.RuleProtocol_ALL { // if we have ALL traffic type squashed rule // it allows all other type of traffic, so we can stop processing break @@ -378,11 +463,11 @@ func (d *DefaultManager) squashAcceptRules( } } - squash(in, mgmProto.FirewallRule_IN) - squash(out, mgmProto.FirewallRule_OUT) + squash(in, mgmProto.RuleDirection_IN) + squash(out, mgmProto.RuleDirection_OUT) // if all protocol was squashed everything is allow and we can ignore all other rules - if _, ok := squashedProtocols[mgmProto.FirewallRule_ALL]; ok { + if _, ok := squashedProtocols[mgmProto.RuleProtocol_ALL]; ok { return squashedRules, squashedProtocols } @@ -412,26 +497,26 @@ func (d *DefaultManager) getRuleGroupingSelector(rule *mgmProto.FirewallRule) st return fmt.Sprintf("%v:%v:%v:%s", strconv.Itoa(int(rule.Direction)), rule.Action, rule.Protocol, rule.Port) } -func (d *DefaultManager) rollBack(newRulePairs map[string][]firewall.Rule) { +func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) { log.Debugf("rollback ACL to previous state") for _, rules := range newRulePairs { for _, rule := range rules { - if err := d.firewall.DeleteRule(rule); err != nil { + if err := d.firewall.DeletePeerRule(rule); err != nil { log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err) } } } } -func convertToFirewallProtocol(protocol mgmProto.FirewallRuleProtocol) (firewall.Protocol, error) { +func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewall.Protocol, error) { switch protocol { - case mgmProto.FirewallRule_TCP: + case mgmProto.RuleProtocol_TCP: return firewall.ProtocolTCP, nil - case mgmProto.FirewallRule_UDP: + case mgmProto.RuleProtocol_UDP: return firewall.ProtocolUDP, nil - case mgmProto.FirewallRule_ICMP: + case mgmProto.RuleProtocol_ICMP: return firewall.ProtocolICMP, nil - case mgmProto.FirewallRule_ALL: + case mgmProto.RuleProtocol_ALL: return firewall.ProtocolALL, nil default: return firewall.ProtocolALL, fmt.Errorf("invalid protocol type: %s", protocol.String()) @@ -442,13 +527,41 @@ func shouldSkipInvertedRule(protocol firewall.Protocol, port *firewall.Port) boo return protocol == firewall.ProtocolALL || protocol == firewall.ProtocolICMP || port == nil } -func convertFirewallAction(action mgmProto.FirewallRuleAction) (firewall.Action, error) { +func convertFirewallAction(action mgmProto.RuleAction) (firewall.Action, error) { switch action { - case mgmProto.FirewallRule_ACCEPT: + case mgmProto.RuleAction_ACCEPT: return firewall.ActionAccept, nil - case mgmProto.FirewallRule_DROP: + case mgmProto.RuleAction_DROP: return firewall.ActionDrop, nil default: return firewall.ActionDrop, fmt.Errorf("invalid action type: %d", action) } } + +func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port { + if portInfo == nil { + return nil + } + + if portInfo.GetPort() != 0 { + return &firewall.Port{ + Values: []int{int(portInfo.GetPort())}, + } + } + + if portInfo.GetRange() != nil { + return &firewall.Port{ + IsRange: true, + Values: []int{int(portInfo.GetRange().Start), int(portInfo.GetRange().End)}, + } + } + + return nil +} + +func getDefault(prefix netip.Prefix) netip.Prefix { + if prefix.Addr().Is6() { + return netip.PrefixFrom(netip.IPv6Unspecified(), 0) + } + return netip.PrefixFrom(netip.IPv4Unspecified(), 0) +} diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 494d54bf2..7d999669a 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -9,8 +9,8 @@ import ( "github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/acl/mocks" - "github.com/netbirdio/netbird/iface" mgmProto "github.com/netbirdio/netbird/management/proto" ) @@ -19,16 +19,16 @@ func TestDefaultManager(t *testing.T) { FirewallRules: []*mgmProto.FirewallRule{ { PeerIP: "10.93.0.1", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_TCP, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, Port: "80", }, { PeerIP: "10.93.0.2", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_DROP, - Protocol: mgmProto.FirewallRule_UDP, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_DROP, + Protocol: mgmProto.RuleProtocol_UDP, Port: "53", }, }, @@ -65,16 +65,16 @@ func TestDefaultManager(t *testing.T) { t.Run("apply firewall rules", func(t *testing.T) { acl.ApplyFiltering(networkMap) - if len(acl.rulesPairs) != 2 { - t.Errorf("firewall rules not applied: %v", acl.rulesPairs) + if len(acl.peerRulesPairs) != 2 { + t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs) return } }) t.Run("add extra rules", func(t *testing.T) { existedPairs := map[string]struct{}{} - for id := range acl.rulesPairs { - existedPairs[id] = struct{}{} + for id := range acl.peerRulesPairs { + existedPairs[id.GetRuleID()] = struct{}{} } // remove first rule @@ -83,24 +83,24 @@ func TestDefaultManager(t *testing.T) { networkMap.FirewallRules, &mgmProto.FirewallRule{ PeerIP: "10.93.0.3", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_DROP, - Protocol: mgmProto.FirewallRule_ICMP, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_DROP, + Protocol: mgmProto.RuleProtocol_ICMP, }, ) acl.ApplyFiltering(networkMap) // we should have one old and one new rule in the existed rules - if len(acl.rulesPairs) != 2 { + if len(acl.peerRulesPairs) != 2 { t.Errorf("firewall rules not applied") return } // check that old rule was removed previousCount := 0 - for id := range acl.rulesPairs { - if _, ok := existedPairs[id]; ok { + for id := range acl.peerRulesPairs { + if _, ok := existedPairs[id.GetRuleID()]; ok { previousCount++ } } @@ -113,15 +113,15 @@ func TestDefaultManager(t *testing.T) { networkMap.FirewallRules = networkMap.FirewallRules[:0] networkMap.FirewallRulesIsEmpty = true - if acl.ApplyFiltering(networkMap); len(acl.rulesPairs) != 0 { - t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.rulesPairs)) + if acl.ApplyFiltering(networkMap); len(acl.peerRulesPairs) != 0 { + t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs)) return } networkMap.FirewallRulesIsEmpty = false acl.ApplyFiltering(networkMap) - if len(acl.rulesPairs) != 2 { - t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.rulesPairs)) + if len(acl.peerRulesPairs) != 2 { + t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs)) return } }) @@ -138,51 +138,51 @@ func TestDefaultManagerSquashRules(t *testing.T) { FirewallRules: []*mgmProto.FirewallRule{ { PeerIP: "10.93.0.1", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.2", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.3", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.4", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.1", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.2", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.3", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.4", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, }, } @@ -199,13 +199,13 @@ func TestDefaultManagerSquashRules(t *testing.T) { case r.PeerIP != "0.0.0.0": t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP) return - case r.Direction != mgmProto.FirewallRule_IN: + case r.Direction != mgmProto.RuleDirection_IN: t.Errorf("direction should be IN, got: %v", r.Direction) return - case r.Protocol != mgmProto.FirewallRule_ALL: + case r.Protocol != mgmProto.RuleProtocol_ALL: t.Errorf("protocol should be ALL, got: %v", r.Protocol) return - case r.Action != mgmProto.FirewallRule_ACCEPT: + case r.Action != mgmProto.RuleAction_ACCEPT: t.Errorf("action should be ACCEPT, got: %v", r.Action) return } @@ -215,13 +215,13 @@ func TestDefaultManagerSquashRules(t *testing.T) { case r.PeerIP != "0.0.0.0": t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP) return - case r.Direction != mgmProto.FirewallRule_OUT: + case r.Direction != mgmProto.RuleDirection_OUT: t.Errorf("direction should be OUT, got: %v", r.Direction) return - case r.Protocol != mgmProto.FirewallRule_ALL: + case r.Protocol != mgmProto.RuleProtocol_ALL: t.Errorf("protocol should be ALL, got: %v", r.Protocol) return - case r.Action != mgmProto.FirewallRule_ACCEPT: + case r.Action != mgmProto.RuleAction_ACCEPT: t.Errorf("action should be ACCEPT, got: %v", r.Action) return } @@ -238,51 +238,51 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) { FirewallRules: []*mgmProto.FirewallRule{ { PeerIP: "10.93.0.1", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.2", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.3", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.4", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_TCP, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, }, { PeerIP: "10.93.0.1", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.2", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.3", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_ALL, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_ALL, }, { PeerIP: "10.93.0.4", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_UDP, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_UDP, }, }, } @@ -308,21 +308,21 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { FirewallRules: []*mgmProto.FirewallRule{ { PeerIP: "10.93.0.1", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_TCP, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, }, { PeerIP: "10.93.0.2", - Direction: mgmProto.FirewallRule_IN, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_TCP, + Direction: mgmProto.RuleDirection_IN, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_TCP, }, { PeerIP: "10.93.0.3", - Direction: mgmProto.FirewallRule_OUT, - Action: mgmProto.FirewallRule_ACCEPT, - Protocol: mgmProto.FirewallRule_UDP, + Direction: mgmProto.RuleDirection_OUT, + Action: mgmProto.RuleAction_ACCEPT, + Protocol: mgmProto.RuleProtocol_UDP, }, }, } @@ -357,8 +357,8 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { acl.ApplyFiltering(networkMap) - if len(acl.rulesPairs) != 4 { - t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.rulesPairs)) + if len(acl.peerRulesPairs) != 4 { + t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.peerRulesPairs)) return } } diff --git a/client/internal/acl/mocks/iface_mapper.go b/client/internal/acl/mocks/iface_mapper.go index 621b29513..3ed12b6dd 100644 --- a/client/internal/acl/mocks/iface_mapper.go +++ b/client/internal/acl/mocks/iface_mapper.go @@ -8,7 +8,8 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - iface "github.com/netbirdio/netbird/iface" + iface "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" ) // MockIFaceMapper is a mock of IFaceMapper interface. @@ -77,7 +78,7 @@ func (mr *MockIFaceMapperMockRecorder) Name() *gomock.Call { } // SetFilter mocks base method. -func (m *MockIFaceMapper) SetFilter(arg0 iface.PacketFilter) error { +func (m *MockIFaceMapper) SetFilter(arg0 device.PacketFilter) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SetFilter", arg0) ret0, _ := ret[0].(error) diff --git a/client/internal/config.go b/client/internal/config.go index 725703c43..ee54c6380 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -16,9 +16,9 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/client/ssh" - "github.com/netbirdio/netbird/iface" mgm "github.com/netbirdio/netbird/management/client" "github.com/netbirdio/netbird/util" ) @@ -117,6 +117,11 @@ type Config struct { // ReadConfig read config file and return with Config. If it is not exists create a new with default values func ReadConfig(configPath string) (*Config, error) { if configFileIsExists(configPath) { + err := util.EnforcePermission(configPath) + if err != nil { + log.Errorf("failed to enforce permission on config dir: %v", err) + } + config := &Config{} if _, err := util.ReadJson(configPath, config); err != nil { return nil, err @@ -159,13 +164,17 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) { if err != nil { return nil, err } - err = WriteOutConfig(input.ConfigPath, cfg) + err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg) return cfg, err } if isPreSharedKeyHidden(input.PreSharedKey) { input.PreSharedKey = nil } + err := util.EnforcePermission(input.ConfigPath) + if err != nil { + log.Errorf("failed to enforce permission on config dir: %v", err) + } return update(input) } diff --git a/client/internal/connect.go b/client/internal/connect.go index 1cfabe910..c77f95603 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -17,15 +17,18 @@ import ( "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" - "github.com/netbirdio/netbird/iface" mgm "github.com/netbirdio/netbird/management/client" mgmProto "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/relay/auth/hmac" + relayClient "github.com/netbirdio/netbird/relay/client" signal "github.com/netbirdio/netbird/signal/client" "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/version" @@ -55,22 +58,20 @@ func NewConnectClient( // Run with main logic. func (c *ConnectClient) Run() error { - return c.run(MobileDependency{}, nil, nil, nil, nil) + return c.run(MobileDependency{}, nil, nil) } // RunWithProbes runs the client's main logic with probes attached func (c *ConnectClient) RunWithProbes( - mgmProbe *Probe, - signalProbe *Probe, - relayProbe *Probe, - wgProbe *Probe, + probes *ProbeHolder, + runningChan chan error, ) error { - return c.run(MobileDependency{}, mgmProbe, signalProbe, relayProbe, wgProbe) + return c.run(MobileDependency{}, probes, runningChan) } // RunOnAndroid with main logic on mobile system func (c *ConnectClient) RunOnAndroid( - tunAdapter iface.TunAdapter, + tunAdapter device.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover, networkChangeListener listener.NetworkChangeListener, dnsAddresses []string, @@ -84,7 +85,7 @@ func (c *ConnectClient) RunOnAndroid( HostDNSAddresses: dnsAddresses, DnsReadyListener: dnsReadyListener, } - return c.run(mobileDependency, nil, nil, nil, nil) + return c.run(mobileDependency, nil, nil) } func (c *ConnectClient) RunOniOS( @@ -100,15 +101,13 @@ func (c *ConnectClient) RunOniOS( NetworkChangeListener: networkChangeListener, DnsManager: dnsManager, } - return c.run(mobileDependency, nil, nil, nil, nil) + return c.run(mobileDependency, nil, nil) } func (c *ConnectClient) run( mobileDependency MobileDependency, - mgmProbe *Probe, - signalProbe *Probe, - relayProbe *Probe, - wgProbe *Probe, + probes *ProbeHolder, + runningChan chan error, ) error { defer func() { if r := recover(); r != nil { @@ -160,12 +159,11 @@ func (c *ConnectClient) run( } defer c.statusRecorder.ClientStop() + runningChanOpen := true operation := func() error { // if context cancelled we not start new backoff cycle - select { - case <-c.ctx.Done(): + if c.isContextCancelled() { return nil - default: } state.Set(StatusConnecting) @@ -187,8 +185,7 @@ func (c *ConnectClient) run( log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host) defer func() { - err = mgmClient.Close() - if err != nil { + if err = mgmClient.Close(); err != nil { log.Warnf("failed to close the Management service client %v", err) } }() @@ -199,6 +196,7 @@ func (c *ConnectClient) run( log.Debug(err) if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) { state.Set(StatusNeedsLogin) + _ = c.Stop() return backoff.Permanent(wrapErr(err)) // unrecoverable error } return wrapErr(err) @@ -208,10 +206,9 @@ func (c *ConnectClient) run( localPeerState := peer.LocalPeerState{ IP: loginResp.GetPeerConfig().GetAddress(), PubKey: myPrivateKey.PublicKey().String(), - KernelInterface: iface.WireGuardModuleIsLoaded(), + KernelInterface: device.WireGuardModuleIsLoaded(), FQDN: loginResp.GetPeerConfig().GetFqdn(), } - c.statusRecorder.UpdateLocalPeerState(localPeerState) signalURL := fmt.Sprintf("%s://%s", @@ -244,6 +241,23 @@ func (c *ConnectClient) run( c.statusRecorder.MarkSignalConnected() + relayURLs, token := parseRelayInfo(loginResp) + relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String()) + if len(relayURLs) > 0 { + if token != nil { + if err := relayManager.UpdateToken(token); err != nil { + log.Errorf("failed to update token: %s", err) + return wrapErr(err) + } + } + log.Infof("connecting to the Relay service(s): %s", strings.Join(relayURLs, ", ")) + if err = relayManager.Serve(); err != nil { + log.Error(err) + return wrapErr(err) + } + c.statusRecorder.SetRelayMgr(relayManager) + } + peerConfig := loginResp.GetPeerConfig() engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig) @@ -255,11 +269,17 @@ func (c *ConnectClient) run( checks := loginResp.GetChecks() c.engineMutex.Lock() - c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, engineConfig, mobileDependency, c.statusRecorder, mgmProbe, signalProbe, relayProbe, wgProbe, checks) + if c.engine != nil && c.engine.ctx.Err() != nil { + log.Info("Stopping Netbird Engine") + if err := c.engine.Stop(); err != nil { + log.Errorf("Failed to stop engine: %v", err) + } + } + c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks) + c.engineMutex.Unlock() - err = c.engine.Start() - if err != nil { + if err := c.engine.Start(); err != nil { log.Errorf("error while starting Netbird Connection Engine: %s", err) return wrapErr(err) } @@ -267,17 +287,17 @@ func (c *ConnectClient) run( log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress()) state.Set(StatusConnected) + if runningChan != nil && runningChanOpen { + runningChan <- nil + close(runningChan) + runningChanOpen = false + } + <-engineCtx.Done() c.statusRecorder.ClientTeardown() backOff.Reset() - err = c.engine.Stop() - if err != nil { - log.Errorf("failed stopping engine %v", err) - return wrapErr(err) - } - log.Info("stopped NetBird client") if _, err := state.Status(); errors.Is(err, ErrResetConnection) { @@ -293,13 +313,31 @@ func (c *ConnectClient) run( log.Debugf("exiting client retry loop due to unrecoverable error: %s", err) if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) { state.Set(StatusNeedsLogin) + _ = c.Stop() } return err } return nil } +func parseRelayInfo(loginResp *mgmProto.LoginResponse) ([]string, *hmac.Token) { + relayCfg := loginResp.GetWiretrusteeConfig().GetRelay() + if relayCfg == nil { + return nil, nil + } + + token := &hmac.Token{ + Payload: relayCfg.GetTokenPayload(), + Signature: relayCfg.GetTokenSignature(), + } + + return relayCfg.GetUrls(), token +} + func (c *ConnectClient) Engine() *Engine { + if c == nil { + return nil + } var e *Engine c.engineMutex.Lock() e = c.engine @@ -307,6 +345,28 @@ func (c *ConnectClient) Engine() *Engine { return e } +func (c *ConnectClient) Stop() error { + if c == nil { + return nil + } + c.engineMutex.Lock() + defer c.engineMutex.Unlock() + + if c.engine == nil { + return nil + } + return c.engine.Stop() +} + +func (c *ConnectClient) isContextCancelled() bool { + select { + case <-c.ctx.Done(): + return true + default: + return false + } +} + // createEngineConfig converts configuration received from Management Service to EngineConfig func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) { nm := false @@ -397,19 +457,43 @@ func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal return notifier } -func freePort(start int) (int, error) { +// freePort attempts to determine if the provided port is available, if not it will ask the system for a free port. +func freePort(initPort int) (int, error) { addr := net.UDPAddr{} - if start == 0 { - start = iface.DefaultWgPort + if initPort == 0 { + initPort = iface.DefaultWgPort } - for x := start; x <= 65535; x++ { - addr.Port = x - conn, err := net.ListenUDP("udp", &addr) - if err != nil { - continue - } - conn.Close() - return x, nil + + addr.Port = initPort + + conn, err := net.ListenUDP("udp", &addr) + if err == nil { + closeConnWithLog(conn) + return initPort, nil + } + + // if the port is already in use, ask the system for a free port + addr.Port = 0 + conn, err = net.ListenUDP("udp", &addr) + if err != nil { + return 0, fmt.Errorf("unable to get a free port: %v", err) + } + + udpAddr, ok := conn.LocalAddr().(*net.UDPAddr) + if !ok { + return 0, errors.New("wrong address type when getting a free port") + } + closeConnWithLog(conn) + return udpAddr.Port, nil +} + +func closeConnWithLog(conn *net.UDPConn) { + startClosing := time.Now() + err := conn.Close() + if err != nil { + log.Warnf("closing probe port %d failed: %v. NetBird will still attempt to use this port for connection.", conn.LocalAddr().(*net.UDPAddr).Port, err) + } + if time.Since(startClosing) > time.Second { + log.Warnf("closing the testing port %d took %s. Usually it is safe to ignore, but continuous warnings may indicate a problem.", conn.LocalAddr().(*net.UDPAddr).Port, time.Since(startClosing)) } - return 0, errors.New("no free ports") } diff --git a/client/internal/connect_test.go b/client/internal/connect_test.go index 6f4a6bbb7..78b4b06e8 100644 --- a/client/internal/connect_test.go +++ b/client/internal/connect_test.go @@ -7,51 +7,55 @@ import ( func Test_freePort(t *testing.T) { tests := []struct { - name string - port int - want int - wantErr bool + name string + port int + want int + shouldMatch bool }{ { - name: "available", - port: 51820, - want: 51820, - wantErr: false, + name: "not provided, fallback to default", + port: 0, + want: 51820, + shouldMatch: true, }, { - name: "notavailable", - port: 51830, - want: 51831, - wantErr: false, + name: "provided and available", + port: 51821, + want: 51821, + shouldMatch: true, }, { - name: "noports", - port: 65535, - want: 0, - wantErr: true, + name: "provided and not available", + port: 51830, + want: 51830, + shouldMatch: false, }, } + c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 51830}) + if err != nil { + t.Errorf("freePort error = %v", err) + } + defer func(c1 *net.UDPConn) { + _ = c1.Close() + }(c1) + for _, tt := range tests { - c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 51830}) - if err != nil { - t.Errorf("freePort error = %v", err) - } - c2, err := net.ListenUDP("udp", &net.UDPAddr{Port: 65535}) - if err != nil { - t.Errorf("freePort error = %v", err) - } t.Run(tt.name, func(t *testing.T) { got, err := freePort(tt.port) - if (err != nil) != tt.wantErr { - t.Errorf("freePort() error = %v, wantErr %v", err, tt.wantErr) - return + + if err != nil { + t.Errorf("got an error while getting free port: %v", err) } - if got != tt.want { - t.Errorf("freePort() = %v, want %v", got, tt.want) + + if tt.shouldMatch && got != tt.want { + t.Errorf("got a different port %v, want %v", got, tt.want) + } + + if !tt.shouldMatch && got == tt.want { + t.Errorf("got the same port %v, want a different port", tt.want) } }) - c1.Close() - c2.Close() + } } diff --git a/client/internal/dns/response_writer_test.go b/client/internal/dns/response_writer_test.go index 5a0047700..857964406 100644 --- a/client/internal/dns/response_writer_test.go +++ b/client/internal/dns/response_writer_test.go @@ -9,7 +9,7 @@ import ( "github.com/google/gopacket/layers" "github.com/miekg/dns" - "github.com/netbirdio/netbird/iface/mocks" + "github.com/netbirdio/netbird/client/iface/mocks" ) func TestResponseWriterLocalAddr(t *testing.T) { diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index b9552bc17..53d18a678 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -15,16 +15,18 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/firewall/uspfilter" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" + pfmock "github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/stdnet" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/formatter" - "github.com/netbirdio/netbird/iface" - pfmock "github.com/netbirdio/netbird/iface/mocks" ) type mocWGIface struct { - filter iface.PacketFilter + filter device.PacketFilter } func (w *mocWGIface) Name() string { @@ -43,11 +45,11 @@ func (w *mocWGIface) ToInterface() *net.Interface { panic("implement me") } -func (w *mocWGIface) GetFilter() iface.PacketFilter { +func (w *mocWGIface) GetFilter() device.PacketFilter { return w.filter } -func (w *mocWGIface) GetDevice() *iface.DeviceWrapper { +func (w *mocWGIface) GetDevice() *device.FilteredDevice { panic("implement me") } @@ -59,13 +61,13 @@ func (w *mocWGIface) IsUserspaceBind() bool { return false } -func (w *mocWGIface) SetFilter(filter iface.PacketFilter) error { +func (w *mocWGIface) SetFilter(filter device.PacketFilter) error { w.filter = filter return nil } -func (w *mocWGIface) GetStats(_ string) (iface.WGStats, error) { - return iface.WGStats{}, nil +func (w *mocWGIface) GetStats(_ string) (configurer.WGStats, error) { + return configurer.WGStats{}, nil } var zoneRecords = []nbdns.SimpleRecord{ diff --git a/client/internal/dns/wgiface.go b/client/internal/dns/wgiface.go index 2f08e8d52..69bc83659 100644 --- a/client/internal/dns/wgiface.go +++ b/client/internal/dns/wgiface.go @@ -5,7 +5,9 @@ package dns import ( "net" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" ) // WGIface defines subset methods of interface required for manager @@ -14,7 +16,7 @@ type WGIface interface { Address() iface.WGAddress ToInterface() *net.Interface IsUserspaceBind() bool - GetFilter() iface.PacketFilter - GetDevice() *iface.DeviceWrapper - GetStats(peerKey string) (iface.WGStats, error) + GetFilter() device.PacketFilter + GetDevice() *device.FilteredDevice + GetStats(peerKey string) (configurer.WGStats, error) } diff --git a/client/internal/dns/wgiface_windows.go b/client/internal/dns/wgiface_windows.go index f8bb80fb9..765132fdb 100644 --- a/client/internal/dns/wgiface_windows.go +++ b/client/internal/dns/wgiface_windows.go @@ -1,14 +1,18 @@ package dns -import "github.com/netbirdio/netbird/iface" +import ( + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/device" +) // WGIface defines subset methods of interface required for manager type WGIface interface { Name() string Address() iface.WGAddress IsUserspaceBind() bool - GetFilter() iface.PacketFilter - GetDevice() *iface.DeviceWrapper - GetStats(peerKey string) (iface.WGStats, error) + GetFilter() device.PacketFilter + GetDevice() *device.FilteredDevice + GetStats(peerKey string) (configurer.WGStats, error) GetInterfaceGUIDString() (string, error) } diff --git a/client/internal/engine.go b/client/internal/engine.go index d65322d6a..c51901a22 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -13,6 +13,7 @@ import ( "slices" "strings" "sync" + "sync/atomic" "time" "github.com/pion/ice/v3" @@ -22,8 +23,12 @@ import ( "github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" + + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/relay" @@ -34,11 +39,11 @@ import ( nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/iface" - "github.com/netbirdio/netbird/iface/bind" mgm "github.com/netbirdio/netbird/management/client" "github.com/netbirdio/netbird/management/domain" mgmProto "github.com/netbirdio/netbird/management/proto" + auth "github.com/netbirdio/netbird/relay/auth/hmac" + relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" signal "github.com/netbirdio/netbird/signal/client" sProto "github.com/netbirdio/netbird/signal/proto" @@ -101,7 +106,8 @@ type EngineConfig struct { // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. type Engine struct { // signal is a Signal Service client - signal signal.Client + signal signal.Client + signaler *peer.Signaler // mgmClient is a Management Service client mgmClient mgm.Client // peerConns is a map that holds all the peers that are known to this peer @@ -122,7 +128,8 @@ type Engine struct { // STUNs is a list of STUN servers used by ICE STUNs []*stun.URI // TURNs is a list of STUN servers used by ICE - TURNs []*stun.URI + TURNs []*stun.URI + stunTurn atomic.Value // clientRoutes is the most recent list of clientRoutes received from the Management Service clientRoutes route.HAMap @@ -134,7 +141,7 @@ type Engine struct { ctx context.Context cancel context.CancelFunc - wgInterface *iface.WGIface + wgInterface iface.IWGIface wgProxyFactory *wgproxy.Factory udpMux *bind.UniversalUDPMuxDefault @@ -155,15 +162,12 @@ type Engine struct { dnsServer dns.Server - mgmProbe *Probe - signalProbe *Probe - relayProbe *Probe - wgProbe *Probe - - wgConnWorker sync.WaitGroup + probes *ProbeHolder // checks are the client-applied posture checks that need to be evaluated on the client checks []*mgmProto.Checks + + relayManager *relayClient.Manager } // Peer is an instance of the Connection Peer @@ -178,6 +182,7 @@ func NewEngine( clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, + relayManager *relayClient.Manager, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, @@ -188,13 +193,11 @@ func NewEngine( clientCancel, signalClient, mgmClient, + relayManager, config, mobileDep, statusRecorder, nil, - nil, - nil, - nil, checks, ) } @@ -205,21 +208,20 @@ func NewEngineWithProbes( clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, + relayManager *relayClient.Manager, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, - mgmProbe *Probe, - signalProbe *Probe, - relayProbe *Probe, - wgProbe *Probe, + probes *ProbeHolder, checks []*mgmProto.Checks, ) *Engine { - return &Engine{ clientCtx: clientCtx, clientCancel: clientCancel, signal: signalClient, + signaler: peer.NewSignaler(signalClient, config.WgPrivateKey), mgmClient: mgmClient, + relayManager: relayManager, peerConns: make(map[string]*peer.Conn), syncMsgMux: &sync.Mutex{}, config: config, @@ -229,22 +231,20 @@ func NewEngineWithProbes( networkSerial: 0, sshServerFunc: nbssh.DefaultSSHServer, statusRecorder: statusRecorder, - mgmProbe: mgmProbe, - signalProbe: signalProbe, - relayProbe: relayProbe, - wgProbe: wgProbe, + probes: probes, checks: checks, } } func (e *Engine) Stop() error { + if e == nil { + // this seems to be a very odd case but there was the possibility if the netbird down command comes before the engine is fully started + log.Debugf("tried stopping engine that is nil") + return nil + } e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() - if e.cancel != nil { - e.cancel() - } - // stopping network monitor first to avoid starting the engine again if e.networkMonitor != nil { e.networkMonitor.Stop() @@ -253,36 +253,24 @@ func (e *Engine) Stop() error { err := e.removeAllPeers() if err != nil { - return err + return fmt.Errorf("failed to remove all peers: %s", err) } e.clientRoutesMu.Lock() e.clientRoutes = nil e.clientRoutesMu.Unlock() + if e.cancel != nil { + e.cancel() + } + // very ugly but we want to remove peers from the WireGuard interface first before removing interface. // Removing peers happens in the conn.Close() asynchronously time.Sleep(500 * time.Millisecond) e.close() - e.wgConnWorker.Wait() - - maxWaitTime := 5 * time.Second - timeout := time.After(maxWaitTime) - - for { - if !e.IsWGIfaceUp() { - log.Infof("stopped Netbird Engine") - return nil - } - - select { - case <-timeout: - return fmt.Errorf("timeout when waiting for interface shutdown") - default: - time.Sleep(100 * time.Millisecond) - } - } + log.Infof("stopped Netbird Engine") + return nil } // Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services @@ -305,7 +293,7 @@ func (e *Engine) Start() error { e.wgInterface = wgIface userspace := e.wgInterface.IsUserspaceBind() - e.wgProxyFactory = wgproxy.NewFactory(e.ctx, userspace, e.config.WgPort) + e.wgProxyFactory = wgproxy.NewFactory(userspace, e.config.WgPort) if e.config.RosenpassEnabled { log.Infof("rosenpass is enabled") @@ -331,7 +319,7 @@ func (e *Engine) Start() error { } e.dnsServer = dnsServer - e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, initialRoutes) + e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes) beforePeerHook, afterPeerHook, err := e.routeManager.Init() if err != nil { log.Errorf("Failed to initialize route manager: %s", err) @@ -480,80 +468,45 @@ func (e *Engine) removePeer(peerKey string) error { conn, exists := e.peerConns[peerKey] if exists { delete(e.peerConns, peerKey) - err := conn.Close() - if err != nil { - switch err.(type) { - case *peer.ConnectionAlreadyClosedError: - return nil - default: - return err - } - } + conn.Close() } return nil } -func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client) error { - err := s.Send(&sProto.Message{ - Key: myKey.PublicKey().String(), - RemoteKey: remoteKey.String(), - Body: &sProto.Body{ - Type: sProto.Body_CANDIDATE, - Payload: candidate.Marshal(), - }, - }) - if err != nil { - return err - } - - return nil -} - -func sendSignal(message *sProto.Message, s signal.Client) error { - return s.Send(message) -} - -// SignalOfferAnswer signals either an offer or an answer to remote peer -func SignalOfferAnswer(offerAnswer peer.OfferAnswer, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client, - isAnswer bool) error { - var t sProto.Body_Type - if isAnswer { - t = sProto.Body_ANSWER - } else { - t = sProto.Body_OFFER - } - - msg, err := signal.MarshalCredential(myKey, offerAnswer.WgListenPort, remoteKey, &signal.Credential{ - UFrag: offerAnswer.IceCredentials.UFrag, - Pwd: offerAnswer.IceCredentials.Pwd, - }, t, offerAnswer.RosenpassPubKey, offerAnswer.RosenpassAddr) - if err != nil { - return err - } - - err = s.Send(msg) - if err != nil { - return err - } - - return nil -} - func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() if update.GetWiretrusteeConfig() != nil { - err := e.updateTURNs(update.GetWiretrusteeConfig().GetTurns()) + wCfg := update.GetWiretrusteeConfig() + err := e.updateTURNs(wCfg.GetTurns()) if err != nil { - return err + return fmt.Errorf("update TURNs: %w", err) } - err = e.updateSTUNs(update.GetWiretrusteeConfig().GetStuns()) + err = e.updateSTUNs(wCfg.GetStuns()) if err != nil { - return err + return fmt.Errorf("update STUNs: %w", err) } + var stunTurn []*stun.URI + stunTurn = append(stunTurn, e.STUNs...) + stunTurn = append(stunTurn, e.TURNs...) + e.stunTurn.Store(stunTurn) + + relayMsg := wCfg.GetRelay() + if relayMsg != nil { + c := &auth.Token{ + Payload: relayMsg.GetTokenPayload(), + Signature: relayMsg.GetTokenSignature(), + } + if err := e.relayManager.UpdateToken(c); err != nil { + log.Errorf("failed to update relay token: %v", err) + return fmt.Errorf("update relay token: %w", err) + } + } + + // todo update relay address in the relay manager // todo update signal } @@ -667,7 +620,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { e.statusRecorder.UpdateLocalPeerState(peer.LocalPeerState{ IP: e.config.WgAddr, PubKey: e.config.WgPrivateKey.PublicKey().String(), - KernelInterface: iface.WireGuardModuleIsLoaded(), + KernelInterface: device.WireGuardModuleIsLoaded(), FQDN: conf.GetFqdn(), }) @@ -752,6 +705,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { return nil } + // Apply ACLs in the beginning to avoid security leaks + if e.acl != nil { + e.acl.ApplyFiltering(networkMap) + } + protoRoutes := networkMap.GetRoutes() if protoRoutes == nil { protoRoutes = []*mgmProto.Route{} @@ -818,10 +776,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { log.Errorf("failed to update dns server, err: %v", err) } - if e.acl != nil { - e.acl.ApplyFiltering(networkMap) - } - e.networkSerial = serial // Test received (upstream) servers for availability right away instead of upon usage. @@ -949,68 +903,13 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error { log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err) } - e.wgConnWorker.Add(1) - go e.connWorker(conn, peerKey) + conn.Open() } return nil } -func (e *Engine) connWorker(conn *peer.Conn, peerKey string) { - defer e.wgConnWorker.Done() - for { - - // randomize starting time a bit - minValue := 500 - maxValue := 2000 - duration := time.Duration(rand.Intn(maxValue-minValue)+minValue) * time.Millisecond - select { - case <-e.ctx.Done(): - return - case <-time.After(duration): - } - - // if peer has been removed -> give up - if !e.peerExists(peerKey) { - log.Debugf("peer %s doesn't exist anymore, won't retry connection", peerKey) - return - } - - if !e.signal.Ready() { - log.Infof("signal client isn't ready, skipping connection attempt %s", peerKey) - continue - } - - // we might have received new STUN and TURN servers meanwhile, so update them - e.syncMsgMux.Lock() - conn.UpdateStunTurn(append(e.STUNs, e.TURNs...)) - e.syncMsgMux.Unlock() - - err := conn.Open(e.ctx) - if err != nil { - log.Debugf("connection to peer %s failed: %v", peerKey, err) - var connectionClosedError *peer.ConnectionClosedError - switch { - case errors.As(err, &connectionClosedError): - // conn has been forced to close, so we exit the loop - return - default: - } - } - } -} - -func (e *Engine) peerExists(peerKey string) bool { - e.syncMsgMux.Lock() - defer e.syncMsgMux.Unlock() - _, ok := e.peerConns[peerKey] - return ok -} - func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, error) { log.Debugf("creating peer connection %s", pubKey) - var stunTurn []*stun.URI - stunTurn = append(stunTurn, e.STUNs...) - stunTurn = append(stunTurn, e.TURNs...) wgConfig := peer.WgConfig{ RemoteKey: pubKey, @@ -1043,52 +942,29 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e // randomize connection timeout timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond config := peer.ConnConfig{ - Key: pubKey, - LocalKey: e.config.WgPrivateKey.PublicKey().String(), - StunTurn: stunTurn, - InterfaceBlackList: e.config.IFaceBlackList, - DisableIPv6Discovery: e.config.DisableIPv6Discovery, - Timeout: timeout, - UDPMux: e.udpMux.UDPMuxDefault, - UDPMuxSrflx: e.udpMux, - WgConfig: wgConfig, - LocalWgPort: e.config.WgPort, - NATExternalIPs: e.parseNATExternalIPMappings(), - RosenpassPubKey: e.getRosenpassPubKey(), - RosenpassAddr: e.getRosenpassAddr(), + Key: pubKey, + LocalKey: e.config.WgPrivateKey.PublicKey().String(), + Timeout: timeout, + WgConfig: wgConfig, + LocalWgPort: e.config.WgPort, + RosenpassPubKey: e.getRosenpassPubKey(), + RosenpassAddr: e.getRosenpassAddr(), + ICEConfig: peer.ICEConfig{ + StunTurn: &e.stunTurn, + InterfaceBlackList: e.config.IFaceBlackList, + DisableIPv6Discovery: e.config.DisableIPv6Discovery, + UDPMux: e.udpMux.UDPMuxDefault, + UDPMuxSrflx: e.udpMux, + NATExternalIPs: e.parseNATExternalIPMappings(), + }, } - peerConn, err := peer.NewConn(config, e.statusRecorder, e.wgProxyFactory, e.mobileDep.TunAdapter, e.mobileDep.IFaceDiscover) + peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.wgProxyFactory, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager) if err != nil { return nil, err } - wgPubKey, err := wgtypes.ParseKey(pubKey) - if err != nil { - return nil, err - } - - signalOffer := func(offerAnswer peer.OfferAnswer) error { - return SignalOfferAnswer(offerAnswer, e.config.WgPrivateKey, wgPubKey, e.signal, false) - } - - signalCandidate := func(candidate ice.Candidate) error { - return signalCandidate(candidate, e.config.WgPrivateKey, wgPubKey, e.signal) - } - - signalAnswer := func(offerAnswer peer.OfferAnswer) error { - return SignalOfferAnswer(offerAnswer, e.config.WgPrivateKey, wgPubKey, e.signal, true) - } - - peerConn.SetSignalCandidate(signalCandidate) - peerConn.SetSignalOffer(signalOffer) - peerConn.SetSignalAnswer(signalAnswer) - peerConn.SetSendSignalMessage(func(message *sProto.Message) error { - return sendSignal(message, e.signal) - }) - if e.rpManager != nil { - peerConn.SetOnConnected(e.rpManager.OnConnected) peerConn.SetOnDisconnected(e.rpManager.OnDisconnected) } @@ -1131,6 +1007,7 @@ func (e *Engine) receiveSignalEvents() { Version: msg.GetBody().GetNetBirdVersion(), RosenpassPubKey: rosenpassPubKey, RosenpassAddr: rosenpassAddr, + RelaySrvAddress: msg.GetBody().GetRelayServerAddress(), }) case sProto.Body_ANSWER: remoteCred, err := signal.UnMarshalCredential(msg) @@ -1153,6 +1030,7 @@ func (e *Engine) receiveSignalEvents() { Version: msg.GetBody().GetNetBirdVersion(), RosenpassPubKey: rosenpassPubKey, RosenpassAddr: rosenpassAddr, + RelaySrvAddress: msg.GetBody().GetRelayServerAddress(), }) case sProto.Body_CANDIDATE: candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload) @@ -1161,7 +1039,7 @@ func (e *Engine) receiveSignalEvents() { return err } - conn.OnRemoteCandidate(candidate, e.GetClientRoutes()) + go conn.OnRemoteCandidate(candidate, e.GetClientRoutes()) case sProto.Body_MODE: } @@ -1239,10 +1117,7 @@ func (e *Engine) close() { } // stop/restore DNS first so dbus and friends don't complain because of a missing interface - if e.dnsServer != nil { - e.dnsServer.Stop() - e.dnsServer = nil - } + e.stopDNSServer() if e.routeManager != nil { e.routeManager.Stop() @@ -1291,15 +1166,15 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) { log.Errorf("failed to create pion's stdnet: %s", err) } - var mArgs *iface.MobileIFaceArguments + var mArgs *device.MobileIFaceArguments switch runtime.GOOS { case "android": - mArgs = &iface.MobileIFaceArguments{ + mArgs = &device.MobileIFaceArguments{ TunAdapter: e.mobileDep.TunAdapter, TunFd: int(e.mobileDep.FileDescriptor), } case "ios": - mArgs = &iface.MobileIFaceArguments{ + mArgs = &device.MobileIFaceArguments{ TunFd: int(e.mobileDep.FileDescriptor), } default: @@ -1415,24 +1290,27 @@ func (e *Engine) getRosenpassAddr() string { } func (e *Engine) receiveProbeEvents() { - if e.signalProbe != nil { - go e.signalProbe.Receive(e.ctx, func() bool { + if e.probes == nil { + return + } + if e.probes.SignalProbe != nil { + go e.probes.SignalProbe.Receive(e.ctx, func() bool { healthy := e.signal.IsHealthy() log.Debugf("received signal probe request, healthy: %t", healthy) return healthy }) } - if e.mgmProbe != nil { - go e.mgmProbe.Receive(e.ctx, func() bool { + if e.probes.MgmProbe != nil { + go e.probes.MgmProbe.Receive(e.ctx, func() bool { healthy := e.mgmClient.IsHealthy() log.Debugf("received management probe request, healthy: %t", healthy) return healthy }) } - if e.relayProbe != nil { - go e.relayProbe.Receive(e.ctx, func() bool { + if e.probes.RelayProbe != nil { + go e.probes.RelayProbe.Receive(e.ctx, func() bool { healthy := true results := append(e.probeSTUNs(), e.probeTURNs()...) @@ -1451,13 +1329,13 @@ func (e *Engine) receiveProbeEvents() { }) } - if e.wgProbe != nil { - go e.wgProbe.Receive(e.ctx, func() bool { + if e.probes.WgProbe != nil { + go e.probes.WgProbe.Receive(e.ctx, func() bool { log.Debug("received wg probe request") for _, peer := range e.peerConns { key := peer.GetKey() - wgStats, err := peer.GetConf().WgConfig.WgInterface.GetStats(key) + wgStats, err := peer.WgConfig().WgInterface.GetStats(key) if err != nil { log.Debugf("failed to get wg stats for peer %s: %s", key, err) } @@ -1481,12 +1359,16 @@ func (e *Engine) probeTURNs() []relay.ProbeResult { } func (e *Engine) restartEngine() { + log.Info("restarting engine") + CtxGetState(e.ctx).Set(StatusConnecting) + if err := e.Stop(); err != nil { log.Errorf("Failed to stop engine: %v", err) } - if err := e.Start(); err != nil { - log.Errorf("Failed to start engine: %v", err) - } + + _ = CtxGetState(e.ctx).Wrap(ErrResetConnection) + log.Infof("cancelling client, engine will be recreated") + e.clientCancel() } func (e *Engine) startNetworkMonitor() { @@ -1508,6 +1390,7 @@ func (e *Engine) startNetworkMonitor() { defer mu.Unlock() if debounceTimer != nil { + log.Infof("Network monitor: detected network change, reset debounceTimer") debounceTimer.Stop() } @@ -1517,7 +1400,7 @@ func (e *Engine) startNetworkMonitor() { mu.Lock() defer mu.Unlock() - log.Infof("Network monitor detected network change, restarting engine") + log.Infof("Network monitor: detected network change, restarting engine") e.restartEngine() }) }) @@ -1542,26 +1425,23 @@ func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) { return false, netip.Prefix{}, nil } +func (e *Engine) stopDNSServer() { + err := fmt.Errorf("DNS server stopped") + nsGroupStates := e.statusRecorder.GetDNSStates() + for i := range nsGroupStates { + nsGroupStates[i].Enabled = false + nsGroupStates[i].Error = err + } + e.statusRecorder.UpdateDNSStates(nsGroupStates) + if e.dnsServer != nil { + e.dnsServer.Stop() + e.dnsServer = nil + } +} + // isChecksEqual checks if two slices of checks are equal. func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool { return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool { return slices.Equal(checks.Files, oChecks.Files) }) } - -func (e *Engine) IsWGIfaceUp() bool { - if e == nil || e.wgInterface == nil { - return false - } - iface, err := net.InterfaceByName(e.wgInterface.Name()) - if err != nil { - log.Debugf("failed to get interface by name %s: %v", e.wgInterface.Name(), err) - return false - } - - if iface.Flags&net.FlagUp != 0 { - return true - } - - return false -} diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index e0f85d211..29a8439a2 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/google/uuid" "github.com/pion/transport/v3/stdnet" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" @@ -24,19 +25,21 @@ import ( "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/iface" - "github.com/netbirdio/netbird/iface/bind" mgmt "github.com/netbirdio/netbird/management/client" mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/telemetry" + relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" signal "github.com/netbirdio/netbird/signal/client" "github.com/netbirdio/netbird/signal/proto" @@ -58,6 +61,12 @@ var ( } ) +func TestMain(m *testing.M) { + _ = util.InitLog("debug", "console") + code := m.Run() + os.Exit(code) +} + func TestEngine_SSH(t *testing.T) { // todo resolve test execution on freebsd if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" { @@ -73,13 +82,23 @@ func TestEngine_SSH(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{ - WgIfaceName: "utun101", - WgAddr: "100.64.0.1/24", - WgPrivateKey: key, - WgPort: 33100, - ServerSSHAllowed: true, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) + relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) + engine := NewEngine( + ctx, cancel, + &signal.MockClient{}, + &mgmt.MockClient{}, + relayMgr, + &EngineConfig{ + WgIfaceName: "utun101", + WgAddr: "100.64.0.1/24", + WgPrivateKey: key, + WgPort: 33100, + ServerSSHAllowed: true, + }, + MobileDependency{}, + peer.NewRecorder("https://mgm"), + nil, + ) engine.dnsServer = &dns.MockServer{ UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, @@ -208,21 +227,29 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{ - WgIfaceName: "utun102", - WgAddr: "100.64.0.1/24", - WgPrivateKey: key, - WgPort: 33100, - }, MobileDependency{}, peer.NewRecorder("https://mgm"), nil) - newNet, err := stdnet.NewNet() - if err != nil { - t.Fatal(err) + relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) + engine := NewEngine( + ctx, cancel, + &signal.MockClient{}, + &mgmt.MockClient{}, + relayMgr, + &EngineConfig{ + WgIfaceName: "utun102", + WgAddr: "100.64.0.1/24", + WgPrivateKey: key, + WgPort: 33100, + }, + MobileDependency{}, + peer.NewRecorder("https://mgm"), + nil) + + wgIface := &iface.MockWGIface{ + RemovePeerFunc: func(peerKey string) error { + return nil + }, } - engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil) - if err != nil { - t.Fatal(err) - } - engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, nil) + engine.wgInterface = wgIface + engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil) engine.dnsServer = &dns.MockServer{ UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, } @@ -404,8 +431,8 @@ func TestEngine_Sync(t *testing.T) { } return nil } - - engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, &EngineConfig{ + relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) + engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{ WgIfaceName: "utun103", WgAddr: "100.64.0.1/24", WgPrivateKey: key, @@ -564,7 +591,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { wgIfaceName := fmt.Sprintf("utun%d", 104+n) wgAddr := fmt.Sprintf("100.66.%d.1/24", n) - engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{ + relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) + engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ WgIfaceName: wgIfaceName, WgAddr: wgAddr, WgPrivateKey: key, @@ -734,7 +762,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { wgIfaceName := fmt.Sprintf("utun%d", 104+n) wgAddr := fmt.Sprintf("100.66.%d.1/24", n) - engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{ + relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) + engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{ WgIfaceName: wgIfaceName, WgAddr: wgAddr, WgPrivateKey: key, @@ -845,6 +874,8 @@ func TestEngine_MultiplePeers(t *testing.T) { engine.dnsServer = &dns.MockServer{} mu.Lock() defer mu.Unlock() + guid := fmt.Sprintf("{%s}", uuid.New().String()) + device.CustomWindowsGUIDString = strings.ToLower(guid) err = engine.Start() if err != nil { t.Errorf("unable to start engine for peer %d with error %v", j, err) @@ -1010,7 +1041,8 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin WgPort: wgPort, } - e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil + relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String()) + e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil e.ctx = ctx return e, err } @@ -1025,7 +1057,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) { log.Fatalf("failed to listen: %v", err) } - srv, err := signalServer.NewServer(otel.Meter("")) + srv, err := signalServer.NewServer(context.Background(), otel.Meter("")) require.NoError(t, err) proto.RegisterSignalExchangeServer(s, srv) @@ -1044,6 +1076,11 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error) config := &server.Config{ Stuns: []*server.Host{}, TURNConfig: &server.TURNConfig{}, + Relay: &server.Relay{ + Addresses: []string{"127.0.0.1:1234"}, + CredentialsTTL: util.Duration{Duration: time.Hour}, + Secret: "222222222222222222", + }, Signal: &server.Host{ Proto: "http", URI: "localhost:10000", @@ -1078,8 +1115,9 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error) if err != nil { return nil, "", err } - turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil) + + secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil) if err != nil { return nil, "", err } diff --git a/client/internal/mobile_dependency.go b/client/internal/mobile_dependency.go index 2355c67c3..2b0c92cc6 100644 --- a/client/internal/mobile_dependency.go +++ b/client/internal/mobile_dependency.go @@ -1,16 +1,16 @@ package internal import ( + "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/stdnet" - "github.com/netbirdio/netbird/iface" ) // MobileDependency collect all dependencies for mobile platform type MobileDependency struct { // Android only - TunAdapter iface.TunAdapter + TunAdapter device.TunAdapter IFaceDiscover stdnet.ExternalIFaceDiscover NetworkChangeListener listener.NetworkChangeListener HostDNSAddresses []string diff --git a/client/internal/networkmonitor/monitor_bsd.go b/client/internal/networkmonitor/monitor_bsd.go index 29df7ea7f..4dc2c1aa3 100644 --- a/client/internal/networkmonitor/monitor_bsd.go +++ b/client/internal/networkmonitor/monitor_bsd.go @@ -24,7 +24,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca defer func() { err := unix.Close(fd) if err != nil && !errors.Is(err, unix.EBADF) { - log.Errorf("Network monitor: failed to close routing socket: %v", err) + log.Warnf("Network monitor: failed to close routing socket: %v", err) } }() @@ -32,7 +32,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca <-ctx.Done() err := unix.Close(fd) if err != nil && !errors.Is(err, unix.EBADF) { - log.Debugf("Network monitor: closed routing socket") + log.Debugf("Network monitor: closed routing socket: %v", err) } }() @@ -45,12 +45,12 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca n, err := unix.Read(fd, buf) if err != nil { if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) { - log.Errorf("Network monitor: failed to read from routing socket: %v", err) + log.Warnf("Network monitor: failed to read from routing socket: %v", err) } continue } if n < unix.SizeofRtMsghdr { - log.Errorf("Network monitor: read from routing socket returned less than expected: %d bytes", n) + log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n) continue } @@ -61,11 +61,11 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca case unix.RTM_ADD, syscall.RTM_DELETE: route, err := parseRouteMessage(buf[:n]) if err != nil { - log.Errorf("Network monitor: error parsing routing message: %v", err) + log.Debugf("Network monitor: error parsing routing message: %v", err) continue } - if !route.Dst.Addr().IsUnspecified() { + if route.Dst.Bits() != 0 { continue } diff --git a/client/internal/networkmonitor/monitor_generic.go b/client/internal/networkmonitor/monitor_generic.go index f5cc19473..19648edba 100644 --- a/client/internal/networkmonitor/monitor_generic.go +++ b/client/internal/networkmonitor/monitor_generic.go @@ -59,7 +59,7 @@ func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error // recover in case sys ops panic defer func() { if r := recover(); r != nil { - err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, string(debug.Stack())) + err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, debug.Stack()) } }() diff --git a/client/internal/networkmonitor/monitor_windows.go b/client/internal/networkmonitor/monitor_windows.go index 308b2aa45..cd48c269d 100644 --- a/client/internal/networkmonitor/monitor_windows.go +++ b/client/internal/networkmonitor/monitor_windows.go @@ -3,252 +3,73 @@ package networkmonitor import ( "context" "fmt" - "net" - "net/netip" "strings" - "time" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" ) -const ( - unreachable = 0 - incomplete = 1 - probe = 2 - delay = 3 - stale = 4 - reachable = 5 - permanent = 6 - tbd = 7 -) - -const interval = 10 * time.Second - func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error { - var neighborv4, neighborv6 *systemops.Neighbor - { - initialNeighbors, err := getNeighbors() - if err != nil { - return fmt.Errorf("get neighbors: %w", err) - } - - neighborv4 = assignNeighbor(nexthopv4, initialNeighbors) - neighborv6 = assignNeighbor(nexthopv6, initialNeighbors) + routeMonitor, err := systemops.NewRouteMonitor(ctx) + if err != nil { + return fmt.Errorf("failed to create route monitor: %w", err) } - log.Debugf("Network monitor: initial IPv4 neighbor: %v, IPv6 neighbor: %v", neighborv4, neighborv6) - - ticker := time.NewTicker(interval) - defer ticker.Stop() + defer func() { + if err := routeMonitor.Stop(); err != nil { + log.Errorf("Network monitor: failed to stop route monitor: %v", err) + } + }() for { select { case <-ctx.Done(): return ErrStopped - case <-ticker.C: - if changed(nexthopv4, neighborv4, nexthopv6, neighborv6) { - go callback() - return nil + case route := <-routeMonitor.RouteUpdates(): + if route.Destination.Bits() != 0 { + continue + } + + if routeChanged(route, nexthopv4, nexthopv6, callback) { + break } } } } -func assignNeighbor(nexthop systemops.Nexthop, initialNeighbors map[netip.Addr]systemops.Neighbor) *systemops.Neighbor { - if n, ok := initialNeighbors[nexthop.IP]; ok && - n.State != unreachable && - n.State != incomplete && - n.State != tbd { - return &n - } - return nil -} - -func changed( - nexthopv4 systemops.Nexthop, - neighborv4 *systemops.Neighbor, - nexthopv6 systemops.Nexthop, - neighborv6 *systemops.Neighbor, -) bool { - neighbors, err := getNeighbors() - if err != nil { - log.Errorf("network monitor: error fetching current neighbors: %v", err) - return false - } - if neighborChanged(nexthopv4, neighborv4, neighbors) || neighborChanged(nexthopv6, neighborv6, neighbors) { - return true - } - - routes, err := getRoutes() - if err != nil { - log.Errorf("network monitor: error fetching current routes: %v", err) - return false - } - - if routeChanged(nexthopv4, nexthopv4.Intf, routes) || routeChanged(nexthopv6, nexthopv6.Intf, routes) { - return true - } - - return false -} - -// routeChanged checks if the default routes still point to our nexthop/interface -func routeChanged(nexthop systemops.Nexthop, intf *net.Interface, routes []systemops.Route) bool { - if !nexthop.IP.IsValid() { - return false - } - - if isSoftInterface(nexthop.Intf.Name) { - log.Tracef("network monitor: ignoring default route change for soft interface %s", nexthop.Intf.Name) - return false - } - - unspec := getUnspecifiedPrefix(nexthop.IP) - defaultRoutes, foundMatchingRoute := processRoutes(nexthop, intf, routes, unspec) - - log.Tracef("network monitor: all default routes:\n%s", strings.Join(defaultRoutes, "\n")) - - if !foundMatchingRoute { - logRouteChange(nexthop.IP, intf) - return true - } - - return false -} - -func getUnspecifiedPrefix(ip netip.Addr) netip.Prefix { - if ip.Is6() { - return netip.PrefixFrom(netip.IPv6Unspecified(), 0) - } - return netip.PrefixFrom(netip.IPv4Unspecified(), 0) -} - -func processRoutes(nexthop systemops.Nexthop, nexthopIntf *net.Interface, routes []systemops.Route, unspec netip.Prefix) ([]string, bool) { - var defaultRoutes []string - foundMatchingRoute := false - - for _, r := range routes { - if r.Destination == unspec { - routeInfo := formatRouteInfo(r) - defaultRoutes = append(defaultRoutes, routeInfo) - - if r.Nexthop == nexthop.IP && compareIntf(r.Interface, nexthopIntf) == 0 { - foundMatchingRoute = true - log.Debugf("network monitor: found matching default route: %s", routeInfo) - } +func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) bool { + intf := "" + if route.Interface != nil { + intf = route.Interface.Name + if isSoftInterface(intf) { + log.Debugf("Network monitor: ignoring default route change for soft interface %s", intf) + return false } } - return defaultRoutes, foundMatchingRoute -} - -func formatRouteInfo(r systemops.Route) string { - newIntf := "" - if r.Interface != nil { - newIntf = r.Interface.Name - } - return fmt.Sprintf("Nexthop: %s, Interface: %s", r.Nexthop, newIntf) -} - -func logRouteChange(ip netip.Addr, intf *net.Interface) { - oldIntf := "" - if intf != nil { - oldIntf = intf.Name - } - log.Infof("network monitor: default route for %s (%s) is gone or changed", ip, oldIntf) -} - -func neighborChanged(nexthop systemops.Nexthop, neighbor *systemops.Neighbor, neighbors map[netip.Addr]systemops.Neighbor) bool { - if neighbor == nil { - return false - } - - // TODO: consider non-local nexthops, e.g. on point-to-point interfaces - if n, ok := neighbors[nexthop.IP]; ok { - if n.State == unreachable || n.State == incomplete { - log.Infof("network monitor: neighbor %s (%s) is not reachable: %s", neighbor.IPAddress, neighbor.LinkLayerAddress, stateFromInt(n.State)) - return true - } else if n.InterfaceIndex != neighbor.InterfaceIndex { - log.Infof( - "network monitor: neighbor %s (%s) changed interface from '%s' (%d) to '%s' (%d): %s", - neighbor.IPAddress, - neighbor.LinkLayerAddress, - neighbor.InterfaceAlias, - neighbor.InterfaceIndex, - n.InterfaceAlias, - n.InterfaceIndex, - stateFromInt(n.State), - ) + switch route.Type { + case systemops.RouteModified: + // TODO: get routing table to figure out if our route is affected for modified routes + log.Infof("Network monitor: default route changed: via %s, interface %s", route.NextHop, intf) + go callback() + return true + case systemops.RouteAdded: + if route.NextHop.Is4() && route.NextHop != nexthopv4.IP || route.NextHop.Is6() && route.NextHop != nexthopv6.IP { + log.Infof("Network monitor: default route added: via %s, interface %s", route.NextHop, intf) + go callback() + return true + } + case systemops.RouteDeleted: + if nexthopv4.Intf != nil && route.NextHop == nexthopv4.IP || nexthopv6.Intf != nil && route.NextHop == nexthopv6.IP { + log.Infof("Network monitor: default route removed: via %s, interface %s", route.NextHop, intf) + go callback() return true } - } else { - log.Infof("network monitor: neighbor %s (%s) is gone", neighbor.IPAddress, neighbor.LinkLayerAddress) - return true } return false } -func getNeighbors() (map[netip.Addr]systemops.Neighbor, error) { - entries, err := systemops.GetNeighbors() - if err != nil { - return nil, fmt.Errorf("get neighbors: %w", err) - } - - neighbours := make(map[netip.Addr]systemops.Neighbor, len(entries)) - for _, entry := range entries { - neighbours[entry.IPAddress] = entry - } - - return neighbours, nil -} - -func getRoutes() ([]systemops.Route, error) { - entries, err := systemops.GetRoutes() - if err != nil { - return nil, fmt.Errorf("get routes: %w", err) - } - - return entries, nil -} - -func stateFromInt(state uint8) string { - switch state { - case unreachable: - return "unreachable" - case incomplete: - return "incomplete" - case probe: - return "probe" - case delay: - return "delay" - case stale: - return "stale" - case reachable: - return "reachable" - case permanent: - return "permanent" - case tbd: - return "tbd" - default: - return "unknown" - } -} - -func compareIntf(a, b *net.Interface) int { - switch { - case a == nil && b == nil: - return 0 - case a == nil: - return -1 - case b == nil: - return 1 - default: - return a.Index - b.Index - } -} - func isSoftInterface(name string) bool { return strings.Contains(strings.ToLower(name), "isatap") || strings.Contains(strings.ToLower(name), "teredo") } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 0d8fd932c..ad84bd700 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -2,570 +2,277 @@ package peer import ( "context" - "fmt" + "math/rand" "net" + "os" "runtime" "strings" "sync" "time" + "github.com/cenkalti/backoff/v4" "github.com/pion/ice/v3" - "github.com/pion/stun/v2" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/wgproxy" - "github.com/netbirdio/netbird/iface" - "github.com/netbirdio/netbird/iface/bind" + relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" - sProto "github.com/netbirdio/netbird/signal/proto" nbnet "github.com/netbirdio/netbird/util/net" - "github.com/netbirdio/netbird/version" ) -const ( - iceKeepAliveDefault = 4 * time.Second - iceDisconnectedTimeoutDefault = 6 * time.Second - // iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package - iceRelayAcceptanceMinWaitDefault = 2 * time.Second +type ConnPriority int +const ( defaultWgKeepAlive = 25 * time.Second + + connPriorityRelay ConnPriority = 1 + connPriorityICETurn ConnPriority = 1 + connPriorityICEP2P ConnPriority = 2 ) type WgConfig struct { WgListenPort int RemoteKey string - WgInterface *iface.WGIface + WgInterface iface.IWGIface AllowedIps string PreSharedKey *wgtypes.Key } // ConnConfig is a peer Connection configuration type ConnConfig struct { - // Key is a public key of a remote peer Key string // LocalKey is a public key of a local peer LocalKey string - // StunTurn is a list of STUN and TURN URLs - StunTurn []*stun.URI - - // InterfaceBlackList is a list of machine interfaces that should be filtered out by ICE Candidate gathering - // (e.g. if eth0 is in the list, host candidate of this interface won't be used) - InterfaceBlackList []string - DisableIPv6Discovery bool - Timeout time.Duration WgConfig WgConfig - UDPMux ice.UDPMux - UDPMuxSrflx ice.UniversalUDPMux - LocalWgPort int - NATExternalIPs []string - // RosenpassPubKey is this peer's Rosenpass public key RosenpassPubKey []byte // RosenpassPubKey is this peer's RosenpassAddr server address (IP:port) RosenpassAddr string + + // ICEConfig ICE protocol configuration + ICEConfig ICEConfig } -// OfferAnswer represents a session establishment offer or answer -type OfferAnswer struct { - IceCredentials IceCredentials - // WgListenPort is a remote WireGuard listen port. - // This field is used when establishing a direct WireGuard connection without any proxy. - // We can set the remote peer's endpoint with this port. - WgListenPort int +type WorkerCallbacks struct { + OnRelayReadyCallback func(info RelayConnInfo) + OnRelayStatusChanged func(ConnStatus) - // Version of NetBird Agent - Version string - // RosenpassPubKey is the Rosenpass public key of the remote peer when receiving this message - // This value is the local Rosenpass server public key when sending the message - RosenpassPubKey []byte - // RosenpassAddr is the Rosenpass server address (IP:port) of the remote peer when receiving this message - // This value is the local Rosenpass server address when sending the message - RosenpassAddr string -} - -// IceCredentials ICE protocol credentials struct -type IceCredentials struct { - UFrag string - Pwd string + OnICEConnReadyCallback func(ConnPriority, ICEConnInfo) + OnICEStatusChanged func(ConnStatus) } type Conn struct { - config ConnConfig - mu sync.Mutex - - // signalCandidate is a handler function to signal remote peer about local connection candidate - signalCandidate func(candidate ice.Candidate) error - // signalOffer is a handler function to signal remote peer our connection offer (credentials) - signalOffer func(OfferAnswer) error - signalAnswer func(OfferAnswer) error - sendSignalMessage func(message *sProto.Message) error - onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) - onDisconnected func(remotePeer string, wgIP string) - - // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection - remoteOffersCh chan OfferAnswer - // remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection - remoteAnswerCh chan OfferAnswer - closeCh chan struct{} - ctx context.Context - notifyDisconnected context.CancelFunc - - agent *ice.Agent - status ConnStatus - + log *log.Entry + mu sync.Mutex + ctx context.Context + ctxCancel context.CancelFunc + config ConnConfig statusRecorder *Status - wgProxyFactory *wgproxy.Factory - wgProxy wgproxy.Proxy + wgProxyICE wgproxy.Proxy + wgProxyRelay wgproxy.Proxy + signaler *Signaler + relayManager *relayClient.Manager + allowedIPsIP string + handshaker *Handshaker - adapter iface.TunAdapter - iFaceDiscover stdnet.ExternalIFaceDiscover - sentExtraSrflx bool + onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) + onDisconnected func(remotePeer string, wgIP string) - connID nbnet.ConnectionID + statusRelay *AtomicConnStatus + statusICE *AtomicConnStatus + currentConnPriority ConnPriority + opened bool // this flag is used to prevent close in case of not opened connection + + workerICE *WorkerICE + workerRelay *WorkerRelay + + connIDRelay nbnet.ConnectionID + connIDICE nbnet.ConnectionID beforeAddPeerHooks []nbnet.AddHookFunc afterRemovePeerHooks []nbnet.RemoveHookFunc -} -// GetConf returns the connection config -func (conn *Conn) GetConf() ConnConfig { - return conn.config -} + endpointRelay *net.UDPAddr -// WgConfig returns the WireGuard config -func (conn *Conn) WgConfig() WgConfig { - return conn.config.WgConfig -} - -// UpdateStunTurn update the turn and stun addresses -func (conn *Conn) UpdateStunTurn(turnStun []*stun.URI) { - conn.config.StunTurn = turnStun + // for reconnection operations + iCEDisconnected chan bool + relayDisconnected chan bool } // NewConn creates a new not opened Conn to the remote peer. // To establish a connection run Conn.Open -func NewConn(config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.Factory, adapter iface.TunAdapter, iFaceDiscover stdnet.ExternalIFaceDiscover) (*Conn, error) { - return &Conn{ - config: config, - mu: sync.Mutex{}, - status: StatusDisconnected, - closeCh: make(chan struct{}), - remoteOffersCh: make(chan OfferAnswer), - remoteAnswerCh: make(chan OfferAnswer), - statusRecorder: statusRecorder, - wgProxyFactory: wgProxyFactory, - adapter: adapter, - iFaceDiscover: iFaceDiscover, - }, nil +func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, wgProxyFactory *wgproxy.Factory, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager) (*Conn, error) { + _, allowedIPsIP, err := net.ParseCIDR(config.WgConfig.AllowedIps) + if err != nil { + log.Errorf("failed to parse allowedIPS: %v", err) + return nil, err + } + + ctx, ctxCancel := context.WithCancel(engineCtx) + connLog := log.WithField("peer", config.Key) + + var conn = &Conn{ + log: connLog, + ctx: ctx, + ctxCancel: ctxCancel, + config: config, + statusRecorder: statusRecorder, + wgProxyFactory: wgProxyFactory, + signaler: signaler, + relayManager: relayManager, + allowedIPsIP: allowedIPsIP.String(), + statusRelay: NewAtomicConnStatus(), + statusICE: NewAtomicConnStatus(), + iCEDisconnected: make(chan bool, 1), + relayDisconnected: make(chan bool, 1), + } + + rFns := WorkerRelayCallbacks{ + OnConnReady: conn.relayConnectionIsReady, + OnDisconnected: conn.onWorkerRelayStateDisconnected, + } + + wFns := WorkerICECallbacks{ + OnConnReady: conn.iCEConnectionIsReady, + OnStatusChanged: conn.onWorkerICEStateDisconnected, + } + + conn.workerRelay = NewWorkerRelay(connLog, config, relayManager, rFns) + + relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() + conn.workerICE, err = NewWorkerICE(ctx, connLog, config, signaler, iFaceDiscover, statusRecorder, relayIsSupportedLocally, wFns) + if err != nil { + return nil, err + } + + conn.handshaker = NewHandshaker(ctx, connLog, config, signaler, conn.workerICE, conn.workerRelay) + + conn.handshaker.AddOnNewOfferListener(conn.workerRelay.OnNewOffer) + if os.Getenv("NB_FORCE_RELAY") != "true" { + conn.handshaker.AddOnNewOfferListener(conn.workerICE.OnNewOffer) + } + + go conn.handshaker.Listen() + + return conn, nil } -func (conn *Conn) reCreateAgent() error { +// Open opens connection to the remote peer +// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will +// be used. +func (conn *Conn) Open() { + conn.log.Debugf("open connection to peer") conn.mu.Lock() defer conn.mu.Unlock() - - failedTimeout := 6 * time.Second - - var err error - transportNet, err := conn.newStdNet() - if err != nil { - log.Errorf("failed to create pion's stdnet: %s", err) - } - - iceKeepAlive := iceKeepAlive() - iceDisconnectedTimeout := iceDisconnectedTimeout() - iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() - - agentConfig := &ice.AgentConfig{ - MulticastDNSMode: ice.MulticastDNSModeDisabled, - NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, - Urls: conn.config.StunTurn, - CandidateTypes: conn.candidateTypes(), - FailedTimeout: &failedTimeout, - InterfaceFilter: stdnet.InterfaceFilter(conn.config.InterfaceBlackList), - UDPMux: conn.config.UDPMux, - UDPMuxSrflx: conn.config.UDPMuxSrflx, - NAT1To1IPs: conn.config.NATExternalIPs, - Net: transportNet, - DisconnectedTimeout: &iceDisconnectedTimeout, - KeepaliveInterval: &iceKeepAlive, - RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait, - } - - if conn.config.DisableIPv6Discovery { - agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4} - } - - conn.agent, err = ice.NewAgent(agentConfig) - if err != nil { - return err - } - - err = conn.agent.OnCandidate(conn.onICECandidate) - if err != nil { - return err - } - - err = conn.agent.OnConnectionStateChange(conn.onICEConnectionStateChange) - if err != nil { - return err - } - - err = conn.agent.OnSelectedCandidatePairChange(conn.onICESelectedCandidatePair) - if err != nil { - return err - } - - err = conn.agent.OnSuccessfulSelectedPairBindingResponse(func(p *ice.CandidatePair) { - err := conn.statusRecorder.UpdateLatency(conn.config.Key, p.Latency()) - if err != nil { - log.Debugf("failed to update latency for peer %s: %s", conn.config.Key, err) - return - } - }) - if err != nil { - return fmt.Errorf("failed setting binding response callback: %w", err) - } - - return nil -} - -func (conn *Conn) candidateTypes() []ice.CandidateType { - if hasICEForceRelayConn() { - return []ice.CandidateType{ice.CandidateTypeRelay} - } - // TODO: remove this once we have refactored userspace proxy into the bind package - if runtime.GOOS == "ios" { - return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive} - } - return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay} -} - -// Open opens connection to the remote peer starting ICE candidate gathering process. -// Blocks until connection has been closed or connection timeout. -// ConnStatus will be set accordingly -func (conn *Conn) Open(ctx context.Context) error { - log.Debugf("trying to connect to peer %s", conn.config.Key) + conn.opened = true peerState := State{ PubKey: conn.config.Key, IP: strings.Split(conn.config.WgConfig.AllowedIps, "/")[0], ConnStatusUpdate: time.Now(), - ConnStatus: conn.status, + ConnStatus: StatusDisconnected, Mux: new(sync.RWMutex), } err := conn.statusRecorder.UpdatePeerState(peerState) if err != nil { - log.Warnf("error while updating the state of peer %s,err: %v", conn.config.Key, err) + conn.log.Warnf("error while updating the state err: %v", err) } - defer func() { - err := conn.cleanup() - if err != nil { - log.Warnf("error while cleaning up peer connection %s: %v", conn.config.Key, err) - return - } - }() + go conn.startHandshakeAndReconnect() +} - err = conn.reCreateAgent() +func (conn *Conn) startHandshakeAndReconnect() { + conn.waitInitialRandomSleepTime() + + err := conn.handshaker.sendOffer() if err != nil { - return err + conn.log.Errorf("failed to send initial offer: %v", err) } - err = conn.sendOffer() - if err != nil { - return err - } - - log.Debugf("connection offer sent to peer %s, waiting for the confirmation", conn.config.Key) - - // Only continue once we got a connection confirmation from the remote peer. - // The connection timeout could have happened before a confirmation received from the remote. - // The connection could have also been closed externally (e.g. when we received an update from the management that peer shouldn't be connected) - var remoteOfferAnswer OfferAnswer - select { - case remoteOfferAnswer = <-conn.remoteOffersCh: - // received confirmation from the remote peer -> ready to proceed - err = conn.sendAnswer() - if err != nil { - return err - } - case remoteOfferAnswer = <-conn.remoteAnswerCh: - case <-time.After(conn.config.Timeout): - return NewConnectionTimeoutError(conn.config.Key, conn.config.Timeout) - case <-conn.closeCh: - // closed externally - return NewConnectionClosedError(conn.config.Key) - } - - log.Debugf("received connection confirmation from peer %s running version %s and with remote WireGuard listen port %d", - conn.config.Key, remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort) - - // at this point we received offer/answer and we are ready to gather candidates - conn.mu.Lock() - conn.status = StatusConnecting - conn.ctx, conn.notifyDisconnected = context.WithCancel(ctx) - defer conn.notifyDisconnected() - conn.mu.Unlock() - - peerState = State{ - PubKey: conn.config.Key, - ConnStatus: conn.status, - ConnStatusUpdate: time.Now(), - Mux: new(sync.RWMutex), - } - err = conn.statusRecorder.UpdatePeerState(peerState) - if err != nil { - log.Warnf("error while updating the state of peer %s,err: %v", conn.config.Key, err) - } - - err = conn.agent.GatherCandidates() - if err != nil { - return fmt.Errorf("gather candidates: %v", err) - } - - // will block until connection succeeded - // but it won't release if ICE Agent went into Disconnected or Failed state, - // so we have to cancel it with the provided context once agent detected a broken connection - isControlling := conn.config.LocalKey > conn.config.Key - var remoteConn *ice.Conn - if isControlling { - remoteConn, err = conn.agent.Dial(conn.ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) + if conn.workerRelay.IsController() { + conn.reconnectLoopWithRetry() } else { - remoteConn, err = conn.agent.Accept(conn.ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) - } - if err != nil { - return err - } - - // dynamically set remote WireGuard port if other side specified a different one from the default one - remoteWgPort := iface.DefaultWgPort - if remoteOfferAnswer.WgListenPort != 0 { - remoteWgPort = remoteOfferAnswer.WgListenPort - } - - // the ice connection has been established successfully so we are ready to start the proxy - remoteAddr, err := conn.configureConnection(remoteConn, remoteWgPort, remoteOfferAnswer.RosenpassPubKey, - remoteOfferAnswer.RosenpassAddr) - if err != nil { - return err - } - - log.Infof("connected to peer %s, endpoint address: %s", conn.config.Key, remoteAddr.String()) - - // wait until connection disconnected or has been closed externally (upper layer, e.g. engine) - select { - case <-conn.closeCh: - // closed externally - return NewConnectionClosedError(conn.config.Key) - case <-conn.ctx.Done(): - // disconnected from the remote peer - return NewConnectionDisconnectedError(conn.config.Key) + conn.reconnectLoopForOnDisconnectedEvent() } } -func isRelayCandidate(candidate ice.Candidate) bool { - return candidate.Type() == ice.CandidateTypeRelay +// Close closes this peer Conn issuing a close event to the Conn closeCh +func (conn *Conn) Close() { + conn.mu.Lock() + defer conn.mu.Unlock() + + conn.log.Infof("close peer connection") + conn.ctxCancel() + + if !conn.opened { + conn.log.Debugf("ignore close connection to peer") + return + } + + conn.workerRelay.DisableWgWatcher() + conn.workerRelay.CloseConn() + conn.workerICE.Close() + + if conn.wgProxyRelay != nil { + err := conn.wgProxyRelay.CloseConn() + if err != nil { + conn.log.Errorf("failed to close wg proxy for relay: %v", err) + } + conn.wgProxyRelay = nil + } + + if conn.wgProxyICE != nil { + err := conn.wgProxyICE.CloseConn() + if err != nil { + conn.log.Errorf("failed to close wg proxy for ice: %v", err) + } + conn.wgProxyICE = nil + } + + err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) + if err != nil { + conn.log.Errorf("failed to remove wg endpoint: %v", err) + } + + conn.freeUpConnID() + + if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil { + conn.onDisconnected(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps) + } + + conn.setStatusToDisconnected() +} + +// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise +// doesn't block, discards the message if connection wasn't ready +func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool { + conn.log.Debugf("OnRemoteAnswer, status ICE: %s, status relay: %s", conn.statusICE, conn.statusRelay) + return conn.handshaker.OnRemoteAnswer(answer) +} + +// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. +func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) { + conn.workerICE.OnRemoteCandidate(candidate, haRoutes) } func (conn *Conn) AddBeforeAddPeerHook(hook nbnet.AddHookFunc) { conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook) } - func (conn *Conn) AddAfterRemovePeerHook(hook nbnet.RemoveHookFunc) { conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook) } -// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected -func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, remoteRosenpassPubKey []byte, remoteRosenpassAddr string) (net.Addr, error) { - conn.mu.Lock() - defer conn.mu.Unlock() - - pair, err := conn.agent.GetSelectedCandidatePair() - if err != nil { - return nil, err - } - - var endpoint net.Addr - if isRelayCandidate(pair.Local) { - log.Debugf("setup relay connection") - conn.wgProxy = conn.wgProxyFactory.GetProxy(conn.ctx) - endpoint, err = conn.wgProxy.AddTurnConn(remoteConn) - if err != nil { - return nil, err - } - } else { - // To support old version's with direct mode we attempt to punch an additional role with the remote WireGuard port - go conn.punchRemoteWGPort(pair, remoteWgPort) - endpoint = remoteConn.RemoteAddr() - } - - endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) - log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP) - - conn.connID = nbnet.GenerateConnID() - for _, hook := range conn.beforeAddPeerHooks { - if err := hook(conn.connID, endpointUdpAddr.IP); err != nil { - log.Errorf("Before add peer hook failed: %v", err) - } - } - - err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey) - if err != nil { - if conn.wgProxy != nil { - if err := conn.wgProxy.CloseConn(); err != nil { - log.Warnf("Failed to close turn connection: %v", err) - } - } - return nil, fmt.Errorf("update peer: %w", err) - } - - conn.status = StatusConnected - rosenpassEnabled := false - if remoteRosenpassPubKey != nil { - rosenpassEnabled = true - } - - peerState := State{ - PubKey: conn.config.Key, - ConnStatus: conn.status, - ConnStatusUpdate: time.Now(), - LocalIceCandidateType: pair.Local.Type().String(), - RemoteIceCandidateType: pair.Remote.Type().String(), - LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()), - RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()), - Direct: !isRelayCandidate(pair.Local), - RosenpassEnabled: rosenpassEnabled, - Mux: new(sync.RWMutex), - } - if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay { - peerState.Relayed = true - } - - err = conn.statusRecorder.UpdatePeerState(peerState) - if err != nil { - log.Warnf("unable to save peer's state, got error: %v", err) - } - - _, ipNet, err := net.ParseCIDR(conn.config.WgConfig.AllowedIps) - if err != nil { - return nil, err - } - - if runtime.GOOS == "ios" { - runtime.GC() - } - - if conn.onConnected != nil { - conn.onConnected(conn.config.Key, remoteRosenpassPubKey, ipNet.IP.String(), remoteRosenpassAddr) - } - - return endpoint, nil -} - -func (conn *Conn) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) { - // wait local endpoint configuration - time.Sleep(time.Second) - addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pair.Remote.Address(), remoteWgPort)) - if err != nil { - log.Warnf("got an error while resolving the udp address, err: %s", err) - return - } - - mux, ok := conn.config.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault) - if !ok { - log.Warn("invalid udp mux conversion") - return - } - _, err = mux.GetSharedConn().WriteTo([]byte{0x6e, 0x62}, addr) - if err != nil { - log.Warnf("got an error while sending the punch packet, err: %s", err) - } -} - -// cleanup closes all open resources and sets status to StatusDisconnected -func (conn *Conn) cleanup() error { - log.Debugf("trying to cleanup %s", conn.config.Key) - conn.mu.Lock() - defer conn.mu.Unlock() - - conn.sentExtraSrflx = false - - var err1, err2, err3 error - if conn.agent != nil { - err1 = conn.agent.Close() - if err1 == nil { - conn.agent = nil - } - } - - if conn.wgProxy != nil { - err2 = conn.wgProxy.CloseConn() - conn.wgProxy = nil - } - - // todo: is it problem if we try to remove a peer what is never existed? - err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) - - if conn.connID != "" { - for _, hook := range conn.afterRemovePeerHooks { - if err := hook(conn.connID); err != nil { - log.Errorf("After remove peer hook failed: %v", err) - } - } - } - conn.connID = "" - - if conn.notifyDisconnected != nil { - conn.notifyDisconnected() - conn.notifyDisconnected = nil - } - - if conn.status == StatusConnected && conn.onDisconnected != nil { - conn.onDisconnected(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps) - } - - conn.status = StatusDisconnected - - peerState := State{ - PubKey: conn.config.Key, - ConnStatus: conn.status, - ConnStatusUpdate: time.Now(), - Mux: new(sync.RWMutex), - } - err := conn.statusRecorder.UpdatePeerState(peerState) - if err != nil { - // pretty common error because by that time Engine can already remove the peer and status won't be available. - // todo rethink status updates - log.Debugf("error while updating peer's %s state, err: %v", conn.config.Key, err) - } - if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, iface.WGStats{}); err != nil { - log.Debugf("failed to reset wireguard stats for peer %s: %s", conn.config.Key, err) - } - - log.Debugf("cleaned up connection to peer %s", conn.config.Key) - if err1 != nil { - return err1 - } - if err2 != nil { - return err2 - } - return err3 -} - -// SetSignalOffer sets a handler function to be triggered by Conn when a new connection offer has to be signalled to the remote peer -func (conn *Conn) SetSignalOffer(handler func(offer OfferAnswer) error) { - conn.signalOffer = handler -} - // SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)) { conn.onConnected = handler @@ -576,218 +283,521 @@ func (conn *Conn) SetOnDisconnected(handler func(remotePeer string, wgIP string) conn.onDisconnected = handler } -// SetSignalAnswer sets a handler function to be triggered by Conn when a new connection answer has to be signalled to the remote peer -func (conn *Conn) SetSignalAnswer(handler func(answer OfferAnswer) error) { - conn.signalAnswer = handler +func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool { + conn.log.Debugf("OnRemoteOffer, on status ICE: %s, status Relay: %s", conn.statusICE, conn.statusRelay) + return conn.handshaker.OnRemoteOffer(offer) } -// SetSignalCandidate sets a handler function to be triggered by Conn when a new ICE local connection candidate has to be signalled to the remote peer -func (conn *Conn) SetSignalCandidate(handler func(candidate ice.Candidate) error) { - conn.signalCandidate = handler -} - -// SetSendSignalMessage sets a handler function to be triggered by Conn when there is new message to send via signal -func (conn *Conn) SetSendSignalMessage(handler func(message *sProto.Message) error) { - conn.sendSignalMessage = handler -} - -// onICECandidate is a callback attached to an ICE Agent to receive new local connection candidates -// and then signals them to the remote peer -func (conn *Conn) onICECandidate(candidate ice.Candidate) { - // nil means candidate gathering has been ended - if candidate == nil { - return - } - - // TODO: reported port is incorrect for CandidateTypeHost, makes understanding ICE use via logs confusing as port is ignored - log.Debugf("discovered local candidate %s", candidate.String()) - go func() { - err := conn.signalCandidate(candidate) - if err != nil { - log.Errorf("failed signaling candidate to the remote peer %s %s", conn.config.Key, err) - } - }() - - if !conn.shouldSendExtraSrflxCandidate(candidate) { - return - } - - // sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port) - // this is useful when network has an existing port forwarding rule for the wireguard port and this peer - extraSrflx, err := extraSrflxCandidate(candidate) - if err != nil { - log.Errorf("failed creating extra server reflexive candidate %s", err) - return - } - conn.sentExtraSrflx = true - - go func() { - err = conn.signalCandidate(extraSrflx) - if err != nil { - log.Errorf("failed signaling the extra server reflexive candidate to the remote peer %s: %s", conn.config.Key, err) - } - }() -} - -func (conn *Conn) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) { - log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(), - conn.config.Key) -} - -// onICEConnectionStateChange registers callback of an ICE Agent to track connection state -func (conn *Conn) onICEConnectionStateChange(state ice.ConnectionState) { - log.Debugf("peer %s ICE ConnectionState has changed to %s", conn.config.Key, state.String()) - if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected { - conn.notifyDisconnected() - } -} - -func (conn *Conn) sendAnswer() error { - conn.mu.Lock() - defer conn.mu.Unlock() - - localUFrag, localPwd, err := conn.agent.GetLocalUserCredentials() - if err != nil { - return err - } - - log.Debugf("sending answer to %s", conn.config.Key) - err = conn.signalAnswer(OfferAnswer{ - IceCredentials: IceCredentials{localUFrag, localPwd}, - WgListenPort: conn.config.LocalWgPort, - Version: version.NetbirdVersion(), - RosenpassPubKey: conn.config.RosenpassPubKey, - RosenpassAddr: conn.config.RosenpassAddr, - }) - if err != nil { - return err - } - - return nil -} - -// sendOffer prepares local user credentials and signals them to the remote peer -func (conn *Conn) sendOffer() error { - conn.mu.Lock() - defer conn.mu.Unlock() - - localUFrag, localPwd, err := conn.agent.GetLocalUserCredentials() - if err != nil { - return err - } - err = conn.signalOffer(OfferAnswer{ - IceCredentials: IceCredentials{localUFrag, localPwd}, - WgListenPort: conn.config.LocalWgPort, - Version: version.NetbirdVersion(), - RosenpassPubKey: conn.config.RosenpassPubKey, - RosenpassAddr: conn.config.RosenpassAddr, - }) - if err != nil { - return err - } - return nil -} - -// Close closes this peer Conn issuing a close event to the Conn closeCh -func (conn *Conn) Close() error { - conn.mu.Lock() - defer conn.mu.Unlock() - select { - case conn.closeCh <- struct{}{}: - return nil - default: - // probably could happen when peer has been added and removed right after not even starting to connect - // todo further investigate - // this really happens due to unordered messages coming from management - // more importantly it causes inconsistency -> 2 Conn objects for the same peer - // e.g. this flow: - // update from management has peers: [1,2,3,4] - // engine creates a Conn for peers: [1,2,3,4] and schedules Open in ~1sec - // before conn.Open() another update from management arrives with peers: [1,2,3] - // engine removes peer 4 and calls conn.Close() which does nothing (this default clause) - // before conn.Open() another update from management arrives with peers: [1,2,3,4,5] - // engine adds a new Conn for 4 and 5 - // therefore peer 4 has 2 Conn objects - log.Warnf("Connection has been already closed or attempted closing not started connection %s", conn.config.Key) - return NewConnectionAlreadyClosed(conn.config.Key) - } +// WgConfig returns the WireGuard config +func (conn *Conn) WgConfig() WgConfig { + return conn.config.WgConfig } // Status returns current status of the Conn func (conn *Conn) Status() ConnStatus { conn.mu.Lock() defer conn.mu.Unlock() - return conn.status -} - -// OnRemoteOffer handles an offer from the remote peer and returns true if the message was accepted, false otherwise -// doesn't block, discards the message if connection wasn't ready -func (conn *Conn) OnRemoteOffer(offer OfferAnswer) bool { - log.Debugf("OnRemoteOffer from peer %s on status %s", conn.config.Key, conn.status.String()) - - select { - case conn.remoteOffersCh <- offer: - return true - default: - log.Debugf("OnRemoteOffer skipping message from peer %s on status %s because is not ready", conn.config.Key, conn.status.String()) - // connection might not be ready yet to receive so we ignore the message - return false - } -} - -// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise -// doesn't block, discards the message if connection wasn't ready -func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) bool { - log.Debugf("OnRemoteAnswer from peer %s on status %s", conn.config.Key, conn.status.String()) - - select { - case conn.remoteAnswerCh <- answer: - return true - default: - // connection might not be ready yet to receive so we ignore the message - log.Debugf("OnRemoteAnswer skipping message from peer %s on status %s because is not ready", conn.config.Key, conn.status.String()) - return false - } -} - -// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. -func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) { - log.Debugf("OnRemoteCandidate from peer %s -> %s", conn.config.Key, candidate.String()) - go func() { - conn.mu.Lock() - defer conn.mu.Unlock() - - if conn.agent == nil { - return - } - - err := conn.agent.AddRemoteCandidate(candidate) - if err != nil { - log.Errorf("error while handling remote candidate from peer %s", conn.config.Key) - return - } - }() + return conn.evalStatus() } func (conn *Conn) GetKey() string { return conn.config.Key } -func (conn *Conn) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool { - if !conn.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port { - return true +func (conn *Conn) reconnectLoopWithRetry() { + // Give chance to the peer to establish the initial connection. + // With it, we can decrease to send necessary offer + select { + case <-conn.ctx.Done(): + case <-time.After(3 * time.Second): + } + + ticker := conn.prepareExponentTicker() + defer ticker.Stop() + time.Sleep(1 * time.Second) + for { + select { + case t := <-ticker.C: + if t.IsZero() { + // in case if the ticker has been canceled by context then avoid the temporary loop + return + } + + if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { + if conn.statusRelay.Get() == StatusDisconnected || conn.statusICE.Get() == StatusDisconnected { + conn.log.Tracef("connectivity guard timedout, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE) + } + } else { + if conn.statusICE.Get() == StatusDisconnected { + conn.log.Tracef("connectivity guard timedout, ice state: %s", conn.statusICE) + } + } + + // checks if there is peer connection is established via relay or ice + if conn.isConnected() { + continue + } + + err := conn.handshaker.sendOffer() + if err != nil { + conn.log.Errorf("failed to do handshake: %v", err) + } + case changed := <-conn.relayDisconnected: + if !changed { + continue + } + conn.log.Debugf("Relay state changed, reset reconnect timer") + ticker.Stop() + ticker = conn.prepareExponentTicker() + case changed := <-conn.iCEDisconnected: + if !changed { + continue + } + conn.log.Debugf("ICE state changed, reset reconnect timer") + ticker.Stop() + ticker = conn.prepareExponentTicker() + case <-conn.ctx.Done(): + conn.log.Debugf("context is done, stop reconnect loop") + return + } } - return false } -func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) { - relatedAdd := candidate.RelatedAddress() - return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ - Network: candidate.NetworkType().String(), - Address: candidate.Address(), - Port: relatedAdd.Port, - Component: candidate.Component(), - RelAddr: relatedAdd.Address, - RelPort: relatedAdd.Port, - }) +func (conn *Conn) prepareExponentTicker() *backoff.Ticker { + bo := backoff.WithContext(&backoff.ExponentialBackOff{ + InitialInterval: 800 * time.Millisecond, + RandomizationFactor: 0.01, + Multiplier: 2, + MaxInterval: conn.config.Timeout, + MaxElapsedTime: 0, + Stop: backoff.Stop, + Clock: backoff.SystemClock, + }, conn.ctx) + + ticker := backoff.NewTicker(bo) + <-ticker.C // consume the initial tick what is happening right after the ticker has been created + + return ticker +} + +// reconnectLoopForOnDisconnectedEvent is used when the peer is not a controller and it should reconnect to the peer +// when the connection is lost. It will try to establish a connection only once time if before the connection was established +// It track separately the ice and relay connection status. Just because a lover priority connection reestablished it does not +// mean that to switch to it. We always force to use the higher priority connection. +func (conn *Conn) reconnectLoopForOnDisconnectedEvent() { + for { + select { + case changed := <-conn.relayDisconnected: + if !changed { + continue + } + conn.log.Debugf("Relay state changed, try to send new offer") + case changed := <-conn.iCEDisconnected: + if !changed { + continue + } + conn.log.Debugf("ICE state changed, try to send new offer") + case <-conn.ctx.Done(): + conn.log.Debugf("context is done, stop reconnect loop") + return + } + + err := conn.handshaker.SendOffer() + if err != nil { + conn.log.Errorf("failed to do handshake: %v", err) + } + } +} + +// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected +func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICEConnInfo) { + conn.mu.Lock() + defer conn.mu.Unlock() + + if conn.ctx.Err() != nil { + return + } + + conn.log.Debugf("ICE connection is ready") + + conn.statusICE.Set(StatusConnected) + + defer conn.updateIceState(iceConnInfo) + + if conn.currentConnPriority > priority { + return + } + + conn.log.Infof("set ICE to active connection") + + endpoint, wgProxy, err := conn.getEndpointForICEConnInfo(iceConnInfo) + if err != nil { + return + } + + endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) + conn.log.Debugf("Conn resolved IP is %s for endopint %s", endpoint, endpointUdpAddr.IP) + + conn.connIDICE = nbnet.GenerateConnID() + for _, hook := range conn.beforeAddPeerHooks { + if err := hook(conn.connIDICE, endpointUdpAddr.IP); err != nil { + conn.log.Errorf("Before add peer hook failed: %v", err) + } + } + + conn.workerRelay.DisableWgWatcher() + + err = conn.configureWGEndpoint(endpointUdpAddr) + if err != nil { + if wgProxy != nil { + if err := wgProxy.CloseConn(); err != nil { + conn.log.Warnf("Failed to close turn connection: %v", err) + } + } + conn.log.Warnf("Failed to update wg peer configuration: %v", err) + return + } + wgConfigWorkaround() + + if conn.wgProxyICE != nil { + if err := conn.wgProxyICE.CloseConn(); err != nil { + conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) + } + } + conn.wgProxyICE = wgProxy + + conn.currentConnPriority = priority + + conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr) +} + +// todo review to make sense to handle connecting and disconnected status also? +func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { + conn.mu.Lock() + defer conn.mu.Unlock() + + if conn.ctx.Err() != nil { + return + } + + conn.log.Tracef("ICE connection state changed to %s", newState) + + // switch back to relay connection + if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay { + conn.log.Debugf("ICE disconnected, set Relay to active connection") + err := conn.configureWGEndpoint(conn.endpointRelay) + if err != nil { + conn.log.Errorf("failed to switch to relay conn: %v", err) + } + conn.workerRelay.EnableWgWatcher(conn.ctx) + conn.currentConnPriority = connPriorityRelay + } + + changed := conn.statusICE.Get() != newState && newState != StatusConnecting + conn.statusICE.Set(newState) + + select { + case conn.iCEDisconnected <- changed: + default: + } + + peerState := State{ + PubKey: conn.config.Key, + ConnStatus: conn.evalStatus(), + Relayed: conn.isRelayed(), + ConnStatusUpdate: time.Now(), + } + + err := conn.statusRecorder.UpdatePeerICEStateToDisconnected(peerState) + if err != nil { + conn.log.Warnf("unable to set peer's state to disconnected ice, got error: %v", err) + } +} + +func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { + conn.mu.Lock() + defer conn.mu.Unlock() + + if conn.ctx.Err() != nil { + if err := rci.relayedConn.Close(); err != nil { + log.Warnf("failed to close unnecessary relayed connection: %v", err) + } + return + } + + conn.log.Debugf("Relay connection is ready to use") + conn.statusRelay.Set(StatusConnected) + + wgProxy := conn.wgProxyFactory.GetProxy() + endpoint, err := wgProxy.AddTurnConn(conn.ctx, rci.relayedConn) + if err != nil { + conn.log.Errorf("failed to add relayed net.Conn to local proxy: %v", err) + return + } + conn.log.Infof("created new wgProxy for relay connection: %s", endpoint) + + endpointUdpAddr, _ := net.ResolveUDPAddr(endpoint.Network(), endpoint.String()) + conn.endpointRelay = endpointUdpAddr + conn.log.Debugf("conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP) + + defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) + + if conn.currentConnPriority > connPriorityRelay { + if conn.statusICE.Get() == StatusConnected { + log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority) + return + } + } + + conn.connIDRelay = nbnet.GenerateConnID() + for _, hook := range conn.beforeAddPeerHooks { + if err := hook(conn.connIDRelay, endpointUdpAddr.IP); err != nil { + conn.log.Errorf("Before add peer hook failed: %v", err) + } + } + + err = conn.configureWGEndpoint(endpointUdpAddr) + if err != nil { + if err := wgProxy.CloseConn(); err != nil { + conn.log.Warnf("Failed to close relay connection: %v", err) + } + conn.log.Errorf("Failed to update wg peer configuration: %v", err) + return + } + conn.workerRelay.EnableWgWatcher(conn.ctx) + wgConfigWorkaround() + + if conn.wgProxyRelay != nil { + if err := conn.wgProxyRelay.CloseConn(); err != nil { + conn.log.Warnf("failed to close deprecated wg proxy conn: %v", err) + } + } + conn.wgProxyRelay = wgProxy + conn.currentConnPriority = connPriorityRelay + + conn.log.Infof("start to communicate with peer via relay") + conn.doOnConnected(rci.rosenpassPubKey, rci.rosenpassAddr) +} + +func (conn *Conn) onWorkerRelayStateDisconnected() { + conn.mu.Lock() + defer conn.mu.Unlock() + + if conn.ctx.Err() != nil { + return + } + + log.Debugf("relay connection is disconnected") + + if conn.currentConnPriority == connPriorityRelay { + log.Debugf("clean up WireGuard config") + err := conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) + if err != nil { + conn.log.Errorf("failed to remove wg endpoint: %v", err) + } + } + + if conn.wgProxyRelay != nil { + conn.endpointRelay = nil + _ = conn.wgProxyRelay.CloseConn() + conn.wgProxyRelay = nil + } + + changed := conn.statusRelay.Get() != StatusDisconnected + conn.statusRelay.Set(StatusDisconnected) + + select { + case conn.relayDisconnected <- changed: + default: + } + + peerState := State{ + PubKey: conn.config.Key, + ConnStatus: conn.evalStatus(), + Relayed: conn.isRelayed(), + ConnStatusUpdate: time.Now(), + } + + err := conn.statusRecorder.UpdatePeerRelayedStateToDisconnected(peerState) + if err != nil { + conn.log.Warnf("unable to save peer's state to Relay disconnected, got error: %v", err) + } +} + +func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr) error { + return conn.config.WgConfig.WgInterface.UpdatePeer( + conn.config.WgConfig.RemoteKey, + conn.config.WgConfig.AllowedIps, + defaultWgKeepAlive, + addr, + conn.config.WgConfig.PreSharedKey, + ) +} + +func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) { + peerState := State{ + PubKey: conn.config.Key, + ConnStatusUpdate: time.Now(), + ConnStatus: conn.evalStatus(), + Relayed: conn.isRelayed(), + RelayServerAddress: relayServerAddr, + RosenpassEnabled: isRosenpassEnabled(rosenpassPubKey), + } + + err := conn.statusRecorder.UpdatePeerRelayedState(peerState) + if err != nil { + conn.log.Warnf("unable to save peer's Relay state, got error: %v", err) + } +} + +func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo) { + peerState := State{ + PubKey: conn.config.Key, + ConnStatusUpdate: time.Now(), + ConnStatus: conn.evalStatus(), + Relayed: iceConnInfo.Relayed, + LocalIceCandidateType: iceConnInfo.LocalIceCandidateType, + RemoteIceCandidateType: iceConnInfo.RemoteIceCandidateType, + LocalIceCandidateEndpoint: iceConnInfo.LocalIceCandidateEndpoint, + RemoteIceCandidateEndpoint: iceConnInfo.RemoteIceCandidateEndpoint, + RosenpassEnabled: isRosenpassEnabled(iceConnInfo.RosenpassPubKey), + } + + err := conn.statusRecorder.UpdatePeerICEState(peerState) + if err != nil { + conn.log.Warnf("unable to save peer's ICE state, got error: %v", err) + } +} + +func (conn *Conn) setStatusToDisconnected() { + conn.statusRelay.Set(StatusDisconnected) + conn.statusICE.Set(StatusDisconnected) + + peerState := State{ + PubKey: conn.config.Key, + ConnStatus: StatusDisconnected, + ConnStatusUpdate: time.Now(), + Mux: new(sync.RWMutex), + } + err := conn.statusRecorder.UpdatePeerState(peerState) + if err != nil { + // pretty common error because by that time Engine can already remove the peer and status won't be available. + // todo rethink status updates + conn.log.Debugf("error while updating peer's state, err: %v", err) + } + if err := conn.statusRecorder.UpdateWireGuardPeerState(conn.config.Key, configurer.WGStats{}); err != nil { + conn.log.Debugf("failed to reset wireguard stats for peer: %s", err) + } +} + +func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAddr string) { + if runtime.GOOS == "ios" { + runtime.GC() + } + + if conn.onConnected != nil { + conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedIPsIP, remoteRosenpassAddr) + } +} + +func (conn *Conn) waitInitialRandomSleepTime() { + minWait := 100 + maxWait := 800 + duration := time.Duration(rand.Intn(maxWait-minWait)+minWait) * time.Millisecond + + timeout := time.NewTimer(duration) + defer timeout.Stop() + + select { + case <-conn.ctx.Done(): + case <-timeout.C: + } +} + +func (conn *Conn) isRelayed() bool { + if conn.statusRelay.Get() == StatusDisconnected && (conn.statusICE.Get() == StatusDisconnected || conn.statusICE.Get() == StatusConnecting) { + return false + } + + if conn.currentConnPriority == connPriorityICEP2P { + return false + } + + return true +} + +func (conn *Conn) evalStatus() ConnStatus { + if conn.statusRelay.Get() == StatusConnected || conn.statusICE.Get() == StatusConnected { + return StatusConnected + } + + if conn.statusRelay.Get() == StatusConnecting || conn.statusICE.Get() == StatusConnecting { + return StatusConnecting + } + + return StatusDisconnected +} + +func (conn *Conn) isConnected() bool { + conn.mu.Lock() + defer conn.mu.Unlock() + + if conn.statusICE.Get() != StatusConnected && conn.statusICE.Get() != StatusConnecting { + return false + } + + if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { + if conn.statusRelay.Get() != StatusConnected { + return false + } + } + + return true +} + +func (conn *Conn) freeUpConnID() { + if conn.connIDRelay != "" { + for _, hook := range conn.afterRemovePeerHooks { + if err := hook(conn.connIDRelay); err != nil { + conn.log.Errorf("After remove peer hook failed: %v", err) + } + } + conn.connIDRelay = "" + } + + if conn.connIDICE != "" { + for _, hook := range conn.afterRemovePeerHooks { + if err := hook(conn.connIDICE); err != nil { + conn.log.Errorf("After remove peer hook failed: %v", err) + } + } + conn.connIDICE = "" + } +} + +func (conn *Conn) getEndpointForICEConnInfo(iceConnInfo ICEConnInfo) (net.Addr, wgproxy.Proxy, error) { + if !iceConnInfo.RelayedOnLocal { + return iceConnInfo.RemoteConn.RemoteAddr(), nil, nil + } + conn.log.Debugf("setup ice turn connection") + wgProxy := conn.wgProxyFactory.GetProxy() + ep, err := wgProxy.AddTurnConn(conn.ctx, iceConnInfo.RemoteConn) + if err != nil { + conn.log.Errorf("failed to add turn net.Conn to local proxy: %v", err) + if errClose := wgProxy.CloseConn(); errClose != nil { + conn.log.Warnf("failed to close turn proxy connection: %v", errClose) + } + return nil, nil, err + } + return ep, wgProxy, nil +} + +func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { + return remoteRosenpassPubKey != nil +} + +// wgConfigWorkaround is a workaround for the issue with WireGuard configuration update +// When update a peer configuration in near to each other time, the second update can be ignored by WireGuard +func wgConfigWorkaround() { + time.Sleep(100 * time.Millisecond) } diff --git a/client/internal/peer/conn_status.go b/client/internal/peer/conn_status.go index 639117c89..3c747864f 100644 --- a/client/internal/peer/conn_status.go +++ b/client/internal/peer/conn_status.go @@ -1,6 +1,10 @@ package peer -import log "github.com/sirupsen/logrus" +import ( + "sync/atomic" + + log "github.com/sirupsen/logrus" +) const ( // StatusConnected indicate the peer is in connected state @@ -12,7 +16,34 @@ const ( ) // ConnStatus describe the status of a peer's connection -type ConnStatus int +type ConnStatus int32 + +// AtomicConnStatus is a thread-safe wrapper for ConnStatus +type AtomicConnStatus struct { + status atomic.Int32 +} + +// NewAtomicConnStatus creates a new AtomicConnStatus with the given initial status +func NewAtomicConnStatus() *AtomicConnStatus { + acs := &AtomicConnStatus{} + acs.Set(StatusDisconnected) + return acs +} + +// Get returns the current connection status +func (acs *AtomicConnStatus) Get() ConnStatus { + return ConnStatus(acs.status.Load()) +} + +// Set updates the connection status +func (acs *AtomicConnStatus) Set(status ConnStatus) { + acs.status.Store(int32(status)) +} + +// String returns the string representation of the current status +func (acs *AtomicConnStatus) String() string { + return acs.Get().String() +} func (s ConnStatus) String() string { switch s { diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index b608a5929..b4926a9d2 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -2,25 +2,33 @@ package peer import ( "context" + "os" "sync" "testing" "time" "github.com/magiconair/properties/assert" - "github.com/pion/stun/v2" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/wgproxy" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/util" ) var connConf = ConnConfig{ - Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", - LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", - StunTurn: []*stun.URI{}, - InterfaceBlackList: nil, - Timeout: time.Second, - LocalWgPort: 51820, + Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", + LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", + Timeout: time.Second, + LocalWgPort: 51820, + ICEConfig: ICEConfig{ + InterfaceBlackList: nil, + }, +} + +func TestMain(m *testing.M) { + _ = util.InitLog("trace", "console") + code := m.Run() + os.Exit(code) } func TestNewConn_interfaceFilter(t *testing.T) { @@ -36,11 +44,11 @@ func TestNewConn_interfaceFilter(t *testing.T) { } func TestConn_GetKey(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) + wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort) defer func() { _ = wgProxyFactory.Free() }() - conn, err := NewConn(connConf, nil, wgProxyFactory, nil, nil) + conn, err := NewConn(context.Background(), connConf, nil, wgProxyFactory, nil, nil, nil) if err != nil { return } @@ -51,11 +59,11 @@ func TestConn_GetKey(t *testing.T) { } func TestConn_OnRemoteOffer(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) + wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort) defer func() { _ = wgProxyFactory.Free() }() - conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil) if err != nil { return } @@ -63,7 +71,7 @@ func TestConn_OnRemoteOffer(t *testing.T) { wg := sync.WaitGroup{} wg.Add(2) go func() { - <-conn.remoteOffersCh + <-conn.handshaker.remoteOffersCh wg.Done() }() @@ -88,11 +96,11 @@ func TestConn_OnRemoteOffer(t *testing.T) { } func TestConn_OnRemoteAnswer(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) + wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort) defer func() { _ = wgProxyFactory.Free() }() - conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil) if err != nil { return } @@ -100,7 +108,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) { wg := sync.WaitGroup{} wg.Add(2) go func() { - <-conn.remoteAnswerCh + <-conn.handshaker.remoteAnswerCh wg.Done() }() @@ -124,62 +132,42 @@ func TestConn_OnRemoteAnswer(t *testing.T) { wg.Wait() } func TestConn_Status(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) + wgProxyFactory := wgproxy.NewFactory(false, connConf.LocalWgPort) defer func() { _ = wgProxyFactory.Free() }() - conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil, nil) if err != nil { return } tables := []struct { - name string - status ConnStatus - want ConnStatus + name string + statusIce ConnStatus + statusRelay ConnStatus + want ConnStatus }{ - {"StatusConnected", StatusConnected, StatusConnected}, - {"StatusDisconnected", StatusDisconnected, StatusDisconnected}, - {"StatusConnecting", StatusConnecting, StatusConnecting}, + {"StatusConnected", StatusConnected, StatusConnected, StatusConnected}, + {"StatusDisconnected", StatusDisconnected, StatusDisconnected, StatusDisconnected}, + {"StatusConnecting", StatusConnecting, StatusConnecting, StatusConnecting}, + {"StatusConnectingIce", StatusConnecting, StatusDisconnected, StatusConnecting}, + {"StatusConnectingIceAlternative", StatusConnecting, StatusConnected, StatusConnected}, + {"StatusConnectingRelay", StatusDisconnected, StatusConnecting, StatusConnecting}, + {"StatusConnectingRelayAlternative", StatusConnected, StatusConnecting, StatusConnected}, } for _, table := range tables { t.Run(table.name, func(t *testing.T) { - conn.status = table.status + si := NewAtomicConnStatus() + si.Set(table.statusIce) + conn.statusICE = si + + sr := NewAtomicConnStatus() + sr.Set(table.statusRelay) + conn.statusRelay = sr got := conn.Status() assert.Equal(t, got, table.want, "they should be equal") }) } } - -func TestConn_Close(t *testing.T) { - wgProxyFactory := wgproxy.NewFactory(context.Background(), false, connConf.LocalWgPort) - defer func() { - _ = wgProxyFactory.Free() - }() - conn, err := NewConn(connConf, NewRecorder("https://mgm"), wgProxyFactory, nil, nil) - if err != nil { - return - } - - wg := sync.WaitGroup{} - wg.Add(1) - go func() { - <-conn.closeCh - wg.Done() - }() - - go func() { - for { - err := conn.Close() - if err != nil { - continue - } else { - return - } - } - }() - - wg.Wait() -} diff --git a/client/internal/peer/handshaker.go b/client/internal/peer/handshaker.go new file mode 100644 index 000000000..545f81966 --- /dev/null +++ b/client/internal/peer/handshaker.go @@ -0,0 +1,192 @@ +package peer + +import ( + "context" + "errors" + "sync" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/version" +) + +var ( + ErrSignalIsNotReady = errors.New("signal is not ready") +) + +// IceCredentials ICE protocol credentials struct +type IceCredentials struct { + UFrag string + Pwd string +} + +// OfferAnswer represents a session establishment offer or answer +type OfferAnswer struct { + IceCredentials IceCredentials + // WgListenPort is a remote WireGuard listen port. + // This field is used when establishing a direct WireGuard connection without any proxy. + // We can set the remote peer's endpoint with this port. + WgListenPort int + + // Version of NetBird Agent + Version string + // RosenpassPubKey is the Rosenpass public key of the remote peer when receiving this message + // This value is the local Rosenpass server public key when sending the message + RosenpassPubKey []byte + // RosenpassAddr is the Rosenpass server address (IP:port) of the remote peer when receiving this message + // This value is the local Rosenpass server address when sending the message + RosenpassAddr string + + // relay server address + RelaySrvAddress string +} + +type Handshaker struct { + mu sync.Mutex + ctx context.Context + log *log.Entry + config ConnConfig + signaler *Signaler + ice *WorkerICE + relay *WorkerRelay + onNewOfferListeners []func(*OfferAnswer) + + // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection + remoteOffersCh chan OfferAnswer + // remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection + remoteAnswerCh chan OfferAnswer +} + +func NewHandshaker(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay) *Handshaker { + return &Handshaker{ + ctx: ctx, + log: log, + config: config, + signaler: signaler, + ice: ice, + relay: relay, + remoteOffersCh: make(chan OfferAnswer), + remoteAnswerCh: make(chan OfferAnswer), + } +} + +func (h *Handshaker) AddOnNewOfferListener(offer func(remoteOfferAnswer *OfferAnswer)) { + h.onNewOfferListeners = append(h.onNewOfferListeners, offer) +} + +func (h *Handshaker) Listen() { + for { + h.log.Debugf("wait for remote offer confirmation") + remoteOfferAnswer, err := h.waitForRemoteOfferConfirmation() + if err != nil { + var connectionClosedError *ConnectionClosedError + if errors.As(err, &connectionClosedError) { + h.log.Tracef("stop handshaker") + return + } + h.log.Errorf("failed to received remote offer confirmation: %s", err) + continue + } + + h.log.Debugf("received connection confirmation, running version %s and with remote WireGuard listen port %d", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort) + for _, listener := range h.onNewOfferListeners { + go listener(remoteOfferAnswer) + } + } +} + +func (h *Handshaker) SendOffer() error { + h.mu.Lock() + defer h.mu.Unlock() + return h.sendOffer() +} + +// OnRemoteOffer handles an offer from the remote peer and returns true if the message was accepted, false otherwise +// doesn't block, discards the message if connection wasn't ready +func (h *Handshaker) OnRemoteOffer(offer OfferAnswer) bool { + select { + case h.remoteOffersCh <- offer: + return true + default: + h.log.Debugf("OnRemoteOffer skipping message because is not ready") + // connection might not be ready yet to receive so we ignore the message + return false + } +} + +// OnRemoteAnswer handles an offer from the remote peer and returns true if the message was accepted, false otherwise +// doesn't block, discards the message if connection wasn't ready +func (h *Handshaker) OnRemoteAnswer(answer OfferAnswer) bool { + select { + case h.remoteAnswerCh <- answer: + return true + default: + // connection might not be ready yet to receive so we ignore the message + h.log.Debugf("OnRemoteAnswer skipping message because is not ready") + return false + } +} + +func (h *Handshaker) waitForRemoteOfferConfirmation() (*OfferAnswer, error) { + select { + case remoteOfferAnswer := <-h.remoteOffersCh: + // received confirmation from the remote peer -> ready to proceed + err := h.sendAnswer() + if err != nil { + return nil, err + } + return &remoteOfferAnswer, nil + case remoteOfferAnswer := <-h.remoteAnswerCh: + return &remoteOfferAnswer, nil + case <-h.ctx.Done(): + // closed externally + return nil, NewConnectionClosedError(h.config.Key) + } +} + +// sendOffer prepares local user credentials and signals them to the remote peer +func (h *Handshaker) sendOffer() error { + if !h.signaler.Ready() { + return ErrSignalIsNotReady + } + + iceUFrag, icePwd := h.ice.GetLocalUserCredentials() + offer := OfferAnswer{ + IceCredentials: IceCredentials{iceUFrag, icePwd}, + WgListenPort: h.config.LocalWgPort, + Version: version.NetbirdVersion(), + RosenpassPubKey: h.config.RosenpassPubKey, + RosenpassAddr: h.config.RosenpassAddr, + } + + addr, err := h.relay.RelayInstanceAddress() + if err == nil { + offer.RelaySrvAddress = addr + } + + return h.signaler.SignalOffer(offer, h.config.Key) +} + +func (h *Handshaker) sendAnswer() error { + h.log.Debugf("sending answer") + uFrag, pwd := h.ice.GetLocalUserCredentials() + + answer := OfferAnswer{ + IceCredentials: IceCredentials{uFrag, pwd}, + WgListenPort: h.config.LocalWgPort, + Version: version.NetbirdVersion(), + RosenpassPubKey: h.config.RosenpassPubKey, + RosenpassAddr: h.config.RosenpassAddr, + } + addr, err := h.relay.RelayInstanceAddress() + if err == nil { + answer.RelaySrvAddress = addr + } + + err = h.signaler.SignalAnswer(answer, h.config.Key) + if err != nil { + return err + } + + return nil +} diff --git a/client/internal/peer/signaler.go b/client/internal/peer/signaler.go new file mode 100644 index 000000000..713123e5d --- /dev/null +++ b/client/internal/peer/signaler.go @@ -0,0 +1,70 @@ +package peer + +import ( + "github.com/pion/ice/v3" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + signal "github.com/netbirdio/netbird/signal/client" + sProto "github.com/netbirdio/netbird/signal/proto" +) + +type Signaler struct { + signal signal.Client + wgPrivateKey wgtypes.Key +} + +func NewSignaler(signal signal.Client, wgPrivateKey wgtypes.Key) *Signaler { + return &Signaler{ + signal: signal, + wgPrivateKey: wgPrivateKey, + } +} + +func (s *Signaler) SignalOffer(offer OfferAnswer, remoteKey string) error { + return s.signalOfferAnswer(offer, remoteKey, sProto.Body_OFFER) +} + +func (s *Signaler) SignalAnswer(offer OfferAnswer, remoteKey string) error { + return s.signalOfferAnswer(offer, remoteKey, sProto.Body_ANSWER) +} + +func (s *Signaler) SignalICECandidate(candidate ice.Candidate, remoteKey string) error { + return s.signal.Send(&sProto.Message{ + Key: s.wgPrivateKey.PublicKey().String(), + RemoteKey: remoteKey, + Body: &sProto.Body{ + Type: sProto.Body_CANDIDATE, + Payload: candidate.Marshal(), + }, + }) +} + +func (s *Signaler) Ready() bool { + return s.signal.Ready() +} + +// SignalOfferAnswer signals either an offer or an answer to remote peer +func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error { + msg, err := signal.MarshalCredential( + s.wgPrivateKey, + offerAnswer.WgListenPort, + remoteKey, + &signal.Credential{ + UFrag: offerAnswer.IceCredentials.UFrag, + Pwd: offerAnswer.IceCredentials.Pwd, + }, + bodyType, + offerAnswer.RosenpassPubKey, + offerAnswer.RosenpassAddr, + offerAnswer.RelaySrvAddress) + if err != nil { + return err + } + + err = s.signal.Send(msg) + if err != nil { + return err + } + + return nil +} diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index a7cfb95c4..a28992fac 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -3,6 +3,7 @@ package peer import ( "errors" "net/netip" + "slices" "sync" "time" @@ -10,9 +11,10 @@ import ( "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" + "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/internal/relay" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/management/domain" + relayClient "github.com/netbirdio/netbird/relay/client" ) // State contains the latest state of a peer @@ -24,11 +26,11 @@ type State struct { ConnStatus ConnStatus ConnStatusUpdate time.Time Relayed bool - Direct bool LocalIceCandidateType string RemoteIceCandidateType string LocalIceCandidateEndpoint string RemoteIceCandidateEndpoint string + RelayServerAddress string LastWireguardHandshake time.Time BytesTx int64 BytesRx int64 @@ -142,6 +144,8 @@ type Status struct { // Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events // set to true this variable and at the end of the processing we will reset it by the FinishPeerListModifications() peerListChangedForNotification bool + + relayMgr *relayClient.Manager } // NewRecorder returns a new Status instance @@ -156,6 +160,12 @@ func NewRecorder(mgmAddress string) *Status { } } +func (d *Status) SetRelayMgr(manager *relayClient.Manager) { + d.mux.Lock() + defer d.mux.Unlock() + d.relayMgr = manager +} + // ReplaceOfflinePeers replaces func (d *Status) ReplaceOfflinePeers(replacement []State) { d.mux.Lock() @@ -193,7 +203,7 @@ func (d *Status) GetPeer(peerPubKey string) (State, error) { state, ok := d.peers[peerPubKey] if !ok { - return State{}, iface.ErrPeerNotFound + return State{}, configurer.ErrPeerNotFound } return state, nil } @@ -231,17 +241,17 @@ func (d *Status) UpdatePeerState(receivedState State) error { peerState.SetRoutes(receivedState.GetRoutes()) } - skipNotification := shouldSkipNotify(receivedState, peerState) + skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState) if receivedState.ConnStatus != peerState.ConnStatus { peerState.ConnStatus = receivedState.ConnStatus peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate - peerState.Direct = receivedState.Direct peerState.Relayed = receivedState.Relayed peerState.LocalIceCandidateType = receivedState.LocalIceCandidateType peerState.RemoteIceCandidateType = receivedState.RemoteIceCandidateType peerState.LocalIceCandidateEndpoint = receivedState.LocalIceCandidateEndpoint peerState.RemoteIceCandidateEndpoint = receivedState.RemoteIceCandidateEndpoint + peerState.RelayServerAddress = receivedState.RelayServerAddress peerState.RosenpassEnabled = receivedState.RosenpassEnabled } @@ -261,8 +271,148 @@ func (d *Status) UpdatePeerState(receivedState State) error { return nil } +func (d *Status) UpdatePeerICEState(receivedState State) error { + d.mux.Lock() + defer d.mux.Unlock() + + peerState, ok := d.peers[receivedState.PubKey] + if !ok { + return errors.New("peer doesn't exist") + } + + if receivedState.IP != "" { + peerState.IP = receivedState.IP + } + + skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState) + + peerState.ConnStatus = receivedState.ConnStatus + peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate + peerState.Relayed = receivedState.Relayed + peerState.LocalIceCandidateType = receivedState.LocalIceCandidateType + peerState.RemoteIceCandidateType = receivedState.RemoteIceCandidateType + peerState.LocalIceCandidateEndpoint = receivedState.LocalIceCandidateEndpoint + peerState.RemoteIceCandidateEndpoint = receivedState.RemoteIceCandidateEndpoint + peerState.RosenpassEnabled = receivedState.RosenpassEnabled + + d.peers[receivedState.PubKey] = peerState + + if skipNotification { + return nil + } + + ch, found := d.changeNotify[receivedState.PubKey] + if found && ch != nil { + close(ch) + d.changeNotify[receivedState.PubKey] = nil + } + + d.notifyPeerListChanged() + return nil +} + +func (d *Status) UpdatePeerRelayedState(receivedState State) error { + d.mux.Lock() + defer d.mux.Unlock() + + peerState, ok := d.peers[receivedState.PubKey] + if !ok { + return errors.New("peer doesn't exist") + } + + skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState) + + peerState.ConnStatus = receivedState.ConnStatus + peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate + peerState.Relayed = receivedState.Relayed + peerState.RelayServerAddress = receivedState.RelayServerAddress + peerState.RosenpassEnabled = receivedState.RosenpassEnabled + + d.peers[receivedState.PubKey] = peerState + + if skipNotification { + return nil + } + + ch, found := d.changeNotify[receivedState.PubKey] + if found && ch != nil { + close(ch) + d.changeNotify[receivedState.PubKey] = nil + } + + d.notifyPeerListChanged() + return nil +} + +func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error { + d.mux.Lock() + defer d.mux.Unlock() + + peerState, ok := d.peers[receivedState.PubKey] + if !ok { + return errors.New("peer doesn't exist") + } + + skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState) + + peerState.ConnStatus = receivedState.ConnStatus + peerState.Relayed = receivedState.Relayed + peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate + peerState.RelayServerAddress = "" + + d.peers[receivedState.PubKey] = peerState + + if skipNotification { + return nil + } + + ch, found := d.changeNotify[receivedState.PubKey] + if found && ch != nil { + close(ch) + d.changeNotify[receivedState.PubKey] = nil + } + + d.notifyPeerListChanged() + return nil +} + +func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error { + d.mux.Lock() + defer d.mux.Unlock() + + peerState, ok := d.peers[receivedState.PubKey] + if !ok { + return errors.New("peer doesn't exist") + } + + skipNotification := shouldSkipNotify(receivedState.ConnStatus, peerState) + + peerState.ConnStatus = receivedState.ConnStatus + peerState.Relayed = receivedState.Relayed + peerState.ConnStatusUpdate = receivedState.ConnStatusUpdate + peerState.LocalIceCandidateType = receivedState.LocalIceCandidateType + peerState.RemoteIceCandidateType = receivedState.RemoteIceCandidateType + peerState.LocalIceCandidateEndpoint = receivedState.LocalIceCandidateEndpoint + peerState.RemoteIceCandidateEndpoint = receivedState.RemoteIceCandidateEndpoint + + d.peers[receivedState.PubKey] = peerState + + if skipNotification { + return nil + } + + ch, found := d.changeNotify[receivedState.PubKey] + if found && ch != nil { + close(ch) + d.changeNotify[receivedState.PubKey] = nil + } + + d.notifyPeerListChanged() + return nil +} + // UpdateWireGuardPeerState updates the WireGuard bits of the peer state -func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats iface.WGStats) error { +func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats configurer.WGStats) error { d.mux.Lock() defer d.mux.Unlock() @@ -280,13 +430,13 @@ func (d *Status) UpdateWireGuardPeerState(pubKey string, wgStats iface.WGStats) return nil } -func shouldSkipNotify(received, curr State) bool { +func shouldSkipNotify(receivedConnStatus ConnStatus, curr State) bool { switch { - case received.ConnStatus == StatusConnecting: + case receivedConnStatus == StatusConnecting: return true - case received.ConnStatus == StatusDisconnected && curr.ConnStatus == StatusConnecting: + case receivedConnStatus == StatusDisconnected && curr.ConnStatus == StatusConnecting: return true - case received.ConnStatus == StatusDisconnected && curr.ConnStatus == StatusDisconnected: + case receivedConnStatus == StatusDisconnected && curr.ConnStatus == StatusDisconnected: return curr.IP != "" default: return false @@ -447,6 +597,8 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) { } func (d *Status) GetRosenpassState() RosenpassState { + d.mux.Lock() + defer d.mux.Unlock() return RosenpassState{ d.rosenpassEnabled, d.rosenpassPermissive, @@ -454,6 +606,8 @@ func (d *Status) GetRosenpassState() RosenpassState { } func (d *Status) GetManagementState() ManagementState { + d.mux.Lock() + defer d.mux.Unlock() return ManagementState{ d.mgmAddress, d.managementState, @@ -495,6 +649,8 @@ func (d *Status) IsLoginRequired() bool { } func (d *Status) GetSignalState() SignalState { + d.mux.Lock() + defer d.mux.Unlock() return SignalState{ d.signalAddress, d.signalState, @@ -502,11 +658,42 @@ func (d *Status) GetSignalState() SignalState { } } +// GetRelayStates returns the stun/turn/permanent relay states func (d *Status) GetRelayStates() []relay.ProbeResult { - return d.relayStates + d.mux.Lock() + defer d.mux.Unlock() + if d.relayMgr == nil { + return d.relayStates + } + + // extend the list of stun, turn servers with relay address + relayStates := slices.Clone(d.relayStates) + + var relayState relay.ProbeResult + + // if the server connection is not established then we will use the general address + // in case of connection we will use the instance specific address + instanceAddr, err := d.relayMgr.RelayInstanceAddress() + if err != nil { + // TODO add their status + if errors.Is(err, relayClient.ErrRelayClientNotConnected) { + for _, r := range d.relayMgr.ServerURLs() { + relayStates = append(relayStates, relay.ProbeResult{ + URI: r, + }) + } + return relayStates + } + relayState.Err = err + } + + relayState.URI = instanceAddr + return append(relayStates, relayState) } func (d *Status) GetDNSStates() []NSGroupState { + d.mux.Lock() + defer d.mux.Unlock() return d.nsGroupStates } @@ -518,24 +705,24 @@ func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix { // GetFullStatus gets full status func (d *Status) GetFullStatus() FullStatus { - d.mux.Lock() - defer d.mux.Unlock() - fullStatus := FullStatus{ ManagementState: d.GetManagementState(), SignalState: d.GetSignalState(), - LocalPeerState: d.localPeer, Relays: d.GetRelayStates(), RosenpassState: d.GetRosenpassState(), NSGroupStates: d.GetDNSStates(), } + d.mux.Lock() + defer d.mux.Unlock() + + fullStatus.LocalPeerState = d.localPeer + for _, status := range d.peers { fullStatus.Peers = append(fullStatus.Peers, status) } fullStatus.Peers = append(fullStatus.Peers, d.offlinePeers...) - return fullStatus } diff --git a/client/internal/peer/status_test.go b/client/internal/peer/status_test.go index a4a6e6081..1d283433b 100644 --- a/client/internal/peer/status_test.go +++ b/client/internal/peer/status_test.go @@ -2,8 +2,8 @@ package peer import ( "errors" - "testing" "sync" + "testing" "github.com/stretchr/testify/assert" ) @@ -43,7 +43,7 @@ func TestUpdatePeerState(t *testing.T) { status := NewRecorder("https://mgm") peerState := State{ PubKey: key, - Mux: new(sync.RWMutex), + Mux: new(sync.RWMutex), } status.peers[key] = peerState @@ -64,7 +64,7 @@ func TestStatus_UpdatePeerFQDN(t *testing.T) { status := NewRecorder("https://mgm") peerState := State{ PubKey: key, - Mux: new(sync.RWMutex), + Mux: new(sync.RWMutex), } status.peers[key] = peerState @@ -83,7 +83,7 @@ func TestGetPeerStateChangeNotifierLogic(t *testing.T) { status := NewRecorder("https://mgm") peerState := State{ PubKey: key, - Mux: new(sync.RWMutex), + Mux: new(sync.RWMutex), } status.peers[key] = peerState @@ -108,7 +108,7 @@ func TestRemovePeer(t *testing.T) { status := NewRecorder("https://mgm") peerState := State{ PubKey: key, - Mux: new(sync.RWMutex), + Mux: new(sync.RWMutex), } status.peers[key] = peerState diff --git a/client/internal/peer/stdnet.go b/client/internal/peer/stdnet.go index 13f5886f5..ae31ebbf0 100644 --- a/client/internal/peer/stdnet.go +++ b/client/internal/peer/stdnet.go @@ -6,6 +6,6 @@ import ( "github.com/netbirdio/netbird/client/internal/stdnet" ) -func (conn *Conn) newStdNet() (*stdnet.Net, error) { - return stdnet.NewNet(conn.config.InterfaceBlackList) +func (w *WorkerICE) newStdNet() (*stdnet.Net, error) { + return stdnet.NewNet(w.config.ICEConfig.InterfaceBlackList) } diff --git a/client/internal/peer/stdnet_android.go b/client/internal/peer/stdnet_android.go index 8a2454371..b411405bb 100644 --- a/client/internal/peer/stdnet_android.go +++ b/client/internal/peer/stdnet_android.go @@ -2,6 +2,6 @@ package peer import "github.com/netbirdio/netbird/client/internal/stdnet" -func (conn *Conn) newStdNet() (*stdnet.Net, error) { - return stdnet.NewNetWithDiscover(conn.iFaceDiscover, conn.config.InterfaceBlackList) +func (w *WorkerICE) newStdNet() (*stdnet.Net, error) { + return stdnet.NewNetWithDiscover(w.iFaceDiscover, w.config.ICEConfig.InterfaceBlackList) } diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go new file mode 100644 index 000000000..c4e9d1950 --- /dev/null +++ b/client/internal/peer/worker_ice.go @@ -0,0 +1,470 @@ +package peer + +import ( + "context" + "fmt" + "net" + "net/netip" + "runtime" + "sync" + "sync/atomic" + "time" + + "github.com/pion/ice/v3" + "github.com/pion/randutil" + "github.com/pion/stun/v2" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/internal/stdnet" + "github.com/netbirdio/netbird/route" +) + +const ( + iceKeepAliveDefault = 4 * time.Second + iceDisconnectedTimeoutDefault = 6 * time.Second + // iceRelayAcceptanceMinWaitDefault is the same as in the Pion ICE package + iceRelayAcceptanceMinWaitDefault = 2 * time.Second + + lenUFrag = 16 + lenPwd = 32 + runesAlpha = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" +) + +var ( + failedTimeout = 6 * time.Second +) + +type ICEConfig struct { + // StunTurn is a list of STUN and TURN URLs + StunTurn *atomic.Value // []*stun.URI + + // InterfaceBlackList is a list of machine interfaces that should be filtered out by ICE Candidate gathering + // (e.g. if eth0 is in the list, host candidate of this interface won't be used) + InterfaceBlackList []string + DisableIPv6Discovery bool + + UDPMux ice.UDPMux + UDPMuxSrflx ice.UniversalUDPMux + + NATExternalIPs []string +} + +type ICEConnInfo struct { + RemoteConn net.Conn + RosenpassPubKey []byte + RosenpassAddr string + LocalIceCandidateType string + RemoteIceCandidateType string + RemoteIceCandidateEndpoint string + LocalIceCandidateEndpoint string + Relayed bool + RelayedOnLocal bool +} + +type WorkerICECallbacks struct { + OnConnReady func(ConnPriority, ICEConnInfo) + OnStatusChanged func(ConnStatus) +} + +type WorkerICE struct { + ctx context.Context + log *log.Entry + config ConnConfig + signaler *Signaler + iFaceDiscover stdnet.ExternalIFaceDiscover + statusRecorder *Status + hasRelayOnLocally bool + conn WorkerICECallbacks + + selectedPriority ConnPriority + + agent *ice.Agent + muxAgent sync.Mutex + + StunTurn []*stun.URI + + sentExtraSrflx bool + + localUfrag string + localPwd string +} + +func NewWorkerICE(ctx context.Context, log *log.Entry, config ConnConfig, signaler *Signaler, ifaceDiscover stdnet.ExternalIFaceDiscover, statusRecorder *Status, hasRelayOnLocally bool, callBacks WorkerICECallbacks) (*WorkerICE, error) { + w := &WorkerICE{ + ctx: ctx, + log: log, + config: config, + signaler: signaler, + iFaceDiscover: ifaceDiscover, + statusRecorder: statusRecorder, + hasRelayOnLocally: hasRelayOnLocally, + conn: callBacks, + } + + localUfrag, localPwd, err := generateICECredentials() + if err != nil { + return nil, err + } + w.localUfrag = localUfrag + w.localPwd = localPwd + return w, nil +} + +func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { + w.log.Debugf("OnNewOffer for ICE") + w.muxAgent.Lock() + + if w.agent != nil { + w.log.Debugf("agent already exists, skipping the offer") + w.muxAgent.Unlock() + return + } + + var preferredCandidateTypes []ice.CandidateType + if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" { + w.selectedPriority = connPriorityICEP2P + preferredCandidateTypes = candidateTypesP2P() + } else { + w.selectedPriority = connPriorityICETurn + preferredCandidateTypes = candidateTypes() + } + + w.log.Debugf("recreate ICE agent") + agentCtx, agentCancel := context.WithCancel(w.ctx) + agent, err := w.reCreateAgent(agentCancel, preferredCandidateTypes) + if err != nil { + w.log.Errorf("failed to recreate ICE Agent: %s", err) + w.muxAgent.Unlock() + return + } + w.agent = agent + w.muxAgent.Unlock() + + w.log.Debugf("gather candidates") + err = w.agent.GatherCandidates() + if err != nil { + w.log.Debugf("failed to gather candidates: %s", err) + return + } + + // will block until connection succeeded + // but it won't release if ICE Agent went into Disconnected or Failed state, + // so we have to cancel it with the provided context once agent detected a broken connection + w.log.Debugf("turn agent dial") + remoteConn, err := w.turnAgentDial(agentCtx, remoteOfferAnswer) + if err != nil { + w.log.Debugf("failed to dial the remote peer: %s", err) + return + } + w.log.Debugf("agent dial succeeded") + + pair, err := w.agent.GetSelectedCandidatePair() + if err != nil { + return + } + + if !isRelayCandidate(pair.Local) { + // dynamically set remote WireGuard port if other side specified a different one from the default one + remoteWgPort := iface.DefaultWgPort + if remoteOfferAnswer.WgListenPort != 0 { + remoteWgPort = remoteOfferAnswer.WgListenPort + } + + // To support old version's with direct mode we attempt to punch an additional role with the remote WireGuard port + go w.punchRemoteWGPort(pair, remoteWgPort) + } + + ci := ICEConnInfo{ + RemoteConn: remoteConn, + RosenpassPubKey: remoteOfferAnswer.RosenpassPubKey, + RosenpassAddr: remoteOfferAnswer.RosenpassAddr, + LocalIceCandidateType: pair.Local.Type().String(), + RemoteIceCandidateType: pair.Remote.Type().String(), + LocalIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Local.Address(), pair.Local.Port()), + RemoteIceCandidateEndpoint: fmt.Sprintf("%s:%d", pair.Remote.Address(), pair.Remote.Port()), + Relayed: isRelayed(pair), + RelayedOnLocal: isRelayCandidate(pair.Local), + } + w.log.Debugf("on ICE conn read to use ready") + go w.conn.OnConnReady(w.selectedPriority, ci) +} + +// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. +func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) { + w.muxAgent.Lock() + defer w.muxAgent.Unlock() + w.log.Debugf("OnRemoteCandidate from peer %s -> %s", w.config.Key, candidate.String()) + if w.agent == nil { + w.log.Warnf("ICE Agent is not initialized yet") + return + } + + if candidateViaRoutes(candidate, haRoutes) { + return + } + + err := w.agent.AddRemoteCandidate(candidate) + if err != nil { + w.log.Errorf("error while handling remote candidate") + return + } +} + +func (w *WorkerICE) GetLocalUserCredentials() (frag string, pwd string) { + w.muxAgent.Lock() + defer w.muxAgent.Unlock() + return w.localUfrag, w.localPwd +} + +func (w *WorkerICE) Close() { + w.muxAgent.Lock() + defer w.muxAgent.Unlock() + + if w.agent == nil { + return + } + + err := w.agent.Close() + if err != nil { + w.log.Warnf("failed to close ICE agent: %s", err) + } +} + +func (w *WorkerICE) reCreateAgent(agentCancel context.CancelFunc, relaySupport []ice.CandidateType) (*ice.Agent, error) { + transportNet, err := w.newStdNet() + if err != nil { + w.log.Errorf("failed to create pion's stdnet: %s", err) + } + + iceKeepAlive := iceKeepAlive() + iceDisconnectedTimeout := iceDisconnectedTimeout() + iceRelayAcceptanceMinWait := iceRelayAcceptanceMinWait() + + agentConfig := &ice.AgentConfig{ + MulticastDNSMode: ice.MulticastDNSModeDisabled, + NetworkTypes: []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}, + Urls: w.config.ICEConfig.StunTurn.Load().([]*stun.URI), + CandidateTypes: relaySupport, + InterfaceFilter: stdnet.InterfaceFilter(w.config.ICEConfig.InterfaceBlackList), + UDPMux: w.config.ICEConfig.UDPMux, + UDPMuxSrflx: w.config.ICEConfig.UDPMuxSrflx, + NAT1To1IPs: w.config.ICEConfig.NATExternalIPs, + Net: transportNet, + FailedTimeout: &failedTimeout, + DisconnectedTimeout: &iceDisconnectedTimeout, + KeepaliveInterval: &iceKeepAlive, + RelayAcceptanceMinWait: &iceRelayAcceptanceMinWait, + LocalUfrag: w.localUfrag, + LocalPwd: w.localPwd, + } + + if w.config.ICEConfig.DisableIPv6Discovery { + agentConfig.NetworkTypes = []ice.NetworkType{ice.NetworkTypeUDP4} + } + + w.sentExtraSrflx = false + agent, err := ice.NewAgent(agentConfig) + if err != nil { + return nil, err + } + + err = agent.OnCandidate(w.onICECandidate) + if err != nil { + return nil, err + } + + err = agent.OnConnectionStateChange(func(state ice.ConnectionState) { + w.log.Debugf("ICE ConnectionState has changed to %s", state.String()) + if state == ice.ConnectionStateFailed || state == ice.ConnectionStateDisconnected { + w.conn.OnStatusChanged(StatusDisconnected) + + w.muxAgent.Lock() + agentCancel() + _ = agent.Close() + w.agent = nil + + w.muxAgent.Unlock() + } + }) + if err != nil { + return nil, err + } + + err = agent.OnSelectedCandidatePairChange(w.onICESelectedCandidatePair) + if err != nil { + return nil, err + } + + err = agent.OnSuccessfulSelectedPairBindingResponse(func(p *ice.CandidatePair) { + err := w.statusRecorder.UpdateLatency(w.config.Key, p.Latency()) + if err != nil { + w.log.Debugf("failed to update latency for peer: %s", err) + return + } + }) + if err != nil { + return nil, fmt.Errorf("failed setting binding response callback: %w", err) + } + + return agent, nil +} + +func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) { + // wait local endpoint configuration + time.Sleep(time.Second) + addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pair.Remote.Address(), remoteWgPort)) + if err != nil { + w.log.Warnf("got an error while resolving the udp address, err: %s", err) + return + } + + mux, ok := w.config.ICEConfig.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault) + if !ok { + w.log.Warn("invalid udp mux conversion") + return + } + _, err = mux.GetSharedConn().WriteTo([]byte{0x6e, 0x62}, addr) + if err != nil { + w.log.Warnf("got an error while sending the punch packet, err: %s", err) + } +} + +// onICECandidate is a callback attached to an ICE Agent to receive new local connection candidates +// and then signals them to the remote peer +func (w *WorkerICE) onICECandidate(candidate ice.Candidate) { + // nil means candidate gathering has been ended + if candidate == nil { + return + } + + // TODO: reported port is incorrect for CandidateTypeHost, makes understanding ICE use via logs confusing as port is ignored + w.log.Debugf("discovered local candidate %s", candidate.String()) + go func() { + err := w.signaler.SignalICECandidate(candidate, w.config.Key) + if err != nil { + w.log.Errorf("failed signaling candidate to the remote peer %s %s", w.config.Key, err) + } + }() + + if !w.shouldSendExtraSrflxCandidate(candidate) { + return + } + + // sends an extra server reflexive candidate to the remote peer with our related port (usually the wireguard port) + // this is useful when network has an existing port forwarding rule for the wireguard port and this peer + extraSrflx, err := extraSrflxCandidate(candidate) + if err != nil { + w.log.Errorf("failed creating extra server reflexive candidate %s", err) + return + } + w.sentExtraSrflx = true + + go func() { + err = w.signaler.SignalICECandidate(extraSrflx, w.config.Key) + if err != nil { + w.log.Errorf("failed signaling the extra server reflexive candidate: %s", err) + } + }() +} + +func (w *WorkerICE) onICESelectedCandidatePair(c1 ice.Candidate, c2 ice.Candidate) { + w.log.Debugf("selected candidate pair [local <-> remote] -> [%s <-> %s], peer %s", c1.String(), c2.String(), + w.config.Key) +} + +func (w *WorkerICE) shouldSendExtraSrflxCandidate(candidate ice.Candidate) bool { + if !w.sentExtraSrflx && candidate.Type() == ice.CandidateTypeServerReflexive && candidate.Port() != candidate.RelatedAddress().Port { + return true + } + return false +} + +func (w *WorkerICE) turnAgentDial(ctx context.Context, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) { + isControlling := w.config.LocalKey > w.config.Key + if isControlling { + return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) + } else { + return w.agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) + } +} + +func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive, error) { + relatedAdd := candidate.RelatedAddress() + return ice.NewCandidateServerReflexive(&ice.CandidateServerReflexiveConfig{ + Network: candidate.NetworkType().String(), + Address: candidate.Address(), + Port: relatedAdd.Port, + Component: candidate.Component(), + RelAddr: relatedAdd.Address, + RelPort: relatedAdd.Port, + }) +} + +func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool { + var routePrefixes []netip.Prefix + for _, routes := range clientRoutes { + if len(routes) > 0 && routes[0] != nil { + routePrefixes = append(routePrefixes, routes[0].Network) + } + } + + addr, err := netip.ParseAddr(candidate.Address()) + if err != nil { + log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err) + return false + } + + for _, prefix := range routePrefixes { + // default route is + if prefix.Bits() == 0 { + continue + } + + if prefix.Contains(addr) { + log.Debugf("Ignoring candidate [%s], its address is part of routed network %s", candidate.String(), prefix) + return true + } + } + return false +} + +func candidateTypes() []ice.CandidateType { + if hasICEForceRelayConn() { + return []ice.CandidateType{ice.CandidateTypeRelay} + } + // TODO: remove this once we have refactored userspace proxy into the bind package + if runtime.GOOS == "ios" { + return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive} + } + return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive, ice.CandidateTypeRelay} +} + +func candidateTypesP2P() []ice.CandidateType { + return []ice.CandidateType{ice.CandidateTypeHost, ice.CandidateTypeServerReflexive} +} + +func isRelayCandidate(candidate ice.Candidate) bool { + return candidate.Type() == ice.CandidateTypeRelay +} + +func isRelayed(pair *ice.CandidatePair) bool { + if pair.Local.Type() == ice.CandidateTypeRelay || pair.Remote.Type() == ice.CandidateTypeRelay { + return true + } + return false +} + +func generateICECredentials() (string, string, error) { + ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha) + if err != nil { + return "", "", err + } + + pwd, err := randutil.GenerateCryptoRandomString(lenPwd, runesAlpha) + if err != nil { + return "", "", err + } + return ufrag, pwd, nil +} diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go new file mode 100644 index 000000000..c02fccebc --- /dev/null +++ b/client/internal/peer/worker_relay.go @@ -0,0 +1,237 @@ +package peer + +import ( + "context" + "errors" + "net" + "sync" + "sync/atomic" + "time" + + log "github.com/sirupsen/logrus" + + relayClient "github.com/netbirdio/netbird/relay/client" +) + +var ( + wgHandshakePeriod = 3 * time.Minute + wgHandshakeOvertime = 30 * time.Second +) + +type RelayConnInfo struct { + relayedConn net.Conn + rosenpassPubKey []byte + rosenpassAddr string +} + +type WorkerRelayCallbacks struct { + OnConnReady func(RelayConnInfo) + OnDisconnected func() +} + +type WorkerRelay struct { + log *log.Entry + config ConnConfig + relayManager relayClient.ManagerService + callBacks WorkerRelayCallbacks + + relayedConn net.Conn + relayLock sync.Mutex + ctxWgWatch context.Context + ctxCancelWgWatch context.CancelFunc + ctxLock sync.Mutex + + relaySupportedOnRemotePeer atomic.Bool +} + +func NewWorkerRelay(log *log.Entry, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay { + r := &WorkerRelay{ + log: log, + config: config, + relayManager: relayManager, + callBacks: callbacks, + } + return r +} + +func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { + if !w.isRelaySupported(remoteOfferAnswer) { + w.log.Infof("Relay is not supported by remote peer") + w.relaySupportedOnRemotePeer.Store(false) + return + } + w.relaySupportedOnRemotePeer.Store(true) + + // the relayManager will return with error in case if the connection has lost with relay server + currentRelayAddress, err := w.relayManager.RelayInstanceAddress() + if err != nil { + w.log.Errorf("failed to handle new offer: %s", err) + return + } + + srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress) + + relayedConn, err := w.relayManager.OpenConn(srv, w.config.Key) + if err != nil { + if errors.Is(err, relayClient.ErrConnAlreadyExists) { + w.log.Debugf("handled offer by reusing existing relay connection") + return + } + w.log.Errorf("failed to open connection via Relay: %s", err) + return + } + w.relayLock.Lock() + w.relayedConn = relayedConn + w.relayLock.Unlock() + + err = w.relayManager.AddCloseListener(srv, w.onRelayMGDisconnected) + if err != nil { + log.Errorf("failed to add close listener: %s", err) + _ = relayedConn.Close() + return + } + + w.log.Debugf("peer conn opened via Relay: %s", srv) + go w.callBacks.OnConnReady(RelayConnInfo{ + relayedConn: relayedConn, + rosenpassPubKey: remoteOfferAnswer.RosenpassPubKey, + rosenpassAddr: remoteOfferAnswer.RosenpassAddr, + }) +} + +func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) { + w.log.Debugf("enable WireGuard watcher") + w.ctxLock.Lock() + defer w.ctxLock.Unlock() + + if w.ctxWgWatch != nil && w.ctxWgWatch.Err() == nil { + return + } + + ctx, ctxCancel := context.WithCancel(ctx) + w.ctxWgWatch = ctx + w.ctxCancelWgWatch = ctxCancel + + w.wgStateCheck(ctx, ctxCancel) +} + +func (w *WorkerRelay) DisableWgWatcher() { + w.ctxLock.Lock() + defer w.ctxLock.Unlock() + + if w.ctxCancelWgWatch == nil { + return + } + + w.log.Debugf("disable WireGuard watcher") + + w.ctxCancelWgWatch() +} + +func (w *WorkerRelay) RelayInstanceAddress() (string, error) { + return w.relayManager.RelayInstanceAddress() +} + +func (w *WorkerRelay) IsRelayConnectionSupportedWithPeer() bool { + return w.relaySupportedOnRemotePeer.Load() && w.RelayIsSupportedLocally() +} + +func (w *WorkerRelay) IsController() bool { + return w.config.LocalKey > w.config.Key +} + +func (w *WorkerRelay) RelayIsSupportedLocally() bool { + return w.relayManager.HasRelayAddress() +} + +func (w *WorkerRelay) CloseConn() { + w.relayLock.Lock() + defer w.relayLock.Unlock() + if w.relayedConn == nil { + return + } + + err := w.relayedConn.Close() + if err != nil { + w.log.Warnf("failed to close relay connection: %v", err) + } +} + +// wgStateCheck help to check the state of the WireGuard handshake and relay connection +func (w *WorkerRelay) wgStateCheck(ctx context.Context, ctxCancel context.CancelFunc) { + w.log.Debugf("WireGuard watcher started") + lastHandshake, err := w.wgState() + if err != nil { + w.log.Warnf("failed to read wg stats: %v", err) + lastHandshake = time.Time{} + } + + go func(lastHandshake time.Time) { + timer := time.NewTimer(wgHandshakeOvertime) + defer timer.Stop() + defer ctxCancel() + + for { + select { + case <-timer.C: + handshake, err := w.wgState() + if err != nil { + w.log.Errorf("failed to read wg stats: %v", err) + timer.Reset(wgHandshakeOvertime) + continue + } + + w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake) + + if handshake.Equal(lastHandshake) { + w.log.Infof("WireGuard handshake timed out, closing relay connection: %v", handshake) + w.relayLock.Lock() + _ = w.relayedConn.Close() + w.relayLock.Unlock() + w.callBacks.OnDisconnected() + return + } + + resetTime := time.Until(handshake.Add(wgHandshakePeriod + wgHandshakeOvertime)) + lastHandshake = handshake + timer.Reset(resetTime) + case <-ctx.Done(): + w.log.Debugf("WireGuard watcher stopped") + return + } + } + }(lastHandshake) + +} + +func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool { + if !w.relayManager.HasRelayAddress() { + return false + } + return answer.RelaySrvAddress != "" +} + +func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress string) string { + if w.IsController() { + return myRelayAddress + } + return remoteRelayAddress +} + +func (w *WorkerRelay) wgState() (time.Time, error) { + wgState, err := w.config.WgConfig.WgInterface.GetStats(w.config.Key) + if err != nil { + return time.Time{}, err + } + return wgState.LastHandshake, nil +} + +func (w *WorkerRelay) onRelayMGDisconnected() { + w.ctxLock.Lock() + defer w.ctxLock.Unlock() + + if w.ctxCancelWgWatch != nil { + w.ctxCancelWgWatch() + } + go w.callBacks.OnDisconnected() +} diff --git a/client/internal/probe.go b/client/internal/probe.go index 743b6b190..23290cf74 100644 --- a/client/internal/probe.go +++ b/client/internal/probe.go @@ -2,6 +2,13 @@ package internal import "context" +type ProbeHolder struct { + MgmProbe *Probe + SignalProbe *Probe + RelayProbe *Probe + WgProbe *Probe +} + // Probe allows to run on-demand callbacks from different code locations. // Pass the probe to a receiving and a sending end. The receiving end starts listening // to requests with Receive and executes a callback when the sending end requests it diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index 4542a37fe..7d98a6060 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -17,7 +17,7 @@ import ( // ProbeResult holds the info about the result of a relay probe request type ProbeResult struct { - URI *stun.URI + URI string Err error Addr string } @@ -176,7 +176,7 @@ func ProbeAll( wg.Add(1) go func(res *ProbeResult, stunURI *stun.URI) { defer wg.Done() - res.URI = stunURI + res.URI = stunURI.String() res.Addr, res.Err = fn(ctx, stunURI) }(&results[i], uri) } diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 1566d10dd..eaa232151 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -10,19 +10,18 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/iface" nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/static" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) type routerPeerStatus struct { connected bool relayed bool - direct bool latency time.Duration } @@ -44,7 +43,7 @@ type clientNetwork struct { ctx context.Context cancel context.CancelFunc statusRecorder *peer.Status - wgInterface *iface.WGIface + wgInterface iface.IWGIface routes map[route.ID]*route.Route routeUpdate chan routesUpdate peerStateUpdate chan struct{} @@ -54,7 +53,7 @@ type clientNetwork struct { updateSerial uint64 } -func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface *iface.WGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork { +func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface iface.IWGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork { ctx, cancel := context.WithCancel(ctx) client := &clientNetwork{ @@ -82,7 +81,6 @@ func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus { routePeerStatuses[r.ID] = routerPeerStatus{ connected: peerStatus.ConnStatus == peer.StatusConnected, relayed: peerStatus.Relayed, - direct: peerStatus.Direct, latency: peerStatus.Latency, } } @@ -97,8 +95,8 @@ func (c *clientNetwork) getRouterPeerStatuses() map[route.ID]routerPeerStatus { // * Connected peers: Only routes with connected peers are considered. // * Metric: Routes with lower metrics (better) are prioritized. // * Non-relayed: Routes without relays are preferred. -// * Direct connections: Routes with direct peer connections are favored. // * Latency: Routes with lower latency are prioritized. +// * we compare the current score + 10ms to the chosen score to avoid flapping between routes // * Stability: In case of equal scores, the currently active route (if any) is maintained. // // It returns the ID of the selected optimal route. @@ -137,10 +135,6 @@ func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[route.ID] tempScore++ } - if peerStatus.direct { - tempScore++ - } - if tempScore > chosenScore || (tempScore == chosenScore && chosen == "") { chosen = r.ID chosenScore = tempScore @@ -384,7 +378,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { } } -func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface *iface.WGIface) RouteHandler { +func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface iface.IWGIface) RouteHandler { if rt.IsDynamic() { dns := nbdns.NewServiceViaMemory(wgInterface) return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort())) diff --git a/client/internal/routemanager/client_test.go b/client/internal/routemanager/client_test.go index 0ae10e568..583156e4d 100644 --- a/client/internal/routemanager/client_test.go +++ b/client/internal/routemanager/client_test.go @@ -24,7 +24,6 @@ func TestGetBestrouteFromStatuses(t *testing.T) { "route1": { connected: true, relayed: false, - direct: true, }, }, existingRoutes: map[route.ID]*route.Route{ @@ -43,7 +42,6 @@ func TestGetBestrouteFromStatuses(t *testing.T) { "route1": { connected: true, relayed: true, - direct: true, }, }, existingRoutes: map[route.ID]*route.Route{ @@ -62,7 +60,6 @@ func TestGetBestrouteFromStatuses(t *testing.T) { "route1": { connected: true, relayed: true, - direct: false, }, }, existingRoutes: map[route.ID]*route.Route{ @@ -81,7 +78,6 @@ func TestGetBestrouteFromStatuses(t *testing.T) { "route1": { connected: false, relayed: false, - direct: false, }, }, existingRoutes: map[route.ID]*route.Route{ @@ -100,12 +96,10 @@ func TestGetBestrouteFromStatuses(t *testing.T) { "route1": { connected: true, relayed: false, - direct: true, }, "route2": { connected: true, relayed: false, - direct: true, }, }, existingRoutes: map[route.ID]*route.Route{ @@ -129,41 +123,10 @@ func TestGetBestrouteFromStatuses(t *testing.T) { "route1": { connected: true, relayed: false, - direct: true, }, "route2": { connected: true, relayed: true, - direct: true, - }, - }, - existingRoutes: map[route.ID]*route.Route{ - "route1": { - ID: "route1", - Metric: route.MaxMetric, - Peer: "peer1", - }, - "route2": { - ID: "route2", - Metric: route.MaxMetric, - Peer: "peer2", - }, - }, - currentRoute: "", - expectedRouteID: "route1", - }, - { - name: "multiple connected peers with one direct", - statuses: map[route.ID]routerPeerStatus{ - "route1": { - connected: true, - relayed: false, - direct: true, - }, - "route2": { - connected: true, - relayed: false, - direct: false, }, }, existingRoutes: map[route.ID]*route.Route{ @@ -241,13 +204,11 @@ func TestGetBestrouteFromStatuses(t *testing.T) { "route1": { connected: true, relayed: false, - direct: true, latency: 15 * time.Millisecond, }, "route2": { connected: true, relayed: false, - direct: true, latency: 10 * time.Millisecond, }, }, @@ -272,13 +233,11 @@ func TestGetBestrouteFromStatuses(t *testing.T) { "route1": { connected: true, relayed: false, - direct: true, latency: 200 * time.Millisecond, }, "route2": { connected: true, relayed: false, - direct: true, latency: 10 * time.Millisecond, }, }, @@ -303,13 +262,11 @@ func TestGetBestrouteFromStatuses(t *testing.T) { "route1": { connected: true, relayed: false, - direct: true, latency: 20 * time.Millisecond, }, "route2": { connected: true, relayed: false, - direct: true, latency: 10 * time.Millisecond, }, }, diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go index 3296f3ddf..ac94d4a5c 100644 --- a/client/internal/routemanager/dynamic/route.go +++ b/client/internal/routemanager/dynamic/route.go @@ -13,10 +13,10 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/util" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/route" ) @@ -48,7 +48,7 @@ type Route struct { currentPeerKey string cancel context.CancelFunc statusRecorder *peer.Status - wgInterface *iface.WGIface + wgInterface iface.IWGIface resolverAddr string } @@ -58,7 +58,7 @@ func NewRoute( allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, interval time.Duration, statusRecorder *peer.Status, - wgInterface *iface.WGIface, + wgInterface iface.IWGIface, resolverAddr string, ) *Route { return &Route{ @@ -303,7 +303,7 @@ func (r *Route) addRoutes(domain domain.Domain, prefixes []netip.Prefix) ([]neti var merr *multierror.Error for _, prefix := range prefixes { - if _, err := r.routeRefCounter.Increment(prefix, nil); err != nil { + if _, err := r.routeRefCounter.Increment(prefix, struct{}{}); err != nil { merr = multierror.Append(merr, fmt.Errorf("add dynamic route for IP %s: %w", prefix, err)) continue } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 0b10dbe33..d7ddf7ae8 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -14,6 +14,8 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/notifier" @@ -21,7 +23,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routeselector" - "github.com/netbirdio/netbird/iface" + relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" nbnet "github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/version" @@ -49,7 +51,8 @@ type DefaultManager struct { serverRouter serverRouter sysOps *systemops.SysOps statusRecorder *peer.Status - wgInterface *iface.WGIface + relayMgr *relayClient.Manager + wgInterface iface.IWGIface pubKey string notifier *notifier.Notifier routeRefCounter *refcounter.RouteRefCounter @@ -61,8 +64,9 @@ func NewManager( ctx context.Context, pubKey string, dnsRouteInterval time.Duration, - wgInterface *iface.WGIface, + wgInterface iface.IWGIface, statusRecorder *peer.Status, + relayMgr *relayClient.Manager, initialRoutes []*route.Route, ) *DefaultManager { mCTX, cancel := context.WithCancel(ctx) @@ -74,6 +78,7 @@ func NewManager( stop: cancel, dnsRouteInterval: dnsRouteInterval, clientNetworks: make(map[route.HAUniqueID]*clientNetwork), + relayMgr: relayMgr, routeSelector: routeselector.NewRouteSelector(), sysOps: sysOps, statusRecorder: statusRecorder, @@ -83,10 +88,10 @@ func NewManager( } dm.routeRefCounter = refcounter.New( - func(prefix netip.Prefix, _ any) (any, error) { - return nil, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface()) + func(prefix netip.Prefix, _ struct{}) (struct{}, error) { + return struct{}{}, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface()) }, - func(prefix netip.Prefix, _ any) error { + func(prefix netip.Prefix, _ struct{}) error { return sysOps.RemoveVPNRoute(prefix, wgInterface.ToInterface()) }, ) @@ -98,7 +103,7 @@ func NewManager( }, func(prefix netip.Prefix, peerKey string) error { if err := wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil { - if !errors.Is(err, iface.ErrPeerNotFound) && !errors.Is(err, iface.ErrAllowedIPNotFound) { + if !errors.Is(err, configurer.ErrPeerNotFound) && !errors.Is(err, configurer.ErrAllowedIPNotFound) { return err } log.Tracef("Remove allowed IPs %s for %s: %v", prefix, peerKey, err) @@ -124,9 +129,12 @@ func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) log.Warnf("Failed cleaning up routing: %v", err) } - mgmtAddress := m.statusRecorder.GetManagementState().URL - signalAddress := m.statusRecorder.GetSignalState().URL - ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress}) + initialAddresses := []string{m.statusRecorder.GetManagementState().URL, m.statusRecorder.GetSignalState().URL} + if m.relayMgr != nil { + initialAddresses = append(initialAddresses, m.relayMgr.ServerURLs()...) + } + + ips := resolveURLsToIPs(initialAddresses) beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips) if err != nil { diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 455c7ac0b..2f26f7a5e 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -12,8 +12,8 @@ import ( "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) @@ -416,7 +416,7 @@ func TestManagerUpdateRoutes(t *testing.T) { statusRecorder := peer.NewRecorder("https://mgm") ctx := context.TODO() - routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil) + routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil) _, _, err = routeManager.Init() diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 58a66715c..908279c88 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -5,9 +5,9 @@ import ( "fmt" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/routeselector" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/util/net" ) diff --git a/client/internal/routemanager/refcounter/refcounter.go b/client/internal/routemanager/refcounter/refcounter.go index f1d696ad9..65ea0f708 100644 --- a/client/internal/routemanager/refcounter/refcounter.go +++ b/client/internal/routemanager/refcounter/refcounter.go @@ -3,7 +3,8 @@ package refcounter import ( "errors" "fmt" - "net/netip" + "runtime" + "strings" "sync" "github.com/hashicorp/go-multierror" @@ -12,118 +13,153 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" ) -// ErrIgnore can be returned by AddFunc to indicate that the counter not be incremented for the given prefix. +const logLevel = log.TraceLevel + +// ErrIgnore can be returned by AddFunc to indicate that the counter should not be incremented for the given key. var ErrIgnore = errors.New("ignore") +// Ref holds the reference count and associated data for a key. type Ref[O any] struct { Count int Out O } -type AddFunc[I, O any] func(prefix netip.Prefix, in I) (out O, err error) -type RemoveFunc[I, O any] func(prefix netip.Prefix, out O) error +// AddFunc is the function type for adding a new key. +// Key is the type of the key (e.g., netip.Prefix). +type AddFunc[Key, I, O any] func(key Key, in I) (out O, err error) -type Counter[I, O any] struct { - // refCountMap keeps track of the reference Ref for prefixes - refCountMap map[netip.Prefix]Ref[O] +// RemoveFunc is the function type for removing a key. +type RemoveFunc[Key, O any] func(key Key, out O) error + +// Counter is a generic reference counter for managing keys and their associated data. +// Key: The type of the key (e.g., netip.Prefix, string). +// +// I: The input type for the AddFunc. It is the input type for additional data needed +// when adding a key, it is passed as the second argument to AddFunc. +// +// O: The output type for the AddFunc and RemoveFunc. This is the output returned by AddFunc. +// It is stored and passed to RemoveFunc when the reference count reaches 0. +// +// The types can be aliased to a specific type using the following syntax: +// +// type RouteRefCounter = Counter[netip.Prefix, any, any] +type Counter[Key comparable, I, O any] struct { + // refCountMap keeps track of the reference Ref for keys + refCountMap map[Key]Ref[O] refCountMu sync.Mutex - // idMap keeps track of the prefixes associated with an ID for removal - idMap map[string][]netip.Prefix + // idMap keeps track of the keys associated with an ID for removal + idMap map[string][]Key idMu sync.Mutex - add AddFunc[I, O] - remove RemoveFunc[I, O] + add AddFunc[Key, I, O] + remove RemoveFunc[Key, O] } -// New creates a new Counter instance -func New[I, O any](add AddFunc[I, O], remove RemoveFunc[I, O]) *Counter[I, O] { - return &Counter[I, O]{ - refCountMap: map[netip.Prefix]Ref[O]{}, - idMap: map[string][]netip.Prefix{}, +// New creates a new Counter instance. +// Usage example: +// +// counter := New[netip.Prefix, string, string]( +// func(key netip.Prefix, in string) (out string, err error) { ... }, +// func(key netip.Prefix, out string) error { ... },` +// ) +func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key, O]) *Counter[Key, I, O] { + return &Counter[Key, I, O]{ + refCountMap: map[Key]Ref[O]{}, + idMap: map[string][]Key{}, add: add, remove: remove, } } -// Increment increments the reference count for the given prefix. -// If this is the first reference to the prefix, the AddFunc is called. -func (rm *Counter[I, O]) Increment(prefix netip.Prefix, in I) (Ref[O], error) { +// Get retrieves the current reference count and associated data for a key. +// If the key doesn't exist, it returns a zero value Ref and false. +func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) { rm.refCountMu.Lock() defer rm.refCountMu.Unlock() - ref := rm.refCountMap[prefix] - log.Tracef("Increasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out) + ref, ok := rm.refCountMap[key] + return ref, ok +} - // Call AddFunc only if it's a new prefix +// Increment increments the reference count for the given key. +// If this is the first reference to the key, the AddFunc is called. +func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) { + rm.refCountMu.Lock() + defer rm.refCountMu.Unlock() + + ref := rm.refCountMap[key] + logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out) + + // Call AddFunc only if it's a new key if ref.Count == 0 { - log.Tracef("Adding for prefix %s with [%v]", prefix, ref.Out) - out, err := rm.add(prefix, in) + logCallerF("Calling add for key %v", key) + out, err := rm.add(key, in) if errors.Is(err, ErrIgnore) { return ref, nil } if err != nil { - return ref, fmt.Errorf("failed to add for prefix %s: %w", prefix, err) + return ref, fmt.Errorf("failed to add for key %v: %w", key, err) } ref.Out = out } ref.Count++ - rm.refCountMap[prefix] = ref + rm.refCountMap[key] = ref return ref, nil } -// IncrementWithID increments the reference count for the given prefix and groups it under the given ID. -// If this is the first reference to the prefix, the AddFunc is called. -func (rm *Counter[I, O]) IncrementWithID(id string, prefix netip.Prefix, in I) (Ref[O], error) { +// IncrementWithID increments the reference count for the given key and groups it under the given ID. +// If this is the first reference to the key, the AddFunc is called. +func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) { rm.idMu.Lock() defer rm.idMu.Unlock() - ref, err := rm.Increment(prefix, in) + ref, err := rm.Increment(key, in) if err != nil { return ref, fmt.Errorf("with ID: %w", err) } - rm.idMap[id] = append(rm.idMap[id], prefix) + rm.idMap[id] = append(rm.idMap[id], key) return ref, nil } -// Decrement decrements the reference count for the given prefix. +// Decrement decrements the reference count for the given key. // If the reference count reaches 0, the RemoveFunc is called. -func (rm *Counter[I, O]) Decrement(prefix netip.Prefix) (Ref[O], error) { +func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) { rm.refCountMu.Lock() defer rm.refCountMu.Unlock() - ref, ok := rm.refCountMap[prefix] + ref, ok := rm.refCountMap[key] if !ok { - log.Tracef("No reference found for prefix %s", prefix) + logCallerF("No reference found for key %v", key) return ref, nil } - log.Tracef("Decreasing ref count %d for prefix %s with [%v]", ref.Count, prefix, ref.Out) + logCallerF("Decreasing ref count [%d -> %d] for key %v with Out [%v]", ref.Count, ref.Count-1, key, ref.Out) if ref.Count == 1 { - log.Tracef("Removing for prefix %s with [%v]", prefix, ref.Out) - if err := rm.remove(prefix, ref.Out); err != nil { - return ref, fmt.Errorf("remove for prefix %s: %w", prefix, err) + logCallerF("Calling remove for key %v", key) + if err := rm.remove(key, ref.Out); err != nil { + return ref, fmt.Errorf("remove for key %v: %w", key, err) } - delete(rm.refCountMap, prefix) + delete(rm.refCountMap, key) } else { ref.Count-- - rm.refCountMap[prefix] = ref + rm.refCountMap[key] = ref } return ref, nil } -// DecrementWithID decrements the reference count for all prefixes associated with the given ID. +// DecrementWithID decrements the reference count for all keys associated with the given ID. // If the reference count reaches 0, the RemoveFunc is called. -func (rm *Counter[I, O]) DecrementWithID(id string) error { +func (rm *Counter[Key, I, O]) DecrementWithID(id string) error { rm.idMu.Lock() defer rm.idMu.Unlock() var merr *multierror.Error - for _, prefix := range rm.idMap[id] { - if _, err := rm.Decrement(prefix); err != nil { + for _, key := range rm.idMap[id] { + if _, err := rm.Decrement(key); err != nil { merr = multierror.Append(merr, err) } } @@ -132,24 +168,77 @@ func (rm *Counter[I, O]) DecrementWithID(id string) error { return nberrors.FormatErrorOrNil(merr) } -// Flush removes all references and calls RemoveFunc for each prefix. -func (rm *Counter[I, O]) Flush() error { +// Flush removes all references and calls RemoveFunc for each key. +func (rm *Counter[Key, I, O]) Flush() error { rm.refCountMu.Lock() defer rm.refCountMu.Unlock() rm.idMu.Lock() defer rm.idMu.Unlock() var merr *multierror.Error - for prefix := range rm.refCountMap { - log.Tracef("Removing for prefix %s", prefix) - ref := rm.refCountMap[prefix] - if err := rm.remove(prefix, ref.Out); err != nil { - merr = multierror.Append(merr, fmt.Errorf("remove for prefix %s: %w", prefix, err)) + for key := range rm.refCountMap { + logCallerF("Calling remove for key %v", key) + ref := rm.refCountMap[key] + if err := rm.remove(key, ref.Out); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove for key %v: %w", key, err)) } } - rm.refCountMap = map[netip.Prefix]Ref[O]{} - rm.idMap = map[string][]netip.Prefix{} + clear(rm.refCountMap) + clear(rm.idMap) return nberrors.FormatErrorOrNil(merr) } + +// Clear removes all references without calling RemoveFunc. +func (rm *Counter[Key, I, O]) Clear() { + rm.refCountMu.Lock() + defer rm.refCountMu.Unlock() + rm.idMu.Lock() + defer rm.idMu.Unlock() + + clear(rm.refCountMap) + clear(rm.idMap) +} + +func getCallerInfo(depth int, maxDepth int) (string, bool) { + if depth >= maxDepth { + return "", false + } + + pc, _, _, ok := runtime.Caller(depth) + if !ok { + return "", false + } + + if details := runtime.FuncForPC(pc); details != nil { + name := details.Name() + + lastDotIndex := strings.LastIndex(name, "/") + if lastDotIndex != -1 { + name = name[lastDotIndex+1:] + } + + if strings.HasPrefix(name, "refcounter.") { + // +2 to account for recursion + return getCallerInfo(depth+2, maxDepth) + } + + return name, true + } + + return "", false +} + +// logCaller logs a message with the package name and method of the function that called the current function. +func logCallerF(format string, args ...interface{}) { + if log.GetLevel() < logLevel { + return + } + + if callerName, ok := getCallerInfo(3, 18); ok { + format = fmt.Sprintf("[%s] %s", callerName, format) + } + + log.StandardLogger().Logf(logLevel, format, args...) +} diff --git a/client/internal/routemanager/refcounter/types.go b/client/internal/routemanager/refcounter/types.go index 6753b64ef..aadac3e25 100644 --- a/client/internal/routemanager/refcounter/types.go +++ b/client/internal/routemanager/refcounter/types.go @@ -1,7 +1,9 @@ package refcounter +import "net/netip" + // RouteRefCounter is a Counter for Route, it doesn't take any input on Increment and doesn't use any output on Decrement -type RouteRefCounter = Counter[any, any] +type RouteRefCounter = Counter[netip.Prefix, struct{}, struct{}] // AllowedIPsRefCounter is a Counter for AllowedIPs, it takes a peer key on Increment and passes it back to Decrement -type AllowedIPsRefCounter = Counter[string, string] +type AllowedIPsRefCounter = Counter[netip.Prefix, string, string] diff --git a/client/internal/routemanager/server_android.go b/client/internal/routemanager/server_android.go index b4065bca6..c75a0a7f2 100644 --- a/client/internal/routemanager/server_android.go +++ b/client/internal/routemanager/server_android.go @@ -7,10 +7,10 @@ import ( "fmt" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/peer" - "github.com/netbirdio/netbird/iface" ) -func newServerRouter(context.Context, *iface.WGIface, firewall.Manager, *peer.Status) (serverRouter, error) { +func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (serverRouter, error) { return nil, fmt.Errorf("server route not supported on this os") } diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index 8470934c2..ef38d5707 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -11,9 +11,9 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" - "github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/route" ) @@ -22,11 +22,11 @@ type defaultServerRouter struct { ctx context.Context routes map[route.ID]*route.Route firewall firewall.Manager - wgInterface *iface.WGIface + wgInterface iface.IWGIface statusRecorder *peer.Status } -func newServerRouter(ctx context.Context, wgInterface *iface.WGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) { +func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) { return &defaultServerRouter{ ctx: ctx, routes: make(map[route.ID]*route.Route), @@ -94,7 +94,7 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error return fmt.Errorf("parse prefix: %w", err) } - err = m.firewall.RemoveRoutingRules(routerPair) + err = m.firewall.RemoveNatRule(routerPair) if err != nil { return fmt.Errorf("remove routing rules: %w", err) } @@ -123,7 +123,7 @@ func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error { return fmt.Errorf("parse prefix: %w", err) } - err = m.firewall.InsertRoutingRules(routerPair) + err = m.firewall.AddNatRule(routerPair) if err != nil { return fmt.Errorf("insert routing rules: %w", err) } @@ -157,7 +157,7 @@ func (m *defaultServerRouter) cleanUp() { continue } - err = m.firewall.RemoveRoutingRules(routerPair) + err = m.firewall.RemoveNatRule(routerPair) if err != nil { log.Errorf("Failed to remove cleanup route: %v", err) } @@ -173,15 +173,15 @@ func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) { // TODO: add ipv6 source := getDefaultPrefix(route.Network) - destination := route.Network.Masked().String() + destination := route.Network.Masked() if route.IsDynamic() { - // TODO: add ipv6 - destination = "0.0.0.0/0" + // TODO: add ipv6 additionally + destination = getDefaultPrefix(destination) } return firewall.RouterPair{ - ID: string(route.ID), - Source: source.String(), + ID: route.ID, + Source: source, Destination: destination, Masquerade: route.Masquerade, }, nil diff --git a/client/internal/routemanager/static/route.go b/client/internal/routemanager/static/route.go index 88cca522a..98c34dbee 100644 --- a/client/internal/routemanager/static/route.go +++ b/client/internal/routemanager/static/route.go @@ -30,7 +30,7 @@ func (r *Route) String() string { } func (r *Route) AddRoute(context.Context) error { - _, err := r.routeRefCounter.Increment(r.route.Network, nil) + _, err := r.routeRefCounter.Increment(r.route.Network, struct{}{}) return err } diff --git a/client/internal/routemanager/sysctl/sysctl_linux.go b/client/internal/routemanager/sysctl/sysctl_linux.go index 43394a823..bb620ee68 100644 --- a/client/internal/routemanager/sysctl/sysctl_linux.go +++ b/client/internal/routemanager/sysctl/sysctl_linux.go @@ -13,7 +13,7 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" ) const ( @@ -23,7 +23,7 @@ const ( ) // Setup configures sysctl settings for RP filtering and source validation. -func Setup(wgIface *iface.WGIface) (map[string]int, error) { +func Setup(wgIface iface.IWGIface) (map[string]int, error) { keys := map[string]int{} var result *multierror.Error diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index cddd7e7e2..d1cb83bfb 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -5,9 +5,9 @@ import ( "net/netip" "sync" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" - "github.com/netbirdio/netbird/iface" ) type Nexthop struct { @@ -15,11 +15,11 @@ type Nexthop struct { Intf *net.Interface } -type ExclusionCounter = refcounter.Counter[any, Nexthop] +type ExclusionCounter = refcounter.Counter[netip.Prefix, struct{}, Nexthop] type SysOps struct { refCounter *ExclusionCounter - wgInterface *iface.WGIface + wgInterface iface.IWGIface // prefixes is tracking all the current added prefixes im memory // (this is used in iOS as all route updates require a full table update) //nolint @@ -30,7 +30,7 @@ type SysOps struct { notifier *notifier.Notifier } -func NewSysOps(wgInterface *iface.WGIface, notifier *notifier.Notifier) *SysOps { +func NewSysOps(wgInterface iface.IWGIface, notifier *notifier.Notifier) *SysOps { return &SysOps{ wgInterface: wgInterface, notifier: notifier, diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 671545b86..9258f4a4e 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -16,10 +16,10 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/vars" - "github.com/netbirdio/netbird/iface" nbnet "github.com/netbirdio/netbird/util/net" ) @@ -41,7 +41,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP) (nbnet.AddHookFunc, nbn } refCounter := refcounter.New( - func(prefix netip.Prefix, _ any) (Nexthop, error) { + func(prefix netip.Prefix, _ struct{}) (Nexthop, error) { initialNexthop := initialNextHopV4 if prefix.Addr().Is6() { initialNexthop = initialNextHopV6 @@ -122,7 +122,7 @@ func (r *SysOps) addRouteForCurrentDefaultGateway(prefix netip.Prefix) error { // addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface. // If the next hop or interface is pointing to the VPN interface, it will return the initial values. -func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIface, initialNextHop Nexthop) (Nexthop, error) { +func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf iface.IWGIface, initialNextHop Nexthop) (Nexthop, error) { addr := prefix.Addr() switch { case addr.IsLoopback(), @@ -317,7 +317,7 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re return fmt.Errorf("convert ip to prefix: %w", err) } - if _, err := r.refCounter.IncrementWithID(string(connID), prefix, nil); err != nil { + if _, err := r.refCounter.IncrementWithID(string(connID), prefix, struct{}{}); err != nil { return fmt.Errorf("adding route reference: %v", err) } diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index 94965c119..238225807 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -19,7 +19,7 @@ import ( "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/client/iface" ) type dialer interface { diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go index 0d3630cb8..3f756788e 100644 --- a/client/internal/routemanager/systemops/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -3,6 +3,8 @@ package systemops import ( + "context" + "encoding/binary" "fmt" "net" "net/netip" @@ -11,15 +13,43 @@ import ( "strconv" "strings" "sync" + "syscall" "time" + "unsafe" log "github.com/sirupsen/logrus" "github.com/yusufpapurcu/wmi" + "golang.org/x/sys/windows" "github.com/netbirdio/netbird/client/firewall/uspfilter" nbnet "github.com/netbirdio/netbird/util/net" ) +type RouteUpdateType int + +// RouteUpdate represents a change in the routing table. +// The interface field contains the index only. +type RouteUpdate struct { + Type RouteUpdateType + Destination netip.Prefix + NextHop netip.Addr + Interface *net.Interface +} + +// RouteMonitor provides a way to monitor changes in the routing table. +type RouteMonitor struct { + updates chan RouteUpdate + handle windows.Handle + done chan struct{} +} + +// Route represents a single routing table entry. +type Route struct { + Destination netip.Prefix + Nexthop netip.Addr + Interface *net.Interface +} + type MSFT_NetRoute struct { DestinationPrefix string NextHop string @@ -28,33 +58,77 @@ type MSFT_NetRoute struct { AddressFamily uint16 } -type Route struct { - Destination netip.Prefix - Nexthop netip.Addr - Interface *net.Interface +// MIB_IPFORWARD_ROW2 is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-mib_ipforward_row2 +type MIB_IPFORWARD_ROW2 struct { + InterfaceLuid uint64 + InterfaceIndex uint32 + DestinationPrefix IP_ADDRESS_PREFIX + NextHop SOCKADDR_INET_NEXTHOP + SitePrefixLength uint8 + ValidLifetime uint32 + PreferredLifetime uint32 + Metric uint32 + Protocol uint32 + Loopback uint8 + AutoconfigureAddress uint8 + Publish uint8 + Immortal uint8 + Age uint32 + Origin uint32 } -type MSFT_NetNeighbor struct { - IPAddress string - LinkLayerAddress string - State uint8 - AddressFamily uint16 - InterfaceIndex uint32 - InterfaceAlias string +// IP_ADDRESS_PREFIX is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ns-netioapi-ip_address_prefix +type IP_ADDRESS_PREFIX struct { + Prefix SOCKADDR_INET + PrefixLength uint8 } -type Neighbor struct { - IPAddress netip.Addr - LinkLayerAddress string - State uint8 - AddressFamily uint16 - InterfaceIndex uint32 - InterfaceAlias string +// SOCKADDR_INET is defined in https://learn.microsoft.com/en-us/windows/win32/api/ws2ipdef/ns-ws2ipdef-sockaddr_inet +// It represents the union of IPv4 and IPv6 socket addresses +type SOCKADDR_INET struct { + sin6_family int16 + // nolint:unused + sin6_port uint16 + // 4 bytes ipv4 or 4 bytes flowinfo + 16 bytes ipv6 + 4 bytes scope_id + data [24]byte } -var prefixList []netip.Prefix -var lastUpdate time.Time -var mux = sync.Mutex{} +// SOCKADDR_INET_NEXTHOP is the same as SOCKADDR_INET but offset by 2 bytes +type SOCKADDR_INET_NEXTHOP struct { + // nolint:unused + pad [2]byte + sin6_family int16 + // nolint:unused + sin6_port uint16 + // 4 bytes ipv4 or 4 bytes flowinfo + 16 bytes ipv6 + 4 bytes scope_id + data [24]byte +} + +// MIB_NOTIFICATION_TYPE is defined in https://learn.microsoft.com/en-us/windows/win32/api/netioapi/ne-netioapi-mib_notification_type +type MIB_NOTIFICATION_TYPE int32 + +var ( + modiphlpapi = windows.NewLazyDLL("iphlpapi.dll") + procNotifyRouteChange2 = modiphlpapi.NewProc("NotifyRouteChange2") + procCancelMibChangeNotify2 = modiphlpapi.NewProc("CancelMibChangeNotify2") + + prefixList []netip.Prefix + lastUpdate time.Time + mux sync.Mutex +) + +const ( + MibParemeterModification MIB_NOTIFICATION_TYPE = iota + MibAddInstance + MibDeleteInstance + MibInitialNotification +) + +const ( + RouteModified RouteUpdateType = iota + RouteAdded + RouteDeleted +) func (r *SysOps) SetupRouting(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { return r.setupRefCounter(initAddresses) @@ -94,6 +168,155 @@ func (r *SysOps) removeFromRouteTable(prefix netip.Prefix, nexthop Nexthop) erro return nil } +// NewRouteMonitor creates and starts a new RouteMonitor. +// It returns a pointer to the RouteMonitor and an error if the monitor couldn't be started. +func NewRouteMonitor(ctx context.Context) (*RouteMonitor, error) { + rm := &RouteMonitor{ + updates: make(chan RouteUpdate, 5), + done: make(chan struct{}), + } + + if err := rm.start(ctx); err != nil { + return nil, err + } + + return rm, nil +} + +func (rm *RouteMonitor) start(ctx context.Context) error { + if ctx.Err() != nil { + return ctx.Err() + } + + callbackPtr := windows.NewCallback(func(callerContext uintptr, row *MIB_IPFORWARD_ROW2, notificationType MIB_NOTIFICATION_TYPE) uintptr { + if ctx.Err() != nil { + return 0 + } + + update, err := rm.parseUpdate(row, notificationType) + if err != nil { + log.Errorf("Failed to parse route update: %v", err) + return 0 + } + + select { + case <-rm.done: + return 0 + case rm.updates <- update: + default: + log.Warn("Route update channel is full, dropping update") + } + return 0 + }) + + var handle windows.Handle + if err := notifyRouteChange2(windows.AF_UNSPEC, callbackPtr, 0, false, &handle); err != nil { + return fmt.Errorf("NotifyRouteChange2 failed: %w", err) + } + + rm.handle = handle + + return nil +} + +func (rm *RouteMonitor) parseUpdate(row *MIB_IPFORWARD_ROW2, notificationType MIB_NOTIFICATION_TYPE) (RouteUpdate, error) { + // destination prefix, next hop, interface index, interface luid are guaranteed to be there + // GetIpForwardEntry2 is not needed + + var update RouteUpdate + + idx := int(row.InterfaceIndex) + if idx != 0 { + intf, err := net.InterfaceByIndex(idx) + if err != nil { + return update, fmt.Errorf("get interface name: %w", err) + } + + update.Interface = intf + } + + log.Tracef("Received route update with destination %v, next hop %v, interface %v", row.DestinationPrefix, row.NextHop, update.Interface) + dest := parseIPPrefix(row.DestinationPrefix, idx) + if !dest.Addr().IsValid() { + return RouteUpdate{}, fmt.Errorf("invalid destination: %v", row) + } + + nexthop := parseIPNexthop(row.NextHop, idx) + if !nexthop.IsValid() { + return RouteUpdate{}, fmt.Errorf("invalid next hop %v", row) + } + + updateType := RouteModified + switch notificationType { + case MibParemeterModification: + updateType = RouteModified + case MibAddInstance: + updateType = RouteAdded + case MibDeleteInstance: + updateType = RouteDeleted + } + + update.Type = updateType + update.Destination = dest + update.NextHop = nexthop + + return update, nil +} + +// Stop stops the RouteMonitor. +func (rm *RouteMonitor) Stop() error { + if rm.handle != 0 { + if err := cancelMibChangeNotify2(rm.handle); err != nil { + return fmt.Errorf("CancelMibChangeNotify2 failed: %w", err) + } + rm.handle = 0 + } + close(rm.done) + close(rm.updates) + return nil +} + +// RouteUpdates returns a channel that receives RouteUpdate messages. +func (rm *RouteMonitor) RouteUpdates() <-chan RouteUpdate { + return rm.updates +} + +func notifyRouteChange2(family uint32, callback uintptr, callerContext uintptr, initialNotification bool, handle *windows.Handle) error { + var initNotif uint32 + if initialNotification { + initNotif = 1 + } + + r1, _, e1 := syscall.SyscallN( + procNotifyRouteChange2.Addr(), + uintptr(family), + callback, + callerContext, + uintptr(initNotif), + uintptr(unsafe.Pointer(handle)), + ) + if r1 != 0 { + if e1 != 0 { + return e1 + } + return syscall.EINVAL + } + return nil +} + +func cancelMibChangeNotify2(handle windows.Handle) error { + r1, _, e1 := syscall.SyscallN(procCancelMibChangeNotify2.Addr(), uintptr(handle)) + if r1 != 0 { + if e1 != 0 { + return e1 + } + return syscall.EINVAL + } + return nil +} + +// GetRoutesFromTable returns the current routing table from with prefixes only. +// It ccaches the result for 2 seconds to avoid blocking the caller. func GetRoutesFromTable() ([]netip.Prefix, error) { mux.Lock() defer mux.Unlock() @@ -117,6 +340,7 @@ func GetRoutesFromTable() ([]netip.Prefix, error) { return prefixList, nil } +// GetRoutes retrieves the current routing table using WMI. func GetRoutes() ([]Route, error) { var entries []MSFT_NetRoute @@ -146,8 +370,8 @@ func GetRoutes() ([]Route, error) { Name: entry.InterfaceAlias, } - if nexthop.Is6() && (nexthop.IsLinkLocalUnicast() || nexthop.IsLinkLocalMulticast()) { - nexthop = nexthop.WithZone(strconv.Itoa(int(entry.InterfaceIndex))) + if nexthop.Is6() { + nexthop = addZone(nexthop, int(entry.InterfaceIndex)) } } @@ -161,33 +385,6 @@ func GetRoutes() ([]Route, error) { return routes, nil } -func GetNeighbors() ([]Neighbor, error) { - var entries []MSFT_NetNeighbor - query := `SELECT IPAddress, LinkLayerAddress, State, AddressFamily, InterfaceIndex, InterfaceAlias FROM MSFT_NetNeighbor` - if err := wmi.QueryNamespace(query, &entries, `ROOT\StandardCimv2`); err != nil { - return nil, fmt.Errorf("failed to query MSFT_NetNeighbor: %w", err) - } - - var neighbors []Neighbor - for _, entry := range entries { - addr, err := netip.ParseAddr(entry.IPAddress) - if err != nil { - log.Warnf("Unable to parse neighbor IP address %s: %v", entry.IPAddress, err) - continue - } - neighbors = append(neighbors, Neighbor{ - IPAddress: addr, - LinkLayerAddress: entry.LinkLayerAddress, - State: entry.State, - AddressFamily: entry.AddressFamily, - InterfaceIndex: entry.InterfaceIndex, - InterfaceAlias: entry.InterfaceAlias, - }) - } - - return neighbors, nil -} - func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error { args := []string{"add", prefix.String()} @@ -220,3 +417,54 @@ func addRouteCmd(prefix netip.Prefix, nexthop Nexthop) error { func isCacheDisabled() bool { return os.Getenv("NB_DISABLE_ROUTE_CACHE") == "true" } + +func parseIPPrefix(prefix IP_ADDRESS_PREFIX, idx int) netip.Prefix { + ip := parseIP(prefix.Prefix, idx) + return netip.PrefixFrom(ip, int(prefix.PrefixLength)) +} + +func parseIP(addr SOCKADDR_INET, idx int) netip.Addr { + return parseIPGeneric(addr.sin6_family, addr.data, idx) +} + +func parseIPNexthop(addr SOCKADDR_INET_NEXTHOP, idx int) netip.Addr { + return parseIPGeneric(addr.sin6_family, addr.data, idx) +} + +func parseIPGeneric(family int16, data [24]byte, interfaceIndex int) netip.Addr { + switch family { + case windows.AF_INET: + ipv4 := binary.BigEndian.Uint32(data[:4]) + return netip.AddrFrom4([4]byte{ + byte(ipv4 >> 24), + byte(ipv4 >> 16), + byte(ipv4 >> 8), + byte(ipv4), + }) + + case windows.AF_INET6: + // The IPv6 address is stored after the 4-byte flowinfo field + var ipv6 [16]byte + copy(ipv6[:], data[4:20]) + ip := netip.AddrFrom16(ipv6) + + // Check if there's a non-zero scope_id + scopeID := binary.BigEndian.Uint32(data[20:24]) + if scopeID != 0 { + ip = ip.WithZone(strconv.FormatUint(uint64(scopeID), 10)) + } else if interfaceIndex != 0 { + ip = addZone(ip, interfaceIndex) + } + + return ip + } + + return netip.IPv4Unspecified() +} + +func addZone(ip netip.Addr, interfaceIndex int) netip.Addr { + if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + ip = ip.WithZone(strconv.Itoa(interfaceIndex)) + } + return ip +} diff --git a/client/internal/wgproxy/portlookup.go b/client/internal/wgproxy/ebpf/portlookup.go similarity index 96% rename from client/internal/wgproxy/portlookup.go rename to client/internal/wgproxy/ebpf/portlookup.go index 6f3d33487..0e2c20c99 100644 --- a/client/internal/wgproxy/portlookup.go +++ b/client/internal/wgproxy/ebpf/portlookup.go @@ -1,4 +1,4 @@ -package wgproxy +package ebpf import ( "fmt" diff --git a/client/internal/wgproxy/portlookup_test.go b/client/internal/wgproxy/ebpf/portlookup_test.go similarity index 97% rename from client/internal/wgproxy/portlookup_test.go rename to client/internal/wgproxy/ebpf/portlookup_test.go index 6a386f330..92f4b8eee 100644 --- a/client/internal/wgproxy/portlookup_test.go +++ b/client/internal/wgproxy/ebpf/portlookup_test.go @@ -1,4 +1,4 @@ -package wgproxy +package ebpf import ( "fmt" diff --git a/client/internal/wgproxy/proxy_ebpf.go b/client/internal/wgproxy/ebpf/proxy.go similarity index 64% rename from client/internal/wgproxy/proxy_ebpf.go rename to client/internal/wgproxy/ebpf/proxy.go index bbd00d6e2..27ede3ef1 100644 --- a/client/internal/wgproxy/proxy_ebpf.go +++ b/client/internal/wgproxy/ebpf/proxy.go @@ -1,6 +1,6 @@ //go:build linux && !android -package wgproxy +package ebpf import ( "context" @@ -13,47 +13,49 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" + "github.com/hashicorp/go-multierror" "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" + nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/ebpf" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" nbnet "github.com/netbirdio/netbird/util/net" ) +const ( + loopbackAddr = "127.0.0.1" +) + // WGEBPFProxy definition for proxy with EBPF support type WGEBPFProxy struct { - ebpfManager ebpfMgr.Manager - - ctx context.Context - cancel context.CancelFunc - - lastUsedPort uint16 localWGListenPort int + ebpfManager ebpfMgr.Manager turnConnStore map[uint16]net.Conn turnConnMutex sync.Mutex - rawConn net.PacketConn - conn transport.UDPConn + lastUsedPort uint16 + rawConn net.PacketConn + conn transport.UDPConn + + ctx context.Context + ctxCancel context.CancelFunc } // NewWGEBPFProxy create new WGEBPFProxy instance -func NewWGEBPFProxy(ctx context.Context, wgPort int) *WGEBPFProxy { +func NewWGEBPFProxy(wgPort int) *WGEBPFProxy { log.Debugf("instantiate ebpf proxy") wgProxy := &WGEBPFProxy{ localWGListenPort: wgPort, ebpfManager: ebpf.GetEbpfManagerInstance(), - lastUsedPort: 0, turnConnStore: make(map[uint16]net.Conn), } - wgProxy.ctx, wgProxy.cancel = context.WithCancel(ctx) - return wgProxy } -// listen load ebpf program and listen the proxy -func (p *WGEBPFProxy) listen() error { +// Listen load ebpf program and listen the proxy +func (p *WGEBPFProxy) Listen() error { pl := portLookup{} wgPorxyPort, err := pl.searchFreePort() if err != nil { @@ -72,13 +74,14 @@ func (p *WGEBPFProxy) listen() error { addr := net.UDPAddr{ Port: wgPorxyPort, - IP: net.ParseIP("127.0.0.1"), + IP: net.ParseIP(loopbackAddr), } + p.ctx, p.ctxCancel = context.WithCancel(context.Background()) + conn, err := nbnet.ListenUDP("udp", &addr) if err != nil { - cErr := p.Free() - if cErr != nil { + if cErr := p.Free(); cErr != nil { log.Errorf("Failed to close the wgproxy: %s", cErr) } return err @@ -91,108 +94,114 @@ func (p *WGEBPFProxy) listen() error { } // AddTurnConn add new turn connection for the proxy -func (p *WGEBPFProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) { +func (p *WGEBPFProxy) AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) { wgEndpointPort, err := p.storeTurnConn(turnConn) if err != nil { return nil, err } - go p.proxyToLocal(wgEndpointPort, turnConn) + go p.proxyToLocal(ctx, wgEndpointPort, turnConn) log.Infof("turn conn added to wg proxy store: %s, endpoint port: :%d", turnConn.RemoteAddr(), wgEndpointPort) wgEndpoint := &net.UDPAddr{ - IP: net.ParseIP("127.0.0.1"), + IP: net.ParseIP(loopbackAddr), Port: int(wgEndpointPort), } return wgEndpoint, nil } -// CloseConn doing nothing because this type of proxy implementation does not store the connection -func (p *WGEBPFProxy) CloseConn() error { - return nil -} - -// Free resources +// Free resources except the remoteConns will be keep open. func (p *WGEBPFProxy) Free() error { log.Debugf("free up ebpf wg proxy") - var err1, err2, err3 error - if p.conn != nil { - err1 = p.conn.Close() + if p.ctx != nil && p.ctx.Err() != nil { + //nolint + return nil } - err2 = p.ebpfManager.FreeWGProxy() - if p.rawConn != nil { - err3 = p.rawConn.Close() + p.ctxCancel() + + var result *multierror.Error + if p.conn != nil { // p.conn will be nil if we have failed to listen + if err := p.conn.Close(); err != nil { + result = multierror.Append(result, err) + } } - if err1 != nil { - return err1 + if err := p.ebpfManager.FreeWGProxy(); err != nil { + result = multierror.Append(result, err) } - if err2 != nil { - return err2 + if err := p.rawConn.Close(); err != nil { + result = multierror.Append(result, err) } - - return err3 + return nberrors.FormatErrorOrNil(result) } -func (p *WGEBPFProxy) proxyToLocal(endpointPort uint16, remoteConn net.Conn) { +func (p *WGEBPFProxy) proxyToLocal(ctx context.Context, endpointPort uint16, remoteConn net.Conn) { + defer p.removeTurnConn(endpointPort) + + var ( + err error + n int + ) buf := make([]byte, 1500) - var err error - defer func() { - p.removeTurnConn(endpointPort) - }() - for { - select { - case <-p.ctx.Done(): - return - default: - var n int - n, err = remoteConn.Read(buf) - if err != nil { - if err != io.EOF { - log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err) - } + for ctx.Err() == nil { + n, err = remoteConn.Read(buf) + if err != nil { + if ctx.Err() != nil { return } - err = p.sendPkg(buf[:n], endpointPort) - if err != nil { - log.Errorf("failed to write out turn pkg to local conn: %v", err) + if err != io.EOF { + log.Errorf("failed to read from turn conn (endpoint: :%d): %s", endpointPort, err) } + return + } + + if err := p.sendPkg(buf[:n], endpointPort); err != nil { + if ctx.Err() != nil || p.ctx.Err() != nil { + return + } + log.Errorf("failed to write out turn pkg to local conn: %v", err) } } } // proxyToRemote read messages from local WireGuard interface and forward it to remote conn +// From this go routine has only one instance. func (p *WGEBPFProxy) proxyToRemote() { buf := make([]byte, 1500) - for { - select { - case <-p.ctx.Done(): - return - default: - n, addr, err := p.conn.ReadFromUDP(buf) - if err != nil { - log.Errorf("failed to read UDP pkg from WG: %s", err) + for p.ctx.Err() == nil { + if err := p.readAndForwardPacket(buf); err != nil { + if p.ctx.Err() != nil { return } - - p.turnConnMutex.Lock() - conn, ok := p.turnConnStore[uint16(addr.Port)] - p.turnConnMutex.Unlock() - if !ok { - log.Infof("turn conn not found by port: %d", addr.Port) - continue - } - - _, err = conn.Write(buf[:n]) - if err != nil { - log.Debugf("failed to forward local wg pkg (%d) to remote turn conn: %s", addr.Port, err) - } + log.Errorf("failed to proxy packet to remote conn: %s", err) } } } +func (p *WGEBPFProxy) readAndForwardPacket(buf []byte) error { + n, addr, err := p.conn.ReadFromUDP(buf) + if err != nil { + return fmt.Errorf("failed to read UDP packet from WG: %w", err) + } + + p.turnConnMutex.Lock() + conn, ok := p.turnConnStore[uint16(addr.Port)] + p.turnConnMutex.Unlock() + if !ok { + if p.ctx.Err() == nil { + log.Debugf("turn conn not found by port because conn already has been closed: %d", addr.Port) + } + return nil + } + + if _, err := conn.Write(buf[:n]); err != nil { + return fmt.Errorf("failed to forward local WG packet (%d) to remote turn conn: %w", addr.Port, err) + } + return nil +} + func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) { p.turnConnMutex.Lock() defer p.turnConnMutex.Unlock() @@ -206,11 +215,14 @@ func (p *WGEBPFProxy) storeTurnConn(turnConn net.Conn) (uint16, error) { } func (p *WGEBPFProxy) removeTurnConn(turnConnID uint16) { - log.Tracef("remove turn conn from store by port: %d", turnConnID) p.turnConnMutex.Lock() defer p.turnConnMutex.Unlock() - delete(p.turnConnStore, turnConnID) + _, ok := p.turnConnStore[turnConnID] + if ok { + log.Debugf("remove turn conn from store by port: %d", turnConnID) + } + delete(p.turnConnStore, turnConnID) } func (p *WGEBPFProxy) nextFreePort() (uint16, error) { diff --git a/client/internal/wgproxy/proxy_ebpf_test.go b/client/internal/wgproxy/ebpf/proxy_test.go similarity index 86% rename from client/internal/wgproxy/proxy_ebpf_test.go rename to client/internal/wgproxy/ebpf/proxy_test.go index 821e64218..b15bc686c 100644 --- a/client/internal/wgproxy/proxy_ebpf_test.go +++ b/client/internal/wgproxy/ebpf/proxy_test.go @@ -1,14 +1,13 @@ //go:build linux && !android -package wgproxy +package ebpf import ( - "context" "testing" ) func TestWGEBPFProxy_connStore(t *testing.T) { - wgProxy := NewWGEBPFProxy(context.Background(), 1) + wgProxy := NewWGEBPFProxy(1) p, _ := wgProxy.storeTurnConn(nil) if p != 1 { @@ -28,7 +27,7 @@ func TestWGEBPFProxy_connStore(t *testing.T) { } func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) { - wgProxy := NewWGEBPFProxy(context.Background(), 1) + wgProxy := NewWGEBPFProxy(1) _, _ = wgProxy.storeTurnConn(nil) wgProxy.lastUsedPort = 65535 @@ -44,7 +43,7 @@ func TestWGEBPFProxy_portCalculation_overflow(t *testing.T) { } func TestWGEBPFProxy_portCalculation_maxConn(t *testing.T) { - wgProxy := NewWGEBPFProxy(context.Background(), 1) + wgProxy := NewWGEBPFProxy(1) for i := 0; i < 65535; i++ { _, _ = wgProxy.storeTurnConn(nil) diff --git a/client/internal/wgproxy/ebpf/wrapper.go b/client/internal/wgproxy/ebpf/wrapper.go new file mode 100644 index 000000000..c5639f840 --- /dev/null +++ b/client/internal/wgproxy/ebpf/wrapper.go @@ -0,0 +1,44 @@ +//go:build linux && !android + +package ebpf + +import ( + "context" + "fmt" + "net" +) + +// ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call +type ProxyWrapper struct { + WgeBPFProxy *WGEBPFProxy + + remoteConn net.Conn + cancel context.CancelFunc // with thic cancel function, we stop remoteToLocal thread +} + +func (e *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) { + ctxConn, cancel := context.WithCancel(ctx) + addr, err := e.WgeBPFProxy.AddTurnConn(ctxConn, remoteConn) + + if err != nil { + cancel() + return nil, fmt.Errorf("add turn conn: %w", err) + } + e.remoteConn = remoteConn + e.cancel = cancel + return addr, err +} + +// CloseConn close the remoteConn and automatically remove the conn instance from the map +func (e *ProxyWrapper) CloseConn() error { + if e.cancel == nil { + return fmt.Errorf("proxy not started") + } + + e.cancel() + + if err := e.remoteConn.Close(); err != nil { + return fmt.Errorf("failed to close remote conn: %w", err) + } + return nil +} diff --git a/client/internal/wgproxy/factory.go b/client/internal/wgproxy/factory.go deleted file mode 100644 index f4eb150b0..000000000 --- a/client/internal/wgproxy/factory.go +++ /dev/null @@ -1,22 +0,0 @@ -package wgproxy - -import "context" - -type Factory struct { - wgPort int - ebpfProxy Proxy -} - -func (w *Factory) GetProxy(ctx context.Context) Proxy { - if w.ebpfProxy != nil { - return w.ebpfProxy - } - return NewWGUserSpaceProxy(ctx, w.wgPort) -} - -func (w *Factory) Free() error { - if w.ebpfProxy != nil { - return w.ebpfProxy.Free() - } - return nil -} diff --git a/client/internal/wgproxy/factory_linux.go b/client/internal/wgproxy/factory_linux.go index d01ae7e74..369ba99db 100644 --- a/client/internal/wgproxy/factory_linux.go +++ b/client/internal/wgproxy/factory_linux.go @@ -3,20 +3,26 @@ package wgproxy import ( - "context" - log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/wgproxy/ebpf" + "github.com/netbirdio/netbird/client/internal/wgproxy/usp" ) -func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory { +type Factory struct { + wgPort int + ebpfProxy *ebpf.WGEBPFProxy +} + +func NewFactory(userspace bool, wgPort int) *Factory { f := &Factory{wgPort: wgPort} if userspace { return f } - ebpfProxy := NewWGEBPFProxy(ctx, wgPort) - err := ebpfProxy.listen() + ebpfProxy := ebpf.NewWGEBPFProxy(wgPort) + err := ebpfProxy.Listen() if err != nil { log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err) return f @@ -25,3 +31,20 @@ func NewFactory(ctx context.Context, userspace bool, wgPort int) *Factory { f.ebpfProxy = ebpfProxy return f } + +func (w *Factory) GetProxy() Proxy { + if w.ebpfProxy != nil { + p := &ebpf.ProxyWrapper{ + WgeBPFProxy: w.ebpfProxy, + } + return p + } + return usp.NewWGUserSpaceProxy(w.wgPort) +} + +func (w *Factory) Free() error { + if w.ebpfProxy == nil { + return nil + } + return w.ebpfProxy.Free() +} diff --git a/client/internal/wgproxy/factory_nonlinux.go b/client/internal/wgproxy/factory_nonlinux.go index d1640c97d..f930b09b3 100644 --- a/client/internal/wgproxy/factory_nonlinux.go +++ b/client/internal/wgproxy/factory_nonlinux.go @@ -2,8 +2,20 @@ package wgproxy -import "context" +import "github.com/netbirdio/netbird/client/internal/wgproxy/usp" -func NewFactory(ctx context.Context, _ bool, wgPort int) *Factory { +type Factory struct { + wgPort int +} + +func NewFactory(_ bool, wgPort int) *Factory { return &Factory{wgPort: wgPort} } + +func (w *Factory) GetProxy() Proxy { + return usp.NewWGUserSpaceProxy(w.wgPort) +} + +func (w *Factory) Free() error { + return nil +} diff --git a/client/internal/wgproxy/proxy.go b/client/internal/wgproxy/proxy.go index b88df73a0..96fae8dd1 100644 --- a/client/internal/wgproxy/proxy.go +++ b/client/internal/wgproxy/proxy.go @@ -1,12 +1,12 @@ package wgproxy import ( + "context" "net" ) -// Proxy is a transfer layer between the Turn connection and the WireGuard +// Proxy is a transfer layer between the relayed connection and the WireGuard type Proxy interface { - AddTurnConn(turnConn net.Conn) (net.Addr, error) + AddTurnConn(ctx context.Context, turnConn net.Conn) (net.Addr, error) CloseConn() error - Free() error } diff --git a/client/internal/wgproxy/proxy_test.go b/client/internal/wgproxy/proxy_test.go new file mode 100644 index 000000000..b09e6be55 --- /dev/null +++ b/client/internal/wgproxy/proxy_test.go @@ -0,0 +1,128 @@ +//go:build linux + +package wgproxy + +import ( + "context" + "io" + "net" + "os" + "runtime" + "testing" + "time" + + "github.com/netbirdio/netbird/client/internal/wgproxy/ebpf" + "github.com/netbirdio/netbird/client/internal/wgproxy/usp" + "github.com/netbirdio/netbird/util" +) + +func TestMain(m *testing.M) { + _ = util.InitLog("trace", "console") + code := m.Run() + os.Exit(code) +} + +type mocConn struct { + closeChan chan struct{} + closed bool +} + +func newMockConn() *mocConn { + return &mocConn{ + closeChan: make(chan struct{}), + } +} + +func (m *mocConn) Read(b []byte) (n int, err error) { + <-m.closeChan + return 0, io.EOF +} + +func (m *mocConn) Write(b []byte) (n int, err error) { + <-m.closeChan + return 0, io.EOF +} + +func (m *mocConn) Close() error { + if m.closed == true { + return nil + } + + m.closed = true + close(m.closeChan) + return nil +} + +func (m *mocConn) LocalAddr() net.Addr { + panic("implement me") +} + +func (m *mocConn) RemoteAddr() net.Addr { + return &net.UDPAddr{ + IP: net.ParseIP("172.16.254.1"), + } +} + +func (m *mocConn) SetDeadline(t time.Time) error { + panic("implement me") +} + +func (m *mocConn) SetReadDeadline(t time.Time) error { + panic("implement me") +} + +func (m *mocConn) SetWriteDeadline(t time.Time) error { + panic("implement me") +} + +func TestProxyCloseByRemoteConn(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + proxy Proxy + }{ + { + name: "userspace proxy", + proxy: usp.NewWGUserSpaceProxy(51830), + }, + } + + if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" { + ebpfProxy := ebpf.NewWGEBPFProxy(51831) + if err := ebpfProxy.Listen(); err != nil { + t.Fatalf("failed to initialize ebpf proxy: %s", err) + } + defer func() { + if err := ebpfProxy.Free(); err != nil { + t.Errorf("failed to free ebpf proxy: %s", err) + } + }() + proxyWrapper := &ebpf.ProxyWrapper{ + WgeBPFProxy: ebpfProxy, + } + + tests = append(tests, struct { + name string + proxy Proxy + }{ + name: "ebpf proxy", + proxy: proxyWrapper, + }) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + relayedConn := newMockConn() + _, err := tt.proxy.AddTurnConn(ctx, relayedConn) + if err != nil { + t.Errorf("error: %v", err) + } + + _ = relayedConn.Close() + if err := tt.proxy.CloseConn(); err != nil { + t.Errorf("error: %v", err) + } + }) + } +} diff --git a/client/internal/wgproxy/proxy_userspace.go b/client/internal/wgproxy/proxy_userspace.go deleted file mode 100644 index 234ea2a42..000000000 --- a/client/internal/wgproxy/proxy_userspace.go +++ /dev/null @@ -1,108 +0,0 @@ -package wgproxy - -import ( - "context" - "fmt" - "net" - - log "github.com/sirupsen/logrus" - - nbnet "github.com/netbirdio/netbird/util/net" -) - -// WGUserSpaceProxy proxies -type WGUserSpaceProxy struct { - localWGListenPort int - ctx context.Context - cancel context.CancelFunc - - remoteConn net.Conn - localConn net.Conn -} - -// NewWGUserSpaceProxy instantiate a user space WireGuard proxy -func NewWGUserSpaceProxy(ctx context.Context, wgPort int) *WGUserSpaceProxy { - log.Debugf("Initializing new user space proxy with port %d", wgPort) - p := &WGUserSpaceProxy{ - localWGListenPort: wgPort, - } - p.ctx, p.cancel = context.WithCancel(ctx) - return p -} - -// AddTurnConn start the proxy with the given remote conn -func (p *WGUserSpaceProxy) AddTurnConn(turnConn net.Conn) (net.Addr, error) { - p.remoteConn = turnConn - - var err error - p.localConn, err = nbnet.NewDialer().DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) - if err != nil { - log.Errorf("failed dialing to local Wireguard port %s", err) - return nil, err - } - - go p.proxyToRemote() - go p.proxyToLocal() - - return p.localConn.LocalAddr(), err -} - -// CloseConn close the localConn -func (p *WGUserSpaceProxy) CloseConn() error { - p.cancel() - if p.localConn == nil { - return nil - } - return p.localConn.Close() -} - -// Free doing nothing because this implementation of proxy does not have global state -func (p *WGUserSpaceProxy) Free() error { - return nil -} - -// proxyToRemote proxies everything from Wireguard to the RemoteKey peer -// blocks -func (p *WGUserSpaceProxy) proxyToRemote() { - - buf := make([]byte, 1500) - for { - select { - case <-p.ctx.Done(): - return - default: - n, err := p.localConn.Read(buf) - if err != nil { - continue - } - - _, err = p.remoteConn.Write(buf[:n]) - if err != nil { - continue - } - } - } -} - -// proxyToLocal proxies everything from the RemoteKey peer to local Wireguard -// blocks -func (p *WGUserSpaceProxy) proxyToLocal() { - - buf := make([]byte, 1500) - for { - select { - case <-p.ctx.Done(): - return - default: - n, err := p.remoteConn.Read(buf) - if err != nil { - continue - } - - _, err = p.localConn.Write(buf[:n]) - if err != nil { - continue - } - } - } -} diff --git a/client/internal/wgproxy/usp/proxy.go b/client/internal/wgproxy/usp/proxy.go new file mode 100644 index 000000000..83a8725d8 --- /dev/null +++ b/client/internal/wgproxy/usp/proxy.go @@ -0,0 +1,146 @@ +package usp + +import ( + "context" + "fmt" + "net" + "sync" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/errors" +) + +// WGUserSpaceProxy proxies +type WGUserSpaceProxy struct { + localWGListenPort int + ctx context.Context + cancel context.CancelFunc + + remoteConn net.Conn + localConn net.Conn + closeMu sync.Mutex + closed bool +} + +// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation +func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy { + log.Debugf("Initializing new user space proxy with port %d", wgPort) + p := &WGUserSpaceProxy{ + localWGListenPort: wgPort, + } + return p +} + +// AddTurnConn start the proxy with the given remote conn +func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) (net.Addr, error) { + p.ctx, p.cancel = context.WithCancel(ctx) + + p.remoteConn = remoteConn + + var err error + dialer := net.Dialer{} + p.localConn, err = dialer.DialContext(p.ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) + if err != nil { + log.Errorf("failed dialing to local Wireguard port %s", err) + return nil, err + } + + go p.proxyToRemote() + go p.proxyToLocal() + + return p.localConn.LocalAddr(), err +} + +// CloseConn close the localConn +func (p *WGUserSpaceProxy) CloseConn() error { + if p.cancel == nil { + return fmt.Errorf("proxy not started") + } + return p.close() +} + +func (p *WGUserSpaceProxy) close() error { + p.closeMu.Lock() + defer p.closeMu.Unlock() + + // prevent double close + if p.closed { + return nil + } + p.closed = true + + p.cancel() + + var result *multierror.Error + if err := p.remoteConn.Close(); err != nil { + result = multierror.Append(result, fmt.Errorf("remote conn: %s", err)) + } + + if err := p.localConn.Close(); err != nil { + result = multierror.Append(result, fmt.Errorf("local conn: %s", err)) + } + return errors.FormatErrorOrNil(result) +} + +// proxyToRemote proxies from Wireguard to the RemoteKey +func (p *WGUserSpaceProxy) proxyToRemote() { + defer func() { + if err := p.close(); err != nil { + log.Warnf("error in proxy to remote loop: %s", err) + } + }() + + buf := make([]byte, 1500) + for p.ctx.Err() == nil { + n, err := p.localConn.Read(buf) + if err != nil { + if p.ctx.Err() != nil { + return + } + log.Debugf("failed to read from wg interface conn: %s", err) + return + } + + _, err = p.remoteConn.Write(buf[:n]) + if err != nil { + if p.ctx.Err() != nil { + return + } + + log.Debugf("failed to write to remote conn: %s", err) + return + } + } +} + +// proxyToLocal proxies from the Remote peer to local WireGuard +func (p *WGUserSpaceProxy) proxyToLocal() { + defer func() { + if err := p.close(); err != nil { + log.Warnf("error in proxy to local loop: %s", err) + } + }() + + buf := make([]byte, 1500) + for p.ctx.Err() == nil { + n, err := p.remoteConn.Read(buf) + if err != nil { + if p.ctx.Err() != nil { + return + } + log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err) + return + } + + _, err = p.localConn.Write(buf[:n]) + if err != nil { + if p.ctx.Err() != nil { + return + } + log.Debugf("failed to write to wg interface conn: %s", err) + continue + } + } +} diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 779c27a4d..dc13706bf 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -168,7 +168,6 @@ func (c *Client) GetStatusDetails() *StatusDetails { BytesTx: p.BytesTx, ConnStatus: p.ConnStatus.String(), ConnStatusUpdate: p.ConnStatusUpdate.Format("2006-01-02 15:04:05"), - Direct: p.Direct, LastWireguardHandshake: p.LastWireguardHandshake.String(), Relayed: p.Relayed, RosenpassEnabled: p.RosenpassEnabled, diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index fb10a38d3..b942d8b6e 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v4.23.4 +// protoc v3.21.12 // source: daemon.proto package proto @@ -899,7 +899,6 @@ type PeerState struct { ConnStatus string `protobuf:"bytes,3,opt,name=connStatus,proto3" json:"connStatus,omitempty"` ConnStatusUpdate *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=connStatusUpdate,proto3" json:"connStatusUpdate,omitempty"` Relayed bool `protobuf:"varint,5,opt,name=relayed,proto3" json:"relayed,omitempty"` - Direct bool `protobuf:"varint,6,opt,name=direct,proto3" json:"direct,omitempty"` LocalIceCandidateType string `protobuf:"bytes,7,opt,name=localIceCandidateType,proto3" json:"localIceCandidateType,omitempty"` RemoteIceCandidateType string `protobuf:"bytes,8,opt,name=remoteIceCandidateType,proto3" json:"remoteIceCandidateType,omitempty"` Fqdn string `protobuf:"bytes,9,opt,name=fqdn,proto3" json:"fqdn,omitempty"` @@ -911,6 +910,7 @@ type PeerState struct { RosenpassEnabled bool `protobuf:"varint,15,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` Routes []string `protobuf:"bytes,16,rep,name=routes,proto3" json:"routes,omitempty"` Latency *durationpb.Duration `protobuf:"bytes,17,opt,name=latency,proto3" json:"latency,omitempty"` + RelayAddress string `protobuf:"bytes,18,opt,name=relayAddress,proto3" json:"relayAddress,omitempty"` } func (x *PeerState) Reset() { @@ -980,13 +980,6 @@ func (x *PeerState) GetRelayed() bool { return false } -func (x *PeerState) GetDirect() bool { - if x != nil { - return x.Direct - } - return false -} - func (x *PeerState) GetLocalIceCandidateType() string { if x != nil { return x.LocalIceCandidateType @@ -1064,6 +1057,13 @@ func (x *PeerState) GetLatency() *durationpb.Duration { return nil } +func (x *PeerState) GetRelayAddress() string { + if x != nil { + return x.RelayAddress + } + return "" +} + // LocalPeerState contains the latest state of the local peer type LocalPeerState struct { state protoimpl.MessageState @@ -2243,7 +2243,7 @@ var file_daemon_proto_rawDesc = []byte{ 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, - 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x22, 0xce, 0x05, 0x0a, 0x09, 0x50, 0x65, + 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x22, 0xda, 0x05, 0x0a, 0x09, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, @@ -2255,209 +2255,210 @@ var file_daemon_proto_rawDesc = []byte{ 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x10, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, - 0x64, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x06, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x12, 0x34, 0x0a, 0x15, 0x6c, 0x6f, 0x63, - 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, - 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, + 0x64, 0x12, 0x34, 0x0a, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, + 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, + 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x36, 0x0a, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, + 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, + 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, - 0x36, 0x0a, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, - 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, - 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, - 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x3c, 0x0a, 0x19, 0x6c, - 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, - 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x19, - 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, - 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x65, 0x6d, - 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, - 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x09, 0x52, 0x1a, 0x72, - 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, - 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x52, 0x0a, 0x16, 0x6c, 0x61, 0x73, - 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, 0x64, 0x73, 0x68, - 0x61, 0x6b, 0x65, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, - 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, - 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x16, 0x6c, 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, - 0x75, 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65, 0x12, 0x18, 0x0a, - 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, - 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, 0x12, 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, - 0x54, 0x78, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, - 0x78, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, - 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, - 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x16, 0x0a, - 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x10, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, - 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, - 0x18, 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, - 0x6e, 0x52, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x22, 0xec, 0x01, 0x0a, 0x0e, 0x4c, - 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, - 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, - 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, - 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, - 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, - 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12, - 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, - 0x71, 0x64, 0x6e, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, - 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, - 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, - 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, - 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, - 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, - 0x65, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, - 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22, 0x53, 0x0a, 0x0b, 0x53, 0x69, 0x67, - 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, - 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, - 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, - 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x57, - 0x0a, 0x0f, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, - 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, - 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x52, 0x0a, 0x0a, 0x52, 0x65, 0x6c, 0x61, 0x79, - 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x49, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x49, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, - 0x61, 0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x76, 0x61, 0x69, - 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x72, 0x0a, 0x0c, 0x4e, - 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x73, 0x65, - 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, - 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, - 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, - 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, - 0x6f, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, - 0xd2, 0x02, 0x0a, 0x0a, 0x46, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x41, - 0x0a, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, - 0x52, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x12, 0x35, 0x0a, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0b, 0x73, 0x69, 0x67, - 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x3e, 0x0a, 0x0e, 0x6c, 0x6f, 0x63, 0x61, - 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, - 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, - 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x27, 0x0a, 0x05, 0x70, 0x65, 0x65, 0x72, - 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x05, 0x70, 0x65, 0x65, 0x72, - 0x73, 0x12, 0x2a, 0x0a, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, - 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, - 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x12, 0x35, 0x0a, - 0x0b, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4e, 0x53, 0x47, 0x72, - 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72, - 0x76, 0x65, 0x72, 0x73, 0x22, 0x13, 0x0a, 0x11, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, - 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, - 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, - 0x25, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x0d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, - 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22, 0x5b, 0x0a, 0x13, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, - 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, - 0x08, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, - 0x08, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x44, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x70, 0x70, - 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, - 0x64, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, - 0x61, 0x6c, 0x6c, 0x22, 0x16, 0x0a, 0x14, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, - 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1a, 0x0a, 0x06, 0x49, - 0x50, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x69, 0x70, 0x73, 0x18, 0x01, 0x20, 0x03, - 0x28, 0x09, 0x52, 0x03, 0x69, 0x70, 0x73, 0x22, 0xf9, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, - 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, - 0x44, 0x12, 0x18, 0x0a, 0x07, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x07, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x1a, 0x0a, 0x08, 0x73, - 0x65, 0x6c, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x73, - 0x65, 0x6c, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, - 0x73, 0x12, 0x40, 0x0a, 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, - 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x52, 0x6f, 0x75, 0x74, 0x65, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, - 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, - 0x49, 0x50, 0x73, 0x1a, 0x4e, 0x0a, 0x10, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, - 0x50, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x24, 0x0a, 0x05, 0x76, 0x61, 0x6c, - 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x49, 0x50, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, - 0x02, 0x38, 0x01, 0x22, 0x6a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, - 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6e, 0x6f, - 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x6e, - 0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, - 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, - 0x1e, 0x0a, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x03, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x22, - 0x29, 0x0a, 0x13, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x14, 0x0a, 0x12, 0x47, 0x65, - 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x22, 0x3d, 0x0a, 0x13, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, - 0x3c, 0x0a, 0x12, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, - 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a, - 0x13, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, - 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, - 0x05, 0x50, 0x41, 0x4e, 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, - 0x4c, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, - 0x0a, 0x04, 0x57, 0x41, 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, - 0x10, 0x05, 0x12, 0x09, 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, - 0x05, 0x54, 0x52, 0x41, 0x43, 0x45, 0x10, 0x07, 0x32, 0xb8, 0x06, 0x0a, 0x0d, 0x44, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, - 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, - 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, - 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, - 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, - 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, - 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, - 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, - 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, - 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, - 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, - 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, - 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, - 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, - 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, - 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x53, 0x65, 0x6c, - 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, - 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, - 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, - 0x6e, 0x64, 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, - 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, - 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, - 0x48, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, - 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, + 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, + 0x71, 0x64, 0x6e, 0x12, 0x3c, 0x0a, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, + 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, + 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, + 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, + 0x74, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, + 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, + 0x0b, 0x20, 0x01, 0x28, 0x09, 0x52, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, + 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, + 0x74, 0x12, 0x52, 0x0a, 0x16, 0x6c, 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, + 0x72, 0x64, 0x48, 0x61, 0x6e, 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65, 0x18, 0x0c, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x16, 0x6c, + 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, 0x64, + 0x73, 0x68, 0x61, 0x6b, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, + 0x18, 0x0d, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, 0x12, + 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x03, + 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, + 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0f, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, + 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, + 0x10, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, + 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x18, 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, + 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, + 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, + 0x63, 0x79, 0x12, 0x22, 0x0a, 0x0c, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x41, 0x64, 0x64, 0x72, 0x65, + 0x73, 0x73, 0x18, 0x12, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x41, + 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0xec, 0x01, 0x0a, 0x0e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, + 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, + 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, + 0x79, 0x12, 0x28, 0x0a, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, + 0x66, 0x61, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, 0x6b, 0x65, 0x72, 0x6e, + 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, + 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, + 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, + 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, + 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72, + 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, + 0x76, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, + 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x12, 0x16, 0x0a, + 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, + 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22, 0x53, 0x0a, 0x0b, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, + 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, + 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x57, 0x0a, 0x0f, 0x4d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, + 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, + 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, + 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, + 0x72, 0x6f, 0x72, 0x22, 0x52, 0x0a, 0x0a, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x49, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, + 0x55, 0x52, 0x49, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, + 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x72, 0x0a, 0x0c, 0x4e, 0x53, 0x47, 0x72, 0x6f, + 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x73, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, + 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x65, + 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, + 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0xd2, 0x02, 0x0a, 0x0a, + 0x46, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x41, 0x0a, 0x0f, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0f, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x35, 0x0a, + 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x69, 0x67, 0x6e, + 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x12, 0x3e, 0x0a, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, + 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x52, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x12, 0x27, 0x0a, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x50, 0x65, 0x65, + 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2a, 0x0a, + 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x52, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x12, 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73, + 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x52, 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, + 0x22, 0x13, 0x0a, 0x11, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, + 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x06, 0x72, + 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, + 0x65, 0x73, 0x22, 0x5b, 0x0a, 0x13, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, + 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x72, 0x6f, 0x75, + 0x74, 0x65, 0x49, 0x44, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x72, 0x6f, 0x75, + 0x74, 0x65, 0x49, 0x44, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, 0x64, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, 0x64, 0x12, 0x10, 0x0a, + 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, + 0x16, 0x0a, 0x14, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1a, 0x0a, 0x06, 0x49, 0x50, 0x4c, 0x69, 0x73, + 0x74, 0x12, 0x10, 0x0a, 0x03, 0x69, 0x70, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x03, + 0x69, 0x70, 0x73, 0x22, 0xf9, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, + 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, + 0x07, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, + 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, 0x6c, 0x65, 0x63, + 0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x73, 0x65, 0x6c, 0x65, 0x63, + 0x74, 0x65, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x04, + 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x40, 0x0a, + 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x18, 0x05, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x6f, 0x75, 0x74, + 0x65, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x45, 0x6e, 0x74, + 0x72, 0x79, 0x52, 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x1a, + 0x4e, 0x0a, 0x10, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x45, 0x6e, + 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x24, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x49, 0x50, + 0x4c, 0x69, 0x73, 0x74, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, + 0x6a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6e, 0x6f, 0x6e, 0x79, 0x6d, 0x69, + 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x6e, 0x6f, 0x6e, 0x79, 0x6d, + 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x1e, 0x0a, 0x0a, 0x73, + 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x22, 0x29, 0x0a, 0x13, 0x44, + 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x14, 0x0a, 0x12, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, + 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3d, 0x0a, 0x13, + 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, + 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x3c, 0x0a, 0x12, 0x53, + 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, + 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, + 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74, + 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, + 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, + 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, + 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, + 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, + 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, + 0x43, 0x45, 0x10, 0x07, 0x32, 0xb8, 0x06, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, + 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, + 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, + 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, + 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, + 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, + 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, + 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, + 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, + 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, + 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, + 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, + 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, + 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, + 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x22, 0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, + 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, + 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, + 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, + 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, + 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, + 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, + 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47, + 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74, - 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, - 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, + 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, + 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, + 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, + 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x33, } var ( diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 43c379fb5..384bc0e62 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -168,7 +168,6 @@ message PeerState { string connStatus = 3; google.protobuf.Timestamp connStatusUpdate = 4; bool relayed = 5; - bool direct = 6; string localIceCandidateType = 7; string remoteIceCandidateType = 8; string fqdn = 9; @@ -180,6 +179,7 @@ message PeerState { bool rosenpassEnabled = 15; repeated string routes = 16; google.protobuf.Duration latency = 17; + string relayAddress = 18; } // LocalPeerState contains the latest state of the local peer diff --git a/client/server/debug.go b/client/server/debug.go index 1187f3187..5ed43293b 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -369,8 +369,8 @@ func seedFromStatus(a *anonymize.Anonymizer, status *peer.FullStatus) { } for _, relay := range status.Relays { - if relay.URI != nil { - a.AnonymizeURI(relay.URI.String()) + if relay.URI != "" { + a.AnonymizeURI(relay.URI) } } } diff --git a/client/server/server.go b/client/server/server.go index 8173d0741..0a4c18131 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -12,7 +12,6 @@ import ( "github.com/cenkalti/backoff/v4" "golang.org/x/exp/maps" - "google.golang.org/protobuf/types/known/durationpb" log "github.com/sirupsen/logrus" @@ -143,10 +142,12 @@ func (s *Server) Start() error { s.sessionWatcher.SetOnExpireListener(s.onSessionExpire) } - if !config.DisableAutoConnect { - go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe) + if config.DisableAutoConnect { + return nil } + go s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil) + return nil } @@ -154,7 +155,7 @@ func (s *Server) Start() error { // mechanism to keep the client connected even when the connection is lost. // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status, - mgmProbe *internal.Probe, signalProbe *internal.Probe, relayProbe *internal.Probe, wgProbe *internal.Probe, + runningChan chan error, ) { backOff := getConnectWithBackoff(ctx) retryStarted := false @@ -185,7 +186,15 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Conf runOperation := func() error { log.Tracef("running client connection") s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder) - err := s.connectClient.RunWithProbes(mgmProbe, signalProbe, relayProbe, wgProbe) + + probes := internal.ProbeHolder{ + MgmProbe: s.mgmProbe, + SignalProbe: s.signalProbe, + RelayProbe: s.relayProbe, + WgProbe: s.wgProbe, + } + + err := s.connectClient.RunWithProbes(&probes, runningChan) if err != nil { log.Debugf("run client connection exited with error: %v. Will retry in the background", err) } @@ -576,9 +585,22 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String()) s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) - go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe) + runningChan := make(chan error) + go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, runningChan) - return &proto.UpResponse{}, nil + for { + select { + case err := <-runningChan: + if err != nil { + log.Debugf("waiting for engine to become ready failed: %s", err) + } else { + return &proto.UpResponse{}, nil + } + case <-callerCtx.Done(): + log.Debug("context done, stopping the wait for engine to become ready") + return nil, callerCtx.Err() + } + } } // Down engine work in the daemon. @@ -590,28 +612,19 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes return nil, fmt.Errorf("service is not up") } s.actCancel() + + err := s.connectClient.Stop() + if err != nil { + log.Errorf("failed to shut down properly: %v", err) + return nil, err + } + state := internal.CtxGetState(s.rootCtx) state.Set(internal.StatusIdle) - maxWaitTime := 5 * time.Second - timeout := time.After(maxWaitTime) + log.Infof("service is down") - engine := s.connectClient.Engine() - - for { - if !engine.IsWGIfaceUp() { - return &proto.DownResponse{}, nil - } - - select { - case <-ctx.Done(): - return &proto.DownResponse{}, nil - case <-timeout: - return nil, fmt.Errorf("failed to shut down properly") - default: - time.Sleep(100 * time.Millisecond) - } - } + return &proto.DownResponse{}, nil } // Status returns the daemon status @@ -745,11 +758,11 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { ConnStatus: peerState.ConnStatus.String(), ConnStatusUpdate: timestamppb.New(peerState.ConnStatusUpdate), Relayed: peerState.Relayed, - Direct: peerState.Direct, LocalIceCandidateType: peerState.LocalIceCandidateType, RemoteIceCandidateType: peerState.RemoteIceCandidateType, LocalIceCandidateEndpoint: peerState.LocalIceCandidateEndpoint, RemoteIceCandidateEndpoint: peerState.RemoteIceCandidateEndpoint, + RelayAddress: peerState.RelayServerAddress, Fqdn: peerState.FQDN, LastWireguardHandshake: timestamppb.New(peerState.LastWireguardHandshake), BytesRx: peerState.BytesRx, @@ -763,7 +776,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { for _, relayState := range fullStatus.Relays { pbRelayState := &proto.RelayState{ - URI: relayState.URI.String(), + URI: relayState.URI, Available: relayState.Err == nil, } if err := relayState.Err; err != nil { diff --git a/client/server/server_test.go b/client/server/server_test.go index 6a3de774c..9b18df4d3 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -6,10 +6,11 @@ import ( "testing" "time" - "github.com/netbirdio/management-integrations/integrations" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel" + "github.com/netbirdio/management-integrations/integrations" + log "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/grpc/keepalive" @@ -73,7 +74,7 @@ func TestConnectWithRetryRuns(t *testing.T) { t.Setenv(maxRetryTimeVar, "5s") t.Setenv(retryMultiplierVar, "1") - s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.mgmProbe, s.signalProbe, s.relayProbe, s.wgProbe) + s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil) if counter < 3 { t.Fatalf("expected counter > 2, got %d", counter) } @@ -129,8 +130,9 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve if err != nil { return nil, "", err } - turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil) + + secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil) if err != nil { return nil, "", err } @@ -158,7 +160,7 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) { log.Fatalf("failed to listen: %v", err) } - srv, err := signalServer.NewServer(otel.Meter("")) + srv, err := signalServer.NewServer(context.Background(), otel.Meter("")) require.NoError(t, err) proto.RegisterSignalExchangeServer(s, srv) diff --git a/client/testdata/management.json b/client/testdata/management.json index 4745f2e8c..674c66e06 100644 --- a/client/testdata/management.json +++ b/client/testdata/management.json @@ -20,6 +20,13 @@ "Secret": "c29tZV9wYXNzd29yZA==", "TimeBasedCredentials": true }, + "Relay": { + "Addresses": [ + "localhost:0" + ], + "CredentialsTTL": "1h", + "Secret": "b29tZV9wYXNzd29yZA==" + }, "Signal": { "Proto": "http", "URI": "signal.wiretrustee.com:10000", @@ -34,4 +41,4 @@ "AuthAudience": "", "AuthKeysLocation": "" } -} \ No newline at end of file +} diff --git a/client/ui/bundled.go b/client/ui/bundled.go new file mode 100644 index 000000000..e2c138b14 --- /dev/null +++ b/client/ui/bundled.go @@ -0,0 +1,12 @@ +// auto-generated +// Code generated by '$ fyne bundle'. DO NOT EDIT. + +package main + +import "fyne.io/fyne/v2" + +var resourceNetbirdSystemtrayConnectedPng = &fyne.StaticResource{ + StaticName: "netbird-systemtray-connected.png", + StaticContent: []byte( + "\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\x00\x00\x00\x01\x00\b\x06\x00\x00\x00\\r\xa8f\x00\x00\x00\xc3zTXtRaw profile type exif\x00\x00x\xdamP\xdb\r\xc3 \f\xfc\xf7\x14\x1d\xc1\xaf\x80\x19\x874\xa9\xd4\r:~\r8Q\x88r\x92χ\x9d\x1cư\xff\xbe\x1fx50)\xe8\x92-\x95\x94СE\vW\x17\x86\x03\xb53\xa1v>@\xc1S\x1dNɞų\x8c\x86\xa5\xf8\xeb\xa8\xd3d\x83T]-\x17#{Gc\x9d\x1bEGf\xbb\x19\xc5E\xd2&b\x17[\x18\x950\x12\x1e\r\n\x83:\x9e\x85\xa9X\xbe>a\xddq\x86\x8d\x80F\x92\xbb\xf7ir?k\xf6\xedm\x8b\x17\x85y\x17\x12t\x16\xd11\x80\xb4P\x90\xdaE\xf5\xf0\xa1\xfc#u-\x92:[L\xe2\vy\xda\xd3\x01\xf8\x03\xda\xd4Y\x17ݮ\xb7\xee\x00\x00\x01\x84iCCPICC profile\x00\x00x\x9c}\x91=H\xc3@\x1c\xc5_S\xa5\"-\x0e\x16\x14\x11\xccP\x9d\xec\xa2\"\xe2T\xabP\x84\n\xa5Vh\xd5\xc1\xe4\xd2/hҐ\xa4\xb88\n\xae\x05\a?\x16\xab\x0e.κ:\xb8\n\x82\xe0\a\x88\xb3\x83\x93\xa2\x8b\x94\xf8\xbf\xa4\xd0\"ƃ\xe3~\xbc\xbb\xf7\xb8{\a\b\x8d\nSͮ\x18\xa0j\x96\x91N\xc4\xc5lnU\f\xbcB\xc0\x00B\x18\xc1\xac\xc4L}.\x95J\xc2s|\xdd\xc3\xc7\u05fb(\xcf\xf2>\xf7\xe7\b)y\x93\x01>\x918\xc6t\xc3\"\xde \x9e\u07b4t\xce\xfb\xc4aV\x92\x14\xe2s\xe2q\x83.H\xfc\xc8u\xd9\xe57\xceE\x87\x05\x9e\x1962\xe9y\xe20\xb1X\xec`\xb9\x83Y\xc9P\x89\xa7\x88#\x8a\xaaQ\xbe\x90uY\xe1\xbc\xc5Y\xad\xd4X\xeb\x9e\xfc\x85\xc1\xbc\xb6\xb2\xccu\x9a\xc3H`\x11KHA\x84\x8c\x1aʨ\xc0B\x94V\x8d\x14\x13iڏ{\xf8\x87\x1c\x7f\x8a\\2\xb9\xca`\xe4X@\x15*$\xc7\x0f\xfe\a\xbf\xbb5\v\x93\x13nR0\x0et\xbf\xd8\xf6\xc7(\x10\xd8\x05\x9au\xdb\xfe>\xb6\xed\xe6\t\xe0\x7f\x06\xae\xb4\xb6\xbf\xda\x00f>I\xaf\xb7\xb5\xc8\x11з\r\\\\\xb75y\x0f\xb8\xdc\x01\x06\x9ftɐ\x1c\xc9OS(\x14\x80\xf73\xfa\xa6\x1c\xd0\x7f\v\xf4\xae\xb9\xbd\xb5\xf6q\xfa\x00d\xa8\xab\xe4\rpp\b\x8c\x15){\xdd\xe3\xdd=\x9d\xbd\xfd{\xa6\xd5\xdf\x0fںr\xd0VwQ\xba\x00\x00\rxiTXtXML:com.adobe.xmp\x00\x00\x00\x00\x00\n\n \n \n \n \n \n \n \n \n \n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\xf0C\xff\xd9\x00\x00\x00\x06bKGD\x00\xff\x00\xff\x00\xff\xa0\xbd\xa7\x93\x00\x00\x00\tpHYs\x00\x00\v\x13\x00\x00\v\x13\x01\x00\x9a\x9c\x18\x00\x00\x00\atIME\a\xe8\x02\x17\r$'\xdd\xf7ȗ\x00\x00\x13;IDATx\xda\xed\x9d]o\x14W\x9a\xc7\xff\xa7\xaamh\xbf\xc46,I`\x99\xa1\xc3\ni\xb5{1\x95O0\xe4\x1b\xc0'X\xf2\t`.W`hp\xa2\xb9\fH{O\xa3\xcc\xc5\xecJ3q\xa4\x1d\xed\xcdJx>Aj/\"EBJګL \xb1\x00g\xf1\v\xb6\xbb\xeb\xec\x85mb\f\xb6\xfb\xa5^Ω\xfa\xfd\xee\x928v\xf7\xa9z\xfe\xcfs\x9e\xa7ο$\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\u0603a\t\xc0g\xd6\x7f\x1f5\x92\x8e\"k\xd4\b\xa4s\xb2jH\x9afez\n\xfe\xdb\b\x00x\x81mF\xd3/CE]\xa3(\x94~c\xa5\x8b;\xc1\x0e\x83E\x7f{\xecF\xfcA\x8d\x95\x00g\xb3\xfb\\tQ\xd2o\xadtq]\xba(I\x81\x95,K345\xa3˒\x84\x00\x80SY~5Х0\xd0o\x13\xabK\x96R>\x9b\xe4o\xd4\x1a\xbd\x1e\xc7\b\x008\x93\xe9\xadtkM\x8a\x02i\xdaZ\x9aS\x99\x12\xea\xf6\xabJ\x80Հ\x02\xf7\xf4W\x13\xe9\xdan\xa6'\xe8sXw\xe9\xf6ؿ\xc6\xed_Z\x01\x00\x05d{\xed\xec\xe9!\xcf\xda\x7f\xbb\xf1\xf7Z/\x80U\x81<\x03\xdf\x12\xf8\xc5\xc5\x7f\xf2K\xe9O\x05\x00d\xfcje\xffx\xecF\xfc\xe1\xfe\x7fM\x05\x00\xd9\x04~3j$5ݲVWX\r\a\xe2?\xdc\x1e\xfb!\x00\x909ks\xd1\xd5Dj\x1a\xcb\x18ω\xe07j\xd5\xf74\xfe\x10\x00Ȅ\x95O\xa3(Ht_R\xc4\xdeҙҿ\xbdw췟\x80\x15\x82\xb4\xb2~\x90\xe8+I\x11\xab\xe1\x0e\xd6\xea\xc1A\xd9\x7f[\x1f\x00\x86\xdc\xeb\xdbP\xf7E\x93\xcf\xc9\xec\xbf\x7f\xec\xc7\x16\x00\xd2\v\xfe\xb9\xe8b\"}axd\xd7\xcd\xf8O\x0e.\xfd\xd9\x02\xc0\xb0\xc1\x7f\xcbJ\x0f\t~G\t4_\xbf\x19\xb7\x8e\xfa1*\x00\xe8\x9b\xd5O\xa2\xfb\x8c\xf7\x1c\xcf\xfe\x81~\xd7\xcb\xcf!\x00\xd03\xb6\x19M\xaf\x87\xfaB\x96\xfd\xbe\xd3\xc1\x7f\xc8؏-\x00\fV\xf27\xa3\xc6z\xa8\x87\xa2\xd9\xe7x\xf4\x1f>\xf6\xa3\x02\x80\x81\x82\xdf\xd6\xf4\x10\a\x1e\x0f\xe2?\xd1\xed\xfa\x8d\u07b2\xff\xb6^\x00\x10\xfc\xa5\xc9\xfeG\x8d\xfdJW\x01\xd8f4\xfd\xf2ؾNt\xe7\xe8\x9b5\f\xb4\xdc\r\xb4\xbc\xfb\xcf\xc77\xb4l\x9a\xf12wѾ=?\xc1\xef\xcf\xf5\xb2\xbd5\xfe\x9c\xac\x00\xd6\x7f\x1f5\xc2D\xd3[\x89\x1a\xd6jZ\x81\xa6\x8d\xd5t`tn\xe7\xcb5$M\xcb\xec\x04{\x867\xa5\x95\x96\x8dѲ\xacvK\xa9ec\xb4\x9cX-Z\xa3ec\xd5\x0e\xa4e\xd5\xd4\xee\xb5\xd9\xe2#ks\x11O\xf6\xf9\x92\xfc\x8dZ\xf5\x1b\xf1\xc7N\n\x80mF\xd3[#jlv\x15)\xd0\xf4\x1e\xfb憌\xa6\xbd\xcf0F\xed\x1d\xb1X\xb6\xd2\xffX\xabvh\xd4>\xdeU\xeckU\xb1v'\xfaLF\xd7\b-On\xc1\x9a>\x18$\x19\x99,\x82<\b\xf4\x1b\xb3\xed\xed\x16Y\xa9Q\xe5\x87E\xac\xb4l\xa4XFq\"-\x86V\xb1\xeb°\xf3\x90O\x93\xb0\xf2\xe6\x1e\xbb=>\x1b\x0ft\xbd\xcc\xc0\x81nuq'\x93Gv\xfb\xf4\x17O\x84\xf5G,\xa9\x9d\x18\xfd5\xb4\x8a\xeb\xb3\xf1\x82\v\x1fju.\xbad\xa4/\xb8<\xfeT\x9f\xf5\x8e>\x1c4\xa1\x98\xa3\x82}-\xd4Ek\xd4\xe0e\f\xb9\xb0 \xa3\xd8Z\xfdu\xac\xab\x85\xbc\xab\x04:\xfe\x1eƿ\xd5ǽ<\xf2ۓ\x00l~\x1aE\x9bV\x17\tv\x87\xaa\x04\xa3\x05c\xf5e\x1e\x15\xc2\xda'\xd1w\\s\xbf\xb2\x7f\xbfc\xbf7~\xc5\xda\\tU\xd2%\xcax/z\t\v\x89\u0557\xe1\x88\x16Ҟ>\xb0\xef\xf70\xfe\al\xfc\xbd\xf6;V>\x8d\"\x93p\xaa\xcb\xc7\xedBb\xf5 \r1\xd89\xd3\xff\x1dK\xeaQ\xf0\x0f8\xf6{\xeb\x16`ǹ\xf5!\xcbZM1X\x9d\x8b\xbe2\xcc\xfb\xbd\xaa\x06\x83\x9a>L\xa3\n|\xd5\x03X\xbf\x13]\xb1F\xf7Y^\uf677҃\xf1\xd9x\xbe\x97\x1f^\xb9\x13]\t\xb8\xee\xbe\t\xc0\xc0c\xbf\x03\x05\x00\x11([\x8d\xa8\xb6\x12͛\x11\xdd;,S\xd0\xf8\xf3\xef\xba\x0e\xdb\xf8\xdb\xcbkǁ\xeb7\xe3\x96U\xefG\t\xc1\xe94ѐ\xd15\xdb\xd1w\xab\x9fD\xf7w^\xb5\xfd\xfa\xde\x7f.\xbaE\xf0{\x16\xffI\xba\xf1i\x0e\xd8\x136\xcd\xf6\xdb\\\xa0d\xd9#It{\xe2fܢ\xf1\xe7\xe5\xf5[\x18\xbb\x11\x7f\x94\xb9\x00 \x02\x15\x10\x82\xe7\x13\xed`z\xe5\"\x8b\xe1\xd1eKa\xec׳\x00 \x02%\xde\x1dlִ\xf5\xf5\xafeF;\nO?Wp\xe2\x05\x8b\xe2z\xf0\xa74\xf6;\xb4\a\xb0\x9f\xf1ٸ\x99H\x0fX\xfer\xd1yt\xe6\x95\x10t\x16Oi\xeb\xeb_+y6\xc9\xc28\\\xb1\xf5c\xf3\x95\x9a\x00H\xd2\xc4l|\x05\x11(\x0fɳI\xd9\xcd\xda\x1b\x15\xc1\xae\x10ؕ:\x8b\xe4\x1e\xf7\xb2\xf2\x9d\xe8\xf94\xe0\xfa\\\xf4\x90w\xbb{^\xfaw\x03u\xbe9\xfb\x86\x00\xbc\x91\x15N\xac(<\xfd\\ft\x8bEs \xfb\xa79\xf6\xeb\xbb\x02\xd8\xe5eW\x97\xed\xf6\x11V\xf05\xfb\xff4ud\xf0oW\t\x13\xda\xfa\xfaW\xea>\x9eaъ\x8e\xff$۱|_~\x00ϛ\xd1\xf4h\xa8\x87<6\xeaa\xf6\xdfi\xfc\xf5}\x83\x8cvT\xbb\xf0\x98j\xa0\x88\xe0Ϩ\xf17P\x05 I3\xcdx9\xe8게\xda\\\x1e\xbf\x184\x9bo\vǯ\xd4\xfd\xfe\xa4l\x97\xd7H\xe4J\x98\xfdCy}_\xd1z3n\x9b\x8e>B\x04<\xca\xfe+\xf5\xa1\xbb\xfcݥ\xa9\x9d\xfe\xc1\b\v\x9a\xc75\x93n\xe7a8;\x90\xa4#\x02~\xd1Y<\x95\xe26\x82\xde@\xf6\xb5\xbf\xdaAM\xad<\xfe\xd4\xc05\x1d\"\xe0\ao\x1b\xfb\r\xbd\x9dx2\xa3-\xaa\x81\xec\xe2?\xc9'\xfbok͐`(\xe2p\x19\xb9YS\xe7љ\xd4\x05\xe0\xd5\xcd3\xdaQx\xf6\xa9\x82\xa9U\x16;\xc5\xec\x9f\xe5\xd8/\xb5\n`\x97\x89\xebql\x03}d%ު\xe3\x18\xdd\xc73\x99\x05\xff+\x81\xf9\xf6=\xb6\x04)R3\xba\x9c\xe7\xdfK\xa5\xad;q=\x8e%}\xcc\xe5s+\xfb\xe7\xf5xo\xf7Ɍ:\x8b\xef2%\x186\xf9\x1b\xb5F\xb7c\xc9/\x01\x90\xa4\xf1\xd9x\xdeXD\xc0\xa5\xec\x9fo\xafa\x82)\xc1\xb0\x84\xf9{q\xa4*\xd9\xf5\x9bqK\xa6\xff\x17\x14B\xda\xc18Y\xc8\xe1\x9e\xed\x9e\xc3iD`\x90\xb5S~\x8d\xbf\xd7[\x0e\x19\xc01\xe2b\xd9\xfa\xfaי\xee\xfd\x8f\xced\x89F.<\x96\xa9op1z\x8b\xc2\\\x1b\x7f\x99U\x00{\xb6\x03M\xac\xc5\n*\xfd{|\xde?\xdb\x0f\x11h\xeb\xd1i%?\x8fsAz\x89\xff\xa4\xb8X\xc9\xf4\xed\xc0T\x02E\x94\xe0g\x8a\x17\x80\xbd\xc5\xc0\xb9%\x85\x18\x8e\x1c\x16\x81\xf1؍\xf8â\xfe|\xa6m\xdb\x1dC\x91{\\\xe5\x9c\x12o\xc6c\xbf\x81>\xd3\xe2)u1\x1b98\xfe\xc3|\xc7~\xb9\n\x80$M\xcc\xc6\xd70\x14\xc9'\xfb\xbb\xea\xea\x83\b\x1c\x10\xfcF\xad\"\x1a\x7f\xb9\n\xc0\x8e\b\xe0*\x941\x9do\xdfw\xbb:A\x04\xf6\x97\xfe\xed\"\xc6~\x85\b\x80$muu\rC\x91lH\x9eMʮ\x8f\xba\xbfE\xf9\xfe\xa4\xec\xfa1.\x98$k\xf5\xa0\xe8쿭C9\x82\xa1HF\xe2Z\xf4د\x1f\xc2D#\xff\xf8\xb7j\x1b\x8c\x148\xf6+\xac\x02\x90\xb6\rE6\xbb\x9c L5\xab:\xd8\xf8;\xfc\x03\a\x95\x7fX\xa8ȱ_\xa1\x02\xb0+\x02\x1c#N\xa9\x8cܬ\xa9\xfbd\xc6\xcb\xcf\xdd\xf9\xf6\xbdJ\x9e\x1d0F\xad\xfa\u0378UY\x01\x90\xf0\x12H3\xfb{+^\xeb\xa3\xea~\xffwջh\xa1[\x0f\xc8\x15&\xc1\x88\xc0\xb0\x01t\xcc\xfb\x97y$\xcf&*u\x94\u0605\xb1\x9f3\x02\xb0W\x04\xf0\x12\xe8\x9fη\uf563\x8ay2S\x8d\x97\x9182\xf6sJ\x00vE\x00C\x91~3\xe7\xa4_\x8d\xbf#\xd8\xfa\xf6\xbd\xd27\x05\xf3\xb4\xf9\xeaO\x97\x1c\x01k\xb1\x1eK\x7f\a\x9f\xf7O\xe5F\x9cx\xa9\x91\v?\x946\xfb\xbb2\xf6s\xae\x02\xd8e\xe2z\x1c\a\x16/\x81#\xb3\xff\xd3\xc9\xd2\x05\xbf$ٕ\xe3\xea.M\x953\xfe\x1d6\xcaqj\x0eS\xbf\x19\xb7p\x15:<\xfb\xfb8\xf6\xeb\xb9\x1f\xf0\xfd\xc9\xd2\xf5\x03\x8cQ\xab>\x1b/ \x00}\x88\x00\xaeB\a\x04H\x05:\xe6\x9d\xc5S\xe5z> t\xdb\x17\xc3ɕ\x1e\xbb\x11\xdf\xc5Pd_\xf6\xffy\xdc\xfb\xb1_\xafUNR\x12\xa1+\xca\xe6\xcb{\x01\x90p\x15z#3~\x7f\xb2:\x95\xceҔ\xff[\x01\xa3\xf6XWw]\xff\x98N\xd7Z\x88\xc06e\x1b\xfbUA\xf0L\xa2ۦ\x19;?\xda6>,\xe6\xca\\t7\x90\xaeV\xb2\xf4/\xe9د\xa7\xed\xf3٧\nO\xfd\xecg\xf6wt\xec\xe7U\x05\xb0K\x95]\x85\xbc;\xed\x97\xf6w\xf7\xb0!hB}\xe4\xcbg\xf5fu\xab\xe8*\xe4\xb2\xcdW>\n\x10xw`\xc8\xc5\xe7\xfdK!\x00R\xf5\\\x85*yZn\x1fɳ\to\x1a\x82VZv}\xec\xe7\xb5\x00\xec\x1a\x8aTA\x04\x92g\x93J~\x1e\x13H\x1d\x7fƂ\xf7|\xca\xfe\x92'M\xc0\xfd\xac7\xa3\x86\xad顬\x1ae\xbd齲\xf9ʁ\x91\v\x8fe&\xd6]\x8e$o\x1a\x7f\xdeV\x00\xbb\x94\xddK\xa0ʍ?_\xab\x00\x97l\xbeJ/\x00e\x16\x01\xbbYSR\xd2C1C\xad\xcb\xcaqw{\x01\x81\xe6]\xb2\xf9\xaa\x84\x00\x94U\x04|\x1d}U\xb9\n0\x81\xbfgW\xbc\xbf\xd3\xca\xe4*T\xf9\xb1\x9f\x87U\x80oc\xbf\xd2\t\xc0\xae\b\x94\xc1U\xa8\xf3\xe8\fQ\xeeS\x15\xe0\xa8\xcdW\xe5\x04@\xda1\x14Q\xb1/Z\x1c\x86*>\xef\xef{\x15\xe0\xaa\xcdW%\x05@\x92\xea\xb3\U000423c6\"\xb6\x1bT\xca\x1d\xb7\x14U\x80Q\xdb\xd7\xc6_i\x05@\xf2\xd3U(\xf9i\x8a\xec\xdfo\x15P\xb0\x89\xa8-\x89}])\xdb\xcd\xf5\x9bq˗c\xc4e\xb7\xf9\xcan\xcb4Q\\\xf27j\x8d\xcf\xc6\xf3\b\x80\xc3\xf8\xe2%@\xe9?\xe0\xba\xfd4Uܸ4,\x8fGE\xa9\aή\x8b\x80]\xa93\xf6\x1bX\x01\x82B\xd6\xce\a\x9b/\x04\xc0\x13\x11\xe8,\x9e\"\x90\x87\xd9\x06,\x8f\xe7\\\xfb\xab\x1d\xd4\xd4*\xd3\x1aV⑳\xf1ٸ隗\x00c\xbf4*\xa8\xe3\xb2\xeb\xc7\xf2\x8b\xff\xa4\\ٿ2\x02 \xb9e(b7k\xec\xfd\xd3\x12Ҽ\x8eL\x97d\xecWY\x01\xd8\x15\x01#-\x14~\xd32\xf6K\xaf\x15\xf0S>\a\xa7j\xc6߇\xcc\x10\x80=\xbc\xec\xear\x91\x86\"v\xb3V\xdaW`\x15\xa3\x00A\xe6O\x06\x1a\xa3\xd6\xe8\xf58F\x00J@ѮB\x94\xfe\x19\xaci\xd6ۀ\xb0\xbc\xd6\xf4\x95@X\xbd\xec\x8f\x00d \x02\x8c\xfd\x1c\xe8\x03\xf4)\x02I\xc5\x1a\x7f\b\xc0\x80\"Ћ\xa1\bo\xf7q\xa1\x0f\xd0\xc76\xc0\xa8\x1d\xd6t\xb7\xaak\x85\x00\xf4\xc1Q\xaeB\xd8|\xb9\xb2\r\xe8\xdd&\xac\x8c6_\xfd`\xb8]\xfagu.\xfa\xcaH\xd1k7]7P盳\b\x80\v7\xf5hG#\xff\xfc\xbf=e\xff\xb1\x1b\xf1\aU^+*\x80\x01x\x9b\xa1\b6_\x0eU\x00\x9b\xb5\x9e\x9a\xb0>\xbeF\x0e\x01p\x80]W\xa1\xdd\x13\x84\xbc\xdd\xc7A\x8e\xd8\x06\x18\xa3V}6^@\x00``\x11\xd8=F\xcc\x13\x7f\xee\x91\x1c5\t\xa8\xe8\xd8\x0f\x01H\x91z3n'\x1b\xb5{\xae\xbc\xae\x1a\xf6l\x03\x0e\xa9\x00\xcan\xf3\x85\x00乀\xc7:Wk\x17~\x90\u0084\xc5pI\x00\xd6\x0e\xa8\x00\x8c\xdac\xdd\xea\x8e\xfd\x10\x80\x14Y\x9b\x8b\xaeʪaF;\x1a\xb9\xf0\x18\x11pI\x00\x0ehȚD\xb7M3^f\x85\x10\x80\xa1XoF\r\x19]{uc\xd574r\xfeG\x16\xc6\x15\xba\xc1\x9b\x0eA%}\xbb\x0f\x02P\x00IM\xb7d\xd5x\xed\xfe\x9aXWxn\x89\xc5q\xa6\x0f\xf0\xfa6\xc0\x84\xfa\x88UA\x00R\xc9\xfe\xc6\xea\xca\xdb\xfe[x\xe2\x05\"\xe0\xe06\xa0J6_\b@\xd6\xd9?\xd4\x17\x87\xfd\xf7\xf0\xc4\v\x85\xa7\x9f\xb3P\x8e\b\x80\x95\x96\x19\xfb!\x00\xa9\xb0r'\xba\xb2\xff1්\xc0\xfb\xcf\x11\x01w*\x80{d\x7f\x04 \x9d\x05\vt\xabןE\x04\nf\xedX%\xde\xee\x83\x00\xe4\xb5\xf7\x9f\x8b\xdeh\xfc!\x02\x0eW\x00ݠ\x926_\xfd\xc0i\xc0^\x83\xbf\x195l\xa8\xef\x06\xfd\xff;\x8b\xa70\n\xc9\xfb\xe6~g\xbd=\xd5\xfa\xaf\x0fX\t*\x80\xa1Ij\xbd\x97\xfeo\xa3vnI\xc1\x89\x17,d\x8e\x84'V\xc8\xfeT\x00\xc5g\xff\xbdl=:û\x02\xf2\xc8lc\x9b\xf3\xef\xfc\xe1?/\xb3\x12T\x00\xc3\xef%kz\x98\xd6\xef\x1a9\xffD\xa6\xbeɢf\x99\xd5F;\xeav&~\xc7J \x00C\xb3r'\xba\xd2o\xe3\xef\xf0\xba4\xd1ȅ\x1f\x10\x81,K\xff\xa9\xf5\xd6\xcc\x1f\xff\xd8f%\x10\x80\xa1K\xff~\xc6~\xfd\x88@\xed\xfc\x13\x99\xd1\x0e\x8b\x9c>\xed\xb0\xb1\xc4\xde\x1f\x01H#P\xf5/\xa9f\xff}ej\xed\xc2\x0f\x88@\xda7\xf4\xf4\x1ag\xfd\xfb\xb9\x0fY\x82\x83\xb3\x7fZ\x8d\xbfC\xfb\v\x9b5u\x1e\x9d\xc1O0\x8d\x9b\x99\xb1\x1f\x15@Z\f;\xf6\xa3\x12(\xe0f\x9e\xd8\xfa\x98U@\x00\x86fu.\xbat\xd0i\xbf\xccD\xe0\xfc\x8f\x18\x8a\fs#\x8fo\xb4&\xff\xed\xbf\x17X\t\x04`\xf8\x804\xfa,\xf7\xbfY\xdf\xc0Uh\b\x01\x1d9\xb9J\xe3\x0f\x01\x18\x9e\xd4\xc7~}\x8a@\r/\x81\xfe\x19\xedܮ\xdf]h\xb3\x10\b\xc0Pd6\xf6\xeb\xe7\x82L\xadb(\xd2\x1f\xedw\xce\xff\x80\xc9'\x020\xff\v\x8d?*\x80\xe1qe\xec\xd7\x0f\xb5\xb3O+\xed%\x10\x9e\xfa?J\x7f*\x80\x94\xb2\xbfcc\xbf\x9e\xe9\x06\xdb\xd6b\xeb\xa3\xd5\xcaV#\xc9\xfc;\xff>\x8f\xcd\x17\x15@\n\xd9\xff\x88\xb7\xfb\xb8\x9d\x06w\\\x85*t\x82Ќv\xd45DZ\xf9B\x00\x86\xa7\u05f7\xfb\xb8.\x02U:F\x8c\xcd\x17\x02\x90ޗv|\xec\xd7OV\xac\x88\b`\xf3\x85\x00\xa4\xb4\xf7\xf7d\xec\x87\b\xfc\x82\x95\xb0\xf9\xca\xea\xfe\xa9T\xf0\xfb\xdc\xf8;*H6k\xda\xfa\xe6\xac\xd4-\x97\xa6\xf3\xbc?\x15@j\xe4e\xf3UT%PFC\x11l\xbe\xa8\x00Ra\xe5\xd3(\n\x12}U\xf6\xefi\u05cfi\xeb\xd1\xe9RT\x02\xc1\xf8F\xeb\x9d\xcf\xff\x82\x00P\x01\xa4\xf0E\xad\xc7c\xbf~\x14\xbd\xbeQ\n/\x013\xdaQwk\x92\xc6\x1f\x02\x90B\xf6/\xd0\xe6\xab\b\xc2\x13/\xfcw\x15\x1a\xed\xdcf\xec\x87\x00\f\x8d\v6_\x85\x89\x80\xbf^\x02\xd8|!\x00)\xed\x89\x03]\xadR\xf6\x7fM\x04<5\x14\t\xde_\xc6\xe6+\xaf\xadVٳ\x7fY\xc7~\xfd\xe0\x93\xb5\x18c?*\x80\xd4(\xf3د\xac\x95\x006_\b@*\xac܉\xae\xe4\xf9v\x1f/D\xc0qW\xa1`|\x83\xe7\xfd\x11\x80\x94\xbeX@\xf6\x7fC\x04\xdcv\x15\xc2\xe6\v\x01H\x87\xb5\xb9\xa8\xb2\x8d\xbf\xa3\xa8\x9d[\x92\x99x\xe9\xde\xde\x1f\x9b\xafbֽl_\xc8G\x9b\xaf\xdcq\xccP\xc4Ԓ\xf6\xd4\x7f\xcc\xd3\xf8\xa3\x02\x18\x1e\x1fm\xbe\xf2\xdf\v\xec\x18\x8a8b-\x16\x9e}J\xe9O\x05\x90R\xf6g\xec\xd73v\xb3\xa6Σ3\xb2\x9bŽ\x1e\x02\x9b/*\x80\x14S\x89\xeesI\xfbP\xff\x82\xbd\x04\xb0\xf9B\x00Rc\xe5Nt\xc5J\x17\xb9\xa4\xfe\x88\x80\x19\xe92\xf6C\x00R\xfa\"\x8c\xfd|\x13\x01\xc6~\b@J{\xff\x92\xd9|\x15&\x02\xe7\x7f\xcc\xcdP\x04\x9b/G\xae\xbb\xf7\xc1ߌ\x1aI\xa8\xaf\x8c4\xcd\xe5L!0s0\x14\xe1\xed>T\x00\xa9\x91\xd4t\x8b\xe0O18\xeb\x1b\x1a9\xffc\xb67\xdd\xd4\x06.?T\x00\xe9d\x7f\xc6~\xd9\xd0}6\xa9\xee\xe2\xa9\xf4o8cZS\x7f\xfa\x13\x02@\x05\x90B\xb9Z\xd3C.a6d\xe1*dF;JFFh\xfc!\x00\xc3S5\x9b\xaf\xc2D \xcdc\xc4\xd8|!\x00\xa9d\xfef4\xcd\xd8/'\x11H\xcfK\x00\x9b/\x04 \x1d^\x86ⴟg\"\x10\x1c\xdf\xc2\xe6\xcbA\xbck\x02\xd2\xf8+\x8eA\xadŰ\xf9\xa2\x02H\rl\xbe\x8a\xad\x04\x061\x14\xc1\xe6\v\x01H\x85չ\xe8\x126_\xc5R;\xb7ԗ\b`\xf3\x85\x00\xa4\xb7_1\xfa\x8cK\xe6\x86\b\xf4\xe4%\x10&\xcb#'W\x19\xfb!\x00\xc3\xc3\xd8\xcf-z1\x141\xf5\xcd{\xf5\xbb\vd\x7f\x97\x93\xaa\x0f\x1f\x12\x9b/G\xe9\x06\xda\xfa\xe6\xecA\x86\"\xed\xe9?\xff\x99\xc6\x1f\x15\xc0\xf0`\xf3\xe5(ar\xe01\xe2Zc\x89ҟ\n \xa5\xec\xcf\xd8\xcfi\xf6[\x8ba\xf3E\x05\x90\xde\xcd\x15\xd2\xf8s>\x8b\xec1\x141a\x82͗G\xd4\\\xfep\xabs\xd1%\x19E\x92\xda\\*\xc7E\xe0XG#\xff\xf0X\x1b\x8b\xa7\xbe\x9c\xf9\xc3<\xd7\v\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00|\xe4\xff\x01\xf6P(\xf3)+S\x1f\x00\x00\x00\x00IEND\xaeB`\x82"), +} diff --git a/encryption/cert.go b/encryption/cert.go new file mode 100644 index 000000000..3f6d5c679 --- /dev/null +++ b/encryption/cert.go @@ -0,0 +1,19 @@ +package encryption + +import "crypto/tls" + +func LoadTLSConfig(certFile, keyFile string) (*tls.Config, error) { + serverCert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + + config := &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + ClientAuth: tls.NoClientCert, + NextProtos: []string{ + "h2", "http/1.1", // enable HTTP/2 + }, + } + return config, nil +} diff --git a/encryption/letsencrypt.go b/encryption/letsencrypt.go index cfe54ec5a..27a5e3110 100644 --- a/encryption/letsencrypt.go +++ b/encryption/letsencrypt.go @@ -9,7 +9,7 @@ import ( ) // CreateCertManager wraps common logic of generating Let's encrypt certificate. -func CreateCertManager(datadir string, letsencryptDomain string) (*autocert.Manager, error) { +func CreateCertManager(datadir string, letsencryptDomain ...string) (*autocert.Manager, error) { certDir := filepath.Join(datadir, "letsencrypt") if _, err := os.Stat(certDir); os.IsNotExist(err) { @@ -24,7 +24,7 @@ func CreateCertManager(datadir string, letsencryptDomain string) (*autocert.Mana certManager := &autocert.Manager{ Prompt: autocert.AcceptTOS, Cache: autocert.DirCache(certDir), - HostPolicy: autocert.HostWhitelist(letsencryptDomain), + HostPolicy: autocert.HostWhitelist(letsencryptDomain...), } return certManager, nil diff --git a/encryption/route53.go b/encryption/route53.go new file mode 100644 index 000000000..3c81ab103 --- /dev/null +++ b/encryption/route53.go @@ -0,0 +1,87 @@ +package encryption + +import ( + "context" + "crypto/tls" + "fmt" + "os" + "strings" + + "github.com/caddyserver/certmagic" + "github.com/libdns/route53" + log "github.com/sirupsen/logrus" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "golang.org/x/crypto/acme" +) + +// Route53TLS by default, loads the AWS configuration from the environment. +// env variables: AWS_REGION, AWS_PROFILE, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_SESSION_TOKEN +type Route53TLS struct { + DataDir string + Email string + Domains []string + CA string +} + +func (r *Route53TLS) GetCertificate() (*tls.Config, error) { + if len(r.Domains) == 0 { + return nil, fmt.Errorf("no domains provided") + } + + certmagic.Default.Logger = logger() + certmagic.Default.Storage = &certmagic.FileStorage{Path: r.DataDir} + certmagic.DefaultACME.Agreed = true + if r.Email != "" { + certmagic.DefaultACME.Email = r.Email + } else { + certmagic.DefaultACME.Email = emailFromDomain(r.Domains[0]) + } + + if r.CA == "" { + certmagic.DefaultACME.CA = certmagic.LetsEncryptProductionCA + } else { + certmagic.DefaultACME.CA = r.CA + } + + certmagic.DefaultACME.DNS01Solver = &certmagic.DNS01Solver{ + DNSManager: certmagic.DNSManager{ + DNSProvider: &route53.Provider{}, + }, + } + cm := certmagic.NewDefault() + if err := cm.ManageSync(context.Background(), r.Domains); err != nil { + log.Errorf("failed to manage certificate: %v", err) + return nil, err + } + + tlsConfig := &tls.Config{ + GetCertificate: cm.GetCertificate, + NextProtos: []string{"h2", "http/1.1", acme.ALPNProto}, + } + + return tlsConfig, nil +} + +func emailFromDomain(domain string) string { + if domain == "" { + return "" + } + + parts := strings.Split(domain, ".") + if len(parts) < 2 { + return "" + } + if parts[0] == "" { + return "" + } + return fmt.Sprintf("admin@%s.%s", parts[len(parts)-2], parts[len(parts)-1]) +} + +func logger() *zap.Logger { + return zap.New(zapcore.NewCore( + zapcore.NewConsoleEncoder(zap.NewProductionEncoderConfig()), + os.Stderr, + zap.ErrorLevel, + )) +} diff --git a/encryption/route53_test.go b/encryption/route53_test.go new file mode 100644 index 000000000..765b60f84 --- /dev/null +++ b/encryption/route53_test.go @@ -0,0 +1,84 @@ +package encryption + +import ( + "context" + "io" + "net/http" + "os" + "testing" + "time" +) + +func TestRoute53TLSConfig(t *testing.T) { + t.SkipNow() // This test requires AWS credentials + exampleString := "Hello, world!" + rtls := &Route53TLS{ + DataDir: t.TempDir(), + Email: os.Getenv("LE_EMAIL_ROUTE53"), + Domains: []string{os.Getenv("DOMAIN")}, + } + tlsConfig, err := rtls.GetCertificate() + if err != nil { + t.Errorf("Route53TLSConfig failed: %v", err) + } + + server := &http.Server{ + Addr: ":8443", + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(exampleString)) + }), + TLSConfig: tlsConfig, + } + + go func() { + err := server.ListenAndServeTLS("", "") + if err != http.ErrServerClosed { + t.Errorf("Failed to start server: %v", err) + } + }() + defer func() { + if err := server.Shutdown(context.Background()); err != nil { + t.Errorf("Failed to shutdown server: %v", err) + } + }() + + time.Sleep(1 * time.Second) + resp, err := http.Get("https://relay.godevltd.com:8443") + if err != nil { + t.Errorf("Failed to get response: %v", err) + return + } + defer func() { + _ = resp.Body.Close() + }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Errorf("Failed to read response body: %v", err) + } + if string(body) != exampleString { + t.Errorf("Unexpected response: %s", body) + } +} + +func Test_emailFromDomain(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"example.com", "admin@example.com"}, + {"x.example.com", "admin@example.com"}, + {"x.x.example.com", "admin@example.com"}, + {"*.example.com", "admin@example.com"}, + {"example", ""}, + {"", ""}, + {".com", ""}, + } + for _, tt := range tests { + t.Run("domain test", func(t *testing.T) { + if got := emailFromDomain(tt.input); got != tt.want { + t.Errorf("emailFromDomain() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/go.mod b/go.mod index e80a7eb46..d9f6162e5 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/netbirdio/netbird -go 1.21.0 +go 1.23.0 require ( cunicu.li/go-rosenpass v0.4.0 @@ -10,9 +10,9 @@ require ( github.com/golang/protobuf v1.5.4 github.com/google/uuid v1.6.0 github.com/gorilla/mux v1.8.0 - github.com/kardianos/service v1.2.1-0.20210728001519-a323c3813bc7 + github.com/kardianos/service v1.2.3-0.20240613133416-becf2eb62b83 github.com/onsi/ginkgo v1.16.5 - github.com/onsi/gomega v1.23.0 + github.com/onsi/gomega v1.27.6 github.com/pion/ice/v3 v3.0.2 github.com/rs/cors v1.8.0 github.com/sirupsen/logrus v1.9.3 @@ -34,6 +34,7 @@ require ( fyne.io/systray v1.11.0 github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible github.com/c-robinson/iplib v1.0.3 + github.com/caddyserver/certmagic v0.21.3 github.com/cilium/ebpf v0.15.0 github.com/coreos/go-iptables v0.7.0 github.com/creack/pty v1.1.18 @@ -50,18 +51,21 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 github.com/hashicorp/go-version v1.6.0 + github.com/libdns/route53 v1.5.0 github.com/libp2p/go-netroute v0.2.1 github.com/magiconair/properties v1.8.7 github.com/mattn/go-sqlite3 v1.14.19 github.com/mdlayher/socket v0.4.1 - github.com/miekg/dns v1.1.43 + github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e + github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd + github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pion/logging v0.2.2 + github.com/pion/randutil v0.1.0 github.com/pion/stun/v2 v2.0.0 github.com/pion/transport/v3 v3.0.1 github.com/pion/turn/v3 v3.0.1 @@ -70,6 +74,7 @@ require ( github.com/rs/xid v1.3.0 github.com/shirou/gopsutil/v3 v3.24.4 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 + github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 github.com/stretchr/testify v1.9.0 github.com/testcontainers/testcontainers-go v0.31.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0 @@ -81,6 +86,7 @@ require ( go.opentelemetry.io/otel/exporters/prometheus v0.48.0 go.opentelemetry.io/otel/metric v1.26.0 go.opentelemetry.io/otel/sdk/metric v1.26.0 + go.uber.org/zap v1.27.0 goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a @@ -93,6 +99,7 @@ require ( gorm.io/driver/postgres v1.5.7 gorm.io/driver/sqlite v1.5.3 gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde + nhooyr.io/websocket v1.8.11 ) require ( @@ -106,8 +113,23 @@ require ( github.com/Microsoft/hcsshim v0.12.3 // indirect github.com/XiaoMi/pegasus-go-client v0.0.0-20210427083443-f3b6b08bc4c2 // indirect github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect + github.com/aws/aws-sdk-go-v2 v1.30.3 // indirect + github.com/aws/aws-sdk-go-v2/config v1.27.27 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.27 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 // indirect + github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 // indirect + github.com/aws/smithy-go v1.20.3 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d // indirect + github.com/caddyserver/zerossl v0.1.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/containerd/containerd v1.7.16 // indirect github.com/containerd/log v0.1.0 // indirect @@ -116,7 +138,7 @@ require ( github.com/dgraph-io/ristretto v0.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/distribution/reference v0.6.0 // indirect - github.com/docker/docker v26.1.4+incompatible // indirect + github.com/docker/docker v26.1.5+incompatible // indirect github.com/docker/go-connections v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect @@ -140,7 +162,7 @@ require ( github.com/googleapis/gax-go/v2 v2.12.3 // indirect github.com/gopherjs/gopherjs v1.17.2 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect - github.com/hashicorp/go-uuid v1.0.2 // indirect + github.com/hashicorp/go-uuid v1.0.3 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect @@ -149,13 +171,17 @@ require ( github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect + github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/josharian/native v1.1.0 // indirect github.com/jsummers/gobmp v0.0.0-20151104160322-e2ba15ffa76e // indirect github.com/kelseyhightower/envconfig v1.4.0 // indirect github.com/klauspost/compress v1.17.8 // indirect + github.com/klauspost/cpuid/v2 v2.2.7 // indirect + github.com/libdns/libdns v0.2.2 // indirect github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect github.com/mdlayher/genetlink v1.3.2 // indirect github.com/mdlayher/netlink v1.7.2 // indirect + github.com/mholt/acmez/v2 v2.0.1 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/patternmatcher v0.6.0 // indirect github.com/moby/sys/sequential v0.5.0 // indirect @@ -164,12 +190,12 @@ require ( github.com/morikuni/aec v1.0.0 // indirect github.com/nicksnyder/go-i18n/v2 v2.4.0 // indirect github.com/nxadm/tail v1.4.8 // indirect + github.com/onsi/ginkgo/v2 v2.9.5 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.0 // indirect github.com/pegasus-kv/thrift v0.13.0 // indirect github.com/pion/dtls/v2 v2.2.10 // indirect github.com/pion/mdns v0.0.12 // indirect - github.com/pion/randutil v0.1.0 // indirect github.com/pion/transport/v2 v2.2.4 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect @@ -186,10 +212,12 @@ require ( github.com/tklauser/numcpus v0.8.0 // indirect github.com/vishvananda/netns v0.0.4 // indirect github.com/yuin/goldmark v1.7.1 // indirect + github.com/zeebo/blake3 v0.2.3 // indirect go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect go.opentelemetry.io/otel/sdk v1.26.0 // indirect go.opentelemetry.io/otel/trace v1.26.0 // indirect + go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/mod v0.17.0 // indirect golang.org/x/text v0.16.0 // indirect @@ -205,7 +233,7 @@ require ( k8s.io/apimachinery v0.26.2 // indirect ) -replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 +replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 diff --git a/go.sum b/go.sum index 6d0408013..d823f505a 100644 --- a/go.sum +++ b/go.sum @@ -79,6 +79,34 @@ github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kd github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= +github.com/aws/aws-sdk-go-v2 v1.30.3 h1:jUeBtG0Ih+ZIFH0F4UkmL9w3cSpaMv9tYYDbzILP8dY= +github.com/aws/aws-sdk-go-v2 v1.30.3/go.mod h1:nIQjQVp5sfpQcTc9mPSr1B0PaWK5ByX9MOoDadSN4lc= +github.com/aws/aws-sdk-go-v2/config v1.27.27 h1:HdqgGt1OAP0HkEDDShEl0oSYa9ZZBSOmKpdpsDMdO90= +github.com/aws/aws-sdk-go-v2/config v1.27.27/go.mod h1:MVYamCg76dFNINkZFu4n4RjDixhVr51HLj4ErWzrVwg= +github.com/aws/aws-sdk-go-v2/credentials v1.17.27 h1:2raNba6gr2IfA0eqqiP2XiQ0UVOpGPgDSi0I9iAP+UI= +github.com/aws/aws-sdk-go-v2/credentials v1.17.27/go.mod h1:gniiwbGahQByxan6YjQUMcW4Aov6bLC3m+evgcoN4r4= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11 h1:KreluoV8FZDEtI6Co2xuNk/UqI9iwMrOx/87PBNIKqw= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.11/go.mod h1:SeSUYBLsMYFoRvHE0Tjvn7kbxaUhl75CJi1sbfhMxkU= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15 h1:SoNJ4RlFEQEbtDcCEt+QG56MY4fm4W8rYirAmq+/DdU= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.15/go.mod h1:U9ke74k1n2bf+RIgoX1SXFed1HLs51OgUSs+Ph0KJP8= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15 h1:C6WHdGnTDIYETAm5iErQUiVNsclNx9qbJVPIt03B6bI= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.15/go.mod h1:ZQLZqhcu+JhSrA9/NXRm8SkDvsycE+JkV3WGY41e+IM= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3 h1:dT3MqvGhSoaIhRseqw2I0yH81l7wiR2vjs57O51EAm8= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.3/go.mod h1:GlAeCkHwugxdHaueRr4nhPuY+WW+gR8UjlcqzPr1SPI= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17 h1:HGErhhrxZlQ044RiM+WdoZxp0p+EGM62y3L6pwA4olE= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.17/go.mod h1:RkZEx4l0EHYDJpWppMJ3nD9wZJAa8/0lq9aVC+r2UII= +github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3 h1:MmLCRqP4U4Cw9gJ4bNrCG0mWqEtBlmAVleyelcHARMU= +github.com/aws/aws-sdk-go-v2/service/route53 v1.42.3/go.mod h1:AMPjK2YnRh0YgOID3PqhJA1BRNfXDfGOnSsKHtAe8yA= +github.com/aws/aws-sdk-go-v2/service/sso v1.22.4 h1:BXx0ZIxvrJdSgSvKTZ+yRBeSqqgPM89VPlulEcl37tM= +github.com/aws/aws-sdk-go-v2/service/sso v1.22.4/go.mod h1:ooyCOXjvJEsUw7x+ZDHeISPMhtwI3ZCB7ggFMcFfWLU= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4 h1:yiwVzJW2ZxZTurVbYWA7QOrAaCYQR72t0wrSBfoesUE= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.4/go.mod h1:0oxfLkpz3rQ/CHlx5hB7H69YUpFiI1tql6Q6Ne+1bCw= +github.com/aws/aws-sdk-go-v2/service/sts v1.30.3 h1:ZsDKRLXGWHk8WdtyYMoGNO7bTudrvuKpDKgMVRlepGE= +github.com/aws/aws-sdk-go-v2/service/sts v1.30.3/go.mod h1:zwySh8fpFyXp9yOr/KVzxOl8SRqgf/IDw5aUt9UKFcQ= +github.com/aws/smithy-go v1.20.3 h1:ryHwveWzPV5BIof6fyDvor6V3iUL7nTfiTKXHiW05nE= +github.com/aws/smithy-go v1.20.3/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= @@ -87,6 +115,10 @@ github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d h1:pVrfxiGfwel github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d/go.mod h1:H0wQNHz2YrLsuXOZozoeDmnHXkNCRmMW0gwFWDfEZDA= github.com/c-robinson/iplib v1.0.3 h1:NG0UF0GoEsrC1/vyfX1Lx2Ss7CySWl3KqqXh3q4DdPU= github.com/c-robinson/iplib v1.0.3/go.mod h1:i3LuuFL1hRT5gFpBRnEydzw8R6yhGkF4szNDIbF8pgo= +github.com/caddyserver/certmagic v0.21.3 h1:pqRRry3yuB4CWBVq9+cUqu+Y6E2z8TswbhNx1AZeYm0= +github.com/caddyserver/certmagic v0.21.3/go.mod h1:Zq6pklO9nVRl3DIFUw9gVUfXKdpc/0qwTUAQMBlfgtI= +github.com/caddyserver/zerossl v0.1.3 h1:onS+pxp3M8HnHpN5MMbOMyNjmTheJyWRaZYwn+YTAyA= +github.com/caddyserver/zerossl v0.1.3/go.mod h1:CxA0acn7oEGO6//4rtrRjYgEoa4MFw/XofZnrYwGqG4= github.com/cenkalti/backoff/v4 v4.1.0/go.mod h1:scbssz8iZGpm3xbr14ovlUdkxfGXNInqkPWOWmG2CLw= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= @@ -132,8 +164,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= -github.com/docker/docker v26.1.4+incompatible h1:vuTpXDuoga+Z38m1OZHzl7NKisKWaWlhjQk7IDPSLsU= -github.com/docker/docker v26.1.4+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/docker v26.1.5+incompatible h1:NEAxTwEjxV6VbBMBoGG3zPqbiJosIApZjxlbrG9q3/g= +github.com/docker/docker v26.1.5+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c= github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= @@ -207,6 +239,8 @@ github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZs github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/go-text/render v0.1.0 h1:osrmVDZNHuP1RSu3pNG7Z77Sd2xSbcb/xWytAj9kyVs= github.com/go-text/render v0.1.0/go.mod h1:jqEuNMenrmj6QRnkdpeaP0oKGFLDNhDkVKwGjsWWYU4= github.com/go-text/typesetting v0.1.0 h1:vioSaLPYcHwPEPLT7gsjCGDCoYSbljxoHJzMnKwVvHw= @@ -350,8 +384,9 @@ github.com/hashicorp/go-sockaddr v1.0.0/go.mod h1:7Xibr9yA9JjQq1JpNB2Vw7kxv8xerX github.com/hashicorp/go-syslog v1.0.0/go.mod h1:qPfqrKkXGihmCqbJM2mZgkZGvKG1dFdvsLplgctolz4= github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= -github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE= github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= +github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-version v1.6.0 h1:feTTfFNnjP967rlCxM/I9g701jU+RN74YKx2mOkIeek= github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90= @@ -382,6 +417,10 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= github.com/json-iterator/go v0.0.0-20180612202835-f2b4162afba3/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= @@ -401,6 +440,9 @@ github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU= github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= +github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= +github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -413,6 +455,10 @@ github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.1.0/go.mod h1:+cyI34gQWZcE1eQU7NVgKkkzdXDQHr1dBMtdAPozLkw= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/libdns/libdns v0.2.2 h1:O6ws7bAfRPaBsgAYt8MDe2HcNBGC29hkZ9MX2eUSX3s= +github.com/libdns/libdns v0.2.2/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ= +github.com/libdns/route53 v1.5.0 h1:2SKdpPFl/qgWsXQvsLNJJAoX7rSxlk7zgoL4jnWdXVA= +github.com/libdns/route53 v1.5.0/go.mod h1:joT4hKmaTNKHEwb7GmZ65eoDz1whTu7KKYPS8ZqIh6Q= github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tAFlj1FYZl8ztUZ13bdq+PLY+NOfbyI= github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k= @@ -431,9 +477,11 @@ github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/ github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U= github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= +github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k= +github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= -github.com/miekg/dns v1.1.43 h1:JKfpVSCB84vrAmHzyrsxB5NAr5kLoMXZArPSw7Qlgyg= -github.com/miekg/dns v1.1.43/go.mod h1:+evo5L0630/F6ca/Z9+GAqzhjGyn8/c+TBaOyfEl0V4= +github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs= +github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= @@ -473,10 +521,12 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e h1:LYxhAmiEzSldLELHSMVoUnRPq3ztTNQImrD27frrGsI= -github.com/netbirdio/management-integrations/integrations v0.0.0-20240703085513-32605f7ffd8e/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= -github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0 h1:hirFRfx3grVA/9eEyjME5/z3nxdJlN9kfQpvWWPk32g= -github.com/netbirdio/service v0.0.0-20230215170314-b923b89432b0/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= +github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd h1:phKq1S1Y/lnqEhP5Qknta733+rPX16dRDHM7hKkot9c= +github.com/netbirdio/management-integrations/integrations v0.0.0-20240929132811-9af486d346fd/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= +github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= +github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f h1:Rl23OSc2xKFyxiuBXtWDMzhZBV4gOM7lhFxvYoCmBZg= +github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241002125159-0e132af8c51f/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed h1:t0UADZUJDaaZgfKrt8JUPrOLL9Mg/ryjP85RAH53qgs= github.com/netbirdio/wireguard-go v0.0.0-20240105182236-6c340dd55aed/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= @@ -492,14 +542,14 @@ github.com/onsi/ginkgo v1.10.1/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+ github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= -github.com/onsi/ginkgo/v2 v2.4.0 h1:+Ig9nvqgS5OBSACXNk15PLdp0U9XPYROt9CFzVdFGIs= -github.com/onsi/ginkgo/v2 v2.4.0/go.mod h1:iHkDK1fKGcBoEHT5W7YBq4RFWaQulw+caOMkAt4OrFo= +github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= +github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= github.com/onsi/gomega v0.0.0-20170829124025-dcabb60a477c/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA= github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= -github.com/onsi/gomega v1.23.0 h1:/oxKu9c2HVap+F3PfKort2Hw5DEU+HGlW8n+tguWsys= -github.com/onsi/gomega v1.23.0/go.mod h1:Z/NWtiqwBrwUt4/2loMmHL63EDLnYHmVbuBpDr2vQAg= +github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= +github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug= @@ -592,6 +642,8 @@ github.com/smartystreets/assertions v1.13.0/go.mod h1:wDmR7qL282YbGsPy6H/yAsesrx github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/smartystreets/goconvey v1.7.2 h1:9RBaZCeXEQ3UselpuwUQHltGVXvdwm6cv1hgR6gDIPg= github.com/smartystreets/goconvey v1.7.2/go.mod h1:Vw0tHAZW6lzCRk3xgdin6fKYcG+G3Pg9vgXWeJpQFMM= +github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8= +github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E= github.com/spf13/afero v1.6.0/go.mod h1:Ai8FlHk4v/PARR026UzYexafAt9roJ7LcLMAmO6Z93I= github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/spf13/cast v1.5.0 h1:rj3WzYc11XZaIZMPKmwP96zkFEnnAmV8s6XbB2aY32w= @@ -660,6 +712,12 @@ github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= github.com/zcalusic/sysinfo v1.0.2 h1:nwTTo2a+WQ0NXwo0BGRojOJvJ/5XKvQih+2RrtWqfxc= github.com/zcalusic/sysinfo v1.0.2/go.mod h1:kluzTYflRWo6/tXVMJPdEjShsbPpsFRyy+p1mBQPC30= +github.com/zeebo/assert v1.1.0 h1:hU1L1vLTHsnO8x8c9KAR5GmM5QscxHg5RNU5z5qbUWY= +github.com/zeebo/assert v1.1.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= +github.com/zeebo/blake3 v0.2.3 h1:TFoLXsjeXqRNFxSbk35Dk4YtszE/MQQGK10BH4ptoTg= +github.com/zeebo/blake3 v0.2.3/go.mod h1:mjJjZpnsyIVtVgTOSpJ9vmRE4wgDeyt2HU3qXvvKCaQ= +github.com/zeebo/pcg v1.0.1 h1:lyqfGeWiv4ahac6ttHs+I5hwtH/+1mrhlCtVNQM2kHo= +github.com/zeebo/pcg v1.0.1/go.mod h1:09F0S9iiKrwn9rlI5yjLkmrug154/YRW6KnnXVDM/l4= go.etcd.io/etcd/api/v3 v3.5.0/go.mod h1:cbVKeC6lCfl7j/8jBhAK6aIYO9XOjdptoxU/nLQcPvs= go.etcd.io/etcd/client/pkg/v3 v3.5.0/go.mod h1:IJHfcCEKxYu1Os13ZdwCwIUTUVGYTSAM3YSwc9/Ac1g= go.etcd.io/etcd/client/v2 v2.305.0/go.mod h1:h9puh54ZTgAKtEbut2oe9P4L/oqKCVB6xsXlzd7alYQ= @@ -695,8 +753,14 @@ go.opentelemetry.io/otel/trace v1.26.0/go.mod h1:4iDxvGDQuUkHve82hJJ8UqrwswHYsZu go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I= go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.17.0/go.mod h1:MXVU+bhUf/A7Xi2HNOnopQOrmycQ5Ih87HtOu4q5SSo= +go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= +go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= goauthentik.io/api/v3 v3.2023051.3 h1:NebAhD/TeTWNo/9X3/Uj+rM5fG1HaiLOlKTNLQv9Qq4= goauthentik.io/api/v3 v3.2023051.3/go.mod h1:nYECml4jGbp/541hj8GcylKQG1gVBsKppHy4+7G8u4U= golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -890,7 +954,6 @@ golang.org/x/sys v0.0.0-20210104204734-6f8348627aad/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210220050731-9a76102bfb43/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210303074136-134d130e1a04/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210305230114-8fe3ee5dd75b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210315160823-c6e025ad8005/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210320140829-1e4c9ba3b0c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -1187,6 +1250,8 @@ k8s.io/gengo v0.0.0-20190128074634-0689ccc1d7d6/go.mod h1:ezvh/TsK7cY6rbqRK0oQQ8 k8s.io/klog v0.0.0-20181102134211-b9b56d5dfc92/go.mod h1:Gq+BEi5rUBO/HRz0bTSXDUcqjScdoY3a9IHpCEIOOfk= k8s.io/klog v1.0.0/go.mod h1:4Bi6QPql/J/LkTDqv7R/cd3hPo4k2DG6Ptcz060Ez5I= k8s.io/kube-openapi v0.0.0-20191107075043-30be4d16710a/go.mod h1:1TqjTSzOxsLGIKfj0lK8EeCP7K1iUG65v09OM0/WG5E= +nhooyr.io/websocket v1.8.11 h1:f/qXNc2/3DpoSZkHt1DQu6rj4zGC8JmkkLkWss0MgN0= +nhooyr.io/websocket v1.8.11/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/iface/iface_darwin.go b/iface/iface_darwin.go deleted file mode 100644 index 15e4a7817..000000000 --- a/iface/iface_darwin.go +++ /dev/null @@ -1,38 +0,0 @@ -//go:build !ios - -package iface - -import ( - "fmt" - - "github.com/pion/transport/v3" - - "github.com/netbirdio/netbird/iface/bind" - "github.com/netbirdio/netbird/iface/netstack" -) - -// NewWGIFace Creates a new WireGuard interface instance -func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) { - wgAddress, err := parseWGAddress(address) - if err != nil { - return nil, err - } - - wgIFace := &WGIface{ - userspaceBind: true, - } - - if netstack.IsEnabled() { - wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn) - return wgIFace, nil - } - - wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn) - - return wgIFace, nil -} - -// CreateOnAndroid this function make sense on mobile only -func (w *WGIface) CreateOnAndroid([]string, string, []string) error { - return fmt.Errorf("this function has not implemented on this platform") -} diff --git a/iface/tun.go b/iface/tun.go deleted file mode 100644 index b3c0f9d80..000000000 --- a/iface/tun.go +++ /dev/null @@ -1,18 +0,0 @@ -//go:build !android -// +build !android - -package iface - -import ( - "github.com/netbirdio/netbird/iface/bind" -) - -type wgTunDevice interface { - Create() (wgConfigurer, error) - Up() (*bind.UniversalUDPMuxDefault, error) - UpdateAddr(address WGAddress) error - WgAddress() WGAddress - DeviceName() string - Close() error - Wrapper() *DeviceWrapper // todo eliminate this function -} diff --git a/iface/wg_configurer.go b/iface/wg_configurer.go deleted file mode 100644 index dd38ba075..000000000 --- a/iface/wg_configurer.go +++ /dev/null @@ -1,21 +0,0 @@ -package iface - -import ( - "errors" - "net" - "time" - - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" -) - -var ErrPeerNotFound = errors.New("peer not found") - -type wgConfigurer interface { - configureInterface(privateKey string, port int) error - updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error - removePeer(peerKey string) error - addAllowedIP(peerKey string, allowedIP string) error - removeAllowedIP(peerKey string, allowedIP string) error - close() - getStats(peerKey string) (WGStats, error) -} diff --git a/infrastructure_files/base.setup.env b/infrastructure_files/base.setup.env index 296e165f0..45dce8d88 100644 --- a/infrastructure_files/base.setup.env +++ b/infrastructure_files/base.setup.env @@ -20,6 +20,12 @@ NETBIRD_MGMT_IDP_SIGNKEY_REFRESH=${NETBIRD_MGMT_IDP_SIGNKEY_REFRESH:-false} NETBIRD_SIGNAL_PROTOCOL="http" NETBIRD_SIGNAL_PORT=${NETBIRD_SIGNAL_PORT:-10000} +# Relay +NETBIRD_RELAY_DOMAIN=${NETBIRD_RELAY_DOMAIN:-$NETBIRD_DOMAIN} +NETBIRD_RELAY_PORT=${NETBIRD_RELAY_PORT:-33080} +# Relay auth secret +NETBIRD_RELAY_AUTH_SECRET= + # Turn TURN_DOMAIN=${NETBIRD_TURN_DOMAIN:-$NETBIRD_DOMAIN} @@ -69,7 +75,7 @@ NETBIRD_DASHBOARD_TAG=${NETBIRD_DASHBOARD_TAG:-"latest"} NETBIRD_SIGNAL_TAG=${NETBIRD_SIGNAL_TAG:-"latest"} NETBIRD_MANAGEMENT_TAG=${NETBIRD_MANAGEMENT_TAG:-"latest"} COTURN_TAG=${COTURN_TAG:-"latest"} - +NETBIRD_RELAY_TAG=${NETBIRD_RELAY_TAG:-"latest"} # exports export NETBIRD_DOMAIN @@ -123,3 +129,7 @@ export NETBIRD_SIGNAL_TAG export NETBIRD_MANAGEMENT_TAG export COTURN_TAG export NETBIRD_TURN_EXTERNAL_IP +export NETBIRD_RELAY_DOMAIN +export NETBIRD_RELAY_PORT +export NETBIRD_RELAY_AUTH_SECRET +export NETBIRD_RELAY_TAG diff --git a/infrastructure_files/configure.sh b/infrastructure_files/configure.sh index f04735de6..ff33004b2 100755 --- a/infrastructure_files/configure.sh +++ b/infrastructure_files/configure.sh @@ -41,6 +41,18 @@ if [[ "x-$NETBIRD_DOMAIN" == "x-" ]]; then exit 1 fi +# Check if PostgreSQL is set as the store engine +if [[ "$NETBIRD_STORE_CONFIG_ENGINE" == "postgres" ]]; then + # Exit if 'NETBIRD_STORE_ENGINE_POSTGRES_DSN' is not set + if [[ -z "$NETBIRD_STORE_ENGINE_POSTGRES_DSN" ]]; then + echo "Warning: NETBIRD_STORE_CONFIG_ENGINE=postgres but NETBIRD_STORE_ENGINE_POSTGRES_DSN is not set." + echo "Please add the following line to your setup.env file:" + echo 'NETBIRD_STORE_ENGINE_POSTGRES_DSN="host= user= password= dbname= port="' + exit 1 + fi + export NETBIRD_STORE_ENGINE_POSTGRES_DSN +fi + # local development or tests if [[ $NETBIRD_DOMAIN == "localhost" || $NETBIRD_DOMAIN == "127.0.0.1" ]]; then export NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN="netbird.selfhosted" @@ -77,6 +89,11 @@ fi export TURN_EXTERNAL_IP_CONFIG +# if not provided, we generate a relay auth secret +if [[ "x-$NETBIRD_RELAY_AUTH_SECRET" == "x-" ]]; then + export NETBIRD_RELAY_AUTH_SECRET=$(openssl rand -base64 32 | sed 's/=//g') +fi + artifacts_path="./artifacts" mkdir -p $artifacts_path diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl index 6b6831493..ba68b3f8d 100644 --- a/infrastructure_files/docker-compose.yml.tmpl +++ b/infrastructure_files/docker-compose.yml.tmpl @@ -49,6 +49,23 @@ services: options: max-size: "500m" max-file: "2" + # Relay + relay: + image: netbirdio/relay:$NETBIRD_RELAY_TAG + restart: unless-stopped + environment: + - NB_LOG_LEVEL=info + - NB_LISTEN_ADDRESS=:$NETBIRD_RELAY_PORT + - NB_EXPOSED_ADDRESS=$NETBIRD_RELAY_DOMAIN:$NETBIRD_RELAY_PORT + # todo: change to a secure secret + - NB_AUTH_SECRET=$NETBIRD_RELAY_AUTH_SECRET + ports: + - $NETBIRD_RELAY_PORT:$NETBIRD_RELAY_PORT + logging: + driver: "json-file" + options: + max-size: "500m" + max-file: "2" # Management management: @@ -77,6 +94,9 @@ services: options: max-size: "500m" max-file: "2" + environment: + - NETBIRD_STORE_ENGINE_POSTGRES_DSN=$NETBIRD_STORE_ENGINE_POSTGRES_DSN + # Coturn coturn: image: coturn/coturn:$COTURN_TAG diff --git a/infrastructure_files/docker-compose.yml.tmpl.traefik b/infrastructure_files/docker-compose.yml.tmpl.traefik index d3ae6529a..c4415d848 100644 --- a/infrastructure_files/docker-compose.yml.tmpl.traefik +++ b/infrastructure_files/docker-compose.yml.tmpl.traefik @@ -81,7 +81,9 @@ services: - traefik.http.routers.netbird-management.service=netbird-management - traefik.http.services.netbird-management.loadbalancer.server.port=443 - traefik.http.services.netbird-management.loadbalancer.server.scheme=h2c - + environment: + - NETBIRD_STORE_ENGINE_POSTGRES_DSN=$NETBIRD_STORE_ENGINE_POSTGRES_DSN + # Coturn coturn: image: coturn/coturn:$COTURN_TAG diff --git a/infrastructure_files/download-geolite2.sh b/infrastructure_files/download-geolite2.sh deleted file mode 100755 index 4a9db5e01..000000000 --- a/infrastructure_files/download-geolite2.sh +++ /dev/null @@ -1,109 +0,0 @@ -#!/bin/bash - -# to install sha256sum on mac: brew install coreutils -if ! command -v sha256sum &> /dev/null -then - echo "sha256sum is not installed or not in PATH, please install with your package manager. e.g. sudo apt install sha256sum" > /dev/stderr - exit 1 -fi - -if ! command -v sqlite3 &> /dev/null -then - echo "sqlite3 is not installed or not in PATH, please install with your package manager. e.g. sudo apt install sqlite3" > /dev/stderr - exit 1 -fi - -if ! command -v unzip &> /dev/null -then - echo "unzip is not installed or not in PATH, please install with your package manager. e.g. sudo apt install unzip" > /dev/stderr - exit 1 -fi - -download_geolite_mmdb() { - DATABASE_URL="https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz" - SIGNATURE_URL="https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz.sha256" - # Download the database and signature files - echo "Downloading mmdb signature file..." - SIGNATURE_FILE=$(curl -s -L -O -J "$SIGNATURE_URL" -w "%{filename_effective}") - echo "Downloading mmdb database file..." - DATABASE_FILE=$(curl -s -L -O -J "$DATABASE_URL" -w "%{filename_effective}") - - # Verify the signature - echo "Verifying signature..." - if sha256sum -c --status "$SIGNATURE_FILE"; then - echo "Signature is valid." - else - echo "Signature is invalid. Aborting." - exit 1 - fi - - # Unpack the database file - EXTRACTION_DIR=$(basename "$DATABASE_FILE" .tar.gz) - echo "Unpacking $DATABASE_FILE..." - mkdir -p "$EXTRACTION_DIR" - tar -xzvf "$DATABASE_FILE" > /dev/null 2>&1 - - MMDB_FILE="GeoLite2-City.mmdb" - cp "$EXTRACTION_DIR"/"$MMDB_FILE" $MMDB_FILE - - # Remove downloaded files - rm -r "$EXTRACTION_DIR" - rm "$DATABASE_FILE" "$SIGNATURE_FILE" - - # Done. Print next steps - echo "" - echo "Process completed successfully." - echo "Now you can place $MMDB_FILE to 'datadir' of management service." - echo -e "Example:\n\tdocker compose cp $MMDB_FILE management:/var/lib/netbird/" -} - - -download_geolite_csv_and_create_sqlite_db() { - DATABASE_URL="https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City-CSV/download?suffix=zip" - SIGNATURE_URL="https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City-CSV/download?suffix=zip.sha256" - - - # Download the database file - echo "Downloading csv signature file..." - SIGNATURE_FILE=$(curl -s -L -O -J "$SIGNATURE_URL" -w "%{filename_effective}") - echo "Downloading csv database file..." - DATABASE_FILE=$(curl -s -L -O -J "$DATABASE_URL" -w "%{filename_effective}") - - # Verify the signature - echo "Verifying signature..." - if sha256sum -c --status "$SIGNATURE_FILE"; then - echo "Signature is valid." - else - echo "Signature is invalid. Aborting." - exit 1 - fi - - # Unpack the database file - EXTRACTION_DIR=$(basename "$DATABASE_FILE" .zip) - DB_NAME="geonames.db" - - echo "Unpacking $DATABASE_FILE..." - unzip "$DATABASE_FILE" > /dev/null 2>&1 - -# Create SQLite database and import data from CSV -sqlite3 "$DB_NAME" < dashboard.env echo "" > turnserver.conf echo "" > management.json + echo "" > relay.env mkdir -p machinekey chmod 777 machinekey @@ -498,6 +514,7 @@ initEnvironment() { renderTurnServerConf > turnserver.conf renderManagementJson > management.json renderDashboardEnv > dashboard.env + renderRelayEnv > relay.env echo -e "\nStarting NetBird services\n" $DOCKER_COMPOSE_COMMAND up -d @@ -541,7 +558,7 @@ renderCaddyfile() { # clickjacking protection # https://cheatsheetseries.owasp.org/cheatsheets/HTTP_Headers_Cheat_Sheet.html#x-frame-options - X-Frame-Options "DENY" + X-Frame-Options "SAMEORIGIN" # xss protection # https://cheatsheetseries.owasp.org/cheatsheets/HTTP_Headers_Cheat_Sheet.html#x-xss-protection @@ -559,6 +576,8 @@ renderCaddyfile() { :80${CADDY_SECURE_DOMAIN} { import security_headers + # relay + reverse_proxy /relay* relay:80 # Signal reverse_proxy /signalexchange.SignalExchange/* h2c://signal:10000 # Management @@ -629,6 +648,11 @@ renderManagementJson() { ], "TimeBasedCredentials": false }, + "Relay": { + "Addresses": ["$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT"], + "CredentialsTTL": "24h", + "Secret": "$NETBIRD_RELAY_AUTH_SECRET" + }, "Signal": { "Proto": "$NETBIRD_HTTP_PROTOCOL", "URI": "$NETBIRD_DOMAIN:$NETBIRD_PORT" @@ -744,6 +768,15 @@ POSTGRES_PASSWORD=$POSTGRES_ROOT_PASSWORD EOF } +renderRelayEnv() { + cat < management.PeerSystemMeta 17, // 1: management.SyncResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig - 20, // 2: management.SyncResponse.peerConfig:type_name -> management.PeerConfig - 22, // 3: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig - 21, // 4: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap - 37, // 5: management.SyncResponse.Checks:type_name -> management.Checks + 21, // 2: management.SyncResponse.peerConfig:type_name -> management.PeerConfig + 23, // 3: management.SyncResponse.remotePeers:type_name -> management.RemotePeerConfig + 22, // 4: management.SyncResponse.NetworkMap:type_name -> management.NetworkMap + 38, // 5: management.SyncResponse.Checks:type_name -> management.Checks 13, // 6: management.SyncMetaRequest.meta:type_name -> management.PeerSystemMeta 13, // 7: management.LoginRequest.meta:type_name -> management.PeerSystemMeta 10, // 8: management.LoginRequest.peerKeys:type_name -> management.PeerKeys - 36, // 9: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress + 37, // 9: management.PeerSystemMeta.networkAddresses:type_name -> management.NetworkAddress 11, // 10: management.PeerSystemMeta.environment:type_name -> management.Environment 12, // 11: management.PeerSystemMeta.files:type_name -> management.File 17, // 12: management.LoginResponse.wiretrusteeConfig:type_name -> management.WiretrusteeConfig - 20, // 13: management.LoginResponse.peerConfig:type_name -> management.PeerConfig - 37, // 14: management.LoginResponse.Checks:type_name -> management.Checks - 38, // 15: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp + 21, // 13: management.LoginResponse.peerConfig:type_name -> management.PeerConfig + 38, // 14: management.LoginResponse.Checks:type_name -> management.Checks + 42, // 15: management.ServerKeyResponse.expiresAt:type_name -> google.protobuf.Timestamp 18, // 16: management.WiretrusteeConfig.stuns:type_name -> management.HostConfig - 19, // 17: management.WiretrusteeConfig.turns:type_name -> management.ProtectedHostConfig + 20, // 17: management.WiretrusteeConfig.turns:type_name -> management.ProtectedHostConfig 18, // 18: management.WiretrusteeConfig.signal:type_name -> management.HostConfig - 0, // 19: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol - 18, // 20: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig - 23, // 21: management.PeerConfig.sshConfig:type_name -> management.SSHConfig - 20, // 22: management.NetworkMap.peerConfig:type_name -> management.PeerConfig - 22, // 23: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig - 29, // 24: management.NetworkMap.Routes:type_name -> management.Route - 30, // 25: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig - 22, // 26: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig - 35, // 27: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule - 23, // 28: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig - 1, // 29: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider - 28, // 30: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 28, // 31: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig - 33, // 32: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup - 31, // 33: management.DNSConfig.CustomZones:type_name -> management.CustomZone - 32, // 34: management.CustomZone.Records:type_name -> management.SimpleRecord - 34, // 35: management.NameServerGroup.NameServers:type_name -> management.NameServer - 2, // 36: management.FirewallRule.Direction:type_name -> management.FirewallRule.direction - 3, // 37: management.FirewallRule.Action:type_name -> management.FirewallRule.action - 4, // 38: management.FirewallRule.Protocol:type_name -> management.FirewallRule.protocol - 5, // 39: management.ManagementService.Login:input_type -> management.EncryptedMessage - 5, // 40: management.ManagementService.Sync:input_type -> management.EncryptedMessage - 16, // 41: management.ManagementService.GetServerKey:input_type -> management.Empty - 16, // 42: management.ManagementService.isHealthy:input_type -> management.Empty - 5, // 43: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage - 5, // 44: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage - 5, // 45: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage - 5, // 46: management.ManagementService.Login:output_type -> management.EncryptedMessage - 5, // 47: management.ManagementService.Sync:output_type -> management.EncryptedMessage - 15, // 48: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse - 16, // 49: management.ManagementService.isHealthy:output_type -> management.Empty - 5, // 50: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage - 5, // 51: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage - 16, // 52: management.ManagementService.SyncMeta:output_type -> management.Empty - 46, // [46:53] is the sub-list for method output_type - 39, // [39:46] is the sub-list for method input_type - 39, // [39:39] is the sub-list for extension type_name - 39, // [39:39] is the sub-list for extension extendee - 0, // [0:39] is the sub-list for field type_name + 19, // 19: management.WiretrusteeConfig.relay:type_name -> management.RelayConfig + 3, // 20: management.HostConfig.protocol:type_name -> management.HostConfig.Protocol + 18, // 21: management.ProtectedHostConfig.hostConfig:type_name -> management.HostConfig + 24, // 22: management.PeerConfig.sshConfig:type_name -> management.SSHConfig + 21, // 23: management.NetworkMap.peerConfig:type_name -> management.PeerConfig + 23, // 24: management.NetworkMap.remotePeers:type_name -> management.RemotePeerConfig + 30, // 25: management.NetworkMap.Routes:type_name -> management.Route + 31, // 26: management.NetworkMap.DNSConfig:type_name -> management.DNSConfig + 23, // 27: management.NetworkMap.offlinePeers:type_name -> management.RemotePeerConfig + 36, // 28: management.NetworkMap.FirewallRules:type_name -> management.FirewallRule + 40, // 29: management.NetworkMap.routesFirewallRules:type_name -> management.RouteFirewallRule + 24, // 30: management.RemotePeerConfig.sshConfig:type_name -> management.SSHConfig + 4, // 31: management.DeviceAuthorizationFlow.Provider:type_name -> management.DeviceAuthorizationFlow.provider + 29, // 32: management.DeviceAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 29, // 33: management.PKCEAuthorizationFlow.ProviderConfig:type_name -> management.ProviderConfig + 34, // 34: management.DNSConfig.NameServerGroups:type_name -> management.NameServerGroup + 32, // 35: management.DNSConfig.CustomZones:type_name -> management.CustomZone + 33, // 36: management.CustomZone.Records:type_name -> management.SimpleRecord + 35, // 37: management.NameServerGroup.NameServers:type_name -> management.NameServer + 1, // 38: management.FirewallRule.Direction:type_name -> management.RuleDirection + 2, // 39: management.FirewallRule.Action:type_name -> management.RuleAction + 0, // 40: management.FirewallRule.Protocol:type_name -> management.RuleProtocol + 41, // 41: management.PortInfo.range:type_name -> management.PortInfo.Range + 2, // 42: management.RouteFirewallRule.action:type_name -> management.RuleAction + 0, // 43: management.RouteFirewallRule.protocol:type_name -> management.RuleProtocol + 39, // 44: management.RouteFirewallRule.portInfo:type_name -> management.PortInfo + 5, // 45: management.ManagementService.Login:input_type -> management.EncryptedMessage + 5, // 46: management.ManagementService.Sync:input_type -> management.EncryptedMessage + 16, // 47: management.ManagementService.GetServerKey:input_type -> management.Empty + 16, // 48: management.ManagementService.isHealthy:input_type -> management.Empty + 5, // 49: management.ManagementService.GetDeviceAuthorizationFlow:input_type -> management.EncryptedMessage + 5, // 50: management.ManagementService.GetPKCEAuthorizationFlow:input_type -> management.EncryptedMessage + 5, // 51: management.ManagementService.SyncMeta:input_type -> management.EncryptedMessage + 5, // 52: management.ManagementService.Login:output_type -> management.EncryptedMessage + 5, // 53: management.ManagementService.Sync:output_type -> management.EncryptedMessage + 15, // 54: management.ManagementService.GetServerKey:output_type -> management.ServerKeyResponse + 16, // 55: management.ManagementService.isHealthy:output_type -> management.Empty + 5, // 56: management.ManagementService.GetDeviceAuthorizationFlow:output_type -> management.EncryptedMessage + 5, // 57: management.ManagementService.GetPKCEAuthorizationFlow:output_type -> management.EncryptedMessage + 16, // 58: management.ManagementService.SyncMeta:output_type -> management.Empty + 52, // [52:59] is the sub-list for method output_type + 45, // [45:52] is the sub-list for method input_type + 45, // [45:45] is the sub-list for extension type_name + 45, // [45:45] is the sub-list for extension extendee + 0, // [0:45] is the sub-list for field type_name } func init() { file_management_proto_init() } @@ -3256,7 +3630,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[14].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ProtectedHostConfig); i { + switch v := v.(*RelayConfig); i { case 0: return &v.state case 1: @@ -3268,7 +3642,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[15].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PeerConfig); i { + switch v := v.(*ProtectedHostConfig); i { case 0: return &v.state case 1: @@ -3280,7 +3654,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[16].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NetworkMap); i { + switch v := v.(*PeerConfig); i { case 0: return &v.state case 1: @@ -3292,7 +3666,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[17].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RemotePeerConfig); i { + switch v := v.(*NetworkMap); i { case 0: return &v.state case 1: @@ -3304,7 +3678,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[18].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SSHConfig); i { + switch v := v.(*RemotePeerConfig); i { case 0: return &v.state case 1: @@ -3316,7 +3690,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[19].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DeviceAuthorizationFlowRequest); i { + switch v := v.(*SSHConfig); i { case 0: return &v.state case 1: @@ -3328,7 +3702,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DeviceAuthorizationFlow); i { + switch v := v.(*DeviceAuthorizationFlowRequest); i { case 0: return &v.state case 1: @@ -3340,7 +3714,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PKCEAuthorizationFlowRequest); i { + switch v := v.(*DeviceAuthorizationFlow); i { case 0: return &v.state case 1: @@ -3352,7 +3726,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[22].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PKCEAuthorizationFlow); i { + switch v := v.(*PKCEAuthorizationFlowRequest); i { case 0: return &v.state case 1: @@ -3364,7 +3738,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[23].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ProviderConfig); i { + switch v := v.(*PKCEAuthorizationFlow); i { case 0: return &v.state case 1: @@ -3376,7 +3750,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Route); i { + switch v := v.(*ProviderConfig); i { case 0: return &v.state case 1: @@ -3388,7 +3762,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[25].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DNSConfig); i { + switch v := v.(*Route); i { case 0: return &v.state case 1: @@ -3400,7 +3774,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[26].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CustomZone); i { + switch v := v.(*DNSConfig); i { case 0: return &v.state case 1: @@ -3412,7 +3786,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[27].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SimpleRecord); i { + switch v := v.(*CustomZone); i { case 0: return &v.state case 1: @@ -3424,7 +3798,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[28].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NameServerGroup); i { + switch v := v.(*SimpleRecord); i { case 0: return &v.state case 1: @@ -3436,7 +3810,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[29].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NameServer); i { + switch v := v.(*NameServerGroup); i { case 0: return &v.state case 1: @@ -3448,7 +3822,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[30].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*FirewallRule); i { + switch v := v.(*NameServer); i { case 0: return &v.state case 1: @@ -3460,7 +3834,7 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[31].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NetworkAddress); i { + switch v := v.(*FirewallRule); i { case 0: return &v.state case 1: @@ -3472,6 +3846,18 @@ func file_management_proto_init() { } } file_management_proto_msgTypes[32].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*NetworkAddress); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[33].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*Checks); i { case 0: return &v.state @@ -3483,6 +3869,46 @@ func file_management_proto_init() { return nil } } + file_management_proto_msgTypes[34].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*PortInfo); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[35].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RouteFirewallRule); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_management_proto_msgTypes[36].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*PortInfo_Range); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + file_management_proto_msgTypes[34].OneofWrappers = []interface{}{ + (*PortInfo_Port)(nil), + (*PortInfo_Range_)(nil), } type x struct{} out := protoimpl.TypeBuilder{ @@ -3490,7 +3916,7 @@ func file_management_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_management_proto_rawDesc, NumEnums: 5, - NumMessages: 33, + NumMessages: 37, NumExtensions: 0, NumServices: 1, }, diff --git a/management/proto/management.proto b/management/proto/management.proto index 06b243773..fe6a828b1 100644 --- a/management/proto/management.proto +++ b/management/proto/management.proto @@ -177,6 +177,8 @@ message WiretrusteeConfig { // a Signal server config HostConfig signal = 3; + + RelayConfig relay = 4; } // HostConfig describes connection properties of some server (e.g. STUN, Signal, Management) @@ -193,6 +195,13 @@ message HostConfig { DTLS = 4; } } + +message RelayConfig { + repeated string urls = 1; + string tokenPayload = 2; + string tokenSignature = 3; +} + // ProtectedHostConfig is similar to HostConfig but has additional user and password // Mostly used for TURN servers message ProtectedHostConfig { @@ -245,6 +254,12 @@ message NetworkMap { // firewallRulesIsEmpty indicates whether FirewallRule array is empty or not to bypass protobuf null and empty array equality. bool firewallRulesIsEmpty = 9; + + // RoutesFirewallRules represents a list of routes firewall rules to be applied to peer + repeated RouteFirewallRule routesFirewallRules = 10; + + // RoutesFirewallRulesIsEmpty indicates whether RouteFirewallRule array is empty or not to bypass protobuf null and empty array equality. + bool routesFirewallRulesIsEmpty = 11; } // RemotePeerConfig represents a configuration of a remote peer. @@ -375,29 +390,32 @@ message NameServer { int64 Port = 3; } +enum RuleProtocol { + UNKNOWN = 0; + ALL = 1; + TCP = 2; + UDP = 3; + ICMP = 4; +} + +enum RuleDirection { + IN = 0; + OUT = 1; +} + +enum RuleAction { + ACCEPT = 0; + DROP = 1; +} + + // FirewallRule represents a firewall rule message FirewallRule { string PeerIP = 1; - direction Direction = 2; - action Action = 3; - protocol Protocol = 4; + RuleDirection Direction = 2; + RuleAction Action = 3; + RuleProtocol Protocol = 4; string Port = 5; - - enum direction { - IN = 0; - OUT = 1; - } - enum action { - ACCEPT = 0; - DROP = 1; - } - enum protocol { - UNKNOWN = 0; - ALL = 1; - TCP = 2; - UDP = 3; - ICMP = 4; - } } message NetworkAddress { @@ -406,5 +424,40 @@ message NetworkAddress { } message Checks { - repeated string Files= 1; + repeated string Files = 1; } + + +message PortInfo { + oneof portSelection { + uint32 port = 1; + Range range = 2; + } + + message Range { + uint32 start = 1; + uint32 end = 2; + } +} + +// RouteFirewallRule signifies a firewall rule applicable for a routed network. +message RouteFirewallRule { + // sourceRanges IP ranges of the routing peers. + repeated string sourceRanges = 1; + + // Action to be taken by the firewall when the rule is applicable. + RuleAction action = 2; + + // Network prefix for the routed network. + string destination = 3; + + // Protocol of the routed network. + RuleProtocol protocol = 4; + + // Details about the port. + PortInfo portInfo = 5; + + // IsDynamic indicates if the route is a DNS route. + bool isDynamic = 6; +} + diff --git a/management/server/account.go b/management/server/account.go index 0a91ae6d8..d5e8c8cf8 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -20,11 +20,6 @@ import ( cacheStore "github.com/eko/gocache/v3/store" "github.com/hashicorp/go-multierror" "github.com/miekg/dns" - gocache "github.com/patrickmn/go-cache" - "github.com/rs/xid" - log "github.com/sirupsen/logrus" - "golang.org/x/exp/maps" - "github.com/netbirdio/netbird/base62" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" @@ -41,6 +36,10 @@ import ( "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/route" + gocache "github.com/patrickmn/go-cache" + "github.com/rs/xid" + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" ) const ( @@ -63,6 +62,7 @@ func cacheEntryExpiration() time.Duration { type AccountManager interface { GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*Account, error) + GetAccount(ctx context.Context, accountID string) (*Account, error) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error) @@ -75,12 +75,14 @@ type AccountManager interface { SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) - GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error) - GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) + GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) + GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) + GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error) DeleteAccount(ctx context.Context, accountID, userID string) error MarkPATUsed(ctx context.Context, tokenID string) error + GetUserByID(ctx context.Context, id string) (*User, error) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) ListUsers(ctx context.Context, accountID string) ([]*User, error) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) @@ -94,6 +96,7 @@ type AccountManager interface { DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) + UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error) GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) @@ -106,11 +109,11 @@ type AccountManager interface { GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) - SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error + SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error DeletePolicy(ctx context.Context, accountID, policyID, userID string) error ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) - CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) + CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) SaveRoute(ctx context.Context, accountID, userID string, route *route.Route) error DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) @@ -144,6 +147,7 @@ type AccountManager interface { SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) + GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) } type DefaultAccountManager struct { @@ -160,6 +164,8 @@ type DefaultAccountManager struct { eventStore activity.Store geo *geolocation.Geolocation + requestBuffer *AccountRequestBuffer + // singleAccountMode indicates whether the instance has a single account. // If true, then every new user will end up under the same account. // This value will be set to false if management service has more than one account. @@ -260,6 +266,16 @@ type AccountSettings struct { Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` } +// Subclass used in gorm to only load network and not whole account +type AccountNetwork struct { + Network *Network `gorm:"embedded;embeddedPrefix:network_"` +} + +// AccountDNSSettings used in gorm to only load dns settings and not whole account +type AccountDNSSettings struct { + DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` +} + type UserPermissions struct { DashboardView string `json:"dashboard_view"` } @@ -444,6 +460,7 @@ func (a *Account) GetPeerNetworkMap( } routesUpdate := a.getRoutesToSync(ctx, peerID, peersToConnect) + routesFirewallRules := a.getPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap) dnsManagementStatus := a.getPeerDNSManagementStatus(peerID) dnsUpdate := nbdns.Config{ @@ -467,12 +484,19 @@ func (a *Account) GetPeerNetworkMap( DNSConfig: dnsUpdate, OfflinePeers: expiredPeers, FirewallRules: firewallRules, + RoutesFirewallRules: routesFirewallRules, } if metrics != nil { objectCount := int64(len(peersToConnect) + len(expiredPeers) + len(routesUpdate) + len(firewallRules)) metrics.CountNetworkMapObjects(objectCount) metrics.CountGetPeerNetworkMapDuration(time.Since(start)) + + if objectCount > 5000 { + log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+ + "peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d", + a.Id, objectCount, len(peersToConnect), len(expiredPeers), len(routesUpdate), len(firewallRules)) + } } return nm @@ -691,14 +715,6 @@ func (a *Account) GetPeerGroupsList(peerID string) []string { return grps } -func (a *Account) getUserGroups(userID string) ([]string, error) { - user, err := a.FindUser(userID) - if err != nil { - return nil, err - } - return user.AutoGroups, nil -} - func (a *Account) getPeerDNSManagementStatus(peerID string) bool { peerGroups := a.getPeerGroups(peerID) enabled := true @@ -725,14 +741,6 @@ func (a *Account) getPeerGroups(peerID string) lookupMap { return groupList } -func (a *Account) getSetupKeyGroups(setupKey string) ([]string, error) { - key, err := a.FindSetupKey(setupKey) - if err != nil { - return nil, err - } - return key.AutoGroups, nil -} - func (a *Account) getTakenIPs() []net.IP { var takenIps []net.IP for _, existingPeer := range a.Peers { @@ -966,6 +974,7 @@ func BuildManager( userDeleteFromIDPEnabled: userDeleteFromIDPEnabled, integratedPeerValidator: integratedPeerValidator, metrics: metrics, + requestBuffer: NewAccountRequestBuffer(ctx, store), } allAccounts := store.GetAllAccounts(ctx) // enable single account mode only if configured by user and number of existing accounts is not grater than 1 @@ -1253,25 +1262,37 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return nil } -// GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and -// userID doesn't have an account associated with it, one account is created -// domain is used to create a new account if no account is found -func (am *DefaultAccountManager) GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error) { +// GetAccountIDByUserOrAccountID retrieves the account ID based on either the userID or accountID provided. +// If an accountID is provided, it checks if the account exists and returns it. +// If no accountID is provided, but a userID is given, it tries to retrieve the account by userID. +// If the user doesn't have an account, it creates one using the provided domain. +// Returns the account ID or an error if none is found or created. +func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) { if accountID != "" { - return am.Store.GetAccount(ctx, accountID) - } else if userID != "" { - account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) + exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID) if err != nil { - return nil, status.Errorf(status.NotFound, "account not found using user id: %s", userID) + return "", err } - err = am.addAccountIDToIDPAppMeta(ctx, userID, account) - if err != nil { - return nil, err + if !exists { + return "", status.Errorf(status.NotFound, "account %s does not exist", accountID) } - return account, nil + return accountID, nil } - return nil, status.Errorf(status.NotFound, "no valid user or account Id provided") + if userID != "" { + account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) + if err != nil { + return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) + } + + if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil { + return "", err + } + + return account.Id, nil + } + + return "", status.Errorf(status.NotFound, "no valid userID or accountID provided") } func isNil(i idp.Manager) bool { @@ -1614,13 +1635,18 @@ func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domai } // redeemInvite checks whether user has been invited and redeems the invite -func (am *DefaultAccountManager) redeemInvite(ctx context.Context, account *Account, userID string) error { +func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID string, userID string) error { // only possible with the enabled IdP manager if am.idpManager == nil { log.WithContext(ctx).Warnf("invites only work with enabled IdP manager") return nil } + account, err := am.Store.GetAccount(ctx, accountID) + if err != nil { + return err + } + user, err := am.lookupUserInCache(ctx, userID, account) if err != nil { return err @@ -1679,6 +1705,11 @@ func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string return am.Store.SaveAccount(ctx, account) } +// GetAccount returns an account associated with this account ID. +func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*Account, error) { + return am.Store.GetAccount(ctx, accountID) +} + // GetAccountFromPAT returns Account and User associated with a personal access token func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*Account, *User, *PersonalAccessToken, error) { if len(token) != PATLength { @@ -1727,10 +1758,24 @@ func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token st return account, user, pat, nil } -// GetAccountFromToken returns an account associated with this token -func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) { +// GetAccountByID returns an account associated with this account ID. +func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return nil, err + } + + if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) { + return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") + } + + return am.Store.GetAccount(ctx, accountID) +} + +// GetAccountIDFromToken returns an account ID associated with this token. +func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { if claims.UserId == "" { - return nil, nil, fmt.Errorf("user ID is empty") + return "", "", fmt.Errorf("user ID is empty") } if am.singleAccountMode && am.singleAccountModeDomain != "" { // This section is mostly related to self-hosted installations. @@ -1740,115 +1785,111 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") } - newAcc, err := am.getAccountWithAuthorizationClaims(ctx, claims) + accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, claims) if err != nil { - return nil, nil, err - } - unlock := am.Store.AcquireWriteLockByUID(ctx, newAcc.Id) - alreadyUnlocked := false - defer func() { - if !alreadyUnlocked { - unlock() - } - }() - - account, err := am.Store.GetAccount(ctx, newAcc.Id) - if err != nil { - return nil, nil, err + return "", "", err } - user := account.Users[claims.UserId] - if user == nil { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) + if err != nil { // this is not really possible because we got an account by user ID - return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId) + return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId) } if !user.IsServiceUser && claims.Invited { - err = am.redeemInvite(ctx, account, claims.UserId) + err = am.redeemInvite(ctx, accountID, user.Id) if err != nil { - return nil, nil, err + return "", "", err } } - if account.Settings.JWTGroupsEnabled { - if account.Settings.JWTGroupsClaimName == "" { - log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set") - return account, user, nil - } - if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok { - if slice, ok := claim.([]interface{}); ok { - var groupsNames []string - for _, item := range slice { - if g, ok := item.(string); ok { - groupsNames = append(groupsNames, g) - } else { - log.WithContext(ctx).Errorf("JWT claim %q is not a string: %v", account.Settings.JWTGroupsClaimName, item) - } - } - - oldGroups := make([]string, len(user.AutoGroups)) - copy(oldGroups, user.AutoGroups) - // if groups were added or modified, save the account - if account.SetJWTGroups(claims.UserId, groupsNames) { - if account.Settings.GroupsPropagationEnabled { - if user, err := account.FindUser(claims.UserId); err == nil { - addNewGroups := difference(user.AutoGroups, oldGroups) - removeOldGroups := difference(oldGroups, user.AutoGroups) - account.UserGroupsAddToPeers(claims.UserId, addNewGroups...) - account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...) - - account.Network.IncSerial() - if err := am.Store.SaveAccount(ctx, account); err != nil { - log.WithContext(ctx).Errorf("failed to save account: %v", err) - } else { - log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - - if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) { - am.updateAccountPeers(ctx, account) - } - unlock() - - alreadyUnlocked = true - for _, g := range addNewGroups { - if group := account.GetGroup(g); group != nil { - am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser, - map[string]any{ - "group": group.Name, - "group_id": group.ID, - "is_service_user": user.IsServiceUser, - "user_name": user.ServiceUserName}) - } - } - for _, g := range removeOldGroups { - if group := account.GetGroup(g); group != nil { - am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser, - map[string]any{ - "group": group.Name, - "group_id": group.ID, - "is_service_user": user.IsServiceUser, - "user_name": user.ServiceUserName}) - } - } - } - } - } else { - if err := am.Store.SaveAccount(ctx, account); err != nil { - log.WithContext(ctx).Errorf("failed to save account: %v", err) - } - } - } - } else { - log.WithContext(ctx).Debugf("JWT claim %q is not a string array", account.Settings.JWTGroupsClaimName) - } - } else { - log.WithContext(ctx).Debugf("JWT claim %q not found", account.Settings.JWTGroupsClaimName) - } + if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil { + return "", "", err } - return account, user, nil + return accountID, user.Id, nil } -// getAccountWithAuthorizationClaims retrievs an account using JWT Claims. +// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, +// and propagates changes to peers if group propagation is enabled. +func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, user *User, claims jwtclaims.AuthorizationClaims) error { + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + + if settings == nil || !settings.JWTGroupsEnabled { + return nil + } + + if settings.JWTGroupsClaimName == "" { + log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set") + return nil + } + + // TODO: Remove GetAccount after refactoring account peer's update + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + account, err := am.Store.GetAccount(ctx, accountID) + if err != nil { + return err + } + + jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) + + oldGroups := make([]string, len(user.AutoGroups)) + copy(oldGroups, user.AutoGroups) + + // Update the account if group membership changes + if account.SetJWTGroups(claims.UserId, jwtGroupsNames) { + addNewGroups := difference(user.AutoGroups, oldGroups) + removeOldGroups := difference(oldGroups, user.AutoGroups) + + if settings.GroupsPropagationEnabled { + account.UserGroupsAddToPeers(claims.UserId, addNewGroups...) + account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...) + account.Network.IncSerial() + } + + if err := am.Store.SaveAccount(ctx, account); err != nil { + log.WithContext(ctx).Errorf("failed to save account: %v", err) + return nil + } + + // Propagate changes to peers if group propagation is enabled + if settings.GroupsPropagationEnabled { + log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) + am.updateAccountPeers(ctx, account) + } + + for _, g := range addNewGroups { + if group := account.GetGroup(g); group != nil { + am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser, + map[string]any{ + "group": group.Name, + "group_id": group.ID, + "is_service_user": user.IsServiceUser, + "user_name": user.ServiceUserName}) + } + } + + for _, g := range removeOldGroups { + if group := account.GetGroup(g); group != nil { + am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser, + map[string]any{ + "group": group.Name, + "group_id": group.ID, + "is_service_user": user.IsServiceUser, + "user_name": user.ServiceUserName}) + } + } + } + + return nil +} + +// getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims. // if domain is of the PrivateCategory category, it will evaluate // if account is new, existing or if there is another account with the same domain // @@ -1865,26 +1906,34 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims // Existing user + Existing account + Existing Indexed Domain -> Nothing changes // // Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) -func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, error) { +func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"", claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory) if claims.UserId == "" { - return nil, fmt.Errorf("user ID is empty") + return "", fmt.Errorf("user ID is empty") } + // if Account ID is part of the claims // it means that we've already classified the domain and user has an account if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { - return am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain) + return am.GetAccountIDByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain) } else if claims.AccountId != "" { - accountFromID, err := am.Store.GetAccount(ctx, claims.AccountId) + userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) if err != nil { - return nil, err + return "", err } - if _, ok := accountFromID.Users[claims.UserId]; !ok { - return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) + + if userAccountID != claims.AccountId { + return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) } - if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain { - return accountFromID, nil + + domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId) + if err != nil { + return "", err + } + + if domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain { + return userAccountID, nil } } @@ -1894,48 +1943,53 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C log.WithContext(ctx).Debugf("Acquired global lock in %s for user %s", time.Since(start), claims.UserId) // We checked if the domain has a primary account already - domainAccount, err := am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) + domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain) if err != nil { // if NotFound we are good to continue, otherwise return error e, ok := status.FromError(err) if !ok || e.Type() != status.NotFound { - return nil, err + return "", err } } - account, err := am.Store.GetAccountByUser(ctx, claims.UserId) + userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) if err == nil { - unlockAccount := am.Store.AcquireWriteLockByUID(ctx, account.Id) + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, userAccountID) defer unlockAccount() - account, err = am.Store.GetAccountByUser(ctx, claims.UserId) + account, err := am.Store.GetAccountByUser(ctx, claims.UserId) if err != nil { - return nil, err + return "", err } // If there is no primary domain account yet, we set the account as primary for the domain. Otherwise, // we compare the account's ID with the domain account ID, and if they don't match, we set the account as // non-primary account for the domain. We don't merge accounts at this stage, because of cases when a domain // was previously unclassified or classified as public so N users that logged int that time, has they own account // and peers that shouldn't be lost. - primaryDomain := domainAccount == nil || account.Id == domainAccount.Id - - err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims) - if err != nil { - return nil, err + primaryDomain := domainAccountID == "" || account.Id == domainAccountID + if err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims); err != nil { + return "", err } - return account, nil + + return account.Id, nil } else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { - if domainAccount != nil { - unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id) + var domainAccount *Account + if domainAccountID != "" { + unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) defer unlockAccount() domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) if err != nil { - return nil, err + return "", err } } - return am.handleNewUserAccount(ctx, domainAccount, claims) + + account, err := am.handleNewUserAccount(ctx, domainAccount, claims) + if err != nil { + return "", err + } + return account.Id, nil } else { // other error - return nil, err + return "", err } } @@ -2028,26 +2082,21 @@ func (am *DefaultAccountManager) GetDNSDomain() string { // CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT // group propagation and set the list of groups with access permissions. func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error { - account, _, err := am.GetAccountFromToken(ctx, claims) + accountID, _, err := am.GetAccountIDFromToken(ctx, claims) + if err != nil { + return err + } + + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { return err } // Ensures JWT group synchronization to the management is enabled before, // filtering access based on the allowed groups. - if account.Settings != nil && account.Settings.JWTGroupsEnabled { - if allowedGroups := account.Settings.JWTAllowGroups; len(allowedGroups) > 0 { - userJWTGroups := make([]string, 0) - - if claim, ok := claims.Raw[account.Settings.JWTGroupsClaimName]; ok { - if claimGroups, ok := claim.([]interface{}); ok { - for _, g := range claimGroups { - if group, ok := g.(string); ok { - userJWTGroups = append(userJWTGroups, group) - } - } - } - } + if settings != nil && settings.JWTGroupsEnabled { + if allowedGroups := settings.JWTAllowGroups; len(allowedGroups) > 0 { + userJWTGroups := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) if !userHasAllowedGroup(allowedGroups, userJWTGroups) { return fmt.Errorf("user does not belong to any of the allowed JWT groups") @@ -2076,6 +2125,60 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee return am.Store.GetAccountIDByPeerPubKey(ctx, peerKey) } +func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID) + if err != nil { + return false, err + } + + err = checkIfPeerOwnerIsBlocked(peer, user) + if err != nil { + return false, err + } + + if peerLoginExpired(ctx, peer, settings) { + err = am.handleExpiredPeer(ctx, user, peer) + if err != nil { + return false, err + } + return true, nil + } + + return false, nil +} + +func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Store, accountID string, peerHostName string) (string, error) { + existingLabels, err := store.GetPeerLabelsInAccount(ctx, LockingStrengthShare, accountID) + if err != nil { + return "", fmt.Errorf("failed to get peer dns labels: %w", err) + } + + labelMap := ConvertSliceToMap(existingLabels) + newLabel, err := getPeerHostLabel(peerHostName, labelMap) + if err != nil { + return "", fmt.Errorf("failed to get new host label: %w", err) + } + + if newLabel == "" { + return "", fmt.Errorf("failed to get new host label: %w", err) + } + + return newLabel, nil +} + +func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return nil, err + } + + if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) { + return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") + } + + return am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) +} + // addAllGroup to account object if it doesn't exist func addAllGroup(account *Account) error { if len(account.Groups) == 0 { @@ -2158,6 +2261,27 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac return acc } +// extractJWTGroups extracts the group names from a JWT token's claims. +func extractJWTGroups(ctx context.Context, claimName string, claims jwtclaims.AuthorizationClaims) []string { + userJWTGroups := make([]string, 0) + + if claim, ok := claims.Raw[claimName]; ok { + if claimGroups, ok := claim.([]interface{}); ok { + for _, g := range claimGroups { + if group, ok := g.(string); ok { + userJWTGroups = append(userJWTGroups, group) + } else { + log.WithContext(ctx).Debugf("JWT claim %q contains a non-string group (type: %T): %v", claimName, g, g) + } + } + } + } else { + log.WithContext(ctx).Debugf("JWT claim %q is not a string array", claimName) + } + + return userJWTGroups +} + // userHasAllowedGroup checks if a user belongs to any of the allowed groups. func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool { for _, userGroup := range userGroups { diff --git a/management/server/account_request_buffer.go b/management/server/account_request_buffer.go new file mode 100644 index 000000000..5f4897e6a --- /dev/null +++ b/management/server/account_request_buffer.go @@ -0,0 +1,108 @@ +package server + +import ( + "context" + "os" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +// AccountRequest holds the result channel to return the requested account. +type AccountRequest struct { + AccountID string + ResultChan chan *AccountResult +} + +// AccountResult holds the account data or an error. +type AccountResult struct { + Account *Account + Err error +} + +type AccountRequestBuffer struct { + store Store + getAccountRequests map[string][]*AccountRequest + mu sync.Mutex + getAccountRequestCh chan *AccountRequest + bufferInterval time.Duration +} + +func NewAccountRequestBuffer(ctx context.Context, store Store) *AccountRequestBuffer { + bufferIntervalStr := os.Getenv("NB_GET_ACCOUNT_BUFFER_INTERVAL") + bufferInterval, err := time.ParseDuration(bufferIntervalStr) + if err != nil { + if bufferIntervalStr != "" { + log.WithContext(ctx).Warnf("failed to parse account request buffer interval: %s", err) + } + bufferInterval = 100 * time.Millisecond + } + + log.WithContext(ctx).Infof("set account request buffer interval to %s", bufferInterval) + + ac := AccountRequestBuffer{ + store: store, + getAccountRequests: make(map[string][]*AccountRequest), + getAccountRequestCh: make(chan *AccountRequest), + bufferInterval: bufferInterval, + } + + go ac.processGetAccountRequests(ctx) + + return &ac +} +func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context, accountID string) (*Account, error) { + req := &AccountRequest{ + AccountID: accountID, + ResultChan: make(chan *AccountResult, 1), + } + + log.WithContext(ctx).Tracef("requesting account %s with backpressure", accountID) + startTime := time.Now() + ac.getAccountRequestCh <- req + + result := <-req.ResultChan + log.WithContext(ctx).Tracef("got account with backpressure after %s", time.Since(startTime)) + return result.Account, result.Err +} + +func (ac *AccountRequestBuffer) processGetAccountBatch(ctx context.Context, accountID string) { + ac.mu.Lock() + requests := ac.getAccountRequests[accountID] + delete(ac.getAccountRequests, accountID) + ac.mu.Unlock() + + if len(requests) == 0 { + return + } + + startTime := time.Now() + account, err := ac.store.GetAccount(ctx, accountID) + log.WithContext(ctx).Tracef("getting account %s in batch took %s", accountID, time.Since(startTime)) + result := &AccountResult{Account: account, Err: err} + + for _, req := range requests { + req.ResultChan <- result + close(req.ResultChan) + } +} + +func (ac *AccountRequestBuffer) processGetAccountRequests(ctx context.Context) { + for { + select { + case req := <-ac.getAccountRequestCh: + ac.mu.Lock() + ac.getAccountRequests[req.AccountID] = append(ac.getAccountRequests[req.AccountID], req) + if len(ac.getAccountRequests[req.AccountID]) == 1 { + go func(ctx context.Context, accountID string) { + time.Sleep(ac.bufferInterval) + ac.processGetAccountBatch(ctx, accountID) + }(ctx, req.AccountID) + } + ac.mu.Unlock() + case <-ctx.Done(): + return + } + } +} diff --git a/management/server/account_test.go b/management/server/account_test.go index d89ce4e4a..1bad43c13 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -462,7 +462,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { assert.Equal(t, account.Id, ev.TargetID) } -func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) { +func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { type initUserParams jwtclaims.AuthorizationClaims type test struct { @@ -633,9 +633,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - initAccount, err := manager.GetAccountByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) require.NoError(t, err, "create init user failed") + initAccount, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "get init account failed") + if testCase.inputUpdateAttrs { err = manager.updateAccountDomainAttributes(context.Background(), initAccount, jwtclaims.AuthorizationClaims{UserId: testCase.inputInitUserParams.UserId, Domain: testCase.inputInitUserParams.Domain, DomainCategory: testCase.inputInitUserParams.DomainCategory}, true) require.NoError(t, err, "update init user failed") @@ -645,8 +648,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) { testCase.inputClaims.AccountId = initAccount.Id } - account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims) + accountID, _, err = manager.GetAccountIDFromToken(context.Background(), testCase.inputClaims) require.NoError(t, err, "support function failed") + + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "get account failed") + verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers) verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy) @@ -669,12 +676,13 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { require.NoError(t, err, "unable to create account manager") accountID := initAccount.Id - acc, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, accountID, domain) + accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userId, accountID, domain) require.NoError(t, err, "create init user failed") // as initAccount was created without account id we have to take the id after account initialization - // that happens inside the GetAccountByUserOrAccountID where the id is getting generated + // that happens inside the GetAccountIDByUserOrAccountID where the id is getting generated // it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it - initAccount = acc + initAccount, err = manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "get init account failed") claims := jwtclaims.AuthorizationClaims{ AccountId: accountID, // is empty as it is based on accountID right after initialization of initAccount @@ -685,8 +693,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { } t.Run("JWT groups disabled", func(t *testing.T) { - account, _, err := manager.GetAccountFromToken(context.Background(), claims) + accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") + + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "get account failed") + require.Len(t, account.Groups, 1, "only ALL group should exists") }) @@ -696,8 +708,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { require.NoError(t, err, "save account failed") require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") - account, _, err := manager.GetAccountFromToken(context.Background(), claims) + accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") + + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "get account failed") + require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT") }) @@ -708,8 +724,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { require.NoError(t, err, "save account failed") require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") - account, _, err := manager.GetAccountFromToken(context.Background(), claims) + accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") + + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "get account failed") + require.Len(t, account.Groups, 3, "groups should be added to the account") groupsByNames := map[string]*group.Group{} @@ -874,21 +894,21 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { userId := "test_user" - account, err := manager.GetAccountByUserOrAccountID(context.Background(), userId, "", "") + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userId, "", "") if err != nil { t.Fatal(err) } - if account == nil { + if accountID == "" { t.Fatalf("expected to create an account for a user %s", userId) return } - _, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "") + _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "") if err != nil { - t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", account.Id) + t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountID) } - _, err = manager.GetAccountByUserOrAccountID(context.Background(), "", "", "") + _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", "", "") if err == nil { t.Errorf("expected an error when user and account IDs are empty") } @@ -1225,10 +1245,10 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { } }() - if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil { - t.Errorf("save policy: %v", err) - return - } + if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + t.Errorf("delete default rule: %v", err) + return + } wg.Wait() } @@ -1636,9 +1656,10 @@ func TestAccount_Copy(t *testing.T) { }, Routes: map[route.ID]*route.Route{ "route1": { - ID: "route1", - PeerGroups: []string{}, - Groups: []string{"group1"}, + ID: "route1", + PeerGroups: []string{}, + Groups: []string{"group1"}, + AccessControlGroups: []string{}, }, }, NameServerGroups: map[string]*nbdns.NameServerGroup{ @@ -1705,19 +1726,22 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") - assert.NotNil(t, account.Settings) - assert.Equal(t, account.Settings.PeerLoginExpirationEnabled, true) - assert.Equal(t, account.Settings.PeerLoginExpiration, 24*time.Hour) + settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "unable to get account settings") + + assert.NotNil(t, settings) + assert.Equal(t, settings.PeerLoginExpirationEnabled, true) + assert.Equal(t, settings.PeerLoginExpiration, 24*time.Hour) } func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - _, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1729,11 +1753,16 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { }) require.NoError(t, err, "unable to add peer") - account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to get the account") + + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "unable to get the account") + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) require.NoError(t, err, "unable to mark peer connected") - account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + + account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1770,7 +1799,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1781,7 +1810,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. LoginExpirationEnabled: true, }) require.NoError(t, err, "unable to add peer") - _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1798,8 +1827,12 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. }, } - account, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to get the account") + + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "unable to get the account") + // when we mark peer as connected, the peer login expiration routine should trigger err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) require.NoError(t, err, "unable to mark peer connected") @@ -1814,7 +1847,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - _, err = manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + _, err = manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1826,8 +1859,12 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test }) require.NoError(t, err, "unable to add peer") - account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to get the account") + + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "unable to get the account") + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) require.NoError(t, err, "unable to mark peer connected") @@ -1870,10 +1907,10 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - account, err := manager.GetAccountByUserOrAccountID(context.Background(), userID, "", "") + accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "") require.NoError(t, err, "unable to create an account") - updated, err := manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: false, }) @@ -1881,19 +1918,22 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { assert.False(t, updated.Settings.PeerLoginExpirationEnabled) assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour) - account, err = manager.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "") + accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "") require.NoError(t, err, "unable to get account by ID") - assert.False(t, account.Settings.PeerLoginExpirationEnabled) - assert.Equal(t, account.Settings.PeerLoginExpiration, time.Hour) + settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "unable to get account settings") - _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + assert.False(t, settings.PeerLoginExpirationEnabled) + assert.Equal(t, settings.PeerLoginExpiration, time.Hour) + + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ PeerLoginExpiration: time.Second, PeerLoginExpirationEnabled: false, }) require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour") - _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ PeerLoginExpiration: time.Hour * 24 * 181, PeerLoginExpirationEnabled: false, }) diff --git a/management/server/activity/sqlite/crypt.go b/management/server/activity/sqlite/crypt.go index cf4dda746..096f49ea3 100644 --- a/management/server/activity/sqlite/crypt.go +++ b/management/server/activity/sqlite/crypt.go @@ -6,13 +6,14 @@ import ( "crypto/cipher" "crypto/rand" "encoding/base64" - "fmt" + "errors" ) var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05} type FieldEncrypt struct { block cipher.Block + gcm cipher.AEAD } func GenerateKey() (string, error) { @@ -35,14 +36,21 @@ func NewFieldEncrypt(key string) (*FieldEncrypt, error) { if err != nil { return nil, err } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + ec := &FieldEncrypt{ block: block, + gcm: gcm, } return ec, nil } -func (ec *FieldEncrypt) Encrypt(payload string) string { +func (ec *FieldEncrypt) LegacyEncrypt(payload string) string { plainText := pkcs5Padding([]byte(payload)) cipherText := make([]byte, len(plainText)) cbc := cipher.NewCBCEncrypter(ec.block, iv) @@ -50,7 +58,22 @@ func (ec *FieldEncrypt) Encrypt(payload string) string { return base64.StdEncoding.EncodeToString(cipherText) } -func (ec *FieldEncrypt) Decrypt(data string) (string, error) { +// Encrypt encrypts plaintext using AES-GCM +func (ec *FieldEncrypt) Encrypt(payload string) (string, error) { + plaintext := []byte(payload) + nonceSize := ec.gcm.NonceSize() + + nonce := make([]byte, nonceSize, len(plaintext)+nonceSize+ec.gcm.Overhead()) + if _, err := rand.Read(nonce); err != nil { + return "", err + } + + ciphertext := ec.gcm.Seal(nonce, nonce, plaintext, nil) + + return base64.StdEncoding.EncodeToString(ciphertext), nil +} + +func (ec *FieldEncrypt) LegacyDecrypt(data string) (string, error) { cipherText, err := base64.StdEncoding.DecodeString(data) if err != nil { return "", err @@ -65,17 +88,49 @@ func (ec *FieldEncrypt) Decrypt(data string) (string, error) { return string(payload), nil } +// Decrypt decrypts ciphertext using AES-GCM +func (ec *FieldEncrypt) Decrypt(data string) (string, error) { + cipherText, err := base64.StdEncoding.DecodeString(data) + if err != nil { + return "", err + } + + nonceSize := ec.gcm.NonceSize() + if len(cipherText) < nonceSize { + return "", errors.New("cipher text too short") + } + + nonce, cipherText := cipherText[:nonceSize], cipherText[nonceSize:] + plainText, err := ec.gcm.Open(nil, nonce, cipherText, nil) + if err != nil { + return "", err + } + + return string(plainText), nil +} + func pkcs5Padding(ciphertext []byte) []byte { padding := aes.BlockSize - len(ciphertext)%aes.BlockSize padText := bytes.Repeat([]byte{byte(padding)}, padding) return append(ciphertext, padText...) } - func pkcs5UnPadding(src []byte) ([]byte, error) { srcLen := len(src) - paddingLen := int(src[srcLen-1]) - if paddingLen >= srcLen || paddingLen > aes.BlockSize { - return nil, fmt.Errorf("padding size error") + if srcLen == 0 { + return nil, errors.New("input data is empty") } + + paddingLen := int(src[srcLen-1]) + if paddingLen == 0 || paddingLen > aes.BlockSize || paddingLen > srcLen { + return nil, errors.New("invalid padding size") + } + + // Verify that all padding bytes are the same + for i := 0; i < paddingLen; i++ { + if src[srcLen-1-i] != byte(paddingLen) { + return nil, errors.New("invalid padding") + } + } + return src[:srcLen-paddingLen], nil } diff --git a/management/server/activity/sqlite/crypt_test.go b/management/server/activity/sqlite/crypt_test.go index efa740921..aff3a08b1 100644 --- a/management/server/activity/sqlite/crypt_test.go +++ b/management/server/activity/sqlite/crypt_test.go @@ -1,6 +1,7 @@ package sqlite import ( + "bytes" "testing" ) @@ -15,7 +16,11 @@ func TestGenerateKey(t *testing.T) { t.Fatalf("failed to init email encryption: %s", err) } - encrypted := ee.Encrypt(testData) + encrypted, err := ee.Encrypt(testData) + if err != nil { + t.Fatalf("failed to encrypt data: %s", err) + } + if encrypted == "" { t.Fatalf("invalid encrypted text") } @@ -30,6 +35,32 @@ func TestGenerateKey(t *testing.T) { } } +func TestGenerateKeyLegacy(t *testing.T) { + testData := "exampl@netbird.io" + key, err := GenerateKey() + if err != nil { + t.Fatalf("failed to generate key: %s", err) + } + ee, err := NewFieldEncrypt(key) + if err != nil { + t.Fatalf("failed to init email encryption: %s", err) + } + + encrypted := ee.LegacyEncrypt(testData) + if encrypted == "" { + t.Fatalf("invalid encrypted text") + } + + decrypted, err := ee.LegacyDecrypt(encrypted) + if err != nil { + t.Fatalf("failed to decrypt data: %s", err) + } + + if decrypted != testData { + t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted) + } +} + func TestCorruptKey(t *testing.T) { testData := "exampl@netbird.io" key, err := GenerateKey() @@ -41,7 +72,11 @@ func TestCorruptKey(t *testing.T) { t.Fatalf("failed to init email encryption: %s", err) } - encrypted := ee.Encrypt(testData) + encrypted, err := ee.Encrypt(testData) + if err != nil { + t.Fatalf("failed to encrypt data: %s", err) + } + if encrypted == "" { t.Fatalf("invalid encrypted text") } @@ -61,3 +96,215 @@ func TestCorruptKey(t *testing.T) { t.Fatalf("incorrect decryption, the result is: %s", res) } } + +func TestEncryptDecrypt(t *testing.T) { + // Generate a key for encryption/decryption + key, err := GenerateKey() + if err != nil { + t.Fatalf("Failed to generate key: %v", err) + } + + // Initialize the FieldEncrypt with the generated key + ec, err := NewFieldEncrypt(key) + if err != nil { + t.Fatalf("Failed to create FieldEncrypt: %v", err) + } + + // Test cases + testCases := []struct { + name string + input string + }{ + { + name: "Empty String", + input: "", + }, + { + name: "Short String", + input: "Hello", + }, + { + name: "String with Spaces", + input: "Hello, World!", + }, + { + name: "Long String", + input: "The quick brown fox jumps over the lazy dog.", + }, + { + name: "Unicode Characters", + input: "こんにちは世界", + }, + { + name: "Special Characters", + input: "!@#$%^&*()_+-=[]{}|;':\",./<>?", + }, + { + name: "Numeric String", + input: "1234567890", + }, + { + name: "Repeated Characters", + input: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + }, + { + name: "Multi-block String", + input: "This is a longer string that will span multiple blocks in the encryption algorithm.", + }, + { + name: "Non-ASCII and ASCII Mix", + input: "Hello 世界 123", + }, + } + + for _, tc := range testCases { + t.Run(tc.name+" - Legacy", func(t *testing.T) { + // Legacy Encryption + encryptedLegacy := ec.LegacyEncrypt(tc.input) + if encryptedLegacy == "" { + t.Errorf("LegacyEncrypt returned empty string for input '%s'", tc.input) + } + + // Legacy Decryption + decryptedLegacy, err := ec.LegacyDecrypt(encryptedLegacy) + if err != nil { + t.Errorf("LegacyDecrypt failed for input '%s': %v", tc.input, err) + } + + // Verify that the decrypted value matches the original input + if decryptedLegacy != tc.input { + t.Errorf("LegacyDecrypt output '%s' does not match original input '%s'", decryptedLegacy, tc.input) + } + }) + + t.Run(tc.name+" - New", func(t *testing.T) { + // New Encryption + encryptedNew, err := ec.Encrypt(tc.input) + if err != nil { + t.Errorf("Encrypt failed for input '%s': %v", tc.input, err) + } + if encryptedNew == "" { + t.Errorf("Encrypt returned empty string for input '%s'", tc.input) + } + + // New Decryption + decryptedNew, err := ec.Decrypt(encryptedNew) + if err != nil { + t.Errorf("Decrypt failed for input '%s': %v", tc.input, err) + } + + // Verify that the decrypted value matches the original input + if decryptedNew != tc.input { + t.Errorf("Decrypt output '%s' does not match original input '%s'", decryptedNew, tc.input) + } + }) + } +} + +func TestPKCS5UnPadding(t *testing.T) { + tests := []struct { + name string + input []byte + expected []byte + expectError bool + }{ + { + name: "Valid Padding", + input: append([]byte("Hello, World!"), bytes.Repeat([]byte{4}, 4)...), + expected: []byte("Hello, World!"), + }, + { + name: "Empty Input", + input: []byte{}, + expectError: true, + }, + { + name: "Padding Length Zero", + input: append([]byte("Hello, World!"), bytes.Repeat([]byte{0}, 4)...), + expectError: true, + }, + { + name: "Padding Length Exceeds Block Size", + input: append([]byte("Hello, World!"), bytes.Repeat([]byte{17}, 17)...), + expectError: true, + }, + { + name: "Padding Length Exceeds Input Length", + input: []byte{5, 5, 5}, + expectError: true, + }, + { + name: "Invalid Padding Bytes", + input: append([]byte("Hello, World!"), []byte{2, 3, 4, 5}...), + expectError: true, + }, + { + name: "Valid Single Byte Padding", + input: append([]byte("Hello, World!"), byte(1)), + expected: []byte("Hello, World!"), + }, + { + name: "Invalid Mixed Padding Bytes", + input: append([]byte("Hello, World!"), []byte{3, 3, 2}...), + expectError: true, + }, + { + name: "Valid Full Block Padding", + input: append([]byte("Hello, World!"), bytes.Repeat([]byte{16}, 16)...), + expected: []byte("Hello, World!"), + }, + { + name: "Non-Padding Byte at End", + input: append([]byte("Hello, World!"), []byte{4, 4, 4, 5}...), + expectError: true, + }, + { + name: "Valid Padding with Different Text Length", + input: append([]byte("Test"), bytes.Repeat([]byte{12}, 12)...), + expected: []byte("Test"), + }, + { + name: "Padding Length Equal to Input Length", + input: bytes.Repeat([]byte{8}, 8), + expected: []byte{}, + }, + { + name: "Invalid Padding Length Zero (Again)", + input: append([]byte("Test"), byte(0)), + expectError: true, + }, + { + name: "Padding Length Greater Than Input", + input: []byte{10}, + expectError: true, + }, + { + name: "Input Length Not Multiple of Block Size", + input: append([]byte("Invalid Length"), byte(1)), + expected: []byte("Invalid Length"), + }, + { + name: "Valid Padding with Non-ASCII Characters", + input: append([]byte("こんにちは"), bytes.Repeat([]byte{2}, 2)...), + expected: []byte("こんにちは"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := pkcs5UnPadding(tt.input) + if tt.expectError { + if err == nil { + t.Errorf("Expected error but got nil") + } + } else { + if err != nil { + t.Errorf("Did not expect error but got: %v", err) + } + if !bytes.Equal(result, tt.expected) { + t.Errorf("Expected output %v, got %v", tt.expected, result) + } + } + }) + } +} diff --git a/management/server/activity/sqlite/migration.go b/management/server/activity/sqlite/migration.go new file mode 100644 index 000000000..28c5b3020 --- /dev/null +++ b/management/server/activity/sqlite/migration.go @@ -0,0 +1,157 @@ +package sqlite + +import ( + "context" + "database/sql" + "fmt" + + log "github.com/sirupsen/logrus" +) + +func migrate(ctx context.Context, crypt *FieldEncrypt, db *sql.DB) error { + if _, err := db.Exec(createTableQuery); err != nil { + return err + } + + if _, err := db.Exec(creatTableDeletedUsersQuery); err != nil { + return err + } + + if err := updateDeletedUsersTable(ctx, db); err != nil { + return fmt.Errorf("failed to update deleted_users table: %v", err) + } + + return migrateLegacyEncryptedUsersToGCM(ctx, crypt, db) +} + +// updateDeletedUsersTable checks and updates the deleted_users table schema to ensure required columns exist. +func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error { + exists, err := checkColumnExists(db, "deleted_users", "name") + if err != nil { + return err + } + + if !exists { + log.WithContext(ctx).Debug("Adding name column to the deleted_users table") + + _, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`) + if err != nil { + return err + } + + log.WithContext(ctx).Debug("Successfully added name column to the deleted_users table") + } + + exists, err = checkColumnExists(db, "deleted_users", "enc_algo") + if err != nil { + return err + } + + if !exists { + log.WithContext(ctx).Debug("Adding enc_algo column to the deleted_users table") + + _, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN enc_algo TEXT;`) + if err != nil { + return err + } + + log.WithContext(ctx).Debug("Successfully added enc_algo column to the deleted_users table") + } + + return nil +} + +// migrateLegacyEncryptedUsersToGCM migrates previously encrypted data using, +// legacy CBC encryption with a static IV to the new GCM encryption method. +func migrateLegacyEncryptedUsersToGCM(ctx context.Context, crypt *FieldEncrypt, db *sql.DB) error { + log.WithContext(ctx).Debug("Migrating CBC encrypted deleted users to GCM") + + tx, err := db.Begin() + if err != nil { + return fmt.Errorf("failed to begin transaction: %v", err) + } + defer func() { + _ = tx.Rollback() + }() + + rows, err := tx.Query(fmt.Sprintf(`SELECT id, email, name FROM deleted_users where enc_algo IS NULL OR enc_algo != '%s'`, gcmEncAlgo)) + if err != nil { + return fmt.Errorf("failed to execute select query: %v", err) + } + defer rows.Close() + + updateStmt, err := tx.Prepare(`UPDATE deleted_users SET email = ?, name = ?, enc_algo = ? WHERE id = ?`) + if err != nil { + return fmt.Errorf("failed to prepare update statement: %v", err) + } + defer updateStmt.Close() + + if err = processUserRows(ctx, crypt, rows, updateStmt); err != nil { + return err + } + + if err = tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %v", err) + } + + log.WithContext(ctx).Debug("Successfully migrated CBC encrypted deleted users to GCM") + return nil +} + +// processUserRows processes database rows of user data, decrypts legacy encryption fields, and re-encrypts them using GCM. +func processUserRows(ctx context.Context, crypt *FieldEncrypt, rows *sql.Rows, updateStmt *sql.Stmt) error { + for rows.Next() { + var ( + id, decryptedEmail, decryptedName string + email, name *string + ) + + err := rows.Scan(&id, &email, &name) + if err != nil { + return err + } + + if email != nil { + decryptedEmail, err = crypt.LegacyDecrypt(*email) + if err != nil { + log.WithContext(ctx).Warnf("skipping migrating deleted user %s: %v", + id, + fmt.Errorf("failed to decrypt email: %w", err), + ) + continue + } + } + + if name != nil { + decryptedName, err = crypt.LegacyDecrypt(*name) + if err != nil { + log.WithContext(ctx).Warnf("skipping migrating deleted user %s: %v", + id, + fmt.Errorf("failed to decrypt name: %w", err), + ) + continue + } + } + + encryptedEmail, err := crypt.Encrypt(decryptedEmail) + if err != nil { + return fmt.Errorf("failed to encrypt email: %w", err) + } + + encryptedName, err := crypt.Encrypt(decryptedName) + if err != nil { + return fmt.Errorf("failed to encrypt name: %w", err) + } + + _, err = updateStmt.Exec(encryptedEmail, encryptedName, gcmEncAlgo, id) + if err != nil { + return err + } + } + + if err := rows.Err(); err != nil { + return err + } + + return nil +} diff --git a/management/server/activity/sqlite/migration_test.go b/management/server/activity/sqlite/migration_test.go new file mode 100644 index 000000000..a03774fa8 --- /dev/null +++ b/management/server/activity/sqlite/migration_test.go @@ -0,0 +1,84 @@ +package sqlite + +import ( + "context" + "database/sql" + "path/filepath" + "testing" + "time" + + _ "github.com/mattn/go-sqlite3" + "github.com/netbirdio/netbird/management/server/activity" + + "github.com/stretchr/testify/require" +) + +func setupDatabase(t *testing.T) *sql.DB { + t.Helper() + + dbFile := filepath.Join(t.TempDir(), eventSinkDB) + db, err := sql.Open("sqlite3", dbFile) + require.NoError(t, err, "Failed to open database") + + t.Cleanup(func() { + _ = db.Close() + }) + + _, err = db.Exec(createTableQuery) + require.NoError(t, err, "Failed to create events table") + + _, err = db.Exec(`CREATE TABLE deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);`) + require.NoError(t, err, "Failed to create deleted_users table") + + return db +} + +func TestMigrate(t *testing.T) { + db := setupDatabase(t) + + key, err := GenerateKey() + require.NoError(t, err, "Failed to generate key") + + crypt, err := NewFieldEncrypt(key) + require.NoError(t, err, "Failed to initialize FieldEncrypt") + + legacyEmail := crypt.LegacyEncrypt("testaccount@test.com") + legacyName := crypt.LegacyEncrypt("Test Account") + + _, err = db.Exec(`INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) VALUES(?, ?, ?, ?, ?, ?)`, + activity.UserDeleted, time.Now(), "initiatorID", "targetID", "accountID", "") + require.NoError(t, err, "Failed to insert event") + + _, err = db.Exec(`INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)`, "targetID", legacyEmail, legacyName) + require.NoError(t, err, "Failed to insert legacy encrypted data") + + colExists, err := checkColumnExists(db, "deleted_users", "enc_algo") + require.NoError(t, err, "Failed to check if enc_algo column exists") + require.False(t, colExists, "enc_algo column should not exist before migration") + + err = migrate(context.Background(), crypt, db) + require.NoError(t, err, "Migration failed") + + colExists, err = checkColumnExists(db, "deleted_users", "enc_algo") + require.NoError(t, err, "Failed to check if enc_algo column exists after migration") + require.True(t, colExists, "enc_algo column should exist after migration") + + var encAlgo string + err = db.QueryRow(`SELECT enc_algo FROM deleted_users LIMIT 1`, "").Scan(&encAlgo) + require.NoError(t, err, "Failed to select updated data") + require.Equal(t, gcmEncAlgo, encAlgo, "enc_algo should be set to 'GCM' after migration") + + store, err := createStore(crypt, db) + require.NoError(t, err, "Failed to create store") + + events, err := store.Get(context.Background(), "accountID", 0, 1, false) + require.NoError(t, err, "Failed to get events") + + require.Len(t, events, 1, "Should have one event") + require.Equal(t, activity.UserDeleted, events[0].Activity, "activity should match") + require.Equal(t, "initiatorID", events[0].InitiatorID, "initiator id should match") + require.Equal(t, "targetID", events[0].TargetID, "target id should match") + require.Equal(t, "accountID", events[0].AccountID, "account id should match") + require.Equal(t, "testaccount@test.com", events[0].Meta["email"], "email should match") + require.Equal(t, "Test Account", events[0].Meta["username"], "username should match") +} diff --git a/management/server/activity/sqlite/sqlite.go b/management/server/activity/sqlite/sqlite.go index fadf1eb07..823e0b4ac 100644 --- a/management/server/activity/sqlite/sqlite.go +++ b/management/server/activity/sqlite/sqlite.go @@ -26,7 +26,7 @@ const ( "meta TEXT," + " target_id TEXT);" - creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT);` + creatTableDeletedUsersQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL, name TEXT, enc_algo TEXT NOT NULL);` selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.name as "initiator_name", i.email as "initiator_email", target_id, t.name as "target_name", t.email as "target_email", account_id, meta FROM events @@ -69,10 +69,12 @@ const ( and some selfhosted deployments might have duplicates already so we need to clean the table first. */ - insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name) VALUES(?, ?, ?)` + insertDeleteUserQuery = `INSERT INTO deleted_users(id, email, name, enc_algo) VALUES(?, ?, ?, ?)` fallbackName = "unknown" fallbackEmail = "unknown@unknown.com" + + gcmEncAlgo = "GCM" ) // Store is the implementation of the activity.Store interface backed by SQLite @@ -100,58 +102,12 @@ func NewSQLiteStore(ctx context.Context, dataDir string, encryptionKey string) ( return nil, err } - _, err = db.Exec(createTableQuery) - if err != nil { + if err = migrate(ctx, crypt, db); err != nil { _ = db.Close() - return nil, err + return nil, fmt.Errorf("events database migration: %w", err) } - _, err = db.Exec(creatTableDeletedUsersQuery) - if err != nil { - _ = db.Close() - return nil, err - } - - err = updateDeletedUsersTable(ctx, db) - if err != nil { - _ = db.Close() - return nil, err - } - - insertStmt, err := db.Prepare(insertQuery) - if err != nil { - _ = db.Close() - return nil, err - } - - selectDescStmt, err := db.Prepare(selectDescQuery) - if err != nil { - _ = db.Close() - return nil, err - } - - selectAscStmt, err := db.Prepare(selectAscQuery) - if err != nil { - _ = db.Close() - return nil, err - } - - deleteUserStmt, err := db.Prepare(insertDeleteUserQuery) - if err != nil { - _ = db.Close() - return nil, err - } - - s := &Store{ - db: db, - fieldEncrypt: crypt, - insertStatement: insertStmt, - selectDescStatement: selectDescStmt, - selectAscStatement: selectAscStmt, - deleteUserStmt: deleteUserStmt, - } - - return s, nil + return createStore(crypt, db) } func (store *Store) processResult(ctx context.Context, result *sql.Rows) ([]*activity.Event, error) { @@ -302,9 +258,16 @@ func (store *Store) saveDeletedUserEmailAndNameInEncrypted(event *activity.Event return event.Meta, nil } - encryptedEmail := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email)) - encryptedName := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name)) - _, err := store.deleteUserStmt.Exec(event.TargetID, encryptedEmail, encryptedName) + encryptedEmail, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", email)) + if err != nil { + return nil, err + } + encryptedName, err := store.fieldEncrypt.Encrypt(fmt.Sprintf("%s", name)) + if err != nil { + return nil, err + } + + _, err = store.deleteUserStmt.Exec(event.TargetID, encryptedEmail, encryptedName, gcmEncAlgo) if err != nil { return nil, err } @@ -325,43 +288,70 @@ func (store *Store) Close(_ context.Context) error { return nil } -func updateDeletedUsersTable(ctx context.Context, db *sql.DB) error { - log.WithContext(ctx).Debugf("check deleted_users table version") - rows, err := db.Query(`PRAGMA table_info(deleted_users);`) +// createStore initializes and returns a new Store instance with prepared SQL statements. +func createStore(crypt *FieldEncrypt, db *sql.DB) (*Store, error) { + insertStmt, err := db.Prepare(insertQuery) if err != nil { - return err + _ = db.Close() + return nil, err + } + + selectDescStmt, err := db.Prepare(selectDescQuery) + if err != nil { + _ = db.Close() + return nil, err + } + + selectAscStmt, err := db.Prepare(selectAscQuery) + if err != nil { + _ = db.Close() + return nil, err + } + + deleteUserStmt, err := db.Prepare(insertDeleteUserQuery) + if err != nil { + _ = db.Close() + return nil, err + } + + return &Store{ + db: db, + fieldEncrypt: crypt, + insertStatement: insertStmt, + selectDescStatement: selectDescStmt, + selectAscStatement: selectAscStmt, + deleteUserStmt: deleteUserStmt, + }, nil +} + +// checkColumnExists checks if a column exists in a specified table +func checkColumnExists(db *sql.DB, tableName, columnName string) (bool, error) { + query := fmt.Sprintf("PRAGMA table_info(%s);", tableName) + rows, err := db.Query(query) + if err != nil { + return false, fmt.Errorf("failed to query table info: %w", err) } defer rows.Close() - found := false + for rows.Next() { - var ( - cid int - name string - dataType string - notNull int - dfltVal sql.NullString - pk int - ) - err := rows.Scan(&cid, &name, &dataType, ¬Null, &dfltVal, &pk) + var cid int + var name, ctype string + var notnull, pk int + var dfltValue sql.NullString + + err = rows.Scan(&cid, &name, &ctype, ¬null, &dfltValue, &pk) if err != nil { - return err + return false, fmt.Errorf("failed to scan row: %w", err) } - if name == "name" { - found = true - break + + if name == columnName { + return true, nil } } - err = rows.Err() - if err != nil { - return err + if err = rows.Err(); err != nil { + return false, err } - if found { - return nil - } - - log.WithContext(ctx).Debugf("update delted_users table") - _, err = db.Exec(`ALTER TABLE deleted_users ADD COLUMN name TEXT;`) - return err + return false, nil } diff --git a/management/server/config.go b/management/server/config.go index 4efe4fe74..2f7e49766 100644 --- a/management/server/config.go +++ b/management/server/config.go @@ -34,6 +34,7 @@ const ( type Config struct { Stuns []*Host TURNConfig *TURNConfig + Relay *Relay Signal *Host Datadir string @@ -75,6 +76,12 @@ type TURNConfig struct { Turns []*Host } +type Relay struct { + Addresses []string + CredentialsTTL util.Duration + Secret string +} + // HttpServerConfig is a config of the HTTP Management service server type HttpServerConfig struct { LetsEncryptDomain string diff --git a/management/server/dns.go b/management/server/dns.go index fa62a0d4a..256b8b125 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -80,24 +80,16 @@ func (d DNSSettings) Copy() DNSSettings { // GetDNSSettings validates a user role and returns the DNS settings for the provided account ID func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !(user.HasAdminPower() || user.IsServiceUser) { + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings") } - dnsSettings := account.DNSSettings.Copy() - return &dnsSettings, nil + + return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) } // SaveDNSSettings validates a user role and updates the account's DNS settings diff --git a/management/server/ephemeral_test.go b/management/server/ephemeral_test.go index 36c88f1d1..1390352a5 100644 --- a/management/server/ephemeral_test.go +++ b/management/server/ephemeral_test.go @@ -7,6 +7,7 @@ import ( "time" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/status" ) type MockStore struct { @@ -24,7 +25,7 @@ func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Accou return s.account, nil } - return nil, fmt.Errorf("account not found") + return nil, status.NewPeerNotFoundError(peerId) } type MocAccountManager struct { diff --git a/management/server/file_store.go b/management/server/file_store.go index 6e3536bcd..994a4b1ee 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -2,20 +2,23 @@ package server import ( "context" + "errors" + "net" "os" "path/filepath" "strings" "sync" "time" - "github.com/rs/xid" - log "github.com/sirupsen/logrus" - + "github.com/netbirdio/netbird/dns" nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/route" + "github.com/rs/xid" + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/util" ) @@ -46,6 +49,158 @@ type FileStore struct { metrics telemetry.AppMetrics `json:"-"` } +func (s *FileStore) ExecuteInTransaction(ctx context.Context, f func(store Store) error) error { + return f(s) +} + +func (s *FileStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error { + s.mux.Lock() + defer s.mux.Unlock() + + accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKeyID)] + if !ok { + return status.NewSetupKeyNotFoundError() + } + + account, err := s.getAccount(accountID) + if err != nil { + return err + } + + account.SetupKeys[setupKeyID].UsedTimes++ + + return s.SaveAccount(ctx, account) +} + +func (s *FileStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { + s.mux.Lock() + defer s.mux.Unlock() + + account, err := s.getAccount(accountID) + if err != nil { + return err + } + + allGroup, err := account.GetGroupAll() + if err != nil || allGroup == nil { + return errors.New("all group not found") + } + + allGroup.Peers = append(allGroup.Peers, peerID) + + return nil +} + +func (s *FileStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error { + s.mux.Lock() + defer s.mux.Unlock() + + account, err := s.getAccount(accountId) + if err != nil { + return err + } + + account.Groups[groupID].Peers = append(account.Groups[groupID].Peers, peerId) + + return nil +} + +func (s *FileStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { + s.mux.Lock() + defer s.mux.Unlock() + + account, ok := s.Accounts[peer.AccountID] + if !ok { + return status.NewAccountNotFoundError(peer.AccountID) + } + + account.Peers[peer.ID] = peer + return s.SaveAccount(ctx, account) +} + +func (s *FileStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { + s.mux.Lock() + defer s.mux.Unlock() + + account, ok := s.Accounts[accountId] + if !ok { + return status.NewAccountNotFoundError(accountId) + } + + account.Network.Serial++ + + return s.SaveAccount(ctx, account) +} + +func (s *FileStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) { + s.mux.Lock() + defer s.mux.Unlock() + + accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(key)] + if !ok { + return nil, status.NewSetupKeyNotFoundError() + } + + account, err := s.getAccount(accountID) + if err != nil { + return nil, err + } + + setupKey, ok := account.SetupKeys[key] + if !ok { + return nil, status.Errorf(status.NotFound, "setup key not found") + } + + return setupKey, nil +} + +func (s *FileStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) { + s.mux.Lock() + defer s.mux.Unlock() + + account, err := s.getAccount(accountID) + if err != nil { + return nil, err + } + + var takenIps []net.IP + for _, existingPeer := range account.Peers { + takenIps = append(takenIps, existingPeer.IP) + } + + return takenIps, nil +} + +func (s *FileStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) { + s.mux.Lock() + defer s.mux.Unlock() + + account, err := s.getAccount(accountID) + if err != nil { + return nil, err + } + + existingLabels := []string{} + for _, peer := range account.Peers { + if peer.DNSLabel != "" { + existingLabels = append(existingLabels, peer.DNSLabel) + } + } + return existingLabels, nil +} + +func (s *FileStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) { + s.mux.Lock() + defer s.mux.Unlock() + + account, err := s.getAccount(accountID) + if err != nil { + return nil, err + } + + return account.Network, nil +} + type StoredAccount struct{} // NewFileStore restores a store from the file located in the datadir @@ -422,7 +577,7 @@ func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*A accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)] if !ok { - return nil, status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists") + return nil, status.NewSetupKeyNotFoundError() } account, err := s.getAccount(accountID) @@ -469,6 +624,44 @@ func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User, return account.Users[userID].Copy(), nil } +func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID string) (*User, error) { + accountID, ok := s.UserID2AccountID[userID] + if !ok { + return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists") + } + + account, err := s.getAccount(accountID) + if err != nil { + return nil, err + } + + user := account.Users[userID].Copy() + pat := make([]PersonalAccessToken, 0, len(user.PATs)) + for _, token := range user.PATs { + if token != nil { + pat = append(pat, *token) + } + } + user.PATsG = pat + + return user, nil +} + +func (s *FileStore) GetAccountGroups(_ context.Context, accountID string) ([]*nbgroup.Group, error) { + account, err := s.getAccount(accountID) + if err != nil { + return nil, err + } + + groupsSlice := make([]*nbgroup.Group, 0, len(account.Groups)) + + for _, group := range account.Groups { + groupsSlice = append(groupsSlice, group) + } + + return groupsSlice, nil +} + // GetAllAccounts returns all accounts func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) { s.mux.Lock() @@ -484,7 +677,7 @@ func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) { func (s *FileStore) getAccount(accountID string) (*Account, error) { account, ok := s.Accounts[accountID] if !ok { - return nil, status.Errorf(status.NotFound, "account not found") + return nil, status.NewAccountNotFoundError(accountID) } return account, nil @@ -610,13 +803,13 @@ func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) ( accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)] if !ok { - return "", status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists") + return "", status.NewSetupKeyNotFoundError() } return accountID, nil } -func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbpeer.Peer, error) { +func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, _ LockingStrength, peerKey string) (*nbpeer.Peer, error) { s.mux.Lock() defer s.mux.Unlock() @@ -639,7 +832,7 @@ func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbp return nil, status.NewPeerNotFoundError(peerKey) } -func (s *FileStore) GetAccountSettings(_ context.Context, accountID string) (*Settings, error) { +func (s *FileStore) GetAccountSettings(_ context.Context, _ LockingStrength, accountID string) (*Settings, error) { s.mux.Lock() defer s.mux.Unlock() @@ -729,7 +922,7 @@ func (s *FileStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer. } // SaveUserLastLogin stores the last login time for a user in memory. It doesn't attempt to persist data to speed up things. -func (s *FileStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error { +func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID string, lastLogin time.Time) error { s.mux.Lock() defer s.mux.Unlock() @@ -748,7 +941,7 @@ func (s *FileStore) SaveUserLastLogin(accountID, userID string, lastLogin time.T return nil } -func (s *FileStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { +func (s *FileStore) GetPostureCheckByChecksDefinition(_ string, _ *posture.ChecksDefinition) (*posture.Checks, error) { return nil, status.Errorf(status.Internal, "GetPostureCheckByChecksDefinition is not implemented") } @@ -767,10 +960,85 @@ func (s *FileStore) GetStoreEngine() StoreEngine { return FileStoreEngine } -func (s *FileStore) SaveUsers(accountID string, users map[string]*User) error { +func (s *FileStore) SaveUsers(_ string, _ map[string]*User) error { return status.Errorf(status.Internal, "SaveUsers is not implemented") } -func (s *FileStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error { +func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error { return status.Errorf(status.Internal, "SaveGroups is not implemented") } + +func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ LockingStrength, _ string) (string, error) { + return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented") +} + +func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ LockingStrength, accountID string) (string, string, error) { + s.mux.Lock() + defer s.mux.Unlock() + + account, err := s.getAccount(accountID) + if err != nil { + return "", "", err + } + + return account.Domain, account.DomainCategory, nil +} + +// AccountExists checks whether an account exists by the given ID. +func (s *FileStore) AccountExists(_ context.Context, _ LockingStrength, id string) (bool, error) { + _, exists := s.Accounts[id] + return exists, nil +} + +func (s *FileStore) GetAccountDNSSettings(_ context.Context, _ LockingStrength, _ string) (*DNSSettings, error) { + return nil, status.Errorf(status.Internal, "GetAccountDNSSettings is not implemented") +} + +func (s *FileStore) GetGroupByID(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) { + return nil, status.Errorf(status.Internal, "GetGroupByID is not implemented") +} + +func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) { + return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented") +} + +func (s *FileStore) GetAccountPolicies(_ context.Context, _ LockingStrength, _ string) ([]*Policy, error) { + return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented") +} + +func (s *FileStore) GetPolicyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*Policy, error) { + return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented") + +} + +func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ LockingStrength, _ string) ([]*posture.Checks, error) { + return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented") +} + +func (s *FileStore) GetPostureChecksByID(_ context.Context, _ LockingStrength, _ string, _ string) (*posture.Checks, error) { + return nil, status.Errorf(status.Internal, "GetPostureChecksByID is not implemented") +} + +func (s *FileStore) GetAccountRoutes(_ context.Context, _ LockingStrength, _ string) ([]*route.Route, error) { + return nil, status.Errorf(status.Internal, "GetAccountRoutes is not implemented") +} + +func (s *FileStore) GetRouteByID(_ context.Context, _ LockingStrength, _ string, _ string) (*route.Route, error) { + return nil, status.Errorf(status.Internal, "GetRouteByID is not implemented") +} + +func (s *FileStore) GetAccountSetupKeys(_ context.Context, _ LockingStrength, _ string) ([]*SetupKey, error) { + return nil, status.Errorf(status.Internal, "GetAccountSetupKeys is not implemented") +} + +func (s *FileStore) GetSetupKeyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*SetupKey, error) { + return nil, status.Errorf(status.Internal, "GetSetupKeyByID is not implemented") +} + +func (s *FileStore) GetAccountNameServerGroups(_ context.Context, _ LockingStrength, _ string) ([]*dns.NameServerGroup, error) { + return nil, status.Errorf(status.Internal, "GetAccountNameServerGroups is not implemented") +} + +func (s *FileStore) GetNameServerGroupByID(_ context.Context, _ LockingStrength, _ string, _ string) (*dns.NameServerGroup, error) { + return nil, status.Errorf(status.Internal, "GetNameServerGroupByID is not implemented") +} diff --git a/management/server/geolocation/database.go b/management/server/geolocation/database.go index c9b2eafff..21ae93b9d 100644 --- a/management/server/geolocation/database.go +++ b/management/server/geolocation/database.go @@ -1,10 +1,9 @@ package geolocation import ( + "context" "encoding/csv" - "fmt" "io" - "net/url" "os" "path" "strconv" @@ -20,26 +19,27 @@ const ( geoLiteCityZipURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City-CSV/download?suffix=zip" geoLiteCitySha256TarURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz.sha256" geoLiteCitySha256ZipURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City-CSV/download?suffix=zip.sha256" + geoLiteCityMMDB = "GeoLite2-City.mmdb" + geoLiteCityCSV = "GeoLite2-City-Locations-en.csv" ) // loadGeolocationDatabases loads the MaxMind databases. -func loadGeolocationDatabases(dataDir string) error { - files := []string{MMDBFileName, GeoSqliteDBFile} - for _, file := range files { +func loadGeolocationDatabases(ctx context.Context, dataDir string, mmdbFile string, geonamesdbFile string) error { + for _, file := range []string{mmdbFile, geonamesdbFile} { exists, _ := fileExists(path.Join(dataDir, file)) if exists { continue } - log.Infof("geo location file %s not found , file will be downloaded", file) + log.WithContext(ctx).Infof("Geolocation database file %s not found, file will be downloaded", file) switch file { - case MMDBFileName: + case mmdbFile: extractFunc := func(src string, dst string) error { if err := decompressTarGzFile(src, dst); err != nil { return err } - return copyFile(path.Join(dst, MMDBFileName), path.Join(dataDir, MMDBFileName)) + return copyFile(path.Join(dst, geoLiteCityMMDB), path.Join(dataDir, mmdbFile)) } if err := loadDatabase( geoLiteCitySha256TarURL, @@ -49,13 +49,13 @@ func loadGeolocationDatabases(dataDir string) error { return err } - case GeoSqliteDBFile: + case geonamesdbFile: extractFunc := func(src string, dst string) error { if err := decompressZipFile(src, dst); err != nil { return err } - extractedCsvFile := path.Join(dst, "GeoLite2-City-Locations-en.csv") - return importCsvToSqlite(dataDir, extractedCsvFile) + extractedCsvFile := path.Join(dst, geoLiteCityCSV) + return importCsvToSqlite(dataDir, extractedCsvFile, geonamesdbFile) } if err := loadDatabase( @@ -79,7 +79,12 @@ func loadDatabase(checksumURL string, fileURL string, extractFunc func(src strin } defer os.RemoveAll(temp) - checksumFile := path.Join(temp, getDatabaseFileName(checksumURL)) + checksumFilename, err := getFilenameFromURL(checksumURL) + if err != nil { + return err + } + checksumFile := path.Join(temp, checksumFilename) + err = downloadFile(checksumURL, checksumFile) if err != nil { return err @@ -90,7 +95,12 @@ func loadDatabase(checksumURL string, fileURL string, extractFunc func(src strin return err } - dbFile := path.Join(temp, getDatabaseFileName(fileURL)) + dbFilename, err := getFilenameFromURL(fileURL) + if err != nil { + return err + } + dbFile := path.Join(temp, dbFilename) + err = downloadFile(fileURL, dbFile) if err != nil { return err @@ -104,13 +114,13 @@ func loadDatabase(checksumURL string, fileURL string, extractFunc func(src strin } // importCsvToSqlite imports a CSV file into a SQLite database. -func importCsvToSqlite(dataDir string, csvFile string) error { +func importCsvToSqlite(dataDir string, csvFile string, geonamesdbFile string) error { geonames, err := loadGeonamesCsv(csvFile) if err != nil { return err } - db, err := gorm.Open(sqlite.Open(path.Join(dataDir, GeoSqliteDBFile)), &gorm.Config{ + db, err := gorm.Open(sqlite.Open(path.Join(dataDir, geonamesdbFile)), &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), CreateBatchSize: 1000, PrepareStmt: true, @@ -178,18 +188,6 @@ func loadGeonamesCsv(filepath string) ([]GeoNames, error) { return geoNames, nil } -// getDatabaseFileName extracts the file name from a given URL string. -func getDatabaseFileName(urlStr string) string { - u, err := url.Parse(urlStr) - if err != nil { - panic(err) - } - - ext := u.Query().Get("suffix") - fileName := fmt.Sprintf("%s.%s", path.Base(u.Path), ext) - return fileName -} - // copyFile performs a file copy operation from the source file to the destination. func copyFile(src string, dst string) error { srcFile, err := os.Open(src) diff --git a/management/server/geolocation/geolocation.go b/management/server/geolocation/geolocation.go index 794f9d0be..553a31581 100644 --- a/management/server/geolocation/geolocation.go +++ b/management/server/geolocation/geolocation.go @@ -1,29 +1,25 @@ package geolocation import ( - "bytes" "context" "fmt" "net" "os" "path" + "path/filepath" + "strings" "sync" - "time" "github.com/oschwald/maxminddb-golang" log "github.com/sirupsen/logrus" ) -const MMDBFileName = "GeoLite2-City.mmdb" - type Geolocation struct { - mmdbPath string - mux sync.RWMutex - sha256sum []byte - db *maxminddb.Reader - locationDB *SqliteStore - stopCh chan struct{} - reloadCheckInterval time.Duration + mmdbPath string + mux sync.RWMutex + db *maxminddb.Reader + locationDB *SqliteStore + stopCh chan struct{} } type Record struct { @@ -53,45 +49,56 @@ type Country struct { CountryName string } -func NewGeolocation(ctx context.Context, dataDir string) (*Geolocation, error) { - if err := loadGeolocationDatabases(dataDir); err != nil { +const ( + mmdbPattern = "GeoLite2-City_*.mmdb" + geonamesdbPattern = "geonames_*.db" +) + +func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (*Geolocation, error) { + mmdbGlobPattern := filepath.Join(dataDir, mmdbPattern) + mmdbFile, err := getDatabaseFilename(ctx, geoLiteCityTarGZURL, mmdbGlobPattern, autoUpdate) + if err != nil { + return nil, fmt.Errorf("failed to get database filename: %v", err) + } + + geonamesDbGlobPattern := filepath.Join(dataDir, geonamesdbPattern) + geonamesDbFile, err := getDatabaseFilename(ctx, geoLiteCityZipURL, geonamesDbGlobPattern, autoUpdate) + if err != nil { + return nil, fmt.Errorf("failed to get database filename: %v", err) + } + + if err := loadGeolocationDatabases(ctx, dataDir, mmdbFile, geonamesDbFile); err != nil { return nil, fmt.Errorf("failed to load MaxMind databases: %v", err) } - mmdbPath := path.Join(dataDir, MMDBFileName) + if err := cleanupMaxMindDatabases(ctx, dataDir, mmdbFile, geonamesDbFile); err != nil { + return nil, fmt.Errorf("failed to remove old MaxMind databases: %v", err) + } + + mmdbPath := path.Join(dataDir, mmdbFile) db, err := openDB(mmdbPath) if err != nil { return nil, err } - sha256sum, err := calculateFileSHA256(mmdbPath) - if err != nil { - return nil, err - } - - locationDB, err := NewSqliteStore(ctx, dataDir) + locationDB, err := NewSqliteStore(ctx, dataDir, geonamesDbFile) if err != nil { return nil, err } geo := &Geolocation{ - mmdbPath: mmdbPath, - mux: sync.RWMutex{}, - sha256sum: sha256sum, - db: db, - locationDB: locationDB, - reloadCheckInterval: 300 * time.Second, // TODO: make configurable - stopCh: make(chan struct{}), + mmdbPath: mmdbPath, + mux: sync.RWMutex{}, + db: db, + locationDB: locationDB, + stopCh: make(chan struct{}), } - go geo.reloader(ctx) - return geo, nil } func openDB(mmdbPath string) (*maxminddb.Reader, error) { _, err := os.Stat(mmdbPath) - if os.IsNotExist(err) { return nil, fmt.Errorf("%v does not exist", mmdbPath) } else if err != nil { @@ -166,70 +173,6 @@ func (gl *Geolocation) Stop() error { return nil } -func (gl *Geolocation) reloader(ctx context.Context) { - for { - select { - case <-gl.stopCh: - return - case <-time.After(gl.reloadCheckInterval): - if err := gl.locationDB.reload(ctx); err != nil { - log.WithContext(ctx).Errorf("geonames db reload failed: %s", err) - } - - newSha256sum1, err := calculateFileSHA256(gl.mmdbPath) - if err != nil { - log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err) - continue - } - if !bytes.Equal(gl.sha256sum, newSha256sum1) { - // we check sum twice just to avoid possible case when we reload during update of the file - // considering the frequency of file update (few times a week) checking sum twice should be enough - time.Sleep(50 * time.Millisecond) - newSha256sum2, err := calculateFileSHA256(gl.mmdbPath) - if err != nil { - log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err) - continue - } - if !bytes.Equal(newSha256sum1, newSha256sum2) { - log.WithContext(ctx).Errorf("sha256 sum changed during reloading of '%s'", gl.mmdbPath) - continue - } - err = gl.reload(ctx, newSha256sum2) - if err != nil { - log.WithContext(ctx).Errorf("mmdb reload failed: %s", err) - } - } else { - log.WithContext(ctx).Tracef("No changes in '%s', no need to reload. Next check is in %.0f seconds.", - gl.mmdbPath, gl.reloadCheckInterval.Seconds()) - } - } - } -} - -func (gl *Geolocation) reload(ctx context.Context, newSha256sum []byte) error { - gl.mux.Lock() - defer gl.mux.Unlock() - - log.WithContext(ctx).Infof("Reloading '%s'", gl.mmdbPath) - - err := gl.db.Close() - if err != nil { - return err - } - - db, err := openDB(gl.mmdbPath) - if err != nil { - return err - } - - gl.db = db - gl.sha256sum = newSha256sum - - log.WithContext(ctx).Infof("Successfully reloaded '%s'", gl.mmdbPath) - - return nil -} - func fileExists(filePath string) (bool, error) { _, err := os.Stat(filePath) if err == nil { @@ -240,3 +183,79 @@ func fileExists(filePath string) (bool, error) { } return false, err } + +func getExistingDatabases(pattern string) []string { + files, _ := filepath.Glob(pattern) + return files +} + +func getDatabaseFilename(ctx context.Context, databaseURL string, filenamePattern string, autoUpdate bool) (string, error) { + var ( + filename string + err error + ) + + if autoUpdate { + filename, err = getFilenameFromURL(databaseURL) + if err != nil { + log.WithContext(ctx).Debugf("Failed to update database from url: %s", databaseURL) + return "", err + } + } else { + files := getExistingDatabases(filenamePattern) + if len(files) < 1 { + filename, err = getFilenameFromURL(databaseURL) + if err != nil { + log.WithContext(ctx).Debugf("Failed to get database from url: %s", databaseURL) + return "", err + } + } else { + filename = filepath.Base(files[len(files)-1]) + log.WithContext(ctx).Debugf("Using existing database, %s", filename) + return filename, nil + } + } + + // strip suffixes that may be nested, such as .tar.gz + basename := strings.SplitN(filename, ".", 2)[0] + // get date version from basename + date := strings.SplitN(basename, "_", 2)[1] + // format db as "GeoLite2-Cities-{maxmind|geonames}_{DATE}.{mmdb|db}" + databaseFilename := filepath.Base(strings.Replace(filenamePattern, "*", date, 1)) + + return databaseFilename, nil +} + +func cleanupOldDatabases(ctx context.Context, pattern string, currentFile string) error { + files := getExistingDatabases(pattern) + + for _, db := range files { + if filepath.Base(db) == currentFile { + continue + } + log.WithContext(ctx).Debugf("Removing old database: %s", db) + err := os.Remove(db) + if err != nil { + return err + } + } + return nil +} + +func cleanupMaxMindDatabases(ctx context.Context, dataDir string, mmdbFile string, geonamesdbFile string) error { + for _, file := range []string{mmdbFile, geonamesdbFile} { + switch file { + case mmdbFile: + pattern := filepath.Join(dataDir, mmdbPattern) + if err := cleanupOldDatabases(ctx, pattern, file); err != nil { + return err + } + case geonamesdbFile: + pattern := filepath.Join(dataDir, geonamesdbPattern) + if err := cleanupOldDatabases(ctx, pattern, file); err != nil { + return err + } + } + } + return nil +} diff --git a/management/server/geolocation/geolocation_test.go b/management/server/geolocation/geolocation_test.go index 6fd46fcfe..9bdefd268 100644 --- a/management/server/geolocation/geolocation_test.go +++ b/management/server/geolocation/geolocation_test.go @@ -2,8 +2,8 @@ package geolocation import ( "net" - "os" "path" + "path/filepath" "sync" "testing" @@ -13,21 +13,15 @@ import ( ) // from https://github.com/maxmind/MaxMind-DB/blob/main/test-data/GeoLite2-City-Test.mmdb -var mmdbPath = "../testdata/GeoLite2-City-Test.mmdb" +var mmdbPath = "../testdata/GeoLite2-City_20240305.mmdb" func TestGeoLite_Lookup(t *testing.T) { tempDir := t.TempDir() - filename := path.Join(tempDir, MMDBFileName) + filename := path.Join(tempDir, filepath.Base(mmdbPath)) err := util.CopyFileContents(mmdbPath, filename) assert.NoError(t, err) - defer func() { - err := os.Remove(filename) - if err != nil { - t.Errorf("os.Remove: %s", err) - } - }() - db, err := openDB(mmdbPath) + db, err := openDB(filename) assert.NoError(t, err) geo := &Geolocation{ diff --git a/management/server/geolocation/store.go b/management/server/geolocation/store.go index 67d420cfd..1f94bf47e 100644 --- a/management/server/geolocation/store.go +++ b/management/server/geolocation/store.go @@ -1,7 +1,6 @@ package geolocation import ( - "bytes" "context" "fmt" "path/filepath" @@ -17,10 +16,6 @@ import ( "github.com/netbirdio/netbird/management/server/status" ) -const ( - GeoSqliteDBFile = "geonames.db" -) - type GeoNames struct { GeoNameID int `gorm:"column:geoname_id"` LocaleCode string `gorm:"column:locale_code"` @@ -44,31 +39,24 @@ func (*GeoNames) TableName() string { // SqliteStore represents a location storage backed by a Sqlite DB. type SqliteStore struct { - db *gorm.DB - filePath string - mux sync.RWMutex - closed bool - sha256sum []byte + db *gorm.DB + filePath string + mux sync.RWMutex + closed bool } -func NewSqliteStore(ctx context.Context, dataDir string) (*SqliteStore, error) { - file := filepath.Join(dataDir, GeoSqliteDBFile) +func NewSqliteStore(ctx context.Context, dataDir string, dbPath string) (*SqliteStore, error) { + file := filepath.Join(dataDir, dbPath) db, err := connectDB(ctx, file) if err != nil { return nil, err } - sha256sum, err := calculateFileSHA256(file) - if err != nil { - return nil, err - } - return &SqliteStore{ - db: db, - filePath: file, - mux: sync.RWMutex{}, - sha256sum: sha256sum, + db: db, + filePath: file, + mux: sync.RWMutex{}, }, nil } @@ -115,48 +103,6 @@ func (s *SqliteStore) GetCitiesByCountry(countryISOCode string) ([]City, error) return cities, nil } -// reload attempts to reload the SqliteStore's database if the database file has changed. -func (s *SqliteStore) reload(ctx context.Context) error { - s.mux.Lock() - defer s.mux.Unlock() - - newSha256sum1, err := calculateFileSHA256(s.filePath) - if err != nil { - log.WithContext(ctx).Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err) - } - - if !bytes.Equal(s.sha256sum, newSha256sum1) { - // we check sum twice just to avoid possible case when we reload during update of the file - // considering the frequency of file update (few times a week) checking sum twice should be enough - time.Sleep(50 * time.Millisecond) - newSha256sum2, err := calculateFileSHA256(s.filePath) - if err != nil { - return fmt.Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err) - } - if !bytes.Equal(newSha256sum1, newSha256sum2) { - return fmt.Errorf("sha256 sum changed during reloading of '%s'", s.filePath) - } - - log.WithContext(ctx).Infof("Reloading '%s'", s.filePath) - _ = s.close() - s.closed = true - - newDb, err := connectDB(ctx, s.filePath) - if err != nil { - return err - } - - s.closed = false - s.db = newDb - - log.WithContext(ctx).Infof("Successfully reloaded '%s'", s.filePath) - } else { - log.WithContext(ctx).Tracef("No changes in '%s', no need to reload", s.filePath) - } - - return nil -} - // close closes the database connection. // It retrieves the underlying *sql.DB object from the *gorm.DB object // and calls the Close() method on it. diff --git a/management/server/geolocation/utils.go b/management/server/geolocation/utils.go index bdbd4732d..5104b0a08 100644 --- a/management/server/geolocation/utils.go +++ b/management/server/geolocation/utils.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "io" + "mime" "net/http" "os" "path" @@ -174,3 +175,21 @@ func downloadFile(url, filepath string) error { _, err = io.Copy(out, bytes.NewBuffer(bodyBytes)) return err } + +func getFilenameFromURL(url string) (string, error) { + resp, err := http.Head(url) + if err != nil { + return "", err + } + + defer resp.Body.Close() + + _, params, err := mime.ParseMediaType(resp.Header["Content-Disposition"][0]) + if err != nil { + return "", err + } + + filename := params["filename"] + + return filename, nil +} diff --git a/management/server/group.go b/management/server/group.go index 63281a2f1..91c06a3c0 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -25,91 +25,46 @@ func (e *GroupLinkError) Error() string { return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name) } -// GetGroup object of the peers +// CheckGroupPermissions validates if a user has the necessary permissions to view groups +func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error { + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return err + } + + if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID { + return status.Errorf(status.PermissionDenied, "groups are blocked for users") + } + + return nil +} + +// GetGroup returns a specific group by groupID in an account func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { + if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { - return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users") - } - - group, ok := account.Groups[groupID] - if ok { - return group, nil - } - - return nil, status.Errorf(status.NotFound, "group with ID %s not found", groupID) + return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID) } // GetAllGroups returns all groups in an account -func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID string, userID string) ([]*nbgroup.Group, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { +func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) { + if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { - return nil, status.Errorf(status.PermissionDenied, "groups are blocked for users") - } - - groups := make([]*nbgroup.Group, 0, len(account.Groups)) - for _, item := range account.Groups { - groups = append(groups, item) - } - - return groups, nil + return am.Store.GetAccountGroups(ctx, accountID) } // GetGroupByName filters all groups in an account by name and returns the one with the most peers func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return nil, err - } - - matchingGroups := make([]*nbgroup.Group, 0) - for _, group := range account.Groups { - if group.Name == groupName { - matchingGroups = append(matchingGroups, group) - } - } - - if len(matchingGroups) == 0 { - return nil, status.Errorf(status.NotFound, "group with name %s not found", groupName) - } - - maxPeers := -1 - var groupWithMostPeers *nbgroup.Group - for i, group := range matchingGroups { - if len(group.Peers) > maxPeers { - maxPeers = len(group.Peers) - groupWithMostPeers = matchingGroups[i] - } - } - - return groupWithMostPeers, nil + return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID) } // SaveGroup object of the peers @@ -269,6 +224,15 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use return nil } + allGroup, err := account.GetGroupAll() + if err != nil { + return err + } + + if allGroup.ID == groupID { + return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") + } + if err = validateDeleteGroup(account, group, userId); err != nil { return err } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index ff7a71cfd..4c4ef6c3c 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -16,13 +16,12 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - nbContext "github.com/netbirdio/netbird/management/server/context" - "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/management/proto" + nbContext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" internalStatus "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" ) @@ -32,17 +31,25 @@ type GRPCServer struct { accountManager AccountManager wgKey wgtypes.Key proto.UnimplementedManagementServiceServer - peersUpdateManager *PeersUpdateManager - config *Config - turnCredentialsManager TURNCredentialsManager - jwtValidator *jwtclaims.JWTValidator - jwtClaimsExtractor *jwtclaims.ClaimsExtractor - appMetrics telemetry.AppMetrics - ephemeralManager *EphemeralManager + peersUpdateManager *PeersUpdateManager + config *Config + secretsManager SecretsManager + jwtValidator *jwtclaims.JWTValidator + jwtClaimsExtractor *jwtclaims.ClaimsExtractor + appMetrics telemetry.AppMetrics + ephemeralManager *EphemeralManager } // NewServer creates a new Management server -func NewServer(ctx context.Context, config *Config, accountManager AccountManager, peersUpdateManager *PeersUpdateManager, turnCredentialsManager TURNCredentialsManager, appMetrics telemetry.AppMetrics, ephemeralManager *EphemeralManager) (*GRPCServer, error) { +func NewServer( + ctx context.Context, + config *Config, + accountManager AccountManager, + peersUpdateManager *PeersUpdateManager, + secretsManager SecretsManager, + appMetrics telemetry.AppMetrics, + ephemeralManager *EphemeralManager, +) (*GRPCServer, error) { key, err := wgtypes.GeneratePrivateKey() if err != nil { return nil, err @@ -88,14 +95,14 @@ func NewServer(ctx context.Context, config *Config, accountManager AccountManage return &GRPCServer{ wgKey: key, // peerKey -> event channel - peersUpdateManager: peersUpdateManager, - accountManager: accountManager, - config: config, - turnCredentialsManager: turnCredentialsManager, - jwtValidator: jwtValidator, - jwtClaimsExtractor: jwtClaimsExtractor, - appMetrics: appMetrics, - ephemeralManager: ephemeralManager, + peersUpdateManager: peersUpdateManager, + accountManager: accountManager, + config: config, + secretsManager: secretsManager, + jwtValidator: jwtValidator, + jwtClaimsExtractor: jwtClaimsExtractor, + appMetrics: appMetrics, + ephemeralManager: ephemeralManager, }, nil } @@ -132,24 +139,30 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi ctx := srv.Context() - realIP := getRealIP(ctx) - syncReq := &proto.SyncRequest{} peerKey, err := s.parseRequest(ctx, req, syncReq) if err != nil { return err } - //nolint + // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String()) + accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()) if err != nil { - // this case should not happen and already indicates an issue but we don't want the system to fail due to being unable to log in detail - accountID = "UNKNOWN" + // nolint:staticcheck + ctx = context.WithValue(ctx, nbContext.AccountIDKey, "UNKNOWN") + log.WithContext(ctx).Tracef("peer %s is not registered", peerKey.String()) + if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound { + return status.Errorf(codes.PermissionDenied, "peer is not registered") + } + return err } - //nolint + + // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) + realIP := getRealIP(ctx) log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String()) if syncReq.GetMeta() == nil { @@ -171,9 +184,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi s.ephemeralManager.OnPeerConnected(ctx, peer) - if s.config.TURNConfig.TimeBasedCredentials { - s.turnCredentialsManager.SetupRefresh(ctx, peer.ID) - } + s.secretsManager.SetupRefresh(ctx, peer.ID) if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart)) @@ -235,7 +246,7 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) { s.peersUpdateManager.CloseChannel(ctx, peer.ID) - s.turnCredentialsManager.CancelRefresh(peer.ID) + s.secretsManager.CancelRefresh(peer.ID) _ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key) s.ephemeralManager.OnPeerDisconnected(ctx, peer) } @@ -251,7 +262,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string } claims := s.jwtClaimsExtractor.FromToken(token) // we need to call this method because if user is new, we will automatically add it to existing or create a new account - _, _, err = s.accountManager.GetAccountFromToken(ctx, claims) + _, _, err = s.accountManager.GetAccountIDFromToken(ctx, claims) if err != nil { return "", status.Errorf(codes.Internal, "unable to fetch account with claims, err: %v", err) } @@ -421,9 +432,17 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p s.ephemeralManager.OnPeerDisconnected(ctx, peer) } + var relayToken *Token + if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 { + relayToken, err = s.secretsManager.GenerateRelayToken() + if err != nil { + log.Errorf("failed generating Relay token: %v", err) + } + } + // if peer has reached this point then it has logged in loginResp := &proto.LoginResponse{ - WiretrusteeConfig: toWiretrusteeConfig(s.config, nil), + WiretrusteeConfig: toWiretrusteeConfig(s.config, nil, relayToken), PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()), Checks: toProtocolChecks(ctx, postureChecks), } @@ -481,10 +500,11 @@ func ToResponseProto(configProto Protocol) proto.HostConfig_Protocol { } } -func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *proto.WiretrusteeConfig { +func toWiretrusteeConfig(config *Config, turnCredentials *Token, relayToken *Token) *proto.WiretrusteeConfig { if config == nil { return nil } + var stuns []*proto.HostConfig for _, stun := range config.Stuns { stuns = append(stuns, &proto.HostConfig{ @@ -492,25 +512,40 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot Protocol: ToResponseProto(stun.Proto), }) } + var turns []*proto.ProtectedHostConfig - for _, turn := range config.TURNConfig.Turns { - var username string - var password string - if turnCredentials != nil { - username = turnCredentials.Username - password = turnCredentials.Password - } else { - username = turn.Username - password = turn.Password + if config.TURNConfig != nil { + for _, turn := range config.TURNConfig.Turns { + var username string + var password string + if turnCredentials != nil { + username = turnCredentials.Payload + password = turnCredentials.Signature + } else { + username = turn.Username + password = turn.Password + } + turns = append(turns, &proto.ProtectedHostConfig{ + HostConfig: &proto.HostConfig{ + Uri: turn.URI, + Protocol: ToResponseProto(turn.Proto), + }, + User: username, + Password: password, + }) + } + } + + var relayCfg *proto.RelayConfig + if config.Relay != nil && len(config.Relay.Addresses) > 0 { + relayCfg = &proto.RelayConfig{ + Urls: config.Relay.Addresses, + } + + if relayToken != nil { + relayCfg.TokenPayload = relayToken.Payload + relayCfg.TokenSignature = relayToken.Signature } - turns = append(turns, &proto.ProtectedHostConfig{ - HostConfig: &proto.HostConfig{ - Uri: turn.URI, - Protocol: ToResponseProto(turn.Proto), - }, - User: username, - Password: password, - }) } return &proto.WiretrusteeConfig{ @@ -520,6 +555,7 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot Uri: config.Signal.URI, Protocol: ToResponseProto(config.Signal.Proto), }, + Relay: relayCfg, } } @@ -533,9 +569,9 @@ func toPeerConfig(peer *nbpeer.Peer, network *Network, dnsName string) *proto.Pe } } -func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache) *proto.SyncResponse { +func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache) *proto.SyncResponse { response := &proto.SyncResponse{ - WiretrusteeConfig: toWiretrusteeConfig(config, turnCredentials), + WiretrusteeConfig: toWiretrusteeConfig(config, turnCredentials, relayCredentials), PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName), NetworkMap: &proto.NetworkMap{ Serial: networkMap.Network.CurrentSerial(), @@ -560,6 +596,10 @@ func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turn response.NetworkMap.FirewallRules = firewallRules response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0 + routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules) + response.NetworkMap.RoutesFirewallRules = routesFirewallRules + response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0 + return response } @@ -582,15 +622,25 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em // sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error { - // make secret time based TURN credentials optional - var turnCredentials *TURNCredentials - if s.config.TURNConfig.TimeBasedCredentials { - creds := s.turnCredentialsManager.GenerateCredentials() - turnCredentials = &creds - } else { - turnCredentials = nil + var err error + + var turnToken *Token + if s.config.TURNConfig != nil && s.config.TURNConfig.TimeBasedCredentials { + turnToken, err = s.secretsManager.GenerateTurnToken() + if err != nil { + log.Errorf("failed generating TURN token: %v", err) + } } - plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil) + + var relayToken *Token + if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 { + relayToken, err = s.secretsManager.GenerateRelayToken() + if err != nil { + log.Errorf("failed generating Relay token: %v", err) + } + } + + plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) if err != nil { diff --git a/management/server/http/accounts_handler.go b/management/server/http/accounts_handler.go index ffa5b9a28..91caa1512 100644 --- a/management/server/http/accounts_handler.go +++ b/management/server/http/accounts_handler.go @@ -35,25 +35,26 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) * // GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - if !(user.HasAdminPower() || user.IsServiceUser) { - util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w) + settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) return } - resp := toAccountResponse(account) + resp := toAccountResponse(accountID, settings) util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) } // UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - _, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + _, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -96,24 +97,19 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) settings.JWTAllowGroups = *req.Settings.JwtAllowGroups } - updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, user.Id, settings) + updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) if err != nil { util.WriteError(r.Context(), err, w) return } - resp := toAccountResponse(updatedAccount) + resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings) util.WriteJSONObject(r.Context(), w, &resp) } // DeleteAccount is a HTTP DELETE handler to delete an account func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodDelete { - util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) - return - } - claims := h.claimsExtractor.FromRequestContext(r) vars := mux.Vars(r) targetAccountID := vars["accountId"] @@ -131,28 +127,28 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) util.WriteJSONObject(r.Context(), w, emptyObject{}) } -func toAccountResponse(account *server.Account) *api.Account { - jwtAllowGroups := account.Settings.JWTAllowGroups +func toAccountResponse(accountID string, settings *server.Settings) *api.Account { + jwtAllowGroups := settings.JWTAllowGroups if jwtAllowGroups == nil { jwtAllowGroups = []string{} } - settings := api.AccountSettings{ - PeerLoginExpiration: int(account.Settings.PeerLoginExpiration.Seconds()), - PeerLoginExpirationEnabled: account.Settings.PeerLoginExpirationEnabled, - GroupsPropagationEnabled: &account.Settings.GroupsPropagationEnabled, - JwtGroupsEnabled: &account.Settings.JWTGroupsEnabled, - JwtGroupsClaimName: &account.Settings.JWTGroupsClaimName, + apiSettings := api.AccountSettings{ + PeerLoginExpiration: int(settings.PeerLoginExpiration.Seconds()), + PeerLoginExpirationEnabled: settings.PeerLoginExpirationEnabled, + GroupsPropagationEnabled: &settings.GroupsPropagationEnabled, + JwtGroupsEnabled: &settings.JWTGroupsEnabled, + JwtGroupsClaimName: &settings.JWTGroupsClaimName, JwtAllowGroups: &jwtAllowGroups, - RegularUsersViewBlocked: account.Settings.RegularUsersViewBlocked, + RegularUsersViewBlocked: settings.RegularUsersViewBlocked, } - if account.Settings.Extra != nil { - settings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &account.Settings.Extra.PeerApprovalEnabled} + if settings.Extra != nil { + apiSettings.Extra = &api.AccountExtraSettings{PeerApprovalEnabled: &settings.Extra.PeerApprovalEnabled} } return &api.Account{ - Id: account.Id, - Settings: settings, + Id: accountID, + Settings: apiSettings, } } diff --git a/management/server/http/accounts_handler_test.go b/management/server/http/accounts_handler_test.go index 45c7679e5..cacb3d430 100644 --- a/management/server/http/accounts_handler_test.go +++ b/management/server/http/accounts_handler_test.go @@ -23,8 +23,11 @@ import ( func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler { return &AccountsHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return account, admin, nil + GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return account.Id, admin.Id, nil + }, + GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) { + return account.Settings, nil }, UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) { halfYearLimit := 180 * 24 * time.Hour diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 45887dc2e..fd0343e97 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -251,7 +251,7 @@ components: - name - ssh_enabled - login_expiration_enabled - PeerBase: + Peer: allOf: - $ref: '#/components/schemas/PeerMinimum' - type: object @@ -378,25 +378,40 @@ components: description: User ID of the user that enrolled this peer type: string example: google-oauth2|277474792786460067937 + os: + description: Peer's operating system and version + type: string + example: linux + country_code: + $ref: '#/components/schemas/CountryCode' + city_name: + $ref: '#/components/schemas/CityName' + geoname_id: + description: Unique identifier from the GeoNames database for a specific geographical location. + type: integer + example: 2643743 + connected: + description: Peer to Management connection status + type: boolean + example: true + last_seen: + description: Last time peer connected to Netbird's management service + type: string + format: date-time + example: "2023-05-05T10:05:26.420578Z" required: - ip - dns_label - user_id - Peer: - allOf: - - $ref: '#/components/schemas/PeerBase' - - type: object - properties: - accessible_peers: - description: List of accessible peers - type: array - items: - $ref: '#/components/schemas/AccessiblePeer' - required: - - accessible_peers + - os + - country_code + - city_name + - geoname_id + - connected + - last_seen PeerBatch: allOf: - - $ref: '#/components/schemas/PeerBase' + - $ref: '#/components/schemas/Peer' - type: object properties: accessible_peers_count: @@ -712,17 +727,39 @@ components: enum: ["all", "tcp", "udp", "icmp"] example: "tcp" ports: - description: Policy rule affected ports or it ranges list + description: Policy rule affected ports type: array items: type: string example: "80" + port_ranges: + description: Policy rule affected ports ranges list + type: array + items: + $ref: '#/components/schemas/RulePortRange' required: - name - enabled - bidirectional - protocol - action + + RulePortRange: + description: Policy rule affected ports range + type: object + properties: + start: + description: The starting port of the range + type: integer + example: 80 + end: + description: The ending port of the range + type: integer + example: 320 + required: + - start + - end + PolicyRuleUpdate: allOf: - $ref: '#/components/schemas/PolicyRuleMinimum' @@ -935,7 +972,7 @@ components: type: array items: type: string - example: ["192.168.1.0/24", "10.0.0.0/8", "2001:db8:1234:1a00::/56"] + example: ["192.168.1.0/24", "10.0.0.0/8", "2001:db8:1234:1a00::/56"] action: description: Action to take upon policy match type: string @@ -1064,12 +1101,12 @@ components: type: string example: 10.64.0.0/24 domains: - description: Domain list to be dynamically resolved. Conflicts with network + description: Domain list to be dynamically resolved. Max of 32 domains can be added per route configuration. Conflicts with network type: array items: type: string minLength: 1 - maxLength: 255 + maxLength: 32 example: "example.com" metric: description: Route metric number. Lowest number has higher priority @@ -1091,6 +1128,12 @@ components: description: Indicate if the route should be kept after a domain doesn't resolve that IP anymore type: boolean example: true + access_control_groups: + description: Access control group identifier associated with route. + type: array + items: + type: string + example: "chacbco6lnnbn6cg5s91" required: - id - description @@ -1806,6 +1849,38 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + /api/peers/{peerId}/accessible-peers: + get: + summary: List accessible Peers + description: Returns a list of peers that the specified peer can connect to within the network. + tags: [ Peers ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: peerId + required: true + schema: + type: string + description: The unique identifier of a peer + responses: + '200': + description: A JSON Array of Accessible Peers + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/AccessiblePeer' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/setup-keys: get: summary: List all Setup Keys @@ -2759,4 +2834,4 @@ paths: '403': "$ref": "#/components/responses/forbidden" '500': - "$ref": "#/components/responses/internal_error" \ No newline at end of file + "$ref": "#/components/responses/internal_error" diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 77a6c643d..570ec03c5 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -152,18 +152,36 @@ const ( // AccessiblePeer defines model for AccessiblePeer. type AccessiblePeer struct { + // CityName Commonly used English name of the city + CityName CityName `json:"city_name"` + + // Connected Peer to Management connection status + Connected bool `json:"connected"` + + // CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country + CountryCode CountryCode `json:"country_code"` + // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud DnsLabel string `json:"dns_label"` + // GeonameId Unique identifier from the GeoNames database for a specific geographical location. + GeonameId int `json:"geoname_id"` + // Id Peer ID Id string `json:"id"` // Ip Peer's IP address Ip string `json:"ip"` + // LastSeen Last time peer connected to Netbird's management service + LastSeen time.Time `json:"last_seen"` + // Name Peer's hostname Name string `json:"name"` + // Os Peer's operating system and version + Os string `json:"os"` + // UserId User ID of the user that enrolled this peer UserId string `json:"user_id"` } @@ -490,81 +508,6 @@ type OSVersionCheck struct { // Peer defines model for Peer. type Peer struct { - // AccessiblePeers List of accessible peers - AccessiblePeers []AccessiblePeer `json:"accessible_peers"` - - // ApprovalRequired (Cloud only) Indicates whether peer needs approval - ApprovalRequired bool `json:"approval_required"` - - // CityName Commonly used English name of the city - CityName CityName `json:"city_name"` - - // Connected Peer to Management connection status - Connected bool `json:"connected"` - - // ConnectionIp Peer's public connection IP address - ConnectionIp string `json:"connection_ip"` - - // CountryCode 2-letter ISO 3166-1 alpha-2 code that represents the country - CountryCode CountryCode `json:"country_code"` - - // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud - DnsLabel string `json:"dns_label"` - - // GeonameId Unique identifier from the GeoNames database for a specific geographical location. - GeonameId int `json:"geoname_id"` - - // Groups Groups that the peer belongs to - Groups []GroupMinimum `json:"groups"` - - // Hostname Hostname of the machine - Hostname string `json:"hostname"` - - // Id Peer ID - Id string `json:"id"` - - // Ip Peer's IP address - Ip string `json:"ip"` - - // KernelVersion Peer's operating system kernel version - KernelVersion string `json:"kernel_version"` - - // LastLogin Last time this peer performed log in (authentication). E.g., user authenticated. - LastLogin time.Time `json:"last_login"` - - // LastSeen Last time peer connected to Netbird's management service - LastSeen time.Time `json:"last_seen"` - - // LoginExpirationEnabled Indicates whether peer login expiration has been enabled or not - LoginExpirationEnabled bool `json:"login_expiration_enabled"` - - // LoginExpired Indicates whether peer's login expired or not - LoginExpired bool `json:"login_expired"` - - // Name Peer's hostname - Name string `json:"name"` - - // Os Peer's operating system and version - Os string `json:"os"` - - // SerialNumber System serial number - SerialNumber string `json:"serial_number"` - - // SshEnabled Indicates whether SSH server is enabled on this peer - SshEnabled bool `json:"ssh_enabled"` - - // UiVersion Peer's desktop UI version - UiVersion string `json:"ui_version"` - - // UserId User ID of the user that enrolled this peer - UserId string `json:"user_id"` - - // Version Peer's daemon or cli version - Version string `json:"version"` -} - -// PeerBase defines model for PeerBase. -type PeerBase struct { // ApprovalRequired (Cloud only) Indicates whether peer needs approval ApprovalRequired bool `json:"approval_required"` @@ -837,7 +780,10 @@ type PolicyRule struct { // Name Policy rule name identifier Name string `json:"name"` - // Ports Policy rule affected ports or it ranges list + // PortRanges Policy rule affected ports ranges list + PortRanges *[]RulePortRange `json:"port_ranges,omitempty"` + + // Ports Policy rule affected ports Ports *[]string `json:"ports,omitempty"` // Protocol Policy rule type of the traffic @@ -873,7 +819,10 @@ type PolicyRuleMinimum struct { // Name Policy rule name identifier Name string `json:"name"` - // Ports Policy rule affected ports or it ranges list + // PortRanges Policy rule affected ports ranges list + PortRanges *[]RulePortRange `json:"port_ranges,omitempty"` + + // Ports Policy rule affected ports Ports *[]string `json:"ports,omitempty"` // Protocol Policy rule type of the traffic @@ -909,7 +858,10 @@ type PolicyRuleUpdate struct { // Name Policy rule name identifier Name string `json:"name"` - // Ports Policy rule affected ports or it ranges list + // PortRanges Policy rule affected ports ranges list + PortRanges *[]RulePortRange `json:"port_ranges,omitempty"` + + // Ports Policy rule affected ports Ports *[]string `json:"ports,omitempty"` // Protocol Policy rule type of the traffic @@ -992,10 +944,13 @@ type ProcessCheck struct { // Route defines model for Route. type Route struct { + // AccessControlGroups Access control group identifier associated with route. + AccessControlGroups *[]string `json:"access_control_groups,omitempty"` + // Description Route description Description string `json:"description"` - // Domains Domain list to be dynamically resolved. Conflicts with network + // Domains Domain list to be dynamically resolved. Max of 32 domains can be added per route configuration. Conflicts with network Domains *[]string `json:"domains,omitempty"` // Enabled Route status @@ -1034,10 +989,13 @@ type Route struct { // RouteRequest defines model for RouteRequest. type RouteRequest struct { + // AccessControlGroups Access control group identifier associated with route. + AccessControlGroups *[]string `json:"access_control_groups,omitempty"` + // Description Route description Description string `json:"description"` - // Domains Domain list to be dynamically resolved. Conflicts with network + // Domains Domain list to be dynamically resolved. Max of 32 domains can be added per route configuration. Conflicts with network Domains *[]string `json:"domains,omitempty"` // Enabled Route status @@ -1068,6 +1026,15 @@ type RouteRequest struct { PeerGroups *[]string `json:"peer_groups,omitempty"` } +// RulePortRange Policy rule affected ports range +type RulePortRange struct { + // End The ending port of the range + End int `json:"end"` + + // Start The starting port of the range + Start int `json:"start"` +} + // SetupKey defines model for SetupKey. type SetupKey struct { // AutoGroups List of group IDs to auto-assign to peers registered with this key diff --git a/management/server/http/dns_settings_handler.go b/management/server/http/dns_settings_handler.go index 74b0e1a55..13c2101a7 100644 --- a/management/server/http/dns_settings_handler.go +++ b/management/server/http/dns_settings_handler.go @@ -32,14 +32,14 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg // GetDNSSettings returns the DNS settings for the account func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), account.Id, user.Id) + dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -55,7 +55,7 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque // UpdateDNSSettings handles update to DNS settings of an account func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -72,7 +72,7 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re DisabledManagementGroups: req.DisabledManagementGroups, } - err = h.accountManager.SaveDNSSettings(r.Context(), account.Id, user.Id, updateDNSSettings) + err = h.accountManager.SaveDNSSettings(r.Context(), accountID, userID, updateDNSSettings) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/dns_settings_handler_test.go b/management/server/http/dns_settings_handler_test.go index 897ae63dc..8baea7b15 100644 --- a/management/server/http/dns_settings_handler_test.go +++ b/management/server/http/dns_settings_handler_test.go @@ -52,8 +52,8 @@ func initDNSSettingsTestData() *DNSSettingsHandler { } return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") }, - GetAccountFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil + GetAccountIDFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) { + return testingDNSSettingsAccount.Id, testingDNSSettingsAccount.Users[testDNSSettingsUserID].Id, nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( diff --git a/management/server/http/events_handler.go b/management/server/http/events_handler.go index 428b4c164..ee0c63f28 100644 --- a/management/server/http/events_handler.go +++ b/management/server/http/events_handler.go @@ -34,14 +34,14 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev // GetAllEvents list of the given account func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - accountEvents, err := h.accountManager.GetEvents(r.Context(), account.Id, user.Id) + accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -51,7 +51,7 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { events[i] = toEventResponse(e) } - err = h.fillEventsWithUserInfo(r.Context(), events, account.Id, user.Id) + err = h.fillEventsWithUserInfo(r.Context(), events, accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/events_handler_test.go b/management/server/http/events_handler_test.go index 8bdd508bf..e525cf2ee 100644 --- a/management/server/http/events_handler_test.go +++ b/management/server/http/events_handler_test.go @@ -20,7 +20,7 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" ) -func initEventsTestData(account string, user *server.User, events ...*activity.Event) *EventsHandler { +func initEventsTestData(account string, events ...*activity.Event) *EventsHandler { return &EventsHandler{ accountManager: &mock_server.MockAccountManager{ GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) { @@ -29,14 +29,8 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E } return []*activity.Event{}, nil }, - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return &server.Account{ - Id: claims.AccountId, - Domain: "hotmail.com", - Users: map[string]*server.User{ - user.Id: user, - }, - }, user, nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil }, GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) { return make([]*server.UserInfo, 0), nil @@ -199,7 +193,7 @@ func TestEvents_GetEvents(t *testing.T) { accountID := "test_account" adminUser := server.NewAdminUser("test_user") events := generateEvents(accountID, adminUser.Id) - handler := initEventsTestData(accountID, adminUser, events...) + handler := initEventsTestData(accountID, events...) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { diff --git a/management/server/http/geolocation_handler_test.go b/management/server/http/geolocation_handler_test.go index b8247f78d..19c916dd2 100644 --- a/management/server/http/geolocation_handler_test.go +++ b/management/server/http/geolocation_handler_test.go @@ -7,12 +7,13 @@ import ( "net/http" "net/http/httptest" "path" + "path/filepath" "testing" "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/server" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -24,32 +25,29 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler { t.Helper() var ( - mmdbPath = "../testdata/GeoLite2-City-Test.mmdb" - geonamesDBPath = "../testdata/geonames-test.db" + mmdbPath = "../testdata/GeoLite2-City_20240305.mmdb" + geonamesdbPath = "../testdata/geonames_20240305.db" ) tempDir := t.TempDir() - err := util.CopyFileContents(mmdbPath, path.Join(tempDir, geolocation.MMDBFileName)) + err := util.CopyFileContents(mmdbPath, path.Join(tempDir, filepath.Base(mmdbPath))) assert.NoError(t, err) - err = util.CopyFileContents(geonamesDBPath, path.Join(tempDir, geolocation.GeoSqliteDBFile)) + err = util.CopyFileContents(geonamesdbPath, path.Join(tempDir, filepath.Base(geonamesdbPath))) assert.NoError(t, err) - geo, err := geolocation.NewGeolocation(context.Background(), tempDir) + geo, err := geolocation.NewGeolocation(context.Background(), tempDir, false) assert.NoError(t, err) t.Cleanup(func() { _ = geo.Stop() }) return &GeolocationsHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - user := server.NewAdminUser("test_user") - return &server.Account{ - Id: claims.AccountId, - Users: map[string]*server.User{ - "test_user": user, - }, - }, user, nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil + }, + GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) { + return server.NewAdminUser(id), nil }, }, geolocationManager: geo, diff --git a/management/server/http/geolocations_handler.go b/management/server/http/geolocations_handler.go index af4d3116f..418228abf 100644 --- a/management/server/http/geolocations_handler.go +++ b/management/server/http/geolocations_handler.go @@ -98,7 +98,12 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http. func (l *GeolocationsHandler) authenticateUser(r *http.Request) error { claims := l.claimsExtractor.FromRequestContext(r) - _, user, err := l.accountManager.GetAccountFromToken(r.Context(), claims) + _, userID, err := l.accountManager.GetAccountIDFromToken(r.Context(), claims) + if err != nil { + return err + } + + user, err := l.accountManager.GetUserByID(r.Context(), userID) if err != nil { return err } diff --git a/management/server/http/groups_handler.go b/management/server/http/groups_handler.go index c622d873a..f369d1a00 100644 --- a/management/server/http/groups_handler.go +++ b/management/server/http/groups_handler.go @@ -5,6 +5,7 @@ import ( "net/http" "github.com/gorilla/mux" + nbpeer "github.com/netbirdio/netbird/management/server/peer" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" @@ -35,14 +36,20 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr // GetAllGroups list for the account func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - groups, err := h.accountManager.GetAllGroups(r.Context(), account.Id, user.Id) + groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -50,7 +57,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { groupsResponse := make([]*api.Group, 0, len(groups)) for _, group := range groups { - groupsResponse = append(groupsResponse, toGroupResponse(account, group)) + groupsResponse = append(groupsResponse, toGroupResponse(accountPeers, group)) } util.WriteJSONObject(r.Context(), w, groupsResponse) @@ -59,7 +66,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { // UpdateGroup handles update to a group identified by a given ID func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -76,17 +83,18 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { return } - eg, ok := account.Groups[groupID] - if !ok { - util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find group with ID %s", groupID), w) - return - } - - allGroup, err := account.GetGroupAll() + existingGroup, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID) if err != nil { util.WriteError(r.Context(), err, w) return } + + allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + if allGroup.ID == groupID { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "updating group ALL is not allowed"), w) return @@ -114,23 +122,29 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { ID: groupID, Name: req.Name, Peers: peers, - Issued: eg.Issued, - IntegrationReference: eg.IntegrationReference, + Issued: existingGroup.Issued, + IntegrationReference: existingGroup.IntegrationReference, } - if err := h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group); err != nil { - log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, account.Id, err) + if err := h.accountManager.SaveGroup(r.Context(), accountID, userID, &group); err != nil { + log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err) util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group)) + accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group)) } // CreateGroup handles group creation request func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -160,24 +174,29 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { Issued: nbgroup.GroupIssuedAPI, } - err = h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group) + err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group) if err != nil { util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(r.Context(), w, toGroupResponse(account, &group)) + accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group)) } // DeleteGroup handles group deletion request func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - aID := account.Id groupID := mux.Vars(r)["groupId"] if len(groupID) == 0 { @@ -185,18 +204,7 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { return } - allGroup, err := account.GetGroupAll() - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - if allGroup.ID == groupID { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed"), w) - return - } - - err = h.accountManager.DeleteGroup(r.Context(), aID, user.Id, groupID) + err = h.accountManager.DeleteGroup(r.Context(), accountID, userID, groupID) if err != nil { _, ok := err.(*server.GroupLinkError) if ok { @@ -213,34 +221,39 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { // GetGroup returns a group func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + groupID := mux.Vars(r)["groupId"] + if len(groupID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w) + return + } + + group, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID) if err != nil { util.WriteError(r.Context(), err, w) return } - switch r.Method { - case http.MethodGet: - groupID := mux.Vars(r)["groupId"] - if len(groupID) == 0 { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w) - return - } - - group, err := h.accountManager.GetGroup(r.Context(), account.Id, groupID, user.Id) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - util.WriteJSONObject(r.Context(), w, toGroupResponse(account, group)) - default: - util.WriteError(r.Context(), status.Errorf(status.NotFound, "HTTP method not found"), w) + accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) return } + + util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, group)) + } -func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group { +func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group { + peersMap := make(map[string]*nbpeer.Peer, len(peers)) + for _, peer := range peers { + peersMap[peer.ID] = peer + } + cache := make(map[string]api.PeerMinimum) gr := api.Group{ Id: group.ID, @@ -251,7 +264,7 @@ func toGroupResponse(account *server.Account, group *nbgroup.Group) *api.Group { for _, pid := range group.Peers { _, ok := cache[pid] if !ok { - peer, ok := account.Peers[pid] + peer, ok := peersMap[pid] if !ok { continue } diff --git a/management/server/http/groups_handler_test.go b/management/server/http/groups_handler_test.go index d5ed07c9e..7f3c81f18 100644 --- a/management/server/http/groups_handler_test.go +++ b/management/server/http/groups_handler_test.go @@ -14,6 +14,7 @@ import ( "github.com/gorilla/mux" "github.com/magiconair/properties/assert" + "golang.org/x/exp/maps" "github.com/netbirdio/netbird/management/server" nbgroup "github.com/netbirdio/netbird/management/server/group" @@ -30,7 +31,7 @@ var TestPeers = map[string]*nbpeer.Peer{ "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, } -func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler { +func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler { return &GroupsHandler{ accountManager: &mock_server.MockAccountManager{ SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error { @@ -40,36 +41,35 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler { return nil }, GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) { - if groupID != "idofthegroup" { + groups := map[string]*nbgroup.Group{ + "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT}, + "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI}, + "id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, + } + + for _, group := range initGroups { + groups[group.ID] = group + } + + group, ok := groups[groupID] + if !ok { return nil, status.Errorf(status.NotFound, "not found") } - if groupID == "id-jwt-group" { - return &nbgroup.Group{ - ID: "id-jwt-group", - Name: "Default Group", - Issued: nbgroup.GroupIssuedJWT, - }, nil - } - return &nbgroup.Group{ - ID: "idofthegroup", - Name: "Group", - Issued: nbgroup.GroupIssuedAPI, - }, nil + + return group, nil }, - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return &server.Account{ - Id: claims.AccountId, - Domain: "hotmail.com", - Peers: TestPeers, - Users: map[string]*server.User{ - user.Id: user, - }, - Groups: map[string]*nbgroup.Group{ - "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT}, - "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI}, - "id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, - }, - }, user, nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil + }, + GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*nbgroup.Group, error) { + if groupName == "All" { + return &nbgroup.Group{ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, nil + } + + return nil, fmt.Errorf("unknown group name") + }, + GetPeersFunc: func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { + return maps.Values(TestPeers), nil }, DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error { if groupID == "linked-grp" { @@ -125,8 +125,7 @@ func TestGetGroup(t *testing.T) { Name: "Group", } - adminUser := server.NewAdminUser("test_user") - p := initGroupTestData(adminUser, group) + p := initGroupTestData(group) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { @@ -247,8 +246,7 @@ func TestWriteGroup(t *testing.T) { }, } - adminUser := server.NewAdminUser("test_user") - p := initGroupTestData(adminUser) + p := initGroupTestData() for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { @@ -325,8 +323,7 @@ func TestDeleteGroup(t *testing.T) { }, } - adminUser := server.NewAdminUser("test_user") - p := initGroupTestData(adminUser) + p := initGroupTestData() for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 3fe26d0ce..3f8a8554d 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -82,7 +82,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa AuthCfg: authCfg, } - if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator); err != nil { + if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter()); err != nil { return nil, fmt.Errorf("register integrations endpoints: %w", err) } @@ -100,27 +100,6 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa api.addPostureCheckEndpoint() api.addLocationsEndpoint() - err := api.Router.Walk(func(route *mux.Route, _ *mux.Router, _ []*mux.Route) error { - methods, err := route.GetMethods() - if err != nil { // we may have wildcard routes from integrations without methods, skip them for now - methods = []string{} - } - for _, method := range methods { - template, err := route.GetPathTemplate() - if err != nil { - return err - } - err = metricsMiddleware.AddHTTPRequestResponseCounter(template, method) - if err != nil { - return err - } - } - return nil - }) - if err != nil { - return nil, err - } - return rootRouter, nil } @@ -136,6 +115,7 @@ func (apiHandler *apiHandler) addPeersEndpoint() { apiHandler.Router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS") apiHandler.Router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). Methods("GET", "PUT", "DELETE", "OPTIONS") + apiHandler.Router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS") } func (apiHandler *apiHandler) addUsersEndpoint() { diff --git a/management/server/http/nameservers_handler.go b/management/server/http/nameservers_handler.go index c6e00bb2d..e7a2bc2ae 100644 --- a/management/server/http/nameservers_handler.go +++ b/management/server/http/nameservers_handler.go @@ -36,14 +36,14 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg // GetAllNameservers returns the list of nameserver groups for the account func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), account.Id, user.Id) + nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -60,7 +60,7 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re // CreateNameserverGroup handles nameserver group creation request func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -79,7 +79,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt return } - nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled) + nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userID, req.SearchDomainsEnabled) if err != nil { util.WriteError(r.Context(), err, w) return @@ -93,7 +93,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt // UpdateNameserverGroup handles update to a nameserver group identified by a given ID func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -130,7 +130,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt SearchDomainsEnabled: req.SearchDomainsEnabled, } - err = h.accountManager.SaveNameServerGroup(r.Context(), account.Id, user.Id, updatedNSGroup) + err = h.accountManager.SaveNameServerGroup(r.Context(), accountID, userID, updatedNSGroup) if err != nil { util.WriteError(r.Context(), err, w) return @@ -144,7 +144,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt // DeleteNameserverGroup handles nameserver group deletion request func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -156,7 +156,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt return } - err = h.accountManager.DeleteNameServerGroup(r.Context(), account.Id, nsGroupID, user.Id) + err = h.accountManager.DeleteNameServerGroup(r.Context(), accountID, nsGroupID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -168,7 +168,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt // GetNameserverGroup handles a nameserver group Get request identified by ID func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) @@ -181,7 +181,7 @@ func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.R return } - nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), account.Id, user.Id, nsGroupID) + nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), accountID, userID, nsGroupID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/nameservers_handler_test.go b/management/server/http/nameservers_handler_test.go index 28b080571..98c2e402d 100644 --- a/management/server/http/nameservers_handler_test.go +++ b/management/server/http/nameservers_handler_test.go @@ -18,7 +18,6 @@ import ( "github.com/gorilla/mux" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -29,14 +28,6 @@ const ( testNSGroupAccountID = "test_id" ) -var testingNSAccount = &server.Account{ - Id: testNSGroupAccountID, - Domain: "hotmail.com", - Users: map[string]*server.User{ - "test_user": server.NewAdminUser("test_user"), - }, -} - var baseExistingNSGroup = &nbdns.NameServerGroup{ ID: existingNSGroupID, Name: "super", @@ -90,8 +81,8 @@ func initNameserversTestData() *NameserversHandler { } return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID) }, - GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return testingNSAccount, testingAccount.Users["test_user"], nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( diff --git a/management/server/http/pat_handler.go b/management/server/http/pat_handler.go index 9d8448d3d..dfa9563e3 100644 --- a/management/server/http/pat_handler.go +++ b/management/server/http/pat_handler.go @@ -34,20 +34,20 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH // GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) - userID := vars["userId"] + targetUserID := vars["userId"] if len(userID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } - pats, err := h.accountManager.GetAllPATs(r.Context(), account.Id, user.Id, userID) + pats, err := h.accountManager.GetAllPATs(r.Context(), accountID, userID, targetUserID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -64,7 +64,7 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { // GetToken is HTTP GET handler that returns a personal access token for the given user func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -83,7 +83,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { return } - pat, err := h.accountManager.GetPAT(r.Context(), account.Id, user.Id, targetUserID, tokenID) + pat, err := h.accountManager.GetPAT(r.Context(), accountID, userID, targetUserID, tokenID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -95,7 +95,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { // CreateToken is HTTP POST handler that creates a personal access token for the given user func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -115,7 +115,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { return } - pat, err := h.accountManager.CreatePAT(r.Context(), account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn) + pat, err := h.accountManager.CreatePAT(r.Context(), accountID, userID, targetUserID, req.Name, req.ExpiresIn) if err != nil { util.WriteError(r.Context(), err, w) return @@ -127,7 +127,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { // DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -146,7 +146,7 @@ func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.DeletePAT(r.Context(), account.Id, user.Id, targetUserID, tokenID) + err = h.accountManager.DeletePAT(r.Context(), accountID, userID, targetUserID, tokenID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/pat_handler_test.go b/management/server/http/pat_handler_test.go index b72f71468..c28228a50 100644 --- a/management/server/http/pat_handler_test.go +++ b/management/server/http/pat_handler_test.go @@ -77,8 +77,8 @@ func initPATTestData() *PATHandler { }, nil }, - GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return testAccount, testAccount.Users[existingUserID], nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil }, DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error { if accountID != existingAccountID { @@ -119,7 +119,7 @@ func initPATTestData() *PATHandler { return jwtclaims.AuthorizationClaims{ UserId: existingUserID, Domain: testDomain, - AccountId: testNSGroupAccountID, + AccountId: existingAccountID, } }), ), diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index 913d424d1..4fbbc3106 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -7,8 +7,6 @@ import ( "net/http" "github.com/gorilla/mux" - log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" @@ -16,6 +14,7 @@ import ( "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + log "github.com/sirupsen/logrus" ) // PeersHandler is a handler that returns peers of the account @@ -71,15 +70,11 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee return } - customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain()) - netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil) - accessiblePeers := toAccessiblePeers(netMap, dnsDomain) - _, valid := validPeers[peer.ID] - util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, accessiblePeers, valid)) + util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid)) } -func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) { +func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) { req := &api.PeerRequest{} err := json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -101,7 +96,7 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, } } - peer, err := h.accountManager.UpdatePeer(ctx, account.Id, user.Id, update) + peer, err := h.accountManager.UpdatePeer(ctx, account.Id, userID, update) if err != nil { util.WriteError(ctx, err, w) return @@ -117,13 +112,9 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, return } - customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain()) - netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil) - accessiblePeers := toAccessiblePeers(netMap, dnsDomain) - _, valid := validPeers[peer.ID] - util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, accessiblePeers, valid)) + util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, valid)) } func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { @@ -139,7 +130,7 @@ func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, // HandlePeer handles all peer requests for GET, PUT and DELETE operations func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -153,13 +144,20 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodDelete: - h.deletePeer(r.Context(), account.Id, user.Id, peerID, w) + h.deletePeer(r.Context(), accountID, userID, peerID, w) return - case http.MethodPut: - h.updatePeer(r.Context(), account, user, peerID, w, r) - return - case http.MethodGet: - h.getPeer(r.Context(), account, peerID, user.Id, w) + case http.MethodGet, http.MethodPut: + account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + if r.Method == http.MethodGet { + h.getPeer(r.Context(), account, peerID, userID, w) + } else { + h.updatePeer(r.Context(), account, userID, peerID, w, r) + } return default: util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w) @@ -168,19 +166,14 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { // GetAllPeers returns a list of all peers associated with a provided account func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w) - return - } - claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - peers, err := h.accountManager.GetPeers(r.Context(), account.Id, user.Id) + account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -188,8 +181,8 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { dnsDomain := h.accountManager.GetDNSDomain() - respBody := make([]*api.PeerBatch, 0, len(peers)) - for _, peer := range peers { + respBody := make([]*api.PeerBatch, 0, len(account.Peers)) + for _, peer := range account.Peers { peerToReturn, err := h.checkPeerStatus(peer) if err != nil { util.WriteError(r.Context(), err, w) @@ -220,32 +213,93 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv } } +// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network. +func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + vars := mux.Vars(r) + peerID := vars["peerId"] + if len(peerID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w) + return + } + + account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + user, err := account.FindUser(userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + // If the user is regular user and does not own the peer + // with the given peerID return an empty list + if !user.HasAdminPower() && !user.IsServiceUser { + peer, ok := account.Peers[peerID] + if !ok { + util.WriteError(r.Context(), status.Errorf(status.NotFound, "peer not found"), w) + return + } + + if peer.UserID != user.Id { + util.WriteJSONObject(r.Context(), w, []api.AccessiblePeer{}) + return + } + } + + dnsDomain := h.accountManager.GetDNSDomain() + + validPeers, err := h.accountManager.GetValidatedPeers(account) + if err != nil { + log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err) + util.WriteError(r.Context(), fmt.Errorf("internal error"), w) + return + } + + customZone := account.GetPeersCustomZone(r.Context(), h.accountManager.GetDNSDomain()) + netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, nil) + + util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) +} + func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.AccessiblePeer { accessiblePeers := make([]api.AccessiblePeer, 0, len(netMap.Peers)+len(netMap.OfflinePeers)) for _, p := range netMap.Peers { - ap := api.AccessiblePeer{ - Id: p.ID, - Name: p.Name, - Ip: p.IP.String(), - DnsLabel: fqdn(p, dnsDomain), - UserId: p.UserID, - } - accessiblePeers = append(accessiblePeers, ap) + accessiblePeers = append(accessiblePeers, peerToAccessiblePeer(p, dnsDomain)) } for _, p := range netMap.OfflinePeers { - ap := api.AccessiblePeer{ - Id: p.ID, - Name: p.Name, - Ip: p.IP.String(), - DnsLabel: fqdn(p, dnsDomain), - UserId: p.UserID, - } - accessiblePeers = append(accessiblePeers, ap) + accessiblePeers = append(accessiblePeers, peerToAccessiblePeer(p, dnsDomain)) } + return accessiblePeers } +func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePeer { + return api.AccessiblePeer{ + CityName: peer.Location.CityName, + Connected: peer.Status.Connected, + CountryCode: peer.Location.CountryCode, + DnsLabel: fqdn(peer, dnsDomain), + GeonameId: int(peer.Location.GeoNameID), + Id: peer.ID, + Ip: peer.IP.String(), + LastSeen: peer.Status.LastSeen, + Name: peer.Name, + Os: peer.Meta.OS, + UserId: peer.UserID, + } +} + func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum { var groupsInfo []api.GroupMinimum groupsChecked := make(map[string]struct{}) @@ -270,7 +324,7 @@ func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMi return groupsInfo } -func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeer []api.AccessiblePeer, approved bool) *api.Peer { +func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer { osVersion := peer.Meta.OSVersion if osVersion == "" { osVersion = peer.Meta.Core @@ -296,7 +350,6 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD LoginExpirationEnabled: peer.LoginExpirationEnabled, LastLogin: peer.LastLogin, LoginExpired: peer.Status.LoginExpired, - AccessiblePeers: accessiblePeer, ApprovalRequired: !approved, CountryCode: peer.Location.CountryCode, CityName: peer.Location.CityName, diff --git a/management/server/http/peers_handler_test.go b/management/server/http/peers_handler_test.go index 153c8f03a..f933eee14 100644 --- a/management/server/http/peers_handler_test.go +++ b/management/server/http/peers_handler_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "fmt" "io" "net" "net/http" @@ -12,20 +13,29 @@ import ( "time" "github.com/gorilla/mux" - - "github.com/netbirdio/netbird/management/server/http/api" - nbpeer "github.com/netbirdio/netbird/management/server/peer" - - "github.com/netbirdio/netbird/management/server/jwtclaims" - - "github.com/magiconair/properties/assert" - "github.com/netbirdio/netbird/management/server" + nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/jwtclaims" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "golang.org/x/exp/maps" + + "github.com/stretchr/testify/assert" + "github.com/netbirdio/netbird/management/server/mock_server" ) -const testPeerID = "test_peer" -const noUpdateChannelTestPeerID = "no-update-channel" +type ctxKey string + +const ( + testPeerID = "test_peer" + noUpdateChannelTestPeerID = "no-update-channel" + + adminUser = "admin_user" + regularUser = "regular_user" + serviceUser = "service_user" + userIDKey ctxKey = "user_id" +) func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { return &PeersHandler{ @@ -59,22 +69,61 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { GetDNSDomainFunc: func() string { return "netbird.selfhosted" }, - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - user := server.NewAdminUser("test_user") - return &server.Account{ - Id: claims.AccountId, - Domain: "hotmail.com", - Peers: map[string]*nbpeer.Peer{ - peers[0].ID: peers[0], - peers[1].ID: peers[1], + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil + }, + GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) { + peersMap := make(map[string]*nbpeer.Peer) + for _, peer := range peers { + peersMap[peer.ID] = peer.Copy() + } + + policy := &server.Policy{ + ID: "policy", + AccountID: accountID, + Name: "policy", + Enabled: true, + Rules: []*server.PolicyRule{ + { + ID: "rule", + Name: "rule", + Enabled: true, + Action: "accept", + Destinations: []string{"group1"}, + Sources: []string{"group1"}, + Bidirectional: true, + Protocol: "all", + Ports: []string{"80"}, + }, }, + } + + srvUser := server.NewRegularUser(serviceUser) + srvUser.IsServiceUser = true + + account := &server.Account{ + Id: accountID, + Domain: "hotmail.com", + Peers: peersMap, Users: map[string]*server.User{ - "test_user": user, + adminUser: server.NewAdminUser(adminUser), + regularUser: server.NewRegularUser(regularUser), + serviceUser: srvUser, + }, + Groups: map[string]*nbgroup.Group{ + "group1": { + ID: "group1", + AccountID: accountID, + Name: "group1", + Issued: "api", + Peers: maps.Keys(peersMap), + }, }, Settings: &server.Settings{ PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour, }, + Policies: []*server.Policy{policy}, Network: &server.Network{ Identifier: "ciclqisab2ss43jdn8q0", Net: net.IPNet{ @@ -83,7 +132,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { }, Serial: 51, }, - }, user, nil + } + + return account, nil }, HasConnectedChannelFunc: func(peerID string) bool { statuses := make(map[string]struct{}) @@ -99,8 +150,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { }, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { + userID := r.Context().Value(userIDKey).(string) return jwtclaims.AuthorizationClaims{ - UserId: "test_user", + UserId: userID, Domain: "hotmail.com", AccountId: "test_id", } @@ -197,6 +249,8 @@ func TestGetPeers(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) + ctx := context.WithValue(context.Background(), userIDKey, "admin_user") + req = req.WithContext(ctx) router := mux.NewRouter() router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET") @@ -227,9 +281,15 @@ func TestGetPeers(t *testing.T) { // hardcode this check for now as we only have two peers in this suite assert.Equal(t, len(respBody), 2) - assert.Equal(t, respBody[1].Connected, false) - got = respBody[0] + for _, peer := range respBody { + if peer.Id == testPeerID { + got = peer + } else { + assert.Equal(t, peer.Connected, false) + } + } + } else { got = &api.Peer{} err = json.Unmarshal(content, got) @@ -251,3 +311,119 @@ func TestGetPeers(t *testing.T) { }) } } + +func TestGetAccessiblePeers(t *testing.T) { + peer1 := &nbpeer.Peer{ + ID: "peer1", + Key: "key1", + IP: net.ParseIP("100.64.0.1"), + Status: &nbpeer.PeerStatus{Connected: true}, + Name: "peer1", + LoginExpirationEnabled: false, + UserID: regularUser, + } + + peer2 := &nbpeer.Peer{ + ID: "peer2", + Key: "key2", + IP: net.ParseIP("100.64.0.2"), + Status: &nbpeer.PeerStatus{Connected: true}, + Name: "peer2", + LoginExpirationEnabled: false, + UserID: adminUser, + } + + peer3 := &nbpeer.Peer{ + ID: "peer3", + Key: "key3", + IP: net.ParseIP("100.64.0.3"), + Status: &nbpeer.PeerStatus{Connected: true}, + Name: "peer3", + LoginExpirationEnabled: false, + UserID: regularUser, + } + + p := initTestMetaData(peer1, peer2, peer3) + + tt := []struct { + name string + peerID string + callerUserID string + expectedStatus int + expectedPeers []string + }{ + { + name: "non admin user can access owned peer", + peerID: "peer1", + callerUserID: regularUser, + expectedStatus: http.StatusOK, + expectedPeers: []string{"peer2", "peer3"}, + }, + { + name: "non admin user can't access unowned peer", + peerID: "peer2", + callerUserID: regularUser, + expectedStatus: http.StatusOK, + expectedPeers: []string{}, + }, + { + name: "admin user can access owned peer", + peerID: "peer2", + callerUserID: adminUser, + expectedStatus: http.StatusOK, + expectedPeers: []string{"peer1", "peer3"}, + }, + { + name: "admin user can access unowned peer", + peerID: "peer3", + callerUserID: adminUser, + expectedStatus: http.StatusOK, + expectedPeers: []string{"peer1", "peer2"}, + }, + { + name: "service user can access unowned peer", + peerID: "peer3", + callerUserID: serviceUser, + expectedStatus: http.StatusOK, + expectedPeers: []string{"peer1", "peer2"}, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + + recorder := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/api/peers/%s/accessible-peers", tc.peerID), nil) + ctx := context.WithValue(context.Background(), userIDKey, tc.callerUserID) + req = req.WithContext(ctx) + + router := mux.NewRouter() + router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET") + router.ServeHTTP(recorder, req) + + res := recorder.Result() + if res.StatusCode != tc.expectedStatus { + t.Fatalf("handler returned wrong status code: got %v want %v", res.StatusCode, tc.expectedStatus) + } + + body, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } + defer res.Body.Close() + + var accessiblePeers []api.AccessiblePeer + err = json.Unmarshal(body, &accessiblePeers) + if err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + peerIDs := make([]string, len(accessiblePeers)) + for i, peer := range accessiblePeers { + peerIDs[i] = peer.Id + } + + assert.ElementsMatch(t, peerIDs, tc.expectedPeers) + }) + } +} diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go index 9622668f4..73f3803b5 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/policies_handler.go @@ -6,6 +6,7 @@ import ( "strconv" "github.com/gorilla/mux" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/rs/xid" "github.com/netbirdio/netbird/management/server" @@ -35,21 +36,27 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) * // GetAllPolicies list for the account func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - accountPolicies, err := h.accountManager.ListPolicies(r.Context(), account.Id, user.Id) + listPolicies, err := h.accountManager.ListPolicies(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return } - policies := []*api.Policy{} - for _, policy := range accountPolicies { - resp := toPolicyResponse(account, policy) + allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + policies := make([]*api.Policy, 0, len(listPolicies)) + for _, policy := range listPolicies { + resp := toPolicyResponse(allGroups, policy) if len(resp.Rules) == 0 { util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) return @@ -63,7 +70,7 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { // UpdatePolicy handles update to a policy identified by a given ID func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -76,41 +83,29 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { return } - policyIdx := -1 - for i, policy := range account.Policies { - if policy.ID == policyID { - policyIdx = i - break - } - } - if policyIdx < 0 { - util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w) - return - } - - h.savePolicy(w, r, account, user, policyID) -} - -// CreatePolicy handles policy creation request -func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) { - claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + _, err = h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID) if err != nil { util.WriteError(r.Context(), err, w) return } - h.savePolicy(w, r, account, user, "") + h.savePolicy(w, r, accountID, userID, policyID) +} + +// CreatePolicy handles policy creation request +func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + h.savePolicy(w, r, accountID, userID, "") } // savePolicy handles policy creation and update -func (h *Policies) savePolicy( - w http.ResponseWriter, - r *http.Request, - account *server.Account, - user *server.User, - policyID string, -) { +func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) { var req api.PutApiPoliciesPolicyIdJSONRequestBody if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) @@ -127,6 +122,8 @@ func (h *Policies) savePolicy( return } + isUpdate := policyID != "" + if policyID == "" { policyID = xid.New().String() } @@ -141,8 +138,8 @@ func (h *Policies) savePolicy( pr := server.PolicyRule{ ID: policyID, // TODO: when policy can contain multiple rules, need refactor Name: rule.Name, - Destinations: groupMinimumsToStrings(account, rule.Destinations), - Sources: groupMinimumsToStrings(account, rule.Sources), + Destinations: rule.Destinations, + Sources: rule.Sources, Bidirectional: rule.Bidirectional, } @@ -175,6 +172,11 @@ func (h *Policies) savePolicy( return } + if (rule.Ports != nil && len(*rule.Ports) != 0) && (rule.PortRanges != nil && len(*rule.PortRanges) != 0) { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "specify either individual ports or port ranges, not both"), w) + return + } + if rule.Ports != nil && len(*rule.Ports) != 0 { for _, v := range *rule.Ports { if port, err := strconv.Atoi(v); err != nil || port < 1 || port > 65535 { @@ -185,10 +187,23 @@ func (h *Policies) savePolicy( } } + if rule.PortRanges != nil && len(*rule.PortRanges) != 0 { + for _, portRange := range *rule.PortRanges { + if portRange.Start < 1 || portRange.End > 65535 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w) + return + } + pr.PortRanges = append(pr.PortRanges, server.RulePortRange{ + Start: uint16(portRange.Start), + End: uint16(portRange.End), + }) + } + } + // validate policy object switch pr.Protocol { case server.PolicyRuleProtocolALL, server.PolicyRuleProtocolICMP: - if len(pr.Ports) != 0 { + if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w) return } @@ -197,7 +212,7 @@ func (h *Policies) savePolicy( return } case server.PolicyRuleProtocolTCP, server.PolicyRuleProtocolUDP: - if !pr.Bidirectional && len(pr.Ports) == 0 { + if !pr.Bidirectional && (len(pr.Ports) == 0 || len(pr.PortRanges) != 0) { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w) return } @@ -207,15 +222,21 @@ func (h *Policies) savePolicy( } if req.SourcePostureChecks != nil { - policy.SourcePostureChecks = sourcePostureChecksToStrings(account, *req.SourcePostureChecks) + policy.SourcePostureChecks = *req.SourcePostureChecks } - if err := h.accountManager.SavePolicy(r.Context(), account.Id, user.Id, &policy); err != nil { + if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil { util.WriteError(r.Context(), err, w) return } - resp := toPolicyResponse(account, &policy) + allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + resp := toPolicyResponse(allGroups, &policy) if len(resp.Rules) == 0 { util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) return @@ -227,12 +248,11 @@ func (h *Policies) savePolicy( // DeletePolicy handles policy deletion request func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - aID := account.Id vars := mux.Vars(r) policyID := vars["policyId"] @@ -241,7 +261,7 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { return } - if err = h.accountManager.DeletePolicy(r.Context(), aID, policyID, user.Id); err != nil { + if err = h.accountManager.DeletePolicy(r.Context(), accountID, policyID, userID); err != nil { util.WriteError(r.Context(), err, w) return } @@ -252,40 +272,46 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { // GetPolicy handles a group Get request identified by ID func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - switch r.Method { - case http.MethodGet: - vars := mux.Vars(r) - policyID := vars["policyId"] - if len(policyID) == 0 { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w) - return - } - - policy, err := h.accountManager.GetPolicy(r.Context(), account.Id, policyID, user.Id) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - resp := toPolicyResponse(account, policy) - if len(resp.Rules) == 0 { - util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) - return - } - - util.WriteJSONObject(r.Context(), w, resp) - default: - util.WriteError(r.Context(), status.Errorf(status.NotFound, "method not found"), w) + vars := mux.Vars(r) + policyID := vars["policyId"] + if len(policyID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w) + return } + + policy, err := h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + resp := toPolicyResponse(allGroups, policy) + if len(resp.Rules) == 0 { + util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) + return + } + + util.WriteJSONObject(r.Context(), w, resp) } -func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Policy { +func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Policy { + groupsMap := make(map[string]*nbgroup.Group) + for _, group := range groups { + groupsMap[group.ID] = group + } + cache := make(map[string]api.GroupMinimum) ap := &api.Policy{ Id: &policy.ID, @@ -306,16 +332,29 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic Protocol: api.PolicyRuleProtocol(r.Protocol), Action: api.PolicyRuleAction(r.Action), } + if len(r.Ports) != 0 { portsCopy := r.Ports rule.Ports = &portsCopy } + + if len(r.PortRanges) != 0 { + portRanges := make([]api.RulePortRange, 0, len(r.PortRanges)) + for _, portRange := range r.PortRanges { + portRanges = append(portRanges, api.RulePortRange{ + End: int(portRange.End), + Start: int(portRange.Start), + }) + } + rule.PortRanges = &portRanges + } + for _, gid := range r.Sources { _, ok := cache[gid] if ok { continue } - if group, ok := account.Groups[gid]; ok { + if group, ok := groupsMap[gid]; ok { minimum := api.GroupMinimum{ Id: group.ID, Name: group.Name, @@ -325,13 +364,14 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic cache[gid] = minimum } } + for _, gid := range r.Destinations { cachedMinimum, ok := cache[gid] if ok { rule.Destinations = append(rule.Destinations, cachedMinimum) continue } - if group, ok := account.Groups[gid]; ok { + if group, ok := groupsMap[gid]; ok { minimum := api.GroupMinimum{ Id: group.ID, Name: group.Name, @@ -345,28 +385,3 @@ func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Polic } return ap } - -func groupMinimumsToStrings(account *server.Account, gm []string) []string { - result := make([]string, 0, len(gm)) - for _, g := range gm { - if _, ok := account.Groups[g]; !ok { - continue - } - result = append(result, g) - } - return result -} - -func sourcePostureChecksToStrings(account *server.Account, postureChecksIds []string) []string { - result := make([]string, 0, len(postureChecksIds)) - for _, id := range postureChecksIds { - for _, postureCheck := range account.PostureChecks { - if id == postureCheck.ID { - result = append(result, id) - continue - } - } - - } - return result -} diff --git a/management/server/http/policies_handler_test.go b/management/server/http/policies_handler_test.go index 06274fb07..228ebcbce 100644 --- a/management/server/http/policies_handler_test.go +++ b/management/server/http/policies_handler_test.go @@ -38,17 +38,23 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies { } return policy, nil }, - SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) error { + SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error { if !strings.HasPrefix(policy.ID, "id-") { policy.ID = "id-was-set" policy.Rules[0].ID = "id-was-set" } return nil }, - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - user := server.NewAdminUser("test_user") + GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) { + return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil + }, + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil + }, + GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) { + user := server.NewAdminUser(userID) return &server.Account{ - Id: claims.AccountId, + Id: accountID, Domain: "hotmail.com", Policies: []*server.Policy{ {ID: "id-existed"}, @@ -60,7 +66,7 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies { Users: map[string]*server.User{ "test_user": user, }, - }, user, nil + }, nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( diff --git a/management/server/http/posture_checks_handler.go b/management/server/http/posture_checks_handler.go index 059cb3b80..1d020e9bc 100644 --- a/management/server/http/posture_checks_handler.go +++ b/management/server/http/posture_checks_handler.go @@ -37,20 +37,20 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa // GetAllPostureChecks list for the account func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - accountPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), account.Id, user.Id) + listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return } - postureChecks := []*api.PostureCheck{} - for _, postureCheck := range accountPostureChecks { + postureChecks := make([]*api.PostureCheck, 0, len(listPostureChecks)) + for _, postureCheck := range listPostureChecks { postureChecks = append(postureChecks, postureCheck.ToAPIResponse()) } @@ -60,7 +60,7 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt // UpdatePostureCheck handles update to a posture check identified by a given ID func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -73,37 +73,31 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http return } - postureChecksIdx := -1 - for i, postureCheck := range account.PostureChecks { - if postureCheck.ID == postureChecksID { - postureChecksIdx = i - break - } - } - if postureChecksIdx < 0 { - util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find posture checks id %s", postureChecksID), w) + _, err = p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) return } - p.savePostureChecks(w, r, account, user, postureChecksID) + p.savePostureChecks(w, r, accountID, userID, postureChecksID) } // CreatePostureCheck handles posture check creation request func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - p.savePostureChecks(w, r, account, user, "") + p.savePostureChecks(w, r, accountID, userID, "") } // GetPostureCheck handles a posture check Get request identified by ID func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -116,7 +110,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re return } - postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), account.Id, postureChecksID, user.Id) + postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -128,7 +122,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re // DeletePostureCheck handles posture check deletion request func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -141,7 +135,7 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http return } - if err = p.accountManager.DeletePostureChecks(r.Context(), account.Id, postureChecksID, user.Id); err != nil { + if err = p.accountManager.DeletePostureChecks(r.Context(), accountID, postureChecksID, userID); err != nil { util.WriteError(r.Context(), err, w) return } @@ -150,13 +144,7 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http } // savePostureChecks handles posture checks create and update -func (p *PostureChecksHandler) savePostureChecks( - w http.ResponseWriter, - r *http.Request, - account *server.Account, - user *server.User, - postureChecksID string, -) { +func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) { var ( err error req api.PostureCheckUpdate @@ -181,7 +169,7 @@ func (p *PostureChecksHandler) savePostureChecks( return } - if err := p.accountManager.SavePostureChecks(r.Context(), account.Id, user.Id, postureChecks); err != nil { + if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil { util.WriteError(r.Context(), err, w) return } diff --git a/management/server/http/posture_checks_handler_test.go b/management/server/http/posture_checks_handler_test.go index 974edafde..02f0f0d83 100644 --- a/management/server/http/posture_checks_handler_test.go +++ b/management/server/http/posture_checks_handler_test.go @@ -14,7 +14,6 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -67,15 +66,8 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH } return accountPostureChecks, nil }, - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - user := server.NewAdminUser("test_user") - return &server.Account{ - Id: claims.AccountId, - Users: map[string]*server.User{ - "test_user": user, - }, - PostureChecks: postureChecks, - }, user, nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil }, }, geolocationManager: &geolocation.Geolocation{}, diff --git a/management/server/http/routes_handler.go b/management/server/http/routes_handler.go index 18c347334..ce4edee4f 100644 --- a/management/server/http/routes_handler.go +++ b/management/server/http/routes_handler.go @@ -43,13 +43,13 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro // GetAllRoutes returns the list of routes for the account func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - routes, err := h.accountManager.ListRoutes(r.Context(), account.Id, user.Id) + routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -70,7 +70,7 @@ func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) { // CreateRoute handles route creation request func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -117,15 +117,14 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { peerGroupIds = *req.PeerGroups } - // Do not allow non-Linux peers - if peer := account.GetPeer(peerId); peer != nil { - if peer.Meta.GoOS != "linux" { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes"), w) - return - } + var accessControlGroupIds []string + if req.AccessControlGroups != nil { + accessControlGroupIds = *req.AccessControlGroups } - newRoute, err := h.accountManager.CreateRoute(r.Context(), account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute) + newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds, + req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userID, req.KeepRoute) + if err != nil { util.WriteError(r.Context(), err, w) return @@ -168,7 +167,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro // UpdateRoute handles update to a route identified by a given ID func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -181,7 +180,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { return } - _, err = h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id) + _, err = h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -204,14 +203,6 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { peerID = *req.Peer } - // do not allow non Linux peers - if peer := account.GetPeer(peerID); peer != nil { - if peer.Meta.GoOS != "linux" { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "non-linux peers are non supported as network routes"), w) - return - } - } - newRoute := &route.Route{ ID: route.ID(routeID), NetID: route.NetID(req.NetworkId), @@ -247,7 +238,11 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { newRoute.PeerGroups = *req.PeerGroups } - err = h.accountManager.SaveRoute(r.Context(), account.Id, user.Id, newRoute) + if req.AccessControlGroups != nil { + newRoute.AccessControlGroups = *req.AccessControlGroups + } + + err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute) if err != nil { util.WriteError(r.Context(), err, w) return @@ -265,7 +260,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { // DeleteRoute handles route deletion request func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -277,7 +272,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.DeleteRoute(r.Context(), account.Id, route.ID(routeID), user.Id) + err = h.accountManager.DeleteRoute(r.Context(), accountID, route.ID(routeID), userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -289,7 +284,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { // GetRoute handles a route Get request identified by ID func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -301,7 +296,7 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) { return } - foundRoute, err := h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id) + foundRoute, err := h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID) if err != nil { util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w) return @@ -340,6 +335,9 @@ func toRouteResponse(serverRoute *route.Route) (*api.Route, error) { if len(serverRoute.PeerGroups) > 0 { route.PeerGroups = &serverRoute.PeerGroups } + if len(serverRoute.AccessControlGroups) > 0 { + route.AccessControlGroups = &serverRoute.AccessControlGroups + } return route, nil } diff --git a/management/server/http/routes_handler_test.go b/management/server/http/routes_handler_test.go index 40075eb9d..83bd7004d 100644 --- a/management/server/http/routes_handler_test.go +++ b/management/server/http/routes_handler_test.go @@ -105,32 +105,44 @@ func initRoutesTestData() *RoutesHandler { } return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID) }, - CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) { + CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) { if peerID == notFoundPeerID { return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) } if len(peerGroups) > 0 && peerGroups[0] == notFoundGroupID { return nil, status.Errorf(status.InvalidArgument, "peer groups with ID %s not found", peerGroups[0]) } + if peerID != "" { + if peerID == nonLinuxExistingPeerID { + return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes") + } + } + return &route.Route{ - ID: existingRouteID, - NetID: netID, - Peer: peerID, - PeerGroups: peerGroups, - Network: prefix, - Domains: domains, - NetworkType: networkType, - Description: description, - Masquerade: masquerade, - Enabled: enabled, - Groups: groups, - KeepRoute: keepRoute, + ID: existingRouteID, + NetID: netID, + Peer: peerID, + PeerGroups: peerGroups, + Network: prefix, + Domains: domains, + NetworkType: networkType, + Description: description, + Masquerade: masquerade, + Enabled: enabled, + Groups: groups, + KeepRoute: keepRoute, + AccessControlGroups: accessControlGroups, }, nil }, SaveRouteFunc: func(_ context.Context, _, _ string, r *route.Route) error { if r.Peer == notFoundPeerID { return status.Errorf(status.InvalidArgument, "peer with ID %s not found", r.Peer) } + + if r.Peer == nonLinuxExistingPeerID { + return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes") + } + return nil }, DeleteRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) error { @@ -139,8 +151,9 @@ func initRoutesTestData() *RoutesHandler { } return nil }, - GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return testingAccount, testingAccount.Users["test_user"], nil + GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) { + //return testingAccount, testingAccount.Users["test_user"], nil + return testingAccount.Id, testingAccount.Users["test_user"].Id, nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( @@ -256,6 +269,27 @@ func TestRoutesHandlers(t *testing.T) { Groups: []string{existingGroupID}, }, }, + { + name: "POST OK With Access Control Groups", + requestType: http.MethodPost, + requestPath: "/api/routes", + requestBody: bytes.NewBuffer( + []byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"],\"access_control_groups\":[\"%s\"]}", existingPeerID, existingGroupID, existingGroupID))), + expectedStatus: http.StatusOK, + expectedBody: true, + expectedRoute: &api.Route{ + Id: existingRouteID, + Description: "Post", + NetworkId: "awesomeNet", + Network: toPtr("192.168.0.0/16"), + Peer: &existingPeerID, + NetworkType: route.IPv4NetworkString, + Masquerade: false, + Enabled: false, + Groups: []string{existingGroupID}, + AccessControlGroups: &[]string{existingGroupID}, + }, + }, { name: "POST Non Linux Peer", requestType: http.MethodPost, diff --git a/management/server/http/setupkeys_handler.go b/management/server/http/setupkeys_handler.go index 8ee7dfaba..8514f0b55 100644 --- a/management/server/http/setupkeys_handler.go +++ b/management/server/http/setupkeys_handler.go @@ -35,7 +35,7 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg) // CreateSetupKey is a POST requests that creates a new SetupKey func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -76,8 +76,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request if req.Ephemeral != nil { ephemeral = *req.Ephemeral } - setupKey, err := h.accountManager.CreateSetupKey(r.Context(), account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn, - req.AutoGroups, req.UsageLimit, user.Id, ephemeral) + setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn, + req.AutoGroups, req.UsageLimit, userID, ephemeral) if err != nil { util.WriteError(r.Context(), err, w) return @@ -89,7 +89,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request // GetSetupKey is a GET request to get a SetupKey by ID func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -102,7 +102,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { return } - key, err := h.accountManager.GetSetupKey(r.Context(), account.Id, user.Id, keyID) + key, err := h.accountManager.GetSetupKey(r.Context(), accountID, userID, keyID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -114,7 +114,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { // UpdateSetupKey is a PUT request to update server.SetupKey func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -150,7 +150,7 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request newKey.Name = req.Name newKey.Id = keyID - newKey, err = h.accountManager.SaveSetupKey(r.Context(), account.Id, newKey, user.Id) + newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -161,13 +161,13 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request // GetAllSetupKeys is a GET request that returns a list of SetupKey func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), account.Id, user.Id) + setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/setupkeys_handler_test.go b/management/server/http/setupkeys_handler_test.go index bfa0ec008..2d15287af 100644 --- a/management/server/http/setupkeys_handler_test.go +++ b/management/server/http/setupkeys_handler_test.go @@ -15,7 +15,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" @@ -34,21 +33,8 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup ) *SetupKeysHandler { return &SetupKeysHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return &server.Account{ - Id: testAccountID, - Domain: "hotmail.com", - Users: map[string]*server.User{ - user.Id: user, - }, - SetupKeys: map[string]*server.SetupKey{ - defaultKey.Key: defaultKey, - }, - Groups: map[string]*nbgroup.Group{ - "group-1": {ID: "group-1", Peers: []string{"A", "B"}}, - "id-all": {ID: "id-all", Name: "All"}, - }, - }, user, nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, claims.UserId, nil }, CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string, _ int, _ string, ephemeral bool, diff --git a/management/server/http/users_handler.go b/management/server/http/users_handler.go index 2c2aed842..6e151a0da 100644 --- a/management/server/http/users_handler.go +++ b/management/server/http/users_handler.go @@ -41,22 +41,22 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) - userID := vars["userId"] - if len(userID) == 0 { + targetUserID := vars["userId"] + if len(targetUserID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } - existingUser, ok := account.Users[userID] - if !ok { - util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find user with ID %s", userID), w) + existingUser, err := h.accountManager.GetUserByID(r.Context(), targetUserID) + if err != nil { + util.WriteError(r.Context(), err, w) return } @@ -78,8 +78,8 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { return } - newUser, err := h.accountManager.SaveUser(r.Context(), account.Id, user.Id, &server.User{ - Id: userID, + newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &server.User{ + Id: targetUserID, Role: userRole, AutoGroups: req.AutoGroups, Blocked: req.IsBlocked, @@ -102,7 +102,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -115,7 +115,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.DeleteUser(r.Context(), account.Id, user.Id, targetUserID) + err = h.accountManager.DeleteUser(r.Context(), accountID, userID, targetUserID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -132,7 +132,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -160,7 +160,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { name = *req.Name } - newUser, err := h.accountManager.CreateUser(r.Context(), account.Id, user.Id, &server.UserInfo{ + newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &server.UserInfo{ Email: email, Name: name, Role: req.Role, @@ -184,13 +184,13 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - data, err := h.accountManager.GetUsersFromAccount(r.Context(), account.Id, user.Id) + data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -231,7 +231,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -244,7 +244,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.InviteUser(r.Context(), account.Id, user.Id, targetUserID) + err = h.accountManager.InviteUser(r.Context(), accountID, userID, targetUserID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/users_handler_test.go b/management/server/http/users_handler_test.go index a78ac3a4e..f3d989da1 100644 --- a/management/server/http/users_handler_test.go +++ b/management/server/http/users_handler_test.go @@ -64,8 +64,11 @@ var usersTestAccount = &server.Account{ func initUsersTestData() *UsersHandler { return &UsersHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return usersTestAccount, usersTestAccount.Users[claims.UserId], nil + GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return usersTestAccount.Id, claims.UserId, nil + }, + GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) { + return usersTestAccount.Users[id], nil }, GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) { users := make([]*server.UserInfo, 0) diff --git a/management/server/idp/zitadel.go b/management/server/idp/zitadel.go index 729b49733..9d7626844 100644 --- a/management/server/idp/zitadel.go +++ b/management/server/idp/zitadel.go @@ -2,10 +2,12 @@ package idp import ( "context" + "errors" "fmt" "io" "net/http" "net/url" + "slices" "strings" "sync" "time" @@ -97,6 +99,42 @@ type zitadelUserResponse struct { PasswordlessRegistration zitadelPasswordlessRegistration `json:"passwordlessRegistration"` } +// readZitadelError parses errors returned by the zitadel APIs from a response. +func readZitadelError(body io.ReadCloser) error { + bodyBytes, err := io.ReadAll(body) + if err != nil { + return fmt.Errorf("failed to read response body: %w", err) + } + + helper := JsonParser{} + var target map[string]interface{} + err = helper.Unmarshal(bodyBytes, &target) + if err != nil { + return fmt.Errorf("error unparsable body: %s", string(bodyBytes)) + } + + // ensure keys are ordered for consistent logging behaviour. + errorKeys := make([]string, 0, len(target)) + for k := range target { + errorKeys = append(errorKeys, k) + } + slices.Sort(errorKeys) + + var errsOut []string + for _, k := range errorKeys { + if _, isEmbedded := target[k].(map[string]interface{}); isEmbedded { + continue + } + errsOut = append(errsOut, fmt.Sprintf("%s: %v", k, target[k])) + } + + if len(errsOut) == 0 { + return errors.New("unknown error") + } + + return errors.New(strings.Join(errsOut, " ")) +} + // NewZitadelManager creates a new instance of the ZitadelManager. func NewZitadelManager(config ZitadelClientConfig, appMetrics telemetry.AppMetrics) (*ZitadelManager, error) { httpTransport := http.DefaultTransport.(*http.Transport).Clone() @@ -176,7 +214,8 @@ func (zc *ZitadelCredentials) requestJWTToken(ctx context.Context) (*http.Respon } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unable to get zitadel token, statusCode %d", resp.StatusCode) + zErr := readZitadelError(resp.Body) + return nil, fmt.Errorf("unable to get zitadel token, statusCode %d, zitadel: %w", resp.StatusCode, zErr) } return resp, nil @@ -489,7 +528,9 @@ func (zm *ZitadelManager) post(ctx context.Context, resource string, body string zm.appMetrics.IDPMetrics().CountRequestStatusError() } - return nil, fmt.Errorf("unable to post %s, statusCode %d", reqURL, resp.StatusCode) + zErr := readZitadelError(resp.Body) + + return nil, fmt.Errorf("unable to post %s, statusCode %d, zitadel: %w", reqURL, resp.StatusCode, zErr) } return io.ReadAll(resp.Body) @@ -561,7 +602,9 @@ func (zm *ZitadelManager) get(ctx context.Context, resource string, q url.Values zm.appMetrics.IDPMetrics().CountRequestStatusError() } - return nil, fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode) + zErr := readZitadelError(resp.Body) + + return nil, fmt.Errorf("unable to get %s, statusCode %d, zitadel: %w", reqURL, resp.StatusCode, zErr) } return io.ReadAll(resp.Body) diff --git a/management/server/idp/zitadel_test.go b/management/server/idp/zitadel_test.go index 6bc612e78..722f94fe0 100644 --- a/management/server/idp/zitadel_test.go +++ b/management/server/idp/zitadel_test.go @@ -66,7 +66,6 @@ func TestNewZitadelManager(t *testing.T) { } func TestZitadelRequestJWTToken(t *testing.T) { - type requestJWTTokenTest struct { name string inputCode int @@ -88,15 +87,14 @@ func TestZitadelRequestJWTToken(t *testing.T) { requestJWTTokenTestCase2 := requestJWTTokenTest{ name: "Request Bad Status Code", inputCode: 400, - inputRespBody: "{}", + inputRespBody: "{\"error\": \"invalid_scope\", \"error_description\":\"openid missing\"}", helper: JsonParser{}, - expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400"), + expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400, zitadel: error: invalid_scope error_description: openid missing"), expectedToken: "", } for _, testCase := range []requestJWTTokenTest{requestJWTTokenTesttCase1, requestJWTTokenTestCase2} { t.Run(testCase.name, func(t *testing.T) { - jwtReqClient := mockHTTPClient{ resBody: testCase.inputRespBody, code: testCase.inputCode, @@ -156,7 +154,7 @@ func TestZitadelParseRequestJWTResponse(t *testing.T) { } parseRequestJWTResponseTestCase2 := parseRequestJWTResponseTest{ name: "Parse Bad json JWT Body", - inputRespBody: "", + inputRespBody: "{}", helper: JsonParser{}, expectedToken: "", expectedExpiresIn: 0, @@ -254,7 +252,7 @@ func TestZitadelAuthenticate(t *testing.T) { inputCode: 400, inputResBody: "{}", helper: JsonParser{}, - expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400"), + expectedFuncExitErrDiff: fmt.Errorf("unable to get zitadel token, statusCode 400, zitadel: unknown error"), expectedCode: 200, expectedToken: "", } diff --git a/management/server/jwtclaims/jwtValidator.go b/management/server/jwtclaims/jwtValidator.go index 39676982e..d5c1e7c9e 100644 --- a/management/server/jwtclaims/jwtValidator.go +++ b/management/server/jwtclaims/jwtValidator.go @@ -1,14 +1,12 @@ package jwtclaims import ( - "bytes" "context" + "crypto/ecdsa" + "crypto/elliptic" "crypto/rsa" - "crypto/x509" "encoding/base64" - "encoding/binary" "encoding/json" - "encoding/pem" "errors" "fmt" "math/big" @@ -41,11 +39,6 @@ type Options struct { // When set, all requests with the OPTIONS method will use authentication // Default: false EnableAuthOnOptions bool - // When set, the middelware verifies that tokens are signed with the specific signing algorithm - // If the signing method is not constant the ValidationKeyGetter callback can be used to implement additional checks - // Important to avoid security issues described here: https://auth0.com/blog/critical-vulnerabilities-in-json-web-token-libraries/ - // Default: nil - SigningMethod jwt.SigningMethod } // Jwks is a collection of JSONWebKey obtained from Config.HttpServerConfig.AuthKeysLocation @@ -54,6 +47,18 @@ type Jwks struct { expiresInTime time.Time } +// The supported elliptic curves types +const ( + // p256 represents a cryptographic elliptical curve type. + p256 = "P-256" + + // p384 represents a cryptographic elliptical curve type. + p384 = "P-384" + + // p521 represents a cryptographic elliptical curve type. + p521 = "P-521" +) + // JSONWebKey is a representation of a Jason Web Key type JSONWebKey struct { Kty string `json:"kty"` @@ -61,6 +66,9 @@ type JSONWebKey struct { Use string `json:"use"` N string `json:"n"` E string `json:"e"` + Crv string `json:"crv"` + X string `json:"x"` + Y string `json:"y"` X5c []string `json:"x5c"` } @@ -115,15 +123,14 @@ func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, } } - cert, err := getPemCert(ctx, token, keys) + publicKey, err := getPublicKey(ctx, token, keys) if err != nil { + log.WithContext(ctx).Errorf("getPublicKey error: %s", err) return nil, err } - result, _ := jwt.ParseRSAPublicKeyFromPEM([]byte(cert)) - return result, nil + return publicKey, nil }, - SigningMethod: jwt.SigningMethodRS256, EnableAuthOnOptions: false, } @@ -159,15 +166,7 @@ func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt // Check if there was an error in parsing... if err != nil { log.WithContext(ctx).Errorf("error parsing token: %v", err) - return nil, fmt.Errorf("Error parsing token: %w", err) - } - - if m.options.SigningMethod != nil && m.options.SigningMethod.Alg() != parsedToken.Header["alg"] { - errorMsg := fmt.Sprintf("Expected %s signing method but token specified %s", - m.options.SigningMethod.Alg(), - parsedToken.Header["alg"]) - log.WithContext(ctx).Debugf("error validating token algorithm: %s", errorMsg) - return nil, fmt.Errorf("error validating token algorithm: %s", errorMsg) + return nil, fmt.Errorf("error parsing token: %w", err) } // Check if the parsed token is valid... @@ -205,9 +204,8 @@ func getPemKeys(ctx context.Context, keysLocation string) (*Jwks, error) { return jwks, err } -func getPemCert(ctx context.Context, token *jwt.Token, jwks *Jwks) (string, error) { +func getPublicKey(ctx context.Context, token *jwt.Token, jwks *Jwks) (interface{}, error) { // todo as we load the jkws when the server is starting, we should build a JKS map with the pem cert at the boot time - cert := "" for k := range jwks.Keys { if token.Header["kid"] != jwks.Keys[k].Kid { @@ -215,73 +213,79 @@ func getPemCert(ctx context.Context, token *jwt.Token, jwks *Jwks) (string, erro } if len(jwks.Keys[k].X5c) != 0 { - cert = "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----" - return cert, nil + cert := "-----BEGIN CERTIFICATE-----\n" + jwks.Keys[k].X5c[0] + "\n-----END CERTIFICATE-----" + return jwt.ParseRSAPublicKeyFromPEM([]byte(cert)) } - log.WithContext(ctx).Debugf("generating validation pem from JWK") - return generatePemFromJWK(jwks.Keys[k]) + + if jwks.Keys[k].Kty == "RSA" { + log.WithContext(ctx).Debugf("generating PublicKey from RSA JWK") + return getPublicKeyFromRSA(jwks.Keys[k]) + } + if jwks.Keys[k].Kty == "EC" { + log.WithContext(ctx).Debugf("generating PublicKey from ECDSA JWK") + return getPublicKeyFromECDSA(jwks.Keys[k]) + } + + log.WithContext(ctx).Debugf("Key Type: %s not yet supported, please raise ticket!", jwks.Keys[k].Kty) } - return cert, errors.New("unable to find appropriate key") + return nil, errors.New("unable to find appropriate key") } -func generatePemFromJWK(jwk JSONWebKey) (string, error) { - decodedModulus, err := base64.RawURLEncoding.DecodeString(jwk.N) - if err != nil { - return "", fmt.Errorf("unable to decode JWK modulus, error: %s", err) +func getPublicKeyFromECDSA(jwk JSONWebKey) (publicKey *ecdsa.PublicKey, err error) { + + if jwk.X == "" || jwk.Y == "" || jwk.Crv == "" { + return nil, fmt.Errorf("ecdsa key incomplete") } - intModules := big.NewInt(0) - intModules.SetBytes(decodedModulus) - - exponent, err := convertExponentStringToInt(jwk.E) - if err != nil { - return "", fmt.Errorf("unable to decode JWK exponent, error: %s", err) + var xCoordinate []byte + if xCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.X); err != nil { + return nil, err } - publicKey := &rsa.PublicKey{ - N: intModules, - E: exponent, + var yCoordinate []byte + if yCoordinate, err = base64.RawURLEncoding.DecodeString(jwk.Y); err != nil { + return nil, err } - derKey, err := x509.MarshalPKIXPublicKey(publicKey) - if err != nil { - return "", fmt.Errorf("unable to convert public key to DER, error: %s", err) + publicKey = &ecdsa.PublicKey{} + + var curve elliptic.Curve + switch jwk.Crv { + case p256: + curve = elliptic.P256() + case p384: + curve = elliptic.P384() + case p521: + curve = elliptic.P521() } - block := &pem.Block{ - Type: "RSA PUBLIC KEY", - Bytes: derKey, - } + publicKey.Curve = curve + publicKey.X = big.NewInt(0).SetBytes(xCoordinate) + publicKey.Y = big.NewInt(0).SetBytes(yCoordinate) - var out bytes.Buffer - err = pem.Encode(&out, block) - if err != nil { - return "", fmt.Errorf("unable to encode Pem block , error: %s", err) - } - - return out.String(), nil + return publicKey, nil } -func convertExponentStringToInt(stringExponent string) (int, error) { - decodedString, err := base64.StdEncoding.DecodeString(stringExponent) +func getPublicKeyFromRSA(jwk JSONWebKey) (*rsa.PublicKey, error) { + + decodedE, err := base64.RawURLEncoding.DecodeString(jwk.E) if err != nil { - return 0, err + return nil, err } - exponentBytes := decodedString - if len(decodedString) < 8 { - exponentBytes = make([]byte, 8-len(decodedString), 8) - exponentBytes = append(exponentBytes, decodedString...) + decodedN, err := base64.RawURLEncoding.DecodeString(jwk.N) + if err != nil { + return nil, err } - bytesReader := bytes.NewReader(exponentBytes) - var exponent uint64 - err = binary.Read(bytesReader, binary.BigEndian, &exponent) - if err != nil { - return 0, err - } + var n, e big.Int + e.SetBytes(decodedE) + n.SetBytes(decodedN) - return int(exponent), nil + return &rsa.PublicKey{ + E: int(e.Int64()), + N: &n, + }, nil } // getMaxAgeFromCacheHeader extracts max-age directive from the Cache-Control header @@ -306,3 +310,4 @@ func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int { return 0 } + diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index fe1e36d47..ff09129bd 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -3,13 +3,17 @@ package server import ( "context" "fmt" + "io" "net" "os" "path/filepath" "runtime" + "sync" + "sync/atomic" "testing" "time" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" @@ -24,6 +28,12 @@ import ( "github.com/netbirdio/netbird/util" ) +type TestingT interface { + require.TestingT + Helper() + Cleanup(func()) +} + var ( kaep = keepalive.EnforcementPolicy{ MinTime: 15 * time.Second, @@ -86,7 +96,7 @@ func Test_SyncProtocol(t *testing.T) { defer func() { os.Remove(filepath.Join(dir, "store.json")) //nolint }() - mgmtServer, _, mgmtAddr, err := startManagement(t, &Config{ + mgmtServer, _, mgmtAddr, err := startManagementForTest(t, &Config{ Stuns: []*Host{{ Proto: "udp", URI: "stun:stun.wiretrustee.com:3468", @@ -402,7 +412,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) { } } -func startManagement(t *testing.T, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) { +func startManagementForTest(t TestingT, config *Config) (*grpc.Server, *DefaultAccountManager, string, error) { t.Helper() lis, err := net.Listen("tcp", "localhost:0") if err != nil { @@ -429,10 +439,11 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, *DefaultAccoun if err != nil { return nil, nil, "", err } - turnManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) + + secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) ephemeralMgr := NewEphemeralManager(store, accountManager) - mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, ephemeralMgr) + mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, ephemeralMgr) if err != nil { return nil, nil, "", err } @@ -485,7 +496,7 @@ func testSyncStatusRace(t *testing.T) { os.Remove(filepath.Join(dir, "store.json")) //nolint }() - mgmtServer, am, mgmtAddr, err := startManagement(t, &Config{ + mgmtServer, am, mgmtAddr, err := startManagementForTest(t, &Config{ Stuns: []*Host{{ Proto: "udp", URI: "stun:stun.wiretrustee.com:3468", @@ -545,7 +556,6 @@ func testSyncStatusRace(t *testing.T) { ctx2, cancelFunc2 := context.WithCancel(context.Background()) - //client. sync2, err := client.Sync(ctx2, &mgmtProto.EncryptedMessage{ WgPubKey: concurrentPeerKey2.PublicKey().String(), Body: message2, @@ -574,7 +584,7 @@ func testSyncStatusRace(t *testing.T) { ctx, cancelFunc := context.WithCancel(context.Background()) - //client. + // client. sync, err := client.Sync(ctx, &mgmtProto.EncryptedMessage{ WgPubKey: peerWithInvalidStatus.PublicKey().String(), Body: message, @@ -617,7 +627,7 @@ func testSyncStatusRace(t *testing.T) { } time.Sleep(10 * time.Millisecond) - peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), peerWithInvalidStatus.PublicKey().String()) + peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peerWithInvalidStatus.PublicKey().String()) if err != nil { t.Fatal(err) return @@ -626,3 +636,221 @@ func testSyncStatusRace(t *testing.T) { t.Fatal("Peer should be connected") } } + +func Test_LoginPerformance(t *testing.T) { + if os.Getenv("CI") == "true" || runtime.GOOS == "windows" { + t.Skip("Skipping test on CI or Windows") + } + + t.Setenv("NETBIRD_STORE_ENGINE", "sqlite") + + benchCases := []struct { + name string + peers int + accounts int + }{ + // {"XXS", 5, 1}, + // {"XS", 10, 1}, + // {"S", 100, 1}, + // {"M", 250, 1}, + // {"L", 500, 1}, + // {"XL", 750, 1}, + {"XXL", 5000, 1}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + for _, bc := range benchCases { + t.Run(bc.name, func(t *testing.T) { + t.Helper() + dir := t.TempDir() + err := util.CopyFileContents("testdata/store_with_expired_peers.json", filepath.Join(dir, "store.json")) + if err != nil { + t.Fatal(err) + } + defer func() { + os.Remove(filepath.Join(dir, "store.json")) //nolint + }() + + mgmtServer, am, _, err := startManagementForTest(t, &Config{ + Stuns: []*Host{{ + Proto: "udp", + URI: "stun:stun.wiretrustee.com:3468", + }}, + TURNConfig: &TURNConfig{ + TimeBasedCredentials: false, + CredentialsTTL: util.Duration{}, + Secret: "whatever", + Turns: []*Host{{ + Proto: "udp", + URI: "turn:stun.wiretrustee.com:3468", + }}, + }, + Signal: &Host{ + Proto: "http", + URI: "signal.wiretrustee.com:10000", + }, + Datadir: dir, + HttpConfig: nil, + }) + if err != nil { + t.Fatal(err) + return + } + defer mgmtServer.GracefulStop() + + t.Logf("management setup complete, start registering peers") + + var counter int32 + var counterStart int32 + var wgAccount sync.WaitGroup + var mu sync.Mutex + messageCalls := []func() error{} + for j := 0; j < bc.accounts; j++ { + wgAccount.Add(1) + var wgPeer sync.WaitGroup + go func(j int, counter *int32, counterStart *int32) { + defer wgAccount.Done() + + account, err := createAccount(am, fmt.Sprintf("account-%d", j), fmt.Sprintf("user-%d", j), fmt.Sprintf("domain-%d", j)) + if err != nil { + t.Logf("account creation failed: %v", err) + return + } + + setupKey, err := am.CreateSetupKey(context.Background(), account.Id, fmt.Sprintf("key-%d", j), SetupKeyReusable, time.Hour, nil, 0, fmt.Sprintf("user-%d", j), false) + if err != nil { + t.Logf("error creating setup key: %v", err) + return + } + + startTime := time.Now() + for i := 0; i < bc.peers; i++ { + wgPeer.Add(1) + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Logf("failed to generate key: %v", err) + return + } + + meta := &mgmtProto.PeerSystemMeta{ + Hostname: key.PublicKey().String(), + GoOS: runtime.GOOS, + OS: runtime.GOOS, + Core: "core", + Platform: "platform", + Kernel: "kernel", + WiretrusteeVersion: "", + } + + peerLogin := PeerLogin{ + WireGuardPubKey: key.String(), + SSHKey: "random", + Meta: extractPeerMeta(context.Background(), meta), + SetupKey: setupKey.Key, + ConnectionIP: net.IP{1, 1, 1, 1}, + } + + login := func() error { + _, _, _, err = am.LoginPeer(context.Background(), peerLogin) + if err != nil { + t.Logf("failed to login peer: %v", err) + return err + } + atomic.AddInt32(counter, 1) + if *counter%100 == 0 { + t.Logf("finished %d login calls", *counter) + } + return nil + } + + mu.Lock() + messageCalls = append(messageCalls, login) + mu.Unlock() + + go func(peerLogin PeerLogin, counterStart *int32) { + defer wgPeer.Done() + _, _, _, err = am.LoginPeer(context.Background(), peerLogin) + if err != nil { + t.Logf("failed to login peer: %v", err) + return + } + + atomic.AddInt32(counterStart, 1) + if *counterStart%100 == 0 { + t.Logf("registered %d peers", *counterStart) + } + }(peerLogin, counterStart) + + } + wgPeer.Wait() + + t.Logf("Time for registration: %s", time.Since(startTime)) + }(j, &counter, &counterStart) + } + + wgAccount.Wait() + + t.Logf("prepared %d login calls", len(messageCalls)) + testLoginPerformance(t, messageCalls) + + }) + } +} + +func testLoginPerformance(t *testing.T, loginCalls []func() error) { + t.Helper() + wgSetup := sync.WaitGroup{} + startChan := make(chan struct{}) + + wgDone := sync.WaitGroup{} + durations := []time.Duration{} + l := sync.Mutex{} + + for i, function := range loginCalls { + wgSetup.Add(1) + wgDone.Add(1) + go func(function func() error, i int) { + defer wgDone.Done() + wgSetup.Done() + + <-startChan + start := time.Now() + + err := function() + if err != nil { + t.Logf("Error: %v", err) + return + } + + duration := time.Since(start) + l.Lock() + durations = append(durations, duration) + l.Unlock() + }(function, i) + } + + wgSetup.Wait() + t.Logf("starting login calls") + close(startChan) + wgDone.Wait() + var tMin, tMax, tSum time.Duration + for i, d := range durations { + if i == 0 { + tMin = d + tMax = d + tSum = d + continue + } + if d < tMin { + tMin = d + } + if d > tMax { + tMax = d + } + tSum += d + } + tAvg := tSum / time.Duration(len(durations)) + t.Logf("Min: %v, Max: %v, Avg: %v", tMin, tMax, tAvg) +} diff --git a/management/server/management_test.go b/management/server/management_test.go index 62e7f5a05..3956d96b1 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -552,8 +552,9 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) { if err != nil { log.Fatalf("failed creating a manager: %v", err) } - turnManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, turnManager, nil, nil) + + secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil) Expect(err).NotTo(HaveOccurred()) mgmtProto.RegisterManagementServiceServer(s, mgmtServer) go func() { diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index f4452eaee..8dc8f6a4f 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -23,10 +23,11 @@ import ( type MockAccountManager struct { GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error) + GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error) CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) - GetAccountByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (*server.Account, error) + GetAccountIDByUserOrAccountIdFunc func(ctx context.Context, userId, accountId, domain string) (string, error) GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) @@ -48,7 +49,7 @@ type MockAccountManager struct { GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) - SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) error + SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error) GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error) @@ -56,7 +57,7 @@ type MockAccountManager struct { MarkPATUsedFunc func(ctx context.Context, pat string) error UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) - CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) + CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups,accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error @@ -78,7 +79,7 @@ type MockAccountManager struct { DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) - GetAccountFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) + GetAccountIDFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error DeleteAccountFunc func(ctx context.Context, accountID, userID string) error GetDNSDomainFunc func() string @@ -104,6 +105,9 @@ type MockAccountManager struct { SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error) + GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*server.Account, error) + GetUserByIDFunc func(ctx context.Context, id string) (*server.User, error) + GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*server.Settings, error) } func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { @@ -189,16 +193,14 @@ func (am *MockAccountManager) CreateSetupKey( return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") } -// GetAccountByUserOrAccountID mock implementation of GetAccountByUserOrAccountID from server.AccountManager interface -func (am *MockAccountManager) GetAccountByUserOrAccountID( - ctx context.Context, userId, accountId, domain string, -) (*server.Account, error) { - if am.GetAccountByUserOrAccountIdFunc != nil { - return am.GetAccountByUserOrAccountIdFunc(ctx, userId, accountId, domain) +// GetAccountIDByUserOrAccountID mock implementation of GetAccountIDByUserOrAccountID from server.AccountManager interface +func (am *MockAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userId, accountId, domain string) (string, error) { + if am.GetAccountIDByUserOrAccountIdFunc != nil { + return am.GetAccountIDByUserOrAccountIdFunc(ctx, userId, accountId, domain) } - return nil, status.Errorf( + return "", status.Errorf( codes.Unimplemented, - "method GetAccountByUserOrAccountID is not implemented", + "method GetAccountIDByUserOrAccountID is not implemented", ) } @@ -364,7 +366,7 @@ func (am *MockAccountManager) DeleteRule(ctx context.Context, accountID, ruleID, if am.DeleteRuleFunc != nil { return am.DeleteRuleFunc(ctx, accountID, ruleID, userID) } - return status.Errorf(codes.Unimplemented, "method DeleteRule is not implemented") + return status.Errorf(codes.Unimplemented, "method DeletePeerRule is not implemented") } // GetPolicy mock implementation of GetPolicy from server.AccountManager interface @@ -376,9 +378,9 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID } // SavePolicy mock implementation of SavePolicy from server.AccountManager interface -func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) error { +func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error { if am.SavePolicyFunc != nil { - return am.SavePolicyFunc(ctx, accountID, userID, policy) + return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate) } return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented") } @@ -431,9 +433,9 @@ func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID } // CreateRoute mock implementation of CreateRoute from server.AccountManager interface -func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { +func (am *MockAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupID []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { if am.CreateRouteFunc != nil { - return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups, enabled, userID, keepRoute) + return am.CreateRouteFunc(ctx, accountID, prefix, networkType, domains, peerID, peerGroupIDs, description, netID, masquerade, metric, groups,accessControlGroupID, enabled, userID, keepRoute) } return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented") } @@ -592,14 +594,12 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID return nil, status.Errorf(codes.Unimplemented, "method CreateUser is not implemented") } -// GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface -func (am *MockAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, - error, -) { - if am.GetAccountFromTokenFunc != nil { - return am.GetAccountFromTokenFunc(ctx, claims) +// GetAccountIDFromToken mocks GetAccountIDFromToken of the AccountManager interface +func (am *MockAccountManager) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + if am.GetAccountIDFromTokenFunc != nil { + return am.GetAccountIDFromTokenFunc(ctx, claims) } - return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented") + return "", "", status.Errorf(codes.Unimplemented, "method GetAccountIDFromToken is not implemented") } func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error { @@ -793,3 +793,33 @@ func (am *MockAccountManager) GetAccountIDForPeerKey(ctx context.Context, peerKe } return "", status.Errorf(codes.Unimplemented, "method GetAccountIDForPeerKey is not implemented") } + +// GetAccountByID mocks GetAccountByID of the AccountManager interface +func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*server.Account, error) { + if am.GetAccountByIDFunc != nil { + return am.GetAccountByIDFunc(ctx, accountID, userID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAccountByID is not implemented") +} + +// GetUserByID mocks GetUserByID of the AccountManager interface +func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*server.User, error) { + if am.GetUserByIDFunc != nil { + return am.GetUserByIDFunc(ctx, id) + } + return nil, status.Errorf(codes.Unimplemented, "method GetUserByID is not implemented") +} + +func (am *MockAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*server.Settings, error) { + if am.GetAccountSettingsFunc != nil { + return am.GetAccountSettingsFunc(ctx, accountID, userID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAccountSettings is not implemented") +} + +func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) (*server.Account, error) { + if am.GetAccountFunc != nil { + return am.GetAccountFunc(ctx, accountID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAccount is not implemented") +} diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 2cd934065..751ca12bc 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -19,30 +19,16 @@ const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$` // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { - - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { + return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") } - if !(user.HasAdminPower() || user.IsServiceUser) { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view nameserver groups") - } - - nsGroup, found := account.NameServerGroups[nsGroupID] - if found { - return nsGroup.Copy(), nil - } - - return nil, status.Errorf(status.NotFound, "nameserver group with ID %s not found", nsGroupID) + return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID) } // CreateNameServerGroup creates and saves a new nameserver group @@ -159,30 +145,16 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco // ListNameServerGroups returns a list of nameserver groups from account func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) { - - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !(user.HasAdminPower() || user.IsServiceUser) { + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") } - nsGroups := make([]*nbdns.NameServerGroup, 0, len(account.NameServerGroups)) - for _, item := range account.NameServerGroups { - nsGroups = append(nsGroups, item.Copy()) - } - - return nsGroups, nil + return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) } func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error { diff --git a/management/server/network.go b/management/server/network.go index 91d844c3e..8fb6a8b3c 100644 --- a/management/server/network.go +++ b/management/server/network.go @@ -26,12 +26,13 @@ const ( ) type NetworkMap struct { - Peers []*nbpeer.Peer - Network *Network - Routes []*route.Route - DNSConfig nbdns.Config - OfflinePeers []*nbpeer.Peer - FirewallRules []*FirewallRule + Peers []*nbpeer.Peer + Network *Network + Routes []*route.Route + DNSConfig nbdns.Config + OfflinePeers []*nbpeer.Peer + FirewallRules []*FirewallRule + RoutesFirewallRules []*RouteFirewallRule } type Network struct { diff --git a/management/server/peer.go b/management/server/peer.go index 7e8dbe1a6..97e11c08a 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "net" - "slices" "strings" "sync" "time" @@ -12,6 +11,7 @@ import ( "github.com/rs/xid" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/proto" @@ -185,9 +185,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) } - peerLabelUpdated := peer.Name != update.Name - - if peerLabelUpdated { + if peer.Name != update.Name { peer.Name = update.Name existingLabels := account.getPeerDNSLabels() @@ -222,16 +220,13 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } account.UpdatePeer(peer) + err = am.Store.SaveAccount(ctx, account) if err != nil { return nil, err } - expired, _ := peer.LoginExpired(account.Settings.PeerLoginExpiration) - - if peerLabelUpdated || (expired && peer.LoginExpirationEnabled) { - am.updateAccountPeers(ctx, account) - } + am.updateAccountPeers(ctx, account) return peer, nil } @@ -295,8 +290,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } - updateAccountPeers := isPeerInActiveGroup(account, peerID) - err = am.deletePeers(ctx, account, []string{peerID}, userID) if err != nil { return err @@ -307,9 +300,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } - if updateAccountPeers { - am.updateAccountPeers(ctx, account) - } + am.updateAccountPeers(ctx, account) return nil } @@ -383,179 +374,207 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } }() - var account *Account - // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account) - account, err = am.Store.GetAccount(ctx, accountID) - if err != nil { - return nil, nil, nil, err - } - - if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" { - if am.idpManager != nil { - userdata, err := am.lookupUserInCache(ctx, userID, account) - if err == nil && userdata != nil { - peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0]) - } - } - } - // This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice. // Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow) // and the peer disconnects with a timeout and tries to register again. // We just check if this machine has been registered before and reject the second registration. // The connecting peer should be able to recover with a retry. - _, err = account.FindPeerByPubKey(peer.Key) + _, err = am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peer.Key) if err == nil { return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered") } opEvent := &activity.Event{ Timestamp: time.Now().UTC(), - AccountID: account.Id, + AccountID: accountID, } - var ephemeral bool - setupKeyName := "" - if !addedByUser { - // validate the setup key if adding with a key - sk, err := account.FindSetupKey(upperKey) - if err != nil { - return nil, nil, nil, err - } + var newPeer *nbpeer.Peer - if !sk.IsValid() { - return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid") - } - - account.SetupKeys[sk.Key] = sk.IncrementUsage() - opEvent.InitiatorID = sk.Id - opEvent.Activity = activity.PeerAddedWithSetupKey - ephemeral = sk.Ephemeral - setupKeyName = sk.Name - } else { - opEvent.InitiatorID = userID - opEvent.Activity = activity.PeerAddedByUser - } - - takenIps := account.getTakenIPs() - existingLabels := account.getPeerDNSLabels() - - newLabel, err := getPeerHostLabel(peer.Meta.Hostname, existingLabels) - if err != nil { - return nil, nil, nil, err - } - - peer.DNSLabel = newLabel - network := account.Network - nextIp, err := AllocatePeerIP(network.Net, takenIps) - if err != nil { - return nil, nil, nil, err - } - - registrationTime := time.Now().UTC() - - newPeer := &nbpeer.Peer{ - ID: xid.New().String(), - Key: peer.Key, - SetupKey: upperKey, - IP: nextIp, - Meta: peer.Meta, - Name: peer.Meta.Hostname, - DNSLabel: newLabel, - UserID: userID, - Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime}, - SSHEnabled: false, - SSHKey: peer.SSHKey, - LastLogin: registrationTime, - CreatedAt: registrationTime, - LoginExpirationEnabled: addedByUser, - Ephemeral: ephemeral, - Location: peer.Location, - } - - if am.geo != nil && newPeer.Location.ConnectionIP != nil { - location, err := am.geo.Lookup(newPeer.Location.ConnectionIP) - if err != nil { - log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err) + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + var groupsToAdd []string + var setupKeyID string + var setupKeyName string + var ephemeral bool + if addedByUser { + user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID) + if err != nil { + return fmt.Errorf("failed to get user groups: %w", err) + } + groupsToAdd = user.AutoGroups + opEvent.InitiatorID = userID + opEvent.Activity = activity.PeerAddedByUser } else { - newPeer.Location.CountryCode = location.Country.ISOCode - newPeer.Location.CityName = location.City.Names.En - newPeer.Location.GeoNameID = location.City.GeonameID - } - } + // Validate the setup key + sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, upperKey) + if err != nil { + return fmt.Errorf("failed to get setup key: %w", err) + } - // add peer to 'All' group - group, err := account.GetGroupAll() - if err != nil { - return nil, nil, nil, err - } - group.Peers = append(group.Peers, newPeer.ID) + if !sk.IsValid() { + return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid") + } - var groupsToAdd []string - if addedByUser { - groupsToAdd, err = account.getUserGroups(userID) - if err != nil { - return nil, nil, nil, err + opEvent.InitiatorID = sk.Id + opEvent.Activity = activity.PeerAddedWithSetupKey + groupsToAdd = sk.AutoGroups + ephemeral = sk.Ephemeral + setupKeyID = sk.Id + setupKeyName = sk.Name } - } else { - groupsToAdd, err = account.getSetupKeyGroups(upperKey) - if err != nil { - return nil, nil, nil, err - } - } - if len(groupsToAdd) > 0 { - for _, s := range groupsToAdd { - if g, ok := account.Groups[s]; ok && g.Name != "All" { - g.Peers = append(g.Peers, newPeer.ID) + if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" { + if am.idpManager != nil { + userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID}) + if err == nil && userdata != nil { + peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0]) + } } } - } - newPeer = am.integratedPeerValidator.PreparePeer(ctx, account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra) - - if addedByUser { - user, err := account.FindUser(userID) + freeLabel, err := am.getFreeDNSLabel(ctx, transaction, accountID, peer.Meta.Hostname) if err != nil { - return nil, nil, nil, status.Errorf(status.Internal, "couldn't find user") + return fmt.Errorf("failed to get free DNS label: %w", err) } - user.updateLastLogin(newPeer.LastLogin) - } - account.Peers[newPeer.ID] = newPeer - account.Network.IncSerial() - err = am.Store.SaveAccount(ctx, account) + freeIP, err := am.getFreeIP(ctx, transaction, accountID) + if err != nil { + return fmt.Errorf("failed to get free IP: %w", err) + } + + registrationTime := time.Now().UTC() + newPeer = &nbpeer.Peer{ + ID: xid.New().String(), + AccountID: accountID, + Key: peer.Key, + SetupKey: upperKey, + IP: freeIP, + Meta: peer.Meta, + Name: peer.Meta.Hostname, + DNSLabel: freeLabel, + UserID: userID, + Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime}, + SSHEnabled: false, + SSHKey: peer.SSHKey, + LastLogin: registrationTime, + CreatedAt: registrationTime, + LoginExpirationEnabled: addedByUser, + Ephemeral: ephemeral, + Location: peer.Location, + } + opEvent.TargetID = newPeer.ID + opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain()) + if !addedByUser { + opEvent.Meta["setup_key_name"] = setupKeyName + } + + if am.geo != nil && newPeer.Location.ConnectionIP != nil { + location, err := am.geo.Lookup(newPeer.Location.ConnectionIP) + if err != nil { + log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err) + } else { + newPeer.Location.CountryCode = location.Country.ISOCode + newPeer.Location.CityName = location.City.Names.En + newPeer.Location.GeoNameID = location.City.GeonameID + } + } + + settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return fmt.Errorf("failed to get account settings: %w", err) + } + newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra) + + err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID) + if err != nil { + return fmt.Errorf("failed adding peer to All group: %w", err) + } + + if len(groupsToAdd) > 0 { + for _, g := range groupsToAdd { + err = transaction.AddPeerToGroup(ctx, accountID, newPeer.ID, g) + if err != nil { + return err + } + } + } + + err = transaction.AddPeerToAccount(ctx, newPeer) + if err != nil { + return fmt.Errorf("failed to add peer to account: %w", err) + } + + err = transaction.IncrementNetworkSerial(ctx, accountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + if addedByUser { + err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.LastLogin) + if err != nil { + return fmt.Errorf("failed to update user last login: %w", err) + } + } else { + err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID) + if err != nil { + return fmt.Errorf("failed to increment setup key usage: %w", err) + } + } + + log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID) + return nil + }) + if err != nil { - return nil, nil, nil, err + return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err) } - // Account is saved, we can release the lock - unlock() - unlock = nil - - opEvent.TargetID = newPeer.ID - opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain()) - if !addedByUser { - opEvent.Meta["setup_key_name"] = setupKeyName + if newPeer == nil { + return nil, nil, nil, fmt.Errorf("new peer is nil") } am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) - if areGroupChangesAffectPeers(account, groupsToAdd) { - am.updateAccountPeers(ctx, account) + unlock() + unlock = nil + + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, nil, nil, fmt.Errorf("error getting account: %w", err) } + am.updateAccountPeers(ctx, account) + approvedPeersMap, err := am.GetValidatedPeers(account) if err != nil { return nil, nil, nil, err } - postureChecks := am.getPeerPostureChecks(account, peer) + postureChecks := am.getPeerPostureChecks(account, newPeer) customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) return newPeer, networkMap, postureChecks, nil } +func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) { + takenIps, err := store.GetTakenIPs(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, fmt.Errorf("failed to get taken IPs: %w", err) + } + + network, err := store.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID) + if err != nil { + return nil, fmt.Errorf("failed getting network: %w", err) + } + + nextIp, err := AllocatePeerIP(network.Net, takenIps) + if err != nil { + return nil, fmt.Errorf("failed to allocate new peer ip: %w", err) + } + + return nextIp, nil +} + // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey) @@ -563,16 +582,23 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac return nil, nil, nil, status.NewPeerNotRegisteredError() } - err = checkIfPeerOwnerIsBlocked(peer, account) - if err != nil { - return nil, nil, nil, err + if peer.UserID != "" { + user, err := account.FindUser(peer.UserID) + if err != nil { + return nil, nil, nil, err + } + + err = checkIfPeerOwnerIsBlocked(peer, user) + if err != nil { + return nil, nil, nil, err + } } if peerLoginExpired(ctx, peer, account.Settings) { return nil, nil, nil, status.NewPeerLoginExpiredError() } - peer, updated := updatePeerMeta(peer, sync.Meta, account) + updated := peer.UpdateMetaIfNew(sync.Meta) if updated { err = am.Store.SavePeer(ctx, account.Id, peer) if err != nil { @@ -589,22 +615,16 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac return nil, nil, nil, err } + var postureChecks []*posture.Checks + if peerNotValid { emptyMap := &NetworkMap{ Network: account.Network.Copy(), } - return peer, emptyMap, nil, nil + return peer, emptyMap, postureChecks, nil } - peer, peerMetaUpdated := updatePeerMeta(peer, sync.Meta, account) - if peerMetaUpdated { - err = am.Store.SaveAccount(ctx, account) - if err != nil { - return nil, nil, nil, err - } - } - - if isStatusChanged || (peerMetaUpdated && sync.UpdateAccountPeers) { + if isStatusChanged { am.updateAccountPeers(ctx, account) } @@ -612,7 +632,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac if err != nil { return nil, nil, nil, err } - postureChecks := am.getPeerPostureChecks(account, peer) + postureChecks = am.getPeerPostureChecks(account, peer) customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil @@ -644,31 +664,28 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) // it means that the client has already checked if it needs login and had been through the SSO flow // so, we can skip this check and directly proceed with the login if login.UserID == "" { + log.Info("Peer needs login") err = am.checkIFPeerNeedsLoginWithoutLock(ctx, accountID, login) if err != nil { return nil, nil, nil, err } } - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + unlockAccount := am.Store.AcquireReadLockByUID(ctx, accountID) + defer unlockAccount() + unlockPeer := am.Store.AcquireWriteLockByUID(ctx, login.WireGuardPubKey) defer func() { - if unlock != nil { - unlock() + if unlockPeer != nil { + unlockPeer() } }() - // fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies - account, err := am.Store.GetAccount(ctx, accountID) + peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey) if err != nil { return nil, nil, nil, err } - peer, err := account.FindPeerByPubKey(login.WireGuardPubKey) - if err != nil { - return nil, nil, nil, status.NewPeerNotRegisteredError() - } - - err = checkIfPeerOwnerIsBlocked(peer, account) + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { return nil, nil, nil, err } @@ -676,21 +693,39 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) // this flag prevents unnecessary calls to the persistent store. shouldStorePeer := false updateRemotePeers := false - if peerLoginExpired(ctx, peer, account.Settings) { - err = am.handleExpiredPeer(ctx, login, account, peer) + + if login.UserID != "" { + changed, err := am.handleUserPeer(ctx, peer, settings) if err != nil { return nil, nil, nil, err } - updateRemotePeers = true - shouldStorePeer = true + if changed { + shouldStorePeer = true + updateRemotePeers = true + } } - isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) + groups, err := am.Store.GetAccountGroups(ctx, accountID) if err != nil { return nil, nil, nil, err } - peer, updated := updatePeerMeta(peer, login.Meta, account) + var grps []string + for _, group := range groups { + for _, id := range group.Peers { + if id == peer.ID { + grps = append(grps, group.ID) + break + } + } + } + + isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, grps, settings.Extra) + if err != nil { + return nil, nil, nil, err + } + + updated := peer.UpdateMetaIfNew(login.Meta) if updated { shouldStorePeer = true } @@ -707,8 +742,13 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } } - unlock() - unlock = nil + unlockPeer() + unlockPeer = nil + + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, nil, nil, err + } if updateRemotePeers || isStatusChanged { am.updateAccountPeers(ctx, account) @@ -723,7 +763,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) // with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired // and before starting the engine, we do the checks without an account lock to avoid piling up requests. func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error { - peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey) + peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, login.WireGuardPubKey) if err != nil { return err } @@ -734,7 +774,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co return nil } - settings, err := am.Store.GetAccountSettings(ctx, accountID) + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { return err } @@ -766,36 +806,30 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil } -func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, login PeerLogin, account *Account, peer *nbpeer.Peer) error { - err := checkAuth(ctx, login.UserID, peer) +func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *User, peer *nbpeer.Peer) error { + err := checkAuth(ctx, user.Id, peer) if err != nil { return err } // If peer was expired before and if it reached this point, it is re-authenticated. // UserID is present, meaning that JWT validation passed successfully in the API layer. - updatePeerLastLogin(peer, account) - - // sync user last login with peer last login - user, err := account.FindUser(login.UserID) - if err != nil { - return status.Errorf(status.Internal, "couldn't find user") - } - - err = am.Store.SaveUserLastLogin(account.Id, user.Id, peer.LastLogin) + peer = peer.UpdateLastLogin() + err = am.Store.SavePeer(ctx, peer.AccountID, peer) if err != nil { return err } - am.StoreEvent(ctx, login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain())) + err = am.Store.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.LastLogin) + if err != nil { + return err + } + + am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain())) return nil } -func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error { +func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, user *User) error { if peer.AddedWithSSOLogin() { - user, err := account.FindUser(peer.UserID) - if err != nil { - return status.Errorf(status.PermissionDenied, "user doesn't exist") - } if user.IsBlocked() { return status.Errorf(status.PermissionDenied, "user is blocked") } @@ -825,9 +859,49 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings return false } -func updatePeerLastLogin(peer *nbpeer.Peer, account *Account) { - peer.UpdateLastLogin() +// UpdatePeerSSHKey updates peer's public SSH key +func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error { + if sshKey == "" { + log.WithContext(ctx).Debugf("empty SSH key provided for peer %s, skipping update", peerID) + return nil + } + + account, err := am.Store.GetAccountByPeerID(ctx, peerID) + if err != nil { + return err + } + + unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id) + defer unlock() + + // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account) + account, err = am.Store.GetAccount(ctx, account.Id) + if err != nil { + return err + } + + peer := account.GetPeer(peerID) + if peer == nil { + return status.Errorf(status.NotFound, "peer with ID %s not found", peerID) + } + + if peer.SSHKey == sshKey { + log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peerID) + return nil + } + + peer.SSHKey = sshKey account.UpdatePeer(peer) + + err = am.Store.SaveAccount(ctx, account) + if err != nil { + return err + } + + // trigger network map update + am.updateAccountPeers(ctx, account) + + return nil } // GetPeer for a given accountID, peerID and userID error if not found. @@ -883,14 +957,6 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, return nil, status.Errorf(status.Internal, "user %s has no access to peer %s under account %s", userID, peerID, accountID) } -func updatePeerMeta(peer *nbpeer.Peer, meta nbpeer.PeerSystemMeta, account *Account) (*nbpeer.Peer, bool) { - if peer.UpdateMetaIfNew(meta) { - account.UpdatePeer(peer) - return peer, true - } - return peer, false -} - // updateAccountPeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) { @@ -929,7 +995,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account postureChecks := am.getPeerPostureChecks(account, p) remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) - update := toSyncResponse(ctx, nil, p, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache) + update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap, Checks: postureChecks}) }(peer) } @@ -937,14 +1003,10 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account wg.Wait() } -// IsPeerInActiveGroup checks if the given peer is part of a group that is used -// in an active DNS, route, or ACL configuration. -func isPeerInActiveGroup(account *Account, peerID string) bool { - peerGroupIDs := make([]string, 0) - for _, group := range account.Groups { - if slices.Contains(group.Peers, peerID) { - peerGroupIDs = append(peerGroupIDs, group.ID) - } +func ConvertSliceToMap(existingLabels []string) map[string]struct{} { + labelMap := make(map[string]struct{}, len(existingLabels)) + for _, label := range existingLabels { + labelMap[label] = struct{}{} } - return areGroupChangesAffectPeers(account, peerGroupIDs) + return labelMap } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 102923e37..387adb91d 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -7,6 +7,7 @@ import ( "net" "net/netip" "os" + "runtime" "testing" "time" @@ -19,9 +20,11 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server/activity" nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/telemetry" nbroute "github.com/netbirdio/netbird/route" ) @@ -248,7 +251,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { Action: PolicyTrafficActionAccept, }, } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy) + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) if err != nil { t.Errorf("expecting rule to be added, got failure %v", err) return @@ -296,7 +299,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { } policy.Enabled = false - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy) + err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) if err != nil { t.Errorf("expecting rule to be added, got failure %v", err) return @@ -643,7 +646,6 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { }) } - } func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccountManager, string, string, error) { @@ -849,9 +851,9 @@ func TestToSyncResponse(t *testing.T) { DNSLabel: "peer1", SSHKey: "peer1-ssh-key", } - turnCredentials := &TURNCredentials{ - Username: "turn-user", - Password: "turn-pass", + turnRelayToken := &Token{ + Payload: "turn-user", + Signature: "turn-pass", } networkMap := &NetworkMap{ Network: &Network{Net: *ipnet, Serial: 1000}, @@ -917,7 +919,7 @@ func TestToSyncResponse(t *testing.T) { } dnsCache := &DNSConfigCache{} - response := toSyncResponse(context.Background(), config, peer, turnCredentials, networkMap, dnsName, checks, dnsCache) + response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache) assert.NotNil(t, response) // assert peer config @@ -988,164 +990,192 @@ func TestToSyncResponse(t *testing.T) { // assert network map Firewall assert.Equal(t, 1, len(response.NetworkMap.FirewallRules)) assert.Equal(t, "192.168.1.2", response.NetworkMap.FirewallRules[0].PeerIP) - assert.Equal(t, proto.FirewallRule_IN, response.NetworkMap.FirewallRules[0].Direction) - assert.Equal(t, proto.FirewallRule_ACCEPT, response.NetworkMap.FirewallRules[0].Action) - assert.Equal(t, proto.FirewallRule_TCP, response.NetworkMap.FirewallRules[0].Protocol) + assert.Equal(t, proto.RuleDirection_IN, response.NetworkMap.FirewallRules[0].Direction) + assert.Equal(t, proto.RuleAction_ACCEPT, response.NetworkMap.FirewallRules[0].Action) + assert.Equal(t, proto.RuleProtocol_TCP, response.NetworkMap.FirewallRules[0].Protocol) assert.Equal(t, "80", response.NetworkMap.FirewallRules[0].Port) // assert posture checks assert.Equal(t, 1, len(response.Checks)) assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0]) } -func TestPeerAccountPeerUpdate(t *testing.T) { - manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) +func Test_RegisterPeerByUser(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } - err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID) + store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + + eventStore := &activity.InMemoryEventStore{} + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + assert.NoError(t, err) + + am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) + assert.NoError(t, err) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + existingUserID := "edafee4e-63fb-11ec-90d6-0242ac120003" + + _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ - ID: "group-id", - Name: "GroupA", - Peers: []string{peer1.ID, peer2.ID, peer3.ID}, - }) - require.NoError(t, err) - - // create a user with auto groups - _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ - Id: "regularUser1", - AccountID: account.Id, - Role: UserRoleAdmin, - Issued: UserIssuedAPI, - AutoGroups: []string{"group-id"}, - }, true) - require.NoError(t, err) - - var peer4 *nbpeer.Peer - - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) - }) - - // Updating not expired peer and peer expiration is enabled should not update account peers and not send peer update - t.Run("updating not expired peer and peer expiration is enabled", func(t *testing.T) { - done := make(chan struct{}) - go func() { - peerShouldNotReceiveUpdate(t, updMsg) - close(done) - }() - - _, err := manager.UpdatePeer(context.Background(), account.Id, userID, peer2) - require.NoError(t, err) - - select { - case <-done: - case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldNotReceiveUpdate") - } - }) - - // Adding peer with an unused group in active dns, route, acl should not update account peers and not send peer update - t.Run("adding peer with unused group", func(t *testing.T) { - done := make(chan struct{}) - go func() { - peerShouldNotReceiveUpdate(t, updMsg) - close(done) - }() - - key, err := wgtypes.GeneratePrivateKey() - require.NoError(t, err) - - expectedPeerKey := key.PublicKey().String() - peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{ - Key: expectedPeerKey, - Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) - require.NoError(t, err) - - select { - case <-done: - case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldNotReceiveUpdate") - } - }) - - // Deleting peer with an unused group in active dns, route, acl should not update account peers and not send peer update - t.Run("deleting peer with unused group", func(t *testing.T) { - done := make(chan struct{}) - go func() { - peerShouldNotReceiveUpdate(t, updMsg) - close(done) - }() - - err = manager.DeletePeer(context.Background(), account.Id, peer4.ID, userID) - require.NoError(t, err) - - select { - case <-done: - case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldNotReceiveUpdate") - } - }) - - // use the group-id in policy - err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ - ID: "policy", - Enabled: true, - Rules: []*PolicyRule{ - { - Enabled: true, - Sources: []string{"group-id"}, - Destinations: []string{"group-id"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, + newPeer := &nbpeer.Peer{ + ID: xid.New().String(), + AccountID: existingAccountID, + Key: "newPeerKey", + SetupKey: "", + IP: net.IP{123, 123, 123, 123}, + Meta: nbpeer.PeerSystemMeta{ + Hostname: "newPeer", + GoOS: "linux", }, - }) + Name: "newPeerName", + DNSLabel: "newPeer.test", + UserID: existingUserID, + Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()}, + SSHEnabled: false, + LastLogin: time.Now(), + } + + addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer) require.NoError(t, err) - // Adding peer with a used group in active dns, route or policy should update account peers and send peer update - t.Run("adding peer with used group", func(t *testing.T) { - done := make(chan struct{}) - go func() { - peerShouldReceiveUpdate(t, updMsg) - close(done) - }() + peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, addedPeer.Key) + require.NoError(t, err) + assert.Equal(t, peer.AccountID, existingAccountID) + assert.Equal(t, peer.UserID, existingUserID) - key, err := wgtypes.GeneratePrivateKey() - require.NoError(t, err) + account, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + assert.Contains(t, account.Peers, addedPeer.ID) + assert.Equal(t, peer.Meta.Hostname, newPeer.Meta.Hostname) + assert.Contains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, addedPeer.ID) + assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID) - expectedPeerKey := key.PublicKey().String() - peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{ - Key: expectedPeerKey, - LoginExpirationEnabled: true, - Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) - require.NoError(t, err) + assert.Equal(t, uint64(1), account.Network.Serial) - select { - case <-done: - case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldReceiveUpdate") - } - }) + lastLogin, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z") + assert.NoError(t, err) + assert.NotEqual(t, lastLogin, account.Users[existingUserID].LastLogin) +} - //Deleting peer with a used group in active dns, route or acl should update account peers and send peer update - t.Run("deleting peer with used group", func(t *testing.T) { - done := make(chan struct{}) - go func() { - peerShouldReceiveUpdate(t, updMsg) - close(done) - }() +func Test_RegisterPeerBySetupKey(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } - err = manager.DeletePeer(context.Background(), account.Id, peer4.ID, userID) - require.NoError(t, err) + store := newSqliteStoreFromFile(t, "testdata/extended-store.json") - select { - case <-done: - case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldReceiveUpdate") - } - }) + eventStore := &activity.InMemoryEventStore{} + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + assert.NoError(t, err) + + am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) + assert.NoError(t, err) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + existingSetupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" + + _, err = store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + + newPeer := &nbpeer.Peer{ + ID: xid.New().String(), + AccountID: existingAccountID, + Key: "newPeerKey", + SetupKey: "existingSetupKey", + UserID: "", + IP: net.IP{123, 123, 123, 123}, + Meta: nbpeer.PeerSystemMeta{ + Hostname: "newPeer", + GoOS: "linux", + }, + Name: "newPeerName", + DNSLabel: "newPeer.test", + Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()}, + SSHEnabled: false, + } + + addedPeer, _, _, err := am.AddPeer(context.Background(), existingSetupKeyID, "", newPeer) + + require.NoError(t, err) + + peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key) + require.NoError(t, err) + assert.Equal(t, peer.AccountID, existingAccountID) + assert.Equal(t, peer.SetupKey, existingSetupKeyID) + + account, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + assert.Contains(t, account.Peers, addedPeer.ID) + assert.Contains(t, account.Groups["cfefqs706sqkneg59g2g"].Peers, addedPeer.ID) + assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID) + + assert.Equal(t, uint64(1), account.Network.Serial) + + lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z") + assert.NoError(t, err) + assert.NotEqual(t, lastUsed, account.SetupKeys[existingSetupKeyID].LastUsed) + assert.Equal(t, 1, account.SetupKeys[existingSetupKeyID].UsedTimes) } + +func Test_RegisterPeerRollbackOnFailure(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + + eventStore := &activity.InMemoryEventStore{} + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + assert.NoError(t, err) + + am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) + assert.NoError(t, err) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + faultyKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBC" + + _, err = store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + + newPeer := &nbpeer.Peer{ + ID: xid.New().String(), + AccountID: existingAccountID, + Key: "newPeerKey", + SetupKey: "existingSetupKey", + UserID: "", + IP: net.IP{123, 123, 123, 123}, + Meta: nbpeer.PeerSystemMeta{ + Hostname: "newPeer", + GoOS: "linux", + }, + Name: "newPeerName", + DNSLabel: "newPeer.test", + Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()}, + SSHEnabled: false, + } + + _, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer) + require.Error(t, err) + + _, err = store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key) + require.Error(t, err) + + account, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + assert.NotContains(t, account.Peers, newPeer.ID) + assert.NotContains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, newPeer.ID) + assert.NotContains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, newPeer.ID) + + assert.Equal(t, uint64(0), account.Network.Serial) + + lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z") + assert.NoError(t, err) + assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed) + assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes) +} diff --git a/management/server/policy.go b/management/server/policy.go index e3a6df3f2..75647de44 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -3,6 +3,7 @@ package server import ( "context" _ "embed" + "slices" "strconv" "strings" @@ -75,6 +76,12 @@ type PolicyUpdateOperation struct { Values []string } +// RulePortRange represents a range of ports for a firewall rule. +type RulePortRange struct { + Start uint16 + End uint16 +} + // PolicyRule is the metadata of the policy type PolicyRule struct { // ID of the policy rule @@ -109,6 +116,9 @@ type PolicyRule struct { // Ports or it ranges list Ports []string `gorm:"serializer:json"` + + // PortRanges a list of port ranges. + PortRanges []RulePortRange `gorm:"serializer:json"` } // Copy returns a copy of a policy rule @@ -124,10 +134,12 @@ func (pm *PolicyRule) Copy() *PolicyRule { Bidirectional: pm.Bidirectional, Protocol: pm.Protocol, Ports: make([]string, len(pm.Ports)), + PortRanges: make([]RulePortRange, len(pm.PortRanges)), } copy(rule.Destinations, pm.Destinations) copy(rule.Sources, pm.Sources) copy(rule.Ports, pm.Ports) + copy(rule.PortRanges, pm.PortRanges) return rule } @@ -191,18 +203,6 @@ func (p *Policy) UpgradeAndFix() { } } -// ruleGroups returns a list of all groups referenced in the policy's rules, -// including sources and destinations. -func (p *Policy) ruleGroups() []string { - groups := make([]string, 0) - for _, rule := range p.Rules { - groups = append(groups, rule.Sources...) - groups = append(groups, rule.Destinations...) - } - - return groups -} - // FirewallRule is a rule of the firewall. type FirewallRule struct { // PeerIP of the peer @@ -326,34 +326,20 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, // GetPolicy from the store func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !(user.HasAdminPower() || user.IsServiceUser) { + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") } - for _, policy := range account.Policies { - if policy.ID == policyID { - return policy, nil - } - } - - return nil, status.Errorf(status.NotFound, "policy with ID %s not found", policyID) + return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID) } // SavePolicy in the store -func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) error { +func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -362,7 +348,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user return err } - exists, updateAccountPeers := am.savePolicy(account, policy) + if err = am.savePolicy(account, policy, isUpdate); err != nil { + return err + } account.Network.IncSerial() if err = am.Store.SaveAccount(ctx, account); err != nil { @@ -370,14 +358,12 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user } action := activity.PolicyAdded - if exists { + if isUpdate { action = activity.PolicyUpdated } am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) - if updateAccountPeers { - am.updateAccountPeers(ctx, account) - } + am.updateAccountPeers(ctx, account) return nil } @@ -404,33 +390,23 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) - if anyGroupHasPeers(account, policy.ruleGroups()) { - am.updateAccountPeers(ctx, account) - } + am.updateAccountPeers(ctx, account) return nil } // ListPolicies from the store func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { + return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") } - if !(user.HasAdminPower() || user.IsServiceUser) { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view policies") - } - - return account.Policies, nil + return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) } func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) { @@ -450,54 +426,47 @@ func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) return policy, nil } -func (am *DefaultAccountManager) savePolicy(account *Account, policy *Policy) (exists, updateAccountPeers bool) { - for i, p := range account.Policies { - if p.ID == policy.ID { - account.Policies[i] = policy +// savePolicy saves or updates a policy in the given account. +// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy. +func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) error { + for index, rule := range policyToSave.Rules { + rule.Sources = filterValidGroupIDs(account, rule.Sources) + rule.Destinations = filterValidGroupIDs(account, rule.Destinations) + policyToSave.Rules[index] = rule + } - exists = true - updateAccountPeers = anyGroupHasPeers(account, p.ruleGroups()) || anyGroupHasPeers(account, policy.ruleGroups()) + if policyToSave.SourcePostureChecks != nil { + policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks) + } - break + if isUpdate { + policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID }) + if policyIdx < 0 { + return status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID) } + + // Update the existing policy + account.Policies[policyIdx] = policyToSave + return nil } - if !exists { - account.Policies = append(account.Policies, policy) - updateAccountPeers = anyGroupHasPeers(account, policy.ruleGroups()) - } - return + + // Add the new policy to the account + account.Policies = append(account.Policies, policyToSave) + + return nil } -func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule { - result := make([]*proto.FirewallRule, len(update)) - for i := range update { - direction := proto.FirewallRule_IN - if update[i].Direction == firewallRuleDirectionOUT { - direction = proto.FirewallRule_OUT - } - action := proto.FirewallRule_ACCEPT - if update[i].Action == string(PolicyTrafficActionDrop) { - action = proto.FirewallRule_DROP - } - - protocol := proto.FirewallRule_UNKNOWN - switch PolicyRuleProtocolType(update[i].Protocol) { - case PolicyRuleProtocolALL: - protocol = proto.FirewallRule_ALL - case PolicyRuleProtocolTCP: - protocol = proto.FirewallRule_TCP - case PolicyRuleProtocolUDP: - protocol = proto.FirewallRule_UDP - case PolicyRuleProtocolICMP: - protocol = proto.FirewallRule_ICMP - } +func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { + result := make([]*proto.FirewallRule, len(rules)) + for i := range rules { + rule := rules[i] result[i] = &proto.FirewallRule{ - PeerIP: update[i].PeerIP, - Direction: direction, - Action: action, - Protocol: protocol, - Port: update[i].Port, + PeerIP: rule.PeerIP, + Direction: getProtoDirection(rule.Direction), + Action: getProtoAction(rule.Action), + Protocol: getProtoProtocol(rule.Protocol), + Port: rule.Port, } } return result @@ -580,3 +549,29 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks { } return nil } + +// filterValidPostureChecks filters and returns the posture check IDs from the given list +// that are valid within the provided account. +func filterValidPostureChecks(account *Account, postureChecksIds []string) []string { + result := make([]string, 0, len(postureChecksIds)) + for _, id := range postureChecksIds { + for _, postureCheck := range account.PostureChecks { + if id == postureCheck.ID { + result = append(result, id) + continue + } + } + } + return result +} + +// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map. +func filterValidGroupIDs(account *Account, groupIDs []string) []string { + result := make([]string, 0, len(groupIDs)) + for _, groupID := range groupIDs { + if _, exists := account.Groups[groupID]; exists { + result = append(result, groupID) + } + } + return result +} diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index c0a9c6a65..ca4946703 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -15,30 +15,16 @@ const ( ) func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !user.HasAdminPower() { + if !user.HasAdminPower() || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) } - for _, postureChecks := range account.PostureChecks { - if postureChecks.ID == postureChecksID { - return postureChecks, nil - } - } - - return nil, status.Errorf(status.NotFound, "posture checks with ID %s not found", postureChecksID) + return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID) } func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { @@ -123,24 +109,16 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun } func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !user.HasAdminPower() { + if !user.HasAdminPower() || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) } - return account.PostureChecks, nil + return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) } func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) { diff --git a/management/server/route.go b/management/server/route.go index 3e3842ab1..39ee6170c 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -4,9 +4,15 @@ import ( "context" "fmt" "net/netip" + "slices" + "strconv" + "strings" "unicode/utf8" "github.com/rs/xid" + log "github.com/sirupsen/logrus" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" @@ -15,31 +21,42 @@ import ( "github.com/netbirdio/netbird/route" ) +// RouteFirewallRule a firewall rule applicable for a routed network. +type RouteFirewallRule struct { + // SourceRanges IP ranges of the routing peers. + SourceRanges []string + + // Action of the traffic when the rule is applicable + Action string + + // Destination a network prefix for the routed traffic + Destination string + + // Protocol of the traffic + Protocol string + + // Port of the traffic + Port uint16 + + // PortRange represents the range of ports for a firewall rule + PortRange RulePortRange + + // isDynamic indicates whether the rule is for DNS routing + IsDynamic bool +} + // GetRoute gets a route object from account and route IDs func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !(user.HasAdminPower() || user.IsServiceUser) { + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") } - wantedRoute, found := account.Routes[routeID] - if found { - return wantedRoute, nil - } - - return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID) + return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID) } // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. @@ -125,7 +142,7 @@ func getRouteDescriptor(prefix netip.Prefix, domains domain.List) string { } // CreateRoute creates and saves a new route -func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { +func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -134,6 +151,13 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri return nil, err } + // Do not allow non-Linux peers + if peer := account.GetPeer(peerID); peer != nil { + if peer.Meta.GoOS != "linux" { + return nil, status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes") + } + } + if len(domains) > 0 && prefix.IsValid() { return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") } @@ -163,6 +187,13 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri } } + if len(accessControlGroupIDs) > 0 { + err = validateGroups(accessControlGroupIDs, account.Groups) + if err != nil { + return nil, err + } + } + err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains) if err != nil { return nil, err @@ -193,6 +224,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri newRoute.Enabled = enabled newRoute.Groups = groups newRoute.KeepRoute = keepRoute + newRoute.AccessControlGroups = accessControlGroupIDs if account.Routes == nil { account.Routes = make(map[route.ID]*route.Route) @@ -205,9 +237,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri return nil, err } - if isRouteChangeAffectPeers(account, &newRoute) { - am.updateAccountPeers(ctx, account) - } + am.updateAccountPeers(ctx, account) am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) @@ -236,6 +266,13 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return err } + // Do not allow non-Linux peers + if peer := account.GetPeer(routeToSave.Peer); peer != nil { + if peer.Meta.GoOS != "linux" { + return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes") + } + } + if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() { return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") } @@ -259,6 +296,13 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI } } + if len(routeToSave.AccessControlGroups) > 0 { + err = validateGroups(routeToSave.AccessControlGroups, account.Groups) + if err != nil { + return err + } + } + err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains) if err != nil { return err @@ -269,7 +313,6 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return err } - oldRoute := account.Routes[routeToSave.ID] account.Routes[routeToSave.ID] = routeToSave account.Network.IncSerial() @@ -277,9 +320,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return err } - if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) { - am.updateAccountPeers(ctx, account) - } + am.updateAccountPeers(ctx, account) am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) @@ -296,8 +337,8 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri return err } - route := account.Routes[routeID] - if route == nil { + routy := account.Routes[routeID] + if routy == nil { return status.Errorf(status.NotFound, "route with ID %s doesn't exist", routeID) } delete(account.Routes, routeID) @@ -307,40 +348,25 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri return err } - if isRouteChangeAffectPeers(account, route) { - am.updateAccountPeers(ctx, account) - } + am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) - am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta()) + am.updateAccountPeers(ctx, account) return nil } // ListRoutes returns a list of routes from account func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !(user.HasAdminPower() || user.IsServiceUser) { + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") } - routes := make([]*route.Route, 0, len(account.Routes)) - for _, item := range account.Routes { - routes = append(routes, item) - } - - return routes, nil + return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID) } func toProtocolRoute(route *route.Route) *proto.Route { @@ -371,8 +397,247 @@ func getPlaceholderIP() netip.Prefix { return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32) } -// isRouteChangeAffectPeers checks if a given route affects peers by determining -// if it has a routing peer, distribution, or peer groups that include peers -func isRouteChangeAffectPeers(account *Account, route *route.Route) bool { - return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" +// getPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account. +func (a *Account) getPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule { + routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes)) + + enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID) + for _, route := range enabledRoutes { + // If no access control groups are specified, accept all traffic. + if len(route.AccessControlGroups) == 0 { + defaultPermit := getDefaultPermit(route) + routesFirewallRules = append(routesFirewallRules, defaultPermit...) + continue + } + + policies := getAllRoutePoliciesFromGroups(a, route.AccessControlGroups) + for _, policy := range policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + distributionGroupPeers, _ := a.getAllPeersFromGroups(ctx, route.Groups, peerID, nil, validatedPeersMap) + rules := generateRouteFirewallRules(ctx, route, rule, distributionGroupPeers, firewallRuleDirectionIN) + routesFirewallRules = append(routesFirewallRules, rules...) + } + } + } + + return routesFirewallRules +} + +func getDefaultPermit(route *route.Route) []*RouteFirewallRule { + var rules []*RouteFirewallRule + + sources := []string{"0.0.0.0/0"} + if route.Network.Addr().Is6() { + sources = []string{"::/0"} + } + rule := RouteFirewallRule{ + SourceRanges: sources, + Action: string(PolicyTrafficActionAccept), + Destination: route.Network.String(), + Protocol: string(PolicyRuleProtocolALL), + IsDynamic: route.IsDynamic(), + } + + rules = append(rules, &rule) + + // dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally + if route.IsDynamic() { + ruleV6 := rule + ruleV6.SourceRanges = []string{"::/0"} + rules = append(rules, &ruleV6) + } + + return rules +} + +// getAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups +// and returns a list of policies that have rules with destinations matching the specified groups. +func getAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy { + routePolicies := make([]*Policy, 0) + for _, groupID := range accessControlGroups { + group, ok := account.Groups[groupID] + if !ok { + continue + } + + for _, policy := range account.Policies { + for _, rule := range policy.Rules { + exist := slices.ContainsFunc(rule.Destinations, func(groupID string) bool { + return groupID == group.ID + }) + if exist { + routePolicies = append(routePolicies, policy) + continue + } + } + } + } + + return routePolicies +} + +// generateRouteFirewallRules generates a list of firewall rules for a given route. +func generateRouteFirewallRules(ctx context.Context, route *route.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule { + rulesExists := make(map[string]struct{}) + rules := make([]*RouteFirewallRule, 0) + + sourceRanges := make([]string, 0, len(groupPeers)) + for _, peer := range groupPeers { + if peer == nil { + continue + } + sourceRanges = append(sourceRanges, fmt.Sprintf(AllowedIPsFormat, peer.IP)) + } + + baseRule := RouteFirewallRule{ + SourceRanges: sourceRanges, + Action: string(rule.Action), + Destination: route.Network.String(), + Protocol: string(rule.Protocol), + IsDynamic: route.IsDynamic(), + } + + // generate rule for port range + if len(rule.Ports) == 0 { + rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...) + } else { + rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...) + + } + + // TODO: generate IPv6 rules for dynamic routes + + return rules +} + +// generateRuleIDBase generates the base rule ID for checking duplicates. +func generateRuleIDBase(rule *PolicyRule, baseRule RouteFirewallRule) string { + return rule.ID + strings.Join(baseRule.SourceRanges, ",") + strconv.Itoa(firewallRuleDirectionIN) + baseRule.Protocol + baseRule.Action +} + +// generateRulesForPeer generates rules for a given peer based on ports and port ranges. +func generateRulesWithPortRanges(baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule { + rules := make([]*RouteFirewallRule, 0) + + ruleIDBase := generateRuleIDBase(rule, baseRule) + if len(rule.Ports) == 0 { + if len(rule.PortRanges) == 0 { + if _, ok := rulesExists[ruleIDBase]; !ok { + rulesExists[ruleIDBase] = struct{}{} + rules = append(rules, &baseRule) + } + } else { + for _, portRange := range rule.PortRanges { + ruleID := fmt.Sprintf("%s%d-%d", ruleIDBase, portRange.Start, portRange.End) + if _, ok := rulesExists[ruleID]; !ok { + rulesExists[ruleID] = struct{}{} + pr := baseRule + pr.PortRange = portRange + rules = append(rules, &pr) + } + } + } + return rules + } + + return rules +} + +// generateRulesWithPorts generates rules when specific ports are provided. +func generateRulesWithPorts(ctx context.Context, baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule { + rules := make([]*RouteFirewallRule, 0) + ruleIDBase := generateRuleIDBase(rule, baseRule) + + for _, port := range rule.Ports { + ruleID := ruleIDBase + port + if _, ok := rulesExists[ruleID]; ok { + continue + } + rulesExists[ruleID] = struct{}{} + + pr := baseRule + p, err := strconv.ParseUint(port, 10, 16) + if err != nil { + log.WithContext(ctx).Errorf("failed to parse port %s for rule: %s", port, rule.ID) + continue + } + + pr.Port = uint16(p) + rules = append(rules, &pr) + } + + return rules +} + +func toProtocolRoutesFirewallRules(rules []*RouteFirewallRule) []*proto.RouteFirewallRule { + result := make([]*proto.RouteFirewallRule, len(rules)) + for i := range rules { + rule := rules[i] + result[i] = &proto.RouteFirewallRule{ + SourceRanges: rule.SourceRanges, + Action: getProtoAction(rule.Action), + Destination: rule.Destination, + Protocol: getProtoProtocol(rule.Protocol), + PortInfo: getProtoPortInfo(rule), + IsDynamic: rule.IsDynamic, + } + } + + return result +} + +// getProtoDirection converts the direction to proto.RuleDirection. +func getProtoDirection(direction int) proto.RuleDirection { + if direction == firewallRuleDirectionOUT { + return proto.RuleDirection_OUT + } + return proto.RuleDirection_IN +} + +// getProtoAction converts the action to proto.RuleAction. +func getProtoAction(action string) proto.RuleAction { + if action == string(PolicyTrafficActionDrop) { + return proto.RuleAction_DROP + } + return proto.RuleAction_ACCEPT +} + +// getProtoProtocol converts the protocol to proto.RuleProtocol. +func getProtoProtocol(protocol string) proto.RuleProtocol { + switch PolicyRuleProtocolType(protocol) { + case PolicyRuleProtocolALL: + return proto.RuleProtocol_ALL + case PolicyRuleProtocolTCP: + return proto.RuleProtocol_TCP + case PolicyRuleProtocolUDP: + return proto.RuleProtocol_UDP + case PolicyRuleProtocolICMP: + return proto.RuleProtocol_ICMP + default: + return proto.RuleProtocol_UNKNOWN + } +} + +// getProtoPortInfo converts the port info to proto.PortInfo. +func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo { + var portInfo proto.PortInfo + if rule.Port != 0 { + portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)} + } else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 { + portInfo.PortSelection = &proto.PortInfo_Range_{ + Range: &proto.PortInfo_Range{ + Start: uint32(portRange.Start), + End: uint32(portRange.End), + }, + } + } + return &portInfo } diff --git a/management/server/route_test.go b/management/server/route_test.go index 8dcdef884..b556816be 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -2,9 +2,10 @@ package server import ( "context" + "fmt" + "net" "net/netip" "testing" - "time" "github.com/rs/xid" "github.com/stretchr/testify/assert" @@ -45,18 +46,19 @@ var existingDomains = domain.List{"example.com"} func TestCreateRoute(t *testing.T) { type input struct { - network netip.Prefix - domains domain.List - keepRoute bool - networkType route.NetworkType - netID route.NetID - peerKey string - peerGroupIDs []string - description string - masquerade bool - metric int - enabled bool - groups []string + network netip.Prefix + domains domain.List + keepRoute bool + networkType route.NetworkType + netID route.NetID + peerKey string + peerGroupIDs []string + description string + masquerade bool + metric int + enabled bool + groups []string + accessControlGroups []string } testCases := []struct { @@ -70,100 +72,107 @@ func TestCreateRoute(t *testing.T) { { name: "Happy Path Network", inputArgs: input{ - network: netip.MustParsePrefix("192.168.0.0/16"), - networkType: route.IPv4Network, - netID: "happy", - peerKey: peer1ID, - description: "super", - masquerade: false, - metric: 9999, - enabled: true, - groups: []string{routeGroup1}, + network: netip.MustParsePrefix("192.168.0.0/16"), + networkType: route.IPv4Network, + netID: "happy", + peerKey: peer1ID, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1}, + accessControlGroups: []string{routeGroup1}, }, errFunc: require.NoError, shouldCreate: true, expectedRoute: &route.Route{ - Network: netip.MustParsePrefix("192.168.0.0/16"), - NetworkType: route.IPv4Network, - NetID: "happy", - Peer: peer1ID, - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1}, + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetworkType: route.IPv4Network, + NetID: "happy", + Peer: peer1ID, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + AccessControlGroups: []string{routeGroup1}, }, }, { name: "Happy Path Domains", inputArgs: input{ - domains: domain.List{"domain1", "domain2"}, - keepRoute: true, - networkType: route.DomainNetwork, - netID: "happy", - peerKey: peer1ID, - description: "super", - masquerade: false, - metric: 9999, - enabled: true, - groups: []string{routeGroup1}, + domains: domain.List{"domain1", "domain2"}, + keepRoute: true, + networkType: route.DomainNetwork, + netID: "happy", + peerKey: peer1ID, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1}, + accessControlGroups: []string{routeGroup1}, }, errFunc: require.NoError, shouldCreate: true, expectedRoute: &route.Route{ - Network: netip.MustParsePrefix("192.0.2.0/32"), - Domains: domain.List{"domain1", "domain2"}, - NetworkType: route.DomainNetwork, - NetID: "happy", - Peer: peer1ID, - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1}, - KeepRoute: true, + Network: netip.MustParsePrefix("192.0.2.0/32"), + Domains: domain.List{"domain1", "domain2"}, + NetworkType: route.DomainNetwork, + NetID: "happy", + Peer: peer1ID, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + KeepRoute: true, + AccessControlGroups: []string{routeGroup1}, }, }, { name: "Happy Path Peer Groups", inputArgs: input{ - network: netip.MustParsePrefix("192.168.0.0/16"), - networkType: route.IPv4Network, - netID: "happy", - peerGroupIDs: []string{routeGroupHA1, routeGroupHA2}, - description: "super", - masquerade: false, - metric: 9999, - enabled: true, - groups: []string{routeGroup1, routeGroup2}, + network: netip.MustParsePrefix("192.168.0.0/16"), + networkType: route.IPv4Network, + netID: "happy", + peerGroupIDs: []string{routeGroupHA1, routeGroupHA2}, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1, routeGroup2}, + accessControlGroups: []string{routeGroup1, routeGroup2}, }, errFunc: require.NoError, shouldCreate: true, expectedRoute: &route.Route{ - Network: netip.MustParsePrefix("192.168.0.0/16"), - NetworkType: route.IPv4Network, - NetID: "happy", - PeerGroups: []string{routeGroupHA1, routeGroupHA2}, - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1, routeGroup2}, + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetworkType: route.IPv4Network, + NetID: "happy", + PeerGroups: []string{routeGroupHA1, routeGroupHA2}, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1, routeGroup2}, + AccessControlGroups: []string{routeGroup1, routeGroup2}, }, }, { name: "Both network and domains provided should fail", inputArgs: input{ - network: netip.MustParsePrefix("192.168.0.0/16"), - domains: domain.List{"domain1", "domain2"}, - netID: "happy", - peerKey: peer1ID, - peerGroupIDs: []string{routeGroupHA1}, - description: "super", - masquerade: false, - metric: 9999, - enabled: true, - groups: []string{routeGroup1}, + network: netip.MustParsePrefix("192.168.0.0/16"), + domains: domain.List{"domain1", "domain2"}, + netID: "happy", + peerKey: peer1ID, + peerGroupIDs: []string{routeGroupHA1}, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1}, + accessControlGroups: []string{routeGroup2}, }, errFunc: require.Error, shouldCreate: false, @@ -171,16 +180,17 @@ func TestCreateRoute(t *testing.T) { { name: "Both peer and peer_groups Provided Should Fail", inputArgs: input{ - network: netip.MustParsePrefix("192.168.0.0/16"), - networkType: route.IPv4Network, - netID: "happy", - peerKey: peer1ID, - peerGroupIDs: []string{routeGroupHA1}, - description: "super", - masquerade: false, - metric: 9999, - enabled: true, - groups: []string{routeGroup1}, + network: netip.MustParsePrefix("192.168.0.0/16"), + networkType: route.IPv4Network, + netID: "happy", + peerKey: peer1ID, + peerGroupIDs: []string{routeGroupHA1}, + description: "super", + masquerade: false, + metric: 9999, + enabled: true, + groups: []string{routeGroup1}, + accessControlGroups: []string{routeGroup2}, }, errFunc: require.Error, shouldCreate: false, @@ -424,13 +434,13 @@ func TestCreateRoute(t *testing.T) { if testCase.createInitRoute { groupAll, errInit := account.GetGroupAll() require.NoError(t, errInit) - _, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID, false) + _, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false) require.NoError(t, errInit) - _, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, true, userID, false) + _, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false) require.NoError(t, errInit) } - outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute) + outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute) testCase.errFunc(t, err) @@ -1038,15 +1048,16 @@ func TestDeleteRoute(t *testing.T) { func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { baseRoute := &route.Route{ - Network: netip.MustParsePrefix("192.168.0.0/16"), - NetID: "superNet", - NetworkType: route.IPv4Network, - PeerGroups: []string{routeGroupHA1, routeGroupHA2}, - Description: "ha route", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1, routeGroup2}, + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetID: "superNet", + NetworkType: route.IPv4Network, + PeerGroups: []string{routeGroupHA1, routeGroupHA2}, + Description: "ha route", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1, routeGroup2}, + AccessControlGroups: []string{routeGroup1}, } am, err := createRouterManager(t) @@ -1063,7 +1074,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { require.NoError(t, err) require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") - newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.Enabled, userID, baseRoute.KeepRoute) + newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute) require.NoError(t, err) require.Equal(t, newRoute.Enabled, true) @@ -1128,16 +1139,17 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { // no routes for peer in different groups // no routes when route is deleted baseRoute := &route.Route{ - ID: "testingRoute", - Network: netip.MustParsePrefix("192.168.0.0/16"), - NetID: "superNet", - NetworkType: route.IPv4Network, - Peer: peer1ID, - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1}, + ID: "testingRoute", + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetID: "superNet", + NetworkType: route.IPv4Network, + Peer: peer1ID, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{routeGroup1}, + AccessControlGroups: []string{routeGroup1}, } am, err := createRouterManager(t) @@ -1154,7 +1166,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.NoError(t, err) require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") - createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, false, userID, baseRoute.KeepRoute) + createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute) require.NoError(t, err) noDisabledRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) @@ -1206,7 +1218,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { newPolicy.Rules[0].Sources = []string{newGroup.ID} newPolicy.Rules[0].Destinations = []string{newGroup.ID} - err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy) + err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false) require.NoError(t, err) err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID) @@ -1273,11 +1285,12 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er } peer1 := &nbpeer.Peer{ - IP: peer1IP, - ID: peer1ID, - Key: peer1Key, - Name: "test-host1@netbird.io", - UserID: userID, + IP: peer1IP, + ID: peer1ID, + Key: peer1Key, + Name: "test-host1@netbird.io", + DNSLabel: "test-host1", + UserID: userID, Meta: nbpeer.PeerSystemMeta{ Hostname: "test-host1@netbird.io", GoOS: "linux", @@ -1299,11 +1312,12 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er } peer2 := &nbpeer.Peer{ - IP: peer2IP, - ID: peer2ID, - Key: peer2Key, - Name: "test-host2@netbird.io", - UserID: userID, + IP: peer2IP, + ID: peer2ID, + Key: peer2Key, + Name: "test-host2@netbird.io", + DNSLabel: "test-host2", + UserID: userID, Meta: nbpeer.PeerSystemMeta{ Hostname: "test-host2@netbird.io", GoOS: "linux", @@ -1325,11 +1339,12 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er } peer3 := &nbpeer.Peer{ - IP: peer3IP, - ID: peer3ID, - Key: peer3Key, - Name: "test-host3@netbird.io", - UserID: userID, + IP: peer3IP, + ID: peer3ID, + Key: peer3Key, + Name: "test-host3@netbird.io", + DNSLabel: "test-host3", + UserID: userID, Meta: nbpeer.PeerSystemMeta{ Hostname: "test-host3@netbird.io", GoOS: "darwin", @@ -1351,11 +1366,12 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er } peer4 := &nbpeer.Peer{ - IP: peer4IP, - ID: peer4ID, - Key: peer4Key, - Name: "test-host4@netbird.io", - UserID: userID, + IP: peer4IP, + ID: peer4ID, + Key: peer4Key, + Name: "test-host4@netbird.io", + DNSLabel: "test-host4", + UserID: userID, Meta: nbpeer.PeerSystemMeta{ Hostname: "test-host4@netbird.io", GoOS: "linux", @@ -1377,13 +1393,14 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er } peer5 := &nbpeer.Peer{ - IP: peer5IP, - ID: peer5ID, - Key: peer5Key, - Name: "test-host4@netbird.io", - UserID: userID, + IP: peer5IP, + ID: peer5ID, + Key: peer5Key, + Name: "test-host5@netbird.io", + DNSLabel: "test-host5", + UserID: userID, Meta: nbpeer.PeerSystemMeta{ - Hostname: "test-host4@netbird.io", + Hostname: "test-host5@netbird.io", GoOS: "linux", Kernel: "Linux", Core: "21.04", @@ -1464,90 +1481,299 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er return am.Store.GetAccount(context.Background(), account.Id) } -func TestRouteAccountPeerUpdate(t *testing.T) { - manager, err := createRouterManager(t) - require.NoError(t, err, "failed to create account manager") +func TestAccount_getPeersRoutesFirewall(t *testing.T) { + var ( + peerBIp = "100.65.80.39" + peerCIp = "100.65.254.139" + peerHIp = "100.65.29.55" + ) - account, err := initTestRouteAccount(t, manager) - require.NoError(t, err, "failed to init testing account") - - updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID) - t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1ID) - }) - - baseRoute := route.Route{ - ID: "testingRoute", - Network: netip.MustParsePrefix("192.168.0.0/16"), - NetID: "superNet", - NetworkType: route.IPv4Network, - Peer: peer1ID, - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - Groups: []string{routeGroup1}, + account := &Account{ + Peers: map[string]*nbpeer.Peer{ + "peerA": { + ID: "peerA", + IP: net.ParseIP("100.65.14.88"), + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + }, + }, + "peerB": { + ID: "peerB", + IP: net.ParseIP(peerBIp), + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{}, + }, + "peerC": { + ID: "peerC", + IP: net.ParseIP(peerCIp), + Status: &nbpeer.PeerStatus{}, + }, + "peerD": { + ID: "peerD", + IP: net.ParseIP("100.65.62.5"), + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + }, + }, + "peerE": { + ID: "peerE", + IP: net.ParseIP("100.65.32.206"), + Key: peer1Key, + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + }, + }, + "peerF": { + ID: "peerF", + IP: net.ParseIP("100.65.250.202"), + Status: &nbpeer.PeerStatus{}, + }, + "peerG": { + ID: "peerG", + IP: net.ParseIP("100.65.13.186"), + Status: &nbpeer.PeerStatus{}, + }, + "peerH": { + ID: "peerH", + IP: net.ParseIP(peerHIp), + Status: &nbpeer.PeerStatus{}, + }, + }, + Groups: map[string]*nbgroup.Group{ + "routingPeer1": { + ID: "routingPeer1", + Name: "RoutingPeer1", + Peers: []string{ + "peerA", + }, + }, + "routingPeer2": { + ID: "routingPeer2", + Name: "RoutingPeer2", + Peers: []string{ + "peerD", + }, + }, + "route1": { + ID: "route1", + Name: "Route1", + Peers: []string{}, + }, + "route2": { + ID: "route2", + Name: "Route2", + Peers: []string{}, + }, + "finance": { + ID: "finance", + Name: "Finance", + Peers: []string{ + "peerF", + "peerG", + }, + }, + "dev": { + ID: "dev", + Name: "Dev", + Peers: []string{ + "peerC", + "peerH", + "peerB", + }, + }, + "contractors": { + ID: "contractors", + Name: "Contractors", + Peers: []string{}, + }, + }, + Routes: map[route.ID]*route.Route{ + "route1": { + ID: "route1", + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetID: "route1", + NetworkType: route.IPv4Network, + PeerGroups: []string{"routingPeer1", "routingPeer2"}, + Description: "Route1 ha route", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"dev"}, + AccessControlGroups: []string{"route1"}, + }, + "route2": { + ID: "route2", + Network: existingNetwork, + NetID: "route2", + NetworkType: route.IPv4Network, + Peer: "peerE", + Description: "Allow", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"finance"}, + AccessControlGroups: []string{"route2"}, + }, + "route3": { + ID: "route3", + Network: netip.MustParsePrefix("192.0.2.0/32"), + Domains: domain.List{"example.com"}, + NetID: "route3", + NetworkType: route.DomainNetwork, + Peer: "peerE", + Description: "Allow all traffic to routed DNS network", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"contractors"}, + AccessControlGroups: []string{}, + }, + }, + Policies: []*Policy{ + { + ID: "RuleRoute1", + Name: "Route1", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "RuleRoute1", + Name: "ruleRoute1", + Bidirectional: true, + Enabled: true, + Protocol: PolicyRuleProtocolALL, + Action: PolicyTrafficActionAccept, + Ports: []string{"80", "320"}, + Sources: []string{ + "dev", + }, + Destinations: []string{ + "route1", + }, + }, + }, + }, + { + ID: "RuleRoute2", + Name: "Route2", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "RuleRoute2", + Name: "ruleRoute2", + Bidirectional: true, + Enabled: true, + Protocol: PolicyRuleProtocolTCP, + Action: PolicyTrafficActionAccept, + PortRanges: []RulePortRange{ + { + Start: 80, + End: 350, + }, { + Start: 80, + End: 350, + }, + }, + Sources: []string{ + "finance", + }, + Destinations: []string{ + "route2", + }, + }, + }, + }, + }, } - // Creating route should not update account peers and send peer update - t.Run("creating route", func(t *testing.T) { - done := make(chan struct{}) - go func() { - peerShouldReceiveUpdate(t, updMsg) - close(done) - }() + validatedPeers := make(map[string]struct{}) + for p := range account.Peers { + validatedPeers[p] = struct{}{} + } - newRoute, err := manager.CreateRoute( - context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, - baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, - baseRoute.Groups, true, userID, baseRoute.KeepRoute, - ) - require.NoError(t, err) - baseRoute = *newRoute + t.Run("check applied policies for the route", func(t *testing.T) { + route1 := account.Routes["route1"] + policies := getAllRoutePoliciesFromGroups(account, route1.AccessControlGroups) + assert.Len(t, policies, 1) - select { - case <-done: - case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldReceiveUpdate") - } + route2 := account.Routes["route2"] + policies = getAllRoutePoliciesFromGroups(account, route2.AccessControlGroups) + assert.Len(t, policies, 1) + + route3 := account.Routes["route3"] + policies = getAllRoutePoliciesFromGroups(account, route3.AccessControlGroups) + assert.Len(t, policies, 0) }) - // Updating the route should update account peers and send peer update - t.Run("updating route", func(t *testing.T) { - baseRoute.Groups = []string{routeGroup1, routeGroup2} + t.Run("check peer routes firewall rules", func(t *testing.T) { + routesFirewallRules := account.getPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers) + assert.Len(t, routesFirewallRules, 2) - done := make(chan struct{}) - go func() { - peerShouldReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute) - require.NoError(t, err) - - select { - case <-done: - case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldReceiveUpdate") + expectedRoutesFirewallRules := []*RouteFirewallRule{ + { + SourceRanges: []string{ + fmt.Sprintf(AllowedIPsFormat, peerCIp), + fmt.Sprintf(AllowedIPsFormat, peerHIp), + fmt.Sprintf(AllowedIPsFormat, peerBIp), + }, + Action: "accept", + Destination: "192.168.0.0/16", + Protocol: "all", + Port: 80, + }, + { + SourceRanges: []string{ + fmt.Sprintf(AllowedIPsFormat, peerCIp), + fmt.Sprintf(AllowedIPsFormat, peerHIp), + fmt.Sprintf(AllowedIPsFormat, peerBIp), + }, + Action: "accept", + Destination: "192.168.0.0/16", + Protocol: "all", + Port: 320, + }, } - }) + assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) - // Deleting the route should update account peers and send peer update - t.Run("deleting route", func(t *testing.T) { - done := make(chan struct{}) - go func() { - peerShouldReceiveUpdate(t, updMsg) - close(done) - }() + //peerD is also the routing peer for route1, should contain same routes firewall rules as peerA + routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers) + assert.Len(t, routesFirewallRules, 2) + assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) - err := manager.DeleteRoute(context.Background(), account.Id, baseRoute.ID, userID) - require.NoError(t, err) + // peerE is a single routing peer for route 2 and route 3 + routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers) + assert.Len(t, routesFirewallRules, 3) - select { - case <-done: - case <-time.After(time.Second): - t.Error("timeout waiting for peerShouldReceiveUpdate") + expectedRoutesFirewallRules = []*RouteFirewallRule{ + { + SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"}, + Action: "accept", + Destination: existingNetwork.String(), + Protocol: "tcp", + PortRange: RulePortRange{Start: 80, End: 350}, + }, + { + SourceRanges: []string{"0.0.0.0/0"}, + Action: "accept", + Destination: "192.0.2.0/32", + Protocol: "all", + IsDynamic: true, + }, + { + SourceRanges: []string{"::/0"}, + Action: "accept", + Destination: "192.0.2.0/32", + Protocol: "all", + IsDynamic: true, + }, } + assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) + + // peerC is part of route1 distribution groups but should not receive the routes firewall rules + routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers) + assert.Len(t, routesFirewallRules, 0) }) } diff --git a/management/server/scheduler_test.go b/management/server/scheduler_test.go index 7c287a554..fa279d4db 100644 --- a/management/server/scheduler_test.go +++ b/management/server/scheduler_test.go @@ -63,6 +63,7 @@ func TestScheduler_Cancel(t *testing.T) { scheduler.Schedule(context.Background(), scheduletime, jobID2, func() (nextRunIn time.Duration, reschedule bool) { return scheduletime, true }) + defer scheduler.Cancel(context.Background(), []string{jobID2}) time.Sleep(sleepTime) assert.Len(t, scheduler.jobs, 2) diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 8ae15726b..e84f8fcd6 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -328,26 +328,24 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str // ListSetupKeys returns a list of all setup keys of the account func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { + return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys") + } + + setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) if err != nil { return nil, err } - if !user.HasAdminPower() && !user.IsServiceUser { - return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies") - } - - keys := make([]*SetupKey, 0, len(account.SetupKeys)) - for _, key := range account.SetupKeys { + keys := make([]*SetupKey, 0, len(setupKeys)) + for _, key := range setupKeys { var k *SetupKey - if !(user.HasAdminPower() || user.IsServiceUser) { + if !user.IsAdminOrServiceUser() { k = key.HiddenCopy(999) } else { k = key.Copy() @@ -360,44 +358,30 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { + return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys") + } + + setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) if err != nil { return nil, err } - if !user.HasAdminPower() && !user.IsServiceUser { - return nil, status.Errorf(status.Unauthorized, "only users with admin power can view policies") - } - - var foundKey *SetupKey - for _, key := range account.SetupKeys { - if key.Id == keyID { - foundKey = key.Copy() - break - } - } - if foundKey == nil { - return nil, status.Errorf(status.NotFound, "setup key not found") - } - // the UpdatedAt field was introduced later, so there might be that some keys have a Zero value (e.g, null in the store file) - if foundKey.UpdatedAt.IsZero() { - foundKey.UpdatedAt = foundKey.CreatedAt + if setupKey.UpdatedAt.IsZero() { + setupKey.UpdatedAt = setupKey.CreatedAt } - if !(user.HasAdminPower() || user.IsServiceUser) { - foundKey = foundKey.HiddenCopy(999) + if !user.IsAdminOrServiceUser() { + setupKey = setupKey.HiddenCopy(999) } - return foundKey, nil + return setupKey, nil } func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error { diff --git a/management/server/sql_store.go b/management/server/sql_store.go index c44ab7f09..85c68ef44 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "net" "os" "path/filepath" "runtime" @@ -33,7 +34,9 @@ import ( const ( storeSqliteFileName = "store.db" idQueryCondition = "id = ?" + keyQueryCondition = "key = ?" accountAndIDQueryCondition = "account_id = ? and id = ?" + accountIDCondition = "account_id = ?" peerNotFoundFMT = "peer %s not found" ) @@ -134,6 +137,12 @@ func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (u func (s *SqlStore) SaveAccount(ctx context.Context, account *Account) error { start := time.Now() + defer func() { + elapsed := time.Since(start) + if elapsed > 1*time.Second { + log.WithContext(ctx).Tracef("SaveAccount for account %s exceeded 1s, took: %v", account.Id, elapsed) + } + }() // todo: remove this check after the issue is resolved s.checkAccountDomainBeforeSave(ctx, account.Id, account.Domain) @@ -391,31 +400,40 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error { } func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) { - var account Account - - result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?", - strings.ToLower(domain), true, PrivateCategory) - if result.Error != nil { - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") - } - log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "issue getting account from store") + accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) + if err != nil { + return nil, err } // TODO: rework to not call GetAccount - return s.GetAccount(ctx, account.Id) + return s.GetAccount(ctx, accountID) +} + +func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) { + var accountID string + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id"). + Where("domain = ? and is_domain_primary_account = ? and domain_category = ?", + strings.ToLower(domain), true, PrivateCategory, + ).First(&accountID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") + } + log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error) + return "", status.Errorf(status.Internal, "issue getting account from store") + } + + return accountID, nil } func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) { var key SetupKey - result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey)) + result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, strings.ToUpper(setupKey)) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "issue getting setup key from store") + return nil, status.NewSetupKeyNotFoundError() } if key.AccountID == "" { @@ -468,6 +486,34 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, return &user, nil } +func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) { + var user User + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Preload(clause.Associations).First(&user, idQueryCondition, userID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewUserNotFoundError(userID) + } + return nil, status.NewGetUserFromStoreError() + } + + return &user, nil +} + +func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { + var groups []*nbgroup.Group + result := s.db.Find(&groups, accountIDCondition, accountID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") + } + log.WithContext(ctx).Errorf("error when getting groups from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "issue getting groups from store") + } + + return groups, nil +} + func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) { var accounts []Account result := s.db.Find(&accounts) @@ -485,6 +531,13 @@ func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) { } func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, error) { + start := time.Now() + defer func() { + elapsed := time.Since(start) + if elapsed > 1*time.Second { + log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed) + } + }() var account Account result := s.db.Model(&account). @@ -494,7 +547,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, if result.Error != nil { log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "account not found") + return nil, status.NewAccountNotFoundError(accountID) } return nil, status.Errorf(status.Internal, "issue getting account from store") } @@ -554,7 +607,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) { var user User - result := s.db.Select("account_id").First(&user, idQueryCondition, userID) + result := s.db.WithContext(ctx).Select("account_id").First(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -571,12 +624,11 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) { var peer nbpeer.Peer - result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID) + result := s.db.WithContext(ctx).Select("account_id").First(&peer, idQueryCondition, peerID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting account from store") } @@ -590,12 +642,11 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) { var peer nbpeer.Peer - result := s.db.Select("account_id").First(&peer, "key = ?", peerKey) + result := s.db.WithContext(ctx).Select("account_id").First(&peer, keyQueryCondition, peerKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting account from store") } @@ -609,12 +660,11 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) ( func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) { var peer nbpeer.Peer var accountID string - result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID) + result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error) return "", status.Errorf(status.Internal, "issue getting account from store") } @@ -622,9 +672,8 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) } func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { - var user User var accountID string - result := s.db.Model(&user).Select("account_id").Where(idQueryCondition, userID).First(&accountID) + result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -636,61 +685,117 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { } func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) { - var key SetupKey var accountID string - result := s.db.Model(&key).Select("account_id").Where("key = ?", strings.ToUpper(setupKey)).First(&accountID) + result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, strings.ToUpper(setupKey)).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error) - return "", status.Errorf(status.Internal, "issue getting setup key from store") + return "", status.NewSetupKeyNotFoundError() + } + + if accountID == "" { + return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } return accountID, nil } -func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error) { +func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) { + var ipJSONStrings []string + + // Fetch the IP addresses as JSON strings + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). + Where("account_id = ?", accountID). + Pluck("ip", &ipJSONStrings) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "no peers found for the account") + } + return nil, status.Errorf(status.Internal, "issue getting IPs from store") + } + + // Convert the JSON strings to net.IP objects + ips := make([]net.IP, len(ipJSONStrings)) + for i, ipJSON := range ipJSONStrings { + var ip net.IP + if err := json.Unmarshal([]byte(ipJSON), &ip); err != nil { + return nil, status.Errorf(status.Internal, "issue parsing IP JSON from store") + } + ips[i] = ip + } + + return ips, nil +} + +func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) { + var labels []string + + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). + Where("account_id = ?", accountID). + Pluck("dns_label", &labels) + + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "no peers found for the account") + } + log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "issue getting dns labels from store") + } + + return labels, nil +} + +func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) { + var accountNetwork AccountNetwork + + if err := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.NewAccountNotFoundError(accountID) + } + return nil, status.Errorf(status.Internal, "issue getting network from store") + } + return accountNetwork.Network, nil +} + +func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) { var peer nbpeer.Peer - result := s.db.First(&peer, "key = ?", peerKey) + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "peer not found") } - log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting peer from store") } return &peer, nil } -func (s *SqlStore) GetAccountSettings(ctx context.Context, accountID string) (*Settings, error) { +func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) { var accountSettings AccountSettings - if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { + if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "settings not found") } - log.WithContext(ctx).Errorf("error when getting settings from the store: %s", err) return nil, status.Errorf(status.Internal, "issue getting settings from store") } return accountSettings.Settings, nil } // SaveUserLastLogin stores the last login time for a user in DB. -func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error { +func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { var user User - result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID) + result := s.db.WithContext(ctx).First(&user, accountAndIDQueryCondition, accountID, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return status.Errorf(status.NotFound, "user %s not found", userID) + return status.NewUserNotFoundError(userID) } - return status.Errorf(status.Internal, "issue getting user from store") + return status.NewGetUserFromStoreError() } - user.LastLogin = lastLogin - return s.db.Save(user).Error + return s.db.Save(&user).Error } func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { @@ -809,3 +914,276 @@ func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore, return store, nil } + +func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) { + var setupKey SetupKey + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&setupKey, keyQueryCondition, strings.ToUpper(key)) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "setup key not found") + } + return nil, status.NewSetupKeyNotFoundError() + } + return &setupKey, nil +} + +func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error { + result := s.db.WithContext(ctx).Model(&SetupKey{}). + Where(idQueryCondition, setupKeyID). + Updates(map[string]interface{}{ + "used_times": gorm.Expr("used_times + 1"), + "last_used": time.Now(), + }) + + if result.Error != nil { + return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error) + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "setup key not found") + } + + return nil +} + +func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { + var group nbgroup.Group + + result := s.db.WithContext(ctx).Where("account_id = ? AND name = ?", accountID, "All").First(&group) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return status.Errorf(status.NotFound, "group 'All' not found for account") + } + return status.Errorf(status.Internal, "issue finding group 'All'") + } + + for _, existingPeerID := range group.Peers { + if existingPeerID == peerID { + return nil + } + } + + group.Peers = append(group.Peers, peerID) + + if err := s.db.Save(&group).Error; err != nil { + return status.Errorf(status.Internal, "issue updating group 'All'") + } + + return nil +} + +func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error { + var group nbgroup.Group + + result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return status.Errorf(status.NotFound, "group not found for account") + } + return status.Errorf(status.Internal, "issue finding group") + } + + for _, existingPeerID := range group.Peers { + if existingPeerID == peerId { + return nil + } + } + + group.Peers = append(group.Peers, peerId) + + if err := s.db.Save(&group).Error; err != nil { + return status.Errorf(status.Internal, "issue updating group") + } + + return nil +} + +func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { + if err := s.db.WithContext(ctx).Create(peer).Error; err != nil { + return status.Errorf(status.Internal, "issue adding peer to account") + } + + return nil +} + +func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { + result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) + if result.Error != nil { + return status.Errorf(status.Internal, "issue incrementing network serial count") + } + return nil +} + +func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error { + tx := s.db.WithContext(ctx).Begin() + if tx.Error != nil { + return tx.Error + } + repo := s.withTx(tx) + err := operation(repo) + if err != nil { + tx.Rollback() + return err + } + return tx.Commit().Error +} + +func (s *SqlStore) withTx(tx *gorm.DB) Store { + return &SqlStore{ + db: tx, + } +} + +func (s *SqlStore) GetDB() *gorm.DB { + return s.db +} + +func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) { + var accountDNSSettings AccountDNSSettings + + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + First(&accountDNSSettings, idQueryCondition, accountID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "dns settings not found") + } + return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error) + } + return &accountDNSSettings.DNSSettings, nil +} + +// AccountExists checks whether an account exists by the given ID. +func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) { + var accountID string + + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + Select("id").First(&accountID, idQueryCondition, id) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return false, nil + } + return false, result.Error + } + + return accountID != "", nil +} + +// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID. +func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) { + var account Account + + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category"). + Where(idQueryCondition, accountID).First(&account) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return "", "", status.Errorf(status.NotFound, "account not found") + } + return "", "", status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error) + } + + return account.Domain, account.DomainCategory, nil +} + +// GetGroupByID retrieves a group by ID and account ID. +func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) { + return getRecordByID[nbgroup.Group](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, groupID, accountID) +} + +// GetGroupByName retrieves a group by name and account ID. +func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) { + var group nbgroup.Group + + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations). + Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID) + if err := result.Error; err != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "group not found") + } + return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error) + } + return &group, nil +} + +// GetAccountPolicies retrieves policies for an account. +func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { + return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) +} + +// GetPolicyByID retrieves a policy by its ID and account ID. +func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) { + return getRecordByID[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, policyID, accountID) +} + +// GetAccountPostureChecks retrieves posture checks for an account. +func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) { + return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID) +} + +// GetPostureChecksByID retrieves posture checks by their ID and account ID. +func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) { + return getRecordByID[posture.Checks](s.db.WithContext(ctx), lockStrength, postureCheckID, accountID) +} + +// GetAccountRoutes retrieves network routes for an account. +func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) { + return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID) +} + +// GetRouteByID retrieves a route by its ID and account ID. +func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) { + return getRecordByID[route.Route](s.db.WithContext(ctx), lockStrength, routeID, accountID) +} + +// GetAccountSetupKeys retrieves setup keys for an account. +func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) { + return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID) +} + +// GetSetupKeyByID retrieves a setup key by its ID and account ID. +func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) { + return getRecordByID[SetupKey](s.db.WithContext(ctx), lockStrength, setupKeyID, accountID) +} + +// GetAccountNameServerGroups retrieves name server groups for an account. +func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) { + return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID) +} + +// GetNameServerGroupByID retrieves a name server group by its ID and account ID. +func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) { + return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID) +} + +// getRecords retrieves records from the database based on the account ID. +func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) { + var record []T + + result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID) + if err := result.Error; err != nil { + parts := strings.Split(fmt.Sprintf("%T", record), ".") + recordType := parts[len(parts)-1] + + return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err) + } + + return record, nil +} + +// getRecordByID retrieves a record by its ID and account ID from the database. +func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) (*T, error) { + var record T + + result := db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&record, accountAndIDQueryCondition, accountID, recordID) + if err := result.Error; err != nil { + parts := strings.Split(fmt.Sprintf("%T", record), ".") + recordType := parts[len(parts)-1] + + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "%s not found", recordType) + } + return nil, status.Errorf(status.Internal, "failed to get %s from store: %v", recordType, err) + } + return &record, nil +} diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index ce4ee531a..64ef36831 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1003,3 +1003,163 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) { require.NoError(t, err) require.Equal(t, id, user.PATs[id].ID) } + +func TestSqlite_GetTakenIPs(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + defer store.Close(context.Background()) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + _, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + + takenIPs, err := store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID) + require.NoError(t, err) + assert.Equal(t, []net.IP{}, takenIPs) + + peer1 := &nbpeer.Peer{ + ID: "peer1", + AccountID: existingAccountID, + IP: net.IP{1, 1, 1, 1}, + } + err = store.AddPeerToAccount(context.Background(), peer1) + require.NoError(t, err) + + takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID) + require.NoError(t, err) + ip1 := net.IP{1, 1, 1, 1}.To16() + assert.Equal(t, []net.IP{ip1}, takenIPs) + + peer2 := &nbpeer.Peer{ + ID: "peer2", + AccountID: existingAccountID, + IP: net.IP{2, 2, 2, 2}, + } + err = store.AddPeerToAccount(context.Background(), peer2) + require.NoError(t, err) + + takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID) + require.NoError(t, err) + ip2 := net.IP{2, 2, 2, 2}.To16() + assert.Equal(t, []net.IP{ip1, ip2}, takenIPs) + +} + +func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + defer store.Close(context.Background()) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + _, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + + labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) + require.NoError(t, err) + assert.Equal(t, []string{}, labels) + + peer1 := &nbpeer.Peer{ + ID: "peer1", + AccountID: existingAccountID, + DNSLabel: "peer1.domain.test", + } + err = store.AddPeerToAccount(context.Background(), peer1) + require.NoError(t, err) + + labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) + require.NoError(t, err) + assert.Equal(t, []string{"peer1.domain.test"}, labels) + + peer2 := &nbpeer.Peer{ + ID: "peer2", + AccountID: existingAccountID, + DNSLabel: "peer2.domain.test", + } + err = store.AddPeerToAccount(context.Background(), peer2) + require.NoError(t, err) + + labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) + require.NoError(t, err) + assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels) +} + +func TestSqlite_GetAccountNetwork(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + defer store.Close(context.Background()) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + _, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + + network, err := store.GetAccountNetwork(context.Background(), LockingStrengthShare, existingAccountID) + require.NoError(t, err) + ip := net.IP{100, 64, 0, 0}.To16() + assert.Equal(t, ip, network.Net.IP) + assert.Equal(t, net.IPMask{255, 255, 0, 0}, network.Net.Mask) + assert.Equal(t, "", network.Dns) + assert.Equal(t, "af1c8024-ha40-4ce2-9418-34653101fc3c", network.Identifier) + assert.Equal(t, uint64(0), network.Serial) +} + +func TestSqlite_GetSetupKeyBySecret(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + defer store.Close(context.Background()) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + _, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + + setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") + require.NoError(t, err) + assert.Equal(t, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", setupKey.Key) + assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", setupKey.AccountID) + assert.Equal(t, "Default key", setupKey.Name) +} + +func TestSqlite_incrementSetupKeyUsage(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + defer store.Close(context.Background()) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + _, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + + setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") + require.NoError(t, err) + assert.Equal(t, 0, setupKey.UsedTimes) + + err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id) + require.NoError(t, err) + + setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") + require.NoError(t, err) + assert.Equal(t, 1, setupKey.UsedTimes) + + err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id) + require.NoError(t, err) + + setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") + require.NoError(t, err) + assert.Equal(t, 2, setupKey.UsedTimes) +} diff --git a/management/server/status/error.go b/management/server/status/error.go index 58b9a84a0..d7fde35b9 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -100,3 +100,13 @@ func NewPeerNotRegisteredError() error { func NewPeerLoginExpiredError() error { return Errorf(PermissionDenied, "peer login has expired, please log in once more") } + +// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key +func NewSetupKeyNotFoundError() error { + return Errorf(NotFound, "setup key not found") +} + +// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store +func NewGetUserFromStoreError() error { + return Errorf(Internal, "issue getting user from store") +} diff --git a/management/server/store.go b/management/server/store.go index 864871c8e..f34a73c2d 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/netbirdio/netbird/dns" log "github.com/sirupsen/logrus" "gorm.io/gorm" @@ -27,45 +28,94 @@ import ( "github.com/netbirdio/netbird/route" ) +type LockingStrength string + +const ( + LockingStrengthUpdate LockingStrength = "UPDATE" // Strongest lock, preventing any changes by other transactions until your transaction completes. + LockingStrengthShare LockingStrength = "SHARE" // Allows reading but prevents changes by other transactions. + LockingStrengthNoKeyUpdate LockingStrength = "NO KEY UPDATE" // Similar to UPDATE but allows changes to related rows. + LockingStrengthKeyShare LockingStrength = "KEY SHARE" // Protects against changes to primary/unique keys but allows other updates. +) + type Store interface { GetAllAccounts(ctx context.Context) []*Account GetAccount(ctx context.Context, accountID string) (*Account, error) - DeleteAccount(ctx context.Context, account *Account) error + AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) + GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) GetAccountByUser(ctx context.Context, userID string) (*Account, error) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) - GetAccountIDByUserID(peerKey string) (string, error) + GetAccountIDByUserID(userID string) (string, error) GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) - GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) - GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) - GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) + GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) + GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) + GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) SaveAccount(ctx context.Context, account *Account) error + DeleteAccount(ctx context.Context, account *Account) error + + GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) + GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) SaveUsers(accountID string, users map[string]*User) error - SaveGroups(accountID string, groups map[string]*nbgroup.Group) error + SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error + GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteTokenID2UserIDIndex(tokenID string) error + + GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) + GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) + GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) + SaveGroups(accountID string, groups map[string]*nbgroup.Group) error + + GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) + GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) + + GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) + GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) + GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) + + GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) + AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error + AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error + AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error + GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) + SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error + SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error + SavePeerLocation(accountID string, peer *nbpeer.Peer) error + + GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) + IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error + GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) + GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) + + GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) + GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) + + GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error) + GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) + + GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) + IncrementNetworkSerial(ctx context.Context, accountId string) error + GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) + GetInstallationID() string SaveInstallationID(ctx context.Context, ID string) error + // AcquireWriteLockByUID should attempt to acquire a lock for write purposes and return a function that releases the lock AcquireWriteLockByUID(ctx context.Context, uniqueID string) func() // AcquireReadLockByUID should attempt to acquire lock for read purposes and return a function that releases the lock AcquireReadLockByUID(ctx context.Context, uniqueID string) func() // AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock AcquireGlobalLock(ctx context.Context) func() - SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error - SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error - SavePeerLocation(accountID string, peer *nbpeer.Peer) error - SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error + // Close should close the store persisting all unsaved data. Close(ctx context.Context) error // GetStoreEngine should return StoreEngine of the current store implementation. // This is also a method of metrics.DataSource interface. GetStoreEngine() StoreEngine - GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error) - GetAccountSettings(ctx context.Context, accountID string) (*Settings, error) + ExecuteInTransaction(ctx context.Context, f func(store Store) error) error } type StoreEngine string diff --git a/management/server/telemetry/http_api_metrics.go b/management/server/telemetry/http_api_metrics.go index a80453dca..357f019c7 100644 --- a/management/server/telemetry/http_api_metrics.go +++ b/management/server/telemetry/http_api_metrics.go @@ -8,6 +8,7 @@ import ( "time" "github.com/google/uuid" + "github.com/gorilla/mux" log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" @@ -54,112 +55,89 @@ func (rw *WrappedResponseWriter) WriteHeader(code int) { // HTTPMiddleware handler used to collect metrics of every request/response coming to the API. // Also adds request tracing (logging). type HTTPMiddleware struct { - meter metric.Meter - ctx context.Context + ctx context.Context // all HTTP requests by endpoint & method - httpRequestCounters map[string]metric.Int64Counter + httpRequestCounter metric.Int64Counter // all HTTP responses by endpoint & method & status code - httpResponseCounters map[string]metric.Int64Counter + httpResponseCounter metric.Int64Counter // all HTTP requests totalHTTPRequestsCounter metric.Int64Counter // all HTTP responses totalHTTPResponseCounter metric.Int64Counter // all HTTP responses by status code - totalHTTPResponseCodeCounters map[int]metric.Int64Counter + totalHTTPResponseCodeCounter metric.Int64Counter // all HTTP requests durations by endpoint and method - httpRequestDurations map[string]metric.Int64Histogram + httpRequestDuration metric.Int64Histogram // all HTTP requests durations totalHTTPRequestDuration metric.Int64Histogram } // NewMetricsMiddleware creates a new HTTPMiddleware func NewMetricsMiddleware(ctx context.Context, meter metric.Meter) (*HTTPMiddleware, error) { - totalHTTPRequestsCounter, err := meter.Int64Counter(fmt.Sprintf("%s_total", httpRequestCounterPrefix), metric.WithUnit("1")) + httpRequestCounter, err := meter.Int64Counter(httpRequestCounterPrefix, metric.WithUnit("1")) if err != nil { return nil, err } - totalHTTPResponseCounter, err := meter.Int64Counter(fmt.Sprintf("%s_total", httpResponseCounterPrefix), metric.WithUnit("1")) + httpResponseCounter, err := meter.Int64Counter(httpResponseCounterPrefix, metric.WithUnit("1")) if err != nil { return nil, err } - totalHTTPRequestDuration, err := meter.Int64Histogram(fmt.Sprintf("%s_total", httpRequestDurationPrefix), metric.WithUnit("milliseconds")) + totalHTTPRequestsCounter, err := meter.Int64Counter(fmt.Sprintf("%s.total", httpRequestCounterPrefix), metric.WithUnit("1")) + if err != nil { + return nil, err + } + + totalHTTPResponseCounter, err := meter.Int64Counter(fmt.Sprintf("%s.total", httpResponseCounterPrefix), metric.WithUnit("1")) + if err != nil { + return nil, err + } + + totalHTTPResponseCodeCounter, err := meter.Int64Counter(fmt.Sprintf("%s.code.total", httpResponseCounterPrefix), metric.WithUnit("1")) + if err != nil { + return nil, err + } + + httpRequestDuration, err := meter.Int64Histogram(httpRequestDurationPrefix, metric.WithUnit("milliseconds")) + if err != nil { + return nil, err + } + + totalHTTPRequestDuration, err := meter.Int64Histogram(fmt.Sprintf("%s.total", httpRequestDurationPrefix), metric.WithUnit("milliseconds")) if err != nil { return nil, err } return &HTTPMiddleware{ - ctx: ctx, - httpRequestCounters: map[string]metric.Int64Counter{}, - httpResponseCounters: map[string]metric.Int64Counter{}, - httpRequestDurations: map[string]metric.Int64Histogram{}, - totalHTTPResponseCodeCounters: map[int]metric.Int64Counter{}, - meter: meter, - totalHTTPRequestsCounter: totalHTTPRequestsCounter, - totalHTTPResponseCounter: totalHTTPResponseCounter, - totalHTTPRequestDuration: totalHTTPRequestDuration, + ctx: ctx, + httpRequestCounter: httpRequestCounter, + httpResponseCounter: httpResponseCounter, + httpRequestDuration: httpRequestDuration, + totalHTTPResponseCodeCounter: totalHTTPResponseCodeCounter, + totalHTTPRequestsCounter: totalHTTPRequestsCounter, + totalHTTPResponseCounter: totalHTTPResponseCounter, + totalHTTPRequestDuration: totalHTTPRequestDuration, }, nil } -// AddHTTPRequestResponseCounter adds a new meter for an HTTP defaultEndpoint and Method (GET, POST, etc) -// Creates one request counter and multiple response counters (one per http response status code). -func (m *HTTPMiddleware) AddHTTPRequestResponseCounter(endpoint string, method string) error { - meterKey := getRequestCounterKey(endpoint, method) - httpReqCounter, err := m.meter.Int64Counter(meterKey, metric.WithUnit("1")) - if err != nil { - return err - } - m.httpRequestCounters[meterKey] = httpReqCounter - - durationKey := getRequestDurationKey(endpoint, method) - requestDuration, err := m.meter.Int64Histogram(durationKey, metric.WithUnit("milliseconds")) - if err != nil { - return err - } - m.httpRequestDurations[durationKey] = requestDuration - - respCodes := []int{200, 204, 400, 401, 403, 404, 500, 502, 503} - for _, code := range respCodes { - meterKey = getResponseCounterKey(endpoint, method, code) - httpRespCounter, err := m.meter.Int64Counter(meterKey, metric.WithUnit("1")) - if err != nil { - return err - } - m.httpResponseCounters[meterKey] = httpRespCounter - - meterKey = fmt.Sprintf("%s_%d_total", httpResponseCounterPrefix, code) - totalHTTPResponseCodeCounter, err := m.meter.Int64Counter(meterKey, metric.WithUnit("1")) - if err != nil { - return err - } - m.totalHTTPResponseCodeCounters[code] = totalHTTPResponseCodeCounter - } - - return nil -} - func replaceEndpointChars(endpoint string) string { - endpoint = strings.ReplaceAll(endpoint, "/", "_") endpoint = strings.ReplaceAll(endpoint, "{", "") endpoint = strings.ReplaceAll(endpoint, "}", "") return endpoint } -func getRequestCounterKey(endpoint, method string) string { - endpoint = replaceEndpointChars(endpoint) - return fmt.Sprintf("%s%s_%s", httpRequestCounterPrefix, endpoint, method) -} - -func getRequestDurationKey(endpoint, method string) string { - endpoint = replaceEndpointChars(endpoint) - return fmt.Sprintf("%s%s_%s", httpRequestDurationPrefix, endpoint, method) -} - -func getResponseCounterKey(endpoint, method string, status int) string { - endpoint = replaceEndpointChars(endpoint) - return fmt.Sprintf("%s%s_%s_%d", httpResponseCounterPrefix, endpoint, method, status) +func getEndpointMetricAttr(r *http.Request) string { + var endpoint string + route := mux.CurrentRoute(r) + if route != nil { + pathTmpl, err := route.GetPathTemplate() + if err == nil { + endpoint = replaceEndpointChars(pathTmpl) + } + } + return endpoint } // Handler logs every request and response and adds the, to metrics. @@ -176,11 +154,10 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler { log.WithContext(ctx).Tracef("HTTP request %v: %v %v", reqID, r.Method, r.URL) - metricKey := getRequestCounterKey(r.URL.Path, r.Method) + endpointAttr := attribute.String("endpoint", getEndpointMetricAttr(r)) + methodAttr := attribute.String("method", r.Method) - if c, ok := m.httpRequestCounters[metricKey]; ok { - c.Add(m.ctx, 1) - } + m.httpRequestCounter.Add(m.ctx, 1, metric.WithAttributes(endpointAttr, methodAttr)) m.totalHTTPRequestsCounter.Add(m.ctx, 1) w := WrapResponseWriter(rw) @@ -193,21 +170,14 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler { log.WithContext(ctx).Tracef("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status()) } - metricKey = getResponseCounterKey(r.URL.Path, r.Method, w.Status()) - if c, ok := m.httpResponseCounters[metricKey]; ok { - c.Add(m.ctx, 1) - } + statusCodeAttr := attribute.Int("code", w.Status()) + m.httpResponseCounter.Add(m.ctx, 1, metric.WithAttributes(endpointAttr, methodAttr, statusCodeAttr)) m.totalHTTPResponseCounter.Add(m.ctx, 1) - if c, ok := m.totalHTTPResponseCodeCounters[w.Status()]; ok { - c.Add(m.ctx, 1) - } + m.totalHTTPResponseCodeCounter.Add(m.ctx, 1, metric.WithAttributes(statusCodeAttr)) - durationKey := getRequestDurationKey(r.URL.Path, r.Method) reqTook := time.Since(reqStart) - if c, ok := m.httpRequestDurations[durationKey]; ok { - c.Record(m.ctx, reqTook.Milliseconds()) - } + m.httpRequestDuration.Record(m.ctx, reqTook.Milliseconds(), metric.WithAttributes(endpointAttr, methodAttr)) log.WithContext(ctx).Debugf("request %s %s took %d ms and finished with status %d", r.Method, r.URL.Path, reqTook.Milliseconds(), w.Status()) if w.Status() == 200 && (r.Method == http.MethodPut || r.Method == http.MethodPost || r.Method == http.MethodDelete) { diff --git a/management/server/testdata/GeoLite2-City-Test.mmdb b/management/server/testdata/GeoLite2-City_20240305.mmdb similarity index 100% rename from management/server/testdata/GeoLite2-City-Test.mmdb rename to management/server/testdata/GeoLite2-City_20240305.mmdb diff --git a/management/server/testdata/extended-store.json b/management/server/testdata/extended-store.json new file mode 100644 index 000000000..7f96e57a8 --- /dev/null +++ b/management/server/testdata/extended-store.json @@ -0,0 +1,120 @@ +{ + "Accounts": { + "bf1c8084-ba50-4ce7-9439-34653001fc3b": { + "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b", + "CreatedBy": "", + "Domain": "test.com", + "DomainCategory": "private", + "IsDomainPrimaryAccount": true, + "SetupKeys": { + "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": { + "Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", + "AccountID": "", + "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", + "Name": "Default key", + "Type": "reusable", + "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", + "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", + "UpdatedAt": "0001-01-01T00:00:00Z", + "Revoked": false, + "UsedTimes": 0, + "LastUsed": "0001-01-01T00:00:00Z", + "AutoGroups": ["cfefqs706sqkneg59g2g"], + "UsageLimit": 0, + "Ephemeral": false + }, + "A2C8E62B-38F5-4553-B31E-DD66C696CEBC": { + "Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC", + "AccountID": "", + "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC", + "Name": "Faulty key with non existing group", + "Type": "reusable", + "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", + "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", + "UpdatedAt": "0001-01-01T00:00:00Z", + "Revoked": false, + "UsedTimes": 0, + "LastUsed": "0001-01-01T00:00:00Z", + "AutoGroups": ["abcd"], + "UsageLimit": 0, + "Ephemeral": false + } + }, + "Network": { + "id": "af1c8024-ha40-4ce2-9418-34653101fc3c", + "Net": { + "IP": "100.64.0.0", + "Mask": "//8AAA==" + }, + "Dns": "", + "Serial": 0 + }, + "Peers": {}, + "Users": { + "edafee4e-63fb-11ec-90d6-0242ac120003": { + "Id": "edafee4e-63fb-11ec-90d6-0242ac120003", + "AccountID": "", + "Role": "admin", + "IsServiceUser": false, + "ServiceUserName": "", + "AutoGroups": ["cfefqs706sqkneg59g3g"], + "PATs": {}, + "Blocked": false, + "LastLogin": "0001-01-01T00:00:00Z" + }, + "f4f6d672-63fb-11ec-90d6-0242ac120003": { + "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003", + "AccountID": "", + "Role": "user", + "IsServiceUser": false, + "ServiceUserName": "", + "AutoGroups": null, + "PATs": { + "9dj38s35-63fb-11ec-90d6-0242ac120003": { + "ID": "9dj38s35-63fb-11ec-90d6-0242ac120003", + "UserID": "", + "Name": "", + "HashedToken": "SoMeHaShEdToKeN", + "ExpirationDate": "2023-02-27T00:00:00Z", + "CreatedBy": "user", + "CreatedAt": "2023-01-01T00:00:00Z", + "LastUsed": "2023-02-01T00:00:00Z" + } + }, + "Blocked": false, + "LastLogin": "0001-01-01T00:00:00Z" + } + }, + "Groups": { + "cfefqs706sqkneg59g4g": { + "ID": "cfefqs706sqkneg59g4g", + "Name": "All", + "Peers": [] + }, + "cfefqs706sqkneg59g3g": { + "ID": "cfefqs706sqkneg59g3g", + "Name": "AwesomeGroup1", + "Peers": [] + }, + "cfefqs706sqkneg59g2g": { + "ID": "cfefqs706sqkneg59g2g", + "Name": "AwesomeGroup2", + "Peers": [] + } + }, + "Rules": null, + "Policies": [], + "Routes": null, + "NameServerGroups": null, + "DNSSettings": null, + "Settings": { + "PeerLoginExpirationEnabled": false, + "PeerLoginExpiration": 86400000000000, + "GroupsPropagationEnabled": false, + "JWTGroupsEnabled": false, + "JWTGroupsClaimName": "" + } + } + }, + "InstallationID": "" +} diff --git a/management/server/testdata/geonames-test.db b/management/server/testdata/geonames_20240305.db similarity index 100% rename from management/server/testdata/geonames-test.db rename to management/server/testdata/geonames_20240305.db diff --git a/management/server/token_mgr.go b/management/server/token_mgr.go new file mode 100644 index 000000000..ef8276b59 --- /dev/null +++ b/management/server/token_mgr.go @@ -0,0 +1,232 @@ +package server + +import ( + "context" + "crypto/sha1" + "crypto/sha256" + "encoding/base64" + "fmt" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/proto" + auth "github.com/netbirdio/netbird/relay/auth/hmac" + authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2" +) + +const defaultDuration = 12 * time.Hour + +// SecretsManager used to manage TURN and relay secrets +type SecretsManager interface { + GenerateTurnToken() (*Token, error) + GenerateRelayToken() (*Token, error) + SetupRefresh(ctx context.Context, peerKey string) + CancelRefresh(peerKey string) +} + +// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server +type TimeBasedAuthSecretsManager struct { + mux sync.Mutex + turnCfg *TURNConfig + relayCfg *Relay + turnHmacToken *auth.TimedHMAC + relayHmacToken *authv2.Generator + updateManager *PeersUpdateManager + turnCancelMap map[string]chan struct{} + relayCancelMap map[string]chan struct{} +} + +type Token auth.Token + +func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, turnCfg *TURNConfig, relayCfg *Relay) *TimeBasedAuthSecretsManager { + mgr := &TimeBasedAuthSecretsManager{ + updateManager: updateManager, + turnCfg: turnCfg, + relayCfg: relayCfg, + turnCancelMap: make(map[string]chan struct{}), + relayCancelMap: make(map[string]chan struct{}), + } + + if turnCfg != nil { + duration := turnCfg.CredentialsTTL.Duration + if turnCfg.CredentialsTTL.Duration <= 0 { + log.Warnf("TURN credentials TTL is not set or invalid, using default value %s", defaultDuration) + duration = defaultDuration + } + mgr.turnHmacToken = auth.NewTimedHMAC(turnCfg.Secret, duration) + } + + if relayCfg != nil { + duration := relayCfg.CredentialsTTL.Duration + if relayCfg.CredentialsTTL.Duration <= 0 { + log.Warnf("Relay credentials TTL is not set or invalid, using default value %s", defaultDuration) + duration = defaultDuration + } + + hashedSecret := sha256.Sum256([]byte(relayCfg.Secret)) + var err error + if mgr.relayHmacToken, err = authv2.NewGenerator(authv2.AuthAlgoHMACSHA256, hashedSecret[:], duration); err != nil { + log.Errorf("failed to create relay token generator: %s", err) + } + } + + return mgr +} + +// GenerateTurnToken generates new time-based secret credentials for TURN +func (m *TimeBasedAuthSecretsManager) GenerateTurnToken() (*Token, error) { + if m.turnHmacToken == nil { + return nil, fmt.Errorf("TURN configuration is not set") + } + turnToken, err := m.turnHmacToken.GenerateToken(sha1.New) + if err != nil { + return nil, fmt.Errorf("generate TURN token: %s", err) + } + return (*Token)(turnToken), nil +} + +// GenerateRelayToken generates new time-based secret credentials for relay +func (m *TimeBasedAuthSecretsManager) GenerateRelayToken() (*Token, error) { + if m.relayHmacToken == nil { + return nil, fmt.Errorf("relay configuration is not set") + } + relayToken, err := m.relayHmacToken.GenerateToken() + if err != nil { + return nil, fmt.Errorf("generate relay token: %s", err) + } + + return &Token{ + Payload: string(relayToken.Payload), + Signature: base64.StdEncoding.EncodeToString(relayToken.Signature), + }, nil +} + +func (m *TimeBasedAuthSecretsManager) cancelTURN(peerID string) { + if channel, ok := m.turnCancelMap[peerID]; ok { + close(channel) + delete(m.turnCancelMap, peerID) + } +} + +func (m *TimeBasedAuthSecretsManager) cancelRelay(peerID string) { + if channel, ok := m.relayCancelMap[peerID]; ok { + close(channel) + delete(m.relayCancelMap, peerID) + } +} + +// CancelRefresh cancels scheduled peer credentials refresh +func (m *TimeBasedAuthSecretsManager) CancelRefresh(peerID string) { + m.mux.Lock() + defer m.mux.Unlock() + m.cancelTURN(peerID) + m.cancelRelay(peerID) +} + +// SetupRefresh starts peer credentials refresh +func (m *TimeBasedAuthSecretsManager) SetupRefresh(ctx context.Context, peerID string) { + m.mux.Lock() + defer m.mux.Unlock() + + m.cancelTURN(peerID) + m.cancelRelay(peerID) + + if m.turnCfg != nil && m.turnCfg.TimeBasedCredentials { + turnCancel := make(chan struct{}, 1) + m.turnCancelMap[peerID] = turnCancel + go m.refreshTURNTokens(ctx, peerID, turnCancel) + log.WithContext(ctx).Debugf("starting TURN refresh for %s", peerID) + } + + if m.relayCfg != nil { + relayCancel := make(chan struct{}, 1) + m.relayCancelMap[peerID] = relayCancel + go m.refreshRelayTokens(ctx, peerID, relayCancel) + log.WithContext(ctx).Debugf("starting relay refresh for %s", peerID) + } +} + +func (m *TimeBasedAuthSecretsManager) refreshTURNTokens(ctx context.Context, peerID string, cancel chan struct{}) { + ticker := time.NewTicker(m.turnCfg.CredentialsTTL.Duration / 4 * 3) + defer ticker.Stop() + + for { + select { + case <-cancel: + log.WithContext(ctx).Debugf("stopping TURN refresh for %s", peerID) + return + case <-ticker.C: + m.pushNewTURNTokens(ctx, peerID) + } + } +} + +func (m *TimeBasedAuthSecretsManager) refreshRelayTokens(ctx context.Context, peerID string, cancel chan struct{}) { + ticker := time.NewTicker(m.relayCfg.CredentialsTTL.Duration / 4 * 3) + defer ticker.Stop() + + for { + select { + case <-cancel: + log.WithContext(ctx).Debugf("stopping relay refresh for %s", peerID) + return + case <-ticker.C: + m.pushNewRelayTokens(ctx, peerID) + } + } +} + +func (m *TimeBasedAuthSecretsManager) pushNewTURNTokens(ctx context.Context, peerID string) { + turnToken, err := m.turnHmacToken.GenerateToken(sha1.New) + if err != nil { + log.Errorf("failed to generate token for peer '%s': %s", peerID, err) + return + } + + var turns []*proto.ProtectedHostConfig + for _, host := range m.turnCfg.Turns { + turn := &proto.ProtectedHostConfig{ + HostConfig: &proto.HostConfig{ + Uri: host.URI, + Protocol: ToResponseProto(host.Proto), + }, + User: turnToken.Payload, + Password: turnToken.Signature, + } + turns = append(turns, turn) + } + + update := &proto.SyncResponse{ + WiretrusteeConfig: &proto.WiretrusteeConfig{ + Turns: turns, + // omit Relay to avoid updates there + }, + } + + log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID) + m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update}) +} + +func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, peerID string) { + relayToken, err := m.relayHmacToken.GenerateToken() + if err != nil { + log.Errorf("failed to generate relay token for peer '%s': %s", peerID, err) + return + } + + update := &proto.SyncResponse{ + WiretrusteeConfig: &proto.WiretrusteeConfig{ + Relay: &proto.RelayConfig{ + Urls: m.relayCfg.Addresses, + TokenPayload: string(relayToken.Payload), + TokenSignature: base64.StdEncoding.EncodeToString(relayToken.Signature), + }, + // omit Turns to avoid updates there + }, + } + + log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID) + m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update}) +} diff --git a/management/server/token_mgr_test.go b/management/server/token_mgr_test.go new file mode 100644 index 000000000..3e63346c2 --- /dev/null +++ b/management/server/token_mgr_test.go @@ -0,0 +1,219 @@ +package server + +import ( + "context" + "crypto/hmac" + "crypto/sha1" + "crypto/sha256" + "encoding/base64" + "hash" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/util" +) + +var TurnTestHost = &Host{ + Proto: UDP, + URI: "turn:turn.wiretrustee.com:77777", + Username: "username", + Password: "", +} + +func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) { + ttl := util.Duration{Duration: time.Hour} + secret := "some_secret" + peersManager := NewPeersUpdateManager(nil) + + rc := &Relay{ + Addresses: []string{"localhost:0"}, + CredentialsTTL: ttl, + Secret: secret, + } + + tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{ + CredentialsTTL: ttl, + Secret: secret, + Turns: []*Host{TurnTestHost}, + TimeBasedCredentials: true, + }, rc) + + turnCredentials, err := tested.GenerateTurnToken() + require.NoError(t, err) + + if turnCredentials.Payload == "" { + t.Errorf("expected generated TURN username not to be empty, got empty") + } + if turnCredentials.Signature == "" { + t.Errorf("expected generated TURN password not to be empty, got empty") + } + + validateMAC(t, sha1.New, turnCredentials.Payload, turnCredentials.Signature, []byte(secret)) + + relayCredentials, err := tested.GenerateRelayToken() + require.NoError(t, err) + + if relayCredentials.Payload == "" { + t.Errorf("expected generated relay payload not to be empty, got empty") + } + if relayCredentials.Signature == "" { + t.Errorf("expected generated relay signature not to be empty, got empty") + } + + hashedSecret := sha256.Sum256([]byte(secret)) + validateMAC(t, sha256.New, relayCredentials.Payload, relayCredentials.Signature, hashedSecret[:]) +} + +func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) { + ttl := util.Duration{Duration: 2 * time.Second} + secret := "some_secret" + peersManager := NewPeersUpdateManager(nil) + peer := "some_peer" + updateChannel := peersManager.CreateChannel(context.Background(), peer) + + rc := &Relay{ + Addresses: []string{"localhost:0"}, + CredentialsTTL: ttl, + Secret: secret, + } + tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{ + CredentialsTTL: ttl, + Secret: secret, + Turns: []*Host{TurnTestHost}, + TimeBasedCredentials: true, + }, rc) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tested.SetupRefresh(ctx, peer) + + if _, ok := tested.turnCancelMap[peer]; !ok { + t.Errorf("expecting peer to be present in the turn cancel map, got not present") + } + + if _, ok := tested.relayCancelMap[peer]; !ok { + t.Errorf("expecting peer to be present in the relay cancel map, got not present") + } + + var updates []*UpdateMessage + +loop: + for timeout := time.After(5 * time.Second); ; { + select { + case update := <-updateChannel: + updates = append(updates, update) + case <-timeout: + break loop + } + + if len(updates) >= 2 { + break loop + } + } + + if len(updates) < 2 { + t.Errorf("expecting at least 2 peer credentials updates, got %v", len(updates)) + } + + var turnUpdates, relayUpdates int + var firstTurnUpdate, secondTurnUpdate *proto.ProtectedHostConfig + var firstRelayUpdate, secondRelayUpdate *proto.RelayConfig + + for _, update := range updates { + if turns := update.Update.GetWiretrusteeConfig().GetTurns(); len(turns) > 0 { + turnUpdates++ + if turnUpdates == 1 { + firstTurnUpdate = turns[0] + } else { + secondTurnUpdate = turns[0] + } + } + if relay := update.Update.GetWiretrusteeConfig().GetRelay(); relay != nil { + relayUpdates++ + if relayUpdates == 1 { + firstRelayUpdate = relay + } else { + secondRelayUpdate = relay + } + } + } + + if turnUpdates < 1 { + t.Errorf("expecting at least 1 TURN credential update, got %v", turnUpdates) + } + if relayUpdates < 1 { + t.Errorf("expecting at least 1 relay credential update, got %v", relayUpdates) + } + + if firstTurnUpdate != nil && secondTurnUpdate != nil { + if firstTurnUpdate.Password == secondTurnUpdate.Password { + t.Errorf("expecting first TURN credential update password %v to be different from second, got equal", firstTurnUpdate.Password) + } + } + + if firstRelayUpdate != nil && secondRelayUpdate != nil { + if firstRelayUpdate.TokenSignature == secondRelayUpdate.TokenSignature { + t.Errorf("expecting first relay credential update signature %v to be different from second, got equal", firstRelayUpdate.TokenSignature) + } + } +} + +func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) { + ttl := util.Duration{Duration: time.Hour} + secret := "some_secret" + peersManager := NewPeersUpdateManager(nil) + peer := "some_peer" + + rc := &Relay{ + Addresses: []string{"localhost:0"}, + CredentialsTTL: ttl, + Secret: secret, + } + tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{ + CredentialsTTL: ttl, + Secret: secret, + Turns: []*Host{TurnTestHost}, + TimeBasedCredentials: true, + }, rc) + + tested.SetupRefresh(context.Background(), peer) + if _, ok := tested.turnCancelMap[peer]; !ok { + t.Errorf("expecting peer to be present in turn cancel map, got not present") + } + if _, ok := tested.relayCancelMap[peer]; !ok { + t.Errorf("expecting peer to be present in relay cancel map, got not present") + } + + tested.CancelRefresh(peer) + if _, ok := tested.turnCancelMap[peer]; ok { + t.Errorf("expecting peer to be not present in turn cancel map, got present") + } + if _, ok := tested.relayCancelMap[peer]; ok { + t.Errorf("expecting peer to be not present in relay cancel map, got present") + } +} + +func validateMAC(t *testing.T, algo func() hash.Hash, username string, actualMAC string, key []byte) { + t.Helper() + mac := hmac.New(algo, key) + + _, err := mac.Write([]byte(username)) + if err != nil { + t.Fatal(err) + } + + expectedMAC := mac.Sum(nil) + decodedMAC, err := base64.StdEncoding.DecodeString(actualMAC) + if err != nil { + t.Fatal(err) + } + equal := hmac.Equal(decodedMAC, expectedMAC) + + if !equal { + t.Errorf("expected password MAC to be %s. got %s", expectedMAC, decodedMAC) + } +} diff --git a/management/server/turncredentials.go b/management/server/turncredentials.go deleted file mode 100644 index 79f42e882..000000000 --- a/management/server/turncredentials.go +++ /dev/null @@ -1,126 +0,0 @@ -package server - -import ( - "context" - "crypto/hmac" - "crypto/sha1" - "encoding/base64" - "fmt" - "sync" - "time" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/management/proto" -) - -// TURNCredentialsManager used to manage TURN credentials -type TURNCredentialsManager interface { - GenerateCredentials() TURNCredentials - SetupRefresh(ctx context.Context, peerKey string) - CancelRefresh(peerKey string) -} - -// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server -type TimeBasedAuthSecretsManager struct { - mux sync.Mutex - config *TURNConfig - updateManager *PeersUpdateManager - cancelMap map[string]chan struct{} -} - -type TURNCredentials struct { - Username string - Password string -} - -func NewTimeBasedAuthSecretsManager(updateManager *PeersUpdateManager, config *TURNConfig) *TimeBasedAuthSecretsManager { - return &TimeBasedAuthSecretsManager{ - mux: sync.Mutex{}, - config: config, - updateManager: updateManager, - cancelMap: make(map[string]chan struct{}), - } -} - -// GenerateCredentials generates new time-based secret credentials - basically username is a unix timestamp and password is a HMAC hash of a timestamp with a preshared TURN secret -func (m *TimeBasedAuthSecretsManager) GenerateCredentials() TURNCredentials { - mac := hmac.New(sha1.New, []byte(m.config.Secret)) - - timeAuth := time.Now().Add(m.config.CredentialsTTL.Duration).Unix() - - username := fmt.Sprint(timeAuth) - - _, err := mac.Write([]byte(username)) - if err != nil { - log.Errorln("Generating turn password failed with error: ", err) - } - - bytePassword := mac.Sum(nil) - password := base64.StdEncoding.EncodeToString(bytePassword) - - return TURNCredentials{ - Username: username, - Password: password, - } - -} - -func (m *TimeBasedAuthSecretsManager) cancel(peerID string) { - if channel, ok := m.cancelMap[peerID]; ok { - close(channel) - delete(m.cancelMap, peerID) - } -} - -// CancelRefresh cancels scheduled peer credentials refresh -func (m *TimeBasedAuthSecretsManager) CancelRefresh(peerID string) { - m.mux.Lock() - defer m.mux.Unlock() - m.cancel(peerID) -} - -// SetupRefresh starts peer credentials refresh. Since credentials are expiring (TTL) it is necessary to always generate them and send to the peer. -// A goroutine is created and put into TimeBasedAuthSecretsManager.cancelMap. This routine should be cancelled if peer is gone. -func (m *TimeBasedAuthSecretsManager) SetupRefresh(ctx context.Context, peerID string) { - m.mux.Lock() - defer m.mux.Unlock() - m.cancel(peerID) - cancel := make(chan struct{}, 1) - m.cancelMap[peerID] = cancel - log.WithContext(ctx).Debugf("starting turn refresh for %s", peerID) - - go func() { - // we don't want to regenerate credentials right on expiration, so we do it slightly before (at 3/4 of TTL) - ticker := time.NewTicker(m.config.CredentialsTTL.Duration / 4 * 3) - - for { - select { - case <-cancel: - log.WithContext(ctx).Debugf("stopping turn refresh for %s", peerID) - return - case <-ticker.C: - c := m.GenerateCredentials() - var turns []*proto.ProtectedHostConfig - for _, host := range m.config.Turns { - turns = append(turns, &proto.ProtectedHostConfig{ - HostConfig: &proto.HostConfig{ - Uri: host.URI, - Protocol: ToResponseProto(host.Proto), - }, - User: c.Username, - Password: c.Password, - }) - } - - update := &proto.SyncResponse{ - WiretrusteeConfig: &proto.WiretrusteeConfig{ - Turns: turns, - }, - } - log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID) - m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update}) - } - } - }() -} diff --git a/management/server/turncredentials_test.go b/management/server/turncredentials_test.go deleted file mode 100644 index 667dccbb5..000000000 --- a/management/server/turncredentials_test.go +++ /dev/null @@ -1,136 +0,0 @@ -package server - -import ( - "context" - "crypto/hmac" - "crypto/sha1" - "encoding/base64" - "testing" - "time" - - "github.com/netbirdio/netbird/util" -) - -var TurnTestHost = &Host{ - Proto: UDP, - URI: "turn:turn.wiretrustee.com:77777", - Username: "username", - Password: "", -} - -func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) { - ttl := util.Duration{Duration: time.Hour} - secret := "some_secret" - peersManager := NewPeersUpdateManager(nil) - - tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{ - CredentialsTTL: ttl, - Secret: secret, - Turns: []*Host{TurnTestHost}, - }) - - credentials := tested.GenerateCredentials() - - if credentials.Username == "" { - t.Errorf("expected generated TURN username not to be empty, got empty") - } - if credentials.Password == "" { - t.Errorf("expected generated TURN password not to be empty, got empty") - } - - validateMAC(t, credentials.Username, credentials.Password, []byte(secret)) - -} - -func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) { - ttl := util.Duration{Duration: 2 * time.Second} - secret := "some_secret" - peersManager := NewPeersUpdateManager(nil) - peer := "some_peer" - updateChannel := peersManager.CreateChannel(context.Background(), peer) - - tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{ - CredentialsTTL: ttl, - Secret: secret, - Turns: []*Host{TurnTestHost}, - }) - - tested.SetupRefresh(context.Background(), peer) - - if _, ok := tested.cancelMap[peer]; !ok { - t.Errorf("expecting peer to be present in a cancel map, got not present") - } - - var updates []*UpdateMessage - -loop: - for timeout := time.After(5 * time.Second); ; { - - select { - case update := <-updateChannel: - updates = append(updates, update) - case <-timeout: - break loop - } - - if len(updates) >= 2 { - break loop - } - } - - if len(updates) < 2 { - t.Errorf("expecting 2 peer credentials updates, got %v", len(updates)) - } - - firstUpdate := updates[0].Update.GetWiretrusteeConfig().Turns[0] - secondUpdate := updates[1].Update.GetWiretrusteeConfig().Turns[0] - - if firstUpdate.Password == secondUpdate.Password { - t.Errorf("expecting first credential update password %v to be diffeerent from second, got equal", firstUpdate.Password) - } - -} - -func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) { - ttl := util.Duration{Duration: time.Hour} - secret := "some_secret" - peersManager := NewPeersUpdateManager(nil) - peer := "some_peer" - - tested := NewTimeBasedAuthSecretsManager(peersManager, &TURNConfig{ - CredentialsTTL: ttl, - Secret: secret, - Turns: []*Host{TurnTestHost}, - }) - - tested.SetupRefresh(context.Background(), peer) - if _, ok := tested.cancelMap[peer]; !ok { - t.Errorf("expecting peer to be present in a cancel map, got not present") - } - - tested.CancelRefresh(peer) - if _, ok := tested.cancelMap[peer]; ok { - t.Errorf("expecting peer to be not present in a cancel map, got present") - } -} - -func validateMAC(t *testing.T, username string, actualMAC string, key []byte) { - t.Helper() - mac := hmac.New(sha1.New, key) - - _, err := mac.Write([]byte(username)) - if err != nil { - t.Fatal(err) - } - - expectedMAC := mac.Sum(nil) - decodedMAC, err := base64.StdEncoding.DecodeString(actualMAC) - if err != nil { - t.Fatal(err) - } - equal := hmac.Equal(decodedMAC, expectedMAC) - - if !equal { - t.Errorf("expected password MAC to be %s. got %s", expectedMAC, decodedMAC) - } -} diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index 68e694893..36168cffe 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -84,7 +84,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda log.WithContext(ctx).Debugf("update was sent to channel for peer %s", peerID) default: dropped = true - log.WithContext(ctx).Warnf("channel for peer %s is %d full", peerID, len(channel)) + log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel)) } } else { log.WithContext(ctx).Debugf("peer %s has no channel", peerID) diff --git a/management/server/user.go b/management/server/user.go index 6c7fdfe3c..7acb0b487 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -90,15 +90,16 @@ func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool { return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero() } -func (u *User) updateLastLogin(login time.Time) { - u.LastLogin = login -} - // HasAdminPower returns true if the user has admin or owner roles, false otherwise func (u *User) HasAdminPower() bool { return u.Role == UserRoleAdmin || u.Role == UserRoleOwner } +// IsAdminOrServiceUser checks if the user has admin power or is a service user. +func (u *User) IsAdminOrServiceUser() bool { + return u.HasAdminPower() || u.IsServiceUser +} + // ToUserInfo converts a User object to a UserInfo object. func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) { autoGroups := u.AutoGroups @@ -362,39 +363,35 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u return newUser.ToUserInfo(idpUser, account.Settings) } +func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*User, error) { + return am.Store.GetUserByUserID(ctx, LockingStrengthShare, id) +} + // GetUser looks up a user by provided authorization claims. // It will also create an account if didn't exist for this user before. func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) { - account, _, err := am.GetAccountFromToken(ctx, claims) + accountID, userID, err := am.GetAccountIDFromToken(ctx, claims) if err != nil { return nil, fmt.Errorf("failed to get account with token claims %v", err) } - unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id) - defer unlock() - - account, err = am.Store.GetAccount(ctx, account.Id) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { - return nil, fmt.Errorf("failed to get an account from store %v", err) + return nil, err } - user, ok := account.Users[claims.UserId] - if !ok { - return nil, status.Errorf(status.NotFound, "user not found") - } - - // this code should be outside of the am.GetAccountFromToken(claims) because this method is called also by the gRPC + // this code should be outside of the am.GetAccountIDFromToken(claims) because this method is called also by the gRPC // server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event. newLogin := user.LastDashboardLoginChanged(claims.LastLogin) - err = am.Store.SaveUserLastLogin(account.Id, claims.UserId, claims.LastLogin) + err = am.Store.SaveUserLastLogin(ctx, accountID, userID, claims.LastLogin) if err != nil { log.WithContext(ctx).Errorf("failed saving user last login: %v", err) } if newLogin { meta := map[string]any{"timestamp": claims.LastLogin} - am.StoreEvent(ctx, claims.UserId, claims.UserId, account.Id, activity.DashboardLogin, meta) + am.StoreEvent(ctx, claims.UserId, claims.UserId, accountID, activity.DashboardLogin, meta) } return user, nil @@ -654,63 +651,48 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string // GetPAT returns a specific PAT from a user func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) if err != nil { - return nil, status.Errorf(status.NotFound, "account not found: %s", err) + return nil, err } - targetUser, ok := account.Users[targetUserID] - if !ok { - return nil, status.Errorf(status.NotFound, "user not found") + targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID) + if err != nil { + return nil, err } - executingUser, ok := account.Users[initiatorUserID] - if !ok { - return nil, status.Errorf(status.NotFound, "user not found") + if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID { + return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user") } - if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) { - return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this userser") + for _, pat := range targetUser.PATsG { + if pat.ID == tokenID { + return pat.Copy(), nil + } } - pat := targetUser.PATs[tokenID] - if pat == nil { - return nil, status.Errorf(status.NotFound, "PAT not found") - } - - return pat, nil + return nil, status.Errorf(status.NotFound, "PAT not found") } // GetAllPATs returns all PATs for a user func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) if err != nil { - return nil, status.Errorf(status.NotFound, "account not found: %s", err) + return nil, err } - targetUser, ok := account.Users[targetUserID] - if !ok { - return nil, status.Errorf(status.NotFound, "user not found") + targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID) + if err != nil { + return nil, err } - executingUser, ok := account.Users[initiatorUserID] - if !ok { - return nil, status.Errorf(status.NotFound, "user not found") - } - - if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) { + if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user") } - var pats []*PersonalAccessToken - for _, pat := range targetUser.PATs { - pats = append(pats, pat) + pats := make([]*PersonalAccessToken, 0, len(targetUser.PATsG)) + for _, pat := range targetUser.PATsG { + pats = append(pats, pat.Copy()) } return pats, nil diff --git a/management/server/user_test.go b/management/server/user_test.go index f5a26ef89..7740b4059 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -202,7 +202,8 @@ func TestUser_GetPAT(t *testing.T) { defer store.Close(context.Background()) account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users[mockUserID] = &User{ - Id: mockUserID, + Id: mockUserID, + AccountID: mockAccountID, PATs: map[string]*PersonalAccessToken{ mockTokenID1: { ID: mockTokenID1, @@ -234,7 +235,8 @@ func TestUser_GetAllPATs(t *testing.T) { defer store.Close(context.Background()) account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users[mockUserID] = &User{ - Id: mockUserID, + Id: mockUserID, + AccountID: mockAccountID, PATs: map[string]*PersonalAccessToken{ mockTokenID1: { ID: mockTokenID1, @@ -799,7 +801,10 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { assert.NoError(t, err) } - acc, err := am.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "") + accID, err := am.GetAccountIDByUserOrAccountID(context.Background(), "", account.Id, "") + assert.NoError(t, err) + + acc, err := am.Store.GetAccount(context.Background(), accID) assert.NoError(t, err) for _, id := range tc.expectedDeleted { diff --git a/relay/Dockerfile b/relay/Dockerfile new file mode 100644 index 000000000..f750027c3 --- /dev/null +++ b/relay/Dockerfile @@ -0,0 +1,4 @@ +FROM gcr.io/distroless/base:debug +ENTRYPOINT [ "/go/bin/netbird-relay" ] +ENV NB_LOG_FILE=console +COPY netbird-relay /go/bin/netbird-relay diff --git a/relay/auth/allow/allow_all.go b/relay/auth/allow/allow_all.go new file mode 100644 index 000000000..2d30c59c9 --- /dev/null +++ b/relay/auth/allow/allow_all.go @@ -0,0 +1,14 @@ +package allow + +// Auth is a Validator that allows all connections. +// Used this for testing purposes only. +type Auth struct { +} + +func (a *Auth) Validate(any) error { + return nil +} + +func (a *Auth) ValidateHelloMsgType(any) error { + return nil +} diff --git a/relay/auth/doc.go b/relay/auth/doc.go new file mode 100644 index 000000000..b3e8dbb08 --- /dev/null +++ b/relay/auth/doc.go @@ -0,0 +1,26 @@ +/* +Package auth manages the authentication process with the relay server. + +Key Components: + +Validator: The Validator interface defines the Validate method. Any type that provides this method can be used as a +Validator. + +Methods: + +Validate(func() hash.Hash, any): This method is defined in the Validator interface and is used to validate the authentication. + +Usage: + +To create a new AllowAllAuth validator, simply instantiate it: + + validator := &allow.Auth{} + +To validate the authentication, use the Validate method: + + err := validator.Validate(sha256.New, any) + +This package provides a simple and effective way to manage authentication with the relay server, ensuring that the +peers are authenticated properly. +*/ +package auth diff --git a/relay/auth/hmac/doc.go b/relay/auth/hmac/doc.go new file mode 100644 index 000000000..a1b135aa6 --- /dev/null +++ b/relay/auth/hmac/doc.go @@ -0,0 +1,8 @@ +/* +This package uses a similar HMAC method for authentication with the TURN server. The Management server provides the +tokens for the peers. The peers manage these tokens in the token store. The token store is a simple thread safe store +that keeps the tokens in memory. These tokens are used to authenticate the peers with the Relay server in the hello +message. +*/ + +package hmac diff --git a/relay/auth/hmac/store.go b/relay/auth/hmac/store.go new file mode 100644 index 000000000..169b8d6b0 --- /dev/null +++ b/relay/auth/hmac/store.go @@ -0,0 +1,44 @@ +package hmac + +import ( + "encoding/base64" + "fmt" + "sync" + + v2 "github.com/netbirdio/netbird/relay/auth/hmac/v2" +) + +// TokenStore is a simple in-memory store for token +// With this can update the token in thread safe way +type TokenStore struct { + mu sync.Mutex + token []byte +} + +func (a *TokenStore) UpdateToken(token *Token) error { + a.mu.Lock() + defer a.mu.Unlock() + if token == nil { + return nil + } + + sig, err := base64.StdEncoding.DecodeString(token.Signature) + if err != nil { + return fmt.Errorf("decode signature: %w", err) + } + + tok := v2.Token{ + AuthAlgo: v2.AuthAlgoHMACSHA256, + Signature: sig, + Payload: []byte(token.Payload), + } + + a.token = tok.Marshal() + return nil +} + +func (a *TokenStore) TokenBinary() []byte { + a.mu.Lock() + defer a.mu.Unlock() + return a.token +} diff --git a/relay/auth/hmac/token.go b/relay/auth/hmac/token.go new file mode 100644 index 000000000..581b1d6fd --- /dev/null +++ b/relay/auth/hmac/token.go @@ -0,0 +1,94 @@ +package hmac + +import ( + "bytes" + "crypto/hmac" + "encoding/base64" + "encoding/gob" + "fmt" + "hash" + "strconv" + "time" + + log "github.com/sirupsen/logrus" +) + +type Token struct { + Payload string + Signature string +} + +func unmarshalToken(payload []byte) (Token, error) { + var creds Token + buffer := bytes.NewBuffer(payload) + decoder := gob.NewDecoder(buffer) + err := decoder.Decode(&creds) + return creds, err +} + +// TimedHMAC generates a token with TTL and uses a pre-shared secret known to the relay server +type TimedHMAC struct { + secret string + timeToLive time.Duration +} + +// NewTimedHMAC creates a new TimedHMAC instance +func NewTimedHMAC(secret string, timeToLive time.Duration) *TimedHMAC { + return &TimedHMAC{ + secret: secret, + timeToLive: timeToLive, + } +} + +// GenerateToken generates new time-based secret token - basically Payload is a unix timestamp and Signature is a HMAC +// hash of a timestamp with a preshared TURN secret +func (m *TimedHMAC) GenerateToken(algo func() hash.Hash) (*Token, error) { + timeAuth := time.Now().Add(m.timeToLive).Unix() + timeStamp := strconv.FormatInt(timeAuth, 10) + + checksum, err := m.generate(algo, timeStamp) + if err != nil { + return nil, err + } + + return &Token{ + Payload: timeStamp, + Signature: base64.StdEncoding.EncodeToString(checksum), + }, nil +} + +// Validate checks if the token is valid +func (m *TimedHMAC) Validate(algo func() hash.Hash, token Token) error { + expectedMAC, err := m.generate(algo, token.Payload) + if err != nil { + return err + } + + expectedSignature := base64.StdEncoding.EncodeToString(expectedMAC) + + if !hmac.Equal([]byte(expectedSignature), []byte(token.Signature)) { + return fmt.Errorf("signature mismatch") + } + + timeAuthInt, err := strconv.ParseInt(token.Payload, 10, 64) + if err != nil { + return fmt.Errorf("invalid payload: %w", err) + } + + if time.Now().Unix() > timeAuthInt { + return fmt.Errorf("expired token") + } + + return nil +} + +func (m *TimedHMAC) generate(algo func() hash.Hash, payload string) ([]byte, error) { + mac := hmac.New(algo, []byte(m.secret)) + _, err := mac.Write([]byte(payload)) + if err != nil { + log.Debugf("failed to generate token: %s", err) + return nil, fmt.Errorf("failed to generate token: %w", err) + } + + return mac.Sum(nil), nil +} diff --git a/relay/auth/hmac/token_test.go b/relay/auth/hmac/token_test.go new file mode 100644 index 000000000..e629eab97 --- /dev/null +++ b/relay/auth/hmac/token_test.go @@ -0,0 +1,105 @@ +package hmac + +import ( + "crypto/sha1" + "crypto/sha256" + "encoding/base64" + "strconv" + "testing" + "time" +) + +func TestGenerateCredentials(t *testing.T) { + secret := "secret" + timeToLive := 1 * time.Hour + v := NewTimedHMAC(secret, timeToLive) + + creds, err := v.GenerateToken(sha1.New) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if creds.Payload == "" { + t.Fatalf("expected non-empty payload") + } + + _, err = strconv.ParseInt(creds.Payload, 10, 64) + if err != nil { + t.Fatalf("expected payload to be a valid unix timestamp, got %v", err) + } + + _, err = base64.StdEncoding.DecodeString(creds.Signature) + if err != nil { + t.Fatalf("expected signature to be base64 encoded, got %v", err) + } +} + +func TestValidateCredentials(t *testing.T) { + secret := "supersecret" + timeToLive := 1 * time.Hour + manager := NewTimedHMAC(secret, timeToLive) + + // Test valid token + creds, err := manager.GenerateToken(sha1.New) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if err := manager.Validate(sha1.New, *creds); err != nil { + t.Fatalf("expected valid token: %s", err) + } +} + +func TestInvalidSignature(t *testing.T) { + secret := "supersecret" + timeToLive := 1 * time.Hour + manager := NewTimedHMAC(secret, timeToLive) + + creds, err := manager.GenerateToken(sha256.New) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + invalidCreds := &Token{ + Payload: creds.Payload, + Signature: "invalidsignature", + } + + if err = manager.Validate(sha1.New, *invalidCreds); err == nil { + t.Fatalf("expected invalid token due to signature mismatch") + } +} + +func TestExpired(t *testing.T) { + secret := "supersecret" + v := NewTimedHMAC(secret, -1*time.Hour) + expiredCreds, err := v.GenerateToken(sha256.New) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if err = v.Validate(sha1.New, *expiredCreds); err == nil { + t.Fatalf("expected invalid token due to expiration") + } +} + +func TestInvalidPayload(t *testing.T) { + secret := "supersecret" + timeToLive := 1 * time.Hour + v := NewTimedHMAC(secret, timeToLive) + + creds, err := v.GenerateToken(sha256.New) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + // Test invalid payload + invalidPayloadCreds := &Token{ + Payload: "invalidtimestamp", + Signature: creds.Signature, + } + + if err = v.Validate(sha1.New, *invalidPayloadCreds); err == nil { + t.Fatalf("expected invalid token due to invalid payload") + } +} diff --git a/relay/auth/hmac/v2/algo.go b/relay/auth/hmac/v2/algo.go new file mode 100644 index 000000000..c379c2bd7 --- /dev/null +++ b/relay/auth/hmac/v2/algo.go @@ -0,0 +1,40 @@ +package v2 + +import ( + "crypto/sha256" + "hash" +) + +const ( + AuthAlgoUnknown AuthAlgo = iota + AuthAlgoHMACSHA256 +) + +type AuthAlgo uint8 + +func (a AuthAlgo) String() string { + switch a { + case AuthAlgoHMACSHA256: + return "HMAC-SHA256" + default: + return "Unknown" + } +} + +func (a AuthAlgo) New() func() hash.Hash { + switch a { + case AuthAlgoHMACSHA256: + return sha256.New + default: + return nil + } +} + +func (a AuthAlgo) Size() int { + switch a { + case AuthAlgoHMACSHA256: + return sha256.Size + default: + return 0 + } +} diff --git a/relay/auth/hmac/v2/generator.go b/relay/auth/hmac/v2/generator.go new file mode 100644 index 000000000..827532730 --- /dev/null +++ b/relay/auth/hmac/v2/generator.go @@ -0,0 +1,45 @@ +package v2 + +import ( + "crypto/hmac" + "fmt" + "hash" + "strconv" + "time" +) + +type Generator struct { + algo func() hash.Hash + algoType AuthAlgo + secret []byte + timeToLive time.Duration +} + +func NewGenerator(algo AuthAlgo, secret []byte, timeToLive time.Duration) (*Generator, error) { + algoFunc := algo.New() + if algoFunc == nil { + return nil, fmt.Errorf("unsupported auth algorithm: %s", algo) + } + return &Generator{ + algo: algoFunc, + algoType: algo, + secret: secret, + timeToLive: timeToLive, + }, nil +} + +func (g *Generator) GenerateToken() (*Token, error) { + expirationTime := time.Now().Add(g.timeToLive).Unix() + + payload := []byte(strconv.FormatInt(expirationTime, 10)) + + h := hmac.New(g.algo, g.secret) + h.Write(payload) + signature := h.Sum(nil) + + return &Token{ + AuthAlgo: g.algoType, + Signature: signature, + Payload: payload, + }, nil +} diff --git a/relay/auth/hmac/v2/hmac_test.go b/relay/auth/hmac/v2/hmac_test.go new file mode 100644 index 000000000..40336363f --- /dev/null +++ b/relay/auth/hmac/v2/hmac_test.go @@ -0,0 +1,110 @@ +package v2 + +import ( + "strconv" + "testing" + "time" +) + +func TestGenerateCredentials(t *testing.T) { + secret := "supersecret" + timeToLive := 1 * time.Hour + g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive) + if err != nil { + t.Fatalf("failed to create generator: %v", err) + } + + token, err := g.GenerateToken() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + if len(token.Payload) == 0 { + t.Fatalf("expected non-empty payload") + } + + _, err = strconv.ParseInt(string(token.Payload), 10, 64) + if err != nil { + t.Fatalf("expected payload to be a valid unix timestamp, got %v", err) + } +} + +func TestValidateCredentials(t *testing.T) { + secret := "supersecret" + timeToLive := 1 * time.Hour + g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive) + if err != nil { + t.Fatalf("failed to create generator: %v", err) + } + + token, err := g.GenerateToken() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + v := NewValidator([]byte(secret)) + if err := v.Validate(token.Marshal()); err != nil { + t.Fatalf("expected valid token: %s", err) + } +} + +func TestInvalidSignature(t *testing.T) { + secret := "supersecret" + timeToLive := 1 * time.Hour + g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive) + if err != nil { + t.Fatalf("failed to create generator: %v", err) + } + + token, err := g.GenerateToken() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + token.Signature = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + + v := NewValidator([]byte(secret)) + if err := v.Validate(token.Marshal()); err == nil { + t.Fatalf("expected valid token: %s", err) + } +} + +func TestExpired(t *testing.T) { + secret := "supersecret" + timeToLive := -1 * time.Hour + g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive) + if err != nil { + t.Fatalf("failed to create generator: %v", err) + } + + token, err := g.GenerateToken() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + v := NewValidator([]byte(secret)) + if err := v.Validate(token.Marshal()); err == nil { + t.Fatalf("expected valid token: %s", err) + } +} + +func TestInvalidPayload(t *testing.T) { + secret := "supersecret" + timeToLive := 1 * time.Hour + g, err := NewGenerator(AuthAlgoHMACSHA256, []byte(secret), timeToLive) + if err != nil { + t.Fatalf("failed to create generator: %v", err) + } + + token, err := g.GenerateToken() + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + token.Payload = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + + v := NewValidator([]byte(secret)) + if err := v.Validate(token.Marshal()); err == nil { + t.Fatalf("expected invalid token due to invalid payload") + } +} diff --git a/relay/auth/hmac/v2/token.go b/relay/auth/hmac/v2/token.go new file mode 100644 index 000000000..553ac01b9 --- /dev/null +++ b/relay/auth/hmac/v2/token.go @@ -0,0 +1,39 @@ +package v2 + +import "errors" + +type Token struct { + AuthAlgo AuthAlgo + Signature []byte + Payload []byte +} + +func (t *Token) Marshal() []byte { + size := 1 + len(t.Signature) + len(t.Payload) + + buf := make([]byte, size) + + buf[0] = byte(t.AuthAlgo) + copy(buf[1:], t.Signature) + copy(buf[1+len(t.Signature):], t.Payload) + + return buf +} + +func UnmarshalToken(data []byte) (*Token, error) { + if len(data) == 0 { + return nil, errors.New("invalid token data") + } + + algo := AuthAlgo(data[0]) + sigSize := algo.Size() + if len(data) < 1+sigSize { + return nil, errors.New("invalid token data: insufficient length") + } + + return &Token{ + AuthAlgo: algo, + Signature: data[1 : 1+sigSize], + Payload: data[1+sigSize:], + }, nil +} diff --git a/relay/auth/hmac/v2/validator.go b/relay/auth/hmac/v2/validator.go new file mode 100644 index 000000000..7f448dd5f --- /dev/null +++ b/relay/auth/hmac/v2/validator.go @@ -0,0 +1,59 @@ +package v2 + +import ( + "crypto/hmac" + "errors" + "fmt" + "strconv" + "time" +) + +const minLengthUnixTimestamp = 10 + +type Validator struct { + secret []byte +} + +func NewValidator(secret []byte) *Validator { + return &Validator{secret: secret} +} + +func (v *Validator) Validate(data any) error { + d, ok := data.([]byte) + if !ok { + return fmt.Errorf("invalid data type") + } + + token, err := UnmarshalToken(d) + if err != nil { + return fmt.Errorf("unmarshal token: %w", err) + } + + if len(token.Payload) < minLengthUnixTimestamp { + return errors.New("invalid payload: insufficient length") + } + + hashFunc := token.AuthAlgo.New() + if hashFunc == nil { + return fmt.Errorf("unsupported auth algorithm: %s", token.AuthAlgo) + } + + h := hmac.New(hashFunc, v.secret) + h.Write(token.Payload) + expectedMAC := h.Sum(nil) + + if !hmac.Equal(token.Signature, expectedMAC) { + return errors.New("invalid signature") + } + + timestamp, err := strconv.ParseInt(string(token.Payload), 10, 64) + if err != nil { + return fmt.Errorf("invalid payload: %w", err) + } + + if time.Now().Unix() > timestamp { + return fmt.Errorf("expired token") + } + + return nil +} diff --git a/relay/auth/hmac/validator.go b/relay/auth/hmac/validator.go new file mode 100644 index 000000000..b0b7542be --- /dev/null +++ b/relay/auth/hmac/validator.go @@ -0,0 +1,33 @@ +package hmac + +import ( + "crypto/sha256" + "fmt" + "time" + + log "github.com/sirupsen/logrus" +) + +type TimedHMACValidator struct { + *TimedHMAC +} + +func NewTimedHMACValidator(secret string, duration time.Duration) *TimedHMACValidator { + ta := NewTimedHMAC(secret, duration) + return &TimedHMACValidator{ + ta, + } +} + +func (a *TimedHMACValidator) Validate(credentials any) error { + b, ok := credentials.([]byte) + if !ok { + return fmt.Errorf("invalid credentials type") + } + c, err := unmarshalToken(b) + if err != nil { + log.Debugf("failed to unmarshal token: %s", err) + return err + } + return a.TimedHMAC.Validate(sha256.New, c) +} diff --git a/relay/auth/validator.go b/relay/auth/validator.go new file mode 100644 index 000000000..854efd5bb --- /dev/null +++ b/relay/auth/validator.go @@ -0,0 +1,35 @@ +package auth + +import ( + "time" + + auth "github.com/netbirdio/netbird/relay/auth/hmac" + authv2 "github.com/netbirdio/netbird/relay/auth/hmac/v2" +) + +// Validator is an interface that defines the Validate method. +type Validator interface { + Validate(any) error + // Deprecated: Use Validate instead. + ValidateHelloMsgType(any) error +} + +type TimedHMACValidator struct { + authenticatorV2 *authv2.Validator + authenticator *auth.TimedHMACValidator +} + +func NewTimedHMACValidator(secret []byte, duration time.Duration) *TimedHMACValidator { + return &TimedHMACValidator{ + authenticatorV2: authv2.NewValidator(secret), + authenticator: auth.NewTimedHMACValidator(string(secret), duration), + } +} + +func (a *TimedHMACValidator) Validate(credentials any) error { + return a.authenticatorV2.Validate(credentials) +} + +func (a *TimedHMACValidator) ValidateHelloMsgType(credentials any) error { + return a.authenticator.Validate(credentials) +} diff --git a/relay/client/addr.go b/relay/client/addr.go new file mode 100644 index 000000000..af4f459f8 --- /dev/null +++ b/relay/client/addr.go @@ -0,0 +1,13 @@ +package client + +type RelayAddr struct { + addr string +} + +func (a RelayAddr) Network() string { + return "relay" +} + +func (a RelayAddr) String() string { + return a.addr +} diff --git a/relay/client/client.go b/relay/client/client.go new file mode 100644 index 000000000..90bc3ac41 --- /dev/null +++ b/relay/client/client.go @@ -0,0 +1,574 @@ +package client + +import ( + "context" + "fmt" + "io" + "net" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + auth "github.com/netbirdio/netbird/relay/auth/hmac" + "github.com/netbirdio/netbird/relay/client/dialer/ws" + "github.com/netbirdio/netbird/relay/healthcheck" + "github.com/netbirdio/netbird/relay/messages" +) + +const ( + bufferSize = 8820 + serverResponseTimeout = 8 * time.Second +) + +var ( + ErrConnAlreadyExists = fmt.Errorf("connection already exists") +) + +type internalStopFlag struct { + sync.Mutex + stop bool +} + +func newInternalStopFlag() *internalStopFlag { + return &internalStopFlag{} +} + +func (isf *internalStopFlag) set() { + isf.Lock() + defer isf.Unlock() + isf.stop = true +} + +func (isf *internalStopFlag) isSet() bool { + isf.Lock() + defer isf.Unlock() + return isf.stop +} + +// Msg carry the payload from the server to the client. With this struct, the net.Conn can free the buffer. +type Msg struct { + Payload []byte + + bufPool *sync.Pool + bufPtr *[]byte +} + +func (m *Msg) Free() { + m.bufPool.Put(m.bufPtr) +} + +// connContainer is a container for the connection to the peer. It is responsible for managing the messages from the +// server and forwarding them to the upper layer content reader. +type connContainer struct { + log *log.Entry + conn *Conn + messages chan Msg + msgChanLock sync.Mutex + closed bool // flag to check if channel is closed + ctx context.Context + cancel context.CancelFunc +} + +func newConnContainer(log *log.Entry, conn *Conn, messages chan Msg) *connContainer { + ctx, cancel := context.WithCancel(context.Background()) + return &connContainer{ + log: log, + conn: conn, + messages: messages, + ctx: ctx, + cancel: cancel, + } +} + +func (cc *connContainer) writeMsg(msg Msg) { + cc.msgChanLock.Lock() + defer cc.msgChanLock.Unlock() + + if cc.closed { + msg.Free() + return + } + + select { + case cc.messages <- msg: + case <-cc.ctx.Done(): + msg.Free() + default: + msg.Free() + cc.log.Infof("message queue is full") + // todo consider to close the connection + } +} + +func (cc *connContainer) close() { + cc.cancel() + + cc.msgChanLock.Lock() + defer cc.msgChanLock.Unlock() + + if cc.closed { + return + } + + cc.closed = true + close(cc.messages) + + for msg := range cc.messages { + msg.Free() + } +} + +// Client is a client for the relay server. It is responsible for establishing a connection to the relay server and +// managing connections to other peers. All exported functions are safe to call concurrently. After close the connection, +// the client can be reused by calling Connect again. When the client is closed, all connections are closed too. +// While the Connect is in progress, the OpenConn function will block until the connection is established with relay server. +type Client struct { + log *log.Entry + parentCtx context.Context + connectionURL string + authTokenStore *auth.TokenStore + hashedID []byte + + bufPool *sync.Pool + + relayConn net.Conn + conns map[string]*connContainer + serviceIsRunning bool + mu sync.Mutex // protect serviceIsRunning and conns + readLoopMutex sync.Mutex + wgReadLoop sync.WaitGroup + instanceURL *RelayAddr + muInstanceURL sync.Mutex + + onDisconnectListener func() + listenerMutex sync.Mutex +} + +// NewClient creates a new client for the relay server. The client is not connected to the server until the Connect +func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client { + hashedID, hashedStringId := messages.HashID(peerID) + c := &Client{ + log: log.WithFields(log.Fields{"relay": serverURL}), + parentCtx: ctx, + connectionURL: serverURL, + authTokenStore: authTokenStore, + hashedID: hashedID, + bufPool: &sync.Pool{ + New: func() any { + buf := make([]byte, bufferSize) + return &buf + }, + }, + conns: make(map[string]*connContainer), + } + c.log.Infof("create new relay connection: local peerID: %s, local peer hashedID: %s", peerID, hashedStringId) + return c +} + +// Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs. +func (c *Client) Connect() error { + c.log.Infof("connecting to relay server") + c.readLoopMutex.Lock() + defer c.readLoopMutex.Unlock() + + c.mu.Lock() + defer c.mu.Unlock() + + if c.serviceIsRunning { + return nil + } + + err := c.connect() + if err != nil { + return err + } + + c.serviceIsRunning = true + + c.wgReadLoop.Add(1) + go c.readLoop(c.relayConn) + + c.log.Infof("relay connection established") + return nil +} + +// OpenConn create a new net.Conn for the destination peer ID. In case if the connection is in progress +// to the relay server, the function will block until the connection is established or timed out. Otherwise, +// it will return immediately. +// todo: what should happen if call with the same peerID with multiple times? +func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if !c.serviceIsRunning { + return nil, fmt.Errorf("relay connection is not established") + } + + hashedID, hashedStringID := messages.HashID(dstPeerID) + _, ok := c.conns[hashedStringID] + if ok { + return nil, ErrConnAlreadyExists + } + + c.log.Infof("open connection to peer: %s", hashedStringID) + msgChannel := make(chan Msg, 100) + conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL) + + c.conns[hashedStringID] = newConnContainer(c.log, conn, msgChannel) + return conn, nil +} + +// ServerInstanceURL returns the address of the relay server. It could change after the close and reopen the connection. +func (c *Client) ServerInstanceURL() (string, error) { + c.muInstanceURL.Lock() + defer c.muInstanceURL.Unlock() + if c.instanceURL == nil { + return "", fmt.Errorf("relay connection is not established") + } + return c.instanceURL.String(), nil +} + +// SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed. +func (c *Client) SetOnDisconnectListener(fn func()) { + c.listenerMutex.Lock() + defer c.listenerMutex.Unlock() + c.onDisconnectListener = fn +} + +// HasConns returns true if there are connections. +func (c *Client) HasConns() bool { + c.mu.Lock() + defer c.mu.Unlock() + return len(c.conns) > 0 +} + +// Close closes the connection to the relay server and all connections to other peers. +func (c *Client) Close() error { + return c.close(true) +} + +func (c *Client) connect() error { + conn, err := ws.Dial(c.connectionURL) + if err != nil { + return err + } + c.relayConn = conn + + err = c.handShake() + if err != nil { + cErr := conn.Close() + if cErr != nil { + c.log.Errorf("failed to close connection: %s", cErr) + } + return err + } + + return nil +} + +func (c *Client) handShake() error { + msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary()) + if err != nil { + c.log.Errorf("failed to marshal auth message: %s", err) + return err + } + + _, err = c.relayConn.Write(msg) + if err != nil { + c.log.Errorf("failed to send auth message: %s", err) + return err + } + buf := make([]byte, messages.MaxHandshakeRespSize) + n, err := c.readWithTimeout(buf) + if err != nil { + c.log.Errorf("failed to read auth response: %s", err) + return err + } + + _, err = messages.ValidateVersion(buf[:n]) + if err != nil { + return fmt.Errorf("validate version: %w", err) + } + + msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n]) + if err != nil { + c.log.Errorf("failed to determine message type: %s", err) + return err + } + + if msgType != messages.MsgTypeAuthResponse { + c.log.Errorf("unexpected message type: %s", msgType) + return fmt.Errorf("unexpected message type") + } + + addr, err := messages.UnmarshalAuthResponse(buf[messages.SizeOfProtoHeader:n]) + if err != nil { + return err + } + + c.muInstanceURL.Lock() + c.instanceURL = &RelayAddr{addr: addr} + c.muInstanceURL.Unlock() + return nil +} + +func (c *Client) readLoop(relayConn net.Conn) { + internallyStoppedFlag := newInternalStopFlag() + hc := healthcheck.NewReceiver(c.log) + go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag) + + var ( + errExit error + n int + ) + for { + bufPtr := c.bufPool.Get().(*[]byte) + buf := *bufPtr + n, errExit = relayConn.Read(buf) + if errExit != nil { + c.log.Infof("start to Relay read loop exit") + c.mu.Lock() + if c.serviceIsRunning && !internallyStoppedFlag.isSet() { + c.log.Debugf("failed to read message from relay server: %s", errExit) + } + c.mu.Unlock() + break + } + + _, err := messages.ValidateVersion(buf[:n]) + if err != nil { + c.log.Errorf("failed to validate protocol version: %s", err) + c.bufPool.Put(bufPtr) + continue + } + + msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n]) + if err != nil { + c.log.Errorf("failed to determine message type: %s", err) + c.bufPool.Put(bufPtr) + continue + } + + if !c.handleMsg(msgType, buf[messages.SizeOfProtoHeader:n], bufPtr, hc, internallyStoppedFlag) { + break + } + } + + hc.Stop() + + c.muInstanceURL.Lock() + c.instanceURL = nil + c.muInstanceURL.Unlock() + + c.notifyDisconnected() + c.wgReadLoop.Done() + _ = c.close(false) +} + +func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte, hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) (continueLoop bool) { + switch msgType { + case messages.MsgTypeHealthCheck: + c.handleHealthCheck(hc, internallyStoppedFlag) + c.bufPool.Put(bufPtr) + case messages.MsgTypeTransport: + return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag) + case messages.MsgTypeClose: + c.log.Debugf("relay connection close by server") + c.bufPool.Put(bufPtr) + return false + } + + return true +} + +func (c *Client) handleHealthCheck(hc *healthcheck.Receiver, internallyStoppedFlag *internalStopFlag) { + msg := messages.MarshalHealthcheck() + _, wErr := c.relayConn.Write(msg) + if wErr != nil { + if c.serviceIsRunning && !internallyStoppedFlag.isSet() { + c.log.Errorf("failed to send heartbeat: %s", wErr) + } + } + hc.Heartbeat() +} + +func (c *Client) handleTransportMsg(buf []byte, bufPtr *[]byte, internallyStoppedFlag *internalStopFlag) bool { + peerID, payload, err := messages.UnmarshalTransportMsg(buf) + if err != nil { + if c.serviceIsRunning && !internallyStoppedFlag.isSet() { + c.log.Errorf("failed to parse transport message: %v", err) + } + + c.bufPool.Put(bufPtr) + return true + } + + stringID := messages.HashIDToString(peerID) + + c.mu.Lock() + if !c.serviceIsRunning { + c.mu.Unlock() + c.bufPool.Put(bufPtr) + return false + } + container, ok := c.conns[stringID] + c.mu.Unlock() + if !ok { + c.log.Errorf("peer not found: %s", stringID) + c.bufPool.Put(bufPtr) + return true + } + msg := Msg{ + bufPool: c.bufPool, + bufPtr: bufPtr, + Payload: payload, + } + container.writeMsg(msg) + return true +} + +func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload []byte) (int, error) { + c.mu.Lock() + conn, ok := c.conns[id] + c.mu.Unlock() + if !ok { + return 0, io.EOF + } + + if conn.conn != connReference { + return 0, io.EOF + } + + // todo: use buffer pool instead of create new transport msg. + msg, err := messages.MarshalTransportMsg(dstID, payload) + if err != nil { + c.log.Errorf("failed to marshal transport message: %s", err) + return 0, err + } + + // the write always return with 0 length because the underling does not support the size feedback. + _, err = c.relayConn.Write(msg) + if err != nil { + c.log.Errorf("failed to write transport message: %s", err) + } + return len(payload), err +} + +func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, internalStopFlag *internalStopFlag) { + for { + select { + case _, ok := <-hc.OnTimeout: + if !ok { + return + } + c.log.Errorf("health check timeout") + internalStopFlag.set() + if err := conn.Close(); err != nil { + // ignore the err handling because the readLoop will handle it + c.log.Warnf("failed to close connection: %s", err) + } + return + case <-c.parentCtx.Done(): + err := c.close(true) + if err != nil { + c.log.Errorf("failed to teardown connection: %s", err) + } + return + } + } +} + +func (c *Client) closeAllConns() { + for _, container := range c.conns { + container.close() + } + c.conns = make(map[string]*connContainer) +} + +func (c *Client) closeConn(connReference *Conn, id string) error { + c.mu.Lock() + defer c.mu.Unlock() + + container, ok := c.conns[id] + if !ok { + return fmt.Errorf("connection already closed") + } + + if container.conn != connReference { + return fmt.Errorf("conn reference mismatch") + } + c.log.Infof("free up connection to peer: %s", id) + delete(c.conns, id) + container.close() + + return nil +} + +func (c *Client) close(gracefullyExit bool) error { + c.readLoopMutex.Lock() + defer c.readLoopMutex.Unlock() + + c.mu.Lock() + var err error + if !c.serviceIsRunning { + c.mu.Unlock() + c.log.Warn("relay connection was already marked as not running") + return nil + } + + c.serviceIsRunning = false + c.log.Infof("closing all peer connections") + c.closeAllConns() + if gracefullyExit { + c.writeCloseMsg() + } + err = c.relayConn.Close() + c.mu.Unlock() + + c.log.Infof("waiting for read loop to close") + c.wgReadLoop.Wait() + c.log.Infof("relay connection closed") + return err +} + +func (c *Client) notifyDisconnected() { + c.listenerMutex.Lock() + defer c.listenerMutex.Unlock() + + if c.onDisconnectListener == nil { + return + } + go c.onDisconnectListener() +} + +func (c *Client) writeCloseMsg() { + msg := messages.MarshalCloseMsg() + _, err := c.relayConn.Write(msg) + if err != nil { + c.log.Errorf("failed to send close message: %s", err) + } +} + +func (c *Client) readWithTimeout(buf []byte) (int, error) { + ctx, cancel := context.WithTimeout(c.parentCtx, serverResponseTimeout) + defer cancel() + + readDone := make(chan struct{}) + var ( + n int + err error + ) + + go func() { + n, err = c.relayConn.Read(buf) + close(readDone) + }() + + select { + case <-ctx.Done(): + return 0, fmt.Errorf("read operation timed out") + case <-readDone: + return n, err + } +} diff --git a/relay/client/client_test.go b/relay/client/client_test.go new file mode 100644 index 000000000..ef28203e9 --- /dev/null +++ b/relay/client/client_test.go @@ -0,0 +1,712 @@ +package client + +import ( + "context" + "net" + "os" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel" + + "github.com/netbirdio/netbird/relay/auth/allow" + "github.com/netbirdio/netbird/relay/auth/hmac" + "github.com/netbirdio/netbird/util" + + "github.com/netbirdio/netbird/relay/server" +) + +var ( + av = &allow.Auth{} + hmacTokenStore = &hmac.TokenStore{} + serverListenAddr = "127.0.0.1:1234" + serverURL = "rel://127.0.0.1:1234" +) + +func TestMain(m *testing.M) { + _ = util.InitLog("error", "console") + code := m.Run() + os.Exit(code) +} + +func TestClient(t *testing.T) { + ctx := context.Background() + + srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + if err != nil { + t.Fatalf("failed to create server: %s", err) + } + errChan := make(chan error, 1) + go func() { + listenCfg := server.ListenerConfig{Address: serverListenAddr} + err := srv.Listen(listenCfg) + if err != nil { + errChan <- err + } + }() + + defer func() { + err := srv.Shutdown(ctx) + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + + // wait for server to start + if err := waitForServerToStart(errChan); err != nil { + t.Fatalf("failed to start server: %s", err) + } + t.Log("alice connecting to server") + clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") + err = clientAlice.Connect() + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + defer clientAlice.Close() + + t.Log("placeholder connecting to server") + clientPlaceHolder := NewClient(ctx, serverURL, hmacTokenStore, "clientPlaceHolder") + err = clientPlaceHolder.Connect() + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + defer clientPlaceHolder.Close() + + t.Log("Bob connecting to server") + clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob") + err = clientBob.Connect() + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + defer clientBob.Close() + + t.Log("Alice open connection to Bob") + connAliceToBob, err := clientAlice.OpenConn("bob") + if err != nil { + t.Fatalf("failed to bind channel: %s", err) + } + + t.Log("Bob open connection to Alice") + connBobToAlice, err := clientBob.OpenConn("alice") + if err != nil { + t.Fatalf("failed to bind channel: %s", err) + } + + payload := "hello bob, I am alice" + _, err = connAliceToBob.Write([]byte(payload)) + if err != nil { + t.Fatalf("failed to write to channel: %s", err) + } + log.Debugf("alice sent message to bob") + + buf := make([]byte, 65535) + n, err := connBobToAlice.Read(buf) + if err != nil { + t.Fatalf("failed to read from channel: %s", err) + } + log.Debugf("on new message from alice to bob") + + if payload != string(buf[:n]) { + t.Fatalf("expected %s, got %s", payload, string(buf[:n])) + } +} + +func TestRegistration(t *testing.T) { + ctx := context.Background() + srvCfg := server.ListenerConfig{Address: serverListenAddr} + srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + if err != nil { + t.Fatalf("failed to create server: %s", err) + } + errChan := make(chan error, 1) + go func() { + err := srv.Listen(srvCfg) + if err != nil { + errChan <- err + } + }() + + // wait for server to start + if err := waitForServerToStart(errChan); err != nil { + t.Fatalf("failed to start server: %s", err) + } + + clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") + err = clientAlice.Connect() + if err != nil { + _ = srv.Shutdown(ctx) + t.Fatalf("failed to connect to server: %s", err) + } + err = clientAlice.Close() + if err != nil { + t.Errorf("failed to close conn: %s", err) + } + err = srv.Shutdown(ctx) + if err != nil { + t.Errorf("failed to close server: %s", err) + } +} + +func TestRegistrationTimeout(t *testing.T) { + ctx := context.Background() + fakeUDPListener, err := net.ListenUDP("udp", &net.UDPAddr{ + Port: 1234, + IP: net.ParseIP("0.0.0.0"), + }) + if err != nil { + t.Fatalf("failed to bind UDP server: %s", err) + } + defer func(fakeUDPListener *net.UDPConn) { + _ = fakeUDPListener.Close() + }(fakeUDPListener) + + fakeTCPListener, err := net.ListenTCP("tcp", &net.TCPAddr{ + Port: 1234, + IP: net.ParseIP("0.0.0.0"), + }) + if err != nil { + t.Fatalf("failed to bind TCP server: %s", err) + } + defer func(fakeTCPListener *net.TCPListener) { + _ = fakeTCPListener.Close() + }(fakeTCPListener) + + clientAlice := NewClient(ctx, "127.0.0.1:1234", hmacTokenStore, "alice") + err = clientAlice.Connect() + if err == nil { + t.Errorf("failed to connect to server: %s", err) + } + log.Debugf("%s", err) + err = clientAlice.Close() + if err != nil { + t.Errorf("failed to close conn: %s", err) + } +} + +func TestEcho(t *testing.T) { + ctx := context.Background() + idAlice := "alice" + idBob := "bob" + srvCfg := server.ListenerConfig{Address: serverListenAddr} + srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + if err != nil { + t.Fatalf("failed to create server: %s", err) + } + errChan := make(chan error, 1) + go func() { + err := srv.Listen(srvCfg) + if err != nil { + errChan <- err + } + }() + + defer func() { + err := srv.Shutdown(ctx) + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + + // wait for servers to start + if err := waitForServerToStart(errChan); err != nil { + t.Fatalf("failed to start server: %s", err) + } + + clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice) + err = clientAlice.Connect() + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + defer func() { + err := clientAlice.Close() + if err != nil { + t.Errorf("failed to close Alice client: %s", err) + } + }() + + clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob) + err = clientBob.Connect() + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + defer func() { + err := clientBob.Close() + if err != nil { + t.Errorf("failed to close Bob client: %s", err) + } + }() + + connAliceToBob, err := clientAlice.OpenConn(idBob) + if err != nil { + t.Fatalf("failed to bind channel: %s", err) + } + + connBobToAlice, err := clientBob.OpenConn(idAlice) + if err != nil { + t.Fatalf("failed to bind channel: %s", err) + } + + payload := "hello bob, I am alice" + _, err = connAliceToBob.Write([]byte(payload)) + if err != nil { + t.Fatalf("failed to write to channel: %s", err) + } + + buf := make([]byte, 65535) + n, err := connBobToAlice.Read(buf) + if err != nil { + t.Fatalf("failed to read from channel: %s", err) + } + + _, err = connBobToAlice.Write(buf[:n]) + if err != nil { + t.Fatalf("failed to write to channel: %s", err) + } + + n, err = connAliceToBob.Read(buf) + if err != nil { + t.Fatalf("failed to read from channel: %s", err) + } + + if payload != string(buf[:n]) { + t.Fatalf("expected %s, got %s", payload, string(buf[:n])) + } +} + +func TestBindToUnavailabePeer(t *testing.T) { + ctx := context.Background() + + srvCfg := server.ListenerConfig{Address: serverListenAddr} + srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + if err != nil { + t.Fatalf("failed to create server: %s", err) + } + errChan := make(chan error, 1) + go func() { + err := srv.Listen(srvCfg) + if err != nil { + errChan <- err + } + }() + + defer func() { + log.Infof("closing server") + err := srv.Shutdown(ctx) + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + + // wait for servers to start + if err := waitForServerToStart(errChan); err != nil { + t.Fatalf("failed to start server: %s", err) + } + + clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") + err = clientAlice.Connect() + if err != nil { + t.Errorf("failed to connect to server: %s", err) + } + _, err = clientAlice.OpenConn("bob") + if err != nil { + t.Errorf("failed to bind channel: %s", err) + } + + log.Infof("closing client") + err = clientAlice.Close() + if err != nil { + t.Errorf("failed to close client: %s", err) + } +} + +func TestBindReconnect(t *testing.T) { + ctx := context.Background() + + srvCfg := server.ListenerConfig{Address: serverListenAddr} + srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + if err != nil { + t.Fatalf("failed to create server: %s", err) + } + errChan := make(chan error, 1) + go func() { + err := srv.Listen(srvCfg) + if err != nil { + errChan <- err + } + }() + + defer func() { + log.Infof("closing server") + err := srv.Shutdown(ctx) + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + + // wait for servers to start + if err := waitForServerToStart(errChan); err != nil { + t.Fatalf("failed to start server: %s", err) + } + + clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") + err = clientAlice.Connect() + if err != nil { + t.Errorf("failed to connect to server: %s", err) + } + + _, err = clientAlice.OpenConn("bob") + if err != nil { + t.Errorf("failed to bind channel: %s", err) + } + + clientBob := NewClient(ctx, serverURL, hmacTokenStore, "bob") + err = clientBob.Connect() + if err != nil { + t.Errorf("failed to connect to server: %s", err) + } + + chBob, err := clientBob.OpenConn("alice") + if err != nil { + t.Errorf("failed to bind channel: %s", err) + } + + log.Infof("closing client Alice") + err = clientAlice.Close() + if err != nil { + t.Errorf("failed to close client: %s", err) + } + + clientAlice = NewClient(ctx, serverURL, hmacTokenStore, "alice") + err = clientAlice.Connect() + if err != nil { + t.Errorf("failed to connect to server: %s", err) + } + + chAlice, err := clientAlice.OpenConn("bob") + if err != nil { + t.Errorf("failed to bind channel: %s", err) + } + + testString := "hello alice, I am bob" + _, err = chBob.Write([]byte(testString)) + if err != nil { + t.Errorf("failed to write to channel: %s", err) + } + + buf := make([]byte, 65535) + n, err := chAlice.Read(buf) + if err != nil { + t.Errorf("failed to read from channel: %s", err) + } + + if testString != string(buf[:n]) { + t.Errorf("expected %s, got %s", testString, string(buf[:n])) + } + + log.Infof("closing client") + err = clientAlice.Close() + if err != nil { + t.Errorf("failed to close client: %s", err) + } +} + +func TestCloseConn(t *testing.T) { + ctx := context.Background() + + srvCfg := server.ListenerConfig{Address: serverListenAddr} + srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + if err != nil { + t.Fatalf("failed to create server: %s", err) + } + errChan := make(chan error, 1) + go func() { + err := srv.Listen(srvCfg) + if err != nil { + errChan <- err + } + }() + + defer func() { + log.Infof("closing server") + err := srv.Shutdown(ctx) + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + + // wait for servers to start + if err := waitForServerToStart(errChan); err != nil { + t.Fatalf("failed to start server: %s", err) + } + + clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") + err = clientAlice.Connect() + if err != nil { + t.Errorf("failed to connect to server: %s", err) + } + + conn, err := clientAlice.OpenConn("bob") + if err != nil { + t.Errorf("failed to bind channel: %s", err) + } + + log.Infof("closing connection") + err = conn.Close() + if err != nil { + t.Errorf("failed to close connection: %s", err) + } + + _, err = conn.Read(make([]byte, 1)) + if err == nil { + t.Errorf("unexpected reading from closed connection") + } + + _, err = conn.Write([]byte("hello")) + if err == nil { + t.Errorf("unexpected writing from closed connection") + } +} + +func TestCloseRelayConn(t *testing.T) { + ctx := context.Background() + + srvCfg := server.ListenerConfig{Address: serverListenAddr} + srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + if err != nil { + t.Fatalf("failed to create server: %s", err) + } + errChan := make(chan error, 1) + go func() { + err := srv.Listen(srvCfg) + if err != nil { + errChan <- err + } + }() + + defer func() { + err := srv.Shutdown(ctx) + if err != nil { + log.Errorf("failed to close server: %s", err) + } + }() + + // wait for servers to start + if err := waitForServerToStart(errChan); err != nil { + t.Fatalf("failed to start server: %s", err) + } + + clientAlice := NewClient(ctx, serverURL, hmacTokenStore, "alice") + err = clientAlice.Connect() + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + + conn, err := clientAlice.OpenConn("bob") + if err != nil { + t.Errorf("failed to bind channel: %s", err) + } + + _ = clientAlice.relayConn.Close() + + _, err = conn.Read(make([]byte, 1)) + if err == nil { + t.Errorf("unexpected reading from closed connection") + } + + _, err = clientAlice.OpenConn("bob") + if err == nil { + t.Errorf("unexpected opening connection to closed server") + } +} + +func TestCloseByServer(t *testing.T) { + ctx := context.Background() + + srvCfg := server.ListenerConfig{Address: serverListenAddr} + srv1, err := server.NewServer(otel.Meter(""), serverURL, false, av) + if err != nil { + t.Fatalf("failed to create server: %s", err) + } + errChan := make(chan error, 1) + + go func() { + err := srv1.Listen(srvCfg) + if err != nil { + errChan <- err + } + }() + + // wait for servers to start + if err := waitForServerToStart(errChan); err != nil { + t.Fatalf("failed to start server: %s", err) + } + + idAlice := "alice" + log.Debugf("connect by alice") + relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice) + err = relayClient.Connect() + if err != nil { + log.Fatalf("failed to connect to server: %s", err) + } + + disconnected := make(chan struct{}) + relayClient.SetOnDisconnectListener(func() { + log.Infof("client disconnected") + close(disconnected) + }) + + err = srv1.Shutdown(ctx) + if err != nil { + t.Fatalf("failed to close server: %s", err) + } + + select { + case <-disconnected: + case <-time.After(3 * time.Second): + log.Fatalf("timeout waiting for client to disconnect") + } + + _, err = relayClient.OpenConn("bob") + if err == nil { + t.Errorf("unexpected opening connection to closed server") + } +} + +func TestCloseByClient(t *testing.T) { + ctx := context.Background() + + srvCfg := server.ListenerConfig{Address: serverListenAddr} + srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + if err != nil { + t.Fatalf("failed to create server: %s", err) + } + errChan := make(chan error, 1) + go func() { + err := srv.Listen(srvCfg) + if err != nil { + errChan <- err + } + }() + + // wait for servers to start + if err := waitForServerToStart(errChan); err != nil { + t.Fatalf("failed to start server: %s", err) + } + + idAlice := "alice" + log.Debugf("connect by alice") + relayClient := NewClient(ctx, serverURL, hmacTokenStore, idAlice) + err = relayClient.Connect() + if err != nil { + log.Fatalf("failed to connect to server: %s", err) + } + + err = relayClient.Close() + if err != nil { + t.Errorf("failed to close client: %s", err) + } + + _, err = relayClient.OpenConn("bob") + if err == nil { + t.Errorf("unexpected opening connection to closed server") + } + + err = srv.Shutdown(ctx) + if err != nil { + t.Fatalf("failed to close server: %s", err) + } +} + +func TestCloseNotDrainedChannel(t *testing.T) { + ctx := context.Background() + idAlice := "alice" + idBob := "bob" + srvCfg := server.ListenerConfig{Address: serverListenAddr} + srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + if err != nil { + t.Fatalf("failed to create server: %s", err) + } + errChan := make(chan error, 1) + go func() { + err := srv.Listen(srvCfg) + if err != nil { + errChan <- err + } + }() + + defer func() { + err := srv.Shutdown(ctx) + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + + // wait for servers to start + if err := waitForServerToStart(errChan); err != nil { + t.Fatalf("failed to start server: %s", err) + } + + clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice) + err = clientAlice.Connect() + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + defer func() { + err := clientAlice.Close() + if err != nil { + t.Errorf("failed to close Alice client: %s", err) + } + }() + + clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob) + err = clientBob.Connect() + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + defer func() { + err := clientBob.Close() + if err != nil { + t.Errorf("failed to close Bob client: %s", err) + } + }() + + connAliceToBob, err := clientAlice.OpenConn(idBob) + if err != nil { + t.Fatalf("failed to bind channel: %s", err) + } + + connBobToAlice, err := clientBob.OpenConn(idAlice) + if err != nil { + t.Fatalf("failed to bind channel: %s", err) + } + + payload := "hello bob, I am alice" + // the internal channel buffer size is 2. So we should overflow it + for i := 0; i < 5; i++ { + _, err = connAliceToBob.Write([]byte(payload)) + if err != nil { + t.Fatalf("failed to write to channel: %s", err) + } + + } + + // wait for delivery + time.Sleep(1 * time.Second) + err = connBobToAlice.Close() + if err != nil { + t.Errorf("failed to close channel: %s", err) + } +} + +func waitForServerToStart(errChan chan error) error { + select { + case err := <-errChan: + if err != nil { + return err + } + case <-time.After(300 * time.Millisecond): + return nil + } + return nil +} diff --git a/relay/client/conn.go b/relay/client/conn.go new file mode 100644 index 000000000..b4ff903e8 --- /dev/null +++ b/relay/client/conn.go @@ -0,0 +1,76 @@ +package client + +import ( + "io" + "net" + "time" +) + +// Conn represent a connection to a relayed remote peer. +type Conn struct { + client *Client + dstID []byte + dstStringID string + messageChan chan Msg + instanceURL *RelayAddr +} + +// NewConn creates a new connection to a relayed remote peer. +// client: the client instance, it used to send messages to the destination peer +// dstID: the destination peer ID +// dstStringID: the destination peer ID in string format +// messageChan: the channel where the messages will be received +// instanceURL: the relay instance URL, it used to get the proper server instance address for the remote peer +func NewConn(client *Client, dstID []byte, dstStringID string, messageChan chan Msg, instanceURL *RelayAddr) *Conn { + c := &Conn{ + client: client, + dstID: dstID, + dstStringID: dstStringID, + messageChan: messageChan, + instanceURL: instanceURL, + } + + return c +} + +func (c *Conn) Write(p []byte) (n int, err error) { + return c.client.writeTo(c, c.dstStringID, c.dstID, p) +} + +func (c *Conn) Read(b []byte) (n int, err error) { + msg, ok := <-c.messageChan + if !ok { + return 0, io.EOF + } + + n = copy(b, msg.Payload) + msg.Free() + return n, nil +} + +func (c *Conn) Close() error { + return c.client.closeConn(c, c.dstStringID) +} + +func (c *Conn) LocalAddr() net.Addr { + return c.client.relayConn.LocalAddr() +} + +func (c *Conn) RemoteAddr() net.Addr { + return c.instanceURL +} + +func (c *Conn) SetDeadline(t time.Time) error { + //TODO implement me + panic("SetDeadline is not implemented") +} + +func (c *Conn) SetReadDeadline(t time.Time) error { + //TODO implement me + panic("SetReadDeadline is not implemented") +} + +func (c *Conn) SetWriteDeadline(t time.Time) error { + //TODO implement me + panic("SetReadDeadline is not implemented") +} diff --git a/relay/client/dialer/ws/addr.go b/relay/client/dialer/ws/addr.go new file mode 100644 index 000000000..43f5dd6af --- /dev/null +++ b/relay/client/dialer/ws/addr.go @@ -0,0 +1,13 @@ +package ws + +type WebsocketAddr struct { + addr string +} + +func (a WebsocketAddr) Network() string { + return "websocket" +} + +func (a WebsocketAddr) String() string { + return a.addr +} diff --git a/relay/client/dialer/ws/conn.go b/relay/client/dialer/ws/conn.go new file mode 100644 index 000000000..e7f771b8d --- /dev/null +++ b/relay/client/dialer/ws/conn.go @@ -0,0 +1,66 @@ +package ws + +import ( + "context" + "fmt" + "net" + "time" + + "nhooyr.io/websocket" +) + +type Conn struct { + ctx context.Context + *websocket.Conn + remoteAddr WebsocketAddr +} + +func NewConn(wsConn *websocket.Conn, serverAddress string) net.Conn { + return &Conn{ + ctx: context.Background(), + Conn: wsConn, + remoteAddr: WebsocketAddr{serverAddress}, + } +} + +func (c *Conn) Read(b []byte) (n int, err error) { + t, ioReader, err := c.Conn.Reader(c.ctx) + if err != nil { + return 0, err + } + + if t != websocket.MessageBinary { + return 0, fmt.Errorf("unexpected message type") + } + + return ioReader.Read(b) +} + +func (c *Conn) Write(b []byte) (n int, err error) { + err = c.Conn.Write(c.ctx, websocket.MessageBinary, b) + return 0, err +} + +func (c *Conn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func (c *Conn) LocalAddr() net.Addr { + return WebsocketAddr{addr: "unknown"} +} + +func (c *Conn) SetReadDeadline(t time.Time) error { + return fmt.Errorf("SetReadDeadline is not implemented") +} + +func (c *Conn) SetWriteDeadline(t time.Time) error { + return fmt.Errorf("SetWriteDeadline is not implemented") +} + +func (c *Conn) SetDeadline(t time.Time) error { + return fmt.Errorf("SetDeadline is not implemented") +} + +func (c *Conn) Close() error { + return c.Conn.CloseNow() +} diff --git a/relay/client/dialer/ws/ws.go b/relay/client/dialer/ws/ws.go new file mode 100644 index 000000000..d9388aafd --- /dev/null +++ b/relay/client/dialer/ws/ws.go @@ -0,0 +1,67 @@ +package ws + +import ( + "context" + "fmt" + "net" + "net/http" + "net/url" + "strings" + + log "github.com/sirupsen/logrus" + "nhooyr.io/websocket" + + "github.com/netbirdio/netbird/relay/server/listener/ws" + nbnet "github.com/netbirdio/netbird/util/net" +) + +func Dial(address string) (net.Conn, error) { + wsURL, err := prepareURL(address) + if err != nil { + return nil, err + } + + opts := &websocket.DialOptions{ + HTTPClient: httpClientNbDialer(), + } + + parsedURL, err := url.Parse(wsURL) + if err != nil { + return nil, err + } + parsedURL.Path = ws.URLPath + + wsConn, resp, err := websocket.Dial(context.Background(), parsedURL.String(), opts) + if err != nil { + log.Errorf("failed to dial to Relay server '%s': %s", wsURL, err) + return nil, err + } + if resp.Body != nil { + _ = resp.Body.Close() + } + + conn := NewConn(wsConn, address) + return conn, nil +} + +func prepareURL(address string) (string, error) { + if !strings.HasPrefix(address, "rel:") && !strings.HasPrefix(address, "rels:") { + return "", fmt.Errorf("unsupported scheme: %s", address) + } + + return strings.Replace(address, "rel", "ws", 1), nil +} + +func httpClientNbDialer() *http.Client { + customDialer := nbnet.NewDialer() + + customTransport := &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return customDialer.DialContext(ctx, network, addr) + }, + } + + return &http.Client{ + Transport: customTransport, + } +} diff --git a/relay/client/doc.go b/relay/client/doc.go new file mode 100644 index 000000000..1339251d9 --- /dev/null +++ b/relay/client/doc.go @@ -0,0 +1,12 @@ +/* +Package client contains the implementation of the Relay client. + +The Relay client is responsible for establishing a connection with the Relay server and sending and receiving messages, +Keep persistent connection with the Relay server and handle the connection issues. +It uses the WebSocket protocol for communication and optionally supports TLS (Transport Layer Security). + +If a peer wants to communicate with a peer on a different relay server, the manager will establish a new connection to +the relay server. The connection with these relay servers will be closed if there is no active connection. The peers +negotiate the common relay instance via signaling service. +*/ +package client diff --git a/relay/client/guard.go b/relay/client/guard.go new file mode 100644 index 000000000..f826cf1b6 --- /dev/null +++ b/relay/client/guard.go @@ -0,0 +1,48 @@ +package client + +import ( + "context" + "time" + + log "github.com/sirupsen/logrus" +) + +var ( + reconnectingTimeout = 5 * time.Second +) + +// Guard manage the reconnection tries to the Relay server in case of disconnection event. +type Guard struct { + ctx context.Context + relayClient *Client +} + +// NewGuard creates a new guard for the relay client. +func NewGuard(context context.Context, relayClient *Client) *Guard { + g := &Guard{ + ctx: context, + relayClient: relayClient, + } + return g +} + +// OnDisconnected is called when the relay client is disconnected from the relay server. It will trigger the reconnection +// todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent +func (g *Guard) OnDisconnected() { + ticker := time.NewTicker(reconnectingTimeout) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + err := g.relayClient.Connect() + if err != nil { + log.Errorf("failed to reconnect to relay server: %s", err) + continue + } + return + case <-g.ctx.Done(): + return + } + } +} diff --git a/relay/client/manager.go b/relay/client/manager.go new file mode 100644 index 000000000..4554c7c0f --- /dev/null +++ b/relay/client/manager.go @@ -0,0 +1,326 @@ +package client + +import ( + "container/list" + "context" + "fmt" + "net" + "reflect" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + relayAuth "github.com/netbirdio/netbird/relay/auth/hmac" +) + +var ( + relayCleanupInterval = 60 * time.Second + + ErrRelayClientNotConnected = fmt.Errorf("relay client not connected") +) + +// RelayTrack hold the relay clients for the foreign relay servers. +// With the mutex can ensure we can open new connection in case the relay connection has been established with +// the relay server. +type RelayTrack struct { + sync.RWMutex + relayClient *Client + err error +} + +func NewRelayTrack() *RelayTrack { + return &RelayTrack{} +} + +type OnServerCloseListener func() + +// ManagerService is the interface for the relay manager. +type ManagerService interface { + Serve() error + OpenConn(serverAddress, peerKey string) (net.Conn, error) + AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error + RelayInstanceAddress() (string, error) + ServerURLs() []string + HasRelayAddress() bool + UpdateToken(token *relayAuth.Token) error +} + +// Manager is a manager for the relay client instances. It establishes one persistent connection to the given relay URL +// and automatically reconnect to them in case disconnection. +// The manager also manage temporary relay connection. If a client wants to communicate with a client on a +// different relay server, the manager will establish a new connection to the relay server. The connection with these +// relay servers will be closed if there is no active connection. Periodically the manager will check if there is any +// unused relay connection and close it. +type Manager struct { + ctx context.Context + serverURLs []string + peerID string + tokenStore *relayAuth.TokenStore + + relayClient *Client + reconnectGuard *Guard + + relayClients map[string]*RelayTrack + relayClientsMutex sync.RWMutex + + onDisconnectedListeners map[string]*list.List + listenerLock sync.Mutex +} + +// NewManager creates a new manager instance. +// The serverURL address can be empty. In this case, the manager will not serve. +func NewManager(ctx context.Context, serverURLs []string, peerID string) *Manager { + return &Manager{ + ctx: ctx, + serverURLs: serverURLs, + peerID: peerID, + tokenStore: &relayAuth.TokenStore{}, + relayClients: make(map[string]*RelayTrack), + onDisconnectedListeners: make(map[string]*list.List), + } +} + +// Serve starts the manager. It will establish a connection to the relay server and start the relay cleanup loop for +// the unused relay connections. The manager will automatically reconnect to the relay server in case of disconnection. +func (m *Manager) Serve() error { + if m.relayClient != nil { + return fmt.Errorf("manager already serving") + } + log.Debugf("starting relay client manager with %v relay servers", m.serverURLs) + + sp := ServerPicker{ + TokenStore: m.tokenStore, + PeerID: m.peerID, + } + + client, err := sp.PickServer(m.ctx, m.serverURLs) + if err != nil { + return err + } + m.relayClient = client + + m.reconnectGuard = NewGuard(m.ctx, m.relayClient) + m.relayClient.SetOnDisconnectListener(func() { + m.onServerDisconnected(client.connectionURL) + }) + m.startCleanupLoop() + return nil +} + +// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be +// established via the relay server. If the peer is on a different relay server, the manager will establish a new +// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection. +func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) { + if m.relayClient == nil { + return nil, ErrRelayClientNotConnected + } + + foreign, err := m.isForeignServer(serverAddress) + if err != nil { + return nil, err + } + + var ( + netConn net.Conn + ) + if !foreign { + log.Debugf("open peer connection via permanent server: %s", peerKey) + netConn, err = m.relayClient.OpenConn(peerKey) + } else { + log.Debugf("open peer connection via foreign server: %s", serverAddress) + netConn, err = m.openConnVia(serverAddress, peerKey) + } + if err != nil { + return nil, err + } + + return netConn, err +} + +// AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection +// closed. +func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error { + foreign, err := m.isForeignServer(serverAddress) + if err != nil { + return err + } + + var listenerAddr string + if foreign { + listenerAddr = serverAddress + } else { + listenerAddr = m.relayClient.connectionURL + } + m.addListener(listenerAddr, onClosedListener) + return nil +} + +// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is +// lost. This address will be sent to the target peer to choose the common relay server for the communication. +func (m *Manager) RelayInstanceAddress() (string, error) { + if m.relayClient == nil { + return "", ErrRelayClientNotConnected + } + return m.relayClient.ServerInstanceURL() +} + +// ServerURLs returns the addresses of the relay servers. +func (m *Manager) ServerURLs() []string { + return m.serverURLs +} + +// HasRelayAddress returns true if the manager is serving. With this method can check if the peer can communicate with +// Relay service. +func (m *Manager) HasRelayAddress() bool { + return len(m.serverURLs) > 0 +} + +// UpdateToken updates the token in the token store. +func (m *Manager) UpdateToken(token *relayAuth.Token) error { + return m.tokenStore.UpdateToken(token) +} + +func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { + // check if already has a connection to the desired relay server + m.relayClientsMutex.RLock() + rt, ok := m.relayClients[serverAddress] + if ok { + rt.RLock() + m.relayClientsMutex.RUnlock() + defer rt.RUnlock() + if rt.err != nil { + return nil, rt.err + } + return rt.relayClient.OpenConn(peerKey) + } + m.relayClientsMutex.RUnlock() + + // if not, establish a new connection but check it again (because changed the lock type) before starting the + // connection + m.relayClientsMutex.Lock() + rt, ok = m.relayClients[serverAddress] + if ok { + rt.RLock() + m.relayClientsMutex.Unlock() + defer rt.RUnlock() + if rt.err != nil { + return nil, rt.err + } + return rt.relayClient.OpenConn(peerKey) + } + + // create a new relay client and store it in the relayClients map + rt = NewRelayTrack() + rt.Lock() + m.relayClients[serverAddress] = rt + m.relayClientsMutex.Unlock() + + relayClient := NewClient(m.ctx, serverAddress, m.tokenStore, m.peerID) + err := relayClient.Connect() + if err != nil { + rt.err = err + rt.Unlock() + m.relayClientsMutex.Lock() + delete(m.relayClients, serverAddress) + m.relayClientsMutex.Unlock() + return nil, err + } + // if connection closed then delete the relay client from the list + relayClient.SetOnDisconnectListener(func() { + m.onServerDisconnected(serverAddress) + }) + rt.relayClient = relayClient + rt.Unlock() + + conn, err := relayClient.OpenConn(peerKey) + if err != nil { + return nil, err + } + return conn, nil +} + +func (m *Manager) onServerDisconnected(serverAddress string) { + if serverAddress == m.relayClient.connectionURL { + go m.reconnectGuard.OnDisconnected() + } + + m.notifyOnDisconnectListeners(serverAddress) +} + +func (m *Manager) isForeignServer(address string) (bool, error) { + rAddr, err := m.relayClient.ServerInstanceURL() + if err != nil { + return false, fmt.Errorf("relay client not connected") + } + return rAddr != address, nil +} + +func (m *Manager) startCleanupLoop() { + if m.ctx.Err() != nil { + return + } + + ticker := time.NewTicker(relayCleanupInterval) + go func() { + defer ticker.Stop() + for { + select { + case <-m.ctx.Done(): + return + case <-ticker.C: + m.cleanUpUnusedRelays() + } + } + }() +} + +func (m *Manager) cleanUpUnusedRelays() { + m.relayClientsMutex.Lock() + defer m.relayClientsMutex.Unlock() + + for addr, rt := range m.relayClients { + rt.Lock() + if rt.relayClient.HasConns() { + rt.Unlock() + continue + } + rt.relayClient.SetOnDisconnectListener(nil) + go func() { + _ = rt.relayClient.Close() + }() + log.Debugf("clean up unused relay server connection: %s", addr) + delete(m.relayClients, addr) + rt.Unlock() + } +} + +func (m *Manager) addListener(serverAddress string, onClosedListener OnServerCloseListener) { + m.listenerLock.Lock() + defer m.listenerLock.Unlock() + l, ok := m.onDisconnectedListeners[serverAddress] + if !ok { + l = list.New() + } + for e := l.Front(); e != nil; e = e.Next() { + if reflect.ValueOf(e.Value).Pointer() == reflect.ValueOf(onClosedListener).Pointer() { + return + } + } + l.PushBack(onClosedListener) + m.onDisconnectedListeners[serverAddress] = l +} + +func (m *Manager) notifyOnDisconnectListeners(serverAddress string) { + m.listenerLock.Lock() + defer m.listenerLock.Unlock() + + l, ok := m.onDisconnectedListeners[serverAddress] + if !ok { + return + } + for e := l.Front(); e != nil; e = e.Next() { + go e.Value.(OnServerCloseListener)() + } + delete(m.onDisconnectedListeners, serverAddress) +} diff --git a/relay/client/manager_test.go b/relay/client/manager_test.go new file mode 100644 index 000000000..e9cc2c581 --- /dev/null +++ b/relay/client/manager_test.go @@ -0,0 +1,432 @@ +package client + +import ( + "context" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel" + + "github.com/netbirdio/netbird/relay/server" +) + +func TestEmptyURL(t *testing.T) { + mgr := NewManager(context.Background(), nil, "alice") + err := mgr.Serve() + if err == nil { + t.Errorf("expected error, got nil") + } +} + +func TestForeignConn(t *testing.T) { + ctx := context.Background() + + srvCfg1 := server.ListenerConfig{ + Address: "localhost:1234", + } + srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av) + if err != nil { + t.Fatalf("failed to create server: %s", err) + } + errChan := make(chan error, 1) + go func() { + err := srv1.Listen(srvCfg1) + if err != nil { + errChan <- err + } + }() + + defer func() { + err := srv1.Shutdown(ctx) + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + + if err := waitForServerToStart(errChan); err != nil { + t.Fatalf("failed to start server: %s", err) + } + + srvCfg2 := server.ListenerConfig{ + Address: "localhost:2234", + } + srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av) + if err != nil { + t.Fatalf("failed to create server: %s", err) + } + errChan2 := make(chan error, 1) + go func() { + err := srv2.Listen(srvCfg2) + if err != nil { + errChan2 <- err + } + }() + + defer func() { + err := srv2.Shutdown(ctx) + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + + if err := waitForServerToStart(errChan2); err != nil { + t.Fatalf("failed to start server: %s", err) + } + + idAlice := "alice" + log.Debugf("connect by alice") + mCtx, cancel := context.WithCancel(ctx) + defer cancel() + clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice) + err = clientAlice.Serve() + if err != nil { + t.Fatalf("failed to serve manager: %s", err) + } + + idBob := "bob" + log.Debugf("connect by bob") + clientBob := NewManager(mCtx, toURL(srvCfg2), idBob) + err = clientBob.Serve() + if err != nil { + t.Fatalf("failed to serve manager: %s", err) + } + bobsSrvAddr, err := clientBob.RelayInstanceAddress() + if err != nil { + t.Fatalf("failed to get relay address: %s", err) + } + connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr, idBob) + if err != nil { + t.Fatalf("failed to bind channel: %s", err) + } + connBobToAlice, err := clientBob.OpenConn(bobsSrvAddr, idAlice) + if err != nil { + t.Fatalf("failed to bind channel: %s", err) + } + + payload := "hello bob, I am alice" + _, err = connAliceToBob.Write([]byte(payload)) + if err != nil { + t.Fatalf("failed to write to channel: %s", err) + } + + buf := make([]byte, 65535) + n, err := connBobToAlice.Read(buf) + if err != nil { + t.Fatalf("failed to read from channel: %s", err) + } + + _, err = connBobToAlice.Write(buf[:n]) + if err != nil { + t.Fatalf("failed to write to channel: %s", err) + } + + n, err = connAliceToBob.Read(buf) + if err != nil { + t.Fatalf("failed to read from channel: %s", err) + } + + if payload != string(buf[:n]) { + t.Fatalf("expected %s, got %s", payload, string(buf[:n])) + } +} + +func TestForeginConnClose(t *testing.T) { + ctx := context.Background() + + srvCfg1 := server.ListenerConfig{ + Address: "localhost:1234", + } + srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av) + if err != nil { + t.Fatalf("failed to create server: %s", err) + } + errChan := make(chan error, 1) + go func() { + err := srv1.Listen(srvCfg1) + if err != nil { + errChan <- err + } + }() + + defer func() { + err := srv1.Shutdown(ctx) + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + + if err := waitForServerToStart(errChan); err != nil { + t.Fatalf("failed to start server: %s", err) + } + + srvCfg2 := server.ListenerConfig{ + Address: "localhost:2234", + } + srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av) + if err != nil { + t.Fatalf("failed to create server: %s", err) + } + errChan2 := make(chan error, 1) + go func() { + err := srv2.Listen(srvCfg2) + if err != nil { + errChan2 <- err + } + }() + + defer func() { + err := srv2.Shutdown(ctx) + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + + if err := waitForServerToStart(errChan2); err != nil { + t.Fatalf("failed to start server: %s", err) + } + + idAlice := "alice" + log.Debugf("connect by alice") + mCtx, cancel := context.WithCancel(ctx) + defer cancel() + mgr := NewManager(mCtx, toURL(srvCfg1), idAlice) + err = mgr.Serve() + if err != nil { + t.Fatalf("failed to serve manager: %s", err) + } + conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer") + if err != nil { + t.Fatalf("failed to bind channel: %s", err) + } + + err = conn.Close() + if err != nil { + t.Fatalf("failed to close connection: %s", err) + } +} + +func TestForeginAutoClose(t *testing.T) { + ctx := context.Background() + relayCleanupInterval = 1 * time.Second + srvCfg1 := server.ListenerConfig{ + Address: "localhost:1234", + } + srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av) + if err != nil { + t.Fatalf("failed to create server: %s", err) + } + errChan := make(chan error, 1) + go func() { + t.Log("binding server 1.") + err := srv1.Listen(srvCfg1) + if err != nil { + errChan <- err + } + }() + + defer func() { + t.Logf("closing server 1.") + err := srv1.Shutdown(ctx) + if err != nil { + t.Errorf("failed to close server: %s", err) + } + t.Logf("server 1. closed") + }() + + if err := waitForServerToStart(errChan); err != nil { + t.Fatalf("failed to start server: %s", err) + } + + srvCfg2 := server.ListenerConfig{ + Address: "localhost:2234", + } + srv2, err := server.NewServer(otel.Meter(""), srvCfg2.Address, false, av) + if err != nil { + t.Fatalf("failed to create server: %s", err) + } + errChan2 := make(chan error, 1) + go func() { + t.Log("binding server 2.") + err := srv2.Listen(srvCfg2) + if err != nil { + errChan2 <- err + } + }() + defer func() { + t.Logf("closing server 2.") + err := srv2.Shutdown(ctx) + if err != nil { + t.Errorf("failed to close server: %s", err) + } + t.Logf("server 2 closed.") + }() + + if err := waitForServerToStart(errChan2); err != nil { + t.Fatalf("failed to start server: %s", err) + } + + idAlice := "alice" + t.Log("connect to server 1.") + mCtx, cancel := context.WithCancel(ctx) + defer cancel() + mgr := NewManager(mCtx, toURL(srvCfg1), idAlice) + err = mgr.Serve() + if err != nil { + t.Fatalf("failed to serve manager: %s", err) + } + + t.Log("open connection to another peer") + conn, err := mgr.OpenConn(toURL(srvCfg2)[0], "anotherpeer") + if err != nil { + t.Fatalf("failed to bind channel: %s", err) + } + + t.Log("close conn") + err = conn.Close() + if err != nil { + t.Fatalf("failed to close connection: %s", err) + } + + t.Logf("waiting for relay cleanup: %s", relayCleanupInterval+1*time.Second) + time.Sleep(relayCleanupInterval + 1*time.Second) + if len(mgr.relayClients) != 0 { + t.Errorf("expected 0, got %d", len(mgr.relayClients)) + } + + t.Logf("closing manager") +} + +func TestAutoReconnect(t *testing.T) { + ctx := context.Background() + reconnectingTimeout = 2 * time.Second + + srvCfg := server.ListenerConfig{ + Address: "localhost:1234", + } + srv, err := server.NewServer(otel.Meter(""), srvCfg.Address, false, av) + if err != nil { + t.Fatalf("failed to create server: %s", err) + } + errChan := make(chan error, 1) + go func() { + err := srv.Listen(srvCfg) + if err != nil { + errChan <- err + } + }() + + defer func() { + err := srv.Shutdown(ctx) + if err != nil { + log.Errorf("failed to close server: %s", err) + } + }() + + if err := waitForServerToStart(errChan); err != nil { + t.Fatalf("failed to start server: %s", err) + } + + mCtx, cancel := context.WithCancel(ctx) + defer cancel() + clientAlice := NewManager(mCtx, toURL(srvCfg), "alice") + err = clientAlice.Serve() + if err != nil { + t.Fatalf("failed to serve manager: %s", err) + } + ra, err := clientAlice.RelayInstanceAddress() + if err != nil { + t.Errorf("failed to get relay address: %s", err) + } + conn, err := clientAlice.OpenConn(ra, "bob") + if err != nil { + t.Errorf("failed to bind channel: %s", err) + } + + t.Log("closing client relay connection") + // todo figure out moc server + _ = clientAlice.relayClient.relayConn.Close() + t.Log("start test reading") + _, err = conn.Read(make([]byte, 1)) + if err == nil { + t.Errorf("unexpected reading from closed connection") + } + + log.Infof("waiting for reconnection") + time.Sleep(reconnectingTimeout + 1*time.Second) + + log.Infof("reopent the connection") + _, err = clientAlice.OpenConn(ra, "bob") + if err != nil { + t.Errorf("failed to open channel: %s", err) + } +} + +func TestNotifierDoubleAdd(t *testing.T) { + ctx := context.Background() + + srvCfg1 := server.ListenerConfig{ + Address: "localhost:1234", + } + srv1, err := server.NewServer(otel.Meter(""), srvCfg1.Address, false, av) + if err != nil { + t.Fatalf("failed to create server: %s", err) + } + errChan := make(chan error, 1) + go func() { + err := srv1.Listen(srvCfg1) + if err != nil { + errChan <- err + } + }() + + defer func() { + err := srv1.Shutdown(ctx) + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + + if err := waitForServerToStart(errChan); err != nil { + t.Fatalf("failed to start server: %s", err) + } + + idAlice := "alice" + log.Debugf("connect by alice") + mCtx, cancel := context.WithCancel(ctx) + defer cancel() + clientAlice := NewManager(mCtx, toURL(srvCfg1), idAlice) + err = clientAlice.Serve() + if err != nil { + t.Fatalf("failed to serve manager: %s", err) + } + + conn1, err := clientAlice.OpenConn(clientAlice.ServerURLs()[0], "idBob") + if err != nil { + t.Fatalf("failed to bind channel: %s", err) + } + + fnCloseListener := OnServerCloseListener(func() { + log.Infof("close listener") + }) + + err = clientAlice.AddCloseListener(clientAlice.ServerURLs()[0], fnCloseListener) + if err != nil { + t.Fatalf("failed to add close listener: %s", err) + } + + err = clientAlice.AddCloseListener(clientAlice.ServerURLs()[0], fnCloseListener) + if err != nil { + t.Fatalf("failed to add close listener: %s", err) + } + + err = conn1.Close() + if err != nil { + t.Errorf("failed to close connection: %s", err) + } + +} + +func toURL(address server.ListenerConfig) []string { + return []string{"rel://" + address.Address} +} diff --git a/relay/client/picker.go b/relay/client/picker.go new file mode 100644 index 000000000..13b0547aa --- /dev/null +++ b/relay/client/picker.go @@ -0,0 +1,98 @@ +package client + +import ( + "context" + "errors" + "fmt" + "time" + + log "github.com/sirupsen/logrus" + + auth "github.com/netbirdio/netbird/relay/auth/hmac" +) + +const ( + connectionTimeout = 30 * time.Second + maxConcurrentServers = 7 +) + +type connResult struct { + RelayClient *Client + Url string + Err error +} + +type ServerPicker struct { + TokenStore *auth.TokenStore + PeerID string +} + +func (sp *ServerPicker) PickServer(parentCtx context.Context, urls []string) (*Client, error) { + ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout) + defer cancel() + + totalServers := len(urls) + + connResultChan := make(chan connResult, totalServers) + successChan := make(chan connResult, 1) + concurrentLimiter := make(chan struct{}, maxConcurrentServers) + + for _, url := range urls { + // todo check if we have a successful connection so we do not need to connect to other servers + concurrentLimiter <- struct{}{} + go func(url string) { + defer func() { + <-concurrentLimiter + }() + sp.startConnection(parentCtx, connResultChan, url) + }(url) + } + + go sp.processConnResults(connResultChan, successChan) + + select { + case cr, ok := <-successChan: + if !ok { + return nil, errors.New("failed to connect to any relay server: all attempts failed") + } + log.Infof("chosen home Relay server: %s", cr.Url) + return cr.RelayClient, nil + case <-ctx.Done(): + return nil, fmt.Errorf("failed to connect to any relay server: %w", ctx.Err()) + } +} + +func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) { + log.Infof("try to connecting to relay server: %s", url) + relayClient := NewClient(ctx, url, sp.TokenStore, sp.PeerID) + err := relayClient.Connect() + resultChan <- connResult{ + RelayClient: relayClient, + Url: url, + Err: err, + } +} + +func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult) { + var hasSuccess bool + for numOfResults := 0; numOfResults < cap(resultChan); numOfResults++ { + cr := <-resultChan + if cr.Err != nil { + log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err) + continue + } + log.Infof("connected to Relay server: %s", cr.Url) + + if hasSuccess { + log.Infof("closing unnecessary Relay connection to: %s", cr.Url) + if err := cr.RelayClient.Close(); err != nil { + log.Errorf("failed to close connection to %s: %v", cr.Url, err) + } + continue + } + + hasSuccess = true + successChan <- cr + } + close(successChan) +} diff --git a/relay/client/picker_test.go b/relay/client/picker_test.go new file mode 100644 index 000000000..eb14581e0 --- /dev/null +++ b/relay/client/picker_test.go @@ -0,0 +1,31 @@ +package client + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestServerPicker_UnavailableServers(t *testing.T) { + sp := ServerPicker{ + TokenStore: nil, + PeerID: "test", + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + go func() { + _, err := sp.PickServer(ctx, []string{"rel://dummy1", "rel://dummy2"}) + if err == nil { + t.Error(err) + } + cancel() + }() + + <-ctx.Done() + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + t.Errorf("PickServer() took too long to complete") + } +} diff --git a/relay/cmd/env.go b/relay/cmd/env.go new file mode 100644 index 000000000..3c15ebe1f --- /dev/null +++ b/relay/cmd/env.go @@ -0,0 +1,35 @@ +package cmd + +import ( + "os" + "strings" + + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "github.com/spf13/pflag" +) + +// setFlagsFromEnvVars reads and updates flag values from environment variables with prefix NB_ +func setFlagsFromEnvVars(cmd *cobra.Command) { + flags := cmd.PersistentFlags() + flags.VisitAll(func(f *pflag.Flag) { + newEnvVar := flagNameToEnvVar(f.Name, "NB_") + value, present := os.LookupEnv(newEnvVar) + if !present { + return + } + + err := flags.Set(f.Name, value) + if err != nil { + log.Infof("unable to configure flag %s using variable %s, err: %v", f.Name, newEnvVar, err) + } + }) +} + +// flagNameToEnvVar converts flag name to environment var name adding a prefix, +// replacing dashes and making all uppercase (e.g. setup-keys is converted to NB_SETUP_KEYS according to the input prefix) +func flagNameToEnvVar(cmdFlag string, prefix string) string { + parsed := strings.ReplaceAll(cmdFlag, "-", "_") + upper := strings.ToUpper(parsed) + return prefix + upper +} diff --git a/relay/cmd/root.go b/relay/cmd/root.go new file mode 100644 index 000000000..d603ff73b --- /dev/null +++ b/relay/cmd/root.go @@ -0,0 +1,215 @@ +package cmd + +import ( + "context" + "crypto/sha256" + "crypto/tls" + "errors" + "fmt" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/encryption" + "github.com/netbirdio/netbird/relay/auth" + "github.com/netbirdio/netbird/relay/server" + "github.com/netbirdio/netbird/signal/metrics" + "github.com/netbirdio/netbird/util" +) + +type Config struct { + ListenAddress string + // in HA every peer connect to a common domain, the instance domain has been distributed during the p2p connection + // it is a domain:port or ip:port + ExposedAddress string + MetricsPort int + LetsencryptEmail string + LetsencryptDataDir string + LetsencryptDomains []string + // in case of using Route 53 for DNS challenge the credentials should be provided in the environment variables or + // in the AWS credentials file + LetsencryptAWSRoute53 bool + TlsCertFile string + TlsKeyFile string + AuthSecret string + LogLevel string + LogFile string +} + +func (c Config) Validate() error { + if c.ExposedAddress == "" { + return fmt.Errorf("exposed address is required") + } + if c.AuthSecret == "" { + return fmt.Errorf("auth secret is required") + } + return nil +} + +func (c Config) HasCertConfig() bool { + return c.TlsCertFile != "" && c.TlsKeyFile != "" +} + +func (c Config) HasLetsEncrypt() bool { + return c.LetsencryptDataDir != "" && c.LetsencryptDomains != nil && len(c.LetsencryptDomains) > 0 +} + +var ( + cobraConfig *Config + rootCmd = &cobra.Command{ + Use: "relay", + Short: "Relay service", + Long: "Relay service for Netbird agents", + SilenceUsage: true, + SilenceErrors: true, + RunE: execute, + } +) + +func init() { + _ = util.InitLog("trace", "console") + cobraConfig = &Config{} + rootCmd.PersistentFlags().StringVarP(&cobraConfig.ListenAddress, "listen-address", "l", ":443", "listen address") + rootCmd.PersistentFlags().StringVarP(&cobraConfig.ExposedAddress, "exposed-address", "e", "", "instance domain address (or ip) and port, it will be distributes between peers") + rootCmd.PersistentFlags().IntVar(&cobraConfig.MetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics") + rootCmd.PersistentFlags().StringVarP(&cobraConfig.LetsencryptDataDir, "letsencrypt-data-dir", "d", "", "a directory to store Let's Encrypt data. Required if Let's Encrypt is enabled.") + rootCmd.PersistentFlags().StringSliceVarP(&cobraConfig.LetsencryptDomains, "letsencrypt-domains", "a", nil, "list of domains to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") + rootCmd.PersistentFlags().StringVar(&cobraConfig.LetsencryptEmail, "letsencrypt-email", "", "email address to use for Let's Encrypt certificate registration") + rootCmd.PersistentFlags().BoolVar(&cobraConfig.LetsencryptAWSRoute53, "letsencrypt-aws-route53", false, "use AWS Route 53 for Let's Encrypt DNS challenge") + rootCmd.PersistentFlags().StringVarP(&cobraConfig.TlsCertFile, "tls-cert-file", "c", "", "") + rootCmd.PersistentFlags().StringVarP(&cobraConfig.TlsKeyFile, "tls-key-file", "k", "", "") + rootCmd.PersistentFlags().StringVarP(&cobraConfig.AuthSecret, "auth-secret", "s", "", "auth secret") + rootCmd.PersistentFlags().StringVar(&cobraConfig.LogLevel, "log-level", "info", "log level") + rootCmd.PersistentFlags().StringVar(&cobraConfig.LogFile, "log-file", "console", "log file") + + setFlagsFromEnvVars(rootCmd) +} + +func Execute() error { + return rootCmd.Execute() +} + +func waitForExitSignal() { + osSigs := make(chan os.Signal, 1) + signal.Notify(osSigs, syscall.SIGINT, syscall.SIGTERM) + <-osSigs +} + +func execute(cmd *cobra.Command, args []string) error { + err := cobraConfig.Validate() + if err != nil { + log.Debugf("invalid config: %s", err) + return fmt.Errorf("invalid config: %s", err) + } + + err = util.InitLog(cobraConfig.LogLevel, cobraConfig.LogFile) + if err != nil { + log.Debugf("failed to initialize log: %s", err) + return fmt.Errorf("failed to initialize log: %s", err) + } + + metricsServer, err := metrics.NewServer(cobraConfig.MetricsPort, "") + if err != nil { + log.Debugf("setup metrics: %v", err) + return fmt.Errorf("setup metrics: %v", err) + } + + go func() { + log.Infof("running metrics server: %s%s", metricsServer.Addr, metricsServer.Endpoint) + if err := metricsServer.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) { + log.Fatalf("Failed to start metrics server: %v", err) + } + }() + + srvListenerCfg := server.ListenerConfig{ + Address: cobraConfig.ListenAddress, + } + + tlsConfig, tlsSupport, err := handleTLSConfig(cobraConfig) + if err != nil { + log.Debugf("failed to setup TLS config: %s", err) + return fmt.Errorf("failed to setup TLS config: %s", err) + } + srvListenerCfg.TLSConfig = tlsConfig + + hashedSecret := sha256.Sum256([]byte(cobraConfig.AuthSecret)) + authenticator := auth.NewTimedHMACValidator(hashedSecret[:], 24*time.Hour) + + srv, err := server.NewServer(metricsServer.Meter, cobraConfig.ExposedAddress, tlsSupport, authenticator) + if err != nil { + log.Debugf("failed to create relay server: %v", err) + return fmt.Errorf("failed to create relay server: %v", err) + } + log.Infof("server will be available on: %s", srv.InstanceURL()) + go func() { + if err := srv.Listen(srvListenerCfg); err != nil { + log.Fatalf("failed to bind server: %s", err) + } + }() + + // it will block until exit signal + waitForExitSignal() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + var shutDownErrors error + if err := srv.Shutdown(ctx); err != nil { + shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close server: %s", err)) + } + + log.Infof("shutting down metrics server") + if err := metricsServer.Shutdown(ctx); err != nil { + shutDownErrors = multierror.Append(shutDownErrors, fmt.Errorf("failed to close metrics server: %v", err)) + } + return shutDownErrors +} + +func handleTLSConfig(cfg *Config) (*tls.Config, bool, error) { + if cfg.LetsencryptAWSRoute53 { + log.Debugf("using Let's Encrypt DNS resolver with Route 53 support") + r53 := encryption.Route53TLS{ + DataDir: cfg.LetsencryptDataDir, + Email: cfg.LetsencryptEmail, + Domains: cfg.LetsencryptDomains, + } + tlsCfg, err := r53.GetCertificate() + if err != nil { + return nil, false, fmt.Errorf("%s", err) + } + return tlsCfg, true, nil + } + + if cfg.HasLetsEncrypt() { + log.Infof("setting up TLS with Let's Encrypt.") + tlsCfg, err := setupTLSCertManager(cfg.LetsencryptDataDir, cfg.LetsencryptDomains...) + if err != nil { + return nil, false, fmt.Errorf("%s", err) + } + return tlsCfg, true, nil + } + + if cfg.HasCertConfig() { + log.Debugf("using file based TLS config") + tlsCfg, err := encryption.LoadTLSConfig(cfg.TlsCertFile, cfg.TlsKeyFile) + if err != nil { + return nil, false, fmt.Errorf("%s", err) + } + return tlsCfg, true, nil + } + return nil, false, nil +} + +func setupTLSCertManager(letsencryptDataDir string, letsencryptDomains ...string) (*tls.Config, error) { + certManager, err := encryption.CreateCertManager(letsencryptDataDir, letsencryptDomains...) + if err != nil { + return nil, fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err) + } + return certManager.TLSConfig(), nil +} diff --git a/relay/doc.go b/relay/doc.go new file mode 100644 index 000000000..56e010e3e --- /dev/null +++ b/relay/doc.go @@ -0,0 +1,14 @@ +//Package main +/* +The `relay` package contains the implementation of the Relay server and client. The Relay server can be used to relay +messages between peers on a single network channel. In this implementation the transport layer is the WebSocket +protocol. + +Between the server and client communication has been design a custom protocol and message format. These messages are +transported over the WebSocket connection. Optionally the server can use TLS to secure the communication. + +The service can support multiple Relay server instances. For this purpose the peers must know the server instance URL. +This URL will be sent to the target peer to choose the common Relay server for the communication via Signal service. + +*/ +package main diff --git a/relay/healthcheck/doc.go b/relay/healthcheck/doc.go new file mode 100644 index 000000000..da9689c6b --- /dev/null +++ b/relay/healthcheck/doc.go @@ -0,0 +1,17 @@ +/* +The `healthcheck` package is responsible for managing the health checks between the client and the relay server. It +ensures that the connection between the client and the server are alive and functioning properly. + +The `Sender` struct is responsible for sending health check signals to the receiver. The receiver listens for these +signals and sends a new signal back to the sender to acknowledge that the signal has been received. If the sender does +not receive an acknowledgment signal within a certain time frame, it will send a timeout signal via timeout channel +and stop working. + +The `Receiver` struct is responsible for receiving the health check signals from the sender. If the receiver does not +receive a signal within a certain time frame, it will send a timeout signal via the OnTimeout channel and stop working. + +In the Relay usage the signal is sent to the peer in message type Healthcheck. In case of timeout the connection is +closed and the peer is removed from the relay. +*/ + +package healthcheck diff --git a/relay/healthcheck/receiver.go b/relay/healthcheck/receiver.go new file mode 100644 index 000000000..b3503d5db --- /dev/null +++ b/relay/healthcheck/receiver.go @@ -0,0 +1,94 @@ +package healthcheck + +import ( + "context" + "time" + + log "github.com/sirupsen/logrus" +) + +var ( + heartbeatTimeout = healthCheckInterval + 10*time.Second +) + +// Receiver is a healthcheck receiver +// It will listen for heartbeat and check if the heartbeat is not received in a certain time +// If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work +// The heartbeat timeout is a bit longer than the sender's healthcheck interval +type Receiver struct { + OnTimeout chan struct{} + log *log.Entry + ctx context.Context + ctxCancel context.CancelFunc + heartbeat chan struct{} + alive bool + attemptThreshold int +} + +// NewReceiver creates a new healthcheck receiver and start the timer in the background +func NewReceiver(log *log.Entry) *Receiver { + ctx, ctxCancel := context.WithCancel(context.Background()) + + r := &Receiver{ + OnTimeout: make(chan struct{}, 1), + log: log, + ctx: ctx, + ctxCancel: ctxCancel, + heartbeat: make(chan struct{}, 1), + attemptThreshold: getAttemptThresholdFromEnv(), + } + + go r.waitForHealthcheck() + return r +} + +// Heartbeat acknowledge the heartbeat has been received +func (r *Receiver) Heartbeat() { + select { + case r.heartbeat <- struct{}{}: + default: + } +} + +// Stop check the timeout and do not send new notifications +func (r *Receiver) Stop() { + r.ctxCancel() +} + +func (r *Receiver) waitForHealthcheck() { + ticker := time.NewTicker(heartbeatTimeout) + defer ticker.Stop() + defer r.ctxCancel() + defer close(r.OnTimeout) + + failureCounter := 0 + for { + select { + case <-r.heartbeat: + r.alive = true + failureCounter = 0 + case <-ticker.C: + if r.alive { + r.alive = false + continue + } + + failureCounter++ + if failureCounter < r.attemptThreshold { + r.log.Warnf("healthcheck failed, attempt %d", failureCounter) + continue + } + r.notifyTimeout() + return + case <-r.ctx.Done(): + return + } + } +} + +func (r *Receiver) notifyTimeout() { + select { + case r.OnTimeout <- struct{}{}: + default: + } +} diff --git a/relay/healthcheck/receiver_test.go b/relay/healthcheck/receiver_test.go new file mode 100644 index 000000000..3b3e32fe6 --- /dev/null +++ b/relay/healthcheck/receiver_test.go @@ -0,0 +1,97 @@ +package healthcheck + +import ( + "context" + "fmt" + "os" + "testing" + "time" + + log "github.com/sirupsen/logrus" +) + +func TestNewReceiver(t *testing.T) { + heartbeatTimeout = 5 * time.Second + r := NewReceiver(log.WithContext(context.Background())) + + select { + case <-r.OnTimeout: + t.Error("unexpected timeout") + case <-time.After(1 * time.Second): + + } +} + +func TestNewReceiverNotReceive(t *testing.T) { + heartbeatTimeout = 1 * time.Second + r := NewReceiver(log.WithContext(context.Background())) + + select { + case <-r.OnTimeout: + case <-time.After(2 * time.Second): + t.Error("timeout not received") + } +} + +func TestNewReceiverAck(t *testing.T) { + heartbeatTimeout = 2 * time.Second + r := NewReceiver(log.WithContext(context.Background())) + + r.Heartbeat() + + select { + case <-r.OnTimeout: + t.Error("unexpected timeout") + case <-time.After(3 * time.Second): + } +} + +func TestReceiverHealthCheckAttemptThreshold(t *testing.T) { + testsCases := []struct { + name string + threshold int + resetCounterOnce bool + }{ + {"Default attempt threshold", defaultAttemptThreshold, false}, + {"Custom attempt threshold", 3, false}, + {"Should reset threshold once", 2, true}, + } + + for _, tc := range testsCases { + t.Run(tc.name, func(t *testing.T) { + originalInterval := healthCheckInterval + originalTimeout := heartbeatTimeout + healthCheckInterval = 1 * time.Second + heartbeatTimeout = healthCheckInterval + 500*time.Millisecond + defer func() { + healthCheckInterval = originalInterval + heartbeatTimeout = originalTimeout + }() + //nolint:tenv + os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold)) + defer os.Unsetenv(defaultAttemptThresholdEnv) + + receiver := NewReceiver(log.WithField("test_name", tc.name)) + + testTimeout := heartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval + + if tc.resetCounterOnce { + receiver.Heartbeat() + t.Logf("reset counter once") + } + + select { + case <-receiver.OnTimeout: + if tc.resetCounterOnce { + t.Fatalf("should not have timed out before %s", testTimeout) + } + case <-time.After(testTimeout): + if tc.resetCounterOnce { + return + } + t.Fatalf("should have timed out before %s", testTimeout) + } + + }) + } +} diff --git a/relay/healthcheck/sender.go b/relay/healthcheck/sender.go new file mode 100644 index 000000000..57b3015ec --- /dev/null +++ b/relay/healthcheck/sender.go @@ -0,0 +1,110 @@ +package healthcheck + +import ( + "context" + "os" + "strconv" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + defaultAttemptThreshold = 1 + defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD" +) + +var ( + healthCheckInterval = 25 * time.Second + healthCheckTimeout = 20 * time.Second +) + +// Sender is a healthcheck sender +// It will send healthcheck signal to the receiver +// If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work +// It will also stop if the context is canceled +type Sender struct { + log *log.Entry + // HealthCheck is a channel to send health check signal to the peer + HealthCheck chan struct{} + // Timeout is a channel to the health check signal is not received in a certain time + Timeout chan struct{} + + ack chan struct{} + alive bool + attemptThreshold int +} + +// NewSender creates a new healthcheck sender +func NewSender(log *log.Entry) *Sender { + hc := &Sender{ + log: log, + HealthCheck: make(chan struct{}, 1), + Timeout: make(chan struct{}, 1), + ack: make(chan struct{}, 1), + attemptThreshold: getAttemptThresholdFromEnv(), + } + + return hc +} + +// OnHCResponse sends an acknowledgment signal to the sender +func (hc *Sender) OnHCResponse() { + select { + case hc.ack <- struct{}{}: + default: + } +} + +func (hc *Sender) StartHealthCheck(ctx context.Context) { + ticker := time.NewTicker(healthCheckInterval) + defer ticker.Stop() + + timeoutTicker := time.NewTicker(hc.getTimeoutTime()) + defer timeoutTicker.Stop() + + defer close(hc.HealthCheck) + defer close(hc.Timeout) + + failureCounter := 0 + for { + select { + case <-ticker.C: + hc.HealthCheck <- struct{}{} + case <-timeoutTicker.C: + if hc.alive { + hc.alive = false + continue + } + + failureCounter++ + if failureCounter < hc.attemptThreshold { + hc.log.Warnf("Health check failed attempt %d.", failureCounter) + continue + } + hc.Timeout <- struct{}{} + return + case <-hc.ack: + failureCounter = 0 + hc.alive = true + case <-ctx.Done(): + return + } + } +} + +func (hc *Sender) getTimeoutTime() time.Duration { + return healthCheckInterval + healthCheckTimeout +} + +func getAttemptThresholdFromEnv() int { + if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" { + threshold, err := strconv.ParseInt(attemptThreshold, 10, 64) + if err != nil { + log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold) + return defaultAttemptThreshold + } + return int(threshold) + } + return defaultAttemptThreshold +} diff --git a/relay/healthcheck/sender_test.go b/relay/healthcheck/sender_test.go new file mode 100644 index 000000000..f21167025 --- /dev/null +++ b/relay/healthcheck/sender_test.go @@ -0,0 +1,205 @@ +package healthcheck + +import ( + "context" + "fmt" + "os" + "testing" + "time" + + log "github.com/sirupsen/logrus" +) + +func TestMain(m *testing.M) { + // override the health check interval to speed up the test + healthCheckInterval = 2 * time.Second + healthCheckTimeout = 100 * time.Millisecond + code := m.Run() + os.Exit(code) +} + +func TestNewHealthPeriod(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + hc := NewSender(log.WithContext(ctx)) + go hc.StartHealthCheck(ctx) + + iterations := 0 + for i := 0; i < 3; i++ { + select { + case <-hc.HealthCheck: + iterations++ + hc.OnHCResponse() + case <-hc.Timeout: + t.Fatalf("health check is timed out") + case <-time.After(healthCheckInterval + 100*time.Millisecond): + t.Fatalf("health check not received") + } + } +} + +func TestNewHealthFailed(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + hc := NewSender(log.WithContext(ctx)) + go hc.StartHealthCheck(ctx) + + select { + case <-hc.Timeout: + case <-time.After(healthCheckInterval + healthCheckTimeout + 100*time.Millisecond): + t.Fatalf("health check is not timed out") + } +} + +func TestNewHealthcheckStop(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + hc := NewSender(log.WithContext(ctx)) + go hc.StartHealthCheck(ctx) + + time.Sleep(100 * time.Millisecond) + cancel() + + select { + case _, ok := <-hc.HealthCheck: + if ok { + t.Fatalf("health check on received") + } + case _, ok := <-hc.Timeout: + if ok { + t.Fatalf("health check on received") + } + case <-ctx.Done(): + // expected + case <-time.After(10 * time.Second): + t.Fatalf("is not exited") + } +} + +func TestTimeoutReset(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + hc := NewSender(log.WithContext(ctx)) + go hc.StartHealthCheck(ctx) + + iterations := 0 + for i := 0; i < 3; i++ { + select { + case <-hc.HealthCheck: + iterations++ + hc.OnHCResponse() + case <-hc.Timeout: + t.Fatalf("health check is timed out") + case <-time.After(healthCheckInterval + 100*time.Millisecond): + t.Fatalf("health check not received") + } + } + + select { + case <-hc.HealthCheck: + case <-hc.Timeout: + // expected + case <-ctx.Done(): + t.Fatalf("context is done") + case <-time.After(10 * time.Second): + t.Fatalf("is not exited") + } +} + +func TestSenderHealthCheckAttemptThreshold(t *testing.T) { + testsCases := []struct { + name string + threshold int + resetCounterOnce bool + }{ + {"Default attempt threshold", defaultAttemptThreshold, false}, + {"Custom attempt threshold", 3, false}, + {"Should reset threshold once", 2, true}, + } + + for _, tc := range testsCases { + t.Run(tc.name, func(t *testing.T) { + originalInterval := healthCheckInterval + originalTimeout := healthCheckTimeout + healthCheckInterval = 1 * time.Second + healthCheckTimeout = 500 * time.Millisecond + defer func() { + healthCheckInterval = originalInterval + healthCheckTimeout = originalTimeout + }() + + //nolint:tenv + os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold)) + defer os.Unsetenv(defaultAttemptThresholdEnv) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sender := NewSender(log.WithField("test_name", tc.name)) + go sender.StartHealthCheck(ctx) + + go func() { + responded := false + for { + select { + case <-ctx.Done(): + return + case _, ok := <-sender.HealthCheck: + if !ok { + return + } + if tc.resetCounterOnce && !responded { + responded = true + sender.OnHCResponse() + } + } + } + }() + + testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + healthCheckInterval + + select { + case <-sender.Timeout: + if tc.resetCounterOnce { + t.Fatalf("should not have timed out before %s", testTimeout) + } + case <-time.After(testTimeout): + if tc.resetCounterOnce { + return + } + t.Fatalf("should have timed out before %s", testTimeout) + } + + }) + } + +} + +//nolint:tenv +func TestGetAttemptThresholdFromEnv(t *testing.T) { + tests := []struct { + name string + envValue string + expected int + }{ + {"Default attempt threshold when env is not set", "", defaultAttemptThreshold}, + {"Custom attempt threshold when env is set to a valid integer", "3", 3}, + {"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.envValue == "" { + os.Unsetenv(defaultAttemptThresholdEnv) + } else { + os.Setenv(defaultAttemptThresholdEnv, tt.envValue) + } + + result := getAttemptThresholdFromEnv() + if result != tt.expected { + t.Fatalf("Expected %d, got %d", tt.expected, result) + } + + os.Unsetenv(defaultAttemptThresholdEnv) + }) + } +} diff --git a/relay/main.go b/relay/main.go new file mode 100644 index 000000000..e28f73603 --- /dev/null +++ b/relay/main.go @@ -0,0 +1,13 @@ +package main + +import ( + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/relay/cmd" +) + +func main() { + if err := cmd.Execute(); err != nil { + log.Fatalf("failed to execute command: %v", err) + } +} diff --git a/relay/messages/address/address.go b/relay/messages/address/address.go new file mode 100644 index 000000000..707e73e55 --- /dev/null +++ b/relay/messages/address/address.go @@ -0,0 +1,21 @@ +// Deprecated: This package is deprecated and will be removed in a future release. +package address + +import ( + "bytes" + "encoding/gob" + "fmt" +) + +type Address struct { + URL string +} + +func (addr *Address) Marshal() ([]byte, error) { + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + if err := enc.Encode(addr); err != nil { + return nil, fmt.Errorf("encode Address: %w", err) + } + return buf.Bytes(), nil +} diff --git a/relay/messages/auth/auth.go b/relay/messages/auth/auth.go new file mode 100644 index 000000000..9c2511f2f --- /dev/null +++ b/relay/messages/auth/auth.go @@ -0,0 +1,43 @@ +// Deprecated: This package is deprecated and will be removed in a future release. +package auth + +import ( + "bytes" + "encoding/gob" + "fmt" +) + +type Algorithm int + +const ( + AlgoUnknown Algorithm = iota + AlgoHMACSHA256 + AlgoHMACSHA512 +) + +func (a Algorithm) String() string { + switch a { + case AlgoHMACSHA256: + return "HMAC-SHA256" + case AlgoHMACSHA512: + return "HMAC-SHA512" + default: + return "Unknown" + } +} + +type Msg struct { + AuthAlgorithm Algorithm + AdditionalData []byte +} + +func UnmarshalMsg(data []byte) (*Msg, error) { + var msg *Msg + + buf := bytes.NewBuffer(data) + dec := gob.NewDecoder(buf) + if err := dec.Decode(&msg); err != nil { + return nil, fmt.Errorf("decode Msg: %w", err) + } + return msg, nil +} diff --git a/relay/messages/doc.go b/relay/messages/doc.go new file mode 100644 index 000000000..4c719df3a --- /dev/null +++ b/relay/messages/doc.go @@ -0,0 +1,5 @@ +/* +Package messages provides the message types that are used to communicate between the relay and the client. +This package is used to determine the type of message that is being sent and received between the relay and the client. +*/ +package messages diff --git a/relay/messages/id.go b/relay/messages/id.go new file mode 100644 index 000000000..e2162cd3b --- /dev/null +++ b/relay/messages/id.go @@ -0,0 +1,31 @@ +package messages + +import ( + "crypto/sha256" + "encoding/base64" + "fmt" +) + +const ( + prefixLength = 4 + IDSize = prefixLength + sha256.Size +) + +var ( + prefix = []byte("sha-") // 4 bytes +) + +// HashID generates a sha256 hash from the peerID and returns the hash and the human-readable string +func HashID(peerID string) ([]byte, string) { + idHash := sha256.Sum256([]byte(peerID)) + idHashString := string(prefix) + base64.StdEncoding.EncodeToString(idHash[:]) + var prefixedHash []byte + prefixedHash = append(prefixedHash, prefix...) + prefixedHash = append(prefixedHash, idHash[:]...) + return prefixedHash, idHashString +} + +// HashIDToString converts a hash to a human-readable string +func HashIDToString(idHash []byte) string { + return fmt.Sprintf("%s%s", idHash[:prefixLength], base64.StdEncoding.EncodeToString(idHash[prefixLength:])) +} diff --git a/relay/messages/id_test.go b/relay/messages/id_test.go new file mode 100644 index 000000000..271a8f90d --- /dev/null +++ b/relay/messages/id_test.go @@ -0,0 +1,13 @@ +package messages + +import ( + "testing" +) + +func TestHashID(t *testing.T) { + hashedID, hashedStringId := HashID("alice") + enc := HashIDToString(hashedID) + if enc != hashedStringId { + t.Errorf("expected %s, got %s", hashedStringId, enc) + } +} diff --git a/relay/messages/message.go b/relay/messages/message.go new file mode 100644 index 000000000..39ca0aa90 --- /dev/null +++ b/relay/messages/message.go @@ -0,0 +1,317 @@ +package messages + +import ( + "bytes" + "errors" + "fmt" +) + +const ( + MaxHandshakeSize = 212 + MaxHandshakeRespSize = 8192 + + CurrentProtocolVersion = 1 + + MsgTypeUnknown MsgType = 0 + // Deprecated: Use MsgTypeAuth instead. + MsgTypeHello MsgType = 1 + // Deprecated: Use MsgTypeAuthResponse instead. + MsgTypeHelloResponse MsgType = 2 + MsgTypeTransport MsgType = 3 + MsgTypeClose MsgType = 4 + MsgTypeHealthCheck MsgType = 5 + MsgTypeAuth = 6 + MsgTypeAuthResponse = 7 + + SizeOfVersionByte = 1 + SizeOfMsgType = 1 + + SizeOfProtoHeader = SizeOfVersionByte + SizeOfMsgType + + sizeOfMagicByte = 4 + + headerSizeTransport = IDSize + + headerSizeHello = sizeOfMagicByte + IDSize + headerSizeHelloResp = 0 + + headerSizeAuth = sizeOfMagicByte + IDSize + headerSizeAuthResp = 0 +) + +var ( + ErrInvalidMessageLength = errors.New("invalid message length") + ErrUnsupportedVersion = errors.New("unsupported version") + + magicHeader = []byte{0x21, 0x12, 0xA4, 0x42} + + healthCheckMsg = []byte{byte(CurrentProtocolVersion), byte(MsgTypeHealthCheck)} +) + +type MsgType byte + +func (m MsgType) String() string { + switch m { + case MsgTypeHello: + return "hello" + case MsgTypeHelloResponse: + return "hello response" + case MsgTypeAuth: + return "auth" + case MsgTypeAuthResponse: + return "auth response" + case MsgTypeTransport: + return "transport" + case MsgTypeClose: + return "close" + case MsgTypeHealthCheck: + return "health check" + default: + return "unknown" + } +} + +// ValidateVersion checks if the given version is supported by the protocol +func ValidateVersion(msg []byte) (int, error) { + if len(msg) < SizeOfVersionByte { + return 0, ErrInvalidMessageLength + } + version := int(msg[0]) + if version != CurrentProtocolVersion { + return 0, fmt.Errorf("%d: %w", version, ErrUnsupportedVersion) + } + return version, nil +} + +// DetermineClientMessageType determines the message type from the first the message +func DetermineClientMessageType(msg []byte) (MsgType, error) { + if len(msg) < SizeOfMsgType { + return 0, ErrInvalidMessageLength + } + + msgType := MsgType(msg[0]) + switch msgType { + case + MsgTypeHello, + MsgTypeAuth, + MsgTypeTransport, + MsgTypeClose, + MsgTypeHealthCheck: + return msgType, nil + default: + return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType) + } +} + +// DetermineServerMessageType determines the message type from the first the message +func DetermineServerMessageType(msg []byte) (MsgType, error) { + if len(msg) < SizeOfMsgType { + return 0, ErrInvalidMessageLength + } + + msgType := MsgType(msg[0]) + switch msgType { + case + MsgTypeHelloResponse, + MsgTypeAuthResponse, + MsgTypeTransport, + MsgTypeClose, + MsgTypeHealthCheck: + return msgType, nil + default: + return MsgTypeUnknown, fmt.Errorf("invalid msg type %d", msgType) + } +} + +// Deprecated: Use MarshalAuthMsg instead. +// MarshalHelloMsg initial hello message +// The Hello message is the first message sent by a client after establishing a connection with the Relay server. This +// message is used to authenticate the client with the server. The authentication is done using an HMAC method. +// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will +// close the network connection without any response. +func MarshalHelloMsg(peerID []byte, additions []byte) ([]byte, error) { + if len(peerID) != IDSize { + return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) + } + + msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeHello+len(additions)) + + msg[0] = byte(CurrentProtocolVersion) + msg[1] = byte(MsgTypeHello) + + copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader) + + msg = append(msg, peerID...) + msg = append(msg, additions...) + + return msg, nil +} + +// Deprecated: Use UnmarshalAuthMsg instead. +// UnmarshalHelloMsg extracts peerID and the additional data from the hello message. The Additional data is used to +// authenticate the client with the server. +func UnmarshalHelloMsg(msg []byte) ([]byte, []byte, error) { + if len(msg) < headerSizeHello { + return nil, nil, ErrInvalidMessageLength + } + if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) { + return nil, nil, errors.New("invalid magic header") + } + + return msg[sizeOfMagicByte:headerSizeHello], msg[headerSizeHello:], nil +} + +// Deprecated: Use MarshalAuthResponse instead. +// MarshalHelloResponse creates a response message to the hello message. +// In case of success connection the server response with a Hello Response message. This message contains the server's +// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay +// servers. +func MarshalHelloResponse(additionalData []byte) ([]byte, error) { + msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeHelloResp+len(additionalData)) + + msg[0] = byte(CurrentProtocolVersion) + msg[1] = byte(MsgTypeHelloResponse) + + msg = append(msg, additionalData...) + + return msg, nil +} + +// Deprecated: Use UnmarshalAuthResponse instead. +// UnmarshalHelloResponse extracts the additional data from the hello response message. +func UnmarshalHelloResponse(msg []byte) ([]byte, error) { + if len(msg) < headerSizeHelloResp { + return nil, ErrInvalidMessageLength + } + return msg, nil +} + +// MarshalAuthMsg initial authentication message +// The Auth message is the first message sent by a client after establishing a connection with the Relay server. This +// message is used to authenticate the client with the server. The authentication is done using an HMAC method. +// The protocol does not limit to use HMAC, it can be any other method. If the authentication failed the server will +// close the network connection without any response. +func MarshalAuthMsg(peerID []byte, authPayload []byte) ([]byte, error) { + if len(peerID) != IDSize { + return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) + } + + msg := make([]byte, SizeOfProtoHeader+sizeOfMagicByte, SizeOfProtoHeader+headerSizeAuth+len(authPayload)) + + msg[0] = byte(CurrentProtocolVersion) + msg[1] = byte(MsgTypeAuth) + + copy(msg[SizeOfProtoHeader:SizeOfProtoHeader+sizeOfMagicByte], magicHeader) + + msg = append(msg, peerID...) + msg = append(msg, authPayload...) + + return msg, nil +} + +// UnmarshalAuthMsg extracts peerID and the auth payload from the message +func UnmarshalAuthMsg(msg []byte) ([]byte, []byte, error) { + if len(msg) < headerSizeAuth { + return nil, nil, ErrInvalidMessageLength + } + if !bytes.Equal(msg[:sizeOfMagicByte], magicHeader) { + return nil, nil, errors.New("invalid magic header") + } + + return msg[sizeOfMagicByte:headerSizeAuth], msg[headerSizeAuth:], nil +} + +// MarshalAuthResponse creates a response message to the auth. +// In case of success connection the server response with a AuthResponse message. This message contains the server's +// instance URL. This URL will be used by choose the common Relay server in case if the peers are in different Relay +// servers. +func MarshalAuthResponse(address string) ([]byte, error) { + ab := []byte(address) + msg := make([]byte, SizeOfProtoHeader, SizeOfProtoHeader+headerSizeAuthResp+len(ab)) + + msg[0] = byte(CurrentProtocolVersion) + msg[1] = byte(MsgTypeAuthResponse) + + msg = append(msg, ab...) + + if len(msg) > MaxHandshakeRespSize { + return nil, fmt.Errorf("invalid message length: %d", len(msg)) + } + + return msg, nil +} + +// UnmarshalAuthResponse it is a confirmation message to auth success +func UnmarshalAuthResponse(msg []byte) (string, error) { + if len(msg) < headerSizeAuthResp+1 { + return "", ErrInvalidMessageLength + } + return string(msg), nil +} + +// MarshalCloseMsg creates a close message. +// The close message is used to close the connection gracefully between the client and the server. The server and the +// client can send this message. After receiving this message, the server or client will close the connection. +func MarshalCloseMsg() []byte { + msg := make([]byte, SizeOfProtoHeader) + + msg[0] = byte(CurrentProtocolVersion) + msg[1] = byte(MsgTypeClose) + + return msg +} + +// MarshalTransportMsg creates a transport message. +// The transport message is used to exchange data between peers. The message contains the data to be exchanged and the +// destination peer hashed ID. +func MarshalTransportMsg(peerID []byte, payload []byte) ([]byte, error) { + if len(peerID) != IDSize { + return nil, fmt.Errorf("invalid peerID length: %d", len(peerID)) + } + + msg := make([]byte, SizeOfProtoHeader+headerSizeTransport, SizeOfProtoHeader+headerSizeTransport+len(payload)) + + msg[0] = byte(CurrentProtocolVersion) + msg[1] = byte(MsgTypeTransport) + + copy(msg[SizeOfProtoHeader:], peerID) + + msg = append(msg, payload...) + + return msg, nil +} + +// UnmarshalTransportMsg extracts the peerID and the payload from the transport message. +func UnmarshalTransportMsg(buf []byte) ([]byte, []byte, error) { + if len(buf) < headerSizeTransport { + return nil, nil, ErrInvalidMessageLength + } + + return buf[:headerSizeTransport], buf[headerSizeTransport:], nil +} + +// UnmarshalTransportID extracts the peerID from the transport message. +func UnmarshalTransportID(buf []byte) ([]byte, error) { + if len(buf) < headerSizeTransport { + return nil, ErrInvalidMessageLength + } + return buf[:headerSizeTransport], nil +} + +// UpdateTransportMsg updates the peerID in the transport message. +// With this function the server can reuse the given byte slice to update the peerID in the transport message. So do +// need to allocate a new byte slice. +func UpdateTransportMsg(msg []byte, peerID []byte) error { + if len(msg) < len(peerID) { + return ErrInvalidMessageLength + } + copy(msg, peerID) + return nil +} + +// MarshalHealthcheck creates a health check message. +// Health check message is sent by the server periodically. The client will respond with a health check response +// message. If the client does not respond to the health check message, the server will close the connection. +func MarshalHealthcheck() []byte { + return healthCheckMsg +} diff --git a/relay/messages/message_test.go b/relay/messages/message_test.go new file mode 100644 index 000000000..6e917da71 --- /dev/null +++ b/relay/messages/message_test.go @@ -0,0 +1,59 @@ +package messages + +import ( + "testing" +) + +func TestMarshalHelloMsg(t *testing.T) { + peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") + bHello, err := MarshalHelloMsg(peerID, nil) + if err != nil { + t.Fatalf("error: %v", err) + } + + receivedPeerID, _, err := UnmarshalHelloMsg(bHello[SizeOfProtoHeader:]) + if err != nil { + t.Fatalf("error: %v", err) + } + if string(receivedPeerID) != string(peerID) { + t.Errorf("expected %s, got %s", peerID, receivedPeerID) + } +} + +func TestMarshalAuthMsg(t *testing.T) { + peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") + bHello, err := MarshalAuthMsg(peerID, []byte{}) + if err != nil { + t.Fatalf("error: %v", err) + } + + receivedPeerID, _, err := UnmarshalAuthMsg(bHello[SizeOfProtoHeader:]) + if err != nil { + t.Fatalf("error: %v", err) + } + if string(receivedPeerID) != string(peerID) { + t.Errorf("expected %s, got %s", peerID, receivedPeerID) + } +} + +func TestMarshalTransportMsg(t *testing.T) { + peerID := []byte("abdFAaBcawquEiCMzAabYosuUaGLtSNhKxz+") + payload := []byte("payload") + msg, err := MarshalTransportMsg(peerID, payload) + if err != nil { + t.Fatalf("error: %v", err) + } + + id, respPayload, err := UnmarshalTransportMsg(msg[SizeOfProtoHeader:]) + if err != nil { + t.Fatalf("error: %v", err) + } + + if string(id) != string(peerID) { + t.Errorf("expected %s, got %s", peerID, id) + } + + if string(respPayload) != string(payload) { + t.Errorf("expected %s, got %s", payload, respPayload) + } +} diff --git a/relay/metrics/realy.go b/relay/metrics/realy.go new file mode 100644 index 000000000..13799713a --- /dev/null +++ b/relay/metrics/realy.go @@ -0,0 +1,136 @@ +package metrics + +import ( + "context" + "sync" + "time" + + log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/metric" +) + +const ( + idleTimeout = 30 * time.Second +) + +type Metrics struct { + metric.Meter + + TransferBytesSent metric.Int64Counter + TransferBytesRecv metric.Int64Counter + + peers metric.Int64UpDownCounter + peerActivityChan chan string + peerLastActive map[string]time.Time + mutexActivity sync.Mutex + ctx context.Context +} + +func NewMetrics(ctx context.Context, meter metric.Meter) (*Metrics, error) { + bytesSent, err := meter.Int64Counter("relay_transfer_sent_bytes_total") + if err != nil { + return nil, err + } + + bytesRecv, err := meter.Int64Counter("relay_transfer_received_bytes_total") + if err != nil { + return nil, err + } + + peers, err := meter.Int64UpDownCounter("relay_peers") + if err != nil { + return nil, err + } + + peersActive, err := meter.Int64ObservableGauge("relay_peers_active") + if err != nil { + return nil, err + } + + peersIdle, err := meter.Int64ObservableGauge("relay_peers_idle") + if err != nil { + return nil, err + } + + m := &Metrics{ + Meter: meter, + TransferBytesSent: bytesSent, + TransferBytesRecv: bytesRecv, + peers: peers, + + ctx: ctx, + peerActivityChan: make(chan string, 10), + peerLastActive: make(map[string]time.Time), + } + + _, err = meter.RegisterCallback( + func(ctx context.Context, o metric.Observer) error { + active, idle := m.calculateActiveIdleConnections() + o.ObserveInt64(peersActive, active) + o.ObserveInt64(peersIdle, idle) + return nil + }, + peersActive, peersIdle, + ) + if err != nil { + return nil, err + } + + go m.readPeerActivity() + return m, nil +} + +// PeerConnected increments the number of connected peers and increments number of idle connections +func (m *Metrics) PeerConnected(id string) { + m.peers.Add(m.ctx, 1) + m.mutexActivity.Lock() + defer m.mutexActivity.Unlock() + + m.peerLastActive[id] = time.Time{} +} + +// PeerDisconnected decrements the number of connected peers and decrements number of idle or active connections +func (m *Metrics) PeerDisconnected(id string) { + m.peers.Add(m.ctx, -1) + m.mutexActivity.Lock() + defer m.mutexActivity.Unlock() + + delete(m.peerLastActive, id) +} + +// PeerActivity increases the active connections +func (m *Metrics) PeerActivity(peerID string) { + select { + case m.peerActivityChan <- peerID: + default: + log.Tracef("peer activity channel is full, dropping activity metrics for peer %s", peerID) + } +} + +func (m *Metrics) calculateActiveIdleConnections() (int64, int64) { + active, idle := int64(0), int64(0) + m.mutexActivity.Lock() + defer m.mutexActivity.Unlock() + + for _, lastActive := range m.peerLastActive { + if time.Since(lastActive) > idleTimeout { + idle++ + } else { + active++ + } + } + return active, idle +} + +func (m *Metrics) readPeerActivity() { + for { + select { + case peerID := <-m.peerActivityChan: + m.mutexActivity.Lock() + m.peerLastActive[peerID] = time.Now() + m.mutexActivity.Unlock() + case <-m.ctx.Done(): + return + } + } +} diff --git a/relay/server/listener/listener.go b/relay/server/listener/listener.go new file mode 100644 index 000000000..535c8bcd9 --- /dev/null +++ b/relay/server/listener/listener.go @@ -0,0 +1,11 @@ +package listener + +import ( + "context" + "net" +) + +type Listener interface { + Listen(func(conn net.Conn)) error + Shutdown(ctx context.Context) error +} diff --git a/relay/server/listener/ws/conn.go b/relay/server/listener/ws/conn.go new file mode 100644 index 000000000..c248963b9 --- /dev/null +++ b/relay/server/listener/ws/conn.go @@ -0,0 +1,114 @@ +package ws + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "sync" + "time" + + log "github.com/sirupsen/logrus" + "nhooyr.io/websocket" +) + +const ( + writeTimeout = 10 * time.Second +) + +type Conn struct { + *websocket.Conn + lAddr *net.TCPAddr + rAddr *net.TCPAddr + + closed bool + closedMu sync.Mutex + ctx context.Context +} + +func NewConn(wsConn *websocket.Conn, lAddr, rAddr *net.TCPAddr) *Conn { + return &Conn{ + Conn: wsConn, + lAddr: lAddr, + rAddr: rAddr, + ctx: context.Background(), + } +} + +func (c *Conn) Read(b []byte) (n int, err error) { + t, r, err := c.Reader(c.ctx) + if err != nil { + return 0, c.ioErrHandling(err) + } + + if t != websocket.MessageBinary { + log.Errorf("unexpected message type: %d", t) + return 0, fmt.Errorf("unexpected message type") + } + + n, err = r.Read(b) + if err != nil { + return 0, c.ioErrHandling(err) + } + return n, err +} + +// Write writes a binary message with the given payload. +// It does not block until fill the internal buffer. +// If the buffer filled up, wait until the buffer is drained or timeout. +func (c *Conn) Write(b []byte) (int, error) { + ctx, ctxCancel := context.WithTimeout(c.ctx, writeTimeout) + defer ctxCancel() + + err := c.Conn.Write(ctx, websocket.MessageBinary, b) + return len(b), err +} + +func (c *Conn) LocalAddr() net.Addr { + return c.lAddr +} + +func (c *Conn) RemoteAddr() net.Addr { + return c.rAddr +} + +func (c *Conn) SetReadDeadline(t time.Time) error { + return fmt.Errorf("SetReadDeadline is not implemented") +} + +func (c *Conn) SetWriteDeadline(t time.Time) error { + return fmt.Errorf("SetWriteDeadline is not implemented") +} + +func (c *Conn) SetDeadline(t time.Time) error { + return fmt.Errorf("SetDeadline is not implemented") +} + +func (c *Conn) Close() error { + c.closedMu.Lock() + c.closed = true + c.closedMu.Unlock() + return c.Conn.CloseNow() +} + +func (c *Conn) isClosed() bool { + c.closedMu.Lock() + defer c.closedMu.Unlock() + return c.closed +} + +func (c *Conn) ioErrHandling(err error) error { + if c.isClosed() { + return io.EOF + } + + var wErr *websocket.CloseError + if !errors.As(err, &wErr) { + return err + } + if wErr.Code == websocket.StatusNormalClosure { + return io.EOF + } + return err +} diff --git a/relay/server/listener/ws/listener.go b/relay/server/listener/ws/listener.go new file mode 100644 index 000000000..10bfbe44d --- /dev/null +++ b/relay/server/listener/ws/listener.go @@ -0,0 +1,92 @@ +package ws + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net" + "net/http" + + log "github.com/sirupsen/logrus" + "nhooyr.io/websocket" +) + +// URLPath is the path for the websocket connection. +const URLPath = "/relay" + +type Listener struct { + // Address is the address to listen on. + Address string + // TLSConfig is the TLS configuration for the server. + TLSConfig *tls.Config + + server *http.Server + acceptFn func(conn net.Conn) +} + +func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { + l.acceptFn = acceptFn + mux := http.NewServeMux() + mux.HandleFunc(URLPath, l.onAccept) + + l.server = &http.Server{ + Addr: l.Address, + Handler: mux, + TLSConfig: l.TLSConfig, + } + + log.Infof("WS server listening address: %s", l.Address) + var err error + if l.TLSConfig != nil { + err = l.server.ListenAndServeTLS("", "") + } else { + err = l.server.ListenAndServe() + } + if errors.Is(err, http.ErrServerClosed) { + return nil + } + return err +} + +func (l *Listener) Shutdown(ctx context.Context) error { + if l.server == nil { + return nil + } + + log.Infof("stop WS listener") + if err := l.server.Shutdown(ctx); err != nil { + return fmt.Errorf("server shutdown failed: %v", err) + } + log.Infof("WS listener stopped") + return nil +} + +func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) { + wsConn, err := websocket.Accept(w, r, nil) + if err != nil { + log.Errorf("failed to accept ws connection from %s: %s", r.RemoteAddr, err) + return + } + + rAddr, err := net.ResolveTCPAddr("tcp", r.RemoteAddr) + if err != nil { + err = wsConn.Close(websocket.StatusInternalError, "internal error") + if err != nil { + log.Errorf("failed to close ws connection: %s", err) + } + return + } + + lAddr, err := net.ResolveTCPAddr("tcp", l.server.Addr) + if err != nil { + err = wsConn.Close(websocket.StatusInternalError, "internal error") + if err != nil { + log.Errorf("failed to close ws connection: %s", err) + } + return + } + + conn := NewConn(wsConn, lAddr, rAddr) + l.acceptFn(conn) +} diff --git a/relay/server/peer.go b/relay/server/peer.go new file mode 100644 index 000000000..a9c542f84 --- /dev/null +++ b/relay/server/peer.go @@ -0,0 +1,212 @@ +package server + +import ( + "context" + "io" + "net" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/relay/healthcheck" + "github.com/netbirdio/netbird/relay/messages" + "github.com/netbirdio/netbird/relay/metrics" +) + +const ( + bufferSize = 8820 +) + +// Peer represents a peer connection +type Peer struct { + metrics *metrics.Metrics + log *log.Entry + idS string + idB []byte + conn net.Conn + connMu sync.RWMutex + store *Store +} + +// NewPeer creates a new Peer instance and prepare custom logging +func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) *Peer { + stringID := messages.HashIDToString(id) + return &Peer{ + metrics: metrics, + log: log.WithField("peer_id", stringID), + idS: stringID, + idB: id, + conn: conn, + store: store, + } +} + +// Work reads data from the connection +// It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle +// the message accordingly. +func (p *Peer) Work() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hc := healthcheck.NewSender(p.log) + go hc.StartHealthCheck(ctx) + go p.handleHealthcheckEvents(ctx, hc) + + buf := make([]byte, bufferSize) + for { + n, err := p.conn.Read(buf) + if err != nil { + if err != io.EOF { + p.log.Errorf("failed to read message: %s", err) + } + return + } + + if n == 0 { + p.log.Errorf("received empty message") + return + } + + msg := buf[:n] + + _, err = messages.ValidateVersion(msg) + if err != nil { + p.log.Warnf("failed to validate protocol version: %s", err) + return + } + + msgType, err := messages.DetermineClientMessageType(msg[messages.SizeOfVersionByte:]) + if err != nil { + p.log.Errorf("failed to determine message type: %s", err) + return + } + + p.handleMsgType(ctx, msgType, hc, n, msg) + } +} + +func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *healthcheck.Sender, n int, msg []byte) { + switch msgType { + case messages.MsgTypeHealthCheck: + hc.OnHCResponse() + case messages.MsgTypeTransport: + p.metrics.TransferBytesRecv.Add(ctx, int64(n)) + p.metrics.PeerActivity(p.String()) + p.handleTransportMsg(msg) + case messages.MsgTypeClose: + p.log.Infof("peer exited gracefully") + if err := p.conn.Close(); err != nil { + log.Errorf("failed to close connection to peer: %s", err) + } + default: + p.log.Warnf("received unexpected message type: %s", msgType) + } +} + +// Write writes data to the connection +func (p *Peer) Write(b []byte) (int, error) { + p.connMu.RLock() + defer p.connMu.RUnlock() + return p.conn.Write(b) +} + +// CloseGracefully closes the connection with the peer gracefully. Send a close message to the client and close the +// connection. +func (p *Peer) CloseGracefully(ctx context.Context) { + p.connMu.Lock() + defer p.connMu.Unlock() + err := p.writeWithTimeout(ctx, messages.MarshalCloseMsg()) + if err != nil { + p.log.Errorf("failed to send close message to peer: %s", p.String()) + } + + err = p.conn.Close() + if err != nil { + p.log.Errorf("failed to close connection to peer: %s", err) + } +} + +func (p *Peer) Close() { + p.connMu.Lock() + defer p.connMu.Unlock() + + if err := p.conn.Close(); err != nil { + p.log.Errorf("failed to close connection to peer: %s", err) + } +} + +// String returns the peer ID +func (p *Peer) String() string { + return p.idS +} + +func (p *Peer) writeWithTimeout(ctx context.Context, buf []byte) error { + ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + defer cancel() + + writeDone := make(chan struct{}) + var err error + go func() { + _, err = p.conn.Write(buf) + close(writeDone) + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-writeDone: + return err + } +} + +func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Sender) { + for { + select { + case <-hc.HealthCheck: + _, err := p.Write(messages.MarshalHealthcheck()) + if err != nil { + p.log.Errorf("failed to send healthcheck message: %s", err) + return + } + case <-hc.Timeout: + p.log.Errorf("peer healthcheck timeout") + err := p.conn.Close() + if err != nil { + p.log.Errorf("failed to close connection to peer: %s", err) + } + p.log.Info("peer connection closed due healthcheck timeout") + return + case <-ctx.Done(): + return + } + } +} + +func (p *Peer) handleTransportMsg(msg []byte) { + peerID, err := messages.UnmarshalTransportID(msg[messages.SizeOfProtoHeader:]) + if err != nil { + p.log.Errorf("failed to unmarshal transport message: %s", err) + return + } + + stringPeerID := messages.HashIDToString(peerID) + dp, ok := p.store.Peer(stringPeerID) + if !ok { + p.log.Debugf("peer not found: %s", stringPeerID) + return + } + + err = messages.UpdateTransportMsg(msg[messages.SizeOfProtoHeader:], p.idB) + if err != nil { + p.log.Errorf("failed to update transport message: %s", err) + return + } + + n, err := dp.Write(msg) + if err != nil { + p.log.Errorf("failed to write transport message to: %s", dp.String()) + return + } + p.metrics.TransferBytesSent.Add(context.Background(), int64(n)) +} diff --git a/relay/server/relay.go b/relay/server/relay.go new file mode 100644 index 000000000..76c01a697 --- /dev/null +++ b/relay/server/relay.go @@ -0,0 +1,249 @@ +package server + +import ( + "context" + "fmt" + "net" + "net/url" + "strings" + "sync" + + log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/metric" + + "github.com/netbirdio/netbird/relay/auth" + "github.com/netbirdio/netbird/relay/messages" + //nolint:staticcheck + "github.com/netbirdio/netbird/relay/messages/address" + //nolint:staticcheck + authmsg "github.com/netbirdio/netbird/relay/messages/auth" + "github.com/netbirdio/netbird/relay/metrics" +) + +// Relay represents the relay server +type Relay struct { + metrics *metrics.Metrics + metricsCancel context.CancelFunc + validator auth.Validator + + store *Store + instanceURL string + + closed bool + closeMu sync.RWMutex +} + +// NewRelay creates a new Relay instance +// +// Parameters: +// meter: An instance of metric.Meter from the go.opentelemetry.io/otel/metric package. It is used to create and manage +// metrics for the relay server. +// exposedAddress: A string representing the address that the relay server is exposed on. The client will use this +// address as the relay server's instance URL. +// tlsSupport: A boolean indicating whether the relay server supports TLS (Transport Layer Security) or not. The +// instance URL depends on this value. +// validator: An instance of auth.Validator from the auth package. It is used to validate the authentication of the +// peers. +// +// Returns: +// A pointer to a Relay instance and an error. If the Relay instance is successfully created, the error is nil. +// Otherwise, the error contains the details of what went wrong. +func NewRelay(meter metric.Meter, exposedAddress string, tlsSupport bool, validator auth.Validator) (*Relay, error) { + ctx, metricsCancel := context.WithCancel(context.Background()) + m, err := metrics.NewMetrics(ctx, meter) + if err != nil { + metricsCancel() + return nil, fmt.Errorf("creating app metrics: %v", err) + } + + r := &Relay{ + metrics: m, + metricsCancel: metricsCancel, + validator: validator, + store: NewStore(), + } + + r.instanceURL, err = getInstanceURL(exposedAddress, tlsSupport) + if err != nil { + metricsCancel() + return nil, fmt.Errorf("get instance URL: %v", err) + } + + return r, nil +} + +// getInstanceURL checks if user supplied a URL scheme otherwise adds to the +// provided address according to TLS definition and parses the address before returning it +func getInstanceURL(exposedAddress string, tlsSupported bool) (string, error) { + addr := exposedAddress + split := strings.Split(exposedAddress, "://") + switch { + case len(split) == 1 && tlsSupported: + addr = "rels://" + exposedAddress + case len(split) == 1 && !tlsSupported: + addr = "rel://" + exposedAddress + case len(split) > 2: + return "", fmt.Errorf("invalid exposed address: %s", exposedAddress) + } + + parsedURL, err := url.ParseRequestURI(addr) + if err != nil { + return "", fmt.Errorf("invalid exposed address: %v", err) + } + + if parsedURL.Scheme != "rel" && parsedURL.Scheme != "rels" { + return "", fmt.Errorf("invalid scheme: %s", parsedURL.Scheme) + } + + return parsedURL.String(), nil +} + +// Accept start to handle a new peer connection +func (r *Relay) Accept(conn net.Conn) { + r.closeMu.RLock() + defer r.closeMu.RUnlock() + if r.closed { + return + } + + peerID, err := r.handshake(conn) + if err != nil { + log.Errorf("failed to handshake: %s", err) + cErr := conn.Close() + if cErr != nil { + log.Errorf("failed to close connection, %s: %s", conn.RemoteAddr(), cErr) + } + return + } + + peer := NewPeer(r.metrics, peerID, conn, r.store) + peer.log.Infof("peer connected from: %s", conn.RemoteAddr()) + r.store.AddPeer(peer) + r.metrics.PeerConnected(peer.String()) + go func() { + peer.Work() + r.store.DeletePeer(peer) + peer.log.Debugf("relay connection closed") + r.metrics.PeerDisconnected(peer.String()) + }() +} + +// Shutdown closes the relay server +// It closes the connection with all peers in gracefully and stops accepting new connections. +func (r *Relay) Shutdown(ctx context.Context) { + log.Infof("close connection with all peers") + r.closeMu.Lock() + wg := sync.WaitGroup{} + peers := r.store.Peers() + for _, peer := range peers { + wg.Add(1) + go func(p *Peer) { + p.CloseGracefully(ctx) + wg.Done() + }(peer) + } + wg.Wait() + r.metricsCancel() + r.closeMu.Unlock() +} + +// InstanceURL returns the instance URL of the relay server +func (r *Relay) InstanceURL() string { + return r.instanceURL +} + +func (r *Relay) handshake(conn net.Conn) ([]byte, error) { + buf := make([]byte, messages.MaxHandshakeSize) + n, err := conn.Read(buf) + if err != nil { + return nil, fmt.Errorf("read from %s: %w", conn.RemoteAddr(), err) + } + + _, err = messages.ValidateVersion(buf[:n]) + if err != nil { + return nil, fmt.Errorf("validate version from %s: %w", conn.RemoteAddr(), err) + } + + msgType, err := messages.DetermineClientMessageType(buf[messages.SizeOfVersionByte:n]) + if err != nil { + return nil, fmt.Errorf("determine message type from %s: %w", conn.RemoteAddr(), err) + } + + var ( + responseMsg []byte + peerID []byte + ) + switch msgType { + //nolint:staticcheck + case messages.MsgTypeHello: + peerID, responseMsg, err = r.handleHelloMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr()) + case messages.MsgTypeAuth: + peerID, responseMsg, err = r.handleAuthMsg(buf[messages.SizeOfProtoHeader:n], conn.RemoteAddr()) + default: + return nil, fmt.Errorf("invalid message type %d from %s", msgType, conn.RemoteAddr()) + } + if err != nil { + return nil, err + } + + _, err = conn.Write(responseMsg) + if err != nil { + return nil, fmt.Errorf("write to %s (%s): %w", peerID, conn.RemoteAddr(), err) + } + + return peerID, nil +} + +func (r *Relay) handleHelloMsg(buf []byte, remoteAddr net.Addr) ([]byte, []byte, error) { + //nolint:staticcheck + rawPeerID, authData, err := messages.UnmarshalHelloMsg(buf) + if err != nil { + return nil, nil, fmt.Errorf("unmarshal hello message: %w", err) + } + + peerID := messages.HashIDToString(rawPeerID) + log.Warnf("peer %s (%s) is using deprecated initial message type", peerID, remoteAddr) + + authMsg, err := authmsg.UnmarshalMsg(authData) + if err != nil { + return nil, nil, fmt.Errorf("unmarshal auth message: %w", err) + } + + //nolint:staticcheck + if err := r.validator.ValidateHelloMsgType(authMsg.AdditionalData); err != nil { + return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, remoteAddr, err) + } + + addr := &address.Address{URL: r.instanceURL} + addrData, err := addr.Marshal() + if err != nil { + return nil, nil, fmt.Errorf("marshal addressc to %s (%s): %w", peerID, remoteAddr, err) + } + + //nolint:staticcheck + responseMsg, err := messages.MarshalHelloResponse(addrData) + if err != nil { + return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, remoteAddr, err) + } + return rawPeerID, responseMsg, nil +} + +func (r *Relay) handleAuthMsg(buf []byte, addr net.Addr) ([]byte, []byte, error) { + rawPeerID, authPayload, err := messages.UnmarshalAuthMsg(buf) + if err != nil { + return nil, nil, fmt.Errorf("unmarshal hello message: %w", err) + } + + peerID := messages.HashIDToString(rawPeerID) + + if err := r.validator.Validate(authPayload); err != nil { + return nil, nil, fmt.Errorf("validate %s (%s): %w", peerID, addr, err) + } + + responseMsg, err := messages.MarshalAuthResponse(r.instanceURL) + if err != nil { + return nil, nil, fmt.Errorf("marshal hello response to %s (%s): %w", peerID, addr, err) + } + + return rawPeerID, responseMsg, nil +} diff --git a/relay/server/relay_test.go b/relay/server/relay_test.go new file mode 100644 index 000000000..062039ab9 --- /dev/null +++ b/relay/server/relay_test.go @@ -0,0 +1,36 @@ +package server + +import "testing" + +func TestGetInstanceURL(t *testing.T) { + tests := []struct { + name string + exposedAddress string + tlsSupported bool + expectedURL string + expectError bool + }{ + {"Valid address with TLS", "example.com", true, "rels://example.com", false}, + {"Valid address without TLS", "example.com", false, "rel://example.com", false}, + {"Valid address with scheme", "rel://example.com", false, "rel://example.com", false}, + {"Valid address with non TLS scheme and TLS true", "rel://example.com", true, "rel://example.com", false}, + {"Valid address with TLS scheme", "rels://example.com", true, "rels://example.com", false}, + {"Valid address with TLS scheme and TLS false", "rels://example.com", false, "rels://example.com", false}, + {"Valid address with TLS scheme and custom port", "rels://example.com:9300", true, "rels://example.com:9300", false}, + {"Invalid address with multiple schemes", "rel://rels://example.com", false, "", true}, + {"Invalid address with unsupported scheme", "http://example.com", false, "", true}, + {"Invalid address format", "://example.com", false, "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + url, err := getInstanceURL(tt.exposedAddress, tt.tlsSupported) + if (err != nil) != tt.expectError { + t.Errorf("expected error: %v, got: %v", tt.expectError, err) + } + if url != tt.expectedURL { + t.Errorf("expected URL: %s, got: %s", tt.expectedURL, url) + } + }) + } +} diff --git a/relay/server/server.go b/relay/server/server.go new file mode 100644 index 000000000..0036e2390 --- /dev/null +++ b/relay/server/server.go @@ -0,0 +1,76 @@ +package server + +import ( + "context" + "crypto/tls" + + log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/metric" + + "github.com/netbirdio/netbird/relay/auth" + "github.com/netbirdio/netbird/relay/server/listener" + "github.com/netbirdio/netbird/relay/server/listener/ws" +) + +// ListenerConfig is the configuration for the listener. +// Address: the address to bind the listener to. It could be an address behind a reverse proxy. +// TLSConfig: the TLS configuration for the listener. +type ListenerConfig struct { + Address string + TLSConfig *tls.Config +} + +// Server is the main entry point for the relay server. +// It is the gate between the WebSocket listener and the Relay server logic. +// In a new HTTP connection, the server will accept the connection and pass it to the Relay server via the Accept method. +type Server struct { + relay *Relay + wSListener listener.Listener +} + +// NewServer creates a new relay server instance. +// meter: the OpenTelemetry meter +// exposedAddress: this address will be used as the instance URL. It should be a domain:port format. +// tlsSupport: if true, the server will support TLS +// authValidator: the auth validator to use for the server +func NewServer(meter metric.Meter, exposedAddress string, tlsSupport bool, authValidator auth.Validator) (*Server, error) { + relay, err := NewRelay(meter, exposedAddress, tlsSupport, authValidator) + if err != nil { + return nil, err + } + return &Server{ + relay: relay, + }, nil +} + +// Listen starts the relay server. +func (r *Server) Listen(cfg ListenerConfig) error { + r.wSListener = &ws.Listener{ + Address: cfg.Address, + TLSConfig: cfg.TLSConfig, + } + + wslErr := r.wSListener.Listen(r.relay.Accept) + if wslErr != nil { + log.Errorf("failed to bind ws server: %s", wslErr) + } + + return wslErr +} + +// Shutdown stops the relay server. If there are active connections, they will be closed gracefully. In case of a context, +// the connections will be forcefully closed. +func (r *Server) Shutdown(ctx context.Context) (err error) { + // stop service new connections + if r.wSListener != nil { + err = r.wSListener.Shutdown(ctx) + } + + r.relay.Shutdown(ctx) + return +} + +// InstanceURL returns the instance URL of the relay server. +func (r *Server) InstanceURL() string { + return r.relay.instanceURL +} diff --git a/relay/server/store.go b/relay/server/store.go new file mode 100644 index 000000000..4288e62c5 --- /dev/null +++ b/relay/server/store.go @@ -0,0 +1,68 @@ +package server + +import ( + "sync" +) + +// Store is a thread-safe store of peers +// It is used to store the peers that are connected to the relay server +type Store struct { + peers map[string]*Peer // consider to use [32]byte as key. The Peer(id string) would be faster + peersLock sync.RWMutex +} + +// NewStore creates a new Store instance +func NewStore() *Store { + return &Store{ + peers: make(map[string]*Peer), + } +} + +// AddPeer adds a peer to the store +func (s *Store) AddPeer(peer *Peer) { + s.peersLock.Lock() + defer s.peersLock.Unlock() + odlPeer, ok := s.peers[peer.String()] + if ok { + odlPeer.Close() + } + + s.peers[peer.String()] = peer +} + +// DeletePeer deletes a peer from the store +func (s *Store) DeletePeer(peer *Peer) { + s.peersLock.Lock() + defer s.peersLock.Unlock() + + dp, ok := s.peers[peer.String()] + if !ok { + return + } + if dp != peer { + return + } + + delete(s.peers, peer.String()) +} + +// Peer returns a peer by its ID +func (s *Store) Peer(id string) (*Peer, bool) { + s.peersLock.RLock() + defer s.peersLock.RUnlock() + + p, ok := s.peers[id] + return p, ok +} + +// Peers returns all the peers in the store +func (s *Store) Peers() []*Peer { + s.peersLock.RLock() + defer s.peersLock.RUnlock() + + peers := make([]*Peer, 0, len(s.peers)) + for _, p := range s.peers { + peers = append(peers, p) + } + return peers +} diff --git a/relay/server/store_test.go b/relay/server/store_test.go new file mode 100644 index 000000000..41c7baa92 --- /dev/null +++ b/relay/server/store_test.go @@ -0,0 +1,85 @@ +package server + +import ( + "context" + "net" + "testing" + "time" + + "go.opentelemetry.io/otel" + + "github.com/netbirdio/netbird/relay/metrics" +) + +type mockConn struct { +} + +func (m mockConn) Read(b []byte) (n int, err error) { + //TODO implement me + panic("implement me") +} + +func (m mockConn) Write(b []byte) (n int, err error) { + //TODO implement me + panic("implement me") +} + +func (m mockConn) Close() error { + return nil +} + +func (m mockConn) LocalAddr() net.Addr { + //TODO implement me + panic("implement me") +} + +func (m mockConn) RemoteAddr() net.Addr { + //TODO implement me + panic("implement me") +} + +func (m mockConn) SetDeadline(t time.Time) error { + //TODO implement me + panic("implement me") +} + +func (m mockConn) SetReadDeadline(t time.Time) error { + //TODO implement me + panic("implement me") +} + +func (m mockConn) SetWriteDeadline(t time.Time) error { + //TODO implement me + panic("implement me") +} + +func TestStore_DeletePeer(t *testing.T) { + s := NewStore() + + m, _ := metrics.NewMetrics(context.Background(), otel.Meter("")) + + p := NewPeer(m, []byte("peer_one"), nil, nil) + s.AddPeer(p) + s.DeletePeer(p) + if _, ok := s.Peer(p.String()); ok { + t.Errorf("peer was not deleted") + } +} + +func TestStore_DeleteDeprecatedPeer(t *testing.T) { + s := NewStore() + + m, _ := metrics.NewMetrics(context.Background(), otel.Meter("")) + + conn := &mockConn{} + p1 := NewPeer(m, []byte("peer_id"), conn, nil) + p2 := NewPeer(m, []byte("peer_id"), conn, nil) + + s.AddPeer(p1) + s.AddPeer(p2) + s.DeletePeer(p1) + + if _, ok := s.Peer(p2.String()); !ok { + t.Errorf("second peer was deleted") + } +} diff --git a/relay/test/benchmark_test.go b/relay/test/benchmark_test.go new file mode 100644 index 000000000..ec2aa488c --- /dev/null +++ b/relay/test/benchmark_test.go @@ -0,0 +1,386 @@ +package test + +import ( + "context" + "crypto/rand" + "fmt" + "net" + "os" + "sync" + "testing" + "time" + + "github.com/pion/logging" + "github.com/pion/turn/v3" + "go.opentelemetry.io/otel" + + "github.com/netbirdio/netbird/relay/auth/allow" + "github.com/netbirdio/netbird/relay/auth/hmac" + "github.com/netbirdio/netbird/relay/client" + "github.com/netbirdio/netbird/relay/server" + "github.com/netbirdio/netbird/util" +) + +var ( + av = &allow.Auth{} + hmacTokenStore = &hmac.TokenStore{} + pairs = []int{1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100} + dataSize = 1024 * 1024 * 10 +) + +func TestMain(m *testing.M) { + _ = util.InitLog("error", "console") + code := m.Run() + os.Exit(code) +} + +func TestRelayDataTransfer(t *testing.T) { + t.SkipNow() // skip this test on CI because it is a benchmark test + testData, err := seedRandomData(dataSize) + if err != nil { + t.Fatalf("failed to seed random data: %s", err) + } + + for _, peerPairs := range pairs { + t.Run(fmt.Sprintf("peerPairs-%d", peerPairs), func(t *testing.T) { + transfer(t, testData, peerPairs) + }) + } +} + +// TestTurnDataTransfer run turn server: +// docker run --rm --name coturn -d --network=host coturn/coturn --user test:test +func TestTurnDataTransfer(t *testing.T) { + t.SkipNow() // skip this test on CI because it is a benchmark test + testData, err := seedRandomData(dataSize) + if err != nil { + t.Fatalf("failed to seed random data: %s", err) + } + + for _, peerPairs := range pairs { + t.Run(fmt.Sprintf("peerPairs-%d", peerPairs), func(t *testing.T) { + runTurnTest(t, testData, peerPairs) + }) + } +} + +func transfer(t *testing.T, testData []byte, peerPairs int) { + t.Helper() + ctx := context.Background() + port := 35000 + peerPairs + serverAddress := fmt.Sprintf("127.0.0.1:%d", port) + serverConnURL := fmt.Sprintf("rel://%s", serverAddress) + + srv, err := server.NewServer(otel.Meter(""), serverConnURL, false, av) + if err != nil { + t.Fatalf("failed to create server: %s", err) + } + errChan := make(chan error, 1) + go func() { + listenCfg := server.ListenerConfig{Address: serverAddress} + err := srv.Listen(listenCfg) + if err != nil { + errChan <- err + } + }() + + defer func() { + err := srv.Shutdown(ctx) + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + + // wait for server to start + if err := waitForServerToStart(errChan); err != nil { + t.Fatalf("failed to start server: %s", err) + } + + clientsSender := make([]*client.Client, peerPairs) + for i := 0; i < cap(clientsSender); i++ { + c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i)) + err := c.Connect() + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + clientsSender[i] = c + } + + clientsReceiver := make([]*client.Client, peerPairs) + for i := 0; i < cap(clientsReceiver); i++ { + c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i)) + err := c.Connect() + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + clientsReceiver[i] = c + } + + connsSender := make([]net.Conn, 0, peerPairs) + connsReceiver := make([]net.Conn, 0, peerPairs) + for i := 0; i < len(clientsSender); i++ { + conn, err := clientsSender[i].OpenConn("receiver-" + fmt.Sprint(i)) + if err != nil { + t.Fatalf("failed to bind channel: %s", err) + } + connsSender = append(connsSender, conn) + + conn, err = clientsReceiver[i].OpenConn("sender-" + fmt.Sprint(i)) + if err != nil { + t.Fatalf("failed to bind channel: %s", err) + } + connsReceiver = append(connsReceiver, conn) + } + + var transferDuration []time.Duration + wg := sync.WaitGroup{} + var writeErr error + var readErr error + for i := 0; i < len(connsSender); i++ { + wg.Add(2) + start := time.Now() + go func(i int) { + defer wg.Done() + pieceSize := 1024 + testDataLen := len(testData) + + for j := 0; j < testDataLen; j += pieceSize { + end := j + pieceSize + if end > testDataLen { + end = testDataLen + } + _, writeErr = connsSender[i].Write(testData[j:end]) + if writeErr != nil { + return + } + } + + }(i) + + go func(i int, start time.Time) { + defer wg.Done() + buf := make([]byte, 8192) + rcv := 0 + var n int + for receivedSize := 0; receivedSize < len(testData); { + + n, readErr = connsReceiver[i].Read(buf) + if readErr != nil { + return + } + + receivedSize += n + rcv += n + } + transferDuration = append(transferDuration, time.Since(start)) + }(i, start) + } + + wg.Wait() + + if writeErr != nil { + t.Fatalf("failed to write to channel: %s", err) + } + + if readErr != nil { + t.Fatalf("failed to read from channel: %s", err) + } + + // calculate the megabytes per second from the average transferDuration against the dataSize + var totalDuration time.Duration + for _, d := range transferDuration { + totalDuration += d + } + avgDuration := totalDuration / time.Duration(len(transferDuration)) + mbps := float64(len(testData)) / avgDuration.Seconds() / 1024 / 1024 + t.Logf("average transfer duration: %s", avgDuration) + t.Logf("average transfer speed: %.2f MB/s", mbps) + + for i := 0; i < len(connsSender); i++ { + err := connsSender[i].Close() + if err != nil { + t.Errorf("failed to close connection: %s", err) + } + + err = connsReceiver[i].Close() + if err != nil { + t.Errorf("failed to close connection: %s", err) + } + } +} + +func runTurnTest(t *testing.T, testData []byte, maxPairs int) { + t.Helper() + var transferDuration []time.Duration + var wg sync.WaitGroup + + for i := 0; i < maxPairs; i++ { + wg.Add(1) + go func() { + defer wg.Done() + d := runTurnDataTransfer(t, testData) + transferDuration = append(transferDuration, d) + }() + + } + wg.Wait() + + var totalDuration time.Duration + for _, d := range transferDuration { + totalDuration += d + } + avgDuration := totalDuration / time.Duration(len(transferDuration)) + mbps := float64(len(testData)) / avgDuration.Seconds() / 1024 / 1024 + t.Logf("average transfer duration: %s", avgDuration) + t.Logf("average transfer speed: %.2f MB/s", mbps) +} + +func runTurnDataTransfer(t *testing.T, testData []byte) time.Duration { + t.Helper() + testDataLen := len(testData) + relayAddress := "192.168.0.10:3478" + conn, err := net.Dial("tcp", relayAddress) + if err != nil { + t.Fatal(err) + } + defer func(conn net.Conn) { + _ = conn.Close() + }(conn) + + turnClient, err := getTurnClient(t, relayAddress, conn) + if err != nil { + t.Fatal(err) + } + defer turnClient.Close() + + relayConn, err := turnClient.Allocate() + if err != nil { + t.Fatal(err) + } + defer func(relayConn net.PacketConn) { + _ = relayConn.Close() + }(relayConn) + + receiverConn, err := net.Dial("udp", relayConn.LocalAddr().String()) + if err != nil { + t.Fatal(err) + } + defer func(receiverConn net.Conn) { + _ = receiverConn.Close() + }(receiverConn) + + var ( + tb int + start time.Time + timerInit bool + readDone = make(chan struct{}) + ack = make([]byte, 1) + ) + go func() { + defer func() { + readDone <- struct{}{} + }() + buff := make([]byte, 8192) + for { + n, e := receiverConn.Read(buff) + if e != nil { + return + } + if !timerInit { + start = time.Now() + timerInit = true + } + tb += n + _, _ = receiverConn.Write(ack) + + if tb >= testDataLen { + return + } + } + }() + + pieceSize := 1024 + ackBuff := make([]byte, 1) + pipelineSize := 10 + for j := 0; j < testDataLen; j += pieceSize { + end := j + pieceSize + if end > testDataLen { + end = testDataLen + } + _, err := relayConn.WriteTo(testData[j:end], receiverConn.LocalAddr()) + if err != nil { + t.Fatalf("failed to write to channel: %s", err) + } + if pipelineSize == 0 { + _, _, _ = relayConn.ReadFrom(ackBuff) + } else { + pipelineSize-- + } + } + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + select { + case <-readDone: + if tb != testDataLen { + t.Fatalf("failed to read all data: %d/%d", tb, testDataLen) + } + case <-ctx.Done(): + t.Fatal("timeout") + } + return time.Since(start) +} + +func getTurnClient(t *testing.T, address string, conn net.Conn) (*turn.Client, error) { + t.Helper() + // Dial TURN Server + addrStr := fmt.Sprintf("%s:%d", address, 443) + + fac := logging.NewDefaultLoggerFactory() + //fac.DefaultLogLevel = logging.LogLevelTrace + + // Start a new TURN Client and wrap our net.Conn in a STUNConn + // This allows us to simulate datagram based communication over a net.Conn + cfg := &turn.ClientConfig{ + TURNServerAddr: address, + Conn: turn.NewSTUNConn(conn), + Username: "test", + Password: "test", + LoggerFactory: fac, + } + + client, err := turn.NewClient(cfg) + if err != nil { + return nil, fmt.Errorf("failed to create TURN client for server %s: %s", addrStr, err) + } + + // Start listening on the conn provided. + err = client.Listen() + if err != nil { + client.Close() + return nil, fmt.Errorf("failed to listen on TURN client for server %s: %s", addrStr, err) + } + + return client, nil +} + +func seedRandomData(size int) ([]byte, error) { + token := make([]byte, size) + _, err := rand.Read(token) + if err != nil { + return nil, err + } + return token, nil +} + +func waitForServerToStart(errChan chan error) error { + select { + case err := <-errChan: + if err != nil { + return err + } + case <-time.After(300 * time.Millisecond): + return nil + } + return nil +} diff --git a/relay/testec2/main.go b/relay/testec2/main.go new file mode 100644 index 000000000..0c8099a5e --- /dev/null +++ b/relay/testec2/main.go @@ -0,0 +1,258 @@ +//go:build linux || darwin + +package main + +import ( + "crypto/rand" + "flag" + "fmt" + "net" + "os" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/util" +) + +const ( + errMsgFailedReadTCP = "failed to read from tcp: %s" +) + +var ( + dataSize = 1024 * 1024 * 50 // 50MB + pairs = []int{1, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100} + signalListenAddress = ":8081" + + relaySrvAddress string + turnSrvAddress string + signalURL string + udpListener string // used for TURN test +) + +type testResult struct { + numOfPairs int + duration time.Duration + speed float64 +} + +func (tr testResult) Speed() string { + speed := tr.speed + var unit string + + switch { + case speed < 1024: + unit = "B/s" + case speed < 1048576: + speed /= 1024 + unit = "KB/s" + case speed < 1073741824: + speed /= 1048576 + unit = "MB/s" + default: + speed /= 1073741824 + unit = "GB/s" + } + + return fmt.Sprintf("%.2f %s", speed, unit) +} + +func seedRandomData(size int) ([]byte, error) { + token := make([]byte, size) + _, err := rand.Read(token) + if err != nil { + return nil, err + } + return token, nil +} + +func avg(transferDuration []time.Duration) (time.Duration, float64) { + var totalDuration time.Duration + for _, d := range transferDuration { + totalDuration += d + } + avgDuration := totalDuration / time.Duration(len(transferDuration)) + bps := float64(dataSize) / avgDuration.Seconds() + return avgDuration, bps +} + +func RelayReceiverMain() []testResult { + testResults := make([]testResult, 0, len(pairs)) + for _, p := range pairs { + tr := testResult{numOfPairs: p} + td := relayReceive(relaySrvAddress, p) + tr.duration, tr.speed = avg(td) + + testResults = append(testResults, tr) + } + + return testResults +} + +func RelaySenderMain() { + log.Infof("starting sender") + log.Infof("starting seed phase") + + testData, err := seedRandomData(dataSize) + if err != nil { + log.Fatalf("failed to seed random data: %s", err) + } + + log.Infof("data size: %d", len(testData)) + + for n, p := range pairs { + log.Infof("running test with %d pairs", p) + relayTransfer(relaySrvAddress, testData, p) + + // grant time to prepare new receivers + if n < len(pairs)-1 { + time.Sleep(3 * time.Second) + } + } + +} + +// TRUNSenderMain is the sender +// - allocate turn clients +// - send relayed addresses to signal server in batch +// - wait for signal server to send back addresses in a map +// - send test data to each address in parallel +func TRUNSenderMain() { + log.Infof("starting TURN sender test") + + log.Infof("starting seed random data: %d", dataSize) + testData, err := seedRandomData(dataSize) + if err != nil { + log.Fatalf("failed to seed random data: %s", err) + } + + ss := SignalClient{signalURL} + + for _, p := range pairs { + log.Infof("running test with %d pairs", p) + turnSender := &TurnSender{} + + createTurnConns(p, turnSender) + + log.Infof("send addresses via signal server: %d", len(turnSender.addresses)) + clientAddresses, err := ss.SendAddress(turnSender.addresses) + if err != nil { + log.Fatalf("failed to send address: %s", err) + } + log.Infof("received addresses: %v", clientAddresses.Address) + + createSenderDevices(turnSender, clientAddresses) + + log.Infof("waiting for tcpListeners to be ready") + time.Sleep(2 * time.Second) + + tcpConns := make([]net.Conn, 0, len(turnSender.devices)) + for i := range turnSender.devices { + addr := fmt.Sprintf("10.0.%d.2:9999", i) + log.Infof("dialing: %s", addr) + tcpConn, err := net.Dial("tcp", addr) + if err != nil { + log.Fatalf("failed to dial tcp: %s", err) + } + tcpConns = append(tcpConns, tcpConn) + } + + log.Infof("start test data transfer for %d pairs", p) + testDataLen := len(testData) + wg := sync.WaitGroup{} + wg.Add(len(tcpConns)) + for i, tcpConn := range tcpConns { + log.Infof("sending test data to device: %d", i) + go runTurnWriting(tcpConn, testData, testDataLen, &wg) + } + wg.Wait() + + for _, d := range turnSender.devices { + _ = d.Close() + } + + log.Infof("test finished with %d pairs", p) + } +} + +func TURNReaderMain() []testResult { + log.Infof("starting TURN receiver test") + si := NewSignalService() + go func() { + log.Infof("starting signal server") + err := si.Listen(signalListenAddress) + if err != nil { + log.Errorf("failed to listen: %s", err) + } + }() + + testResults := make([]testResult, 0, len(pairs)) + for range pairs { + addresses := <-si.AddressesChan + instanceNumber := len(addresses) + log.Infof("received addresses: %d", instanceNumber) + + turnReceiver := &TurnReceiver{} + err := createDevices(addresses, turnReceiver) + if err != nil { + log.Fatalf("%s", err) + } + + // send client addresses back via signal server + si.ClientAddressChan <- turnReceiver.clientAddresses + + durations := make(chan time.Duration, instanceNumber) + for _, device := range turnReceiver.devices { + go runTurnReading(device, durations) + } + + durationsList := make([]time.Duration, 0, instanceNumber) + for d := range durations { + durationsList = append(durationsList, d) + if len(durationsList) == instanceNumber { + close(durations) + } + } + + avgDuration, avgSpeed := avg(durationsList) + ts := testResult{ + numOfPairs: len(durationsList), + duration: avgDuration, + speed: avgSpeed, + } + testResults = append(testResults, ts) + + for _, d := range turnReceiver.devices { + _ = d.Close() + } + } + return testResults +} + +func main() { + var mode string + + _ = util.InitLog("debug", "console") + flag.StringVar(&mode, "mode", "sender", "sender or receiver mode") + flag.Parse() + + relaySrvAddress = os.Getenv("TEST_RELAY_SERVER") // rel://ip:port + turnSrvAddress = os.Getenv("TEST_TURN_SERVER") // ip:3478 + signalURL = os.Getenv("TEST_SIGNAL_URL") // http://receiver_ip:8081 + udpListener = os.Getenv("TEST_UDP_LISTENER") // IP:0 + + if mode == "receiver" { + relayResult := RelayReceiverMain() + turnResults := TURNReaderMain() + for i := 0; i < len(turnResults); i++ { + log.Infof("pairs: %d,\tRelay speed:\t%s,\trelay duration:\t%s", relayResult[i].numOfPairs, relayResult[i].Speed(), relayResult[i].duration) + log.Infof("pairs: %d,\tTURN speed:\t%s,\tturn duration:\t%s", turnResults[i].numOfPairs, turnResults[i].Speed(), turnResults[i].duration) + } + } else { + RelaySenderMain() + // grant time for receiver to start + time.Sleep(3 * time.Second) + TRUNSenderMain() + } +} diff --git a/relay/testec2/relay.go b/relay/testec2/relay.go new file mode 100644 index 000000000..93d084387 --- /dev/null +++ b/relay/testec2/relay.go @@ -0,0 +1,176 @@ +//go:build linux || darwin + +package main + +import ( + "context" + "fmt" + "net" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/relay/auth/hmac" + "github.com/netbirdio/netbird/relay/client" +) + +var ( + hmacTokenStore = &hmac.TokenStore{} +) + +func relayTransfer(serverConnURL string, testData []byte, peerPairs int) { + connsSender := prepareConnsSender(serverConnURL, peerPairs) + defer func() { + for i := 0; i < len(connsSender); i++ { + err := connsSender[i].Close() + if err != nil { + log.Errorf("failed to close connection: %s", err) + } + } + }() + + wg := sync.WaitGroup{} + wg.Add(len(connsSender)) + for _, conn := range connsSender { + go func(conn net.Conn) { + defer wg.Done() + runWriter(conn, testData) + }(conn) + } + wg.Wait() +} + +func runWriter(conn net.Conn, testData []byte) { + si := NewStartInidication(time.Now(), len(testData)) + _, err := conn.Write(si) + if err != nil { + log.Errorf("failed to write to channel: %s", err) + return + } + log.Infof("sent start indication") + + pieceSize := 1024 + testDataLen := len(testData) + + for j := 0; j < testDataLen; j += pieceSize { + end := j + pieceSize + if end > testDataLen { + end = testDataLen + } + _, writeErr := conn.Write(testData[j:end]) + if writeErr != nil { + log.Errorf("failed to write to channel: %s", writeErr) + return + } + } +} + +func prepareConnsSender(serverConnURL string, peerPairs int) []net.Conn { + ctx := context.Background() + clientsSender := make([]*client.Client, peerPairs) + for i := 0; i < cap(clientsSender); i++ { + c := client.NewClient(ctx, serverConnURL, hmacTokenStore, "sender-"+fmt.Sprint(i)) + if err := c.Connect(); err != nil { + log.Fatalf("failed to connect to server: %s", err) + } + clientsSender[i] = c + } + + connsSender := make([]net.Conn, 0, peerPairs) + for i := 0; i < len(clientsSender); i++ { + conn, err := clientsSender[i].OpenConn("receiver-" + fmt.Sprint(i)) + if err != nil { + log.Fatalf("failed to bind channel: %s", err) + } + connsSender = append(connsSender, conn) + } + return connsSender +} + +func relayReceive(serverConnURL string, peerPairs int) []time.Duration { + connsReceiver := prepareConnsReceiver(serverConnURL, peerPairs) + defer func() { + for i := 0; i < len(connsReceiver); i++ { + if err := connsReceiver[i].Close(); err != nil { + log.Errorf("failed to close connection: %s", err) + } + } + }() + + durations := make(chan time.Duration, len(connsReceiver)) + wg := sync.WaitGroup{} + for _, conn := range connsReceiver { + wg.Add(1) + go func(conn net.Conn) { + defer wg.Done() + duration := runReader(conn) + durations <- duration + }(conn) + } + wg.Wait() + + durationsList := make([]time.Duration, 0, len(connsReceiver)) + for d := range durations { + durationsList = append(durationsList, d) + if len(durationsList) == len(connsReceiver) { + close(durations) + } + } + + return durationsList +} + +func runReader(conn net.Conn) time.Duration { + buf := make([]byte, 8192) + + n, readErr := conn.Read(buf) + if readErr != nil { + log.Errorf("failed to read from channel: %s", readErr) + return 0 + } + + si := DecodeStartIndication(buf[:n]) + log.Infof("received start indication: %v", si) + + receivedSize, err := conn.Read(buf) + if err != nil { + log.Fatalf("failed to read from relay: %s", err) + } + now := time.Now() + + rcv := 0 + for receivedSize < si.TransferSize { + n, readErr = conn.Read(buf) + if readErr != nil { + log.Errorf("failed to read from channel: %s", readErr) + return 0 + } + + receivedSize += n + rcv += n + } + return time.Since(now) +} + +func prepareConnsReceiver(serverConnURL string, peerPairs int) []net.Conn { + clientsReceiver := make([]*client.Client, peerPairs) + for i := 0; i < cap(clientsReceiver); i++ { + c := client.NewClient(context.Background(), serverConnURL, hmacTokenStore, "receiver-"+fmt.Sprint(i)) + err := c.Connect() + if err != nil { + log.Fatalf("failed to connect to server: %s", err) + } + clientsReceiver[i] = c + } + + connsReceiver := make([]net.Conn, 0, peerPairs) + for i := 0; i < len(clientsReceiver); i++ { + conn, err := clientsReceiver[i].OpenConn("sender-" + fmt.Sprint(i)) + if err != nil { + log.Fatalf("failed to bind channel: %s", err) + } + connsReceiver = append(connsReceiver, conn) + } + return connsReceiver +} diff --git a/relay/testec2/signal.go b/relay/testec2/signal.go new file mode 100644 index 000000000..fe93a2fe2 --- /dev/null +++ b/relay/testec2/signal.go @@ -0,0 +1,91 @@ +//go:build linux || darwin + +package main + +import ( + "bytes" + "encoding/json" + "net/http" + + log "github.com/sirupsen/logrus" +) + +type PeerAddr struct { + Address []string +} + +type ClientPeerAddr struct { + Address map[string]string +} + +type Signal struct { + AddressesChan chan []string + ClientAddressChan chan map[string]string +} + +func NewSignalService() *Signal { + return &Signal{ + AddressesChan: make(chan []string), + ClientAddressChan: make(chan map[string]string), + } +} + +func (rs *Signal) Listen(listenAddr string) error { + http.HandleFunc("/", rs.onNewAddresses) + return http.ListenAndServe(listenAddr, nil) +} + +func (rs *Signal) onNewAddresses(w http.ResponseWriter, r *http.Request) { + var msg PeerAddr + err := json.NewDecoder(r.Body).Decode(&msg) + if err != nil { + log.Errorf("Error decoding message: %v", err) + } + + log.Infof("received addresses: %d", len(msg.Address)) + rs.AddressesChan <- msg.Address + clientAddresses := <-rs.ClientAddressChan + + respMsg := ClientPeerAddr{ + Address: clientAddresses, + } + data, err := json.Marshal(respMsg) + if err != nil { + log.Errorf("Error marshalling message: %v", err) + return + } + + _, err = w.Write(data) + if err != nil { + log.Errorf("Error writing response: %v", err) + } +} + +type SignalClient struct { + SignalURL string +} + +func (ss SignalClient) SendAddress(addresses []string) (*ClientPeerAddr, error) { + msg := PeerAddr{ + Address: addresses, + } + data, err := json.Marshal(msg) + if err != nil { + return nil, err + } + + response, err := http.Post(ss.SignalURL, "application/json", bytes.NewBuffer(data)) + if err != nil { + return nil, err + } + + defer response.Body.Close() + + log.Debugf("wait for signal response") + var respPeerAddress ClientPeerAddr + err = json.NewDecoder(response.Body).Decode(&respPeerAddress) + if err != nil { + return nil, err + } + return &respPeerAddress, nil +} diff --git a/relay/testec2/start_msg.go b/relay/testec2/start_msg.go new file mode 100644 index 000000000..19b65380b --- /dev/null +++ b/relay/testec2/start_msg.go @@ -0,0 +1,39 @@ +//go:build linux || darwin + +package main + +import ( + "bytes" + "encoding/gob" + "time" + + log "github.com/sirupsen/logrus" +) + +type StartIndication struct { + Started time.Time + TransferSize int +} + +func NewStartInidication(started time.Time, transferSize int) []byte { + si := StartIndication{ + Started: started, + TransferSize: transferSize, + } + + var data bytes.Buffer + err := gob.NewEncoder(&data).Encode(si) + if err != nil { + log.Fatal("encode error:", err) + } + return data.Bytes() +} + +func DecodeStartIndication(data []byte) StartIndication { + var si StartIndication + err := gob.NewDecoder(bytes.NewReader(data)).Decode(&si) + if err != nil { + log.Fatal("decode error:", err) + } + return si +} diff --git a/relay/testec2/tun/proxy.go b/relay/testec2/tun/proxy.go new file mode 100644 index 000000000..7d84bece7 --- /dev/null +++ b/relay/testec2/tun/proxy.go @@ -0,0 +1,72 @@ +//go:build linux || darwin + +package tun + +import ( + "net" + "sync/atomic" + + log "github.com/sirupsen/logrus" +) + +type Proxy struct { + Device *Device + PConn net.PacketConn + DstAddr net.Addr + shutdownFlag atomic.Bool +} + +func (p *Proxy) Start() { + go p.readFromDevice() + go p.readFromConn() +} + +func (p *Proxy) Close() { + p.shutdownFlag.Store(true) +} + +func (p *Proxy) readFromDevice() { + buf := make([]byte, 1500) + for { + n, err := p.Device.Read(buf) + if err != nil { + if p.shutdownFlag.Load() { + return + } + log.Errorf("failed to read from device: %s", err) + return + } + + _, err = p.PConn.WriteTo(buf[:n], p.DstAddr) + if err != nil { + if p.shutdownFlag.Load() { + return + } + log.Errorf("failed to write to conn: %s", err) + return + } + } +} + +func (p *Proxy) readFromConn() { + buf := make([]byte, 1500) + for { + n, _, err := p.PConn.ReadFrom(buf) + if err != nil { + if p.shutdownFlag.Load() { + return + } + log.Errorf("failed to read from conn: %s", err) + return + } + + _, err = p.Device.Write(buf[:n]) + if err != nil { + if p.shutdownFlag.Load() { + return + } + log.Errorf("failed to write to device: %s", err) + return + } + } +} diff --git a/relay/testec2/tun/tun.go b/relay/testec2/tun/tun.go new file mode 100644 index 000000000..5580785ce --- /dev/null +++ b/relay/testec2/tun/tun.go @@ -0,0 +1,110 @@ +//go:build linux || darwin + +package tun + +import ( + "net" + + log "github.com/sirupsen/logrus" + "github.com/songgao/water" + "github.com/vishvananda/netlink" +) + +type Device struct { + Name string + IP string + PConn net.PacketConn + DstAddr net.Addr + + iFace *water.Interface + proxy *Proxy +} + +func (d *Device) Up() error { + cfg := water.Config{ + DeviceType: water.TUN, + PlatformSpecificParams: water.PlatformSpecificParams{ + Name: d.Name, + }, + } + iFace, err := water.New(cfg) + if err != nil { + return err + } + d.iFace = iFace + + err = d.assignIP() + if err != nil { + return err + } + + err = d.bringUp() + if err != nil { + return err + } + + d.proxy = &Proxy{ + Device: d, + PConn: d.PConn, + DstAddr: d.DstAddr, + } + d.proxy.Start() + return nil +} + +func (d *Device) Close() error { + if d.proxy != nil { + d.proxy.Close() + } + if d.iFace != nil { + return d.iFace.Close() + } + return nil +} + +func (d *Device) Read(b []byte) (int, error) { + return d.iFace.Read(b) +} + +func (d *Device) Write(b []byte) (int, error) { + return d.iFace.Write(b) +} + +func (d *Device) assignIP() error { + iface, err := netlink.LinkByName(d.Name) + if err != nil { + log.Errorf("failed to get TUN device: %v", err) + return err + } + + ip := net.IPNet{ + IP: net.ParseIP(d.IP), + Mask: net.CIDRMask(24, 32), + } + + addr := &netlink.Addr{ + IPNet: &ip, + } + err = netlink.AddrAdd(iface, addr) + if err != nil { + log.Errorf("failed to add IP address: %v", err) + return err + } + return nil +} + +func (d *Device) bringUp() error { + iface, err := netlink.LinkByName(d.Name) + if err != nil { + log.Errorf("failed to get device: %v", err) + return err + } + + // Bring the interface up + err = netlink.LinkSetUp(iface) + if err != nil { + log.Errorf("failed to set device up: %v", err) + return err + } + return nil +} diff --git a/relay/testec2/turn.go b/relay/testec2/turn.go new file mode 100644 index 000000000..8beb40423 --- /dev/null +++ b/relay/testec2/turn.go @@ -0,0 +1,181 @@ +//go:build linux || darwin + +package main + +import ( + "fmt" + "net" + "sync" + "time" + + "github.com/netbirdio/netbird/relay/testec2/tun" + + log "github.com/sirupsen/logrus" +) + +type TurnReceiver struct { + conns []*net.UDPConn + clientAddresses map[string]string + devices []*tun.Device +} + +type TurnSender struct { + turnConns map[string]*TurnConn + addresses []string + devices []*tun.Device +} + +func runTurnWriting(tcpConn net.Conn, testData []byte, testDataLen int, wg *sync.WaitGroup) { + defer wg.Done() + defer tcpConn.Close() + + log.Infof("start to sending test data: %s", tcpConn.RemoteAddr()) + + si := NewStartInidication(time.Now(), testDataLen) + _, err := tcpConn.Write(si) + if err != nil { + log.Errorf("failed to write to tcp: %s", err) + return + } + + pieceSize := 1024 + for j := 0; j < testDataLen; j += pieceSize { + end := j + pieceSize + if end > testDataLen { + end = testDataLen + } + _, writeErr := tcpConn.Write(testData[j:end]) + if writeErr != nil { + log.Errorf("failed to write to tcp conn: %s", writeErr) + return + } + } + + // grant time to flush out packages + time.Sleep(3 * time.Second) +} + +func createSenderDevices(sender *TurnSender, clientAddresses *ClientPeerAddr) { + var i int + devices := make([]*tun.Device, 0, len(clientAddresses.Address)) + for k, v := range clientAddresses.Address { + tc, ok := sender.turnConns[k] + if !ok { + log.Fatalf("failed to find turn conn: %s", k) + } + + addr, err := net.ResolveUDPAddr("udp", v) + if err != nil { + log.Fatalf("failed to resolve udp address: %s", err) + } + device := &tun.Device{ + Name: fmt.Sprintf("mtun-sender-%d", i), + IP: fmt.Sprintf("10.0.%d.1", i), + PConn: tc.relayConn, + DstAddr: addr, + } + + err = device.Up() + if err != nil { + log.Fatalf("failed to bring up device: %s", err) + } + + devices = append(devices, device) + i++ + } + sender.devices = devices +} + +func createTurnConns(p int, sender *TurnSender) { + turnConns := make(map[string]*TurnConn) + addresses := make([]string, 0, len(pairs)) + for i := 0; i < p; i++ { + tc := AllocateTurnClient(turnSrvAddress) + log.Infof("allocated turn client: %s", tc.Address().String()) + turnConns[tc.Address().String()] = tc + addresses = append(addresses, tc.Address().String()) + } + + sender.turnConns = turnConns + sender.addresses = addresses +} + +func runTurnReading(d *tun.Device, durations chan time.Duration) { + tcpListener, err := net.Listen("tcp", d.IP+":9999") + if err != nil { + log.Fatalf("failed to listen on tcp: %s", err) + } + log := log.WithField("device", tcpListener.Addr()) + + tcpConn, err := tcpListener.Accept() + if err != nil { + _ = tcpListener.Close() + log.Fatalf("failed to accept connection: %s", err) + } + log.Infof("remote peer connected") + + buf := make([]byte, 103) + n, err := tcpConn.Read(buf) + if err != nil { + _ = tcpListener.Close() + log.Fatalf(errMsgFailedReadTCP, err) + } + + si := DecodeStartIndication(buf[:n]) + log.Infof("received start indication: %v, %d", si, n) + + buf = make([]byte, 8192) + i, err := tcpConn.Read(buf) + if err != nil { + _ = tcpListener.Close() + log.Fatalf(errMsgFailedReadTCP, err) + } + now := time.Now() + for i < si.TransferSize { + n, err := tcpConn.Read(buf) + if err != nil { + _ = tcpListener.Close() + log.Fatalf(errMsgFailedReadTCP, err) + } + i += n + } + durations <- time.Since(now) +} + +func createDevices(addresses []string, receiver *TurnReceiver) error { + receiver.conns = make([]*net.UDPConn, 0, len(addresses)) + receiver.clientAddresses = make(map[string]string, len(addresses)) + receiver.devices = make([]*tun.Device, 0, len(addresses)) + for i, addr := range addresses { + localAddr, err := net.ResolveUDPAddr("udp", udpListener) + if err != nil { + return fmt.Errorf("failed to resolve UDP address: %s", err) + } + + conn, err := net.ListenUDP("udp", localAddr) + if err != nil { + return fmt.Errorf("failed to create UDP connection: %s", err) + } + + receiver.conns = append(receiver.conns, conn) + receiver.clientAddresses[addr] = conn.LocalAddr().String() + + dstAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return fmt.Errorf("failed to resolve address: %s", err) + } + + device := &tun.Device{ + Name: fmt.Sprintf("mtun-%d", i), + IP: fmt.Sprintf("10.0.%d.2", i), + PConn: conn, + DstAddr: dstAddr, + } + + if err = device.Up(); err != nil { + return fmt.Errorf("failed to bring up device: %s, %s", device.Name, err) + } + receiver.devices = append(receiver.devices, device) + } + return nil +} diff --git a/relay/testec2/turn_allocator.go b/relay/testec2/turn_allocator.go new file mode 100644 index 000000000..fd86208df --- /dev/null +++ b/relay/testec2/turn_allocator.go @@ -0,0 +1,83 @@ +//go:build linux || darwin + +package main + +import ( + "fmt" + "net" + + "github.com/pion/logging" + "github.com/pion/turn/v3" + log "github.com/sirupsen/logrus" +) + +type TurnConn struct { + conn net.Conn + turnClient *turn.Client + relayConn net.PacketConn +} + +func (tc *TurnConn) Address() net.Addr { + return tc.relayConn.LocalAddr() +} + +func (tc *TurnConn) Close() { + _ = tc.relayConn.Close() + tc.turnClient.Close() + _ = tc.conn.Close() +} + +func AllocateTurnClient(serverAddr string) *TurnConn { + conn, err := net.Dial("tcp", serverAddr) + if err != nil { + log.Fatal(err) + } + + turnClient, err := getTurnClient(serverAddr, conn) + if err != nil { + log.Fatal(err) + } + + relayConn, err := turnClient.Allocate() + if err != nil { + log.Fatal(err) + } + + return &TurnConn{ + conn: conn, + turnClient: turnClient, + relayConn: relayConn, + } +} + +func getTurnClient(address string, conn net.Conn) (*turn.Client, error) { + // Dial TURN Server + addrStr := fmt.Sprintf("%s:%d", address, 443) + + fac := logging.NewDefaultLoggerFactory() + //fac.DefaultLogLevel = logging.LogLevelTrace + + // Start a new TURN Client and wrap our net.Conn in a STUNConn + // This allows us to simulate datagram based communication over a net.Conn + cfg := &turn.ClientConfig{ + TURNServerAddr: address, + Conn: turn.NewSTUNConn(conn), + Username: "test", + Password: "test", + LoggerFactory: fac, + } + + client, err := turn.NewClient(cfg) + if err != nil { + return nil, fmt.Errorf("failed to create TURN client for server %s: %s", addrStr, err) + } + + // Start listening on the conn provided. + err = client.Listen() + if err != nil { + client.Close() + return nil, fmt.Errorf("failed to listen on TURN client for server %s: %s", addrStr, err) + } + + return client, nil +} diff --git a/release_files/install.sh b/release_files/install.sh index 198d74428..b7a6c08f9 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -21,6 +21,8 @@ SUDO="" if command -v sudo > /dev/null && [ "$(id -u)" -ne 0 ]; then SUDO="sudo" +elif command -v doas > /dev/null && [ "$(id -u)" -ne 0 ]; then + SUDO="doas" fi if [ -z ${NETBIRD_RELEASE+x} ]; then @@ -31,14 +33,16 @@ get_release() { local RELEASE=$1 if [ "$RELEASE" = "latest" ]; then local TAG="latest" + local URL="https://pkgs.netbird.io/releases/latest" else local TAG="tags/${RELEASE}" + local URL="https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" fi if [ -n "$GITHUB_TOKEN" ]; then - curl -H "Authorization: token ${GITHUB_TOKEN}" -s "https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" \ + curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}" \ | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' else - curl -s "https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" \ + curl -s "${URL}" \ | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' fi } @@ -68,7 +72,7 @@ download_release_binary() { if [ -n "$GITHUB_TOKEN" ]; then cd /tmp && curl -H "Authorization: token ${GITHUB_TOKEN}" -LO "$DOWNLOAD_URL" else - cd /tmp && curl -LO "$DOWNLOAD_URL" + cd /tmp && curl -LO "$DOWNLOAD_URL" || curl -LO --dns-servers 8.8.8.8 "$DOWNLOAD_URL" fi @@ -151,6 +155,22 @@ add_aur_repo() { ${SUDO} pacman -Rs "$REMOVE_PKGS" --noconfirm } +prepare_tun_module() { + # Create the necessary file structure for /dev/net/tun + if [ ! -c /dev/net/tun ]; then + if [ ! -d /dev/net ]; then + mkdir -m 755 /dev/net + fi + mknod /dev/net/tun c 10 200 + chmod 0755 /dev/net/tun + fi + + # Load the tun module if not already loaded + if ! lsmod | grep -q "^tun\s"; then + insmod /lib/modules/tun.ko + fi +} + install_native_binaries() { # Checks for supported architecture case "$ARCH" in @@ -226,6 +246,13 @@ install_netbird() { ${SUDO} dnf -y install netbird-ui fi ;; + rpm-ostree) + add_rpm_repo + ${SUDO} rpm-ostree -y install netbird + if ! $SKIP_UI_APP; then + ${SUDO} rpm-ostree -y install netbird-ui + fi + ;; pacman) ${SUDO} pacman -Syy add_aur_repo @@ -268,16 +295,22 @@ install_netbird() { ;; esac + if [ "$OS_NAME" = "synology" ]; then + prepare_tun_module + fi + # Add package manager to config ${SUDO} mkdir -p "$CONFIG_FOLDER" echo "package_manager=$PACKAGE_MANAGER" | ${SUDO} tee "$CONFIG_FILE" > /dev/null # Load and start netbird service - if ! ${SUDO} netbird service install 2>&1; then - echo "NetBird service has already been loaded" - fi - if ! ${SUDO} netbird service start 2>&1; then - echo "NetBird service has already been started" + if [ "$PACKAGE_MANAGER" != "rpm-ostree" ]; then + if ! ${SUDO} netbird service install 2>&1; then + echo "NetBird service has already been loaded" + fi + if ! ${SUDO} netbird service start 2>&1; then + echo "NetBird service has already been started" + fi fi @@ -287,7 +320,7 @@ install_netbird() { } version_greater_equal() { - printf '%s\n%s\n' "$2" "$1" | sort -V -C + printf '%s\n%s\n' "$2" "$1" | sort -V -c } is_bin_package_manager() { @@ -383,6 +416,9 @@ if type uname >/dev/null 2>&1; then elif [ -x "$(command -v dnf)" ]; then PACKAGE_MANAGER="dnf" echo "The installation will be performed using dnf package manager" + elif [ -x "$(command -v rpm-ostree)" ]; then + PACKAGE_MANAGER="rpm-ostree" + echo "The installation will be performed using rpm-ostree package manager" elif [ -x "$(command -v yum)" ]; then PACKAGE_MANAGER="yum" echo "The installation will be performed using yum package manager" diff --git a/route/route.go b/route/route.go index eb6c36bd8..e23801e6e 100644 --- a/route/route.go +++ b/route/route.go @@ -100,6 +100,7 @@ type Route struct { Metric int Enabled bool Groups []string `gorm:"serializer:json"` + AccessControlGroups []string `gorm:"serializer:json"` } // EventMeta returns activity event meta related to the route @@ -123,6 +124,7 @@ func (r *Route) Copy() *Route { Masquerade: r.Masquerade, Enabled: r.Enabled, Groups: slices.Clone(r.Groups), + AccessControlGroups: slices.Clone(r.AccessControlGroups), } return route } @@ -147,7 +149,8 @@ func (r *Route) IsEqual(other *Route) bool { other.Masquerade == r.Masquerade && other.Enabled == r.Enabled && slices.Equal(r.Groups, other.Groups) && - slices.Equal(r.PeerGroups, other.PeerGroups) + slices.Equal(r.PeerGroups, other.PeerGroups)&& + slices.Equal(r.AccessControlGroups, other.AccessControlGroups) } // IsDynamic returns if the route is dynamic, i.e. has domains diff --git a/signal/client/client.go b/signal/client/client.go index 9d99b3677..ced3fb7d0 100644 --- a/signal/client/client.go +++ b/signal/client/client.go @@ -51,11 +51,10 @@ func UnMarshalCredential(msg *proto.Message) (*Credential, error) { } // MarshalCredential marshal a Credential instance and returns a Message object -func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey wgtypes.Key, credential *Credential, t proto.Body_Type, - rosenpassPubKey []byte, rosenpassAddr string) (*proto.Message, error) { +func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey string, credential *Credential, t proto.Body_Type, rosenpassPubKey []byte, rosenpassAddr string, relaySrvAddress string) (*proto.Message, error) { return &proto.Message{ Key: myKey.PublicKey().String(), - RemoteKey: remoteKey.String(), + RemoteKey: remoteKey, Body: &proto.Body{ Type: t, Payload: fmt.Sprintf("%s:%s", credential.UFrag, credential.Pwd), @@ -65,6 +64,7 @@ func MarshalCredential(myKey wgtypes.Key, myPort int, remoteKey wgtypes.Key, cre RosenpassPubKey: rosenpassPubKey, RosenpassServerAddr: rosenpassAddr, }, + RelayServerAddress: relaySrvAddress, }, }, nil } diff --git a/signal/client/client_test.go b/signal/client/client_test.go index 2525493b4..f7d4ebc50 100644 --- a/signal/client/client_test.go +++ b/signal/client/client_test.go @@ -199,7 +199,7 @@ func startSignal() (*grpc.Server, net.Listener) { panic(err) } s := grpc.NewServer() - srv, err := server.NewServer(otel.Meter("")) + srv, err := server.NewServer(context.Background(), otel.Meter("")) if err != nil { panic(err) } diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 61f7a32a7..1bb2f1d0c 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -29,12 +29,9 @@ import ( "google.golang.org/grpc/keepalive" ) -const ( - metricsPort = 9090 -) - var ( signalPort int + metricsPort int signalLetsencryptDomain string signalSSLDir string defaultSignalSSLDir string @@ -105,7 +102,7 @@ var ( } }() - srv, err := server.NewServer(metricsServer.Meter) + srv, err := server.NewServer(cmd.Context(), metricsServer.Meter) if err != nil { return fmt.Errorf("creating signal server: %v", err) } @@ -288,6 +285,7 @@ func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) { func init() { runCmd.PersistentFlags().IntVar(&signalPort, "port", 80, "Server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise") + runCmd.Flags().IntVar(&metricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics") runCmd.Flags().StringVar(&signalSSLDir, "ssl-dir", defaultSignalSSLDir, "server ssl directory location. *Required only for Let's Encrypt certificates.") runCmd.Flags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") runCmd.Flags().StringVar(&signalCertFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect") diff --git a/signal/peer/peer.go b/signal/peer/peer.go index 3149526b2..ed2360d67 100644 --- a/signal/peer/peer.go +++ b/signal/peer/peer.go @@ -18,16 +18,20 @@ type Peer struct { StreamID int64 - //a gRpc connection stream to the Peer + // a gRpc connection stream to the Peer Stream proto.SignalExchange_ConnectStreamServer + + // registration time + RegisteredAt time.Time } // NewPeer creates a new instance of a connected Peer func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer) *Peer { return &Peer{ - Id: id, - Stream: stream, - StreamID: time.Now().UnixNano(), + Id: id, + Stream: stream, + StreamID: time.Now().UnixNano(), + RegisteredAt: time.Now(), } } @@ -78,8 +82,11 @@ func (registry *Registry) Register(peer *Peer) { log.Warnf("peer [%s] is already registered [new streamID %d, previous StreamID %d]. Will override stream.", peer.Id, peer.StreamID, pp.StreamID) registry.Peers.Store(peer.Id, peer) + return } + log.Debugf("peer registered [%s]", peer.Id) + registry.metrics.ActivePeers.Add(context.Background(), 1) // record time as milliseconds registry.metrics.RegistrationDelay.Record(context.Background(), float64(time.Since(start).Nanoseconds())/1e6) @@ -101,8 +108,8 @@ func (registry *Registry) Deregister(peer *Peer) { peer.Id, pp.StreamID, peer.StreamID) return } + registry.metrics.ActivePeers.Add(context.Background(), -1) + log.Debugf("peer deregistered [%s]", peer.Id) + registry.metrics.Deregistrations.Add(context.Background(), 1) } - log.Debugf("peer deregistered [%s]", peer.Id) - - registry.metrics.Deregistrations.Add(context.Background(), 1) } diff --git a/signal/proto/signalexchange.pb.go b/signal/proto/signalexchange.pb.go index 782c45da1..30f704c6f 100644 --- a/signal/proto/signalexchange.pb.go +++ b/signal/proto/signalexchange.pb.go @@ -1,15 +1,15 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v3.12.4 +// protoc v3.21.12 // source: signalexchange.proto package proto import ( - _ "github.com/golang/protobuf/protoc-gen-go/descriptor" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" + _ "google.golang.org/protobuf/types/descriptorpb" reflect "reflect" sync "sync" ) @@ -225,6 +225,8 @@ type Body struct { FeaturesSupported []uint32 `protobuf:"varint,6,rep,packed,name=featuresSupported,proto3" json:"featuresSupported,omitempty"` // RosenpassConfig is a Rosenpass config of the remote peer our peer tries to connect to RosenpassConfig *RosenpassConfig `protobuf:"bytes,7,opt,name=rosenpassConfig,proto3" json:"rosenpassConfig,omitempty"` + // relayServerAddress is an IP:port of the relay server + RelayServerAddress string `protobuf:"bytes,8,opt,name=relayServerAddress,proto3" json:"relayServerAddress,omitempty"` } func (x *Body) Reset() { @@ -308,6 +310,13 @@ func (x *Body) GetRosenpassConfig() *RosenpassConfig { return nil } +func (x *Body) GetRelayServerAddress() string { + if x != nil { + return x.RelayServerAddress + } + return "" +} + // Mode indicates a connection mode type Mode struct { state protoimpl.MessageState @@ -431,7 +440,7 @@ var file_signalexchange_proto_rawDesc = []byte{ 0x52, 0x09, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x04, 0x62, 0x6f, 0x64, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x52, - 0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xf6, 0x02, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d, + 0x04, 0x62, 0x6f, 0x64, 0x79, 0x22, 0xa6, 0x03, 0x0a, 0x04, 0x42, 0x6f, 0x64, 0x79, 0x12, 0x2d, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x42, 0x6f, 0x64, 0x79, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x18, 0x0a, @@ -451,7 +460,10 @@ var file_signalexchange_proto_rawDesc = []byte{ 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x65, 0x78, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x2e, 0x52, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x36, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x08, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x12, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, + 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0x36, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x09, 0x0a, 0x05, 0x4f, 0x46, 0x46, 0x45, 0x52, 0x10, 0x00, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x4e, 0x53, 0x57, 0x45, 0x52, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x41, 0x4e, 0x44, 0x49, 0x44, 0x41, 0x54, 0x45, 0x10, 0x02, 0x12, 0x08, 0x0a, 0x04, 0x4d, 0x4f, 0x44, 0x45, 0x10, 0x04, 0x22, 0x2e, diff --git a/signal/proto/signalexchange.proto b/signal/proto/signalexchange.proto index a8c4c309c..4431edd7c 100644 --- a/signal/proto/signalexchange.proto +++ b/signal/proto/signalexchange.proto @@ -60,6 +60,9 @@ message Body { // RosenpassConfig is a Rosenpass config of the remote peer our peer tries to connect to RosenpassConfig rosenpassConfig = 7; + + // relayServerAddress is url of the relay server + string relayServerAddress = 8; } // Mode indicates a connection mode diff --git a/signal/server/signal.go b/signal/server/signal.go index 219bdcc41..63cc43bd7 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -13,6 +13,8 @@ import ( "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + "github.com/netbirdio/signal-dispatcher/dispatcher" + "github.com/netbirdio/netbird/signal/metrics" "github.com/netbirdio/netbird/signal/peer" "github.com/netbirdio/netbird/signal/proto" @@ -40,20 +42,26 @@ const ( type Server struct { registry *peer.Registry proto.UnimplementedSignalExchangeServer - - metrics *metrics.AppMetrics + dispatcher *dispatcher.Dispatcher + metrics *metrics.AppMetrics } // NewServer creates a new Signal server -func NewServer(meter metric.Meter) (*Server, error) { +func NewServer(ctx context.Context, meter metric.Meter) (*Server, error) { appMetrics, err := metrics.NewAppMetrics(meter) if err != nil { return nil, fmt.Errorf("creating app metrics: %v", err) } + dispatcher, err := dispatcher.NewDispatcher(ctx, meter) + if err != nil { + return nil, fmt.Errorf("creating dispatcher: %v", err) + } + s := &Server{ - registry: peer.NewRegistry(appMetrics), - metrics: appMetrics, + dispatcher: dispatcher, + registry: peer.NewRegistry(appMetrics), + metrics: appMetrics, } return s, nil @@ -61,57 +69,26 @@ func NewServer(meter metric.Meter) (*Server, error) { // Send forwards a message to the signal peer func (s *Server) Send(ctx context.Context, msg *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { - if !s.registry.IsPeerRegistered(msg.Key) { - s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotRegistered))) + log.Debugf("received a new message to send from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey) - return nil, fmt.Errorf("peer %s is not registered", msg.Key) + if _, found := s.registry.Get(msg.RemoteKey); found { + s.forwardMessageToPeer(ctx, msg) + return &proto.EncryptedMessage{}, nil } - getRegistrationStart := time.Now() - - if dstPeer, found := s.registry.Get(msg.RemoteKey); found { - s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage), attribute.String(labelRegistrationStatus, labelRegistrationFound))) - start := time.Now() - //forward the message to the target peer - if err := dstPeer.Stream.Send(msg); err != nil { - log.Errorf("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err) - //todo respond to the sender? - - s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError))) - } else { - s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage))) - s.metrics.MessagesForwarded.Add(context.Background(), 1) - } - } else { - s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeMessage), attribute.String(labelRegistrationStatus, labelRegistrationNotFound))) - log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey) - //todo respond to the sender? - - s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected))) - } - return &proto.EncryptedMessage{}, nil + return s.dispatcher.SendMessage(context.Background(), msg) } // ConnectStream connects to the exchange stream func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) error { - p, err := s.connectPeer(stream) + p, err := s.RegisterPeer(stream) if err != nil { return err } - startRegister := time.Now() + defer s.DeregisterPeer(p) - s.metrics.ActivePeers.Add(stream.Context(), 1) - - defer func() { - log.Infof("peer disconnected [%s] [streamID %d] ", p.Id, p.StreamID) - s.registry.Deregister(p) - - s.metrics.PeerConnectionDuration.Record(stream.Context(), int64(time.Since(startRegister).Seconds())) - s.metrics.ActivePeers.Add(context.Background(), -1) - }() - - //needed to confirm that the peer has been registered so that the client can proceed + // needed to confirm that the peer has been registered so that the client can proceed header := metadata.Pairs(proto.HeaderRegistered, "1") err = stream.SendHeader(header) if err != nil { @@ -119,11 +96,10 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) return err } - log.Infof("peer connected [%s] [streamID %d] ", p.Id, p.StreamID) + log.Debugf("peer connected [%s] [streamID %d] ", p.Id, p.StreamID) for { - - //read incoming messages + // read incoming messages msg, err := stream.Recv() if err == io.EOF { break @@ -131,44 +107,26 @@ func (s *Server) ConnectStream(stream proto.SignalExchange_ConnectStreamServer) return err } - log.Debugf("received a new message from peer [%s] to peer [%s]", p.Id, msg.RemoteKey) + log.Debugf("Received a response from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey) - getRegistrationStart := time.Now() - - // lookup the target peer where the message is going to - if dstPeer, found := s.registry.Get(msg.RemoteKey); found { - s.metrics.GetRegistrationDelay.Record(stream.Context(), float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound))) - start := time.Now() - //forward the message to the target peer - if err := dstPeer.Stream.Send(msg); err != nil { - log.Errorf("error while forwarding message from peer [%s] to peer [%s] %v", p.Id, msg.RemoteKey, err) - //todo respond to the sender? - s.metrics.MessageForwardFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelType, labelTypeError))) - } else { - // in milliseconds - s.metrics.MessageForwardLatency.Record(stream.Context(), float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream))) - s.metrics.MessagesForwarded.Add(stream.Context(), 1) - } - } else { - s.metrics.GetRegistrationDelay.Record(stream.Context(), float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationNotFound))) - s.metrics.MessageForwardFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected))) - log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", p.Id, msg.RemoteKey) - //todo respond to the sender? + _, err = s.dispatcher.SendMessage(stream.Context(), msg) + if err != nil { + log.Debugf("error while sending message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err) } } + <-stream.Context().Done() return stream.Context().Err() } -// Handles initial Peer connection. -// Each connection must provide an Id header. -// At this moment the connecting Peer will be registered in the peer.Registry -func (s Server) connectPeer(stream proto.SignalExchange_ConnectStreamServer) (*peer.Peer, error) { +func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) (*peer.Peer, error) { + log.Debugf("registering new peer") if meta, hasMeta := metadata.FromIncomingContext(stream.Context()); hasMeta { if id, found := meta[proto.HeaderId]; found { p := peer.NewPeer(id[0], stream) s.registry.Register(p) + s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer) return p, nil } else { @@ -180,3 +138,37 @@ func (s Server) connectPeer(stream proto.SignalExchange_ConnectStreamServer) (*p return nil, status.Errorf(codes.FailedPrecondition, "missing connection stream meta") } } + +func (s *Server) DeregisterPeer(p *peer.Peer) { + log.Debugf("peer disconnected [%s] [streamID %d] ", p.Id, p.StreamID) + s.registry.Deregister(p) + + s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds())) +} + +func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) { + log.Debugf("forwarding a new message from peer [%s] to peer [%s]", msg.Key, msg.RemoteKey) + + getRegistrationStart := time.Now() + + // lookup the target peer where the message is going to + if dstPeer, found := s.registry.Get(msg.RemoteKey); found { + s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationFound))) + start := time.Now() + // forward the message to the target peer + if err := dstPeer.Stream.Send(msg); err != nil { + log.Warnf("error while forwarding message from peer [%s] to peer [%s] %v", msg.Key, msg.RemoteKey, err) + // todo respond to the sender? + s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeError))) + } else { + // in milliseconds + s.metrics.MessageForwardLatency.Record(ctx, float64(time.Since(start).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream))) + s.metrics.MessagesForwarded.Add(ctx, 1) + } + } else { + s.metrics.GetRegistrationDelay.Record(ctx, float64(time.Since(getRegistrationStart).Nanoseconds())/1e6, metric.WithAttributes(attribute.String(labelType, labelTypeStream), attribute.String(labelRegistrationStatus, labelRegistrationNotFound))) + s.metrics.MessageForwardFailures.Add(ctx, 1, metric.WithAttributes(attribute.String(labelType, labelTypeNotConnected))) + log.Debugf("message from peer [%s] can't be forwarded to peer [%s] because destination peer is not connected", msg.Key, msg.RemoteKey) + // todo respond to the sender? + } +} diff --git a/util/file.go b/util/file.go index 2a6182556..8355488c9 100644 --- a/util/file.go +++ b/util/file.go @@ -10,51 +10,30 @@ import ( log "github.com/sirupsen/logrus" ) -// WriteJson writes JSON config object to a file creating parent directories if required -// The output JSON is pretty-formatted -func WriteJson(file string, obj interface{}) error { - +// WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory +func WriteJsonWithRestrictedPermission(file string, obj interface{}) error { configDir, configFileName, err := prepareConfigFileDir(file) if err != nil { return err } - // make it pretty - bs, err := json.MarshalIndent(obj, "", " ") + err = EnforcePermission(file) if err != nil { return err } - tempFile, err := os.CreateTemp(configDir, ".*"+configFileName) + return writeJson(file, obj, configDir, configFileName) +} + +// WriteJson writes JSON config object to a file creating parent directories if required +// The output JSON is pretty-formatted +func WriteJson(file string, obj interface{}) error { + configDir, configFileName, err := prepareConfigFileDir(file) if err != nil { return err } - tempFileName := tempFile.Name() - // closing file ops as windows doesn't allow to move it - err = tempFile.Close() - if err != nil { - return err - } - - defer func() { - _, err = os.Stat(tempFileName) - if err == nil { - os.Remove(tempFileName) - } - }() - - err = os.WriteFile(tempFileName, bs, 0600) - if err != nil { - return err - } - - err = os.Rename(tempFileName, file) - if err != nil { - return err - } - - return nil + return writeJson(file, obj, configDir, configFileName) } // DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file @@ -96,6 +75,46 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error { return nil } +func writeJson(file string, obj interface{}, configDir string, configFileName string) error { + + // make it pretty + bs, err := json.MarshalIndent(obj, "", " ") + if err != nil { + return err + } + + tempFile, err := os.CreateTemp(configDir, ".*"+configFileName) + if err != nil { + return err + } + + tempFileName := tempFile.Name() + // closing file ops as windows doesn't allow to move it + err = tempFile.Close() + if err != nil { + return err + } + + defer func() { + _, err = os.Stat(tempFileName) + if err == nil { + os.Remove(tempFileName) + } + }() + + err = os.WriteFile(tempFileName, bs, 0600) + if err != nil { + return err + } + + err = os.Rename(tempFileName, file) + if err != nil { + return err + } + + return nil +} + func openOrCreateFile(file string) (*os.File, error) { s, err := os.Stat(file) if err == nil { @@ -172,5 +191,9 @@ func prepareConfigFileDir(file string) (string, string, error) { } err := os.MkdirAll(configDir, 0750) + if err != nil { + return "", "", err + } + return configDir, configFileName, err } diff --git a/util/log.go b/util/log.go index 4bce75e4a..7a9235ee6 100644 --- a/util/log.go +++ b/util/log.go @@ -5,6 +5,7 @@ import ( "os" "path/filepath" "slices" + "strconv" log "github.com/sirupsen/logrus" "gopkg.in/natefinch/lumberjack.v2" @@ -12,6 +13,8 @@ import ( "github.com/netbirdio/netbird/formatter" ) +const defaultLogSize = 5 + // InitLog parses and sets log-level input func InitLog(logLevel string, logPath string) error { level, err := log.ParseLevel(logLevel) @@ -19,13 +22,14 @@ func InitLog(logLevel string, logPath string) error { log.Errorf("Failed parsing log-level %s: %s", logLevel, err) return err } - customOutputs := []string{"console", "syslog"}; + customOutputs := []string{"console", "syslog"} if logPath != "" && !slices.Contains(customOutputs, logPath) { + maxLogSize := getLogMaxSize() lumberjackLogger := &lumberjack.Logger{ // Log file absolute path, os agnostic Filename: filepath.ToSlash(logPath), - MaxSize: 5, // MB + MaxSize: maxLogSize, // MB MaxBackups: 10, MaxAge: 30, // days Compress: true, @@ -46,3 +50,18 @@ func InitLog(logLevel string, logPath string) error { log.SetLevel(level) return nil } + +func getLogMaxSize() int { + if sizeVar, ok := os.LookupEnv("NB_LOG_MAX_SIZE_MB"); ok { + size, err := strconv.ParseInt(sizeVar, 10, 64) + if err != nil { + log.Errorf("Failed parsing log-size %s: %s. Should be just an integer", sizeVar, err) + return defaultLogSize + } + + log.Infof("Setting log file max size to %d MB", size) + + return int(size) + } + return defaultLogSize +} diff --git a/util/net/dialer_nonios.go b/util/net/dialer_nonios.go index 7a5de7587..4032a75c0 100644 --- a/util/net/dialer_nonios.go +++ b/util/net/dialer_nonios.go @@ -49,6 +49,8 @@ func RemoveDialerHooks() { // DialContext wraps the net.Dialer's DialContext method to use the custom connection func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + log.Debugf("Dialing %s %s", network, address) + if CustomRoutingDisabled() { return d.Dialer.DialContext(ctx, network, address) } diff --git a/util/net/net.go b/util/net/net.go index 8d1fcebd0..61b47dbe7 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -4,7 +4,7 @@ import ( "net" "os" - "github.com/netbirdio/netbird/iface/netstack" + "github.com/netbirdio/netbird/client/iface/netstack" "github.com/google/uuid" ) diff --git a/util/permission.go b/util/permission.go new file mode 100644 index 000000000..666998cff --- /dev/null +++ b/util/permission.go @@ -0,0 +1,7 @@ +//go:build !windows + +package util + +func EnforcePermission(dirPath string) error { + return nil +} diff --git a/util/permission_windows.go b/util/permission_windows.go new file mode 100644 index 000000000..548fef824 --- /dev/null +++ b/util/permission_windows.go @@ -0,0 +1,86 @@ +package util + +import ( + "path/filepath" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" +) + +const ( + securityFlags = windows.OWNER_SECURITY_INFORMATION | + windows.GROUP_SECURITY_INFORMATION | + windows.DACL_SECURITY_INFORMATION | + windows.PROTECTED_DACL_SECURITY_INFORMATION +) + +func EnforcePermission(file string) error { + dirPath := filepath.Dir(file) + + user, group, err := sids() + if err != nil { + return err + } + + adminGroupSid, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid) + if err != nil { + return err + } + + explicitAccess := []windows.EXPLICIT_ACCESS{ + { + AccessPermissions: windows.GENERIC_ALL, + AccessMode: windows.SET_ACCESS, + Inheritance: windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT, + Trustee: windows.TRUSTEE{ + MultipleTrusteeOperation: windows.NO_MULTIPLE_TRUSTEE, + TrusteeForm: windows.TRUSTEE_IS_SID, + TrusteeType: windows.TRUSTEE_IS_USER, + TrusteeValue: windows.TrusteeValueFromSID(user), + }, + }, + { + AccessPermissions: windows.GENERIC_ALL, + AccessMode: windows.SET_ACCESS, + Inheritance: windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT, + Trustee: windows.TRUSTEE{ + MultipleTrusteeOperation: windows.NO_MULTIPLE_TRUSTEE, + TrusteeForm: windows.TRUSTEE_IS_SID, + TrusteeType: windows.TRUSTEE_IS_WELL_KNOWN_GROUP, + TrusteeValue: windows.TrusteeValueFromSID(adminGroupSid), + }, + }, + } + + dacl, err := windows.ACLFromEntries(explicitAccess, nil) + if err != nil { + return err + } + + return windows.SetNamedSecurityInfo(dirPath, windows.SE_FILE_OBJECT, securityFlags, user, group, dacl, nil) +} + +func sids() (*windows.SID, *windows.SID, error) { + var token windows.Token + err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &token) + if err != nil { + return nil, nil, err + } + defer func() { + if err := token.Close(); err != nil { + log.Errorf("failed to close process token: %v", err) + } + }() + + tu, err := token.GetTokenUser() + if err != nil { + return nil, nil, err + } + + pg, err := token.GetTokenPrimaryGroup() + if err != nil { + return nil, nil, err + } + + return tu.User.Sid, pg.PrimaryGroup, nil +}