Compare commits

...

35 Commits

Author SHA1 Message Date
Viktor Liu
1010629ea2 Merge pull request #2 from lixmal/windows-msi-cmd-fix
Remove auto sections
2025-06-24 14:28:17 +02:00
Viktor Liu
955d48abb9 Remove auto sections 2025-06-24 14:27:30 +02:00
Viktor Liu
d6dad20c83 Merge pull request #1 from lixmal/windows-msi-cmd-fix
Windows msi cmd fix
2025-06-24 14:19:56 +02:00
Viktor Liu
04676c5368 Remove auto sections 2025-06-24 14:17:15 +02:00
Viktor Liu
b53a517bf8 Don't open cmd.exe during MSI actions 2025-06-24 13:01:20 +02:00
Viktor Liu
f37aa2cc9d [misc] Specify netbird binary location in Dockerfiles (#4024) 2025-06-23 10:09:02 +02:00
Maycon Santos
5343bee7b2 [management] check and log on new management version (#4029)
This PR enhances the version checker to send a custom User-Agent header when polling for updates, and configures both the management CLI and client UI to use distinct agents. 

- NewUpdate now takes an `httpAgent` string to set the User-Agent header.
- `fetchVersion` builds a custom HTTP request (instead of `http.Get`) and sets the User-Agent.
- Management CLI and client UI now pass `"nb/management"` and `"nb/client-ui"` respectively to NewUpdate.
- Tests updated to supply an `httpAgent` constant.
- Logs if there is a new version available for management
2025-06-22 16:44:33 +02:00
Maycon Santos
870e29db63 [misc] add additional metrics (#4028)
* add additional metrics

we are collecting active rosenpass, ssh from the client side
we are also collecting active user peers and active users

* remove duplicated
2025-06-22 13:44:25 +02:00
Maycon Santos
08e9b05d51 [client] close windows when process needs to exit (#4027)
This PR fixes a bug by ensuring that the advanced settings and re-authentication windows are closed appropriately when the main GUI process exits.

- Updated runSelfCommand calls throughout the UI to pass a context parameter.
- Modified runSelfCommand’s signature and its internal command invocation to use exec.CommandContext for proper cancellation handling.
2025-06-22 10:33:04 +02:00
hakansa
3581648071 [client] Refactor showLoginURL to improve error handling and connection status checks (#4026)
This PR refactors showLoginURL to improve error handling and connection status checks by delaying the login fetch until user interaction and closing the pop-up if already connected.

- Moved s.login(false) call into the click handler to defer network I/O.
- Added a conn.Status check after opening the URL to skip reconnection if already connected.
- Enhanced error logs for missing verification URLs and service status failures.
2025-06-22 10:03:58 +02:00
Viktor Liu
2a51609436 [client] Handle lazy routing peers that are part of HA groups (#3943)
* Activate new lazy routing peers if the HA group is active
* Prevent lazy peers going to idle if HA group members are active (#3948)
2025-06-20 18:07:19 +02:00
Pascal Fischer
83457f8b99 [management] add transaction for integrated validator groups update and primary account update (#4014) 2025-06-20 12:13:24 +02:00
Pascal Fischer
b45284f086 [management] export ephemeral peer flag on api (#4004) 2025-06-19 16:46:56 +02:00
Bethuel Mmbaga
e9016aecea [management] Add backward compatibility for older clients without firewall rules port range support (#4003)
Adds backward compatibility for clients with versions prior to v0.48.0 that do not support port range firewall rules.

- Skips generation of firewall rules with multi-port ranges for older clients
- Preserves support for single-port ranges by treating them as individual port rules, ensuring compatibility with older clients
2025-06-19 13:07:06 +03:00
Viktor Liu
23b5d45b68 [client] Fix port range squashing (#4007) 2025-06-18 18:56:48 +02:00
Viktor Liu
0e5dc9d412 [client] Add more Android advanced settings (#4001) 2025-06-18 17:23:23 +02:00
Zoltan Papp
91f7ee6a3c Fix route notification
On Android ignore the dynamic roots in the route notifications
2025-06-18 16:49:03 +02:00
Bethuel Mmbaga
7c6b85b4cb [management] Refactor routes to use store methods (#2928) 2025-06-18 16:40:29 +03:00
hakansa
08c9107c61 [client] fix connection state handling (#3995)
[client] fix connection state handling (#3995)
2025-06-17 17:14:08 +03:00
hakansa
81d83245e1 [client] Fix logic in updateStatus to correctly handle connection state (#3994)
[client] Fix logic in updateStatus to correctly handle connection state (#3994)
2025-06-17 17:02:04 +03:00
Maycon Santos
af2b427751 [management] Avoid recalculating next peer expiration (#3991)
* Avoid recalculating next peer expiration

- Check if an account schedule is already running
- Cancel executing schedules only when changes occurs
- Add more context info to logs

* fix tests
2025-06-17 15:14:11 +02:00
hakansa
f61ebdb3bc [client] Fix DNS Interceptor Build Error (#3993)
[client] Fix DNS Interceptor Build Error
2025-06-17 16:07:14 +03:00
Viktor Liu
de7384e8ea [client] Tighten allowed domains for dns forwarder (#3978) 2025-06-17 14:03:00 +02:00
Viktor Liu
75c1be69cf [client] Prioritze the local resolver in the dns handler chain (#3965) 2025-06-17 14:02:30 +02:00
hakansa
424ae28de9 [client] Fix UI Download URL (#3990)
[client] Fix UI Download URL
2025-06-17 11:55:48 +03:00
Viktor Liu
d4a800edd5 [client] Fix status recorder panic (#3988) 2025-06-17 01:20:26 +02:00
Maycon Santos
dd9917f1a8 [misc] add missing images (#3987) 2025-06-16 21:05:49 +02:00
Viktor Liu
8df8c1012f [client] Support wildcard DNS on iOS (#3979) 2025-06-16 18:33:51 +02:00
Viktor Liu
bfa5c21d2d [client] Improve icmp conntrack log (#3963) 2025-06-16 10:12:59 +02:00
Maycon Santos
b1247a14ba [management] Use xID for setup key IDs to avoid id collisions (#3977)
This PR addresses potential ID collisions by switching the setup key ID generation from a hash-based approach to using xid-generated IDs.

Replace the hash function with xid.New().String()
Remove obsolete imports and the Hash() function
2025-06-14 12:24:16 +01:00
Philippe Vaucher
f595057a0b [signal] Set flags from environment variables (#3972) 2025-06-14 00:08:34 +02:00
hakansa
089d442fb2 [client] Display login popup on session expiration (#3955)
This PR implements a feature enhancement to display a login popup when the session expires. Key changes include updating flag handling and client construction to support a new login URL popup, revising login and notification handling logic to use the new popup, and updating status and server-side session state management accordingly.
2025-06-13 23:51:57 +02:00
Viktor Liu
04a3765391 [client] Fix unncessary UI updates (#3785) 2025-06-13 20:38:50 +02:00
Zoltan Papp
d24d8328f9 [client] Propagation networks for Android client (#3966)
Add networks propagation
2025-06-13 11:04:17 +02:00
Vlad
4f63996ae8 [management] added events streaming metrics (#3814) 2025-06-12 18:48:54 +01:00
96 changed files with 3494 additions and 2589 deletions

View File

@@ -1,21 +0,0 @@
name: Git Town
on:
pull_request:
branches:
- '**'
jobs:
git-town:
name: Display the branch stack
runs-on: ubuntu-latest
permissions:
contents: read
pull-requests: write
steps:
- uses: actions/checkout@v4
- uses: git-town/action@v1
with:
skip-single-stacks: true

View File

@@ -1,46 +0,0 @@
name: "Darwin"
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test:
name: "Client / Unit"
runs-on: macos-latest
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Cache Go modules
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: macos-gotest-${{ hashFiles('**/go.sum') }}
restore-keys: |
macos-gotest-
macos-go-
- name: Install libpcap
run: brew install libpcap
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)

View File

@@ -1,52 +0,0 @@
name: "FreeBSD"
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test:
name: "Client / Unit"
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- name: Test in FreeBSD
id: test
uses: vmactions/freebsd-vm@v1
with:
usesh: true
copyback: false
release: "14.2"
prepare: |
pkg install -y curl pkgconf xorg
LATEST_VERSION=$(curl -s https://go.dev/VERSION?m=text|head -n 1)
GO_TARBALL="$LATEST_VERSION.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -vLO "$GO_URL"
tar -C /usr/local -vxzf "$GO_TARBALL"
# -x - to print all executed commands
# -e - to faile on first error
run: |
set -e -x
export PATH=$PATH:/usr/local/go/bin:$HOME/go/bin
time go build -o netbird client/main.go
# check all component except management, since we do not support management server on freebsd
time go test -timeout 1m -failfast ./base62/...
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
time go test -timeout 8m -failfast -p 1 ./client/...
time go test -timeout 1m -failfast ./dns/...
time go test -timeout 1m -failfast ./encryption/...
time go test -timeout 1m -failfast ./formatter/...
time go test -timeout 1m -failfast ./client/iface/...
time go test -timeout 1m -failfast ./route/...
time go test -timeout 1m -failfast ./sharedsock/...
time go test -timeout 1m -failfast ./signal/...
time go test -timeout 1m -failfast ./util/...
time go test -timeout 1m -failfast ./version/...

View File

@@ -1,568 +0,0 @@
name: Linux
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
build-cache:
name: "Build Cache"
runs-on: ubuntu-22.04
outputs:
management: ${{ steps.filter.outputs.management }}
steps:
- name: Checkout code
uses: actions/checkout@v4
- uses: dorny/paths-filter@v3
id: filter
with:
filters: |
management:
- 'management/**'
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@v4
id: cache
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: steps.cache.outputs.cache-hit != 'true'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Build client
if: steps.cache.outputs.cache-hit != 'true'
working-directory: client
run: CGO_ENABLED=1 go build .
- name: Build client 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: client
run: CGO_ENABLED=1 GOARCH=386 go build -o client-386 .
- name: Build management
if: steps.cache.outputs.cache-hit != 'true'
working-directory: management
run: CGO_ENABLED=1 go build .
- name: Build management 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: management
run: CGO_ENABLED=1 GOARCH=386 go build -o management-386 .
- name: Build signal
if: steps.cache.outputs.cache-hit != 'true'
working-directory: signal
run: CGO_ENABLED=1 go build .
- name: Build signal 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: signal
run: CGO_ENABLED=1 GOARCH=386 go build -o signal-386 .
- name: Build relay
if: steps.cache.outputs.cache-hit != 'true'
working-directory: relay
run: CGO_ENABLED=1 go build .
- name: Build relay 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: relay
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
test:
name: "Client / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
test_client_on_docker:
name: "Client (Docker) / Unit"
needs: [build-cache]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
id: go-env
run: |
echo "cache_dir=$(go env GOCACHE)" >> $GITHUB_OUTPUT
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
- name: Cache Go modules
uses: actions/cache/restore@v4
id: cache-restore
with:
path: |
${{ steps.go-env.outputs.cache_dir }}
${{ steps.go-env.outputs.modcache_dir }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Run tests in container
env:
HOST_GOCACHE: ${{ steps.go-env.outputs.cache_dir }}
HOST_GOMODCACHE: ${{ steps.go-env.outputs.modcache_dir }}
run: |
CONTAINER_GOCACHE="/root/.cache/go-build"
CONTAINER_GOMODCACHE="/go/pkg/mod"
docker run --rm \
--cap-add=NET_ADMIN \
--privileged \
-v $PWD:/app \
-w /app \
-v "${HOST_GOCACHE}:${CONTAINER_GOCACHE}" \
-v "${HOST_GOMODCACHE}:${CONTAINER_GOMODCACHE}" \
-e CGO_ENABLED=1 \
-e CI=true \
-e DOCKER_CI=true \
-e GOARCH=${GOARCH_TARGET} \
-e GOCACHE=${CONTAINER_GOCACHE} \
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
golang:1.23-alpine \
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
'
test_relay:
name: "Relay / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test \
-exec 'sudo' \
-timeout 10m ./signal/...
test_signal:
name: "Signal / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test \
-exec 'sudo' \
-timeout 10m ./signal/...
test_management:
name: "Management / Unit"
needs: [ build-cache ]
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Login to Docker hub
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
go test -tags=devcert \
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
-timeout 20m ./management/...
benchmark:
name: "Management / Benchmark"
needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Login to Docker hub
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
go test -tags devcert -run=^$ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/...
api_benchmark:
name: "Management / Benchmark (API)"
needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Create Docker network
run: docker network create promnet
- name: Start Prometheus Pushgateway
run: docker run -d --name pushgateway --network promnet -p 9091:9091 prom/pushgateway
- name: Start Prometheus (for Pushgateway forwarding)
run: |
echo '
global:
scrape_interval: 15s
scrape_configs:
- job_name: "pushgateway"
static_configs:
- targets: ["pushgateway:9091"]
remote_write:
- url: ${{ secrets.GRAFANA_URL }}
basic_auth:
username: ${{ secrets.GRAFANA_USER }}
password: ${{ secrets.GRAFANA_API_KEY }}
' > prometheus.yml
docker run -d --name prometheus --network promnet \
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
-p 9090:9090 \
prom/prometheus
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Login to Docker hub
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
GIT_BRANCH=${{ github.ref_name }} \
go test -tags=benchmark \
-run=^$ \
-bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
-timeout 20m ./management/...
api_integration_test:
name: "Management / Integration"
needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres']
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
go test -tags=integration \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/...

View File

@@ -1,72 +0,0 @@
name: "Windows"
on:
push:
branches:
- main
pull_request:
env:
downloadPath: '${{ github.workspace }}\temp'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test:
name: "Client / Unit"
runs-on: windows-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
id: go
with:
go-version: "1.23.x"
cache: false
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $env:GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-
${{ runner.os }}-go-
- name: Download wintun
uses: carlosperate/download-file-action@v2
id: download-wintun
with:
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
file-name: wintun.zip
location: ${{ env.downloadPath }}
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
- name: Decompressing wintun files
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
- run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'
- run: choco install -y sysinternals --ignore-checksums
- run: choco install -y mingw
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
- name: test output
if: ${{ always() }}
run: Get-Content test-out.txt

View File

@@ -1,58 +0,0 @@
name: Lint
on: [pull_request]
permissions:
contents: read
pull-requests: read
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
codespell:
name: codespell
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
skip: go.mod,go.sum
golangci:
strategy:
fail-fast: false
matrix:
os: [macos-latest, windows-latest, ubuntu-latest]
include:
- os: macos-latest
display_name: Darwin
- os: windows-latest
display_name: Windows
- os: ubuntu-latest
display_name: Linux
name: ${{ matrix.display_name }}
runs-on: ${{ matrix.os }}
timeout-minutes: 15
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest'
run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Install dependencies
if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: golangci-lint
uses: golangci/golangci-lint-action@v4
with:
version: latest
args: --timeout=12m --out-format colored-line-number

View File

@@ -1,37 +0,0 @@
name: Test installation
on:
push:
branches:
- main
pull_request:
paths:
- "release_files/install.sh"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test-install-script:
strategy:
fail-fast: false
max-parallel: 2
matrix:
os: [ubuntu-latest, macos-latest]
skip_ui_mode: [true, false]
install_binary: [true, false]
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: run install script
env:
SKIP_UI_APP: ${{ matrix.skip_ui_mode }}
USE_BIN_INSTALL: ${{ matrix.install_binary }}
GITHUB_TOKEN: ${{ secrets.RO_API_CALLER_TOKEN }}
run: |
[ "$SKIP_UI_APP" == "false" ] && export XDG_CURRENT_DESKTOP="none"
cat release_files/install.sh | sh -x
- name: check cli binary
run: command -v netbird

View File

@@ -1,67 +0,0 @@
name: Mobile
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
android_build:
name: "Android / Build"
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
- name: Setup Android SDK
uses: android-actions/setup-android@v3
with:
cmdline-tools-version: 8512546
- name: Setup Java
uses: actions/setup-java@v4
with:
java-version: "11"
distribution: "adopt"
- name: NDK Cache
id: ndk-cache
uses: actions/cache@v4
with:
path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620
- name: Setup NDK
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
- name: gomobile init
run: gomobile init
- name: build android netbird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
env:
CGO_ENABLED: 0
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
ios_build:
name: "iOS / Build"
runs-on: macos-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
- name: gomobile init
run: gomobile init
- name: build iOS netbird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o ./NetBirdSDK.xcframework ./client/ios/NetBirdSDK
env:
CGO_ENABLED: 0

View File

@@ -55,23 +55,23 @@ jobs:
run: go mod tidy run: go mod tidy
- name: check git status - name: check git status
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Set up QEMU # - name: Set up QEMU
uses: docker/setup-qemu-action@v2 # uses: docker/setup-qemu-action@v2
- name: Set up Docker Buildx # - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2 # uses: docker/setup-buildx-action@v2
- name: Login to Docker hub # - name: Login to Docker hub
if: github.event_name != 'pull_request' # if: github.event_name != 'pull_request'
uses: docker/login-action@v1 # uses: docker/login-action@v1
with: # with:
username: ${{ secrets.DOCKER_USER }} # username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }} # password: ${{ secrets.DOCKER_TOKEN }}
- name: Log in to the GitHub container registry # - name: Log in to the GitHub container registry
if: github.event_name != 'pull_request' # if: github.event_name != 'pull_request'
uses: docker/login-action@v3 # uses: docker/login-action@v3
with: # with:
registry: ghcr.io # registry: ghcr.io
username: ${{ github.actor }} # username: ${{ github.actor }}
password: ${{ secrets.CI_DOCKER_PUSH_GITHUB_TOKEN }} # password: ${{ secrets.CI_DOCKER_PUSH_GITHUB_TOKEN }}
- name: Install OS build dependencies - name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu

View File

@@ -1,22 +0,0 @@
name: sync main
on:
push:
branches:
- main
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
trigger_sync_main:
runs-on: ubuntu-latest
steps:
- name: Trigger main branch sync
uses: benc-uk/workflow-dispatch@v1
with:
workflow: sync-main.yml
repo: ${{ secrets.UPSTREAM_REPO }}
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "sha": "${{ github.sha }}" }'

View File

@@ -1,23 +0,0 @@
name: sync tag
on:
push:
tags:
- 'v*'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
trigger_sync_tag:
runs-on: ubuntu-latest
steps:
- name: Trigger release tag sync
uses: benc-uk/workflow-dispatch@v1
with:
workflow: sync-tag.yml
ref: main
repo: ${{ secrets.UPSTREAM_REPO }}
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref_name }}" }'

View File

@@ -1,310 +0,0 @@
name: Test Infrastructure files
on:
push:
branches:
- main
pull_request:
paths:
- 'infrastructure_files/**'
- '.github/workflows/test-infrastructure-files.yml'
- 'management/cmd/**'
- 'signal/cmd/**'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test-docker-compose:
runs-on: ubuntu-latest
strategy:
matrix:
store: [ 'sqlite', 'postgres', 'mysql' ]
services:
postgres:
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
env:
POSTGRES_USER: netbird
POSTGRES_PASSWORD: postgres
POSTGRES_DB: netbird
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
ports:
- 5432:5432
mysql:
image: ${{ (matrix.store == 'mysql') && 'mysql' || '' }}
env:
MYSQL_USER: netbird
MYSQL_PASSWORD: mysql
MYSQL_ROOT_PASSWORD: mysqlroot
MYSQL_DATABASE: netbird
options: >-
--health-cmd "mysqladmin ping --silent"
--health-interval 10s
--health-timeout 5s
ports:
- 3306:3306
steps:
- name: Set Database Connection String
run: |
if [ "${{ matrix.store }}" == "postgres" ]; then
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN=host=$(hostname -I | awk '{print $1}') user=netbird password=postgres dbname=netbird port=5432" >> $GITHUB_ENV
else
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN==" >> $GITHUB_ENV
fi
if [ "${{ matrix.store }}" == "mysql" ]; then
echo "NETBIRD_STORE_ENGINE_MYSQL_DSN=netbird:mysql@tcp($(hostname -I | awk '{print $1}'):3306)/netbird" >> $GITHUB_ENV
else
echo "NETBIRD_STORE_ENGINE_MYSQL_DSN==" >> $GITHUB_ENV
fi
- name: Install jq
run: sudo apt-get install -y jq
- name: Install curl
run: sudo apt-get install -y curl
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
- name: Cache Go modules
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v4
- name: cp setup.env
run: cp infrastructure_files/tests/setup.env infrastructure_files/
- name: run configure
working-directory: infrastructure_files
run: bash -x configure.sh
env:
CI_NETBIRD_DOMAIN: localhost
CI_NETBIRD_AUTH_CLIENT_ID: testing.client.id
CI_NETBIRD_AUTH_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_AUTH_AUDIENCE: testing.ci
CI_NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT: https://example.eu.auth0.com/.well-known/openid-configuration
CI_NETBIRD_USE_AUTH0: true
CI_NETBIRD_MGMT_IDP: "none"
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
NETBIRD_STORE_ENGINE_POSTGRES_DSN: ${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}
NETBIRD_STORE_ENGINE_MYSQL_DSN: ${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
- name: check values
working-directory: infrastructure_files/artifacts
env:
CI_NETBIRD_DOMAIN: localhost
CI_NETBIRD_AUTH_CLIENT_ID: testing.client.id
CI_NETBIRD_AUTH_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_AUTH_AUDIENCE: testing.ci
CI_NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT: https://example.eu.auth0.com/.well-known/openid-configuration
CI_NETBIRD_USE_AUTH0: true
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
CI_NETBIRD_AUTH_AUTHORITY: https://example.eu.auth0.com/
CI_NETBIRD_AUTH_JWT_CERTS: https://example.eu.auth0.com/.well-known/jwks.json
CI_NETBIRD_AUTH_TOKEN_ENDPOINT: https://example.eu.auth0.com/oauth/token
CI_NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT: https://example.eu.auth0.com/oauth/device/code
CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT: https://example.eu.auth0.com/authorize
CI_NETBIRD_AUTH_REDIRECT_URI: "/peers"
CI_NETBIRD_TOKEN_SOURCE: "idToken"
CI_NETBIRD_AUTH_USER_ID_CLAIM: "email"
CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE: "super"
CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE: "openid email"
CI_NETBIRD_MGMT_IDP: "none"
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_SIGNAL_PORT: 12345
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
run: |
set -x
grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID
grep AUTH_CLIENT_SECRET docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_SECRET
grep AUTH_AUTHORITY docker-compose.yml | grep $CI_NETBIRD_AUTH_AUTHORITY
grep AUTH_AUDIENCE docker-compose.yml | grep $CI_NETBIRD_AUTH_AUDIENCE
grep AUTH_SUPPORTED_SCOPES docker-compose.yml | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
grep USE_AUTH0 docker-compose.yml | grep $CI_NETBIRD_USE_AUTH0
grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "$CI_NETBIRD_DOMAIN:33073"
grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI
grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$'
grep $CI_NETBIRD_SIGNAL_PORT docker-compose.yml | grep ':80'
grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$'
grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE
grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM
grep -A 3 DeviceAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
grep -A 3 DeviceAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
grep Engine management.json | grep "$CI_NETBIRD_STORE_CONFIG_ENGINE"
grep IdpSignKeyRefreshEnabled management.json | grep "$CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH"
grep UseIDToken management.json | grep false
grep -A 1 IdpManagerConfig management.json | grep ManagerType | grep $CI_NETBIRD_MGMT_IDP
grep -A 3 IdpManagerConfig management.json | grep -A 1 ClientConfig | grep Issuer | grep $CI_NETBIRD_AUTH_AUTHORITY
grep -A 4 IdpManagerConfig management.json | grep -A 2 ClientConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
grep -A 5 IdpManagerConfig management.json | grep -A 3 ClientConfig | grep ClientID | grep $CI_NETBIRD_IDP_MGMT_CLIENT_ID
grep -A 6 IdpManagerConfig management.json | grep -A 4 ClientConfig | grep ClientSecret | grep $CI_NETBIRD_IDP_MGMT_CLIENT_SECRET
grep -A 7 IdpManagerConfig management.json | grep -A 5 ClientConfig | grep GrantType | grep client_credentials
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_AUDIENCE
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep ClientID | grep $CI_NETBIRD_AUTH_CLIENT_ID
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep ClientSecret | grep $CI_NETBIRD_AUTH_CLIENT_SECRET
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep AuthorizationEndpoint | grep $CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
# check relay values
grep "NB_EXPOSED_ADDRESS=rels://$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
grep '33445:33445' docker-compose.yml
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
grep -A 7 Relay management.json | grep "rels://$CI_NETBIRD_DOMAIN:33445"
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
grep DisablePromptLogin management.json | grep 'true'
grep LoginFlag management.json | grep 0
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Build management binary
working-directory: management
run: CGO_ENABLED=1 go build -o netbird-mgmt main.go
- name: Build management docker image
working-directory: management
run: |
docker build -t netbirdio/management:latest .
- name: Build signal binary
working-directory: signal
run: CGO_ENABLED=0 go build -o netbird-signal main.go
- name: Build signal docker image
working-directory: signal
run: |
docker build -t netbirdio/signal:latest .
- name: Build relay binary
working-directory: relay
run: CGO_ENABLED=0 go build -o netbird-relay main.go
- name: Build relay docker image
working-directory: relay
run: |
docker build -t netbirdio/relay:latest .
- name: run docker compose up
working-directory: infrastructure_files/artifacts
run: |
docker compose up -d
sleep 5
docker compose ps
docker compose logs --tail=20
- name: test running containers
run: |
count=$(docker compose ps --format json | jq '. | select(.Name | contains("artifacts")) | .State' | grep -c running)
test $count -eq 5 || docker compose logs
working-directory: infrastructure_files/artifacts
- name: test geolocation databases
working-directory: infrastructure_files/artifacts
run: |
sleep 30
docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City_[0-9]*.mmdb
docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames_[0-9]*.db
test-getting-started-script:
runs-on: ubuntu-latest
steps:
- name: Install jq
run: sudo apt-get install -y jq
- name: Checkout code
uses: actions/checkout@v4
- name: run script with Zitadel PostgreSQL
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
- name: test Caddy file gen postgres
run: test -f Caddyfile
- name: test docker-compose file gen postgres
run: test -f docker-compose.yml
- name: test management.json file gen postgres
run: test -f management.json
- name: test turnserver.conf file gen postgres
run: |
set -x
test -f turnserver.conf
grep external-ip turnserver.conf
- name: test zitadel.env file gen postgres
run: test -f zitadel.env
- name: test dashboard.env file gen postgres
run: test -f dashboard.env
- name: test relay.env file gen postgres
run: test -f relay.env
- name: test zdb.env file gen postgres
run: test -f zdb.env
- name: Postgres run cleanup
run: |
docker compose down --volumes --rmi all
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
- name: run script with Zitadel CockroachDB
run: bash -x infrastructure_files/getting-started-with-zitadel.sh
env:
NETBIRD_DOMAIN: use-ip
ZITADEL_DATABASE: cockroach
- name: test Caddy file gen CockroachDB
run: test -f Caddyfile
- name: test docker-compose file gen CockroachDB
run: test -f docker-compose.yml
- name: test management.json file gen CockroachDB
run: test -f management.json
- name: test turnserver.conf file gen CockroachDB
run: |
set -x
test -f turnserver.conf
grep external-ip turnserver.conf
- name: test zitadel.env file gen CockroachDB
run: test -f zitadel.env
- name: test dashboard.env file gen CockroachDB
run: test -f dashboard.env
- name: test relay.env file gen CockroachDB
run: test -f relay.env

View File

@@ -1,22 +0,0 @@
name: update docs
on:
push:
tags:
- 'v*'
paths:
- 'management/server/http/api/openapi.yml'
jobs:
trigger_docs_api_update:
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- name: Trigger API pages generation
uses: benc-uk/workflow-dispatch@v1
with:
workflow: generate api pages
repo: netbirdio/docs
ref: "refs/heads/main"
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}" }'

File diff suppressed because it is too large Load Diff

View File

@@ -79,19 +79,19 @@ nfpms:
dependencies: dependencies:
- netbird - netbird
uploads: # uploads:
- name: debian # - name: debian
ids: # ids:
- netbird-ui-deb # - netbird-ui-deb
mode: archive # mode: archive
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package= # target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
username: dev@wiretrustee.com # username: dev@wiretrustee.com
method: PUT # method: PUT
- name: yum # - name: yum
ids: # ids:
- netbird-ui-rpm # - netbird-ui-rpm
mode: archive # mode: archive
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }} # target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
username: dev@wiretrustee.com # username: dev@wiretrustee.com
method: PUT # method: PUT

View File

@@ -1,6 +1,9 @@
FROM alpine:3.21.3 FROM alpine:3.21.3
# iproute2: busybox doesn't display ip rules properly # iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables
ARG NETBIRD_BINARY=netbird
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
ENV NB_FOREGROUND_MODE=true ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"] ENTRYPOINT [ "/usr/local/bin/netbird","up"]
COPY netbird /usr/local/bin/netbird

View File

@@ -1,6 +1,7 @@
FROM alpine:3.21.0 FROM alpine:3.21.0
COPY netbird /usr/local/bin/netbird ARG NETBIRD_BINARY=netbird
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
RUN apk add --no-cache ca-certificates \ RUN apk add --no-cache ca-certificates \
&& adduser -D -h /var/lib/netbird netbird && adduser -D -h /var/lib/netbird netbird

View File

@@ -59,6 +59,8 @@ type Client struct {
deviceName string deviceName string
uiVersion string uiVersion string
networkChangeListener listener.NetworkChangeListener networkChangeListener listener.NetworkChangeListener
connectClient *internal.ConnectClient
} }
// NewClient instantiate a new Client // NewClient instantiate a new Client
@@ -106,8 +108,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
} }
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@@ -132,8 +134,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
} }
// Stop the internal client and free the resources // Stop the internal client and free the resources
@@ -174,6 +176,53 @@ func (c *Client) PeersList() *PeerInfoArray {
return &PeerInfoArray{items: peerInfos} return &PeerInfoArray{items: peerInfos}
} }
func (c *Client) Networks() *NetworkArray {
if c.connectClient == nil {
log.Error("not connected")
return nil
}
engine := c.connectClient.Engine()
if engine == nil {
log.Error("could not get engine")
return nil
}
routeManager := engine.GetRouteManager()
if routeManager == nil {
log.Error("could not get route manager")
return nil
}
networkArray := &NetworkArray{
items: make([]Network, 0),
}
for id, routes := range routeManager.GetClientRoutesWithNetID() {
if len(routes) == 0 {
continue
}
if routes[0].IsDynamic() {
continue
}
peer, err := c.recorder.GetPeer(routes[0].Peer)
if err != nil {
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
continue
}
network := Network{
Name: string(id),
Network: routes[0].Network.String(),
Peer: peer.FQDN,
Status: peer.ConnStatus.String(),
}
networkArray.Add(network)
}
return networkArray
}
// OnUpdatedHostDNS update the DNS servers addresses for root zones // OnUpdatedHostDNS update the DNS servers addresses for root zones
func (c *Client) OnUpdatedHostDNS(list *DNSList) error { func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
dnsServer, err := dns.GetServerDns() dnsServer, err := dns.GetServerDns()

View File

@@ -0,0 +1,27 @@
//go:build android
package android
type Network struct {
Name string
Network string
Peer string
Status string
}
type NetworkArray struct {
items []Network
}
func (array *NetworkArray) Add(s Network) *NetworkArray {
array.items = append(array.items, s)
return array
}
func (array *NetworkArray) Get(i int) *Network {
return &array.items[i]
}
func (array *NetworkArray) Size() int {
return len(array.items)
}

View File

@@ -7,30 +7,23 @@ type PeerInfo struct {
ConnStatus string // Todo replace to enum ConnStatus string // Todo replace to enum
} }
// PeerInfoCollection made for Java layer to get non default types as collection // PeerInfoArray is a wrapper of []PeerInfo
type PeerInfoCollection interface {
Add(s string) PeerInfoCollection
Get(i int) string
Size() int
}
// PeerInfoArray is the implementation of the PeerInfoCollection
type PeerInfoArray struct { type PeerInfoArray struct {
items []PeerInfo items []PeerInfo
} }
// Add new PeerInfo to the collection // Add new PeerInfo to the collection
func (array PeerInfoArray) Add(s PeerInfo) PeerInfoArray { func (array *PeerInfoArray) Add(s PeerInfo) *PeerInfoArray {
array.items = append(array.items, s) array.items = append(array.items, s)
return array return array
} }
// Get return an element of the collection // Get return an element of the collection
func (array PeerInfoArray) Get(i int) *PeerInfo { func (array *PeerInfoArray) Get(i int) *PeerInfo {
return &array.items[i] return &array.items[i]
} }
// Size return with the size of the collection // Size return with the size of the collection
func (array PeerInfoArray) Size() int { func (array *PeerInfoArray) Size() int {
return len(array.items) return len(array.items)
} }

View File

@@ -4,12 +4,12 @@ import (
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
) )
// Preferences export a subset of the internal config for gomobile // Preferences exports a subset of the internal config for gomobile
type Preferences struct { type Preferences struct {
configInput internal.ConfigInput configInput internal.ConfigInput
} }
// NewPreferences create new Preferences instance // NewPreferences creates a new Preferences instance
func NewPreferences(configPath string) *Preferences { func NewPreferences(configPath string) *Preferences {
ci := internal.ConfigInput{ ci := internal.ConfigInput{
ConfigPath: configPath, ConfigPath: configPath,
@@ -17,7 +17,7 @@ func NewPreferences(configPath string) *Preferences {
return &Preferences{ci} return &Preferences{ci}
} }
// GetManagementURL read url from config file // GetManagementURL reads URL from config file
func (p *Preferences) GetManagementURL() (string, error) { func (p *Preferences) GetManagementURL() (string, error) {
if p.configInput.ManagementURL != "" { if p.configInput.ManagementURL != "" {
return p.configInput.ManagementURL, nil return p.configInput.ManagementURL, nil
@@ -30,12 +30,12 @@ func (p *Preferences) GetManagementURL() (string, error) {
return cfg.ManagementURL.String(), err return cfg.ManagementURL.String(), err
} }
// SetManagementURL store the given url and wait for commit // SetManagementURL stores the given URL and waits for commit
func (p *Preferences) SetManagementURL(url string) { func (p *Preferences) SetManagementURL(url string) {
p.configInput.ManagementURL = url p.configInput.ManagementURL = url
} }
// GetAdminURL read url from config file // GetAdminURL reads URL from config file
func (p *Preferences) GetAdminURL() (string, error) { func (p *Preferences) GetAdminURL() (string, error) {
if p.configInput.AdminURL != "" { if p.configInput.AdminURL != "" {
return p.configInput.AdminURL, nil return p.configInput.AdminURL, nil
@@ -48,12 +48,12 @@ func (p *Preferences) GetAdminURL() (string, error) {
return cfg.AdminURL.String(), err return cfg.AdminURL.String(), err
} }
// SetAdminURL store the given url and wait for commit // SetAdminURL stores the given URL and waits for commit
func (p *Preferences) SetAdminURL(url string) { func (p *Preferences) SetAdminURL(url string) {
p.configInput.AdminURL = url p.configInput.AdminURL = url
} }
// GetPreSharedKey read preshared key from config file // GetPreSharedKey reads pre-shared key from config file
func (p *Preferences) GetPreSharedKey() (string, error) { func (p *Preferences) GetPreSharedKey() (string, error) {
if p.configInput.PreSharedKey != nil { if p.configInput.PreSharedKey != nil {
return *p.configInput.PreSharedKey, nil return *p.configInput.PreSharedKey, nil
@@ -66,17 +66,17 @@ func (p *Preferences) GetPreSharedKey() (string, error) {
return cfg.PreSharedKey, err return cfg.PreSharedKey, err
} }
// SetPreSharedKey store the given key and wait for commit // SetPreSharedKey stores the given key and waits for commit
func (p *Preferences) SetPreSharedKey(key string) { func (p *Preferences) SetPreSharedKey(key string) {
p.configInput.PreSharedKey = &key p.configInput.PreSharedKey = &key
} }
// SetRosenpassEnabled store if rosenpass is enabled // SetRosenpassEnabled stores whether Rosenpass is enabled
func (p *Preferences) SetRosenpassEnabled(enabled bool) { func (p *Preferences) SetRosenpassEnabled(enabled bool) {
p.configInput.RosenpassEnabled = &enabled p.configInput.RosenpassEnabled = &enabled
} }
// GetRosenpassEnabled read rosenpass enabled from config file // GetRosenpassEnabled reads Rosenpass enabled status from config file
func (p *Preferences) GetRosenpassEnabled() (bool, error) { func (p *Preferences) GetRosenpassEnabled() (bool, error) {
if p.configInput.RosenpassEnabled != nil { if p.configInput.RosenpassEnabled != nil {
return *p.configInput.RosenpassEnabled, nil return *p.configInput.RosenpassEnabled, nil
@@ -89,12 +89,12 @@ func (p *Preferences) GetRosenpassEnabled() (bool, error) {
return cfg.RosenpassEnabled, err return cfg.RosenpassEnabled, err
} }
// SetRosenpassPermissive store the given permissive and wait for commit // SetRosenpassPermissive stores the given permissive setting and waits for commit
func (p *Preferences) SetRosenpassPermissive(permissive bool) { func (p *Preferences) SetRosenpassPermissive(permissive bool) {
p.configInput.RosenpassPermissive = &permissive p.configInput.RosenpassPermissive = &permissive
} }
// GetRosenpassPermissive read rosenpass permissive from config file // GetRosenpassPermissive reads Rosenpass permissive setting from config file
func (p *Preferences) GetRosenpassPermissive() (bool, error) { func (p *Preferences) GetRosenpassPermissive() (bool, error) {
if p.configInput.RosenpassPermissive != nil { if p.configInput.RosenpassPermissive != nil {
return *p.configInput.RosenpassPermissive, nil return *p.configInput.RosenpassPermissive, nil
@@ -107,7 +107,119 @@ func (p *Preferences) GetRosenpassPermissive() (bool, error) {
return cfg.RosenpassPermissive, err return cfg.RosenpassPermissive, err
} }
// Commit write out the changes into config file // GetDisableClientRoutes reads disable client routes setting from config file
func (p *Preferences) GetDisableClientRoutes() (bool, error) {
if p.configInput.DisableClientRoutes != nil {
return *p.configInput.DisableClientRoutes, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableClientRoutes, err
}
// SetDisableClientRoutes stores the given value and waits for commit
func (p *Preferences) SetDisableClientRoutes(disable bool) {
p.configInput.DisableClientRoutes = &disable
}
// GetDisableServerRoutes reads disable server routes setting from config file
func (p *Preferences) GetDisableServerRoutes() (bool, error) {
if p.configInput.DisableServerRoutes != nil {
return *p.configInput.DisableServerRoutes, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableServerRoutes, err
}
// SetDisableServerRoutes stores the given value and waits for commit
func (p *Preferences) SetDisableServerRoutes(disable bool) {
p.configInput.DisableServerRoutes = &disable
}
// GetDisableDNS reads disable DNS setting from config file
func (p *Preferences) GetDisableDNS() (bool, error) {
if p.configInput.DisableDNS != nil {
return *p.configInput.DisableDNS, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableDNS, err
}
// SetDisableDNS stores the given value and waits for commit
func (p *Preferences) SetDisableDNS(disable bool) {
p.configInput.DisableDNS = &disable
}
// GetDisableFirewall reads disable firewall setting from config file
func (p *Preferences) GetDisableFirewall() (bool, error) {
if p.configInput.DisableFirewall != nil {
return *p.configInput.DisableFirewall, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableFirewall, err
}
// SetDisableFirewall stores the given value and waits for commit
func (p *Preferences) SetDisableFirewall(disable bool) {
p.configInput.DisableFirewall = &disable
}
// GetServerSSHAllowed reads server SSH allowed setting from config file
func (p *Preferences) GetServerSSHAllowed() (bool, error) {
if p.configInput.ServerSSHAllowed != nil {
return *p.configInput.ServerSSHAllowed, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
if cfg.ServerSSHAllowed == nil {
// Default to false for security on Android
return false, nil
}
return *cfg.ServerSSHAllowed, err
}
// SetServerSSHAllowed stores the given value and waits for commit
func (p *Preferences) SetServerSSHAllowed(allowed bool) {
p.configInput.ServerSSHAllowed = &allowed
}
// GetBlockInbound reads block inbound setting from config file
func (p *Preferences) GetBlockInbound() (bool, error) {
if p.configInput.BlockInbound != nil {
return *p.configInput.BlockInbound, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.BlockInbound, err
}
// SetBlockInbound stores the given value and waits for commit
func (p *Preferences) SetBlockInbound(block bool) {
p.configInput.BlockInbound = &block
}
// Commit writes out the changes to the config file
func (p *Preferences) Commit() error { func (p *Preferences) Commit() error {
_, err := internal.UpdateOrCreateConfig(p.configInput) _, err := internal.UpdateOrCreateConfig(p.configInput)
return err return err

View File

@@ -69,7 +69,10 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return err return err
} }
if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) { status := resp.GetStatus()
if status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
status == string(internal.StatusSessionExpired) {
cmd.Printf("Daemon status: %s\n\n"+ cmd.Printf("Daemon status: %s\n\n"+
"Run UP command to log in with SSO (interactive login):\n\n"+ "Run UP command to log in with SSO (interactive login):\n\n"+
" netbird up \n\n"+ " netbird up \n\n"+

View File

@@ -38,5 +38,5 @@ func init() {
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false, upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+ "Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
"This overrides any policies received from the management service.") "This overrides any policies received from the management service.")
} }

View File

@@ -118,7 +118,7 @@ func tracePacket(cmd *cobra.Command, args []string) error {
} }
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) { func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto)) cmd.Printf("Packet trace %s:%d %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
for _, stage := range resp.Stages { for _, stage := range resp.Stages {
if stage.ForwardingDetails != nil { if stage.ForwardingDetails != nil {

View File

@@ -62,5 +62,5 @@ type ConnKey struct {
} }
func (c ConnKey) String() string { func (c ConnKey) String() string {
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort) return fmt.Sprintf("%s:%d %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
} }

View File

@@ -3,6 +3,7 @@ package conntrack
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"sync" "sync"
"time" "time"
@@ -19,6 +20,10 @@ const (
DefaultICMPTimeout = 30 * time.Second DefaultICMPTimeout = 30 * time.Second
// ICMPCleanupInterval is how often we check for stale ICMP connections // ICMPCleanupInterval is how often we check for stale ICMP connections
ICMPCleanupInterval = 15 * time.Second ICMPCleanupInterval = 15 * time.Second
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info,
// which includes the IP header (20 bytes) and transport header (8 bytes)
MaxICMPPayloadLength = 28
) )
// ICMPConnKey uniquely identifies an ICMP connection // ICMPConnKey uniquely identifies an ICMP connection
@@ -29,7 +34,7 @@ type ICMPConnKey struct {
} }
func (i ICMPConnKey) String() string { func (i ICMPConnKey) String() string {
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID) return fmt.Sprintf("%s %s (id %d)", i.SrcIP, i.DstIP, i.ID)
} }
// ICMPConnTrack represents an ICMP connection state // ICMPConnTrack represents an ICMP connection state
@@ -50,6 +55,72 @@ type ICMPTracker struct {
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
} }
// ICMPInfo holds ICMP type, code, and payload for lazy string formatting in logs
type ICMPInfo struct {
TypeCode layers.ICMPv4TypeCode
PayloadData [MaxICMPPayloadLength]byte
// actual length of valid data
PayloadLen int
}
// String implements fmt.Stringer for lazy evaluation in log messages
func (info ICMPInfo) String() string {
if info.isErrorMessage() && info.PayloadLen >= MaxICMPPayloadLength {
if origInfo := info.parseOriginalPacket(); origInfo != "" {
return fmt.Sprintf("%s (original: %s)", info.TypeCode, origInfo)
}
}
return info.TypeCode.String()
}
// isErrorMessage returns true if this ICMP type carries original packet info
func (info ICMPInfo) isErrorMessage() bool {
typ := info.TypeCode.Type()
return typ == 3 || // Destination Unreachable
typ == 5 || // Redirect
typ == 11 || // Time Exceeded
typ == 12 // Parameter Problem
}
// parseOriginalPacket extracts info about the original packet from ICMP payload
func (info ICMPInfo) parseOriginalPacket() string {
if info.PayloadLen < MaxICMPPayloadLength {
return ""
}
// TODO: handle IPv6
if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 {
return ""
}
protocol := info.PayloadData[9]
srcIP := net.IP(info.PayloadData[12:16])
dstIP := net.IP(info.PayloadData[16:20])
transportData := info.PayloadData[20:]
switch nftypes.Protocol(protocol) {
case nftypes.TCP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
return fmt.Sprintf("TCP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
case nftypes.UDP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
return fmt.Sprintf("UDP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
case nftypes.ICMP:
icmpType := transportData[0]
icmpCode := transportData[1]
return fmt.Sprintf("ICMP %s → %s (type %d code %d)", srcIP, dstIP, icmpType, icmpCode)
default:
return fmt.Sprintf("Proto %d %s → %s", protocol, srcIP, dstIP)
}
}
// NewICMPTracker creates a new ICMP connection tracker // NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker { func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
if timeout == 0 { if timeout == 0 {
@@ -93,30 +164,64 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint
} }
// TrackOutbound records an outbound ICMP connection // TrackOutbound records an outbound ICMP connection
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) { func (t *ICMPTracker) TrackOutbound(
srcIP netip.Addr,
dstIP netip.Addr,
id uint16,
typecode layers.ICMPv4TypeCode,
payload []byte,
size int,
) {
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists { if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction // if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size) t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, payload, size)
} }
} }
// TrackInbound records an inbound ICMP Echo Request // TrackInbound records an inbound ICMP Echo Request
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) { func (t *ICMPTracker) TrackInbound(
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size) srcIP netip.Addr,
dstIP netip.Addr,
id uint16,
typecode layers.ICMPv4TypeCode,
ruleId []byte,
payload []byte,
size int,
) {
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, payload, size)
} }
// track is the common implementation for tracking both inbound and outbound ICMP connections // track is the common implementation for tracking both inbound and outbound ICMP connections
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) { func (t *ICMPTracker) track(
srcIP netip.Addr,
dstIP netip.Addr,
id uint16,
typecode layers.ICMPv4TypeCode,
direction nftypes.Direction,
ruleId []byte,
payload []byte,
size int,
) {
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size) key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
if exists { if exists {
return return
} }
typ, code := typecode.Type(), typecode.Code() typ, code := typecode.Type(), typecode.Code()
icmpInfo := ICMPInfo{
TypeCode: typecode,
}
if len(payload) > 0 {
icmpInfo.PayloadLen = len(payload)
if icmpInfo.PayloadLen > MaxICMPPayloadLength {
icmpInfo.PayloadLen = MaxICMPPayloadLength
}
copy(icmpInfo.PayloadData[:], payload[:icmpInfo.PayloadLen])
}
// non echo requests don't need tracking // non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) { if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code) t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size) t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
return return
} }
@@ -138,7 +243,7 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typec
t.connections[key] = conn t.connections[key] = conn
t.mutex.Unlock() t.mutex.Unlock()
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code) t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
t.sendEvent(nftypes.TypeStart, conn, ruleId) t.sendEvent(nftypes.TypeStart, conn, ruleId)
} }

View File

@@ -15,7 +15,7 @@ func BenchmarkICMPTracker(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, []byte{}, 0)
} }
}) })
@@ -28,7 +28,7 @@ func BenchmarkICMPTracker(b *testing.B) {
// Pre-populate some connections // Pre-populate some connections
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, []byte{}, 0)
} }
b.ResetTimer() b.ResetTimer()

View File

@@ -86,5 +86,5 @@ type epID stack.TransportEndpointID
func (i epID) String() string { func (i epID) String() string {
// src and remote is swapped // src and remote is swapped
return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort) return fmt.Sprintf("%s:%d %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
} }

View File

@@ -111,12 +111,12 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
if errInToOut != nil { if errInToOut != nil {
if !isClosedError(errInToOut) { if !isClosedError(errInToOut) {
f.logger.Error("proxyTCP: copy error (in -> out) for %s: %v", epID(id), errInToOut) f.logger.Error("proxyTCP: copy error (in out) for %s: %v", epID(id), errInToOut)
} }
} }
if errOutToIn != nil { if errOutToIn != nil {
if !isClosedError(errOutToIn) { if !isClosedError(errOutToIn) {
f.logger.Error("proxyTCP: copy error (out -> in) for %s: %v", epID(id), errOutToIn) f.logger.Error("proxyTCP: copy error (out in) for %s: %v", epID(id), errOutToIn)
} }
} }

View File

@@ -250,10 +250,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
wg.Wait() wg.Wait()
if outboundErr != nil && !isClosedError(outboundErr) { if outboundErr != nil && !isClosedError(outboundErr) {
f.logger.Error("proxyUDP: copy error (outbound->inbound) for %s: %v", epID(id), outboundErr) f.logger.Error("proxyUDP: copy error (outboundinbound) for %s: %v", epID(id), outboundErr)
} }
if inboundErr != nil && !isClosedError(inboundErr) { if inboundErr != nil && !isClosedError(inboundErr) {
f.logger.Error("proxyUDP: copy error (inbound->outbound) for %s: %v", epID(id), inboundErr) f.logger.Error("proxyUDP: copy error (inboundoutbound) for %s: %v", epID(id), inboundErr)
} }
var rxPackets, txPackets uint64 var rxPackets, txPackets uint64

View File

@@ -671,7 +671,7 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
flags := getTCPFlags(&d.tcp) flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size) m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
} }
} }
@@ -684,7 +684,7 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
flags := getTCPFlags(&d.tcp) flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size) m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size) m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
} }
} }

View File

@@ -24,6 +24,7 @@ type WGTunDevice struct {
mtu int mtu int
iceBind *bind.ICEBind iceBind *bind.ICEBind
tunAdapter TunAdapter tunAdapter TunAdapter
disableDNS bool
name string name string
device *device.Device device *device.Device
@@ -32,7 +33,7 @@ type WGTunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice { func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
return &WGTunDevice{ return &WGTunDevice{
address: address, address: address,
port: port, port: port,
@@ -40,6 +41,7 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind
mtu: mtu, mtu: mtu,
iceBind: iceBind, iceBind: iceBind,
tunAdapter: tunAdapter, tunAdapter: tunAdapter,
disableDNS: disableDNS,
} }
} }
@@ -49,6 +51,13 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
routesString := routesToString(routes) routesString := routesToString(routes)
searchDomainsToString := searchDomainsToString(searchDomains) searchDomainsToString := searchDomainsToString(searchDomains)
// Skip DNS configuration when DisableDNS is enabled
if t.disableDNS {
log.Info("DNS is disabled, skipping DNS and search domain configuration")
dns = ""
searchDomainsToString = ""
}
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString) fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
if err != nil { if err != nil {
log.Errorf("failed to create Android interface: %s", err) log.Errorf("failed to create Android interface: %s", err)

View File

@@ -43,6 +43,7 @@ type WGIFaceOpts struct {
MobileArgs *device.MobileIFaceArguments MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net TransportNet transport.Net
FilterFn bind.FilterFn FilterFn bind.FilterFn
DisableDNS bool
} }
// WGIface represents an interface instance // WGIface represents an interface instance

View File

@@ -18,7 +18,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
userspaceBind: true, userspaceBind: true,
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter), tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind), wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
} }
return wgIFace, nil return wgIFace, nil

View File

@@ -398,11 +398,15 @@ func (d *DefaultManager) squashAcceptRules(
// //
// We zeroed this to notify squash function that this protocol can't be squashed. // We zeroed this to notify squash function that this protocol can't be squashed.
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) { addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
drop := r.Action == mgmProto.RuleAction_DROP || r.Port != "" hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP ||
if drop { r.Port != "" || !portInfoEmpty(r.PortInfo)
if hasPortRestrictions {
// Don't squash rules with port restrictions
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}} protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
return return
} }
if _, ok := protocols[r.Protocol]; !ok { if _, ok := protocols[r.Protocol]; !ok {
protocols[r.Protocol] = &protoMatch{ protocols[r.Protocol] = &protoMatch{
ips: map[string]int{}, ips: map[string]int{},

View File

@@ -330,6 +330,434 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
assert.Equal(t, len(networkMap.FirewallRules), len(rules)) assert.Equal(t, len(networkMap.FirewallRules), len(rules))
} }
func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) {
tests := []struct {
name string
rules []*mgmProto.FirewallRule
expectedCount int
description string
}{
{
name: "should not squash rules with port ranges",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
},
expectedCount: 4,
description: "Rules with port ranges should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with specific ports",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
},
expectedCount: 4,
description: "Rules with specific ports should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with legacy port field",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
expectedCount: 4,
description: "Rules with legacy port field should not be squashed",
},
{
name: "should not squash rules with DROP action",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "Rules with DROP action should not be squashed",
},
{
name: "should squash rules without port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 1,
description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule",
},
{
name: "mixed rules should not squash protocol with port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "TCP should not be squashed because one rule has port restrictions",
},
{
name: "should squash UDP but not TCP when TCP has port restrictions",
rules: []*mgmProto.FirewallRule{
// TCP rules with port restrictions - should NOT be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
// UDP rules without port restrictions - SHOULD be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
},
expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0)
description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
{AllowedIps: []string{"10.93.0.4"}},
},
FirewallRules: tt.rules,
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, tt.expectedCount, len(rules), tt.description)
// For squashed rules, verify we get the expected 0.0.0.0 rule
if tt.expectedCount == 1 {
assert.Equal(t, "0.0.0.0", rules[0].PeerIP)
assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action)
}
})
}
}
func TestPortInfoEmpty(t *testing.T) {
tests := []struct {
name string
portInfo *mgmProto.PortInfo
expected bool
}{
{
name: "nil PortInfo should be empty",
portInfo: nil,
expected: true,
},
{
name: "PortInfo with zero port should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 0,
},
},
expected: true,
},
{
name: "PortInfo with valid port should not be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
expected: false,
},
{
name: "PortInfo with nil range should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: nil,
},
},
expected: true,
},
{
name: "PortInfo with zero start range should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 0,
End: 100,
},
},
},
expected: true,
},
{
name: "PortInfo with zero end range should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 80,
End: 0,
},
},
},
expected: true,
},
{
name: "PortInfo with valid range should not be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := portInfoEmpty(tt.portInfo)
assert.Equal(t, tt.expected, result)
})
}
}
func TestDefaultManagerEnableSSHRules(t *testing.T) { func TestDefaultManagerEnableSSHRules(t *testing.T) {
networkMap := &mgmProto.NetworkMap{ networkMap := &mgmProto.NetworkMap{
PeerConfig: &mgmProto.PeerConfig{ PeerConfig: &mgmProto.PeerConfig{

View File

@@ -223,6 +223,8 @@ func createNewConfig(input ConfigInput) (*Config, error) {
config := &Config{ config := &Config{
// defaults to false only for new (post 0.26) configurations // defaults to false only for new (post 0.26) configurations
ServerSSHAllowed: util.False(), ServerSSHAllowed: util.False(),
// default to disabling server routes on Android for security
DisableServerRoutes: runtime.GOOS == "android",
} }
if _, err := config.apply(input); err != nil { if _, err := config.apply(input); err != nil {
@@ -416,9 +418,15 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
config.ServerSSHAllowed = input.ServerSSHAllowed config.ServerSSHAllowed = input.ServerSSHAllowed
updated = true updated = true
} else if config.ServerSSHAllowed == nil { } else if config.ServerSSHAllowed == nil {
// enables SSH for configs from old versions to preserve backwards compatibility if runtime.GOOS == "android" {
log.Infof("falling back to enabled SSH server for pre-existing configuration") // default to disabled SSH on Android for security
config.ServerSSHAllowed = util.True() log.Infof("setting SSH server to false by default on Android")
config.ServerSSHAllowed = util.False()
} else {
// enables SSH for configs from old versions to preserve backwards compatibility
log.Infof("falling back to enabled SSH server for pre-existing configuration")
config.ServerSSHAllowed = util.True()
}
updated = true updated = true
} }

View File

@@ -175,7 +175,7 @@ func (e *ConnMgr) AddPeerConn(ctx context.Context, peerKey string, conn *peer.Co
PeerConnID: conn.ConnID(), PeerConnID: conn.ConnID(),
Log: conn.Log, Log: conn.Log,
} }
excluded, err := e.lazyConnMgr.AddPeer(lazyPeerCfg) excluded, err := e.lazyConnMgr.AddPeer(e.lazyCtx, lazyPeerCfg)
if err != nil { if err != nil {
conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err) conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err)
if err := conn.Open(ctx); err != nil { if err := conn.Open(ctx); err != nil {

View File

@@ -11,9 +11,10 @@ import (
) )
const ( const (
PriorityDNSRoute = 100 PriorityLocal = 100
PriorityMatchDomain = 50 PriorityDNSRoute = 75
PriorityDefault = 1 PriorityUpstream = 50
PriorityDefault = 1
) )
type SubdomainMatcher interface { type SubdomainMatcher interface {

View File

@@ -22,7 +22,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
// Setup handlers with different priorities // Setup handlers with different priorities
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault) chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault)
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain) chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityUpstream)
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute) chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute)
// Create test request // Create test request
@@ -200,7 +200,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
priority int priority int
}{ }{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault}, {pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain}, {pattern: "*.example.com.", priority: nbdns.PriorityUpstream},
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute}, {pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
}, },
queryDomain: "test.example.com.", queryDomain: "test.example.com.",
@@ -214,7 +214,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
priority int priority int
}{ }{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault}, {pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain}, {pattern: "test.example.com.", priority: nbdns.PriorityUpstream},
{pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute}, {pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute},
}, },
queryDomain: "sub.test.example.com.", queryDomain: "sub.test.example.com.",
@@ -281,7 +281,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
// Add handlers in priority order // Add handlers in priority order
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute) chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute)
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain) chain.AddHandler("example.com.", handler2, nbdns.PriorityUpstream)
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault) chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault)
// Create test request // Create test request
@@ -344,13 +344,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
priority int priority int
}{ }{
{"add", "example.com.", nbdns.PriorityDNSRoute}, {"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityMatchDomain}, {"add", "example.com.", nbdns.PriorityUpstream},
{"remove", "example.com.", nbdns.PriorityDNSRoute}, {"remove", "example.com.", nbdns.PriorityDNSRoute},
}, },
query: "example.com.", query: "example.com.",
expectedCalls: map[int]bool{ expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: false, nbdns.PriorityDNSRoute: false,
nbdns.PriorityMatchDomain: true, nbdns.PriorityUpstream: true,
}, },
}, },
{ {
@@ -361,13 +361,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
priority int priority int
}{ }{
{"add", "example.com.", nbdns.PriorityDNSRoute}, {"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityMatchDomain}, {"add", "example.com.", nbdns.PriorityUpstream},
{"remove", "example.com.", nbdns.PriorityMatchDomain}, {"remove", "example.com.", nbdns.PriorityUpstream},
}, },
query: "example.com.", query: "example.com.",
expectedCalls: map[int]bool{ expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: true, nbdns.PriorityDNSRoute: true,
nbdns.PriorityMatchDomain: false, nbdns.PriorityUpstream: false,
}, },
}, },
{ {
@@ -378,16 +378,16 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
priority int priority int
}{ }{
{"add", "example.com.", nbdns.PriorityDNSRoute}, {"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityMatchDomain}, {"add", "example.com.", nbdns.PriorityUpstream},
{"add", "example.com.", nbdns.PriorityDefault}, {"add", "example.com.", nbdns.PriorityDefault},
{"remove", "example.com.", nbdns.PriorityDNSRoute}, {"remove", "example.com.", nbdns.PriorityDNSRoute},
{"remove", "example.com.", nbdns.PriorityMatchDomain}, {"remove", "example.com.", nbdns.PriorityUpstream},
}, },
query: "example.com.", query: "example.com.",
expectedCalls: map[int]bool{ expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: false, nbdns.PriorityDNSRoute: false,
nbdns.PriorityMatchDomain: false, nbdns.PriorityUpstream: false,
nbdns.PriorityDefault: true, nbdns.PriorityDefault: true,
}, },
}, },
} }
@@ -454,7 +454,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
// Add handlers in mixed order // Add handlers in mixed order
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault) chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute) chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain) chain.AddHandler(testDomain, matchHandler, nbdns.PriorityUpstream)
// Test 1: Initial state // Test 1: Initial state
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
@@ -490,7 +490,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
defaultHandler.Calls = nil defaultHandler.Calls = nil
// Test 3: Remove middle priority handler // Test 3: Remove middle priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain) chain.RemoveHandler(testDomain, nbdns.PriorityUpstream)
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Now lowest priority handler (defaultHandler) should be called // Now lowest priority handler (defaultHandler) should be called
@@ -607,7 +607,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
shouldMatch bool shouldMatch bool
}{ }{
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false}, {"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
{"example.com.", nbdns.PriorityMatchDomain, false, false}, {"example.com.", nbdns.PriorityUpstream, false, false},
{"Example.Com.", nbdns.PriorityDNSRoute, false, true}, {"Example.Com.", nbdns.PriorityDNSRoute, false, true},
}, },
query: "example.com.", query: "example.com.",
@@ -702,8 +702,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int priority int
subdomain bool subdomain bool
}{ }{
{"add", "example.com.", nbdns.PriorityMatchDomain, true}, {"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false}, {"add", "sub.example.com.", nbdns.PriorityUpstream, false},
}, },
query: "sub.example.com.", query: "sub.example.com.",
expectedMatch: "sub.example.com.", expectedMatch: "sub.example.com.",
@@ -717,8 +717,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int priority int
subdomain bool subdomain bool
}{ }{
{"add", "example.com.", nbdns.PriorityMatchDomain, true}, {"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true}, {"add", "sub.example.com.", nbdns.PriorityUpstream, true},
}, },
query: "sub.example.com.", query: "sub.example.com.",
expectedMatch: "sub.example.com.", expectedMatch: "sub.example.com.",
@@ -732,10 +732,10 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int priority int
subdomain bool subdomain bool
}{ }{
{"add", "example.com.", nbdns.PriorityMatchDomain, true}, {"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true}, {"add", "sub.example.com.", nbdns.PriorityUpstream, true},
{"add", "test.sub.example.com.", nbdns.PriorityMatchDomain, false}, {"add", "test.sub.example.com.", nbdns.PriorityUpstream, false},
{"remove", "test.sub.example.com.", nbdns.PriorityMatchDomain, false}, {"remove", "test.sub.example.com.", nbdns.PriorityUpstream, false},
}, },
query: "test.sub.example.com.", query: "test.sub.example.com.",
expectedMatch: "sub.example.com.", expectedMatch: "sub.example.com.",
@@ -749,7 +749,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int priority int
subdomain bool subdomain bool
}{ }{
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false}, {"add", "sub.example.com.", nbdns.PriorityUpstream, false},
{"add", "example.com.", nbdns.PriorityDNSRoute, true}, {"add", "example.com.", nbdns.PriorityDNSRoute, true},
}, },
query: "sub.example.com.", query: "sub.example.com.",
@@ -764,9 +764,9 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int priority int
subdomain bool subdomain bool
}{ }{
{"add", "example.com.", nbdns.PriorityMatchDomain, true}, {"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "other.example.com.", nbdns.PriorityMatchDomain, true}, {"add", "other.example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false}, {"add", "sub.example.com.", nbdns.PriorityUpstream, false},
}, },
query: "sub.example.com.", query: "sub.example.com.",
expectedMatch: "sub.example.com.", expectedMatch: "sub.example.com.",

View File

@@ -527,7 +527,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
muxUpdates = append(muxUpdates, handlerWrapper{ muxUpdates = append(muxUpdates, handlerWrapper{
domain: customZone.Domain, domain: customZone.Domain,
handler: s.localResolver, handler: s.localResolver,
priority: PriorityMatchDomain, priority: PriorityLocal,
}) })
for _, record := range customZone.Records { for _, record := range customZone.Records {
@@ -566,7 +566,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
groupedNS := groupNSGroupsByDomain(nameServerGroups) groupedNS := groupNSGroupsByDomain(nameServerGroups)
for _, domainGroup := range groupedNS { for _, domainGroup := range groupedNS {
basePriority := PriorityMatchDomain basePriority := PriorityUpstream
if domainGroup.domain == nbdns.RootZone { if domainGroup.domain == nbdns.RootZone {
basePriority = PriorityDefault basePriority = PriorityDefault
} }
@@ -588,10 +588,14 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
// Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts // Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts
priority := basePriority - i priority := basePriority - i
// Check if we're about to overlap with the next priority tier // Check if we're about to overlap with the next priority tier.
if basePriority == PriorityMatchDomain && priority <= PriorityDefault { // This boundary check ensures that the priority of upstream handlers does not conflict
// with the default priority tier. By decrementing the priority for each handler, we avoid
// overlaps, but if the calculated priority falls into the default tier, we skip the remaining
// handlers to maintain the integrity of the priority system.
if basePriority == PriorityUpstream && priority <= PriorityDefault {
log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers", log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers",
domainGroup.domain, PriorityMatchDomain-PriorityDefault) domainGroup.domain, PriorityUpstream-PriorityDefault)
break break
} }

View File

@@ -164,12 +164,12 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{ generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
domain: "netbird.io", domain: "netbird.io",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
dummyHandler.ID(): handlerWrapper{ dummyHandler.ID(): handlerWrapper{
domain: "netbird.cloud", domain: "netbird.cloud",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityLocal,
}, },
generateDummyHandler(".", nameServers).ID(): handlerWrapper{ generateDummyHandler(".", nameServers).ID(): handlerWrapper{
domain: nbdns.RootZone, domain: nbdns.RootZone,
@@ -186,7 +186,7 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: "netbird.cloud", domain: "netbird.cloud",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
}, },
initSerial: 0, initSerial: 0,
@@ -210,12 +210,12 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{ generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
domain: "netbird.io", domain: "netbird.io",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
"local-resolver": handlerWrapper{ "local-resolver": handlerWrapper{
domain: "netbird.cloud", domain: "netbird.cloud",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityLocal,
}, },
}, },
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}}, expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
@@ -305,7 +305,7 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name, domain: zoneRecords[0].Name,
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
}, },
initSerial: 0, initSerial: 0,
@@ -321,7 +321,7 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name, domain: zoneRecords[0].Name,
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
}, },
initSerial: 0, initSerial: 0,
@@ -495,7 +495,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
"id1": handlerWrapper{ "id1": handlerWrapper{
domain: zoneRecords[0].Name, domain: zoneRecords[0].Name,
handler: &local.Resolver{}, handler: &local.Resolver{},
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
} }
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}} //dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
@@ -978,7 +978,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
} }
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute) chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute)
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain) chain.AddHandler("example.com.", upstreamHandler, PriorityUpstream)
testCases := []struct { testCases := []struct {
name string name string
@@ -1059,14 +1059,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
"upstream-group2": { "upstream-group2": {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityMatchDomain - 1, priority: PriorityUpstream - 1,
}, },
} }
@@ -1093,21 +1093,21 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
"upstream-group2": { "upstream-group2": {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityMatchDomain - 1, priority: PriorityUpstream - 1,
}, },
"upstream-other": { "upstream-other": {
domain: "other.com", domain: "other.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-other", Id: "upstream-other",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
} }
@@ -1128,7 +1128,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityMatchDomain - 1, priority: PriorityUpstream - 1,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@@ -1146,7 +1146,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@@ -1164,7 +1164,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group3", Id: "upstream-group3",
}, },
priority: PriorityMatchDomain + 1, priority: PriorityUpstream + 1,
}, },
// Keep existing groups with their original priorities // Keep existing groups with their original priorities
{ {
@@ -1172,14 +1172,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
{ {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityMatchDomain - 1, priority: PriorityUpstream - 1,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@@ -1199,14 +1199,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
{ {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityMatchDomain - 1, priority: PriorityUpstream - 1,
}, },
// Add group3 with lowest priority // Add group3 with lowest priority
{ {
@@ -1214,7 +1214,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group3", Id: "upstream-group3",
}, },
priority: PriorityMatchDomain - 2, priority: PriorityUpstream - 2,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@@ -1335,14 +1335,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
{ {
domain: "other.com", domain: "other.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-other", Id: "upstream-other",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@@ -1360,28 +1360,28 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
{ {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityMatchDomain - 1, priority: PriorityUpstream - 1,
}, },
{ {
domain: "other.com", domain: "other.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-other", Id: "upstream-other",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
{ {
domain: "new.com", domain: "new.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-new", Id: "upstream-new",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@@ -1791,14 +1791,14 @@ func TestExtraDomainsRefCounting(t *testing.T) {
// Register domains from different handlers with same domain // Register domains from different handlers with same domain
server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute) server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute)
server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityMatchDomain) server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityUpstream)
// Verify refcount is 2 // Verify refcount is 2
zoneKey := toZone("shared.example.com") zoneKey := toZone("shared.example.com")
assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice") assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice")
// Deregister one handler // Deregister one handler
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityMatchDomain) server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityUpstream)
// Verify refcount is 1 // Verify refcount is 1
assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler") assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler")
@@ -1925,7 +1925,7 @@ func TestDomainCaseHandling(t *testing.T) {
} }
server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault) server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault)
server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityMatchDomain) server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityUpstream)
assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized") assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized")
@@ -1945,3 +1945,111 @@ func TestDomainCaseHandling(t *testing.T) {
assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent") assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent")
assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present") assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present")
} }
func TestLocalResolverPriorityInServer(t *testing.T) {
server := &DefaultServer{
ctx: context.Background(),
wgInterface: &mocWGIface{},
handlerChain: NewHandlerChain(),
localResolver: local.NewResolver(),
service: &mockService{},
extraDomains: make(map[domain.Domain]int),
}
config := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "local.example.com",
Records: []nbdns.SimpleRecord{
{
Name: "test.local.example.com",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.100",
},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"local.example.com"}, // Same domain as local records
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
},
},
},
}
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
assert.NoError(t, err)
upstreamMuxUpdates, err := server.buildUpstreamHandlerUpdate(config.NameServerGroups)
assert.NoError(t, err)
// Verify that local handler has higher priority than upstream for same domain
var localPriority, upstreamPriority int
localFound, upstreamFound := false, false
for _, update := range localMuxUpdates {
if update.domain == "local.example.com" {
localPriority = update.priority
localFound = true
}
}
for _, update := range upstreamMuxUpdates {
if update.domain == "local.example.com" {
upstreamPriority = update.priority
upstreamFound = true
}
}
assert.True(t, localFound, "Local handler should be found")
assert.True(t, upstreamFound, "Upstream handler should be found")
assert.Greater(t, localPriority, upstreamPriority,
"Local handler priority (%d) should be higher than upstream priority (%d)",
localPriority, upstreamPriority)
assert.Equal(t, PriorityLocal, localPriority, "Local handler should use PriorityLocal")
assert.Equal(t, PriorityUpstream, upstreamPriority, "Upstream handler should use PriorityUpstream")
}
func TestLocalResolverPriorityConstants(t *testing.T) {
// Test that priority constants are ordered correctly
assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route")
assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream")
assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default")
// Test that local resolver uses the correct priority
server := &DefaultServer{
localResolver: local.NewResolver(),
}
config := nbdns.Config{
CustomZones: []nbdns.CustomZone{
{
Domain: "local.example.com",
Records: []nbdns.SimpleRecord{
{
Name: "test.local.example.com",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.100",
},
},
},
},
}
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
assert.NoError(t, err)
assert.Len(t, localMuxUpdates, 1)
assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal")
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
}

View File

@@ -2,6 +2,7 @@ package dns
import ( import (
"context" "context"
"crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"errors" "errors"
@@ -103,19 +104,21 @@ func (u *upstreamResolverBase) Stop() {
// ServeDNS handles a DNS request // ServeDNS handles a DNS request
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := GenerateRequestID()
logger := log.WithField("request_id", requestID)
var err error var err error
defer func() { defer func() {
u.checkUpstreamFails(err) u.checkUpstreamFails(err)
}() }()
log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
if r.Extra == nil { if r.Extra == nil {
r.MsgHdr.AuthenticatedData = true r.MsgHdr.AuthenticatedData = true
} }
select { select {
case <-u.ctx.Done(): case <-u.ctx.Done():
log.Tracef("%s has been stopped", u) logger.Tracef("%s has been stopped", u)
return return
default: default:
} }
@@ -132,35 +135,35 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if err != nil { if err != nil {
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) { if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
log.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name) logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
continue continue
} }
log.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err) logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
continue continue
} }
if rm == nil || !rm.Response { if rm == nil || !rm.Response {
log.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
continue continue
} }
u.successCount.Add(1) u.successCount.Add(1)
log.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name) logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
if err = w.WriteMsg(rm); err != nil { if err = w.WriteMsg(rm); err != nil {
log.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err) logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
} }
// count the fails only if they happen sequentially // count the fails only if they happen sequentially
u.failsCount.Store(0) u.failsCount.Store(0)
return return
} }
u.failsCount.Add(1) u.failsCount.Add(1)
log.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name) logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
m := new(dns.Msg) m := new(dns.Msg)
m.SetRcode(r, dns.RcodeServerFailure) m.SetRcode(r, dns.RcodeServerFailure)
if err := w.WriteMsg(m); err != nil { if err := w.WriteMsg(m); err != nil {
log.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err) logger.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
} }
} }
@@ -385,3 +388,13 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
return rm, t, nil return rm, t, nil
} }
func GenerateRequestID() string {
bytes := make([]byte, 4)
_, err := rand.Read(bytes)
if err != nil {
log.Errorf("failed to generate request ID: %v", err)
return ""
}
return hex.EncodeToString(bytes)
}

View File

@@ -84,3 +84,10 @@ func (u *upstreamResolver) isLocalResolver(upstream string) bool {
} }
return false return false
} }
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
return &dns.Client{
Timeout: dialTimeout,
Net: "udp",
}, nil
}

View File

@@ -36,3 +36,10 @@ func newUpstreamResolver(
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream) return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream)
} }
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
return &dns.Client{
Timeout: dialTimeout,
Net: "udp",
}, nil
}

View File

@@ -18,14 +18,20 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
const errResolveFailed = "failed to resolve query for domain=%s: %v" const errResolveFailed = "failed to resolve query for domain=%s: %v"
const upstreamTimeout = 15 * time.Second const upstreamTimeout = 15 * time.Second
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
type firewaller interface {
UpdateSet(set firewall.Set, prefixes []netip.Prefix) error
}
type DNSForwarder struct { type DNSForwarder struct {
listenAddress string listenAddress string
ttl uint32 ttl uint32
@@ -38,16 +44,18 @@ type DNSForwarder struct {
mutex sync.RWMutex mutex sync.RWMutex
fwdEntries []*ForwarderEntry fwdEntries []*ForwarderEntry
firewall firewall.Manager firewall firewaller
resolver resolver
} }
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager, statusRecorder *peer.Status) *DNSForwarder { func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl) log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
return &DNSForwarder{ return &DNSForwarder{
listenAddress: listenAddress, listenAddress: listenAddress,
ttl: ttl, ttl: ttl,
firewall: firewall, firewall: firewall,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
resolver: net.DefaultResolver,
} }
} }
@@ -57,14 +65,17 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
// UDP server // UDP server
mux := dns.NewServeMux() mux := dns.NewServeMux()
f.mux = mux f.mux = mux
mux.HandleFunc(".", f.handleDNSQueryUDP)
f.dnsServer = &dns.Server{ f.dnsServer = &dns.Server{
Addr: f.listenAddress, Addr: f.listenAddress,
Net: "udp", Net: "udp",
Handler: mux, Handler: mux,
} }
// TCP server // TCP server
tcpMux := dns.NewServeMux() tcpMux := dns.NewServeMux()
f.tcpMux = tcpMux f.tcpMux = tcpMux
tcpMux.HandleFunc(".", f.handleDNSQueryTCP)
f.tcpServer = &dns.Server{ f.tcpServer = &dns.Server{
Addr: f.listenAddress, Addr: f.listenAddress,
Net: "tcp", Net: "tcp",
@@ -87,30 +98,13 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
// return the first error we get (e.g. bind failure or shutdown) // return the first error we get (e.g. bind failure or shutdown)
return <-errCh return <-errCh
} }
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
f.mutex.Lock() f.mutex.Lock()
defer f.mutex.Unlock() defer f.mutex.Unlock()
if f.mux == nil {
log.Debug("DNS mux is nil, skipping domain update")
f.fwdEntries = entries
return
}
oldDomains := filterDomains(f.fwdEntries)
for _, d := range oldDomains {
f.mux.HandleRemove(d.PunycodeString())
f.tcpMux.HandleRemove(d.PunycodeString())
}
newDomains := filterDomains(entries)
for _, d := range newDomains {
f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQueryUDP)
f.tcpMux.HandleFunc(d.PunycodeString(), f.handleDNSQueryTCP)
}
f.fwdEntries = entries f.fwdEntries = entries
log.Debugf("Updated domains from %v to %v", oldDomains, newDomains) log.Debugf("Updated DNS forwarder with %d domains", len(entries))
} }
func (f *DNSForwarder) Close(ctx context.Context) error { func (f *DNSForwarder) Close(ctx context.Context) error {
@@ -157,22 +151,31 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
return nil return nil
} }
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
// query doesn't match any configured domain
if mostSpecificResId == "" {
resp.Rcode = dns.RcodeRefused
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
}
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout) ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel() defer cancel()
ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain) ips, err := f.resolver.LookupNetIP(ctx, network, domain)
if err != nil { if err != nil {
f.handleDNSError(w, query, resp, domain, err) f.handleDNSError(w, query, resp, domain, err)
return nil return nil
} }
f.updateInternalState(domain, ips) f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips) f.addIPsToResponse(resp, domain, ips)
return resp return resp
} }
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) { func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
resp := f.handleDNSQuery(w, query) resp := f.handleDNSQuery(w, query)
if resp == nil { if resp == nil {
return return
@@ -206,9 +209,8 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
} }
} }
func (f *DNSForwarder) updateInternalState(domain string, ips []netip.Addr) { func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
var prefixes []netip.Prefix var prefixes []netip.Prefix
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
if mostSpecificResId != "" { if mostSpecificResId != "" {
for _, ip := range ips { for _, ip := range ips {
var prefix netip.Prefix var prefix netip.Prefix
@@ -339,16 +341,3 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar
return selectedResId, matches return selectedResId, matches
} }
// filterDomains returns a list of normalized domains
func filterDomains(entries []*ForwarderEntry) domain.List {
newDomains := make(domain.List, 0, len(entries))
for _, d := range entries {
if d.Domain == "" {
log.Warn("empty domain in DNS forwarder")
continue
}
newDomains = append(newDomains, domain.Domain(nbdns.NormalizeZone(d.Domain.PunycodeString())))
}
return newDomains
}

View File

@@ -1,11 +1,21 @@
package dnsfwd package dnsfwd
import ( import (
"context"
"fmt"
"net/netip"
"strings"
"testing" "testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -13,7 +23,7 @@ import (
func Test_getMatchingEntries(t *testing.T) { func Test_getMatchingEntries(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
storedMappings map[string]route.ResID // key: domain pattern, value: resId storedMappings map[string]route.ResID
queryDomain string queryDomain string
expectedResId route.ResID expectedResId route.ResID
}{ }{
@@ -44,7 +54,7 @@ func Test_getMatchingEntries(t *testing.T) {
{ {
name: "Wildcard pattern does not match different domain", name: "Wildcard pattern does not match different domain",
storedMappings: map[string]route.ResID{"*.example.com": "res4"}, storedMappings: map[string]route.ResID{"*.example.com": "res4"},
queryDomain: "foo.notexample.com", queryDomain: "foo.example.org",
expectedResId: "", expectedResId: "",
}, },
{ {
@@ -101,3 +111,619 @@ func Test_getMatchingEntries(t *testing.T) {
}) })
} }
} }
type MockFirewall struct {
mock.Mock
}
func (m *MockFirewall) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
args := m.Called(set, prefixes)
return args.Error(0)
}
type MockResolver struct {
mock.Mock
}
func (m *MockResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
args := m.Called(ctx, network, host)
return args.Get(0).([]netip.Addr), args.Error(1)
}
func TestDNSForwarder_SubdomainAccessLogic(t *testing.T) {
tests := []struct {
name string
configuredDomain string
queryDomain string
shouldMatch bool
expectedResID route.ResID
description string
}{
{
name: "exact domain match should be allowed",
configuredDomain: "example.com",
queryDomain: "example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Direct match to configured domain should work",
},
{
name: "subdomain access should be restricted",
configuredDomain: "example.com",
queryDomain: "mail.example.com",
shouldMatch: false,
expectedResID: "",
description: "Subdomain should not be accessible unless explicitly configured",
},
{
name: "wildcard should allow subdomains",
configuredDomain: "*.example.com",
queryDomain: "mail.example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Wildcard domains should allow subdomain access",
},
{
name: "wildcard should allow base domain",
configuredDomain: "*.example.com",
queryDomain: "example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Wildcard should also match the base domain",
},
{
name: "deep subdomain should be restricted",
configuredDomain: "example.com",
queryDomain: "deep.mail.example.com",
shouldMatch: false,
expectedResID: "",
description: "Deep subdomains should not be accessible",
},
{
name: "wildcard allows deep subdomains",
configuredDomain: "*.example.com",
queryDomain: "deep.mail.example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Wildcard should allow deep subdomains",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
forwarder := &DNSForwarder{}
d, err := domain.FromString(tt.configuredDomain)
require.NoError(t, err)
entries := []*ForwarderEntry{
{
Domain: d,
ResID: "test-res-id",
},
}
forwarder.UpdateDomains(entries)
resID, matchingEntries := forwarder.getMatchingEntries(tt.queryDomain)
if tt.shouldMatch {
assert.Equal(t, tt.expectedResID, resID, "Expected matching ResID")
assert.NotEmpty(t, matchingEntries, "Expected matching entries")
t.Logf("✓ Domain %s correctly matches pattern %s", tt.queryDomain, tt.configuredDomain)
} else {
assert.Equal(t, tt.expectedResID, resID, "Expected no ResID match")
assert.Empty(t, matchingEntries, "Expected no matching entries")
t.Logf("✓ Domain %s correctly does NOT match pattern %s", tt.queryDomain, tt.configuredDomain)
}
})
}
}
func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
tests := []struct {
name string
configuredDomain string
queryDomain string
shouldResolve bool
description string
}{
{
name: "configured exact domain resolves",
configuredDomain: "example.com",
queryDomain: "example.com",
shouldResolve: true,
description: "Exact match should resolve",
},
{
name: "unauthorized subdomain blocked",
configuredDomain: "example.com",
queryDomain: "mail.example.com",
shouldResolve: false,
description: "Subdomain should be blocked without wildcard",
},
{
name: "wildcard allows subdomain",
configuredDomain: "*.example.com",
queryDomain: "mail.example.com",
shouldResolve: true,
description: "Wildcard should allow subdomain",
},
{
name: "wildcard allows base domain",
configuredDomain: "*.example.com",
queryDomain: "example.com",
shouldResolve: true,
description: "Wildcard should allow base domain",
},
{
name: "unrelated domain blocked",
configuredDomain: "example.com",
queryDomain: "example.org",
shouldResolve: false,
description: "Unrelated domain should be blocked",
},
{
name: "deep subdomain blocked",
configuredDomain: "example.com",
queryDomain: "deep.mail.example.com",
shouldResolve: false,
description: "Deep subdomain should be blocked",
},
{
name: "wildcard allows deep subdomain",
configuredDomain: "*.example.com",
queryDomain: "deep.mail.example.com",
shouldResolve: true,
description: "Wildcard should allow deep subdomain",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
if tt.shouldResolve {
mockFirewall.On("UpdateSet", mock.AnythingOfType("manager.Set"), mock.AnythingOfType("[]netip.Prefix")).Return(nil)
// Mock successful DNS resolution
fakeIP := netip.MustParseAddr("1.2.3.4")
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil)
}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
d, err := domain.FromString(tt.configuredDomain)
require.NoError(t, err)
entries := []*ForwarderEntry{
{
Domain: d,
ResID: "test-res-id",
Set: firewall.NewDomainSet([]domain.Domain{d}),
},
}
forwarder.UpdateDomains(entries)
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
assert.NotEmpty(t, resp.Answer, "Expected DNS answer records")
time.Sleep(10 * time.Millisecond)
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
} else {
if resp != nil {
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
"Unauthorized domain should not return successful answers")
}
mockFirewall.AssertNotCalled(t, "UpdateSet")
mockResolver.AssertNotCalled(t, "LookupNetIP")
}
})
}
}
func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
tests := []struct {
name string
configuredDomains []string
query string
mockIP string
shouldResolve bool
expectedSetCount int // How many sets should be updated
description string
}{
{
name: "exact domain gets firewall update",
configuredDomains: []string{"example.com"},
query: "example.com",
mockIP: "1.1.1.1",
shouldResolve: true,
expectedSetCount: 1,
description: "Single exact match updates one set",
},
{
name: "wildcard domain gets firewall update",
configuredDomains: []string{"*.example.com"},
query: "mail.example.com",
mockIP: "1.1.1.2",
shouldResolve: true,
expectedSetCount: 1,
description: "Wildcard match updates one set",
},
{
name: "overlapping exact and wildcard both get updates",
configuredDomains: []string{"*.example.com", "mail.example.com"},
query: "mail.example.com",
mockIP: "1.1.1.3",
shouldResolve: true,
expectedSetCount: 2,
description: "Both exact and wildcard sets should be updated",
},
{
name: "unauthorized domain gets no firewall update",
configuredDomains: []string{"example.com"},
query: "mail.example.com",
mockIP: "1.1.1.4",
shouldResolve: false,
expectedSetCount: 0,
description: "No firewall update for unauthorized domains",
},
{
name: "multiple wildcards matching get all updated",
configuredDomains: []string{"*.example.com", "*.sub.example.com"},
query: "test.sub.example.com",
mockIP: "1.1.1.5",
shouldResolve: true,
expectedSetCount: 2,
description: "All matching wildcard sets should be updated",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
// Set up forwarder
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
// Create entries and track sets
var entries []*ForwarderEntry
sets := make([]firewall.Set, 0)
for i, configDomain := range tt.configuredDomains {
d, err := domain.FromString(configDomain)
require.NoError(t, err)
set := firewall.NewDomainSet([]domain.Domain{d})
sets = append(sets, set)
entries = append(entries, &ForwarderEntry{
Domain: d,
ResID: route.ResID(fmt.Sprintf("res-%d", i)),
Set: set,
})
}
forwarder.UpdateDomains(entries)
// Set up mocks
if tt.shouldResolve {
fakeIP := netip.MustParseAddr(tt.mockIP)
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.query)).
Return([]netip.Addr{fakeIP}, nil).Once()
expectedPrefixes := []netip.Prefix{netip.PrefixFrom(fakeIP, 32)}
// Count how many sets should actually match
updateCount := 0
for i, entry := range entries {
domain := strings.ToLower(tt.query)
pattern := entry.Domain.PunycodeString()
matches := false
if strings.HasPrefix(pattern, "*.") {
baseDomain := strings.TrimPrefix(pattern, "*.")
if domain == baseDomain || strings.HasSuffix(domain, "."+baseDomain) {
matches = true
}
} else if domain == pattern {
matches = true
}
if matches {
mockFirewall.On("UpdateSet", sets[i], expectedPrefixes).Return(nil).Once()
updateCount++
}
}
assert.Equal(t, tt.expectedSetCount, updateCount,
"Expected %d sets to be updated, but mock expects %d",
tt.expectedSetCount, updateCount)
}
// Execute query
dnsQuery := &dns.Msg{}
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, dnsQuery)
// Verify response
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.NotEmpty(t, resp.Answer)
} else if resp != nil {
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
"Unauthorized domain should be refused or have no answers")
}
// Verify all mock expectations were met
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
})
}
}
// Test to verify that multiple IPs for one domain result in all prefixes being sent together
func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
// Configure a single domain
d, err := domain.FromString("example.com")
require.NoError(t, err)
set := firewall.NewDomainSet([]domain.Domain{d})
entries := []*ForwarderEntry{{
Domain: d,
ResID: "test-res",
Set: set,
}}
forwarder.UpdateDomains(entries)
// Mock resolver returns multiple IPs
ips := []netip.Addr{
netip.MustParseAddr("1.1.1.1"),
netip.MustParseAddr("1.1.1.2"),
netip.MustParseAddr("1.1.1.3"),
}
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return(ips, nil).Once()
// Expect ONE UpdateSet call with ALL prefixes
expectedPrefixes := []netip.Prefix{
netip.PrefixFrom(ips[0], 32),
netip.PrefixFrom(ips[1], 32),
netip.PrefixFrom(ips[2], 32),
}
mockFirewall.On("UpdateSet", set, expectedPrefixes).Return(nil).Once()
// Execute query
query := &dns.Msg{}
query.SetQuestion("example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
// Verify response contains all IPs
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
// Verify mocks
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
}
func TestDNSForwarder_ResponseCodes(t *testing.T) {
tests := []struct {
name string
queryType uint16
queryDomain string
configured string
expectedCode int
description string
}{
{
name: "unauthorized domain returns REFUSED",
queryType: dns.TypeA,
queryDomain: "evil.com",
configured: "example.com",
expectedCode: dns.RcodeRefused,
description: "RFC compliant REFUSED for unauthorized queries",
},
{
name: "unsupported query type returns NOTIMP",
queryType: dns.TypeMX,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "RFC compliant NOTIMP for unsupported types",
},
{
name: "CNAME query returns NOTIMP",
queryType: dns.TypeCNAME,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "CNAME queries not supported",
},
{
name: "TXT query returns NOTIMP",
queryType: dns.TypeTXT,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "TXT queries not supported",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
d, err := domain.FromString(tt.configured)
require.NoError(t, err)
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}}
forwarder.UpdateDomains(entries)
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(tt.queryDomain), tt.queryType)
// Capture the written response
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
_ = forwarder.handleDNSQuery(mockWriter, query)
// Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
})
}
}
func TestDNSForwarder_TCPTruncation(t *testing.T) {
// Test that large UDP responses are truncated with TC bit set
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
forwarder.resolver = mockResolver
d, _ := domain.FromString("example.com")
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}}
forwarder.UpdateDomains(entries)
// Mock many IPs to create a large response
var manyIPs []netip.Addr
for i := 0; i < 100; i++ {
manyIPs = append(manyIPs, netip.MustParseAddr(fmt.Sprintf("1.1.1.%d", i%256)))
}
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").Return(manyIPs, nil)
// Query without EDNS0
query := &dns.Msg{}
query.SetQuestion("example.com.", dns.TypeA)
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
forwarder.handleDNSQueryUDP(mockWriter, query)
require.NotNil(t, writtenResp)
assert.True(t, writtenResp.Truncated, "Large response should be truncated")
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
}
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
// Test complex overlapping pattern scenarios
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
// Set up complex overlapping patterns
patterns := []string{
"*.example.com", // Matches all subdomains
"*.mail.example.com", // More specific wildcard
"smtp.mail.example.com", // Exact match
"example.com", // Base domain
}
var entries []*ForwarderEntry
sets := make(map[string]firewall.Set)
for _, pattern := range patterns {
d, _ := domain.FromString(pattern)
set := firewall.NewDomainSet([]domain.Domain{d})
sets[pattern] = set
entries = append(entries, &ForwarderEntry{
Domain: d,
ResID: route.ResID("res-" + pattern),
Set: set,
})
}
forwarder.UpdateDomains(entries)
// Test smtp.mail.example.com - should match 3 patterns
fakeIP := netip.MustParseAddr("1.2.3.4")
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "smtp.mail.example.com.").Return([]netip.Addr{fakeIP}, nil)
expectedPrefix := netip.PrefixFrom(fakeIP, 32)
// All three matching patterns should get firewall updates
mockFirewall.On("UpdateSet", sets["smtp.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
mockFirewall.On("UpdateSet", sets["*.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
mockFirewall.On("UpdateSet", sets["*.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
query := &dns.Msg{}
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
// Verify all three sets were updated
mockFirewall.AssertExpectations(t)
// Verify the most specific ResID was selected
// (exact match should win over wildcards)
resID, matches := forwarder.getMatchingEntries("smtp.mail.example.com")
assert.Equal(t, route.ResID("res-smtp.mail.example.com"), resID)
assert.Len(t, matches, 3, "Should match 3 patterns")
}
func TestDNSForwarder_EmptyQuery(t *testing.T) {
// Test handling of malformed query with no questions
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
query := &dns.Msg{}
// Don't set any question
writeCalled := false
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writeCalled = true
return nil
},
}
resp := forwarder.handleDNSQuery(mockWriter, query)
assert.Nil(t, resp, "Should return nil for empty query")
assert.False(t, writeCalled, "Should not write response for empty query")
}

View File

@@ -1527,6 +1527,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
MTU: iface.DefaultMTU, MTU: iface.DefaultMTU,
TransportNet: transportNet, TransportNet: transportNet,
FilterFn: e.addrViaRoutes, FilterFn: e.addrViaRoutes,
DisableDNS: e.config.DisableDNS,
} }
switch runtime.GOOS { switch runtime.GOOS {

View File

@@ -68,3 +68,8 @@ func (i *Monitor) PauseTimer() {
func (i *Monitor) ResetTimer() { func (i *Monitor) ResetTimer() {
i.timer.Reset(i.inactivityThreshold) i.timer.Reset(i.inactivityThreshold)
} }
func (i *Monitor) ResetMonitor(ctx context.Context, timeoutChan chan peer.ConnID) {
i.Stop()
go i.Start(ctx, timeoutChan)
}

View File

@@ -58,7 +58,7 @@ type Manager struct {
// Route HA group management // Route HA group management
peerToHAGroups map[string][]route.HAUniqueID // peer ID -> HA groups they belong to peerToHAGroups map[string][]route.HAUniqueID // peer ID -> HA groups they belong to
haGroupToPeers map[route.HAUniqueID][]string // HA group -> peer IDs in the group haGroupToPeers map[route.HAUniqueID][]string // HA group -> peer IDs in the group
routesMu sync.RWMutex // protects route mappings routesMu sync.RWMutex
onInactive chan peerid.ConnID onInactive chan peerid.ConnID
} }
@@ -146,7 +146,7 @@ func (m *Manager) Start(ctx context.Context) {
case peerConnID := <-m.activityManager.OnActivityChan: case peerConnID := <-m.activityManager.OnActivityChan:
m.onPeerActivity(ctx, peerConnID) m.onPeerActivity(ctx, peerConnID)
case peerConnID := <-m.onInactive: case peerConnID := <-m.onInactive:
m.onPeerInactivityTimedOut(peerConnID) m.onPeerInactivityTimedOut(ctx, peerConnID)
} }
} }
} }
@@ -197,7 +197,7 @@ func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerCo
return added return added
} }
func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) { func (m *Manager) AddPeer(ctx context.Context, peerCfg lazyconn.PeerConfig) (bool, error) {
m.managedPeersMu.Lock() m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock() defer m.managedPeersMu.Unlock()
@@ -225,6 +225,13 @@ func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
peerCfg: &peerCfg, peerCfg: &peerCfg,
expectedWatcher: watcherActivity, expectedWatcher: watcherActivity,
} }
// Check if this peer should be activated because its HA group peers are active
if group, ok := m.shouldActivateNewPeer(peerCfg.PublicKey); ok {
peerCfg.Log.Debugf("peer belongs to active HA group %s, will activate immediately", group)
m.activateNewPeerInActiveGroup(ctx, peerCfg)
}
return false, nil return false, nil
} }
@@ -315,36 +322,38 @@ func (m *Manager) activateSinglePeer(ctx context.Context, cfg *lazyconn.PeerConf
// activateHAGroupPeers activates all peers in HA groups that the given peer belongs to // activateHAGroupPeers activates all peers in HA groups that the given peer belongs to
func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string) { func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string) {
var peersToActivate []string
m.routesMu.RLock() m.routesMu.RLock()
haGroups := m.peerToHAGroups[triggerPeerID] haGroups := m.peerToHAGroups[triggerPeerID]
m.routesMu.RUnlock()
if len(haGroups) == 0 { if len(haGroups) == 0 {
m.routesMu.RUnlock()
log.Debugf("peer %s is not part of any HA groups", triggerPeerID) log.Debugf("peer %s is not part of any HA groups", triggerPeerID)
return return
} }
activatedCount := 0
for _, haGroup := range haGroups { for _, haGroup := range haGroups {
m.routesMu.RLock()
peers := m.haGroupToPeers[haGroup] peers := m.haGroupToPeers[haGroup]
m.routesMu.RUnlock()
for _, peerID := range peers { for _, peerID := range peers {
if peerID == triggerPeerID { if peerID != triggerPeerID {
continue peersToActivate = append(peersToActivate, peerID)
} }
}
}
m.routesMu.RUnlock()
cfg, mp := m.getPeerForActivation(peerID) activatedCount := 0
if cfg == nil { for _, peerID := range peersToActivate {
continue cfg, mp := m.getPeerForActivation(peerID)
} if cfg == nil {
continue
}
if m.activateSinglePeer(ctx, cfg, mp) { if m.activateSinglePeer(ctx, cfg, mp) {
activatedCount++ activatedCount++
cfg.Log.Infof("activated peer as part of HA group %s (triggered by %s)", haGroup, triggerPeerID) cfg.Log.Infof("activated peer as part of HA group (triggered by %s)", triggerPeerID)
m.peerStore.PeerConnOpen(m.engineCtx, cfg.PublicKey) m.peerStore.PeerConnOpen(m.engineCtx, cfg.PublicKey)
}
} }
} }
@@ -354,6 +363,51 @@ func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string
} }
} }
// shouldActivateNewPeer checks if a newly added peer should be activated
// because other peers in its HA groups are already active
func (m *Manager) shouldActivateNewPeer(peerID string) (route.HAUniqueID, bool) {
m.routesMu.RLock()
defer m.routesMu.RUnlock()
haGroups := m.peerToHAGroups[peerID]
if len(haGroups) == 0 {
return "", false
}
for _, haGroup := range haGroups {
peers := m.haGroupToPeers[haGroup]
for _, groupPeerID := range peers {
if groupPeerID == peerID {
continue
}
cfg, ok := m.managedPeers[groupPeerID]
if !ok {
continue
}
if mp, ok := m.managedPeersByConnID[cfg.PeerConnID]; ok && mp.expectedWatcher == watcherInactivity {
return haGroup, true
}
}
}
return "", false
}
// activateNewPeerInActiveGroup activates a newly added peer that should be active due to HA group
func (m *Manager) activateNewPeerInActiveGroup(ctx context.Context, peerCfg lazyconn.PeerConfig) {
mp, ok := m.managedPeersByConnID[peerCfg.PeerConnID]
if !ok {
return
}
if !m.activateSinglePeer(ctx, &peerCfg, mp) {
return
}
peerCfg.Log.Infof("activated newly added peer due to active HA group peers")
m.peerStore.PeerConnOpen(m.engineCtx, peerCfg.PublicKey)
}
func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error { func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error {
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok { if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
peerCfg.Log.Warnf("peer already managed") peerCfg.Log.Warnf("peer already managed")
@@ -415,6 +469,48 @@ func (m *Manager) close() {
log.Infof("lazy connection manager closed") log.Infof("lazy connection manager closed")
} }
// shouldDeferIdleForHA checks if peer should stay connected due to HA group requirements
func (m *Manager) shouldDeferIdleForHA(peerID string) bool {
m.routesMu.RLock()
defer m.routesMu.RUnlock()
haGroups := m.peerToHAGroups[peerID]
if len(haGroups) == 0 {
return false
}
for _, haGroup := range haGroups {
groupPeers := m.haGroupToPeers[haGroup]
for _, groupPeerID := range groupPeers {
if groupPeerID == peerID {
continue
}
cfg, ok := m.managedPeers[groupPeerID]
if !ok {
continue
}
groupMp, ok := m.managedPeersByConnID[cfg.PeerConnID]
if !ok {
continue
}
if groupMp.expectedWatcher != watcherInactivity {
continue
}
// Other member is still connected, defer idle
if peer, ok := m.peerStore.PeerConn(groupPeerID); ok && peer.IsConnected() {
return true
}
}
}
return false
}
func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) { func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) {
m.managedPeersMu.Lock() m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock() defer m.managedPeersMu.Unlock()
@@ -441,7 +537,7 @@ func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID)
m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey) m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey)
} }
func (m *Manager) onPeerInactivityTimedOut(peerConnID peerid.ConnID) { func (m *Manager) onPeerInactivityTimedOut(ctx context.Context, peerConnID peerid.ConnID) {
m.managedPeersMu.Lock() m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock() defer m.managedPeersMu.Unlock()
@@ -456,6 +552,17 @@ func (m *Manager) onPeerInactivityTimedOut(peerConnID peerid.ConnID) {
return return
} }
if m.shouldDeferIdleForHA(mp.peerCfg.PublicKey) {
iw, ok := m.inactivityMonitors[peerConnID]
if ok {
mp.peerCfg.Log.Debugf("resetting inactivity timer due to HA group requirements")
iw.ResetMonitor(ctx, m.onInactive)
} else {
mp.peerCfg.Log.Errorf("inactivity monitor not found for HA defer reset")
}
return
}
mp.peerCfg.Log.Infof("connection timed out") mp.peerCfg.Log.Infof("connection timed out")
// this is blocking operation, potentially can be optimized // this is blocking operation, potentially can be optimized
@@ -489,7 +596,7 @@ func (m *Manager) onPeerConnected(peerConnID peerid.ConnID) {
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID] iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
if !ok { if !ok {
mp.peerCfg.Log.Errorf("inactivity monitor not found for peer") mp.peerCfg.Log.Warnf("inactivity monitor not found for peer")
return return
} }

View File

@@ -204,7 +204,7 @@ func (c *ConnTrack) handleEvent(event nfct.Event) {
eventStr = "Ended" eventStr = "Ended"
} }
log.Tracef("%s %s %s connection: %s:%d -> %s:%d", eventStr, direction, proto, srcIP, srcPort, dstIP, dstPort) log.Tracef("%s %s %s connection: %s:%d %s:%d", eventStr, direction, proto, srcIP, srcPort, dstIP, dstPort)
c.flowLogger.StoreEvent(nftypes.EventFields{ c.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID, FlowID: flowID,

View File

@@ -317,12 +317,12 @@ func (conn *Conn) WgConfig() WgConfig {
return conn.config.WgConfig return conn.config.WgConfig
} }
// IsConnected unit tests only // IsConnected returns true if the peer is connected
// refactor unit test to use status recorder use refactor status recorded to manage connection status in peer.Conn
func (conn *Conn) IsConnected() bool { func (conn *Conn) IsConnected() bool {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
return conn.currentConnPriority != conntype.None
return conn.evalStatus() == StatusConnected
} }
func (conn *Conn) GetKey() string { func (conn *Conn) GetKey() string {

View File

@@ -575,13 +575,12 @@ func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
// FinishPeerListModifications this event invoke the notification // FinishPeerListModifications this event invoke the notification
func (d *Status) FinishPeerListModifications() { func (d *Status) FinishPeerListModifications() {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
if !d.peerListChangedForNotification { if !d.peerListChangedForNotification {
d.mux.Unlock()
return return
} }
d.peerListChangedForNotification = false d.peerListChangedForNotification = false
d.mux.Unlock()
d.notifyPeerListChanged() d.notifyPeerListChanged()

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"reflect" "reflect"
"runtime"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -23,7 +22,7 @@ import (
const ( const (
handlerTypeDynamic = iota handlerTypeDynamic = iota
handlerTypeDomain handlerTypeDnsInterceptor
handlerTypeStatic handlerTypeStatic
) )
@@ -566,13 +565,14 @@ func HandlerFromRoute(
useNewDNSRoute bool, useNewDNSRoute bool,
) RouteHandler { ) RouteHandler {
switch handlerType(rt, useNewDNSRoute) { switch handlerType(rt, useNewDNSRoute) {
case handlerTypeDomain: case handlerTypeDnsInterceptor:
return dnsinterceptor.New( return dnsinterceptor.New(
rt, rt,
routeRefCounter, routeRefCounter,
allowedIPsRefCounter, allowedIPsRefCounter,
statusRecorder, statusRecorder,
dnsServer, dnsServer,
wgInterface,
peerStore, peerStore,
) )
case handlerTypeDynamic: case handlerTypeDynamic:
@@ -596,8 +596,8 @@ func handlerType(rt *route.Route, useNewDNSRoute bool) int {
return handlerTypeStatic return handlerTypeStatic
} }
if useNewDNSRoute && runtime.GOOS != "ios" { if useNewDNSRoute {
return handlerTypeDomain return handlerTypeDnsInterceptor
} }
return handlerTypeDynamic return handlerTypeDynamic
} }

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbdns "github.com/netbirdio/netbird/client/internal/dns" nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
@@ -23,6 +24,11 @@ import (
type domainMap map[domain.Domain][]netip.Prefix type domainMap map[domain.Domain][]netip.Prefix
type wgInterface interface {
Name() string
Address() wgaddr.Address
}
type DnsInterceptor struct { type DnsInterceptor struct {
mu sync.RWMutex mu sync.RWMutex
route *route.Route route *route.Route
@@ -32,6 +38,7 @@ type DnsInterceptor struct {
dnsServer nbdns.Server dnsServer nbdns.Server
currentPeerKey string currentPeerKey string
interceptedDomains domainMap interceptedDomains domainMap
wgInterface wgInterface
peerStore *peerstore.Store peerStore *peerstore.Store
} }
@@ -41,6 +48,7 @@ func New(
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
statusRecorder *peer.Status, statusRecorder *peer.Status,
dnsServer nbdns.Server, dnsServer nbdns.Server,
wgInterface wgInterface,
peerStore *peerstore.Store, peerStore *peerstore.Store,
) *DnsInterceptor { ) *DnsInterceptor {
return &DnsInterceptor{ return &DnsInterceptor{
@@ -49,6 +57,7 @@ func New(
allowedIPsRefcounter: allowedIPsRefCounter, allowedIPsRefcounter: allowedIPsRefCounter,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
dnsServer: dnsServer, dnsServer: dnsServer,
wgInterface: wgInterface,
interceptedDomains: make(domainMap), interceptedDomains: make(domainMap),
peerStore: peerStore, peerStore: peerStore,
} }
@@ -135,15 +144,18 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
// ServeDNS implements the dns.Handler interface // ServeDNS implements the dns.Handler interface
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := nbdns.GenerateRequestID()
logger := log.WithField("request_id", requestID)
if len(r.Question) == 0 { if len(r.Question) == 0 {
return return
} }
log.Tracef("received DNS request for domain=%s type=%v class=%v", logger.Tracef("received DNS request for domain=%s type=%v class=%v",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
// pass if non A/AAAA query // pass if non A/AAAA query
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA { if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
d.continueToNextHandler(w, r, "non A/AAAA query") d.continueToNextHandler(w, r, logger, "non A/AAAA query")
return return
} }
@@ -152,29 +164,32 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
d.mu.RUnlock() d.mu.RUnlock()
if peerKey == "" { if peerKey == "" {
d.writeDNSError(w, r, "no current peer key") d.writeDNSError(w, r, logger, "no current peer key")
return return
} }
upstreamIP, err := d.getUpstreamIP(peerKey) upstreamIP, err := d.getUpstreamIP(peerKey)
if err != nil { if err != nil {
d.writeDNSError(w, r, fmt.Sprintf("get upstream IP: %v", err)) d.writeDNSError(w, r, logger, fmt.Sprintf("get upstream IP: %v", err))
return
}
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), nbdns.UpstreamTimeout)
if err != nil {
d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
return return
} }
if r.Extra == nil { if r.Extra == nil {
r.MsgHdr.AuthenticatedData = true r.MsgHdr.AuthenticatedData = true
} }
client := &dns.Client{
Timeout: nbdns.UpstreamTimeout,
Net: "udp",
}
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort) upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream) reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream)
if err != nil { if err != nil {
log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
log.Errorf("failed writing DNS response: %v", err) logger.Errorf("failed writing DNS response: %v", err)
} }
return return
} }
@@ -184,34 +199,34 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
answer = reply.Answer answer = reply.Answer
} }
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer) logger.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
reply.Id = r.Id reply.Id = r.Id
if err := d.writeMsg(w, reply); err != nil { if err := d.writeMsg(w, reply); err != nil {
log.Errorf("failed writing DNS response: %v", err) logger.Errorf("failed writing DNS response: %v", err)
} }
} }
func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, reason string) { func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
log.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason) logger.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason)
resp := new(dns.Msg) resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeServerFailure) resp.SetRcode(r, dns.RcodeServerFailure)
if err := w.WriteMsg(resp); err != nil { if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS error response: %v", err) logger.Errorf("failed to write DNS error response: %v", err)
} }
} }
// continueToNextHandler signals the handler chain to try the next handler // continueToNextHandler signals the handler chain to try the next handler
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) { func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason) logger.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
resp := new(dns.Msg) resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeNameError) resp.SetRcode(r, dns.RcodeNameError)
// Set Zero bit to signal handler chain to continue // Set Zero bit to signal handler chain to continue
resp.MsgHdr.Zero = true resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil { if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed writing DNS continue response: %v", err) logger.Errorf("failed writing DNS continue response: %v", err)
} }
} }

View File

@@ -15,7 +15,7 @@ import (
// MockManager is the mock instance of a route manager // MockManager is the mock instance of a route manager
type MockManager struct { type MockManager struct {
ClassifyRoutesFunc func(routes []*route.Route) (map[route.ID]*route.Route, route.HAMap) ClassifyRoutesFunc func(routes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
UpdateRoutesFunc func (updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error UpdateRoutesFunc func(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
TriggerSelectionFunc func(haMap route.HAMap) TriggerSelectionFunc func(haMap route.HAMap)
GetRouteSelectorFunc func() *routeselector.RouteSelector GetRouteSelectorFunc func() *routeselector.RouteSelector
GetClientRoutesFunc func() route.HAMap GetClientRoutesFunc func() route.HAMap

View File

@@ -32,7 +32,6 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) { func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
nets := make([]string, 0) nets := make([]string, 0)
for _, r := range clientRoutes { for _, r := range clientRoutes {
// filter out domain routes
if r.IsDynamic() { if r.IsDynamic() {
continue continue
} }
@@ -46,30 +45,27 @@ func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
if runtime.GOOS != "android" { if runtime.GOOS != "android" {
return return
} }
newNets := make([]string, 0)
var newNets []string
for _, routes := range idMap { for _, routes := range idMap {
for _, r := range routes { for _, r := range routes {
if r.IsDynamic() {
continue
}
newNets = append(newNets, r.Network.String()) newNets = append(newNets, r.Network.String())
} }
} }
sort.Strings(newNets) sort.Strings(newNets)
switch runtime.GOOS { if !n.hasDiff(n.initialRouteRanges, newNets) {
case "android": return
if !n.hasDiff(n.initialRouteRanges, newNets) {
return
}
default:
if !n.hasDiff(n.routeRanges, newNets) {
return
}
} }
n.routeRanges = newNets n.routeRanges = newNets
n.notify() n.notify()
} }
// OnNewPrefixes is called from iOS only
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) { func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
newNets := make([]string, 0) newNets := make([]string, 0)
for _, prefix := range prefixes { for _, prefix := range prefixes {
@@ -77,19 +73,11 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
} }
sort.Strings(newNets) sort.Strings(newNets)
switch runtime.GOOS { if !n.hasDiff(n.routeRanges, newNets) {
case "android": return
if !n.hasDiff(n.initialRouteRanges, newNets) {
return
}
default:
if !n.hasDiff(n.routeRanges, newNets) {
return
}
} }
n.routeRanges = newNets n.routeRanges = newNets
n.notify() n.notify()
} }

View File

@@ -10,10 +10,11 @@ type StatusType string
const ( const (
StatusIdle StatusType = "Idle" StatusIdle StatusType = "Idle"
StatusConnecting StatusType = "Connecting" StatusConnecting StatusType = "Connecting"
StatusConnected StatusType = "Connected" StatusConnected StatusType = "Connected"
StatusNeedsLogin StatusType = "NeedsLogin" StatusNeedsLogin StatusType = "NeedsLogin"
StatusLoginFailed StatusType = "LoginFailed" StatusLoginFailed StatusType = "LoginFailed"
StatusSessionExpired StatusType = "SessionExpired"
) )
// CtxInitState setup context state into the context tree. // CtxInitState setup context state into the context tree.

View File

@@ -46,28 +46,40 @@
<ComponentRef Id="NetbirdFiles" /> <ComponentRef Id="NetbirdFiles" />
</ComponentGroup> </ComponentGroup>
<Property Id="cmd" Value="cmd.exe"/>
<CustomAction Id="KillDaemon" <CustomAction Id="KillDaemon"
ExeCommand='/c "taskkill /im netbird.exe"' BinaryRef="WixCA"
DllEntry="WixQuietExec64"
Execute="deferred" Execute="deferred"
Property="cmd"
Impersonate="no" Impersonate="no"
Return="ignore" Return="ignore"
/> />
<CustomAction Id="KillUI" <CustomAction Id="KillUI"
ExeCommand='/c "taskkill /im netbird-ui.exe"' BinaryRef="WixCA"
DllEntry="WixQuietExec64"
Execute="deferred" Execute="deferred"
Property="cmd"
Impersonate="no" Impersonate="no"
Return="ignore" Return="ignore"
/> />
<CustomAction Id="SetKillDaemonCommand"
Property="KillDaemon"
Value="taskkill /f /im netbird.exe"
Execute="immediate"
/>
<CustomAction Id="SetKillUICommand"
Property="KillUI"
Value="taskkill /f /im netbird-ui.exe"
Execute="immediate"
/>
<InstallExecuteSequence> <InstallExecuteSequence>
<!-- For Uninstallation --> <!-- For Uninstallation -->
<Custom Action="SetKillDaemonCommand" Before="KillDaemon" Condition="Installed"/>
<Custom Action="KillDaemon" Before="RemoveFiles" Condition="Installed"/> <Custom Action="KillDaemon" Before="RemoveFiles" Condition="Installed"/>
<Custom Action="KillUI" After="KillDaemon" Condition="Installed"/> <Custom Action="SetKillUICommand" After="KillDaemon" Condition="Installed"/>
<Custom Action="KillUI" After="SetKillUICommand" Condition="Installed"/>
</InstallExecuteSequence> </InstallExecuteSequence>
<!-- Icons --> <!-- Icons -->

View File

@@ -8,6 +8,7 @@ import (
"runtime" "runtime"
"strconv" "strconv"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
@@ -66,6 +67,7 @@ type Server struct {
lastProbe time.Time lastProbe time.Time
persistNetworkMap bool persistNetworkMap bool
isSessionActive atomic.Bool
} }
type oauthAuthFlow struct { type oauthAuthFlow struct {
@@ -567,9 +569,6 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin
tokenInfo, err := s.oauthAuthFlow.flow.WaitToken(waitCTX, flowInfo) tokenInfo, err := s.oauthAuthFlow.flow.WaitToken(waitCTX, flowInfo)
if err != nil { if err != nil {
if err == context.Canceled {
return nil, nil //nolint:nilnil
}
s.mutex.Lock() s.mutex.Lock()
s.oauthAuthFlow.expiresAt = time.Now() s.oauthAuthFlow.expiresAt = time.Now()
s.mutex.Unlock() s.mutex.Unlock()
@@ -640,6 +639,7 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
for { for {
select { select {
case <-runningChan: case <-runningChan:
s.isSessionActive.Store(true)
return &proto.UpResponse{}, nil return &proto.UpResponse{}, nil
case <-callerCtx.Done(): case <-callerCtx.Done():
log.Debug("context done, stopping the wait for engine to become ready") log.Debug("context done, stopping the wait for engine to become ready")
@@ -668,6 +668,7 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
log.Errorf("failed to shut down properly: %v", err) log.Errorf("failed to shut down properly: %v", err)
return nil, err return nil, err
} }
s.isSessionActive.Store(false)
state := internal.CtxGetState(s.rootCtx) state := internal.CtxGetState(s.rootCtx)
state.Set(internal.StatusIdle) state.Set(internal.StatusIdle)
@@ -694,6 +695,12 @@ func (s *Server) Status(
return nil, err return nil, err
} }
if status == internal.StatusNeedsLogin && s.isSessionActive.Load() {
log.Debug("status requested while session is active, returning SessionExpired")
status = internal.StatusSessionExpired
s.isSessionActive.Store(false)
}
statusResponse := proto.StatusResponse{Status: string(status), DaemonVersion: version.NetbirdVersion()} statusResponse := proto.StatusResponse{Status: string(status), DaemonVersion: version.NetbirdVersion()}
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String()) s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())

View File

@@ -59,16 +59,16 @@ type Info struct {
Environment Environment Environment Environment
Files []File // for posture checks Files []File // for posture checks
RosenpassEnabled bool RosenpassEnabled bool
RosenpassPermissive bool RosenpassPermissive bool
ServerSSHAllowed bool ServerSSHAllowed bool
DisableClientRoutes bool DisableClientRoutes bool
DisableServerRoutes bool DisableServerRoutes bool
DisableDNS bool DisableDNS bool
DisableFirewall bool DisableFirewall bool
BlockLANAccess bool BlockLANAccess bool
BlockInbound bool BlockInbound bool
LazyConnectionEnabled bool LazyConnectionEnabled bool
} }

View File

@@ -20,7 +20,10 @@ import (
"fyne.io/fyne/v2" "fyne.io/fyne/v2"
"fyne.io/fyne/v2/app" "fyne.io/fyne/v2/app"
"fyne.io/fyne/v2/canvas"
"fyne.io/fyne/v2/container"
"fyne.io/fyne/v2/dialog" "fyne.io/fyne/v2/dialog"
"fyne.io/fyne/v2/layout"
"fyne.io/fyne/v2/theme" "fyne.io/fyne/v2/theme"
"fyne.io/fyne/v2/widget" "fyne.io/fyne/v2/widget"
"fyne.io/systray" "fyne.io/systray"
@@ -51,7 +54,7 @@ const (
) )
func main() { func main() {
daemonAddr, showSettings, showNetworks, showDebug, errorMsg, saveLogsInFile := parseFlags() daemonAddr, showSettings, showNetworks, showLoginURL, showDebug, errorMsg, saveLogsInFile := parseFlags()
// Initialize file logging if needed. // Initialize file logging if needed.
var logFile string var logFile string
@@ -77,13 +80,13 @@ func main() {
} }
// Create the service client (this also builds the settings or networks UI if requested). // Create the service client (this also builds the settings or networks UI if requested).
client := newServiceClient(daemonAddr, logFile, a, showSettings, showNetworks, showDebug) client := newServiceClient(daemonAddr, logFile, a, showSettings, showNetworks, showLoginURL, showDebug)
// Watch for theme/settings changes to update the icon. // Watch for theme/settings changes to update the icon.
go watchSettingsChanges(a, client) go watchSettingsChanges(a, client)
// Run in window mode if any UI flag was set. // Run in window mode if any UI flag was set.
if showSettings || showNetworks || showDebug { if showSettings || showNetworks || showDebug || showLoginURL {
a.Run() a.Run()
return return
} }
@@ -104,7 +107,7 @@ func main() {
} }
// parseFlags reads and returns all needed command-line flags. // parseFlags reads and returns all needed command-line flags.
func parseFlags() (daemonAddr string, showSettings, showNetworks, showDebug bool, errorMsg string, saveLogsInFile bool) { func parseFlags() (daemonAddr string, showSettings, showNetworks, showLoginURL, showDebug bool, errorMsg string, saveLogsInFile bool) {
defaultDaemonAddr := "unix:///var/run/netbird.sock" defaultDaemonAddr := "unix:///var/run/netbird.sock"
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
defaultDaemonAddr = "tcp://127.0.0.1:41731" defaultDaemonAddr = "tcp://127.0.0.1:41731"
@@ -112,6 +115,7 @@ func parseFlags() (daemonAddr string, showSettings, showNetworks, showDebug bool
flag.StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]") flag.StringVar(&daemonAddr, "daemon-addr", defaultDaemonAddr, "Daemon service address to serve CLI requests [unix|tcp]://[path|host:port]")
flag.BoolVar(&showSettings, "settings", false, "run settings window") flag.BoolVar(&showSettings, "settings", false, "run settings window")
flag.BoolVar(&showNetworks, "networks", false, "run networks window") flag.BoolVar(&showNetworks, "networks", false, "run networks window")
flag.BoolVar(&showLoginURL, "login-url", false, "show login URL in a popup window")
flag.BoolVar(&showDebug, "debug", false, "run debug window") flag.BoolVar(&showDebug, "debug", false, "run debug window")
flag.StringVar(&errorMsg, "error-msg", "", "displays an error message window") flag.StringVar(&errorMsg, "error-msg", "", "displays an error message window")
flag.BoolVar(&saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir())) flag.BoolVar(&saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir()))
@@ -253,6 +257,7 @@ type serviceClient struct {
exitNodeStates []exitNodeState exitNodeStates []exitNodeState
mExitNodeDeselectAll *systray.MenuItem mExitNodeDeselectAll *systray.MenuItem
logFile string logFile string
wLoginURL fyne.Window
} }
type menuHandler struct { type menuHandler struct {
@@ -263,7 +268,7 @@ type menuHandler struct {
// newServiceClient instance constructor // newServiceClient instance constructor
// //
// This constructor also builds the UI elements for the settings window. // This constructor also builds the UI elements for the settings window.
func newServiceClient(addr string, logFile string, a fyne.App, showSettings bool, showNetworks bool, showDebug bool) *serviceClient { func newServiceClient(addr string, logFile string, a fyne.App, showSettings bool, showNetworks bool, showLoginURL bool, showDebug bool) *serviceClient {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
s := &serviceClient{ s := &serviceClient{
ctx: ctx, ctx: ctx,
@@ -275,7 +280,7 @@ func newServiceClient(addr string, logFile string, a fyne.App, showSettings bool
showAdvancedSettings: showSettings, showAdvancedSettings: showSettings,
showNetworks: showNetworks, showNetworks: showNetworks,
update: version.NewUpdate(), update: version.NewUpdate("nb/client-ui"),
} }
s.eventHandler = newEventHandler(s) s.eventHandler = newEventHandler(s)
@@ -286,6 +291,8 @@ func newServiceClient(addr string, logFile string, a fyne.App, showSettings bool
s.showSettingsUI() s.showSettingsUI()
case showNetworks: case showNetworks:
s.showNetworksUI() s.showNetworksUI()
case showLoginURL:
s.showLoginURL()
case showDebug: case showDebug:
s.showDebugUI() s.showDebugUI()
} }
@@ -445,11 +452,11 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
} }
} }
func (s *serviceClient) login() error { func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) {
conn, err := s.getSrvClient(defaultFailTimeout) conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil { if err != nil {
log.Errorf("get client: %v", err) log.Errorf("get client: %v", err)
return err return nil, err
} }
loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{ loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{
@@ -457,24 +464,24 @@ func (s *serviceClient) login() error {
}) })
if err != nil { if err != nil {
log.Errorf("login to management URL with: %v", err) log.Errorf("login to management URL with: %v", err)
return err return nil, err
} }
if loginResp.NeedsSSOLogin { if loginResp.NeedsSSOLogin && openURL {
err = open.Run(loginResp.VerificationURIComplete) err = open.Run(loginResp.VerificationURIComplete)
if err != nil { if err != nil {
log.Errorf("opening the verification uri in the browser failed: %v", err) log.Errorf("opening the verification uri in the browser failed: %v", err)
return err return nil, err
} }
_, err = conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode}) _, err = conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
if err != nil { if err != nil {
log.Errorf("waiting sso login failed with: %v", err) log.Errorf("waiting sso login failed with: %v", err)
return err return nil, err
} }
} }
return nil return loginResp, nil
} }
func (s *serviceClient) menuUpClick() error { func (s *serviceClient) menuUpClick() error {
@@ -486,7 +493,7 @@ func (s *serviceClient) menuUpClick() error {
return err return err
} }
err = s.login() _, err = s.login(true)
if err != nil { if err != nil {
log.Errorf("login failed with: %v", err) log.Errorf("login failed with: %v", err)
return err return err
@@ -558,7 +565,7 @@ func (s *serviceClient) updateStatus() error {
defer s.updateIndicationLock.Unlock() defer s.updateIndicationLock.Unlock()
// notify the user when the session has expired // notify the user when the session has expired
if status.Status == string(internal.StatusNeedsLogin) { if status.Status == string(internal.StatusSessionExpired) {
s.onSessionExpire() s.onSessionExpire()
} }
@@ -732,7 +739,6 @@ func (s *serviceClient) onTrayReady() {
go s.eventHandler.listen(s.ctx) go s.eventHandler.listen(s.ctx)
} }
func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File { func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File {
if s.logFile == "" { if s.logFile == "" {
// attach child's streams to parent's streams // attach child's streams to parent's streams
@@ -871,17 +877,9 @@ func (s *serviceClient) onUpdateAvailable() {
// onSessionExpire sends a notification to the user when the session expires. // onSessionExpire sends a notification to the user when the session expires.
func (s *serviceClient) onSessionExpire() { func (s *serviceClient) onSessionExpire() {
s.sendNotification = true
if s.sendNotification { if s.sendNotification {
title := "Connection session expired" s.eventHandler.runSelfCommand(s.ctx, "login-url", "true")
if runtime.GOOS == "darwin" {
title = "NetBird connection session expired"
}
s.app.SendNotification(
fyne.NewNotification(
title,
"Please re-authenticate to connect to the network",
),
)
s.sendNotification = false s.sendNotification = false
} }
} }
@@ -955,9 +953,9 @@ func (s *serviceClient) updateConfig() error {
ServerSSHAllowed: &sshAllowed, ServerSSHAllowed: &sshAllowed,
RosenpassEnabled: &rosenpassEnabled, RosenpassEnabled: &rosenpassEnabled,
DisableAutoConnect: &disableAutoStart, DisableAutoConnect: &disableAutoStart,
DisableNotifications: &notificationsDisabled,
LazyConnectionEnabled: &lazyConnectionEnabled, LazyConnectionEnabled: &lazyConnectionEnabled,
BlockInbound: &blockInbound, BlockInbound: &blockInbound,
DisableNotifications: &notificationsDisabled,
} }
if err := s.restartClient(&loginRequest); err != nil { if err := s.restartClient(&loginRequest); err != nil {
@@ -991,6 +989,99 @@ func (s *serviceClient) restartClient(loginRequest *proto.LoginRequest) error {
return nil return nil
} }
// showLoginURL creates a borderless window styled like a pop-up in the top-right corner using s.wLoginURL.
func (s *serviceClient) showLoginURL() {
resIcon := fyne.NewStaticResource("netbird.png", iconAbout)
if s.wLoginURL == nil {
s.wLoginURL = s.app.NewWindow("NetBird Session Expired")
s.wLoginURL.Resize(fyne.NewSize(400, 200))
s.wLoginURL.SetIcon(resIcon)
}
// add a description label
label := widget.NewLabel("Your NetBird session has expired.\nPlease re-authenticate to continue using NetBird.")
btn := widget.NewButtonWithIcon("Re-authenticate", theme.ViewRefreshIcon(), func() {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
log.Errorf("get client: %v", err)
return
}
resp, err := s.login(false)
if err != nil {
log.Errorf("failed to fetch login URL: %v", err)
return
}
verificationURL := resp.VerificationURIComplete
if verificationURL == "" {
verificationURL = resp.VerificationURI
}
if verificationURL == "" {
log.Error("no verification URL provided in the login response")
return
}
if err := openURL(verificationURL); err != nil {
log.Errorf("failed to open login URL: %v", err)
return
}
_, err = conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode})
if err != nil {
log.Errorf("Waiting sso login failed with: %v", err)
label.SetText("Waiting login failed, please create \na debug bundle in the settings and contact support.")
return
}
label.SetText("Re-authentication successful.\nReconnecting")
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
if err != nil {
log.Errorf("get service status: %v", err)
return
}
if status.Status == string(internal.StatusConnected) {
label.SetText("Already connected.\nClosing this window.")
time.Sleep(2 * time.Second)
s.wLoginURL.Close()
return
}
_, err = conn.Up(s.ctx, &proto.UpRequest{})
if err != nil {
label.SetText("Reconnecting failed, please create \na debug bundle in the settings and contact support.")
log.Errorf("Reconnecting failed with: %v", err)
return
}
label.SetText("Connection successful.\nClosing this window.")
time.Sleep(time.Second)
s.wLoginURL.Close()
})
img := canvas.NewImageFromResource(resIcon)
img.FillMode = canvas.ImageFillContain
img.SetMinSize(fyne.NewSize(64, 64))
img.Resize(fyne.NewSize(64, 64))
// center the content vertically
content := container.NewVBox(
layout.NewSpacer(),
img,
label,
btn,
layout.NewSpacer(),
)
s.wLoginURL.SetContent(container.NewCenter(content))
s.wLoginURL.Show()
}
func openURL(url string) error { func openURL(url string) error {
var err error var err error
switch runtime.GOOS { switch runtime.GOOS {

View File

@@ -12,6 +12,8 @@ import (
"fyne.io/fyne/v2" "fyne.io/fyne/v2"
"fyne.io/systray" "fyne.io/systray"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/version"
) )
type eventHandler struct { type eventHandler struct {
@@ -120,7 +122,7 @@ func (h *eventHandler) handleAdvancedSettingsClick() {
go func() { go func() {
defer h.client.mAdvancedSettings.Enable() defer h.client.mAdvancedSettings.Enable()
defer h.client.getSrvConfig() defer h.client.getSrvConfig()
h.runSelfCommand("settings", "true") h.runSelfCommand(h.client.ctx, "settings", "true")
}() }()
} }
@@ -128,7 +130,7 @@ func (h *eventHandler) handleCreateDebugBundleClick() {
h.client.mCreateDebugBundle.Disable() h.client.mCreateDebugBundle.Disable()
go func() { go func() {
defer h.client.mCreateDebugBundle.Enable() defer h.client.mCreateDebugBundle.Enable()
h.runSelfCommand("debug", "true") h.runSelfCommand(h.client.ctx, "debug", "true")
}() }()
} }
@@ -143,7 +145,7 @@ func (h *eventHandler) handleGitHubClick() {
} }
func (h *eventHandler) handleUpdateClick() { func (h *eventHandler) handleUpdateClick() {
if err := openURL("https://netbird.io/download"); err != nil { if err := openURL(version.DownloadUrl()); err != nil {
log.Errorf("failed to open download URL: %v", err) log.Errorf("failed to open download URL: %v", err)
} }
} }
@@ -152,7 +154,7 @@ func (h *eventHandler) handleNetworksClick() {
h.client.mNetworks.Disable() h.client.mNetworks.Disable()
go func() { go func() {
defer h.client.mNetworks.Enable() defer h.client.mNetworks.Enable()
h.runSelfCommand("networks", "true") h.runSelfCommand(h.client.ctx, "networks", "true")
}() }()
} }
@@ -170,14 +172,14 @@ func (h *eventHandler) updateConfigWithErr() {
} }
} }
func (h *eventHandler) runSelfCommand(command, arg string) { func (h *eventHandler) runSelfCommand(ctx context.Context, command, arg string) {
proc, err := os.Executable() proc, err := os.Executable()
if err != nil { if err != nil {
log.Errorf("error getting executable path: %v", err) log.Errorf("error getting executable path: %v", err)
return return
} }
cmd := exec.Command(proc, cmd := exec.CommandContext(ctx, proc,
fmt.Sprintf("--%s=%s", command, arg), fmt.Sprintf("--%s=%s", command, arg),
fmt.Sprintf("--daemon-addr=%s", h.client.addr), fmt.Sprintf("--daemon-addr=%s", h.client.addr),
) )

View File

@@ -358,8 +358,6 @@ func (s *serviceClient) updateExitNodes() {
} else { } else {
s.mExitNode.Disable() s.mExitNode.Disable()
} }
log.Debugf("Exit nodes updated: %d", len(s.mExitNodeItems))
} }
func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) { func (s *serviceClient) recreateExitNodeMenu(exitNodes []*proto.Network) {

2
go.mod
View File

@@ -63,7 +63,7 @@ require (
github.com/miekg/dns v1.1.59 github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/nadoo/ipset v0.5.0 github.com/nadoo/ipset v0.5.0
github.com/netbirdio/management-integrations/integrations v0.0.0-20250529122842-6700aa91190c github.com/netbirdio/management-integrations/integrations v0.0.0-20250612164546-6bd7e2338d65
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb
github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/oschwald/maxminddb-golang v1.12.0 github.com/oschwald/maxminddb-golang v1.12.0

4
go.sum
View File

@@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c=
github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250529122842-6700aa91190c h1:SdZxYjR9XXHLyRsTbS1EHBr6+RI15oie1K9Q8yvi3FY= github.com/netbirdio/management-integrations/integrations v0.0.0-20250612164546-6bd7e2338d65 h1:5OfYiLjpr4dbQYJI5ouZaylkVdi2KlErLFOwBeBo5Hw=
github.com/netbirdio/management-integrations/integrations v0.0.0-20250529122842-6700aa91190c/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q= github.com/netbirdio/management-integrations/integrations v0.0.0-20250612164546-6bd7e2338d65/go.mod h1:Gi9raplYzCCyh07Olw/DVfCJTFgpr1WCXJ/Q+8TSA9Q=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8=
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb h1:Cr6age+ePALqlSvtp7wc6lYY97XN7rkD1K4XEDmY+TU= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250514131221-a464fd5f30cb h1:Cr6age+ePALqlSvtp7wc6lYY97XN7rkD1K4XEDmY+TU=

View File

@@ -159,6 +159,12 @@ var (
if err != nil { if err != nil {
return err return err
} }
integrationMetrics, err := integrations.InitIntegrationMetrics(ctx, appMetrics)
if err != nil {
return err
}
store, err := store.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics, false) store, err := store.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics, false)
if err != nil { if err != nil {
return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err) return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err)
@@ -176,7 +182,7 @@ var (
if disableSingleAccMode { if disableSingleAccMode {
mgmtSingleAccModeDomain = "" mgmtSingleAccModeDomain = ""
} }
eventStore, key, err := integrations.InitEventStore(ctx, config.Datadir, config.DataStoreEncryptionKey) eventStore, key, err := integrations.InitEventStore(ctx, config.Datadir, config.DataStoreEncryptionKey, integrationMetrics)
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize database: %s", err) return fmt.Errorf("failed to initialize database: %s", err)
} }
@@ -351,6 +357,13 @@ var (
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", listener.Addr().String()) log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", listener.Addr().String())
serveGRPCWithHTTP(ctx, listener, rootHandler, tlsEnabled) serveGRPCWithHTTP(ctx, listener, rootHandler, tlsEnabled)
update := version.NewUpdate("nb/management")
update.SetDaemonVersion(version.NetbirdVersion())
update.SetOnUpdateListener(func() {
log.WithContext(ctx).Infof("your management version, \"%s\", is outdated, a new management version is available. Learn more here: https://github.com/netbirdio/netbird/releases", version.NetbirdVersion())
})
defer update.StopWatch()
SetupCloseHandler() SetupCloseHandler()
<-stopCh <-stopCh

View File

@@ -24,6 +24,7 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache" nbcache "github.com/netbirdio/netbird/management/server/cache"
@@ -409,14 +410,15 @@ func (am *DefaultAccountManager) handlePeerLoginExpirationSettings(ctx context.C
event = activity.AccountPeerLoginExpirationDisabled event = activity.AccountPeerLoginExpirationDisabled
am.peerLoginExpiry.Cancel(ctx, []string{accountID}) am.peerLoginExpiry.Cancel(ctx, []string{accountID})
} else { } else {
am.checkAndSchedulePeerLoginExpiration(ctx, accountID) am.schedulePeerLoginExpiration(ctx, accountID)
} }
am.StoreEvent(ctx, userID, accountID, accountID, event, nil) am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
} }
if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration { if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil) am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil)
am.checkAndSchedulePeerLoginExpiration(ctx, accountID) am.peerLoginExpiry.Cancel(ctx, []string{accountID})
am.schedulePeerLoginExpiration(ctx, accountID)
} }
} }
@@ -454,6 +456,10 @@ func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
return func() (time.Duration, bool) { return func() (time.Duration, bool) {
//nolint
ctx := context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
//nolint
ctx = context.WithValue(ctx, hook.ExecutionContextKey, fmt.Sprintf("%s-PEER-EXPIRATION", hook.SystemSource))
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
@@ -478,8 +484,11 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc
} }
} }
func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, accountID string) { func (am *DefaultAccountManager) schedulePeerLoginExpiration(ctx context.Context, accountID string) {
am.peerLoginExpiry.Cancel(ctx, []string{accountID}) if am.peerLoginExpiry.IsSchedulerRunning(accountID) {
log.WithContext(ctx).Tracef("peer login expiration job for account %s is already scheduled", accountID)
return
}
if nextRun, ok := am.getNextPeerExpiration(ctx, accountID); ok { if nextRun, ok := am.getNextPeerExpiration(ctx, accountID); ok {
go am.peerLoginExpiry.Schedule(ctx, nextRun, accountID, am.peerLoginExpirationJob(ctx, accountID)) go am.peerLoginExpiry.Schedule(ctx, nextRun, accountID, am.peerLoginExpirationJob(ctx, accountID))
} }
@@ -1844,40 +1853,49 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
} }
func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) { func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {
account, err := am.Store.GetAccount(ctx, accountId) var account *types.Account
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
account, err = transaction.GetAccount(ctx, accountId)
if err != nil {
return err
}
if account.IsDomainPrimaryAccount {
return nil
}
existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, account.Domain)
// error is not a not found error
if handleNotFound(err) != nil {
return err
}
// a primary account already exists for this private domain
if err == nil {
log.WithContext(ctx).WithFields(log.Fields{
"accountId": accountId,
"existingAccountId": existingPrimaryAccountID,
}).Errorf("cannot update account to primary, another account already exists as primary for the same domain")
return status.Errorf(status.Internal, "cannot update account to primary")
}
account.IsDomainPrimaryAccount = true
if err := transaction.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).WithFields(log.Fields{
"accountId": accountId,
}).Errorf("failed to update account to primary: %v", err)
return status.Errorf(status.Internal, "failed to update account to primary")
}
return nil
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if account.IsDomainPrimaryAccount {
return account, nil
}
existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, account.Domain)
// error is not a not found error
if handleNotFound(err) != nil {
return nil, err
}
// a primary account already exists for this private domain
if err == nil {
log.WithContext(ctx).WithFields(log.Fields{
"accountId": accountId,
"existingAccountId": existingPrimaryAccountID,
}).Errorf("cannot update account to primary, another account already exists as primary for the same domain")
return nil, status.Errorf(status.Internal, "cannot update account to primary")
}
account.IsDomainPrimaryAccount = true
if err := am.Store.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).WithFields(log.Fields{
"accountId": accountId,
}).Errorf("failed to update account to primary: %v", err)
return nil, status.Errorf(status.Internal, "failed to update account to primary")
}
return account, nil return account, nil
} }

View File

@@ -1862,11 +1862,8 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
require.NoError(t, err, "expecting to update account settings successfully but got error") require.NoError(t, err, "expecting to update account settings successfully but got error")
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
wg.Add(2) wg.Add(1)
manager.peerLoginExpiry = &MockScheduler{ manager.peerLoginExpiry = &MockScheduler{
CancelFunc: func(ctx context.Context, IDs []string) {
wg.Done()
},
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
wg.Done() wg.Done()
}, },

View File

@@ -664,15 +664,6 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
return false, nil return false, nil
} }
func (am *DefaultAccountManager) anyGroupHasPeers(account *types.Account, groupIDs []string) bool {
for _, groupID := range groupIDs {
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
return true
}
}
return false
}
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources. // anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs)

View File

@@ -426,6 +426,10 @@ components:
items: items:
type: string type: string
example: "stage-host-1" example: "stage-host-1"
ephemeral:
description: Indicates whether the peer is ephemeral or not
type: boolean
example: false
required: required:
- city_name - city_name
- connected - connected
@@ -450,6 +454,7 @@ components:
- approval_required - approval_required
- serial_number - serial_number
- extra_dns_labels - extra_dns_labels
- ephemeral
AccessiblePeer: AccessiblePeer:
allOf: allOf:
- $ref: '#/components/schemas/PeerMinimum' - $ref: '#/components/schemas/PeerMinimum'

View File

@@ -1016,6 +1016,9 @@ type Peer struct {
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"` DnsLabel string `json:"dns_label"`
// Ephemeral Indicates whether the peer is ephemeral or not
Ephemeral bool `json:"ephemeral"`
// ExtraDnsLabels Extra DNS labels added to the peer // ExtraDnsLabels Extra DNS labels added to the peer
ExtraDnsLabels []string `json:"extra_dns_labels"` ExtraDnsLabels []string `json:"extra_dns_labels"`
@@ -1097,6 +1100,9 @@ type PeerBatch struct {
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"` DnsLabel string `json:"dns_label"`
// Ephemeral Indicates whether the peer is ephemeral or not
Ephemeral bool `json:"ephemeral"`
// ExtraDnsLabels Extra DNS labels added to the peer // ExtraDnsLabels Extra DNS labels added to the peer
ExtraDnsLabels []string `json:"extra_dns_labels"` ExtraDnsLabels []string `json:"extra_dns_labels"`

View File

@@ -365,6 +365,7 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
CityName: peer.Location.CityName, CityName: peer.Location.CityName,
SerialNumber: peer.Meta.SystemSerialNumber, SerialNumber: peer.Meta.SystemSerialNumber,
InactivityExpirationEnabled: peer.InactivityExpirationEnabled, InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
Ephemeral: peer.Ephemeral,
} }
} }

View File

@@ -37,21 +37,23 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
a, err := am.Store.GetAccountByUser(ctx, userID) return am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err != nil { a, err := transaction.GetAccountByUser(ctx, userID)
return err if err != nil {
} return err
}
var extra *types.ExtraSettings var extra *types.ExtraSettings
if a.Settings.Extra != nil { if a.Settings.Extra != nil {
extra = a.Settings.Extra extra = a.Settings.Extra
} else { } else {
extra = &types.ExtraSettings{} extra = &types.ExtraSettings{}
a.Settings.Extra = extra a.Settings.Extra = extra
} }
extra.IntegratedValidatorGroups = groups extra.IntegratedValidatorGroups = groups
return am.Store.SaveAccount(ctx, a) return transaction.SaveAccount(ctx, a)
})
} }
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) { func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) {

View File

@@ -184,7 +184,9 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
ephemeralPeersSKs int ephemeralPeersSKs int
ephemeralPeersSKUsage int ephemeralPeersSKUsage int
activePeersLastDay int activePeersLastDay int
activeUserPeersLastDay int
osPeers map[string]int osPeers map[string]int
activeUsersLastDay map[string]struct{}
userPeers int userPeers int
rules int rules int
rulesProtocol map[string]int rulesProtocol map[string]int
@@ -203,6 +205,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
version string version string
peerActiveVersions []string peerActiveVersions []string
osUIClients map[string]int osUIClients map[string]int
rosenpassEnabled int
) )
start := time.Now() start := time.Now()
metricsProperties := make(properties) metricsProperties := make(properties)
@@ -210,6 +213,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
osUIClients = make(map[string]int) osUIClients = make(map[string]int)
rulesProtocol = make(map[string]int) rulesProtocol = make(map[string]int)
rulesDirection = make(map[string]int) rulesDirection = make(map[string]int)
activeUsersLastDay = make(map[string]struct{})
uptime = time.Since(w.startupTime).Seconds() uptime = time.Since(w.startupTime).Seconds()
connections := w.connManager.GetAllConnectedPeers() connections := w.connManager.GetAllConnectedPeers()
version = nbversion.NetbirdVersion() version = nbversion.NetbirdVersion()
@@ -277,10 +281,14 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
for _, peer := range account.Peers { for _, peer := range account.Peers {
peers++ peers++
if peer.SSHEnabled { if peer.SSHEnabled || peer.Meta.Flags.ServerSSHAllowed {
peersSSHEnabled++ peersSSHEnabled++
} }
if peer.Meta.Flags.RosenpassEnabled {
rosenpassEnabled++
}
if peer.UserID != "" { if peer.UserID != "" {
userPeers++ userPeers++
} }
@@ -299,6 +307,10 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
_, connected := connections[peer.ID] _, connected := connections[peer.ID]
if connected || peer.Status.LastSeen.After(w.lastRun) { if connected || peer.Status.LastSeen.After(w.lastRun) {
activePeersLastDay++ activePeersLastDay++
if peer.UserID != "" {
activeUserPeersLastDay++
activeUsersLastDay[peer.UserID] = struct{}{}
}
osActiveKey := osKey + "_active" osActiveKey := osKey + "_active"
osActiveCount := osPeers[osActiveKey] osActiveCount := osPeers[osActiveKey]
osPeers[osActiveKey] = osActiveCount + 1 osPeers[osActiveKey] = osActiveCount + 1
@@ -320,6 +332,8 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
metricsProperties["ephemeral_peers_setup_keys"] = ephemeralPeersSKs metricsProperties["ephemeral_peers_setup_keys"] = ephemeralPeersSKs
metricsProperties["ephemeral_peers_setup_keys_usage"] = ephemeralPeersSKUsage metricsProperties["ephemeral_peers_setup_keys_usage"] = ephemeralPeersSKUsage
metricsProperties["active_peers_last_day"] = activePeersLastDay metricsProperties["active_peers_last_day"] = activePeersLastDay
metricsProperties["active_user_peers_last_day"] = activeUserPeersLastDay
metricsProperties["active_users_last_day"] = len(activeUsersLastDay)
metricsProperties["user_peers"] = userPeers metricsProperties["user_peers"] = userPeers
metricsProperties["rules"] = rules metricsProperties["rules"] = rules
metricsProperties["rules_with_src_posture_checks"] = rulesWithSrcPostureChecks metricsProperties["rules_with_src_posture_checks"] = rulesWithSrcPostureChecks
@@ -338,6 +352,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
metricsProperties["ui_clients"] = uiClient metricsProperties["ui_clients"] = uiClient
metricsProperties["idp_manager"] = w.idpManager metricsProperties["idp_manager"] = w.idpManager
metricsProperties["store_engine"] = w.dataSource.GetStoreEngine() metricsProperties["store_engine"] = w.dataSource.GetStoreEngine()
metricsProperties["rosenpass_enabled"] = rosenpassEnabled
for protocol, count := range rulesProtocol { for protocol, count := range rulesProtocol {
metricsProperties["rules_protocol_"+protocol] = count metricsProperties["rules_protocol_"+protocol] = count

View File

@@ -47,8 +47,8 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
"1": { "1": {
ID: "1", ID: "1",
UserID: "test", UserID: "test",
SSHEnabled: true, SSHEnabled: false,
Meta: nbpeer.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1"}, Meta: nbpeer.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1", Flags: nbpeer.Flags{ServerSSHAllowed: true, RosenpassEnabled: true}},
}, },
}, },
Policies: []*types.Policy{ Policies: []*types.Policy{
@@ -312,7 +312,19 @@ func TestGenerateProperties(t *testing.T) {
} }
if properties["posture_checks"] != 2 { if properties["posture_checks"] != 2 {
t.Errorf("expected 1 posture_checks, got %d", properties["posture_checks"]) t.Errorf("expected 2 posture_checks, got %d", properties["posture_checks"])
}
if properties["rosenpass_enabled"] != 1 {
t.Errorf("expected 1 rosenpass_enabled, got %d", properties["rosenpass_enabled"])
}
if properties["active_user_peers_last_day"] != 2 {
t.Errorf("expected 2 active_user_peers_last_day, got %d", properties["active_user_peers_last_day"])
}
if properties["active_users_last_day"] != 1 {
t.Errorf("expected 1 active_users_last_day, got %d", properties["active_users_last_day"])
} }
} }

View File

@@ -92,7 +92,7 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc
// fetch all the peers that have access to the user's peers // fetch all the peers that have access to the user's peers
for _, peer := range peers { for _, peer := range peers {
aclPeers, _ := account.GetPeerConnectionResources(ctx, peer.ID, approvedPeersMap) aclPeers, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap)
for _, p := range aclPeers { for _, p := range aclPeers {
peersMap[p.ID] = p peersMap[p.ID] = p
} }
@@ -133,7 +133,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
} }
if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled { if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, accountID) am.schedulePeerLoginExpiration(ctx, accountID)
} }
if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled { if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled {
@@ -296,7 +296,8 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(dnsDomain)) am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(dnsDomain))
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled { if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, accountID) am.peerLoginExpiry.Cancel(ctx, []string{accountID})
am.schedulePeerLoginExpiration(ctx, accountID)
} }
} }
@@ -1148,7 +1149,7 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
} }
for _, p := range userPeers { for _, p := range userPeers {
aclPeers, _ := account.GetPeerConnectionResources(ctx, p.ID, approvedPeersMap) aclPeers, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap)
for _, aclPeer := range aclPeers { for _, aclPeer := range aclPeers {
if aclPeer.ID == peer.ID { if aclPeer.ID == peer.ID {
return peer, nil return peer, nil

View File

@@ -27,6 +27,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
ID: "peerB", ID: "peerB",
IP: net.ParseIP("100.65.80.39"), IP: net.ParseIP("100.65.80.39"),
Status: &nbpeer.PeerStatus{}, Status: &nbpeer.PeerStatus{},
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.48.0"},
}, },
"peerC": { "peerC": {
ID: "peerC", ID: "peerC",
@@ -63,6 +64,12 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
IP: net.ParseIP("100.65.31.2"), IP: net.ParseIP("100.65.31.2"),
Status: &nbpeer.PeerStatus{}, Status: &nbpeer.PeerStatus{},
}, },
"peerK": {
ID: "peerK",
IP: net.ParseIP("100.32.80.1"),
Status: &nbpeer.PeerStatus{},
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.30.0"},
},
}, },
Groups: map[string]*types.Group{ Groups: map[string]*types.Group{
"GroupAll": { "GroupAll": {
@@ -111,6 +118,13 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
"peerI", "peerI",
}, },
}, },
"GroupWorkflow": {
ID: "GroupWorkflow",
Name: "workflow",
Peers: []string{
"peerK",
},
},
}, },
Policies: []*types.Policy{ Policies: []*types.Policy{
{ {
@@ -189,6 +203,39 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
}, },
}, },
}, },
{
ID: "RuleWorkflow",
Name: "Workflow",
Description: "No description",
Enabled: true,
Rules: []*types.PolicyRule{
{
ID: "RuleWorkflow",
Name: "Workflow",
Description: "No description",
Bidirectional: true,
Enabled: true,
Protocol: types.PolicyRuleProtocolTCP,
Action: types.PolicyTrafficActionAccept,
PortRanges: []types.RulePortRange{
{
Start: 8088,
End: 8088,
},
{
Start: 9090,
End: 9095,
},
},
Sources: []string{
"GroupWorkflow",
},
Destinations: []string{
"GroupDMZ",
},
},
},
},
}, },
} }
@@ -199,14 +246,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
t.Run("check that all peers get map", func(t *testing.T) { t.Run("check that all peers get map", func(t *testing.T) {
for _, p := range account.Peers { for _, p := range account.Peers {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p.ID, validatedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p, validatedPeers)
assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present") assert.GreaterOrEqual(t, len(peers), 1, "minimum number peers should present")
assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present") assert.GreaterOrEqual(t, len(firewallRules), 1, "minimum number of firewall rules should present")
} }
}) })
t.Run("check first peer map details", func(t *testing.T) { t.Run("check first peer map details", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", validatedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers)
assert.Len(t, peers, 8) assert.Len(t, peers, 8)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
@@ -364,6 +411,32 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
assert.True(t, contains, "rule not found in expected rules %#v", rule) assert.True(t, contains, "rule not found in expected rules %#v", rule)
} }
}) })
t.Run("check port ranges support for older peers", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers)
assert.Len(t, peers, 1)
assert.Contains(t, peers, account.Peers["peerI"])
expectedFirewallRules := []*types.FirewallRule{
{
PeerIP: "100.65.31.2",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "tcp",
Port: "8088",
PolicyID: "RuleWorkflow",
},
{
PeerIP: "100.65.31.2",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
Port: "8088",
PolicyID: "RuleWorkflow",
},
}
assert.ElementsMatch(t, firewallRules, expectedFirewallRules)
})
} }
func TestAccount_getPeersByPolicyDirect(t *testing.T) { func TestAccount_getPeersByPolicyDirect(t *testing.T) {
@@ -466,10 +539,10 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
} }
t.Run("check first peer map", func(t *testing.T) { t.Run("check first peer map", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
epectedFirewallRules := []*types.FirewallRule{ expectedFirewallRules := []*types.FirewallRule{
{ {
PeerIP: "100.65.254.139", PeerIP: "100.65.254.139",
Direction: types.FirewallRuleDirectionIN, Direction: types.FirewallRuleDirectionIN,
@@ -487,19 +560,19 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
PolicyID: "RuleSwarm", PolicyID: "RuleSwarm",
}, },
} }
assert.Len(t, firewallRules, len(epectedFirewallRules)) assert.Len(t, firewallRules, len(expectedFirewallRules))
slices.SortFunc(epectedFirewallRules, sortFunc()) slices.SortFunc(expectedFirewallRules, sortFunc())
slices.SortFunc(firewallRules, sortFunc()) slices.SortFunc(firewallRules, sortFunc())
for i := range firewallRules { for i := range firewallRules {
assert.Equal(t, epectedFirewallRules[i], firewallRules[i]) assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
} }
}) })
t.Run("check second peer map", func(t *testing.T) { t.Run("check second peer map", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
assert.Contains(t, peers, account.Peers["peerB"]) assert.Contains(t, peers, account.Peers["peerB"])
epectedFirewallRules := []*types.FirewallRule{ expectedFirewallRules := []*types.FirewallRule{
{ {
PeerIP: "100.65.80.39", PeerIP: "100.65.80.39",
Direction: types.FirewallRuleDirectionIN, Direction: types.FirewallRuleDirectionIN,
@@ -517,21 +590,21 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
PolicyID: "RuleSwarm", PolicyID: "RuleSwarm",
}, },
} }
assert.Len(t, firewallRules, len(epectedFirewallRules)) assert.Len(t, firewallRules, len(expectedFirewallRules))
slices.SortFunc(epectedFirewallRules, sortFunc()) slices.SortFunc(expectedFirewallRules, sortFunc())
slices.SortFunc(firewallRules, sortFunc()) slices.SortFunc(firewallRules, sortFunc())
for i := range firewallRules { for i := range firewallRules {
assert.Equal(t, epectedFirewallRules[i], firewallRules[i]) assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
} }
}) })
account.Policies[1].Rules[0].Bidirectional = false account.Policies[1].Rules[0].Bidirectional = false
t.Run("check first peer map directional only", func(t *testing.T) { t.Run("check first peer map directional only", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
epectedFirewallRules := []*types.FirewallRule{ expectedFirewallRules := []*types.FirewallRule{
{ {
PeerIP: "100.65.254.139", PeerIP: "100.65.254.139",
Direction: types.FirewallRuleDirectionOUT, Direction: types.FirewallRuleDirectionOUT,
@@ -541,19 +614,19 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
PolicyID: "RuleSwarm", PolicyID: "RuleSwarm",
}, },
} }
assert.Len(t, firewallRules, len(epectedFirewallRules)) assert.Len(t, firewallRules, len(expectedFirewallRules))
slices.SortFunc(epectedFirewallRules, sortFunc()) slices.SortFunc(expectedFirewallRules, sortFunc())
slices.SortFunc(firewallRules, sortFunc()) slices.SortFunc(firewallRules, sortFunc())
for i := range firewallRules { for i := range firewallRules {
assert.Equal(t, epectedFirewallRules[i], firewallRules[i]) assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
} }
}) })
t.Run("check second peer map directional only", func(t *testing.T) { t.Run("check second peer map directional only", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
assert.Contains(t, peers, account.Peers["peerB"]) assert.Contains(t, peers, account.Peers["peerB"])
epectedFirewallRules := []*types.FirewallRule{ expectedFirewallRules := []*types.FirewallRule{
{ {
PeerIP: "100.65.80.39", PeerIP: "100.65.80.39",
Direction: types.FirewallRuleDirectionIN, Direction: types.FirewallRuleDirectionIN,
@@ -563,11 +636,11 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
PolicyID: "RuleSwarm", PolicyID: "RuleSwarm",
}, },
} }
assert.Len(t, firewallRules, len(epectedFirewallRules)) assert.Len(t, firewallRules, len(expectedFirewallRules))
slices.SortFunc(epectedFirewallRules, sortFunc()) slices.SortFunc(expectedFirewallRules, sortFunc())
slices.SortFunc(firewallRules, sortFunc()) slices.SortFunc(firewallRules, sortFunc())
for i := range firewallRules { for i := range firewallRules {
assert.Equal(t, epectedFirewallRules[i], firewallRules[i]) assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
} }
}) })
} }
@@ -748,7 +821,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
t.Run("verify peer's network map with default group peer list", func(t *testing.T) { t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm, // peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
// will establish a connection with all source peers satisfying the NB posture check. // will establish a connection with all source peers satisfying the NB posture check.
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
assert.Len(t, peers, 4) assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4) assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
@@ -758,7 +831,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerC satisfy the NB posture check, should establish connection to all destination group peer's // peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections // We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, 1) assert.Len(t, firewallRules, 1)
expectedFirewallRules := []*types.FirewallRule{ expectedFirewallRules := []*types.FirewallRule{
@@ -775,7 +848,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection // all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
assert.Len(t, peers, 4) assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4) assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
@@ -785,7 +858,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm, // peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection // all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
assert.Len(t, peers, 4) assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4) assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
@@ -800,19 +873,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's // peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group // no connection should be established to any peer of destination group
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
assert.Len(t, peers, 0) assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0) assert.Len(t, firewallRules, 0)
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's // peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group // no connection should be established to any peer of destination group
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
assert.Len(t, peers, 0) assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0) assert.Len(t, firewallRules, 0)
// peerC satisfy the NB posture check, should establish connection to all destination group peer's // peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections // We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
@@ -827,14 +900,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection // all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
assert.Len(t, peers, 3) assert.Len(t, peers, 3)
assert.Len(t, firewallRules, 3) assert.Len(t, firewallRules, 3)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
assert.Contains(t, peers, account.Peers["peerD"]) assert.Contains(t, peers, account.Peers["peerD"])
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerA", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers)
assert.Len(t, peers, 5) assert.Len(t, peers, 5)
// assert peers from Group Swarm // assert peers from Group Swarm
assert.Contains(t, peers, account.Peers["peerD"]) assert.Contains(t, peers, account.Peers["peerD"])

View File

@@ -24,20 +24,12 @@ func sanitizeVersion(version string) string {
} }
func (n *NBVersionCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) { func (n *NBVersionCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) {
peerVersion := sanitizeVersion(peer.Meta.WtVersion) meetsMin, err := MeetsMinVersion(n.MinVersion, peer.Meta.WtVersion)
minVersion := sanitizeVersion(n.MinVersion)
peerNBVersion, err := version.NewVersion(peerVersion)
if err != nil { if err != nil {
return false, err return false, err
} }
constraints, err := version.NewConstraint(">= " + minVersion) if meetsMin {
if err != nil {
return false, err
}
if constraints.Check(peerNBVersion) {
return true, nil return true, nil
} }
@@ -60,3 +52,21 @@ func (n *NBVersionCheck) Validate() error {
} }
return nil return nil
} }
// MeetsMinVersion checks if the peer's version meets or exceeds the minimum required version
func MeetsMinVersion(minVer, peerVer string) (bool, error) {
peerVer = sanitizeVersion(peerVer)
minVer = sanitizeVersion(minVer)
peerNBVer, err := version.NewVersion(peerVer)
if err != nil {
return false, err
}
constraints, err := version.NewConstraint(">= " + minVer)
if err != nil {
return false, err
}
return constraints.Check(peerNBVer), nil
}

View File

@@ -139,3 +139,68 @@ func TestNBVersionCheck_Validate(t *testing.T) {
}) })
} }
} }
func TestMeetsMinVersion(t *testing.T) {
tests := []struct {
name string
minVer string
peerVer string
want bool
wantErr bool
}{
{
name: "Peer version greater than min version",
minVer: "0.26.0",
peerVer: "0.60.1",
want: true,
wantErr: false,
},
{
name: "Peer version equals min version",
minVer: "1.0.0",
peerVer: "1.0.0",
want: true,
wantErr: false,
},
{
name: "Peer version less than min version",
minVer: "1.0.0",
peerVer: "0.9.9",
want: false,
wantErr: false,
},
{
name: "Peer version with pre-release tag greater than min version",
minVer: "1.0.0",
peerVer: "1.0.1-alpha",
want: true,
wantErr: false,
},
{
name: "Invalid peer version format",
minVer: "1.0.0",
peerVer: "dev",
want: false,
wantErr: true,
},
{
name: "Invalid min version format",
minVer: "invalid.version",
peerVer: "1.0.0",
want: false,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := MeetsMinVersion(tt.minVer, tt.peerVer)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -4,19 +4,19 @@ import (
"context" "context"
"fmt" "fmt"
"net/netip" "net/netip"
"slices"
"unicode/utf8" "unicode/utf8"
"github.com/rs/xid" "github.com/rs/xid"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -30,13 +30,19 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, string(routeID), accountID) return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, accountID, string(routeID))
} }
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *types.Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error { func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction store.Store, accountID string, checkRoute *route.Route, groupsMap map[string]*types.Group) error {
// routes can have both peer and peer_groups // routes can have both peer and peer_groups
routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains) prefix := checkRoute.Network
domains := checkRoute.Domains
routesWithPrefix, err := getRoutesByPrefixOrDomains(ctx, transaction, accountID, prefix, domains)
if err != nil {
return err
}
// lets remember all the peers and the peer groups from routesWithPrefix // lets remember all the peers and the peer groups from routesWithPrefix
seenPeers := make(map[string]bool) seenPeers := make(map[string]bool)
@@ -45,18 +51,24 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
for _, prefixRoute := range routesWithPrefix { for _, prefixRoute := range routesWithPrefix {
// we skip route(s) with the same network ID as we want to allow updating of the existing route // we skip route(s) with the same network ID as we want to allow updating of the existing route
// when creating a new route routeID is newly generated so nothing will be skipped // when creating a new route routeID is newly generated so nothing will be skipped
if routeID == prefixRoute.ID { if checkRoute.ID == prefixRoute.ID {
continue continue
} }
if prefixRoute.Peer != "" { if prefixRoute.Peer != "" {
seenPeers[string(prefixRoute.ID)] = true seenPeers[string(prefixRoute.ID)] = true
} }
peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, prefixRoute.PeerGroups)
if err != nil {
return err
}
for _, groupID := range prefixRoute.PeerGroups { for _, groupID := range prefixRoute.PeerGroups {
seenPeerGroups[groupID] = true seenPeerGroups[groupID] = true
group := account.GetGroup(groupID) group, ok := peerGroupsMap[groupID]
if group == nil { if !ok || group == nil {
return status.Errorf( return status.Errorf(
status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist", status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist",
getRouteDescriptor(prefix, domains), groupID, getRouteDescriptor(prefix, domains), groupID,
@@ -69,12 +81,13 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
} }
} }
if peerID != "" { if peerID := checkRoute.Peer; peerID != "" {
// check that peerID exists and is not in any route as single peer or part of the group // check that peerID exists and is not in any route as single peer or part of the group
peer := account.GetPeer(peerID) _, err = transaction.GetPeerByID(context.Background(), store.LockingStrengthShare, accountID, peerID)
if peer == nil { if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
} }
if _, ok := seenPeers[peerID]; ok { if _, ok := seenPeers[peerID]; ok {
return status.Errorf(status.AlreadyExists, return status.Errorf(status.AlreadyExists,
"failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID) "failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID)
@@ -82,9 +95,8 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
} }
// check that peerGroupIDs are not in any route peerGroups list // check that peerGroupIDs are not in any route peerGroups list
for _, groupID := range peerGroupIDs { for _, groupID := range checkRoute.PeerGroups {
group := account.GetGroup(groupID) // we validated the group existence before entering this function, no need to check again. group := groupsMap[groupID] // we validated the group existence before entering this function, no need to check again.
if _, ok := seenPeerGroups[groupID]; ok { if _, ok := seenPeerGroups[groupID]; ok {
return status.Errorf( return status.Errorf(
status.AlreadyExists, "failed to add route with %s - peer group %s already has this route", status.AlreadyExists, "failed to add route with %s - peer group %s already has this route",
@@ -92,12 +104,18 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
} }
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix // check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
peersMap, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, group.Peers)
if err != nil {
return err
}
for _, id := range group.Peers { for _, id := range group.Peers {
if _, ok := seenPeers[id]; ok { if _, ok := seenPeers[id]; ok {
peer := account.GetPeer(id) peer, ok := peersMap[id]
if peer == nil { if !ok || peer == nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID %s not found", id)
} }
return status.Errorf(status.AlreadyExists, return status.Errorf(status.AlreadyExists,
"failed to add route with %s - peer %s from the group %s already has this route", "failed to add route with %s - peer %s from the group %s already has this route",
getRouteDescriptor(prefix, domains), peer.Name, group.Name) getRouteDescriptor(prefix, domains), peer.Name, group.Name)
@@ -128,97 +146,58 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
if len(domains) > 0 && prefix.IsValid() { if len(domains) > 0 && prefix.IsValid() {
return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
} }
if len(domains) == 0 && !prefix.IsValid() { var newRoute *route.Route
return nil, status.Errorf(status.InvalidArgument, "invalid Prefix") var updateAccountPeers bool
}
if len(domains) > 0 { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
prefix = getPlaceholderIP() newRoute = &route.Route{
} ID: route.ID(xid.New().String()),
AccountID: accountID,
if peerID != "" && len(peerGroupIDs) != 0 { Network: prefix,
return nil, status.Errorf( Domains: domains,
status.InvalidArgument, KeepRoute: keepRoute,
"peer with ID %s and peers group %s should not be provided at the same time", NetID: netID,
peerID, peerGroupIDs) Description: description,
} Peer: peerID,
PeerGroups: peerGroupIDs,
var newRoute route.Route NetworkType: networkType,
newRoute.ID = route.ID(xid.New().String()) Masquerade: masquerade,
Metric: metric,
if len(peerGroupIDs) > 0 { Enabled: enabled,
err = validateGroups(peerGroupIDs, account.Groups) Groups: groups,
if err != nil { AccessControlGroups: accessControlGroupIDs,
return nil, err
} }
}
if len(accessControlGroupIDs) > 0 { if err = validateRoute(ctx, transaction, accountID, newRoute); err != nil {
err = validateGroups(accessControlGroupIDs, account.Groups) return err
if err != nil {
return nil, err
} }
}
err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains) updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, newRoute)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, newRoute)
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if metric < route.MinMetric || metric > route.MaxMetric {
return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
}
if utf8.RuneCountInString(string(netID)) > route.MaxNetIDChar || netID == "" {
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
}
err = validateGroups(groups, account.Groups)
if err != nil {
return nil, err
}
newRoute.Peer = peerID
newRoute.PeerGroups = peerGroupIDs
newRoute.Network = prefix
newRoute.Domains = domains
newRoute.NetworkType = networkType
newRoute.Description = description
newRoute.NetID = netID
newRoute.Masquerade = masquerade
newRoute.Metric = metric
newRoute.Enabled = enabled
newRoute.Groups = groups
newRoute.KeepRoute = keepRoute
newRoute.AccessControlGroups = accessControlGroupIDs
if account.Routes == nil {
account.Routes = make(map[route.ID]*route.Route)
}
account.Routes[newRoute.ID] = &newRoute
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return nil, err
}
if am.isRouteChangeAffectPeers(account, &newRoute) {
am.UpdateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
return &newRoute, nil if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
}
return newRoute, nil
} }
// SaveRoute saves route // SaveRoute saves route
@@ -226,6 +205,115 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var oldRoute *route.Route
var oldRouteAffectsPeers bool
var newRouteAffectsPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateRoute(ctx, transaction, accountID, routeToSave); err != nil {
return err
}
oldRoute, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeToSave.ID))
if err != nil {
return err
}
oldRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, oldRoute)
if err != nil {
return err
}
newRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, routeToSave)
if err != nil {
return err
}
routeToSave.AccountID = accountID
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, routeToSave)
})
if err != nil {
return err
}
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
if oldRouteAffectsPeers || newRouteAffectsPeers {
am.UpdateAccountPeers(ctx, accountID)
}
return nil
}
// DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var route *route.Route
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
route, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
if err != nil {
return err
}
updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, route)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.DeleteRoute(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
})
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
}
return nil
}
// ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
}
func validateRoute(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) error {
if routeToSave == nil { if routeToSave == nil {
return status.Errorf(status.InvalidArgument, "route provided is nil") return status.Errorf(status.InvalidArgument, "route provided is nil")
} }
@@ -238,19 +326,6 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
} }
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() { if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
} }
@@ -267,96 +342,39 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time") return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time")
} }
groupsMap, err := validateRouteGroups(ctx, transaction, accountID, routeToSave)
if err != nil {
return err
}
return checkRoutePrefixOrDomainsExistForPeers(ctx, transaction, accountID, routeToSave, groupsMap)
}
// validateRouteGroups validates the route groups and returns the validated groups map.
func validateRouteGroups(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) (map[string]*types.Group, error) {
groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups)
groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupsToValidate)
if err != nil {
return nil, err
}
if len(routeToSave.PeerGroups) > 0 { if len(routeToSave.PeerGroups) > 0 {
err = validateGroups(routeToSave.PeerGroups, account.Groups) if err = validateGroups(routeToSave.PeerGroups, groupsMap); err != nil {
if err != nil { return nil, err
return err
} }
} }
if len(routeToSave.AccessControlGroups) > 0 { if len(routeToSave.AccessControlGroups) > 0 {
err = validateGroups(routeToSave.AccessControlGroups, account.Groups) if err = validateGroups(routeToSave.AccessControlGroups, groupsMap); err != nil {
if err != nil { return nil, err
return err
} }
} }
err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains) if err = validateGroups(routeToSave.Groups, groupsMap); err != nil {
if err != nil { return nil, err
return err
} }
err = validateGroups(routeToSave.Groups, account.Groups) return groupsMap, nil
if err != nil {
return err
}
oldRoute := account.Routes[routeToSave.ID]
account.Routes[routeToSave.ID] = routeToSave
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) {
am.UpdateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
return nil
}
// DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
routy := account.Routes[routeID]
if routy == nil {
return status.Errorf(status.NotFound, "route with ID %s doesn't exist", routeID)
}
delete(account.Routes, routeID)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
if am.isRouteChangeAffectPeers(account, routy) {
am.UpdateAccountPeers(ctx, accountID)
}
return nil
}
// ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
} }
func toProtocolRoute(route *route.Route) *proto.Route { func toProtocolRoute(route *route.Route) *proto.Route {
@@ -455,8 +473,40 @@ func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
return &portInfo return &portInfo
} }
// isRouteChangeAffectPeers checks if a given route affects peers by determining // areRouteChangesAffectPeers checks if a given route affects peers by determining
// if it has a routing peer, distribution, or peer groups that include peers // if it has a routing peer, distribution, or peer groups that include peers.
func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *types.Account, route *route.Route) bool { func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, route *route.Route) (bool, error) {
return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" if route.Peer != "" {
return true, nil
}
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.Groups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.PeerGroups)
}
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
func getRoutesByPrefixOrDomains(ctx context.Context, transaction store.Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
routes := make([]*route.Route, 0)
for _, r := range accountRoutes {
dynamic := r.IsDynamic()
if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() ||
!dynamic && r.Network.String() == prefix.String() {
routes = append(routes, r)
}
}
return routes, nil
} }

View File

@@ -12,6 +12,7 @@ import (
type Scheduler interface { type Scheduler interface {
Cancel(ctx context.Context, IDs []string) Cancel(ctx context.Context, IDs []string)
Schedule(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) Schedule(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool))
IsSchedulerRunning(ID string) bool
} }
// MockScheduler is a mock implementation of Scheduler // MockScheduler is a mock implementation of Scheduler
@@ -26,7 +27,7 @@ func (mock *MockScheduler) Cancel(ctx context.Context, IDs []string) {
mock.CancelFunc(ctx, IDs) mock.CancelFunc(ctx, IDs)
return return
} }
log.WithContext(ctx).Errorf("MockScheduler doesn't have Cancel function defined ") log.WithContext(ctx).Warnf("MockScheduler doesn't have Cancel function defined ")
} }
// Schedule mocks the Schedule function of the Scheduler interface // Schedule mocks the Schedule function of the Scheduler interface
@@ -35,7 +36,13 @@ func (mock *MockScheduler) Schedule(ctx context.Context, in time.Duration, ID st
mock.ScheduleFunc(ctx, in, ID, job) mock.ScheduleFunc(ctx, in, ID, job)
return return
} }
log.WithContext(ctx).Errorf("MockScheduler doesn't have Schedule function defined") log.WithContext(ctx).Warnf("MockScheduler doesn't have Schedule function defined")
}
func (mock *MockScheduler) IsSchedulerRunning(ID string) bool {
// MockScheduler does not implement IsSchedulerRunning, so we return false
log.Warnf("MockScheduler doesn't have IsSchedulerRunning function defined")
return false
} }
// DefaultScheduler is a generic structure that allows to schedule jobs (functions) to run in the future and cancel them. // DefaultScheduler is a generic structure that allows to schedule jobs (functions) to run in the future and cancel them.
@@ -124,3 +131,11 @@ func (wm *DefaultScheduler) Schedule(ctx context.Context, in time.Duration, ID s
}() }()
} }
// IsSchedulerRunning checks if a job with the provided ID is scheduled to run
func (wm *DefaultScheduler) IsSchedulerRunning(ID string) bool {
wm.mu.Lock()
defer wm.mu.Unlock()
_, ok := wm.jobs[ID]
return ok
}

View File

@@ -5,7 +5,6 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"strconv"
"strings" "strings"
"testing" "testing"
"time" "time"
@@ -182,7 +181,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
} }
assertKey(t, key, tCase.expectedKeyName, false, tCase.expectedType, tCase.expectedUsedTimes, assertKey(t, key, tCase.expectedKeyName, false, tCase.expectedType, tCase.expectedUsedTimes,
tCase.expectedCreatedAt, tCase.expectedExpiresAt, strconv.Itoa(int(types.Hash(key.Key))), tCase.expectedCreatedAt, tCase.expectedExpiresAt, key.Id,
tCase.expectedUpdatedAt, tCase.expectedGroups, false) tCase.expectedUpdatedAt, tCase.expectedGroups, false)
// check the corresponding events that should have been generated // check the corresponding events that should have been generated
@@ -258,10 +257,10 @@ func TestGenerateDefaultSetupKey(t *testing.T) {
expectedExpiresAt := time.Now().UTC().Add(24 * 30 * time.Hour) expectedExpiresAt := time.Now().UTC().Add(24 * 30 * time.Hour)
var expectedAutoGroups []string var expectedAutoGroups []string
key, plainKey := types.GenerateDefaultSetupKey() key, _ := types.GenerateDefaultSetupKey()
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt,
expectedExpiresAt, strconv.Itoa(int(types.Hash(plainKey))), expectedUpdatedAt, expectedAutoGroups, true) expectedExpiresAt, key.Id, expectedUpdatedAt, expectedAutoGroups, true)
} }
@@ -275,10 +274,10 @@ func TestGenerateSetupKey(t *testing.T) {
expectedUpdatedAt := time.Now().UTC() expectedUpdatedAt := time.Now().UTC()
var expectedAutoGroups []string var expectedAutoGroups []string
key, plain := types.GenerateSetupKey(expectedName, types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false, false) key, _ := types.GenerateSetupKey(expectedName, types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false, false)
assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt,
expectedExpiresAt, strconv.Itoa(int(types.Hash(plain))), expectedUpdatedAt, expectedAutoGroups, true) expectedExpiresAt, key.Id, expectedUpdatedAt, expectedAutoGroups, true)
} }

View File

@@ -227,3 +227,7 @@ func NewUserRoleNotFoundError(role string) error {
func NewOperationNotFoundError(operation operations.Operation) error { func NewOperationNotFoundError(operation operations.Operation) error {
return Errorf(NotFound, "operation: %s not found", operation) return Errorf(NotFound, "operation: %s not found", operation)
} }
func NewRouteNotFoundError(routeID string) error {
return Errorf(NotFound, "route: %s not found", routeID)
}

View File

@@ -23,8 +23,6 @@ import (
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
"github.com/netbirdio/netbird/management/server/util"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
@@ -34,6 +32,7 @@ import (
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -1968,12 +1967,58 @@ func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength Locking
// GetAccountRoutes retrieves network routes for an account. // GetAccountRoutes retrieves network routes for an account.
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) { func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
return getRecords[*route.Route](s.db, lockStrength, accountID) var routes []*route.Route
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Find(&routes, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get routes from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get routes from store")
}
return routes, nil
} }
// GetRouteByID retrieves a route by its ID and account ID. // GetRouteByID retrieves a route by its ID and account ID.
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) { func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID string, routeID string) (*route.Route, error) {
return getRecordByID[route.Route](s.db, lockStrength, routeID, accountID) var route *route.Route
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&route, accountAndIDQueryCondition, accountID, routeID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewRouteNotFoundError(routeID)
}
log.WithContext(ctx).Errorf("failed to get route from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get route from store")
}
return route, nil
}
// SaveRoute saves a route to the database.
func (s *SqlStore) SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(route)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save route to the store: %s", err)
return status.Errorf(status.Internal, "failed to save route to store")
}
return nil
}
// DeleteRoute deletes a route from the database.
func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&route.Route{}, accountAndIDQueryCondition, accountID, routeID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete route from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete route from store")
}
if result.RowsAffected == 0 {
return status.NewRouteNotFoundError(routeID)
}
return nil
} }
// GetAccountSetupKeys retrieves setup keys for an account. // GetAccountSetupKeys retrieves setup keys for an account.
@@ -2104,49 +2149,6 @@ func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength Locki
return nil return nil
} }
// getRecords retrieves records from the database based on the account ID.
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
tx := db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var record []T
result := tx.Find(&record, accountIDCondition, accountID)
if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".")
recordType := parts[len(parts)-1]
return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err)
}
return record, nil
}
// getRecordByID retrieves a record by its ID and account ID from the database.
func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) (*T, error) {
tx := db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var record T
result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&record, accountAndIDQueryCondition, accountID, recordID)
if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".")
recordType := parts[len(parts)-1]
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "%s not found", recordType)
}
return nil, status.Errorf(status.Internal, "failed to get %s from store: %v", recordType, err)
}
return &record, nil
}
// SaveDNSSettings saves the DNS settings to the store. // SaveDNSSettings saves the DNS settings to the store.
func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error { func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).

View File

@@ -19,21 +19,17 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/util"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
route2 "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/management/server/status"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
nbroute "github.com/netbirdio/netbird/route" nbroute "github.com/netbirdio/netbird/route"
route2 "github.com/netbirdio/netbird/route"
) )
func runTestForAllEngines(t *testing.T, testDataFile string, f func(t *testing.T, store Store)) { func runTestForAllEngines(t *testing.T, testDataFile string, f func(t *testing.T, store Store)) {
@@ -3247,6 +3243,132 @@ func TestSqlStore_SaveGroups_LargeBatch(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 8003, len(accountGroups)) require.Equal(t, 8003, len(accountGroups))
} }
func TestSqlStore_GetAccountRoutes(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
tests := []struct {
name string
accountID string
expectedCount int
}{
{
name: "retrieve routes by existing account ID",
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
expectedCount: 1,
},
{
name: "non-existing account ID",
accountID: "nonexistent",
expectedCount: 0,
},
{
name: "empty account ID",
accountID: "",
expectedCount: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
routes, err := store.GetAccountRoutes(context.Background(), LockingStrengthShare, tt.accountID)
require.NoError(t, err)
require.Len(t, routes, tt.expectedCount)
})
}
}
func TestSqlStore_GetRouteByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
routeID string
expectError bool
}{
{
name: "retrieve existing route",
routeID: "ct03t427qv97vmtmglog",
expectError: false,
},
{
name: "retrieve non-existing route",
routeID: "non-existing",
expectError: true,
},
{
name: "retrieve with empty route ID",
routeID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
route, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, tt.routeID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, route)
} else {
require.NoError(t, err)
require.NotNil(t, route)
require.Equal(t, tt.routeID, string(route.ID))
}
})
}
}
func TestSqlStore_SaveRoute(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
route := &route2.Route{
ID: "route-id",
AccountID: accountID,
Network: netip.MustParsePrefix("10.10.0.0/16"),
NetID: "netID",
PeerGroups: []string{"routeA"},
NetworkType: route2.IPv4Network,
Masquerade: true,
Metric: 9999,
Enabled: true,
Groups: []string{"groupA"},
AccessControlGroups: []string{},
}
err = store.SaveRoute(context.Background(), LockingStrengthUpdate, route)
require.NoError(t, err)
saveRoute, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, string(route.ID))
require.NoError(t, err)
require.Equal(t, route, saveRoute)
}
func TestSqlStore_DeleteRoute(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
routeID := "ct03t427qv97vmtmglog"
err = store.DeleteRoute(context.Background(), LockingStrengthUpdate, accountID, routeID)
require.NoError(t, err)
route, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, routeID)
require.Error(t, err)
require.Nil(t, route)
}
func TestSqlStore_GetAccountMeta(t *testing.T) { func TestSqlStore_GetAccountMeta(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())

View File

@@ -145,7 +145,9 @@ type Store interface {
DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error)
SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error
DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)

View File

@@ -184,10 +184,10 @@ func (appMetrics *defaultAppMetrics) Expose(ctx context.Context, port int, endpo
} }
appMetrics.listener = listener appMetrics.listener = listener
go func() { go func() {
err := http.Serve(listener, rootRouter) if err := http.Serve(listener, rootRouter); err != nil && err != http.ErrServerClosed {
if err != nil { log.WithContext(ctx).Errorf("metrics server error: %v", err)
return
} }
log.WithContext(ctx).Info("metrics server stopped")
}() }()
log.WithContext(ctx).Infof("enabled application metrics and exposing on http://%s", listener.Addr().String()) log.WithContext(ctx).Infof("enabled application metrics and exposing on http://%s", listener.Addr().String())
@@ -204,7 +204,7 @@ func (appMetrics *defaultAppMetrics) GetMeter() metric2.Meter {
func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) { func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
exporter, err := prometheus.New() exporter, err := prometheus.New()
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to create prometheus exporter: %w", err)
} }
provider := metric.NewMeterProvider(metric.WithReader(exporter)) provider := metric.NewMeterProvider(metric.WithReader(exporter))
@@ -213,32 +213,32 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
idpMetrics, err := NewIDPMetrics(ctx, meter) idpMetrics, err := NewIDPMetrics(ctx, meter)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to initialize IDP metrics: %w", err)
} }
middleware, err := NewMetricsMiddleware(ctx, meter) middleware, err := NewMetricsMiddleware(ctx, meter)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to initialize HTTP middleware metrics: %w", err)
} }
grpcMetrics, err := NewGRPCMetrics(ctx, meter) grpcMetrics, err := NewGRPCMetrics(ctx, meter)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to initialize gRPC metrics: %w", err)
} }
storeMetrics, err := NewStoreMetrics(ctx, meter) storeMetrics, err := NewStoreMetrics(ctx, meter)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to initialize store metrics: %w", err)
} }
updateChannelMetrics, err := NewUpdateChannelMetrics(ctx, meter) updateChannelMetrics, err := NewUpdateChannelMetrics(ctx, meter)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to initialize update channel metrics: %w", err)
} }
accountManagerMetrics, err := NewAccountManagerMetrics(ctx, meter) accountManagerMetrics, err := NewAccountManagerMetrics(ctx, meter)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to initialize account manager metrics: %w", err)
} }
return &defaultAppMetrics{ return &defaultAppMetrics{

View File

@@ -38,4 +38,5 @@ INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-3465
INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}'); INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}');
INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}'); INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}');
INSERT INTO name_server_groups VALUES('csqdelq7qv97ncu7d9t0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Google DNS','Google DNS Servers','[{"IP":"8.8.8.8","NSType":1,"Port":53},{"IP":"8.8.4.4","NSType":1,"Port":53}]','["cfefqs706sqkneg59g2g"]',1,'[]',1,0); INSERT INTO name_server_groups VALUES('csqdelq7qv97ncu7d9t0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Google DNS','Google DNS Servers','[{"IP":"8.8.8.8","NSType":1,"Port":53},{"IP":"8.8.4.4","NSType":1,"Port":53}]','["cfefqs706sqkneg59g2g"]',1,'[]',1,0);
INSERT INTO routes VALUES('ct03t427qv97vmtmglog','bf1c8084-ba50-4ce7-9439-34653001fc3b','"10.10.0.0/16"',NULL,0,'aws-eu-central-1-vpc','Production VPC in Frankfurt','ct03r5q7qv97vmtmglng',NULL,1,1,9999,1,'["cfefqs706sqkneg59g2g"]',NULL);
INSERT INTO installations VALUES(1,''); INSERT INTO installations VALUES(1,'');

View File

@@ -36,6 +36,9 @@ const (
PublicCategory = "public" PublicCategory = "public"
PrivateCategory = "private" PrivateCategory = "private"
UnknownCategory = "unknown" UnknownCategory = "unknown"
// firewallRuleMinPortRangesVer defines the minimum peer version that supports port range rules.
firewallRuleMinPortRangesVer = "0.48.0"
) )
type LookupMap map[string]struct{} type LookupMap map[string]struct{}
@@ -248,7 +251,7 @@ func (a *Account) GetPeerNetworkMap(
} }
} }
aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peerID, validatedPeersMap) aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap)
// exclude expired peers // exclude expired peers
var peersToConnect []*nbpeer.Peer var peersToConnect []*nbpeer.Peer
var expiredPeers []*nbpeer.Peer var expiredPeers []*nbpeer.Peer
@@ -961,8 +964,9 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map
// GetPeerConnectionResources for a given peer // GetPeerConnectionResources for a given peer
// //
// This function returns the list of peers and firewall rules that are applicable to a given peer. // This function returns the list of peers and firewall rules that are applicable to a given peer.
func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx) generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer)
for _, policy := range a.Policies { for _, policy := range a.Policies {
if !policy.Enabled { if !policy.Enabled {
continue continue
@@ -973,8 +977,8 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string,
continue continue
} }
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peer.ID, policy.SourcePostureChecks, validatedPeersMap)
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap) destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peer.ID, nil, validatedPeersMap)
if rule.Bidirectional { if rule.Bidirectional {
if peerInSources { if peerInSources {
@@ -1003,7 +1007,7 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string,
// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer. // The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer.
// It safe to call the generator function multiple times for same peer and different rules no duplicates will be // It safe to call the generator function multiple times for same peer and different rules no duplicates will be
// generated. The accumulator function returns the result of all the generator calls. // generated. The accumulator function returns the result of all the generator calls.
func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) { func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer.Peer) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
rulesExists := make(map[string]struct{}) rulesExists := make(map[string]struct{})
peersExists := make(map[string]struct{}) peersExists := make(map[string]struct{})
rules := make([]*FirewallRule, 0) rules := make([]*FirewallRule, 0)
@@ -1051,17 +1055,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
continue continue
} }
for _, port := range rule.Ports { rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...)
pr := fr // clone rule and add set new port
pr.Port = port
rules = append(rules, &pr)
}
for _, portRange := range rule.PortRanges {
pr := fr
pr.PortRange = portRange
rules = append(rules, &pr)
}
} }
}, func() ([]*nbpeer.Peer, []*FirewallRule) { }, func() ([]*nbpeer.Peer, []*FirewallRule) {
return peers, rules return peers, rules
@@ -1590,3 +1584,45 @@ func (a *Account) AddAllGroup() error {
} }
return nil return nil
} }
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
var expanded []*FirewallRule
if len(rule.Ports) > 0 {
for _, port := range rule.Ports {
fr := base
fr.Port = port
expanded = append(expanded, &fr)
}
return expanded
}
supportPortRanges := peerSupportsPortRanges(peer.Meta.WtVersion)
for _, portRange := range rule.PortRanges {
fr := base
if supportPortRanges {
fr.PortRange = portRange
} else {
// Peer doesn't support port ranges, only allow single-port ranges
if portRange.Start != portRange.End {
continue
}
fr.Port = strconv.FormatUint(uint64(portRange.Start), 10)
}
expanded = append(expanded, &fr)
}
return expanded
}
// peerSupportsPortRanges checks if the peer version supports port ranges.
func peerSupportsPortRanges(peerVer string) bool {
if strings.Contains(peerVer, "dev") {
return true
}
meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer)
return err == nil && meetMinVer
}

View File

@@ -76,7 +76,6 @@ func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule
rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...) rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...)
} else { } else {
rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...) rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...)
} }
// TODO: generate IPv6 rules for dynamic routes // TODO: generate IPv6 rules for dynamic routes

View File

@@ -3,13 +3,12 @@ package types
import ( import (
"crypto/sha256" "crypto/sha256"
b64 "encoding/base64" b64 "encoding/base64"
"hash/fnv"
"strconv"
"strings" "strings"
"time" "time"
"unicode/utf8" "unicode/utf8"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
) )
@@ -170,7 +169,7 @@ func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoG
encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:])
return &SetupKey{ return &SetupKey{
Id: strconv.Itoa(int(Hash(key))), Id: xid.New().String(),
Key: encodedHashedKey, Key: encodedHashedKey,
KeySecret: HiddenKey(key, 4), KeySecret: HiddenKey(key, 4),
Name: name, Name: name,
@@ -192,12 +191,3 @@ func GenerateDefaultSetupKey() (*SetupKey, string) {
return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{}, return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{},
SetupKeyUnlimitedUsage, false, false) SetupKeyUnlimitedUsage, false, false)
} }
func Hash(s string) uint32 {
h := fnv.New32a()
_, err := h.Write([]byte(s))
if err != nil {
panic(err)
}
return h.Sum32()
}

35
signal/cmd/env.go Normal file
View File

@@ -0,0 +1,35 @@
package cmd
import (
"os"
"strings"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
)
// setFlagsFromEnvVars reads and updates flag values from environment variables with prefix NB_
func setFlagsFromEnvVars(cmd *cobra.Command) {
flags := cmd.PersistentFlags()
flags.VisitAll(func(f *pflag.Flag) {
newEnvVar := flagNameToEnvVar(f.Name, "NB_")
value, present := os.LookupEnv(newEnvVar)
if !present {
return
}
err := flags.Set(f.Name, value)
if err != nil {
log.Infof("unable to configure flag %s using variable %s, err: %v", f.Name, newEnvVar, err)
}
})
}
// flagNameToEnvVar converts flag name to environment var name adding a prefix,
// replacing dashes and making all uppercase (e.g. setup-keys is converted to NB_SETUP_KEYS according to the input prefix)
func flagNameToEnvVar(cmdFlag string, prefix string) string {
parsed := strings.ReplaceAll(cmdFlag, "-", "_")
upper := strings.ToUpper(parsed)
return prefix + upper
}

View File

@@ -303,4 +303,5 @@ func init() {
runCmd.Flags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") runCmd.Flags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS")
runCmd.Flags().StringVar(&signalCertFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect") runCmd.Flags().StringVar(&signalCertFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
runCmd.Flags().StringVar(&signalCertKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect") runCmd.Flags().StringVar(&signalCertKey, "cert-key", "", "Location of your SSL certificate private key. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect")
setFlagsFromEnvVars(runCmd)
} }

View File

@@ -21,6 +21,7 @@ var (
// Update fetch the version info periodically and notify the onUpdateListener in case the UI version or the // Update fetch the version info periodically and notify the onUpdateListener in case the UI version or the
// daemon version are deprecated // daemon version are deprecated
type Update struct { type Update struct {
httpAgent string
uiVersion *goversion.Version uiVersion *goversion.Version
daemonVersion *goversion.Version daemonVersion *goversion.Version
latestAvailable *goversion.Version latestAvailable *goversion.Version
@@ -34,7 +35,7 @@ type Update struct {
} }
// NewUpdate instantiate Update and start to fetch the new version information // NewUpdate instantiate Update and start to fetch the new version information
func NewUpdate() *Update { func NewUpdate(httpAgent string) *Update {
currentVersion, err := goversion.NewVersion(version) currentVersion, err := goversion.NewVersion(version)
if err != nil { if err != nil {
currentVersion, _ = goversion.NewVersion("0.0.0") currentVersion, _ = goversion.NewVersion("0.0.0")
@@ -43,6 +44,7 @@ func NewUpdate() *Update {
latestAvailable, _ := goversion.NewVersion("0.0.0") latestAvailable, _ := goversion.NewVersion("0.0.0")
u := &Update{ u := &Update{
httpAgent: httpAgent,
latestAvailable: latestAvailable, latestAvailable: latestAvailable,
uiVersion: currentVersion, uiVersion: currentVersion,
fetchTicker: time.NewTicker(fetchPeriod), fetchTicker: time.NewTicker(fetchPeriod),
@@ -112,7 +114,15 @@ func (u *Update) startFetcher() {
func (u *Update) fetchVersion() bool { func (u *Update) fetchVersion() bool {
log.Debugf("fetching version info from %s", versionURL) log.Debugf("fetching version info from %s", versionURL)
resp, err := http.Get(versionURL) req, err := http.NewRequest("GET", versionURL, nil)
if err != nil {
log.Errorf("failed to create request for version info: %s", err)
return false
}
req.Header.Set("User-Agent", u.httpAgent)
resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
log.Errorf("failed to fetch version info: %s", err) log.Errorf("failed to fetch version info: %s", err)
return false return false

View File

@@ -9,6 +9,8 @@ import (
"time" "time"
) )
const httpAgent = "pkg/test"
func TestNewUpdate(t *testing.T) { func TestNewUpdate(t *testing.T) {
version = "1.0.0" version = "1.0.0"
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -21,7 +23,7 @@ func TestNewUpdate(t *testing.T) {
wg.Add(1) wg.Add(1)
onUpdate := false onUpdate := false
u := NewUpdate() u := NewUpdate(httpAgent)
defer u.StopWatch() defer u.StopWatch()
u.SetOnUpdateListener(func() { u.SetOnUpdateListener(func() {
onUpdate = true onUpdate = true
@@ -46,7 +48,7 @@ func TestDoNotUpdate(t *testing.T) {
wg.Add(1) wg.Add(1)
onUpdate := false onUpdate := false
u := NewUpdate() u := NewUpdate(httpAgent)
defer u.StopWatch() defer u.StopWatch()
u.SetOnUpdateListener(func() { u.SetOnUpdateListener(func() {
onUpdate = true onUpdate = true
@@ -71,7 +73,7 @@ func TestDaemonUpdate(t *testing.T) {
wg.Add(1) wg.Add(1)
onUpdate := false onUpdate := false
u := NewUpdate() u := NewUpdate(httpAgent)
defer u.StopWatch() defer u.StopWatch()
u.SetOnUpdateListener(func() { u.SetOnUpdateListener(func() {
onUpdate = true onUpdate = true