Compare commits

...

83 Commits

Author SHA1 Message Date
Viktor Liu
1010629ea2 Merge pull request #2 from lixmal/windows-msi-cmd-fix
Remove auto sections
2025-06-24 14:28:17 +02:00
Viktor Liu
955d48abb9 Remove auto sections 2025-06-24 14:27:30 +02:00
Viktor Liu
d6dad20c83 Merge pull request #1 from lixmal/windows-msi-cmd-fix
Windows msi cmd fix
2025-06-24 14:19:56 +02:00
Viktor Liu
04676c5368 Remove auto sections 2025-06-24 14:17:15 +02:00
Viktor Liu
b53a517bf8 Don't open cmd.exe during MSI actions 2025-06-24 13:01:20 +02:00
Viktor Liu
f37aa2cc9d [misc] Specify netbird binary location in Dockerfiles (#4024) 2025-06-23 10:09:02 +02:00
Maycon Santos
5343bee7b2 [management] check and log on new management version (#4029)
This PR enhances the version checker to send a custom User-Agent header when polling for updates, and configures both the management CLI and client UI to use distinct agents. 

- NewUpdate now takes an `httpAgent` string to set the User-Agent header.
- `fetchVersion` builds a custom HTTP request (instead of `http.Get`) and sets the User-Agent.
- Management CLI and client UI now pass `"nb/management"` and `"nb/client-ui"` respectively to NewUpdate.
- Tests updated to supply an `httpAgent` constant.
- Logs if there is a new version available for management
2025-06-22 16:44:33 +02:00
Maycon Santos
870e29db63 [misc] add additional metrics (#4028)
* add additional metrics

we are collecting active rosenpass, ssh from the client side
we are also collecting active user peers and active users

* remove duplicated
2025-06-22 13:44:25 +02:00
Maycon Santos
08e9b05d51 [client] close windows when process needs to exit (#4027)
This PR fixes a bug by ensuring that the advanced settings and re-authentication windows are closed appropriately when the main GUI process exits.

- Updated runSelfCommand calls throughout the UI to pass a context parameter.
- Modified runSelfCommand’s signature and its internal command invocation to use exec.CommandContext for proper cancellation handling.
2025-06-22 10:33:04 +02:00
hakansa
3581648071 [client] Refactor showLoginURL to improve error handling and connection status checks (#4026)
This PR refactors showLoginURL to improve error handling and connection status checks by delaying the login fetch until user interaction and closing the pop-up if already connected.

- Moved s.login(false) call into the click handler to defer network I/O.
- Added a conn.Status check after opening the URL to skip reconnection if already connected.
- Enhanced error logs for missing verification URLs and service status failures.
2025-06-22 10:03:58 +02:00
Viktor Liu
2a51609436 [client] Handle lazy routing peers that are part of HA groups (#3943)
* Activate new lazy routing peers if the HA group is active
* Prevent lazy peers going to idle if HA group members are active (#3948)
2025-06-20 18:07:19 +02:00
Pascal Fischer
83457f8b99 [management] add transaction for integrated validator groups update and primary account update (#4014) 2025-06-20 12:13:24 +02:00
Pascal Fischer
b45284f086 [management] export ephemeral peer flag on api (#4004) 2025-06-19 16:46:56 +02:00
Bethuel Mmbaga
e9016aecea [management] Add backward compatibility for older clients without firewall rules port range support (#4003)
Adds backward compatibility for clients with versions prior to v0.48.0 that do not support port range firewall rules.

- Skips generation of firewall rules with multi-port ranges for older clients
- Preserves support for single-port ranges by treating them as individual port rules, ensuring compatibility with older clients
2025-06-19 13:07:06 +03:00
Viktor Liu
23b5d45b68 [client] Fix port range squashing (#4007) 2025-06-18 18:56:48 +02:00
Viktor Liu
0e5dc9d412 [client] Add more Android advanced settings (#4001) 2025-06-18 17:23:23 +02:00
Zoltan Papp
91f7ee6a3c Fix route notification
On Android ignore the dynamic roots in the route notifications
2025-06-18 16:49:03 +02:00
Bethuel Mmbaga
7c6b85b4cb [management] Refactor routes to use store methods (#2928) 2025-06-18 16:40:29 +03:00
hakansa
08c9107c61 [client] fix connection state handling (#3995)
[client] fix connection state handling (#3995)
2025-06-17 17:14:08 +03:00
hakansa
81d83245e1 [client] Fix logic in updateStatus to correctly handle connection state (#3994)
[client] Fix logic in updateStatus to correctly handle connection state (#3994)
2025-06-17 17:02:04 +03:00
Maycon Santos
af2b427751 [management] Avoid recalculating next peer expiration (#3991)
* Avoid recalculating next peer expiration

- Check if an account schedule is already running
- Cancel executing schedules only when changes occurs
- Add more context info to logs

* fix tests
2025-06-17 15:14:11 +02:00
hakansa
f61ebdb3bc [client] Fix DNS Interceptor Build Error (#3993)
[client] Fix DNS Interceptor Build Error
2025-06-17 16:07:14 +03:00
Viktor Liu
de7384e8ea [client] Tighten allowed domains for dns forwarder (#3978) 2025-06-17 14:03:00 +02:00
Viktor Liu
75c1be69cf [client] Prioritze the local resolver in the dns handler chain (#3965) 2025-06-17 14:02:30 +02:00
hakansa
424ae28de9 [client] Fix UI Download URL (#3990)
[client] Fix UI Download URL
2025-06-17 11:55:48 +03:00
Viktor Liu
d4a800edd5 [client] Fix status recorder panic (#3988) 2025-06-17 01:20:26 +02:00
Maycon Santos
dd9917f1a8 [misc] add missing images (#3987) 2025-06-16 21:05:49 +02:00
Viktor Liu
8df8c1012f [client] Support wildcard DNS on iOS (#3979) 2025-06-16 18:33:51 +02:00
Viktor Liu
bfa5c21d2d [client] Improve icmp conntrack log (#3963) 2025-06-16 10:12:59 +02:00
Maycon Santos
b1247a14ba [management] Use xID for setup key IDs to avoid id collisions (#3977)
This PR addresses potential ID collisions by switching the setup key ID generation from a hash-based approach to using xid-generated IDs.

Replace the hash function with xid.New().String()
Remove obsolete imports and the Hash() function
2025-06-14 12:24:16 +01:00
Philippe Vaucher
f595057a0b [signal] Set flags from environment variables (#3972) 2025-06-14 00:08:34 +02:00
hakansa
089d442fb2 [client] Display login popup on session expiration (#3955)
This PR implements a feature enhancement to display a login popup when the session expires. Key changes include updating flag handling and client construction to support a new login URL popup, revising login and notification handling logic to use the new popup, and updating status and server-side session state management accordingly.
2025-06-13 23:51:57 +02:00
Viktor Liu
04a3765391 [client] Fix unncessary UI updates (#3785) 2025-06-13 20:38:50 +02:00
Zoltan Papp
d24d8328f9 [client] Propagation networks for Android client (#3966)
Add networks propagation
2025-06-13 11:04:17 +02:00
Vlad
4f63996ae8 [management] added events streaming metrics (#3814) 2025-06-12 18:48:54 +01:00
Zoltan Papp
bdf2994e97 [client] Feature/android preferences (#3957)
Propagate Rosenpass preferences for Android
2025-06-12 09:41:12 +02:00
Bethuel Mmbaga
6d654acbad [management] Persist peer flags in meta updates (#3958)
This PR adds persistence for peer feature flags when updating metadata, including equality checks, gRPC extraction, and corresponding unit tests.

- Introduce a new `Flags` struct with `isEqual` and incorporate it into `PeerSystemMeta`.
- Update `UpdateMetaIfNew` logic to consider flag changes.
- Extend gRPC server’s `extractPeerMeta` to populate `Flags` and add tests for `Flags.isEqual`.
2025-06-11 22:39:59 +02:00
Viktor Liu
3e43298471 [client] Fix local resolver returning error for existing domains with other types (#3959) 2025-06-11 21:08:45 +02:00
Bethuel Mmbaga
0ad2590974 [misc] Push all docker images to ghcr in releases (#3954)
This PR refactors the release process to push all release images to the GitHub Container Registry.

Updated image naming in .goreleaser.yaml to include new registry references.
Added a GitHub Actions step in .github/workflows/release.yml to log in to the GitHub Container Registry.
2025-06-11 15:28:30 +02:00
Zoltan Papp
9d11257b1a [client] Carry the peer's actual state with the notification. (#3929)
- Removed separate thread execution of GetStates during notifications.
- Updated notification handler to rely on state data included in the notification payload.
2025-06-11 13:33:38 +02:00
Bethuel Mmbaga
4ee1635baa [management] Propagate user groups when group propagation setting is re-enabled (#3912) 2025-06-11 14:32:16 +03:00
Zoltan Papp
75feb0da8b [client] Refactor context management in ConnMgr for clarity and consistency (#3951)
In the conn_mgr we must distinguish two contexts. One is relevant for lazy-manager, and one (engine context) is relevant for peer creation. If we use the incorrect context, then when we disable the lazy connection feature, we cancel the peer connections too, instead of just the lazy manager.
2025-06-11 11:04:44 +02:00
Bethuel Mmbaga
87376afd13 [management] Enable unidirectional rules for all port policy (#3826) 2025-06-10 18:02:45 +03:00
Bethuel Mmbaga
b76d9e8e9e [management] Add support for port ranges in firewall rules (#3823) 2025-06-10 18:02:13 +03:00
Viktor Liu
e71383dcb9 [client] Add missing client meta flags (#3898) 2025-06-10 14:27:58 +02:00
Viktor Liu
e002a2e6e8 [client] Add more advanced settings to UI (#3941) 2025-06-10 14:27:06 +02:00
Viktor Liu
6127a01196 [client] Remove strings from allowed IPs (#3920) 2025-06-10 14:26:28 +02:00
Bethuel Mmbaga
de27d6df36 [management] Add account ID index to activity events (#3946) 2025-06-09 14:34:53 +03:00
Viktor Liu
3c535cdd2b [client] Add lazy connections to routed networks (#3908) 2025-06-08 14:10:34 +02:00
Maycon Santos
0f050e5fe1 [client] Optmize process check time (#3938)
This PR optimizes the process check time by updating the implementation of getRunningProcesses and introducing new benchmark tests.

Updated getRunningProcesses to use process.Pids() instead of process.Processes()
Added benchmark tests for both the new and the legacy implementations

Benchmark: https://github.com/netbirdio/netbird/actions/runs/15512741612

todo: evaluate windows optmizations and caching risks
2025-06-08 12:19:54 +02:00
Maycon Santos
0f7c7f1da2 [misc] use generic slack url (#3939) 2025-06-08 10:53:27 +02:00
Maycon Santos
b56f61bf1b [misc] fix relay exposed address test (#3931) 2025-06-05 15:44:44 +02:00
Viktor Liu
64f111923e [client] Increase stun status probe timeout (#3930) 2025-06-05 15:22:59 +02:00
Abdul Latif
122a89c02b [misc] remove error causing dnf config-manager add (#3925) 2025-06-05 14:28:19 +02:00
Robert Neumann
c6cceba381 Update getting-started-with-zitadel.sh - fix zitadel user console (#3446) 2025-06-05 14:16:04 +02:00
Ghazy Abdallah
6c0cdb6ed1 [misc] fix: traefik relay accessibility (#3696) 2025-06-05 14:15:01 +02:00
Viktor Liu
84354951d3 [client] Add systemd netbird logs to debug bundle (#3917) 2025-06-05 13:54:15 +02:00
Viktor Liu
55957a1960 [client] Run registerdns before flushing (#3926)
* Run registerdns before flushing

* Disable WINS, dynamic updates and registration
2025-06-05 12:40:23 +02:00
Viktor Liu
df82a45d99 [client] Improve dns match trace log (#3928) 2025-06-05 12:39:58 +02:00
Zoltan Papp
9424b88db2 [client] Add output similar to wg show to the debug package (#3922) 2025-06-05 11:51:39 +02:00
Viktor Liu
609654eee7 [client] Allow userspace local forwarding to internal interfaces if requested (#3884) 2025-06-04 18:12:48 +02:00
Bethuel Mmbaga
b604c66140 [management] Add postgres support for activity event store (#3890) 2025-06-04 17:38:49 +03:00
Viktor Liu
ea4d13e96d [client] Use platform-native routing APIs for freeBSD, macOS and Windows 2025-06-04 16:28:58 +02:00
Pedro Maia Costa
87148c503f [management] support account retrieval and creation by private domain (#3825)
* [management] sys initiator save user (#3911)

* [management] activity events with multiple external account users (#3914)
2025-06-04 11:21:31 +01:00
Viktor Liu
0cd36baf67 [client] Allow the netbird service to log to console (#3916) 2025-06-03 13:09:39 +02:00
Viktor Liu
06980e7fa0 [client] Apply routes right away instead of on peer connection (#3907) 2025-06-03 10:53:39 +02:00
Viktor Liu
1ce4ee0cef [client] Add block inbound flag to disallow inbound connections of any kind (#3897) 2025-06-03 10:53:27 +02:00
Viktor Liu
f367925496 [client] Log duplicate client ui pid (#3915) 2025-06-03 10:52:10 +02:00
hakansa
616b19c064 [client] Add "Deselect All" Menu Item to Exit Node Menu (#3877)
* [client] Enhance exit node menu functionality with deselect all option

* Hide exit nodes before removal in recreateExitNodeMenu

* recreateExitNodeMenu adding mutex locks

* Refetch exit nodes after deselecting all in exit node menu
2025-06-03 09:49:13 +02:00
Zoltan Papp
af27aaf9af [client] Refactor peer state change subscription mechanism (#3910)
* Refactor peer state change subscription mechanism

Because the code generated new channel for every single event, was easy to miss notification.
Use single channel.

* Fix lint

* Avoid potential deadlock

* Fix test

* Add context

* Fix test
2025-06-03 09:20:33 +02:00
Maycon Santos
35287f8241 [misc] Fail linter workflows on codespell failures (#3913)
* Fail linter workflows on codespell failures

* testing workflow

* remove test
2025-06-03 00:37:51 +02:00
Pedro Maia Costa
07b220d91b [management] REST client impersonation (#3879) 2025-06-02 22:11:28 +02:00
Viktor Liu
41cd4952f1 [client] Apply return traffic rules only if firewall is stateless (#3895) 2025-06-02 12:11:54 +02:00
Zoltan Papp
f16f0c7831 [client] Fix HA router switch (#3889)
* Fix HA router switch.

- Simplify the notification filter logic.
Always send notification if a state has been changed

- Remove IP changes check because we never modify

* Notify only the proper listeners

* Fix test

* Fix TestGetPeerStateChangeNotifierLogic test

* Before lazy connection, when the peer disconnected, the status switched to disconnected.
After implementing lazy connection, the peer state is connecting, so we did not decrease the reference counters on the routes.

* When switch to idle notify the route mgr
2025-06-01 16:08:27 +02:00
Zoltan Papp
aa07b3b87b Fix deadlock (#3904) 2025-05-30 23:38:02 +02:00
Bethuel Mmbaga
2bef214cc0 [management] Fix user groups propagation (#3902) 2025-05-30 18:12:30 +03:00
hakansa
cfb2d82352 [client] Refactor exclude list handling to use a map for permanent connections (#3901)
[client] Refactor exclude list handling to use a map for permanent connections (#3901)
2025-05-30 16:54:49 +03:00
Bethuel Mmbaga
684501fd35 [management] Prevent deletion of peers linked to network routers (#3881)
- Prevent deletion of peers linked to network routers
- Add API endpoint to list all network routers
2025-05-29 18:50:00 +03:00
Zoltan Papp
0492c1724a [client, android] Fix/notifier threading (#3807)
- Fix potential deadlocks
- When adding a listener, immediately notify with the last known IP and fqdn.
2025-05-27 17:12:04 +02:00
Zoltan Papp
6f436e57b5 [server-test] Install libs for i386 tests (#3887)
Install libs for i386 tests
2025-05-27 16:42:06 +02:00
Bethuel Mmbaga
a0d28f9851 [management] Reset test containers after cleanup (#3885) 2025-05-27 14:42:00 +03:00
Zoltan Papp
cdd27a9fe5 [client, android] Fix/android enable server route (#3806)
Enable the server route; otherwise, the manager throws an error and the engine will restart.
2025-05-27 13:32:54 +02:00
Bethuel Mmbaga
5523040acd [management] Add correlated network traffic event schema (#3680) 2025-05-27 13:47:53 +03:00
218 changed files with 11220 additions and 6674 deletions

View File

@@ -1,21 +0,0 @@
name: Git Town
on:
pull_request:
branches:
- '**'
jobs:
git-town:
name: Display the branch stack
runs-on: ubuntu-latest
permissions:
contents: read
pull-requests: write
steps:
- uses: actions/checkout@v4
- uses: git-town/action@v1
with:
skip-single-stacks: true

View File

@@ -1,46 +0,0 @@
name: "Darwin"
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test:
name: "Client / Unit"
runs-on: macos-latest
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Cache Go modules
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: macos-gotest-${{ hashFiles('**/go.sum') }}
restore-keys: |
macos-gotest-
macos-go-
- name: Install libpcap
run: brew install libpcap
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)

View File

@@ -1,52 +0,0 @@
name: "FreeBSD"
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test:
name: "Client / Unit"
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- name: Test in FreeBSD
id: test
uses: vmactions/freebsd-vm@v1
with:
usesh: true
copyback: false
release: "14.2"
prepare: |
pkg install -y curl pkgconf xorg
LATEST_VERSION=$(curl -s https://go.dev/VERSION?m=text|head -n 1)
GO_TARBALL="$LATEST_VERSION.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -vLO "$GO_URL"
tar -C /usr/local -vxzf "$GO_TARBALL"
# -x - to print all executed commands
# -e - to faile on first error
run: |
set -e -x
export PATH=$PATH:/usr/local/go/bin:$HOME/go/bin
time go build -o netbird client/main.go
# check all component except management, since we do not support management server on freebsd
time go test -timeout 1m -failfast ./base62/...
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
time go test -timeout 8m -failfast -p 1 ./client/...
time go test -timeout 1m -failfast ./dns/...
time go test -timeout 1m -failfast ./encryption/...
time go test -timeout 1m -failfast ./formatter/...
time go test -timeout 1m -failfast ./client/iface/...
time go test -timeout 1m -failfast ./route/...
time go test -timeout 1m -failfast ./sharedsock/...
time go test -timeout 1m -failfast ./signal/...
time go test -timeout 1m -failfast ./util/...
time go test -timeout 1m -failfast ./version/...

View File

@@ -1,560 +0,0 @@
name: Linux
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
build-cache:
name: "Build Cache"
runs-on: ubuntu-22.04
outputs:
management: ${{ steps.filter.outputs.management }}
steps:
- name: Checkout code
uses: actions/checkout@v4
- uses: dorny/paths-filter@v3
id: filter
with:
filters: |
management:
- 'management/**'
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@v4
id: cache
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: steps.cache.outputs.cache-hit != 'true'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Build client
if: steps.cache.outputs.cache-hit != 'true'
working-directory: client
run: CGO_ENABLED=1 go build .
- name: Build client 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: client
run: CGO_ENABLED=1 GOARCH=386 go build -o client-386 .
- name: Build management
if: steps.cache.outputs.cache-hit != 'true'
working-directory: management
run: CGO_ENABLED=1 go build .
- name: Build management 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: management
run: CGO_ENABLED=1 GOARCH=386 go build -o management-386 .
- name: Build signal
if: steps.cache.outputs.cache-hit != 'true'
working-directory: signal
run: CGO_ENABLED=1 go build .
- name: Build signal 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: signal
run: CGO_ENABLED=1 GOARCH=386 go build -o signal-386 .
- name: Build relay
if: steps.cache.outputs.cache-hit != 'true'
working-directory: relay
run: CGO_ENABLED=1 go build .
- name: Build relay 386
if: steps.cache.outputs.cache-hit != 'true'
working-directory: relay
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
test:
name: "Client / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
test_client_on_docker:
name: "Client (Docker) / Unit"
needs: [build-cache]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
id: go-env
run: |
echo "cache_dir=$(go env GOCACHE)" >> $GITHUB_OUTPUT
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
- name: Cache Go modules
uses: actions/cache/restore@v4
id: cache-restore
with:
path: |
${{ steps.go-env.outputs.cache_dir }}
${{ steps.go-env.outputs.modcache_dir }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Run tests in container
env:
HOST_GOCACHE: ${{ steps.go-env.outputs.cache_dir }}
HOST_GOMODCACHE: ${{ steps.go-env.outputs.modcache_dir }}
run: |
CONTAINER_GOCACHE="/root/.cache/go-build"
CONTAINER_GOMODCACHE="/go/pkg/mod"
docker run --rm \
--cap-add=NET_ADMIN \
--privileged \
-v $PWD:/app \
-w /app \
-v "${HOST_GOCACHE}:${CONTAINER_GOCACHE}" \
-v "${HOST_GOMODCACHE}:${CONTAINER_GOMODCACHE}" \
-e CGO_ENABLED=1 \
-e CI=true \
-e DOCKER_CI=true \
-e GOARCH=${GOARCH_TARGET} \
-e GOCACHE=${CONTAINER_GOCACHE} \
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
golang:1.23-alpine \
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
'
test_relay:
name: "Relay / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test \
-exec 'sudo' \
-timeout 10m ./signal/...
test_signal:
name: "Signal / Unit"
needs: [build-cache]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
go test \
-exec 'sudo' \
-timeout 10m ./signal/...
test_management:
name: "Management / Unit"
needs: [ build-cache ]
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Login to Docker hub
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
go test -tags=devcert \
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
-timeout 20m ./management/...
benchmark:
name: "Management / Benchmark"
needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Login to Docker hub
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
go test -tags devcert -run=^$ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/...
api_benchmark:
name: "Management / Benchmark (API)"
needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Create Docker network
run: docker network create promnet
- name: Start Prometheus Pushgateway
run: docker run -d --name pushgateway --network promnet -p 9091:9091 prom/pushgateway
- name: Start Prometheus (for Pushgateway forwarding)
run: |
echo '
global:
scrape_interval: 15s
scrape_configs:
- job_name: "pushgateway"
static_configs:
- targets: ["pushgateway:9091"]
remote_write:
- url: ${{ secrets.GRAFANA_URL }}
basic_auth:
username: ${{ secrets.GRAFANA_USER }}
password: ${{ secrets.GRAFANA_API_KEY }}
' > prometheus.yml
docker run -d --name prometheus --network promnet \
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
-p 9090:9090 \
prom/prometheus
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Login to Docker hub
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
GIT_BRANCH=${{ github.ref_name }} \
go test -tags=benchmark \
-run=^$ \
-bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
-timeout 20m ./management/...
api_integration_test:
name: "Management / Integration"
needs: [ build-cache ]
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
strategy:
fail-fast: false
matrix:
arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres']
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
go test -tags=integration \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/...

View File

@@ -1,72 +0,0 @@
name: "Windows"
on:
push:
branches:
- main
pull_request:
env:
downloadPath: '${{ github.workspace }}\temp'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test:
name: "Client / Unit"
runs-on: windows-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
id: go
with:
go-version: "1.23.x"
cache: false
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $env:GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-
${{ runner.os }}-go-
- name: Download wintun
uses: carlosperate/download-file-action@v2
id: download-wintun
with:
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
file-name: wintun.zip
location: ${{ env.downloadPath }}
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
- name: Decompressing wintun files
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
- run: mv ${{ env.downloadPath }}/wintun/bin/amd64/wintun.dll 'C:\Windows\System32\'
- run: choco install -y sysinternals --ignore-checksums
- run: choco install -y mingw
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
- name: test output
if: ${{ always() }}
run: Get-Content test-out.txt

View File

@@ -1,59 +0,0 @@
name: Lint
on: [pull_request]
permissions:
contents: read
pull-requests: read
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
codespell:
name: codespell
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
skip: go.mod,go.sum
only_warn: 1
golangci:
strategy:
fail-fast: false
matrix:
os: [macos-latest, windows-latest, ubuntu-latest]
include:
- os: macos-latest
display_name: Darwin
- os: windows-latest
display_name: Windows
- os: ubuntu-latest
display_name: Linux
name: ${{ matrix.display_name }}
runs-on: ${{ matrix.os }}
timeout-minutes: 15
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Check for duplicate constants
if: matrix.os == 'ubuntu-latest'
run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Install dependencies
if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: golangci-lint
uses: golangci/golangci-lint-action@v4
with:
version: latest
args: --timeout=12m --out-format colored-line-number

View File

@@ -1,37 +0,0 @@
name: Test installation
on:
push:
branches:
- main
pull_request:
paths:
- "release_files/install.sh"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test-install-script:
strategy:
fail-fast: false
max-parallel: 2
matrix:
os: [ubuntu-latest, macos-latest]
skip_ui_mode: [true, false]
install_binary: [true, false]
runs-on: ${{ matrix.os }}
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: run install script
env:
SKIP_UI_APP: ${{ matrix.skip_ui_mode }}
USE_BIN_INSTALL: ${{ matrix.install_binary }}
GITHUB_TOKEN: ${{ secrets.RO_API_CALLER_TOKEN }}
run: |
[ "$SKIP_UI_APP" == "false" ] && export XDG_CURRENT_DESKTOP="none"
cat release_files/install.sh | sh -x
- name: check cli binary
run: command -v netbird

View File

@@ -1,67 +0,0 @@
name: Mobile
on:
push:
branches:
- main
pull_request:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
android_build:
name: "Android / Build"
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
- name: Setup Android SDK
uses: android-actions/setup-android@v3
with:
cmdline-tools-version: 8512546
- name: Setup Java
uses: actions/setup-java@v4
with:
java-version: "11"
distribution: "adopt"
- name: NDK Cache
id: ndk-cache
uses: actions/cache@v4
with:
path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620
- name: Setup NDK
run: /usr/local/lib/android/sdk/cmdline-tools/7.0/bin/sdkmanager --install "ndk;23.1.7779620"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
- name: gomobile init
run: gomobile init
- name: build android netbird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -o $GITHUB_WORKSPACE/netbird.aar -javapkg=io.netbird.gomobile -ldflags="-X golang.zx2c4.com/wireguard/ipc.socketDirectory=/data/data/io.netbird.client/cache/wireguard -X github.com/netbirdio/netbird/version.version=buildtest" $GITHUB_WORKSPACE/client/android
env:
CGO_ENABLED: 0
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
ios_build:
name: "iOS / Build"
runs-on: macos-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
- name: install gomobile
run: go install golang.org/x/mobile/cmd/gomobile@v0.0.0-20240404231514-09dbf07665ed
- name: gomobile init
run: gomobile init
- name: build iOS netbird lib
run: PATH=$PATH:$(go env GOPATH) gomobile bind -target=ios -bundleid=io.netbird.framework -ldflags="-X github.com/netbirdio/netbird/version.version=buildtest" -o ./NetBirdSDK.xcframework ./client/ios/NetBirdSDK
env:
CGO_ENABLED: 0

View File

@@ -55,16 +55,23 @@ jobs:
run: go mod tidy run: go mod tidy
- name: check git status - name: check git status
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Set up QEMU # - name: Set up QEMU
uses: docker/setup-qemu-action@v2 # uses: docker/setup-qemu-action@v2
- name: Set up Docker Buildx # - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2 # uses: docker/setup-buildx-action@v2
- name: Login to Docker hub # - name: Login to Docker hub
if: github.event_name != 'pull_request' # if: github.event_name != 'pull_request'
uses: docker/login-action@v1 # uses: docker/login-action@v1
with: # with:
username: ${{ secrets.DOCKER_USER }} # username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }} # password: ${{ secrets.DOCKER_TOKEN }}
# - name: Log in to the GitHub container registry
# if: github.event_name != 'pull_request'
# uses: docker/login-action@v3
# with:
# registry: ghcr.io
# username: ${{ github.actor }}
# password: ${{ secrets.CI_DOCKER_PUSH_GITHUB_TOKEN }}
- name: Install OS build dependencies - name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu

View File

@@ -1,22 +0,0 @@
name: sync main
on:
push:
branches:
- main
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
trigger_sync_main:
runs-on: ubuntu-latest
steps:
- name: Trigger main branch sync
uses: benc-uk/workflow-dispatch@v1
with:
workflow: sync-main.yml
repo: ${{ secrets.UPSTREAM_REPO }}
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "sha": "${{ github.sha }}" }'

View File

@@ -1,23 +0,0 @@
name: sync tag
on:
push:
tags:
- 'v*'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
trigger_sync_tag:
runs-on: ubuntu-latest
steps:
- name: Trigger release tag sync
uses: benc-uk/workflow-dispatch@v1
with:
workflow: sync-tag.yml
ref: main
repo: ${{ secrets.UPSTREAM_REPO }}
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref_name }}" }'

View File

@@ -1,310 +0,0 @@
name: Test Infrastructure files
on:
push:
branches:
- main
pull_request:
paths:
- 'infrastructure_files/**'
- '.github/workflows/test-infrastructure-files.yml'
- 'management/cmd/**'
- 'signal/cmd/**'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
jobs:
test-docker-compose:
runs-on: ubuntu-latest
strategy:
matrix:
store: [ 'sqlite', 'postgres', 'mysql' ]
services:
postgres:
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
env:
POSTGRES_USER: netbird
POSTGRES_PASSWORD: postgres
POSTGRES_DB: netbird
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
ports:
- 5432:5432
mysql:
image: ${{ (matrix.store == 'mysql') && 'mysql' || '' }}
env:
MYSQL_USER: netbird
MYSQL_PASSWORD: mysql
MYSQL_ROOT_PASSWORD: mysqlroot
MYSQL_DATABASE: netbird
options: >-
--health-cmd "mysqladmin ping --silent"
--health-interval 10s
--health-timeout 5s
ports:
- 3306:3306
steps:
- name: Set Database Connection String
run: |
if [ "${{ matrix.store }}" == "postgres" ]; then
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN=host=$(hostname -I | awk '{print $1}') user=netbird password=postgres dbname=netbird port=5432" >> $GITHUB_ENV
else
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN==" >> $GITHUB_ENV
fi
if [ "${{ matrix.store }}" == "mysql" ]; then
echo "NETBIRD_STORE_ENGINE_MYSQL_DSN=netbird:mysql@tcp($(hostname -I | awk '{print $1}'):3306)/netbird" >> $GITHUB_ENV
else
echo "NETBIRD_STORE_ENGINE_MYSQL_DSN==" >> $GITHUB_ENV
fi
- name: Install jq
run: sudo apt-get install -y jq
- name: Install curl
run: sudo apt-get install -y curl
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
- name: Cache Go modules
uses: actions/cache@v4
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Checkout code
uses: actions/checkout@v4
- name: cp setup.env
run: cp infrastructure_files/tests/setup.env infrastructure_files/
- name: run configure
working-directory: infrastructure_files
run: bash -x configure.sh
env:
CI_NETBIRD_DOMAIN: localhost
CI_NETBIRD_AUTH_CLIENT_ID: testing.client.id
CI_NETBIRD_AUTH_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_AUTH_AUDIENCE: testing.ci
CI_NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT: https://example.eu.auth0.com/.well-known/openid-configuration
CI_NETBIRD_USE_AUTH0: true
CI_NETBIRD_MGMT_IDP: "none"
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
NETBIRD_STORE_ENGINE_POSTGRES_DSN: ${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}
NETBIRD_STORE_ENGINE_MYSQL_DSN: ${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
- name: check values
working-directory: infrastructure_files/artifacts
env:
CI_NETBIRD_DOMAIN: localhost
CI_NETBIRD_AUTH_CLIENT_ID: testing.client.id
CI_NETBIRD_AUTH_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_AUTH_AUDIENCE: testing.ci
CI_NETBIRD_AUTH_OIDC_CONFIGURATION_ENDPOINT: https://example.eu.auth0.com/.well-known/openid-configuration
CI_NETBIRD_USE_AUTH0: true
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
CI_NETBIRD_AUTH_AUTHORITY: https://example.eu.auth0.com/
CI_NETBIRD_AUTH_JWT_CERTS: https://example.eu.auth0.com/.well-known/jwks.json
CI_NETBIRD_AUTH_TOKEN_ENDPOINT: https://example.eu.auth0.com/oauth/token
CI_NETBIRD_AUTH_DEVICE_AUTH_ENDPOINT: https://example.eu.auth0.com/oauth/device/code
CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT: https://example.eu.auth0.com/authorize
CI_NETBIRD_AUTH_REDIRECT_URI: "/peers"
CI_NETBIRD_TOKEN_SOURCE: "idToken"
CI_NETBIRD_AUTH_USER_ID_CLAIM: "email"
CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE: "super"
CI_NETBIRD_AUTH_DEVICE_AUTH_SCOPE: "openid email"
CI_NETBIRD_MGMT_IDP: "none"
CI_NETBIRD_IDP_MGMT_CLIENT_ID: testing.client.id
CI_NETBIRD_IDP_MGMT_CLIENT_SECRET: testing.client.secret
CI_NETBIRD_SIGNAL_PORT: 12345
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
run: |
set -x
grep AUTH_CLIENT_ID docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_ID
grep AUTH_CLIENT_SECRET docker-compose.yml | grep $CI_NETBIRD_AUTH_CLIENT_SECRET
grep AUTH_AUTHORITY docker-compose.yml | grep $CI_NETBIRD_AUTH_AUTHORITY
grep AUTH_AUDIENCE docker-compose.yml | grep $CI_NETBIRD_AUTH_AUDIENCE
grep AUTH_SUPPORTED_SCOPES docker-compose.yml | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
grep USE_AUTH0 docker-compose.yml | grep $CI_NETBIRD_USE_AUTH0
grep NETBIRD_MGMT_API_ENDPOINT docker-compose.yml | grep "$CI_NETBIRD_DOMAIN:33073"
grep AUTH_REDIRECT_URI docker-compose.yml | grep $CI_NETBIRD_AUTH_REDIRECT_URI
grep AUTH_SILENT_REDIRECT_URI docker-compose.yml | egrep 'AUTH_SILENT_REDIRECT_URI=$'
grep $CI_NETBIRD_SIGNAL_PORT docker-compose.yml | grep ':80'
grep LETSENCRYPT_DOMAIN docker-compose.yml | egrep 'LETSENCRYPT_DOMAIN=$'
grep NETBIRD_TOKEN_SOURCE docker-compose.yml | grep $CI_NETBIRD_TOKEN_SOURCE
grep AuthUserIDClaim management.json | grep $CI_NETBIRD_AUTH_USER_ID_CLAIM
grep -A 3 DeviceAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
grep -A 3 DeviceAuthorizationFlow management.json | grep -A 1 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_DEVICE_AUTH_AUDIENCE
grep Engine management.json | grep "$CI_NETBIRD_STORE_CONFIG_ENGINE"
grep IdpSignKeyRefreshEnabled management.json | grep "$CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH"
grep UseIDToken management.json | grep false
grep -A 1 IdpManagerConfig management.json | grep ManagerType | grep $CI_NETBIRD_MGMT_IDP
grep -A 3 IdpManagerConfig management.json | grep -A 1 ClientConfig | grep Issuer | grep $CI_NETBIRD_AUTH_AUTHORITY
grep -A 4 IdpManagerConfig management.json | grep -A 2 ClientConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
grep -A 5 IdpManagerConfig management.json | grep -A 3 ClientConfig | grep ClientID | grep $CI_NETBIRD_IDP_MGMT_CLIENT_ID
grep -A 6 IdpManagerConfig management.json | grep -A 4 ClientConfig | grep ClientSecret | grep $CI_NETBIRD_IDP_MGMT_CLIENT_SECRET
grep -A 7 IdpManagerConfig management.json | grep -A 5 ClientConfig | grep GrantType | grep client_credentials
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Audience | grep $CI_NETBIRD_AUTH_AUDIENCE
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep ClientID | grep $CI_NETBIRD_AUTH_CLIENT_ID
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep ClientSecret | grep $CI_NETBIRD_AUTH_CLIENT_SECRET
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep AuthorizationEndpoint | grep $CI_NETBIRD_AUTH_PKCE_AUTHORIZATION_ENDPOINT
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep TokenEndpoint | grep $CI_NETBIRD_AUTH_TOKEN_ENDPOINT
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
# check relay values
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
grep '33445:33445' docker-compose.yml
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
grep DisablePromptLogin management.json | grep 'true'
grep LoginFlag management.json | grep 0
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Build management binary
working-directory: management
run: CGO_ENABLED=1 go build -o netbird-mgmt main.go
- name: Build management docker image
working-directory: management
run: |
docker build -t netbirdio/management:latest .
- name: Build signal binary
working-directory: signal
run: CGO_ENABLED=0 go build -o netbird-signal main.go
- name: Build signal docker image
working-directory: signal
run: |
docker build -t netbirdio/signal:latest .
- name: Build relay binary
working-directory: relay
run: CGO_ENABLED=0 go build -o netbird-relay main.go
- name: Build relay docker image
working-directory: relay
run: |
docker build -t netbirdio/relay:latest .
- name: run docker compose up
working-directory: infrastructure_files/artifacts
run: |
docker compose up -d
sleep 5
docker compose ps
docker compose logs --tail=20
- name: test running containers
run: |
count=$(docker compose ps --format json | jq '. | select(.Name | contains("artifacts")) | .State' | grep -c running)
test $count -eq 5 || docker compose logs
working-directory: infrastructure_files/artifacts
- name: test geolocation databases
working-directory: infrastructure_files/artifacts
run: |
sleep 30
docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City_[0-9]*.mmdb
docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames_[0-9]*.db
test-getting-started-script:
runs-on: ubuntu-latest
steps:
- name: Install jq
run: sudo apt-get install -y jq
- name: Checkout code
uses: actions/checkout@v4
- name: run script with Zitadel PostgreSQL
run: NETBIRD_DOMAIN=use-ip bash -x infrastructure_files/getting-started-with-zitadel.sh
- name: test Caddy file gen postgres
run: test -f Caddyfile
- name: test docker-compose file gen postgres
run: test -f docker-compose.yml
- name: test management.json file gen postgres
run: test -f management.json
- name: test turnserver.conf file gen postgres
run: |
set -x
test -f turnserver.conf
grep external-ip turnserver.conf
- name: test zitadel.env file gen postgres
run: test -f zitadel.env
- name: test dashboard.env file gen postgres
run: test -f dashboard.env
- name: test relay.env file gen postgres
run: test -f relay.env
- name: test zdb.env file gen postgres
run: test -f zdb.env
- name: Postgres run cleanup
run: |
docker compose down --volumes --rmi all
rm -rf docker-compose.yml Caddyfile zitadel.env dashboard.env machinekey/zitadel-admin-sa.token turnserver.conf management.json zdb.env
- name: run script with Zitadel CockroachDB
run: bash -x infrastructure_files/getting-started-with-zitadel.sh
env:
NETBIRD_DOMAIN: use-ip
ZITADEL_DATABASE: cockroach
- name: test Caddy file gen CockroachDB
run: test -f Caddyfile
- name: test docker-compose file gen CockroachDB
run: test -f docker-compose.yml
- name: test management.json file gen CockroachDB
run: test -f management.json
- name: test turnserver.conf file gen CockroachDB
run: |
set -x
test -f turnserver.conf
grep external-ip turnserver.conf
- name: test zitadel.env file gen CockroachDB
run: test -f zitadel.env
- name: test dashboard.env file gen CockroachDB
run: test -f dashboard.env
- name: test relay.env file gen CockroachDB
run: test -f relay.env

View File

@@ -1,22 +0,0 @@
name: update docs
on:
push:
tags:
- 'v*'
paths:
- 'management/server/http/api/openapi.yml'
jobs:
trigger_docs_api_update:
runs-on: ubuntu-latest
if: startsWith(github.ref, 'refs/tags/')
steps:
- name: Trigger API pages generation
uses: benc-uk/workflow-dispatch@v1
with:
workflow: generate api pages
repo: netbirdio/docs
ref: "refs/heads/main"
token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}" }'

View File

@@ -146,439 +146,541 @@ nfpms:
scripts: scripts:
postinstall: "release_files/post_install.sh" postinstall: "release_files/post_install.sh"
preremove: "release_files/pre_remove.sh" preremove: "release_files/pre_remove.sh"
dockers: # dockers:
- image_templates: # - image_templates:
- netbirdio/netbird:{{ .Version }}-amd64 # - netbirdio/netbird:{{ .Version }}-amd64
ids: # - ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
- netbird # ids:
goarch: amd64 # - netbird
use: buildx # goarch: amd64
dockerfile: client/Dockerfile # use: buildx
build_flag_templates: # dockerfile: client/Dockerfile
- "--platform=linux/amd64" # build_flag_templates:
- "--label=org.opencontainers.image.created={{.Date}}" # - "--platform=linux/amd64"
- "--label=org.opencontainers.image.title={{.ProjectName}}" # - "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.version={{.Version}}" # - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" # - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.version={{.Version}}" # - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=maintainer=dev@netbird.io" # - "--label=org.opencontainers.image.version={{.Version}}"
- image_templates: # - "--label=maintainer=dev@netbird.io"
- netbirdio/netbird:{{ .Version }}-arm64v8 # - image_templates:
ids: # - netbirdio/netbird:{{ .Version }}-arm64v8
- netbird # - ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
goarch: arm64 # ids:
use: buildx # - netbird
dockerfile: client/Dockerfile # goarch: arm64
build_flag_templates: # use: buildx
- "--platform=linux/arm64" # dockerfile: client/Dockerfile
- "--label=org.opencontainers.image.created={{.Date}}" # build_flag_templates:
- "--label=org.opencontainers.image.title={{.ProjectName}}" # - "--platform=linux/arm64"
- "--label=org.opencontainers.image.version={{.Version}}" # - "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" # - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}" # - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io" # - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- image_templates: # - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- netbirdio/netbird:{{ .Version }}-arm # - "--label=maintainer=dev@netbird.io"
ids: # - image_templates:
- netbird # - netbirdio/netbird:{{ .Version }}-arm
goarch: arm # - ghcr.io/netbirdio/netbird:{{ .Version }}-arm
goarm: 6 # ids:
use: buildx # - netbird
dockerfile: client/Dockerfile # goarch: arm
build_flag_templates: # goarm: 6
- "--platform=linux/arm" # use: buildx
- "--label=org.opencontainers.image.created={{.Date}}" # dockerfile: client/Dockerfile
- "--label=org.opencontainers.image.title={{.ProjectName}}" # build_flag_templates:
- "--label=org.opencontainers.image.version={{.Version}}" # - "--platform=linux/arm"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" # - "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.version={{.Version}}" # - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=maintainer=dev@netbird.io" # - "--label=org.opencontainers.image.version={{.Version}}"
# - "--label=org.opencontainers.image.revision={{.FullCommit}}"
# - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
# - "--label=maintainer=dev@netbird.io"
- image_templates: # - image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-amd64 # - netbirdio/netbird:{{ .Version }}-rootless-amd64
ids: # - ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
- netbird # ids:
goarch: amd64 # - netbird
use: buildx # goarch: amd64
dockerfile: client/Dockerfile-rootless # use: buildx
build_flag_templates: # dockerfile: client/Dockerfile-rootless
- "--platform=linux/amd64" # build_flag_templates:
- "--label=org.opencontainers.image.created={{.Date}}" # - "--platform=linux/amd64"
- "--label=org.opencontainers.image.title={{.ProjectName}}" # - "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.version={{.Version}}" # - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" # - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io" # - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- image_templates: # - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8 # - "--label=maintainer=dev@netbird.io"
ids: # - image_templates:
- netbird # - netbirdio/netbird:{{ .Version }}-rootless-arm64v8
goarch: arm64 # - ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
use: buildx # ids:
dockerfile: client/Dockerfile-rootless # - netbird
build_flag_templates: # goarch: arm64
- "--platform=linux/arm64" # use: buildx
- "--label=org.opencontainers.image.created={{.Date}}" # dockerfile: client/Dockerfile-rootless
- "--label=org.opencontainers.image.title={{.ProjectName}}" # build_flag_templates:
- "--label=org.opencontainers.image.version={{.Version}}" # - "--platform=linux/arm64"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" # - "--label=org.opencontainers.image.created={{.Date}}"
- "--label=maintainer=dev@netbird.io" # - "--label=org.opencontainers.image.title={{.ProjectName}}"
- image_templates: # - "--label=org.opencontainers.image.version={{.Version}}"
- netbirdio/netbird:{{ .Version }}-rootless-arm # - "--label=org.opencontainers.image.revision={{.FullCommit}}"
ids: # - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- netbird # - "--label=maintainer=dev@netbird.io"
goarch: arm # - image_templates:
goarm: 6 # - netbirdio/netbird:{{ .Version }}-rootless-arm
use: buildx # - ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
dockerfile: client/Dockerfile-rootless # ids:
build_flag_templates: # - netbird
- "--platform=linux/arm" # goarch: arm
- "--label=org.opencontainers.image.created={{.Date}}" # goarm: 6
- "--label=org.opencontainers.image.title={{.ProjectName}}" # use: buildx
- "--label=org.opencontainers.image.version={{.Version}}" # dockerfile: client/Dockerfile-rootless
- "--label=org.opencontainers.image.revision={{.FullCommit}}" # build_flag_templates:
- "--label=maintainer=dev@netbird.io" # - "--platform=linux/arm"
# - "--label=org.opencontainers.image.created={{.Date}}"
# - "--label=org.opencontainers.image.title={{.ProjectName}}"
# - "--label=org.opencontainers.image.version={{.Version}}"
# - "--label=org.opencontainers.image.revision={{.FullCommit}}"
# - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
# - "--label=maintainer=dev@netbird.io"
- image_templates: # - image_templates:
- netbirdio/relay:{{ .Version }}-amd64 # - netbirdio/relay:{{ .Version }}-amd64
ids: # - ghcr.io/netbirdio/relay:{{ .Version }}-amd64
- netbird-relay # ids:
goarch: amd64 # - netbird-relay
use: buildx # goarch: amd64
dockerfile: relay/Dockerfile # use: buildx
build_flag_templates: # dockerfile: relay/Dockerfile
- "--platform=linux/amd64" # build_flag_templates:
- "--label=org.opencontainers.image.created={{.Date}}" # - "--platform=linux/amd64"
- "--label=org.opencontainers.image.title={{.ProjectName}}" # - "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.version={{.Version}}" # - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" # - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.version={{.Version}}" # - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=maintainer=dev@netbird.io" # - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- image_templates: # - "--label=maintainer=dev@netbird.io"
- netbirdio/relay:{{ .Version }}-arm64v8 # - image_templates:
ids: # - netbirdio/relay:{{ .Version }}-arm64v8
- netbird-relay # - ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
goarch: arm64 # ids:
use: buildx # - netbird-relay
dockerfile: relay/Dockerfile # goarch: arm64
build_flag_templates: # use: buildx
- "--platform=linux/arm64" # dockerfile: relay/Dockerfile
- "--label=org.opencontainers.image.created={{.Date}}" # build_flag_templates:
- "--label=org.opencontainers.image.title={{.ProjectName}}" # - "--platform=linux/arm64"
- "--label=org.opencontainers.image.version={{.Version}}" # - "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" # - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}" # - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io" # - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- image_templates: # - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- netbirdio/relay:{{ .Version }}-arm # - "--label=maintainer=dev@netbird.io"
ids: # - image_templates:
- netbird-relay # - netbirdio/relay:{{ .Version }}-arm
goarch: arm # - ghcr.io/netbirdio/relay:{{ .Version }}-arm
goarm: 6 # ids:
use: buildx # - netbird-relay
dockerfile: relay/Dockerfile # goarch: arm
build_flag_templates: # goarm: 6
- "--platform=linux/arm" # use: buildx
- "--label=org.opencontainers.image.created={{.Date}}" # dockerfile: relay/Dockerfile
- "--label=org.opencontainers.image.title={{.ProjectName}}" # build_flag_templates:
- "--label=org.opencontainers.image.version={{.Version}}" # - "--platform=linux/arm"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" # - "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.version={{.Version}}" # - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=maintainer=dev@netbird.io" # - "--label=org.opencontainers.image.version={{.Version}}"
- image_templates: # - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- netbirdio/signal:{{ .Version }}-amd64 # - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
ids: # - "--label=maintainer=dev@netbird.io"
- netbird-signal # - image_templates:
goarch: amd64 # - netbirdio/signal:{{ .Version }}-amd64
use: buildx # - ghcr.io/netbirdio/signal:{{ .Version }}-amd64
dockerfile: signal/Dockerfile # ids:
build_flag_templates: # - netbird-signal
- "--platform=linux/amd64" # goarch: amd64
- "--label=org.opencontainers.image.created={{.Date}}" # use: buildx
- "--label=org.opencontainers.image.title={{.ProjectName}}" # dockerfile: signal/Dockerfile
- "--label=org.opencontainers.image.version={{.Version}}" # build_flag_templates:
- "--label=org.opencontainers.image.revision={{.FullCommit}}" # - "--platform=linux/amd64"
- "--label=org.opencontainers.image.version={{.Version}}" # - "--label=org.opencontainers.image.created={{.Date}}"
- "--label=maintainer=dev@netbird.io" # - "--label=org.opencontainers.image.title={{.ProjectName}}"
- image_templates: # - "--label=org.opencontainers.image.version={{.Version}}"
- netbirdio/signal:{{ .Version }}-arm64v8 # - "--label=org.opencontainers.image.revision={{.FullCommit}}"
ids: # - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- netbird-signal # - "--label=maintainer=dev@netbird.io"
goarch: arm64 # - image_templates:
use: buildx # - netbirdio/signal:{{ .Version }}-arm64v8
dockerfile: signal/Dockerfile # - ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
build_flag_templates: # ids:
- "--platform=linux/arm64" # - netbird-signal
- "--label=org.opencontainers.image.created={{.Date}}" # goarch: arm64
- "--label=org.opencontainers.image.title={{.ProjectName}}" # use: buildx
- "--label=org.opencontainers.image.version={{.Version}}" # dockerfile: signal/Dockerfile
- "--label=org.opencontainers.image.revision={{.FullCommit}}" # build_flag_templates:
- "--label=org.opencontainers.image.version={{.Version}}" # - "--platform=linux/arm64"
- "--label=maintainer=dev@netbird.io" # - "--label=org.opencontainers.image.created={{.Date}}"
- image_templates: # - "--label=org.opencontainers.image.title={{.ProjectName}}"
- netbirdio/signal:{{ .Version }}-arm # - "--label=org.opencontainers.image.version={{.Version}}"
ids: # - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- netbird-signal # - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
goarch: arm # - "--label=maintainer=dev@netbird.io"
goarm: 6 # - image_templates:
use: buildx # - netbirdio/signal:{{ .Version }}-arm
dockerfile: signal/Dockerfile # - ghcr.io/netbirdio/signal:{{ .Version }}-arm
build_flag_templates: # ids:
- "--platform=linux/arm" # - netbird-signal
- "--label=org.opencontainers.image.created={{.Date}}" # goarch: arm
- "--label=org.opencontainers.image.title={{.ProjectName}}" # goarm: 6
- "--label=org.opencontainers.image.version={{.Version}}" # use: buildx
- "--label=org.opencontainers.image.revision={{.FullCommit}}" # dockerfile: signal/Dockerfile
- "--label=org.opencontainers.image.version={{.Version}}" # build_flag_templates:
- "--label=maintainer=dev@netbird.io" # - "--platform=linux/arm"
- image_templates: # - "--label=org.opencontainers.image.created={{.Date}}"
- netbirdio/management:{{ .Version }}-amd64 # - "--label=org.opencontainers.image.title={{.ProjectName}}"
ids: # - "--label=org.opencontainers.image.version={{.Version}}"
- netbird-mgmt # - "--label=org.opencontainers.image.revision={{.FullCommit}}"
goarch: amd64 # - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
use: buildx # - "--label=maintainer=dev@netbird.io"
dockerfile: management/Dockerfile # - image_templates:
build_flag_templates: # - netbirdio/management:{{ .Version }}-amd64
- "--platform=linux/amd64" # - ghcr.io/netbirdio/management:{{ .Version }}-amd64
- "--label=org.opencontainers.image.created={{.Date}}" # ids:
- "--label=org.opencontainers.image.title={{.ProjectName}}" # - netbird-mgmt
- "--label=org.opencontainers.image.version={{.Version}}" # goarch: amd64
- "--label=org.opencontainers.image.revision={{.FullCommit}}" # use: buildx
- "--label=org.opencontainers.image.version={{.Version}}" # dockerfile: management/Dockerfile
- "--label=maintainer=dev@netbird.io" # build_flag_templates:
- image_templates: # - "--platform=linux/amd64"
- netbirdio/management:{{ .Version }}-arm64v8 # - "--label=org.opencontainers.image.created={{.Date}}"
ids: # - "--label=org.opencontainers.image.title={{.ProjectName}}"
- netbird-mgmt # - "--label=org.opencontainers.image.version={{.Version}}"
goarch: arm64 # - "--label=org.opencontainers.image.revision={{.FullCommit}}"
use: buildx # - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
dockerfile: management/Dockerfile # - "--label=maintainer=dev@netbird.io"
build_flag_templates: # - image_templates:
- "--platform=linux/arm64" # - netbirdio/management:{{ .Version }}-arm64v8
- "--label=org.opencontainers.image.created={{.Date}}" # - ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
- "--label=org.opencontainers.image.title={{.ProjectName}}" # ids:
- "--label=org.opencontainers.image.version={{.Version}}" # - netbird-mgmt
- "--label=org.opencontainers.image.revision={{.FullCommit}}" # goarch: arm64
- "--label=org.opencontainers.image.version={{.Version}}" # use: buildx
- "--label=maintainer=dev@netbird.io" # dockerfile: management/Dockerfile
- image_templates: # build_flag_templates:
- netbirdio/management:{{ .Version }}-arm # - "--platform=linux/arm64"
ids: # - "--label=org.opencontainers.image.created={{.Date}}"
- netbird-mgmt # - "--label=org.opencontainers.image.title={{.ProjectName}}"
goarch: arm # - "--label=org.opencontainers.image.version={{.Version}}"
goarm: 6 # - "--label=org.opencontainers.image.revision={{.FullCommit}}"
use: buildx # - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
dockerfile: management/Dockerfile # - "--label=maintainer=dev@netbird.io"
build_flag_templates: # - image_templates:
- "--platform=linux/arm" # - netbirdio/management:{{ .Version }}-arm
- "--label=org.opencontainers.image.created={{.Date}}" # - ghcr.io/netbirdio/management:{{ .Version }}-arm
- "--label=org.opencontainers.image.title={{.ProjectName}}" # ids:
- "--label=org.opencontainers.image.version={{.Version}}" # - netbird-mgmt
- "--label=org.opencontainers.image.revision={{.FullCommit}}" # goarch: arm
- "--label=org.opencontainers.image.version={{.Version}}" # goarm: 6
- "--label=maintainer=dev@netbird.io" # use: buildx
- image_templates: # dockerfile: management/Dockerfile
- netbirdio/management:{{ .Version }}-debug-amd64 # build_flag_templates:
ids: # - "--platform=linux/arm"
- netbird-mgmt # - "--label=org.opencontainers.image.created={{.Date}}"
goarch: amd64 # - "--label=org.opencontainers.image.title={{.ProjectName}}"
use: buildx # - "--label=org.opencontainers.image.version={{.Version}}"
dockerfile: management/Dockerfile.debug # - "--label=org.opencontainers.image.revision={{.FullCommit}}"
build_flag_templates: # - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--platform=linux/amd64" # - "--label=maintainer=dev@netbird.io"
- "--label=org.opencontainers.image.created={{.Date}}" # - image_templates:
- "--label=org.opencontainers.image.title={{.ProjectName}}" # - netbirdio/management:{{ .Version }}-debug-amd64
- "--label=org.opencontainers.image.version={{.Version}}" # - ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
- "--label=org.opencontainers.image.revision={{.FullCommit}}" # ids:
- "--label=org.opencontainers.image.version={{.Version}}" # - netbird-mgmt
- "--label=maintainer=dev@netbird.io" # goarch: amd64
- image_templates: # use: buildx
- netbirdio/management:{{ .Version }}-debug-arm64v8 # dockerfile: management/Dockerfile.debug
ids: # build_flag_templates:
- netbird-mgmt # - "--platform=linux/amd64"
goarch: arm64 # - "--label=org.opencontainers.image.created={{.Date}}"
use: buildx # - "--label=org.opencontainers.image.title={{.ProjectName}}"
dockerfile: management/Dockerfile.debug # - "--label=org.opencontainers.image.version={{.Version}}"
build_flag_templates: # - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--platform=linux/arm64" # - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=org.opencontainers.image.created={{.Date}}" # - "--label=maintainer=dev@netbird.io"
- "--label=org.opencontainers.image.title={{.ProjectName}}" # - image_templates:
- "--label=org.opencontainers.image.version={{.Version}}" # - netbirdio/management:{{ .Version }}-debug-arm64v8
- "--label=org.opencontainers.image.revision={{.FullCommit}}" # - ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
- "--label=org.opencontainers.image.version={{.Version}}" # ids:
- "--label=maintainer=dev@netbird.io" # - netbird-mgmt
# goarch: arm64
# use: buildx
# dockerfile: management/Dockerfile.debug
# build_flag_templates:
# - "--platform=linux/arm64"
# - "--label=org.opencontainers.image.created={{.Date}}"
# - "--label=org.opencontainers.image.title={{.ProjectName}}"
# - "--label=org.opencontainers.image.version={{.Version}}"
# - "--label=org.opencontainers.image.revision={{.FullCommit}}"
# - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
# - "--label=maintainer=dev@netbird.io"
- image_templates: # - image_templates:
- netbirdio/management:{{ .Version }}-debug-arm # - netbirdio/management:{{ .Version }}-debug-arm
ids: # - ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
- netbird-mgmt # ids:
goarch: arm # - netbird-mgmt
goarm: 6 # goarch: arm
use: buildx # goarm: 6
dockerfile: management/Dockerfile.debug # use: buildx
build_flag_templates: # dockerfile: management/Dockerfile.debug
- "--platform=linux/arm" # build_flag_templates:
- "--label=org.opencontainers.image.created={{.Date}}" # - "--platform=linux/arm"
- "--label=org.opencontainers.image.title={{.ProjectName}}" # - "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.version={{.Version}}" # - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" # - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.version={{.Version}}" # - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=maintainer=dev@netbird.io" # - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- image_templates: # - "--label=maintainer=dev@netbird.io"
- netbirdio/upload:{{ .Version }}-amd64 # - image_templates:
ids: # - netbirdio/upload:{{ .Version }}-amd64
- netbird-upload # - ghcr.io/netbirdio/upload:{{ .Version }}-amd64
goarch: amd64 # ids:
use: buildx # - netbird-upload
dockerfile: upload-server/Dockerfile # goarch: amd64
build_flag_templates: # use: buildx
- "--platform=linux/amd64" # dockerfile: upload-server/Dockerfile
- "--label=org.opencontainers.image.created={{.Date}}" # build_flag_templates:
- "--label=org.opencontainers.image.title={{.ProjectName}}" # - "--platform=linux/amd64"
- "--label=org.opencontainers.image.version={{.Version}}" # - "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" # - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}" # - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io" # - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- image_templates: # - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- netbirdio/upload:{{ .Version }}-arm64v8 # - "--label=maintainer=dev@netbird.io"
ids: # - image_templates:
- netbird-upload # - netbirdio/upload:{{ .Version }}-arm64v8
goarch: arm64 # - ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
use: buildx # ids:
dockerfile: upload-server/Dockerfile # - netbird-upload
build_flag_templates: # goarch: arm64
- "--platform=linux/arm64" # use: buildx
- "--label=org.opencontainers.image.created={{.Date}}" # dockerfile: upload-server/Dockerfile
- "--label=org.opencontainers.image.title={{.ProjectName}}" # build_flag_templates:
- "--label=org.opencontainers.image.version={{.Version}}" # - "--platform=linux/arm64"
- "--label=org.opencontainers.image.revision={{.FullCommit}}" # - "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.version={{.Version}}" # - "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=maintainer=dev@netbird.io" # - "--label=org.opencontainers.image.version={{.Version}}"
- image_templates: # - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- netbirdio/upload:{{ .Version }}-arm # - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
ids: # - "--label=maintainer=dev@netbird.io"
- netbird-upload # - image_templates:
goarch: arm # - netbirdio/upload:{{ .Version }}-arm
goarm: 6 # - ghcr.io/netbirdio/upload:{{ .Version }}-arm
use: buildx # ids:
dockerfile: upload-server/Dockerfile # - netbird-upload
build_flag_templates: # goarch: arm
- "--platform=linux/arm" # goarm: 6
- "--label=org.opencontainers.image.created={{.Date}}" # use: buildx
- "--label=org.opencontainers.image.title={{.ProjectName}}" # dockerfile: upload-server/Dockerfile
- "--label=org.opencontainers.image.version={{.Version}}" # build_flag_templates:
- "--label=org.opencontainers.image.revision={{.FullCommit}}" # - "--platform=linux/arm"
- "--label=org.opencontainers.image.version={{.Version}}" # - "--label=org.opencontainers.image.created={{.Date}}"
- "--label=maintainer=dev@netbird.io" # - "--label=org.opencontainers.image.title={{.ProjectName}}"
docker_manifests: # - "--label=org.opencontainers.image.version={{.Version}}"
- name_template: netbirdio/netbird:{{ .Version }} # - "--label=org.opencontainers.image.revision={{.FullCommit}}"
image_templates: # - "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- netbirdio/netbird:{{ .Version }}-arm64v8 # - "--label=maintainer=dev@netbird.io"
- netbirdio/netbird:{{ .Version }}-arm # docker_manifests:
- netbirdio/netbird:{{ .Version }}-amd64 # - name_template: netbirdio/netbird:{{ .Version }}
# image_templates:
# - netbirdio/netbird:{{ .Version }}-arm64v8
# - netbirdio/netbird:{{ .Version }}-arm
# - netbirdio/netbird:{{ .Version }}-amd64
#
# - name_template: netbirdio/netbird:latest
# image_templates:
# - netbirdio/netbird:{{ .Version }}-arm64v8
# - netbirdio/netbird:{{ .Version }}-arm
# - netbirdio/netbird:{{ .Version }}-amd64
#
# - name_template: netbirdio/netbird:{{ .Version }}-rootless
# image_templates:
# - netbirdio/netbird:{{ .Version }}-rootless-arm64v8
# - netbirdio/netbird:{{ .Version }}-rootless-arm
# - netbirdio/netbird:{{ .Version }}-rootless-amd64
#
# - name_template: netbirdio/netbird:rootless-latest
# image_templates:
# - netbirdio/netbird:{{ .Version }}-rootless-arm64v8
# - netbirdio/netbird:{{ .Version }}-rootless-arm
# - netbirdio/netbird:{{ .Version }}-rootless-amd64
#
# - name_template: netbirdio/relay:{{ .Version }}
# image_templates:
# - netbirdio/relay:{{ .Version }}-arm64v8
# - netbirdio/relay:{{ .Version }}-arm
# - netbirdio/relay:{{ .Version }}-amd64
#
# - name_template: netbirdio/relay:latest
# image_templates:
# - netbirdio/relay:{{ .Version }}-arm64v8
# - netbirdio/relay:{{ .Version }}-arm
# - netbirdio/relay:{{ .Version }}-amd64
#
# - name_template: netbirdio/signal:{{ .Version }}
# image_templates:
# - netbirdio/signal:{{ .Version }}-arm64v8
# - netbirdio/signal:{{ .Version }}-arm
# - netbirdio/signal:{{ .Version }}-amd64
#
# - name_template: netbirdio/signal:latest
# image_templates:
# - netbirdio/signal:{{ .Version }}-arm64v8
# - netbirdio/signal:{{ .Version }}-arm
# - netbirdio/signal:{{ .Version }}-amd64
#
# - name_template: netbirdio/management:{{ .Version }}
# image_templates:
# - netbirdio/management:{{ .Version }}-arm64v8
# - netbirdio/management:{{ .Version }}-arm
# - netbirdio/management:{{ .Version }}-amd64
#
# - name_template: netbirdio/management:latest
# image_templates:
# - netbirdio/management:{{ .Version }}-arm64v8
# - netbirdio/management:{{ .Version }}-arm
# - netbirdio/management:{{ .Version }}-amd64
#
# - name_template: netbirdio/management:debug-latest
# image_templates:
# - netbirdio/management:{{ .Version }}-debug-arm64v8
# - netbirdio/management:{{ .Version }}-debug-arm
# - netbirdio/management:{{ .Version }}-debug-amd64
# - name_template: netbirdio/upload:{{ .Version }}
# image_templates:
# - netbirdio/upload:{{ .Version }}-arm64v8
# - netbirdio/upload:{{ .Version }}-arm
# - netbirdio/upload:{{ .Version }}-amd64
#
# - name_template: netbirdio/upload:latest
# image_templates:
# - netbirdio/upload:{{ .Version }}-arm64v8
# - netbirdio/upload:{{ .Version }}-arm
# - netbirdio/upload:{{ .Version }}-amd64
#
# - name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
# image_templates:
# - ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
# - ghcr.io/netbirdio/netbird:{{ .Version }}-arm
# - ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
#
# - name_template: ghcr.io/netbirdio/netbird:latest
# image_templates:
# - ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
# - ghcr.io/netbirdio/netbird:{{ .Version }}-arm
# - ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
#
# - name_template: ghcr.io/netbirdio/netbird:{{ .Version }}-rootless
# image_templates:
# - ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
# - ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
# - ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
#
# - name_template: ghcr.io/netbirdio/netbird:rootless-latest
# image_templates:
# - ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
# - ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
# - ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
#
# - name_template: ghcr.io/netbirdio/relay:{{ .Version }}
# image_templates:
# - ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
# - ghcr.io/netbirdio/relay:{{ .Version }}-arm
# - ghcr.io/netbirdio/relay:{{ .Version }}-amd64
#
# - name_template: ghcr.io/netbirdio/relay:latest
# image_templates:
# - ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
# - ghcr.io/netbirdio/relay:{{ .Version }}-arm
# - ghcr.io/netbirdio/relay:{{ .Version }}-amd64
#
# - name_template: ghcr.io/netbirdio/signal:{{ .Version }}
# image_templates:
# - ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
# - ghcr.io/netbirdio/signal:{{ .Version }}-arm
# - ghcr.io/netbirdio/signal:{{ .Version }}-amd64
#
# - name_template: ghcr.io/netbirdio/signal:latest
# image_templates:
# - ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
# - ghcr.io/netbirdio/signal:{{ .Version }}-arm
# - ghcr.io/netbirdio/signal:{{ .Version }}-amd64
#
# - name_template: ghcr.io/netbirdio/management:{{ .Version }}
# image_templates:
# - ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
# - ghcr.io/netbirdio/management:{{ .Version }}-arm
# - ghcr.io/netbirdio/management:{{ .Version }}-amd64
#
# - name_template: ghcr.io/netbirdio/management:latest
# image_templates:
# - ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
# - ghcr.io/netbirdio/management:{{ .Version }}-arm
# - ghcr.io/netbirdio/management:{{ .Version }}-amd64
#
# - name_template: ghcr.io/netbirdio/management:debug-latest
# image_templates:
# - ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
# - ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
# - ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
#
# - name_template: ghcr.io/netbirdio/upload:{{ .Version }}
# image_templates:
# - ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
# - ghcr.io/netbirdio/upload:{{ .Version }}-arm
# - ghcr.io/netbirdio/upload:{{ .Version }}-amd64
#
# - name_template: ghcr.io/netbirdio/upload:latest
# image_templates:
# - ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
# - ghcr.io/netbirdio/upload:{{ .Version }}-arm
# - ghcr.io/netbirdio/upload:{{ .Version }}-amd64
# brews:
# - ids:
# - default
# repository:
# owner: netbirdio
# name: homebrew-tap
# token: "{{ .Env.HOMEBREW_TAP_GITHUB_TOKEN }}"
# commit_author:
# name: Netbird
# email: dev@netbird.io
# description: Netbird project.
# download_strategy: CurlDownloadStrategy
# homepage: https://netbird.io/
# license: "BSD3"
# test: |
# system "#{bin}/{{ .ProjectName }} version"
- name_template: netbirdio/netbird:latest # uploads:
image_templates: # - name: debian
- netbirdio/netbird:{{ .Version }}-arm64v8 # ids:
- netbirdio/netbird:{{ .Version }}-arm # - netbird-deb
- netbirdio/netbird:{{ .Version }}-amd64 # mode: archive
# target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
# username: dev@wiretrustee.com
# method: PUT
- name_template: netbirdio/netbird:{{ .Version }}-rootless # - name: yum
image_templates: # ids:
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8 # - netbird-rpm
- netbirdio/netbird:{{ .Version }}-rootless-arm # mode: archive
- netbirdio/netbird:{{ .Version }}-rootless-amd64 # target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
# username: dev@wiretrustee.com
- name_template: netbirdio/netbird:rootless-latest # method: PUT
image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- netbirdio/netbird:{{ .Version }}-rootless-arm
- netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: netbirdio/relay:{{ .Version }}
image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
- netbirdio/relay:{{ .Version }}-arm
- netbirdio/relay:{{ .Version }}-amd64
- name_template: netbirdio/relay:latest
image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
- netbirdio/relay:{{ .Version }}-arm
- netbirdio/relay:{{ .Version }}-amd64
- name_template: netbirdio/signal:{{ .Version }}
image_templates:
- netbirdio/signal:{{ .Version }}-arm64v8
- netbirdio/signal:{{ .Version }}-arm
- netbirdio/signal:{{ .Version }}-amd64
- name_template: netbirdio/signal:latest
image_templates:
- netbirdio/signal:{{ .Version }}-arm64v8
- netbirdio/signal:{{ .Version }}-arm
- netbirdio/signal:{{ .Version }}-amd64
- name_template: netbirdio/management:{{ .Version }}
image_templates:
- netbirdio/management:{{ .Version }}-arm64v8
- netbirdio/management:{{ .Version }}-arm
- netbirdio/management:{{ .Version }}-amd64
- name_template: netbirdio/management:latest
image_templates:
- netbirdio/management:{{ .Version }}-arm64v8
- netbirdio/management:{{ .Version }}-arm
- netbirdio/management:{{ .Version }}-amd64
- name_template: netbirdio/management:debug-latest
image_templates:
- netbirdio/management:{{ .Version }}-debug-arm64v8
- netbirdio/management:{{ .Version }}-debug-arm
- netbirdio/management:{{ .Version }}-debug-amd64
- name_template: netbirdio/upload:{{ .Version }}
image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
- name_template: netbirdio/upload:latest
image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
brews:
- ids:
- default
repository:
owner: netbirdio
name: homebrew-tap
token: "{{ .Env.HOMEBREW_TAP_GITHUB_TOKEN }}"
commit_author:
name: Netbird
email: dev@netbird.io
description: Netbird project.
download_strategy: CurlDownloadStrategy
homepage: https://netbird.io/
license: "BSD3"
test: |
system "#{bin}/{{ .ProjectName }} version"
uploads:
- name: debian
ids:
- netbird-deb
mode: archive
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
username: dev@wiretrustee.com
method: PUT
- name: yum
ids:
- netbird-rpm
mode: archive
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
username: dev@wiretrustee.com
method: PUT
checksum: checksum:
extra_files: extra_files:

View File

@@ -79,19 +79,19 @@ nfpms:
dependencies: dependencies:
- netbird - netbird
uploads: # uploads:
- name: debian # - name: debian
ids: # ids:
- netbird-ui-deb # - netbird-ui-deb
mode: archive # mode: archive
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package= # target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
username: dev@wiretrustee.com # username: dev@wiretrustee.com
method: PUT # method: PUT
- name: yum # - name: yum
ids: # ids:
- netbird-ui-rpm # - netbird-ui-rpm
mode: archive # mode: archive
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }} # target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
username: dev@wiretrustee.com # username: dev@wiretrustee.com
method: PUT # method: PUT

View File

@@ -12,7 +12,7 @@
<img src="https://img.shields.io/badge/license-BSD--3-blue" /> <img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a> </a>
<br> <br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g"> <a href="https://docs.netbird.io/slack-url">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/> <img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a> </a>
<br> <br>
@@ -29,7 +29,7 @@
<br/> <br/>
See <a href="https://netbird.io/docs/">Documentation</a> See <a href="https://netbird.io/docs/">Documentation</a>
<br/> <br/>
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">Slack channel</a> Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a>
<br/> <br/>
</strong> </strong>

View File

@@ -1,6 +1,9 @@
FROM alpine:3.21.3 FROM alpine:3.21.3
# iproute2: busybox doesn't display ip rules properly # iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables
ARG NETBIRD_BINARY=netbird
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
ENV NB_FOREGROUND_MODE=true ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"] ENTRYPOINT [ "/usr/local/bin/netbird","up"]
COPY netbird /usr/local/bin/netbird

View File

@@ -1,6 +1,7 @@
FROM alpine:3.21.0 FROM alpine:3.21.0
COPY netbird /usr/local/bin/netbird ARG NETBIRD_BINARY=netbird
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
RUN apk add --no-cache ca-certificates \ RUN apk add --no-cache ca-certificates \
&& adduser -D -h /var/lib/netbird netbird && adduser -D -h /var/lib/netbird netbird

View File

@@ -59,6 +59,8 @@ type Client struct {
deviceName string deviceName string
uiVersion string uiVersion string
networkChangeListener listener.NetworkChangeListener networkChangeListener listener.NetworkChangeListener
connectClient *internal.ConnectClient
} }
// NewClient instantiate a new Client // NewClient instantiate a new Client
@@ -106,8 +108,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
} }
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@@ -132,8 +134,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder) c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener) return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
} }
// Stop the internal client and free the resources // Stop the internal client and free the resources
@@ -174,6 +176,53 @@ func (c *Client) PeersList() *PeerInfoArray {
return &PeerInfoArray{items: peerInfos} return &PeerInfoArray{items: peerInfos}
} }
func (c *Client) Networks() *NetworkArray {
if c.connectClient == nil {
log.Error("not connected")
return nil
}
engine := c.connectClient.Engine()
if engine == nil {
log.Error("could not get engine")
return nil
}
routeManager := engine.GetRouteManager()
if routeManager == nil {
log.Error("could not get route manager")
return nil
}
networkArray := &NetworkArray{
items: make([]Network, 0),
}
for id, routes := range routeManager.GetClientRoutesWithNetID() {
if len(routes) == 0 {
continue
}
if routes[0].IsDynamic() {
continue
}
peer, err := c.recorder.GetPeer(routes[0].Peer)
if err != nil {
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
continue
}
network := Network{
Name: string(id),
Network: routes[0].Network.String(),
Peer: peer.FQDN,
Status: peer.ConnStatus.String(),
}
networkArray.Add(network)
}
return networkArray
}
// OnUpdatedHostDNS update the DNS servers addresses for root zones // OnUpdatedHostDNS update the DNS servers addresses for root zones
func (c *Client) OnUpdatedHostDNS(list *DNSList) error { func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
dnsServer, err := dns.GetServerDns() dnsServer, err := dns.GetServerDns()

View File

@@ -0,0 +1,27 @@
//go:build android
package android
type Network struct {
Name string
Network string
Peer string
Status string
}
type NetworkArray struct {
items []Network
}
func (array *NetworkArray) Add(s Network) *NetworkArray {
array.items = append(array.items, s)
return array
}
func (array *NetworkArray) Get(i int) *Network {
return &array.items[i]
}
func (array *NetworkArray) Size() int {
return len(array.items)
}

View File

@@ -7,30 +7,23 @@ type PeerInfo struct {
ConnStatus string // Todo replace to enum ConnStatus string // Todo replace to enum
} }
// PeerInfoCollection made for Java layer to get non default types as collection // PeerInfoArray is a wrapper of []PeerInfo
type PeerInfoCollection interface {
Add(s string) PeerInfoCollection
Get(i int) string
Size() int
}
// PeerInfoArray is the implementation of the PeerInfoCollection
type PeerInfoArray struct { type PeerInfoArray struct {
items []PeerInfo items []PeerInfo
} }
// Add new PeerInfo to the collection // Add new PeerInfo to the collection
func (array PeerInfoArray) Add(s PeerInfo) PeerInfoArray { func (array *PeerInfoArray) Add(s PeerInfo) *PeerInfoArray {
array.items = append(array.items, s) array.items = append(array.items, s)
return array return array
} }
// Get return an element of the collection // Get return an element of the collection
func (array PeerInfoArray) Get(i int) *PeerInfo { func (array *PeerInfoArray) Get(i int) *PeerInfo {
return &array.items[i] return &array.items[i]
} }
// Size return with the size of the collection // Size return with the size of the collection
func (array PeerInfoArray) Size() int { func (array *PeerInfoArray) Size() int {
return len(array.items) return len(array.items)
} }

View File

@@ -4,12 +4,12 @@ import (
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
) )
// Preferences export a subset of the internal config for gomobile // Preferences exports a subset of the internal config for gomobile
type Preferences struct { type Preferences struct {
configInput internal.ConfigInput configInput internal.ConfigInput
} }
// NewPreferences create new Preferences instance // NewPreferences creates a new Preferences instance
func NewPreferences(configPath string) *Preferences { func NewPreferences(configPath string) *Preferences {
ci := internal.ConfigInput{ ci := internal.ConfigInput{
ConfigPath: configPath, ConfigPath: configPath,
@@ -17,7 +17,7 @@ func NewPreferences(configPath string) *Preferences {
return &Preferences{ci} return &Preferences{ci}
} }
// GetManagementURL read url from config file // GetManagementURL reads URL from config file
func (p *Preferences) GetManagementURL() (string, error) { func (p *Preferences) GetManagementURL() (string, error) {
if p.configInput.ManagementURL != "" { if p.configInput.ManagementURL != "" {
return p.configInput.ManagementURL, nil return p.configInput.ManagementURL, nil
@@ -30,12 +30,12 @@ func (p *Preferences) GetManagementURL() (string, error) {
return cfg.ManagementURL.String(), err return cfg.ManagementURL.String(), err
} }
// SetManagementURL store the given url and wait for commit // SetManagementURL stores the given URL and waits for commit
func (p *Preferences) SetManagementURL(url string) { func (p *Preferences) SetManagementURL(url string) {
p.configInput.ManagementURL = url p.configInput.ManagementURL = url
} }
// GetAdminURL read url from config file // GetAdminURL reads URL from config file
func (p *Preferences) GetAdminURL() (string, error) { func (p *Preferences) GetAdminURL() (string, error) {
if p.configInput.AdminURL != "" { if p.configInput.AdminURL != "" {
return p.configInput.AdminURL, nil return p.configInput.AdminURL, nil
@@ -48,12 +48,12 @@ func (p *Preferences) GetAdminURL() (string, error) {
return cfg.AdminURL.String(), err return cfg.AdminURL.String(), err
} }
// SetAdminURL store the given url and wait for commit // SetAdminURL stores the given URL and waits for commit
func (p *Preferences) SetAdminURL(url string) { func (p *Preferences) SetAdminURL(url string) {
p.configInput.AdminURL = url p.configInput.AdminURL = url
} }
// GetPreSharedKey read preshared key from config file // GetPreSharedKey reads pre-shared key from config file
func (p *Preferences) GetPreSharedKey() (string, error) { func (p *Preferences) GetPreSharedKey() (string, error) {
if p.configInput.PreSharedKey != nil { if p.configInput.PreSharedKey != nil {
return *p.configInput.PreSharedKey, nil return *p.configInput.PreSharedKey, nil
@@ -66,12 +66,160 @@ func (p *Preferences) GetPreSharedKey() (string, error) {
return cfg.PreSharedKey, err return cfg.PreSharedKey, err
} }
// SetPreSharedKey store the given key and wait for commit // SetPreSharedKey stores the given key and waits for commit
func (p *Preferences) SetPreSharedKey(key string) { func (p *Preferences) SetPreSharedKey(key string) {
p.configInput.PreSharedKey = &key p.configInput.PreSharedKey = &key
} }
// Commit write out the changes into config file // SetRosenpassEnabled stores whether Rosenpass is enabled
func (p *Preferences) SetRosenpassEnabled(enabled bool) {
p.configInput.RosenpassEnabled = &enabled
}
// GetRosenpassEnabled reads Rosenpass enabled status from config file
func (p *Preferences) GetRosenpassEnabled() (bool, error) {
if p.configInput.RosenpassEnabled != nil {
return *p.configInput.RosenpassEnabled, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.RosenpassEnabled, err
}
// SetRosenpassPermissive stores the given permissive setting and waits for commit
func (p *Preferences) SetRosenpassPermissive(permissive bool) {
p.configInput.RosenpassPermissive = &permissive
}
// GetRosenpassPermissive reads Rosenpass permissive setting from config file
func (p *Preferences) GetRosenpassPermissive() (bool, error) {
if p.configInput.RosenpassPermissive != nil {
return *p.configInput.RosenpassPermissive, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.RosenpassPermissive, err
}
// GetDisableClientRoutes reads disable client routes setting from config file
func (p *Preferences) GetDisableClientRoutes() (bool, error) {
if p.configInput.DisableClientRoutes != nil {
return *p.configInput.DisableClientRoutes, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableClientRoutes, err
}
// SetDisableClientRoutes stores the given value and waits for commit
func (p *Preferences) SetDisableClientRoutes(disable bool) {
p.configInput.DisableClientRoutes = &disable
}
// GetDisableServerRoutes reads disable server routes setting from config file
func (p *Preferences) GetDisableServerRoutes() (bool, error) {
if p.configInput.DisableServerRoutes != nil {
return *p.configInput.DisableServerRoutes, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableServerRoutes, err
}
// SetDisableServerRoutes stores the given value and waits for commit
func (p *Preferences) SetDisableServerRoutes(disable bool) {
p.configInput.DisableServerRoutes = &disable
}
// GetDisableDNS reads disable DNS setting from config file
func (p *Preferences) GetDisableDNS() (bool, error) {
if p.configInput.DisableDNS != nil {
return *p.configInput.DisableDNS, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableDNS, err
}
// SetDisableDNS stores the given value and waits for commit
func (p *Preferences) SetDisableDNS(disable bool) {
p.configInput.DisableDNS = &disable
}
// GetDisableFirewall reads disable firewall setting from config file
func (p *Preferences) GetDisableFirewall() (bool, error) {
if p.configInput.DisableFirewall != nil {
return *p.configInput.DisableFirewall, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableFirewall, err
}
// SetDisableFirewall stores the given value and waits for commit
func (p *Preferences) SetDisableFirewall(disable bool) {
p.configInput.DisableFirewall = &disable
}
// GetServerSSHAllowed reads server SSH allowed setting from config file
func (p *Preferences) GetServerSSHAllowed() (bool, error) {
if p.configInput.ServerSSHAllowed != nil {
return *p.configInput.ServerSSHAllowed, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
if cfg.ServerSSHAllowed == nil {
// Default to false for security on Android
return false, nil
}
return *cfg.ServerSSHAllowed, err
}
// SetServerSSHAllowed stores the given value and waits for commit
func (p *Preferences) SetServerSSHAllowed(allowed bool) {
p.configInput.ServerSSHAllowed = &allowed
}
// GetBlockInbound reads block inbound setting from config file
func (p *Preferences) GetBlockInbound() (bool, error) {
if p.configInput.BlockInbound != nil {
return *p.configInput.BlockInbound, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.BlockInbound, err
}
// SetBlockInbound stores the given value and waits for commit
func (p *Preferences) SetBlockInbound(block bool) {
p.configInput.BlockInbound = &block
}
// Commit writes out the changes to the config file
func (p *Preferences) Commit() error { func (p *Preferences) Commit() error {
_, err := internal.UpdateOrCreateConfig(p.configInput) _, err := internal.UpdateOrCreateConfig(p.configInput)
return err return err

View File

@@ -69,6 +69,22 @@ func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
return a.ipAnonymizer[ip] return a.ipAnonymizer[ip]
} }
func (a *Anonymizer) AnonymizeUDPAddr(addr net.UDPAddr) net.UDPAddr {
// Convert IP to netip.Addr
ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
return addr
}
anonIP := a.AnonymizeIP(ip)
return net.UDPAddr{
IP: anonIP.AsSlice(),
Port: addr.Port,
Zone: addr.Zone,
}
}
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs // isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool { func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 { if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {

View File

@@ -39,7 +39,6 @@ const (
extraIFaceBlackListFlag = "extra-iface-blacklist" extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval" dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info" systemInfoFlag = "system-info"
blockLANAccessFlag = "block-lan-access"
enableLazyConnectionFlag = "enable-lazy-connection" enableLazyConnectionFlag = "enable-lazy-connection"
uploadBundle = "upload-bundle" uploadBundle = "upload-bundle"
uploadBundleURL = "upload-bundle-url" uploadBundleURL = "upload-bundle-url"
@@ -78,7 +77,6 @@ var (
anonymizeFlag bool anonymizeFlag bool
debugSystemInfoFlag bool debugSystemInfoFlag bool
dnsRouteInterval time.Duration dnsRouteInterval time.Duration
blockLANAccess bool
debugUploadBundle bool debugUploadBundle bool
debugUploadBundleURL string debugUploadBundleURL string
lazyConnEnabled bool lazyConnEnabled bool

View File

@@ -2,6 +2,7 @@ package cmd
import ( import (
"context" "context"
"runtime"
"sync" "sync"
"github.com/kardianos/service" "github.com/kardianos/service"
@@ -27,12 +28,19 @@ func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
} }
func newSVCConfig() *service.Config { func newSVCConfig() *service.Config {
return &service.Config{ config := &service.Config{
Name: serviceName, Name: serviceName,
DisplayName: "Netbird", DisplayName: "Netbird",
Description: "A WireGuard-based mesh network that connects your devices into a single private network.", Description: "Netbird mesh network client",
Option: make(service.KeyValue), Option: make(service.KeyValue),
EnvVars: make(map[string]string),
} }
if runtime.GOOS == "linux" {
config.EnvVars["SYSTEMD_UNIT"] = serviceName
}
return config
} }
func newSVC(prg *program, conf *service.Config) (service.Service, error) { func newSVC(prg *program, conf *service.Config) (service.Service, error) {

View File

@@ -39,7 +39,7 @@ var installCmd = &cobra.Command{
svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL) svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL)
} }
if logFile != "console" { if logFile != "" {
svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile) svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile)
} }

View File

@@ -69,7 +69,10 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return err return err
} }
if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) { status := resp.GetStatus()
if status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
status == string(internal.StatusSessionExpired) {
cmd.Printf("Daemon status: %s\n\n"+ cmd.Printf("Daemon status: %s\n\n"+
"Run UP command to log in with SSO (interactive login):\n\n"+ "Run UP command to log in with SSO (interactive login):\n\n"+
" netbird up \n\n"+ " netbird up \n\n"+

View File

@@ -6,6 +6,8 @@ const (
disableServerRoutesFlag = "disable-server-routes" disableServerRoutesFlag = "disable-server-routes"
disableDNSFlag = "disable-dns" disableDNSFlag = "disable-dns"
disableFirewallFlag = "disable-firewall" disableFirewallFlag = "disable-firewall"
blockLANAccessFlag = "block-lan-access"
blockInboundFlag = "block-inbound"
) )
var ( var (
@@ -13,6 +15,8 @@ var (
disableServerRoutes bool disableServerRoutes bool
disableDNS bool disableDNS bool
disableFirewall bool disableFirewall bool
blockLANAccess bool
blockInbound bool
) )
func init() { func init() {
@@ -28,4 +32,11 @@ func init() {
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false, upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
"Disable firewall configuration. If enabled, the client won't modify firewall rules.") "Disable firewall configuration. If enabled, the client won't modify firewall rules.")
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false,
"Block access to local networks (LAN) when using this peer as a router or exit node")
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
"This overrides any policies received from the management service.")
} }

View File

@@ -17,7 +17,7 @@ var traceCmd = &cobra.Command{
Example: ` Example: `
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53 netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0 netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --icmp-type 8 --icmp-code 0
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`, netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
Args: cobra.ExactArgs(3), Args: cobra.ExactArgs(3),
RunE: tracePacket, RunE: tracePacket,
@@ -118,7 +118,7 @@ func tracePacket(cmd *cobra.Command, args []string) error {
} }
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) { func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto)) cmd.Printf("Packet trace %s:%d %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
for _, stage := range resp.Stages { for _, stage := range resp.Stages {
if stage.ForwardingDetails != nil { if stage.ForwardingDetails != nil {

View File

@@ -55,12 +55,11 @@ func init() {
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name") upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port") upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor, upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux. `+ `Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`, `E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
) )
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening") upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval") upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil, upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
`Sets DNS labels`+ `Sets DNS labels`+
@@ -119,83 +118,9 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
return err return err
} }
ic := internal.ConfigInput{ ic, err := setupConfig(customDNSAddressConverted, cmd)
ManagementURL: managementURL, if err != nil {
AdminURL: adminURL, return fmt.Errorf("setup config: %v", err)
ConfigPath: configPath,
NATExternalIPs: natExternalIPs,
CustomDNSAddress: customDNSAddressConverted,
ExtraIFaceBlackList: extraIFaceBlackList,
DNSLabels: dnsLabelsValidated,
}
if cmd.Flag(enableRosenpassFlag).Changed {
ic.RosenpassEnabled = &rosenpassEnabled
}
if cmd.Flag(rosenpassPermissiveFlag).Changed {
ic.RosenpassPermissive = &rosenpassPermissive
}
if cmd.Flag(serverSSHAllowedFlag).Changed {
ic.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
return err
}
ic.InterfaceName = &interfaceName
}
if cmd.Flag(wireguardPortFlag).Changed {
p := int(wireguardPort)
ic.WireguardPort = &p
}
if cmd.Flag(networkMonitorFlag).Changed {
ic.NetworkMonitor = &networkMonitor
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
ic.PreSharedKey = &preSharedKey
}
if cmd.Flag(disableAutoConnectFlag).Changed {
ic.DisableAutoConnect = &autoConnectDisabled
if autoConnectDisabled {
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
}
if !autoConnectDisabled {
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
}
}
if cmd.Flag(dnsRouteIntervalFlag).Changed {
ic.DNSRouteInterval = &dnsRouteInterval
}
if cmd.Flag(disableClientRoutesFlag).Changed {
ic.DisableClientRoutes = &disableClientRoutes
}
if cmd.Flag(disableServerRoutesFlag).Changed {
ic.DisableServerRoutes = &disableServerRoutes
}
if cmd.Flag(disableDNSFlag).Changed {
ic.DisableDNS = &disableDNS
}
if cmd.Flag(disableFirewallFlag).Changed {
ic.DisableFirewall = &disableFirewall
}
if cmd.Flag(blockLANAccessFlag).Changed {
ic.BlockLANAccess = &blockLANAccess
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
ic.LazyConnectionEnabled = &lazyConnEnabled
} }
providedSetupKey, err := getSetupKey() providedSetupKey, err := getSetupKey()
@@ -203,7 +128,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
return err return err
} }
config, err := internal.UpdateOrCreateConfig(ic) config, err := internal.UpdateOrCreateConfig(*ic)
if err != nil { if err != nil {
return fmt.Errorf("get config file: %v", err) return fmt.Errorf("get config file: %v", err)
} }
@@ -262,9 +187,141 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
providedSetupKey, err := getSetupKey() providedSetupKey, err := getSetupKey()
if err != nil { if err != nil {
return err return fmt.Errorf("get setup key: %v", err)
} }
loginRequest, err := setupLoginRequest(providedSetupKey, customDNSAddressConverted, cmd)
if err != nil {
return fmt.Errorf("setup login request: %v", err)
}
var loginErr error
var loginResp *proto.LoginResponse
err = WithBackOff(func() error {
var backOffErr error
loginResp, backOffErr = client.Login(ctx, loginRequest)
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
s.Code() == codes.PermissionDenied ||
s.Code() == codes.NotFound ||
s.Code() == codes.Unimplemented) {
loginErr = backOffErr
return nil
}
return backOffErr
})
if err != nil {
return fmt.Errorf("login backoff cycle failed: %v", err)
}
if loginErr != nil {
return fmt.Errorf("login failed: %v", loginErr)
}
if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil {
return fmt.Errorf("waiting sso login failed with: %v", err)
}
}
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
return fmt.Errorf("call service up method: %v", err)
}
cmd.Println("Connected")
return nil
}
func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command) (*internal.ConfigInput, error) {
ic := internal.ConfigInput{
ManagementURL: managementURL,
AdminURL: adminURL,
ConfigPath: configPath,
NATExternalIPs: natExternalIPs,
CustomDNSAddress: customDNSAddressConverted,
ExtraIFaceBlackList: extraIFaceBlackList,
DNSLabels: dnsLabelsValidated,
}
if cmd.Flag(enableRosenpassFlag).Changed {
ic.RosenpassEnabled = &rosenpassEnabled
}
if cmd.Flag(rosenpassPermissiveFlag).Changed {
ic.RosenpassPermissive = &rosenpassPermissive
}
if cmd.Flag(serverSSHAllowedFlag).Changed {
ic.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
return nil, err
}
ic.InterfaceName = &interfaceName
}
if cmd.Flag(wireguardPortFlag).Changed {
p := int(wireguardPort)
ic.WireguardPort = &p
}
if cmd.Flag(networkMonitorFlag).Changed {
ic.NetworkMonitor = &networkMonitor
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
ic.PreSharedKey = &preSharedKey
}
if cmd.Flag(disableAutoConnectFlag).Changed {
ic.DisableAutoConnect = &autoConnectDisabled
if autoConnectDisabled {
cmd.Println("Autoconnect has been disabled. The client won't connect automatically when the service starts.")
}
if !autoConnectDisabled {
cmd.Println("Autoconnect has been enabled. The client will connect automatically when the service starts.")
}
}
if cmd.Flag(dnsRouteIntervalFlag).Changed {
ic.DNSRouteInterval = &dnsRouteInterval
}
if cmd.Flag(disableClientRoutesFlag).Changed {
ic.DisableClientRoutes = &disableClientRoutes
}
if cmd.Flag(disableServerRoutesFlag).Changed {
ic.DisableServerRoutes = &disableServerRoutes
}
if cmd.Flag(disableDNSFlag).Changed {
ic.DisableDNS = &disableDNS
}
if cmd.Flag(disableFirewallFlag).Changed {
ic.DisableFirewall = &disableFirewall
}
if cmd.Flag(blockLANAccessFlag).Changed {
ic.BlockLANAccess = &blockLANAccess
}
if cmd.Flag(blockInboundFlag).Changed {
ic.BlockInbound = &blockInbound
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
ic.LazyConnectionEnabled = &lazyConnEnabled
}
return &ic, nil
}
func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte, cmd *cobra.Command) (*proto.LoginRequest, error) {
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey, SetupKey: providedSetupKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
@@ -301,7 +358,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
if cmd.Flag(interfaceNameFlag).Changed { if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil { if err := parseInterfaceName(interfaceName); err != nil {
return err return nil, err
} }
loginRequest.InterfaceName = &interfaceName loginRequest.InterfaceName = &interfaceName
} }
@@ -336,49 +393,14 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
loginRequest.BlockLanAccess = &blockLANAccess loginRequest.BlockLanAccess = &blockLANAccess
} }
if cmd.Flag(blockInboundFlag).Changed {
loginRequest.BlockInbound = &blockInbound
}
if cmd.Flag(enableLazyConnectionFlag).Changed { if cmd.Flag(enableLazyConnectionFlag).Changed {
loginRequest.LazyConnectionEnabled = &lazyConnEnabled loginRequest.LazyConnectionEnabled = &lazyConnEnabled
} }
return &loginRequest, nil
var loginErr error
var loginResp *proto.LoginResponse
err = WithBackOff(func() error {
var backOffErr error
loginResp, backOffErr = client.Login(ctx, &loginRequest)
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
s.Code() == codes.PermissionDenied ||
s.Code() == codes.NotFound ||
s.Code() == codes.Unimplemented) {
loginErr = backOffErr
return nil
}
return backOffErr
})
if err != nil {
return fmt.Errorf("login backoff cycle failed: %v", err)
}
if loginErr != nil {
return fmt.Errorf("login failed: %v", loginErr)
}
if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil {
return fmt.Errorf("waiting sso login failed with: %v", err)
}
}
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
return fmt.Errorf("call service up method: %v", err)
}
cmd.Println("Connected")
return nil
} }
func validateNATExternalIPs(list []string) error { func validateNATExternalIPs(list []string) error {

View File

@@ -147,6 +147,10 @@ func (m *Manager) IsServerRouteSupported() bool {
return true return true
} }
func (m *Manager) IsStateful() bool {
return true
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -198,7 +202,7 @@ func (m *Manager) AllowNetbird() error {
_, err := m.AddPeerFiltering( _, err := m.AddPeerFiltering(
nil, nil,
net.IP{0, 0, 0, 0}, net.IP{0, 0, 0, 0},
"all", firewall.ProtocolALL,
nil, nil,
nil, nil,
firewall.ActionAccept, firewall.ActionAccept,
@@ -219,10 +223,16 @@ func (m *Manager) SetLogLevel(log.Level) {
} }
func (m *Manager) EnableRouting() error { func (m *Manager) EnableRouting() error {
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
return fmt.Errorf("enable IP forwarding: %w", err)
}
return nil return nil
} }
func (m *Manager) DisableRouting() error { func (m *Manager) DisableRouting() error {
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
return fmt.Errorf("disable IP forwarding: %w", err)
}
return nil return nil
} }

View File

@@ -2,7 +2,7 @@ package iptables
import ( import (
"fmt" "fmt"
"net" "net/netip"
"testing" "testing"
"time" "time"
@@ -19,11 +19,8 @@ var ifaceMock = &iFaceMock{
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: netip.MustParseAddr("10.20.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("10.20.0.0/24"),
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -70,12 +67,12 @@ func TestIptablesManager(t *testing.T) {
var rule2 []fw.Rule var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) { t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3") ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{ port := &fw.Port{
IsRange: true, IsRange: true,
Values: []uint16{8043, 8046}, Values: []uint16{8043, 8046},
} }
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "") rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
for _, r := range rule2 { for _, r := range rule2 {
@@ -95,9 +92,9 @@ func TestIptablesManager(t *testing.T) {
t.Run("reset check", func(t *testing.T) { t.Run("reset check", func(t *testing.T) {
// add second rule // add second rule
ip := net.ParseIP("10.20.0.3") ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{Values: []uint16{5353}} port := &fw.Port{Values: []uint16{5353}}
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Close(nil) err = manager.Close(nil)
@@ -119,11 +116,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: netip.MustParseAddr("10.20.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("10.20.0.0/24"),
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -144,11 +138,11 @@ func TestIptablesManagerIPSet(t *testing.T) {
var rule2 []fw.Rule var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) { t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3") ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{ port := &fw.Port{
Values: []uint16{443}, Values: []uint16{443},
} }
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default") rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
for _, r := range rule2 { for _, r := range rule2 {
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set") require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
@@ -186,11 +180,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: netip.MustParseAddr("10.20.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("10.20.0.0/24"),
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -212,11 +203,11 @@ func TestIptablesCreatePerformance(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
ip := net.ParseIP("10.20.0.100") ip := netip.MustParseAddr("10.20.0.100")
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}} port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
} }

View File

@@ -248,10 +248,6 @@ func (r *router) deleteIpSet(setName string) error {
// AddNatRule inserts an iptables rule pair into the nat chain // AddNatRule inserts an iptables rule pair into the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error { func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if r.legacyManagement { if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination) log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil { if err := r.addLegacyRouteRule(pair); err != nil {
@@ -278,10 +274,6 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains // RemoveNatRule removes an iptables rule pair from forwarding and nat chains
func (r *router) RemoveNatRule(pair firewall.RouterPair) error { func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if pair.Masquerade { if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil { if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err) return fmt.Errorf("remove nat rule: %w", err)

View File

@@ -116,6 +116,8 @@ type Manager interface {
// IsServerRouteSupported returns true if the firewall supports server side routing operations // IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool IsServerRouteSupported() bool
IsStateful() bool
AddRouteFiltering( AddRouteFiltering(
id []byte, id []byte,
sources []netip.Prefix, sources []netip.Prefix,

View File

@@ -170,6 +170,10 @@ func (m *Manager) IsServerRouteSupported() bool {
return true return true
} }
func (m *Manager) IsStateful() bool {
return true
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -324,10 +328,16 @@ func (m *Manager) SetLogLevel(log.Level) {
} }
func (m *Manager) EnableRouting() error { func (m *Manager) EnableRouting() error {
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
return fmt.Errorf("enable IP forwarding: %w", err)
}
return nil return nil
} }
func (m *Manager) DisableRouting() error { func (m *Manager) DisableRouting() error {
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
return fmt.Errorf("disable IP forwarding: %w", err)
}
return nil return nil
} }

View File

@@ -3,7 +3,6 @@ package nftables
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"os/exec" "os/exec"
"testing" "testing"
@@ -25,11 +24,8 @@ var ifaceMock = &iFaceMock{
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("100.96.0.1"), IP: netip.MustParseAddr("100.96.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("100.96.0.0/16"),
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -70,11 +66,11 @@ func TestNftablesManager(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
ip := net.ParseIP("100.96.0.1") ip := netip.MustParseAddr("100.96.0.1").Unmap()
testClient := &nftables.Conn{} testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "") rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Flush() err = manager.Flush()
@@ -109,8 +105,6 @@ func TestNftablesManager(t *testing.T) {
} }
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1) compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
expectedExprs2 := []expr.Any{ expectedExprs2 := []expr.Any{
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
@@ -132,7 +126,7 @@ func TestNftablesManager(t *testing.T) {
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
Register: 1, Register: 1,
Data: add.AsSlice(), Data: ip.AsSlice(),
}, },
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
@@ -173,11 +167,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("100.96.0.1"), IP: netip.MustParseAddr("100.96.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("100.96.0.0/16"),
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -197,11 +188,11 @@ func TestNFtablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
ip := net.ParseIP("10.20.0.100") ip := netip.MustParseAddr("10.20.0.100")
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}} port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
if i%100 == 0 { if i%100 == 0 {
@@ -282,8 +273,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
verifyIptablesOutput(t, stdout, stderr) verifyIptablesOutput(t, stdout, stderr)
}) })
ip := net.ParseIP("100.96.0.1") ip := netip.MustParseAddr("100.96.0.1")
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err, "failed to add peer filtering rule") require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering( _, err = manager.AddRouteFiltering(

View File

@@ -573,10 +573,6 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
// AddNatRule appends a nftables rule pair to the nat chain // AddNatRule appends a nftables rule pair to the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error { func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
@@ -1006,10 +1002,6 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
// RemoveNatRule removes the prerouting mark rule // RemoveNatRule removes the prerouting mark rule
func (r *router) RemoveNatRule(pair firewall.RouterPair) error { func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }

View File

@@ -62,5 +62,5 @@ type ConnKey struct {
} }
func (c ConnKey) String() string { func (c ConnKey) String() string {
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort) return fmt.Sprintf("%s:%d %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
} }

View File

@@ -3,6 +3,7 @@ package conntrack
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"sync" "sync"
"time" "time"
@@ -19,6 +20,10 @@ const (
DefaultICMPTimeout = 30 * time.Second DefaultICMPTimeout = 30 * time.Second
// ICMPCleanupInterval is how often we check for stale ICMP connections // ICMPCleanupInterval is how often we check for stale ICMP connections
ICMPCleanupInterval = 15 * time.Second ICMPCleanupInterval = 15 * time.Second
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info,
// which includes the IP header (20 bytes) and transport header (8 bytes)
MaxICMPPayloadLength = 28
) )
// ICMPConnKey uniquely identifies an ICMP connection // ICMPConnKey uniquely identifies an ICMP connection
@@ -29,7 +34,7 @@ type ICMPConnKey struct {
} }
func (i ICMPConnKey) String() string { func (i ICMPConnKey) String() string {
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID) return fmt.Sprintf("%s %s (id %d)", i.SrcIP, i.DstIP, i.ID)
} }
// ICMPConnTrack represents an ICMP connection state // ICMPConnTrack represents an ICMP connection state
@@ -50,6 +55,72 @@ type ICMPTracker struct {
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
} }
// ICMPInfo holds ICMP type, code, and payload for lazy string formatting in logs
type ICMPInfo struct {
TypeCode layers.ICMPv4TypeCode
PayloadData [MaxICMPPayloadLength]byte
// actual length of valid data
PayloadLen int
}
// String implements fmt.Stringer for lazy evaluation in log messages
func (info ICMPInfo) String() string {
if info.isErrorMessage() && info.PayloadLen >= MaxICMPPayloadLength {
if origInfo := info.parseOriginalPacket(); origInfo != "" {
return fmt.Sprintf("%s (original: %s)", info.TypeCode, origInfo)
}
}
return info.TypeCode.String()
}
// isErrorMessage returns true if this ICMP type carries original packet info
func (info ICMPInfo) isErrorMessage() bool {
typ := info.TypeCode.Type()
return typ == 3 || // Destination Unreachable
typ == 5 || // Redirect
typ == 11 || // Time Exceeded
typ == 12 // Parameter Problem
}
// parseOriginalPacket extracts info about the original packet from ICMP payload
func (info ICMPInfo) parseOriginalPacket() string {
if info.PayloadLen < MaxICMPPayloadLength {
return ""
}
// TODO: handle IPv6
if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 {
return ""
}
protocol := info.PayloadData[9]
srcIP := net.IP(info.PayloadData[12:16])
dstIP := net.IP(info.PayloadData[16:20])
transportData := info.PayloadData[20:]
switch nftypes.Protocol(protocol) {
case nftypes.TCP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
return fmt.Sprintf("TCP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
case nftypes.UDP:
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
return fmt.Sprintf("UDP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
case nftypes.ICMP:
icmpType := transportData[0]
icmpCode := transportData[1]
return fmt.Sprintf("ICMP %s → %s (type %d code %d)", srcIP, dstIP, icmpType, icmpCode)
default:
return fmt.Sprintf("Proto %d %s → %s", protocol, srcIP, dstIP)
}
}
// NewICMPTracker creates a new ICMP connection tracker // NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker { func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
if timeout == 0 { if timeout == 0 {
@@ -93,30 +164,64 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint
} }
// TrackOutbound records an outbound ICMP connection // TrackOutbound records an outbound ICMP connection
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) { func (t *ICMPTracker) TrackOutbound(
srcIP netip.Addr,
dstIP netip.Addr,
id uint16,
typecode layers.ICMPv4TypeCode,
payload []byte,
size int,
) {
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists { if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction // if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size) t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, payload, size)
} }
} }
// TrackInbound records an inbound ICMP Echo Request // TrackInbound records an inbound ICMP Echo Request
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) { func (t *ICMPTracker) TrackInbound(
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size) srcIP netip.Addr,
dstIP netip.Addr,
id uint16,
typecode layers.ICMPv4TypeCode,
ruleId []byte,
payload []byte,
size int,
) {
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, payload, size)
} }
// track is the common implementation for tracking both inbound and outbound ICMP connections // track is the common implementation for tracking both inbound and outbound ICMP connections
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) { func (t *ICMPTracker) track(
srcIP netip.Addr,
dstIP netip.Addr,
id uint16,
typecode layers.ICMPv4TypeCode,
direction nftypes.Direction,
ruleId []byte,
payload []byte,
size int,
) {
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size) key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
if exists { if exists {
return return
} }
typ, code := typecode.Type(), typecode.Code() typ, code := typecode.Type(), typecode.Code()
icmpInfo := ICMPInfo{
TypeCode: typecode,
}
if len(payload) > 0 {
icmpInfo.PayloadLen = len(payload)
if icmpInfo.PayloadLen > MaxICMPPayloadLength {
icmpInfo.PayloadLen = MaxICMPPayloadLength
}
copy(icmpInfo.PayloadData[:], payload[:icmpInfo.PayloadLen])
}
// non echo requests don't need tracking // non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) { if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code) t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size) t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
return return
} }
@@ -138,7 +243,7 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typec
t.connections[key] = conn t.connections[key] = conn
t.mutex.Unlock() t.mutex.Unlock()
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code) t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
t.sendEvent(nftypes.TypeStart, conn, ruleId) t.sendEvent(nftypes.TypeStart, conn, ruleId)
} }

View File

@@ -15,7 +15,7 @@ func BenchmarkICMPTracker(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, []byte{}, 0)
} }
}) })
@@ -28,7 +28,7 @@ func BenchmarkICMPTracker(b *testing.B) {
// Pre-populate some connections // Pre-populate some connections
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, []byte{}, 0)
} }
b.ResetTimer() b.ResetTimer()

View File

@@ -86,5 +86,5 @@ type epID stack.TransportEndpointID
func (i epID) String() string { func (i epID) String() string {
// src and remote is swapped // src and remote is swapped
return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort) return fmt.Sprintf("%s:%d %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
} }

View File

@@ -41,7 +41,7 @@ type Forwarder struct {
udpForwarder *udpForwarder udpForwarder *udpForwarder
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
ip net.IP ip tcpip.Address
netstack bool netstack bool
} }
@@ -71,12 +71,11 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
return nil, fmt.Errorf("failed to create NIC: %v", err) return nil, fmt.Errorf("failed to create NIC: %v", err)
} }
ones, _ := iface.Address().Network.Mask.Size()
protoAddr := tcpip.ProtocolAddress{ protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber, Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{ AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()), Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
PrefixLen: ones, PrefixLen: iface.Address().Network.Bits(),
}, },
} }
@@ -116,7 +115,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
netstack: netstack, netstack: netstack,
ip: iface.Address().IP, ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
} }
receiveWindow := defaultReceiveWindow receiveWindow := defaultReceiveWindow
@@ -167,7 +166,7 @@ func (f *Forwarder) Stop() {
} }
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP { func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
if f.netstack && f.ip.Equal(addr.AsSlice()) { if f.netstack && f.ip.Equal(addr) {
return net.IPv4(127, 0, 0, 1) return net.IPv4(127, 0, 0, 1)
} }
return addr.AsSlice() return addr.AsSlice()
@@ -179,7 +178,6 @@ func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uin
} }
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) { func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok { if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
return value.([]byte), true return value.([]byte), true
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok { } else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {

View File

@@ -111,12 +111,12 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
if errInToOut != nil { if errInToOut != nil {
if !isClosedError(errInToOut) { if !isClosedError(errInToOut) {
f.logger.Error("proxyTCP: copy error (in -> out): %v", errInToOut) f.logger.Error("proxyTCP: copy error (in out) for %s: %v", epID(id), errInToOut)
} }
} }
if errOutToIn != nil { if errOutToIn != nil {
if !isClosedError(errOutToIn) { if !isClosedError(errOutToIn) {
f.logger.Error("proxyTCP: copy error (out -> in): %v", errOutToIn) f.logger.Error("proxyTCP: copy error (out in) for %s: %v", epID(id), errOutToIn)
} }
} }

View File

@@ -250,10 +250,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
wg.Wait() wg.Wait()
if outboundErr != nil && !isClosedError(outboundErr) { if outboundErr != nil && !isClosedError(outboundErr) {
f.logger.Error("proxyUDP: copy error (outbound->inbound): %v", outboundErr) f.logger.Error("proxyUDP: copy error (outboundinbound) for %s: %v", epID(id), outboundErr)
} }
if inboundErr != nil && !isClosedError(inboundErr) { if inboundErr != nil && !isClosedError(inboundErr) {
f.logger.Error("proxyUDP: copy error (inbound->outbound): %v", inboundErr) f.logger.Error("proxyUDP: copy error (inboundoutbound) for %s: %v", epID(id), inboundErr)
} }
var rxPackets, txPackets uint64 var rxPackets, txPackets uint64

View File

@@ -45,8 +45,12 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
} }
func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) { func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
if ipv4 := ip.To4(); ipv4 != nil { if !ip.Is4() {
return
}
ipv4 := ip.AsSlice()
high := uint16(ipv4[0]) high := uint16(ipv4[0])
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3]) low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
@@ -58,11 +62,9 @@ func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap,
bit := low % 32 bit := low % 32
bitmap[high].bitmap[index] |= 1 << bit bitmap[high].bitmap[index] |= 1 << bit
ipStr := ipv4.String() if _, exists := ipv4Set[ip]; !exists {
if _, exists := ipv4Set[ipStr]; !exists { ipv4Set[ip] = struct{}{}
ipv4Set[ipStr] = struct{}{} *ipv4Addresses = append(*ipv4Addresses, ip)
*ipv4Addresses = append(*ipv4Addresses, ipStr)
}
} }
} }
@@ -79,12 +81,12 @@ func (m *localIPManager) checkBitmapBit(ip []byte) bool {
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0 return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
} }
func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error { func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses) m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
return nil return nil
} }
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) { func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
addrs, err := iface.Addrs() addrs, err := iface.Addrs()
if err != nil { if err != nil {
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err) log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
@@ -102,7 +104,13 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv
continue continue
} }
if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil { addr, ok := netip.AddrFromSlice(ip)
if !ok {
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
continue
}
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
log.Debugf("process IP failed: %v", err) log.Debugf("process IP failed: %v", err)
} }
} }
@@ -116,8 +124,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
}() }()
var newIPv4Bitmap [256]*ipv4LowBitmap var newIPv4Bitmap [256]*ipv4LowBitmap
ipv4Set := make(map[string]struct{}) ipv4Set := make(map[netip.Addr]struct{})
var ipv4Addresses []string var ipv4Addresses []netip.Addr
// 127.0.0.0/8 // 127.0.0.0/8
newIPv4Bitmap[127] = &ipv4LowBitmap{} newIPv4Bitmap[127] = &ipv4LowBitmap{}

View File

@@ -20,11 +20,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Localhost range", name: "Localhost range",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("127.0.0.2"), testIP: netip.MustParseAddr("127.0.0.2"),
expected: true, expected: true,
@@ -32,11 +29,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Localhost standard address", name: "Localhost standard address",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("127.0.0.1"), testIP: netip.MustParseAddr("127.0.0.1"),
expected: true, expected: true,
@@ -44,11 +38,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Localhost range edge", name: "Localhost range edge",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("127.255.255.255"), testIP: netip.MustParseAddr("127.255.255.255"),
expected: true, expected: true,
@@ -56,11 +47,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Local IP matches", name: "Local IP matches",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("192.168.1.1"), testIP: netip.MustParseAddr("192.168.1.1"),
expected: true, expected: true,
@@ -68,11 +56,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Local IP doesn't match", name: "Local IP doesn't match",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("192.168.1.2"), testIP: netip.MustParseAddr("192.168.1.2"),
expected: false, expected: false,
@@ -80,11 +65,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Local IP doesn't match - addresses 32 apart", name: "Local IP doesn't match - addresses 32 apart",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("192.168.1.33"), testIP: netip.MustParseAddr("192.168.1.33"),
expected: false, expected: false,
@@ -92,11 +74,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "IPv6 address", name: "IPv6 address",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("fe80::1"), IP: netip.MustParseAddr("fe80::1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("fe80::"),
Mask: net.CIDRMask(64, 128),
},
}, },
testIP: netip.MustParseAddr("fe80::1"), testIP: netip.MustParseAddr("fe80::1"),
expected: false, expected: false,

View File

@@ -38,11 +38,8 @@ func TestTracePacket(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"), IP: netip.MustParseAddr("100.10.0.100"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("100.10.0.0/16"),
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
} }
}, },
} }

View File

@@ -39,8 +39,12 @@ const (
// EnvForceUserspaceRouter forces userspace routing even if native routing is available. // EnvForceUserspaceRouter forces userspace routing even if native routing is available.
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER" EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
// EnvEnableNetstackLocalForwarding enables forwarding of local traffic to the native stack when running netstack // EnvEnableLocalForwarding enables forwarding of local traffic to the native stack for internal (non-NetBird) interfaces.
// Leaving this on by default introduces a security risk as sockets on listening on localhost only will be accessible // Default off as it might be security risk because sockets listening on localhost only will become accessible.
EnvEnableLocalForwarding = "NB_ENABLE_LOCAL_FORWARDING"
// EnvEnableNetstackLocalForwarding is an alias for EnvEnableLocalForwarding.
// In netstack mode, it enables forwarding of local traffic to the native stack for all interfaces.
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING" EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
) )
@@ -71,7 +75,6 @@ type Manager struct {
// incomingRules is used for filtering and hooks // incomingRules is used for filtering and hooks
incomingRules map[netip.Addr]RuleSet incomingRules map[netip.Addr]RuleSet
routeRules RouteRules routeRules RouteRules
wgNetwork *net.IPNet
decoders sync.Pool decoders sync.Pool
wgIface common.IFaceMapper wgIface common.IFaceMapper
nativeFirewall firewall.Manager nativeFirewall firewall.Manager
@@ -148,6 +151,11 @@ func parseCreateEnv() (bool, bool) {
if err != nil { if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err) log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
} }
} else if val := os.Getenv(EnvEnableLocalForwarding); val != "" {
enableLocalForwarding, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err)
}
} }
return disableConntrack, enableLocalForwarding return disableConntrack, enableLocalForwarding
@@ -269,7 +277,7 @@ func (m *Manager) determineRouting() error {
log.Info("userspace routing is forced") log.Info("userspace routing is forced")
case !m.netstack && m.nativeFirewall != nil && m.nativeFirewall.IsServerRouteSupported(): case !m.netstack && m.nativeFirewall != nil:
// if the OS supports routing natively, then we don't need to filter/route ourselves // if the OS supports routing natively, then we don't need to filter/route ourselves
// netstack mode won't support native routing as there is no interface // netstack mode won't support native routing as there is no interface
@@ -326,6 +334,10 @@ func (m *Manager) IsServerRouteSupported() bool {
return true return true
} }
func (m *Manager) IsStateful() bool {
return m.stateful
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.AddNatRule(pair) return m.nativeFirewall.AddNatRule(pair)
@@ -606,9 +618,8 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
return true return true
} }
if m.stateful { // for netflow we keep track even if the firewall is stateless
m.trackOutbound(d, srcIP, dstIP, size) m.trackOutbound(d, srcIP, dstIP, size)
}
return false return false
} }
@@ -660,7 +671,7 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
flags := getTCPFlags(&d.tcp) flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size) m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
} }
} }
@@ -673,7 +684,7 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
flags := getTCPFlags(&d.tcp) flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size) m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size) m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
} }
} }
@@ -777,9 +788,10 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
return true return true
} }
// if running in netstack mode we need to pass this to the forwarder // If requested we pass local traffic to internal interfaces to the forwarder.
if m.netstack && m.localForwarding { // netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder.
return m.handleNetstackLocalTraffic(packetData) if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) {
return m.handleForwardedLocalTraffic(packetData)
} }
// track inbound packets to get the correct direction and session id for flows // track inbound packets to get the correct direction and session id for flows
@@ -789,8 +801,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
return false return false
} }
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool { func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
fwd := m.forwarder.Load() fwd := m.forwarder.Load()
if fwd == nil { if fwd == nil {
m.logger.Trace("Dropping local packet (forwarder not initialized)") m.logger.Trace("Dropping local packet (forwarder not initialized)")
@@ -1088,11 +1099,6 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
return true return true
} }
// SetNetwork of the wireguard interface to which filtering applied
func (m *Manager) SetNetwork(network *net.IPNet) {
m.wgNetwork = network
}
// AddUDPPacketHook calls hook when UDP packet from given direction matched // AddUDPPacketHook calls hook when UDP packet from given direction matched
// //
// Hook function returns flag which indicates should be the matched package dropped or not // Hook function returns flag which indicates should be the matched package dropped or not

View File

@@ -174,11 +174,6 @@ func BenchmarkCoreFiltering(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
// Apply scenario-specific setup // Apply scenario-specific setup
sc.setupFunc(manager) sc.setupFunc(manager)
@@ -219,11 +214,6 @@ func BenchmarkStateScaling(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
// Pre-populate connection table // Pre-populate connection table
srcIPs := generateRandomIPs(count) srcIPs := generateRandomIPs(count)
dstIPs := generateRandomIPs(count) dstIPs := generateRandomIPs(count)
@@ -267,11 +257,6 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
srcIP := generateRandomIPs(1)[0] srcIP := generateRandomIPs(1)[0]
dstIP := generateRandomIPs(1)[0] dstIP := generateRandomIPs(1)[0]
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP) outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
@@ -304,10 +289,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -321,10 +302,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -339,10 +316,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -356,10 +329,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -373,10 +342,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -390,10 +355,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -408,10 +369,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "post_handshake", state: "post_handshake",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -426,10 +383,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -443,10 +396,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -593,11 +542,6 @@ func BenchmarkLongLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
// Single rule to allow all return traffic from port 80 // Single rule to allow all return traffic from port 80
@@ -681,11 +625,6 @@ func BenchmarkShortLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
// Single rule to allow all return traffic from port 80 // Single rule to allow all return traffic from port 80
@@ -797,11 +736,6 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "") _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
@@ -882,11 +816,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
if sc.rules { if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "") _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
require.NoError(b, err) require.NoError(b, err)
@@ -1032,7 +961,8 @@ func BenchmarkRouteACLs(b *testing.B) {
} }
for _, r := range rules { for _, r := range rules {
_, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept) dst := fw.Network{Prefix: r.dest}
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }

View File

@@ -19,12 +19,8 @@ import (
) )
func TestPeerACLFiltering(t *testing.T) { func TestPeerACLFiltering(t *testing.T) {
localIP := net.ParseIP("100.10.0.100") localIP := netip.MustParseAddr("100.10.0.100")
wgNet := &net.IPNet{ wgNet := netip.MustParsePrefix("100.10.0.0/16")
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
@@ -43,8 +39,6 @@ func TestPeerACLFiltering(t *testing.T) {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Close(nil))
}) })
manager.wgNetwork = wgNet
err = manager.UpdateLocalIPs() err = manager.UpdateLocalIPs()
require.NoError(t, err) require.NoError(t, err)
@@ -581,14 +575,13 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
dev := mocks.NewMockDevice(ctrl) dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes() dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
localIP, wgNet, err := net.ParseCIDR(network) wgNet := netip.MustParsePrefix(network)
require.NoError(tb, err)
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: localIP, IP: wgNet.Addr(),
Network: wgNet, Network: wgNet,
} }
}, },
@@ -1440,11 +1433,8 @@ func TestRouteACLSet(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"), IP: netip.MustParseAddr("100.10.0.100"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("100.10.0.0/16"),
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
} }
}, },
} }

View File

@@ -271,11 +271,8 @@ func TestNotMatchByIP(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"), IP: netip.MustParseAddr("100.10.0.100"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("100.10.0.0/16"),
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
} }
}, },
} }
@@ -285,10 +282,6 @@ func TestNotMatchByIP(t *testing.T) {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
} }
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
ip := net.ParseIP("0.0.0.0") ip := net.ParseIP("0.0.0.0")
proto := fw.ProtocolUDP proto := fw.ProtocolUDP
@@ -396,10 +389,6 @@ func TestProcessOutgoingHooks(t *testing.T) {
}, false, flowLogger) }, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close() manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger) manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
defer func() { defer func() {
@@ -509,11 +498,6 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}, false, flowLogger) }, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close() // Close the existing tracker manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger) manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
manager.decoders = sync.Pool{ manager.decoders = sync.Pool{

View File

@@ -164,7 +164,7 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
return nil return nil
} }
if u.address.Network.Contains(a.AsSlice()) { if u.address.Network.Contains(a) {
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address) log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address) return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
} }

View File

@@ -0,0 +1,17 @@
package configurer
import (
"net"
"net/netip"
)
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
ipNets := make([]net.IPNet, len(prefixes))
for i, prefix := range prefixes {
ipNets[i] = net.IPNet{
IP: prefix.Addr().AsSlice(), // Convert netip.Addr to net.IP
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), // Create subnet mask
}
}
return ipNets
}

View File

@@ -5,6 +5,7 @@ package configurer
import ( import (
"fmt" "fmt"
"net" "net"
"net/netip"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -12,6 +13,8 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
var zeroKey wgtypes.Key
type KernelConfigurer struct { type KernelConfigurer struct {
deviceName string deviceName string
} }
@@ -43,7 +46,7 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error
return nil return nil
} }
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {
return err return err
@@ -52,7 +55,7 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, ke
PublicKey: peerKeyParsed, PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false, ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP // don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: allowedIps, AllowedIPs: prefixesToIPNets(allowedIps),
PersistentKeepaliveInterval: &keepAlive, PersistentKeepaliveInterval: &keepAlive,
Endpoint: endpoint, Endpoint: endpoint,
PresharedKey: preSharedKey, PresharedKey: preSharedKey,
@@ -89,10 +92,10 @@ func (c *KernelConfigurer) RemovePeer(peerKey string) error {
return nil return nil
} }
func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error { func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
_, ipNet, err := net.ParseCIDR(allowedIP) ipNet := net.IPNet{
if err != nil { IP: allowedIP.Addr().AsSlice(),
return err Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()),
} }
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
@@ -103,7 +106,7 @@ func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error
PublicKey: peerKeyParsed, PublicKey: peerKeyParsed,
UpdateOnly: true, UpdateOnly: true,
ReplaceAllowedIPs: false, ReplaceAllowedIPs: false,
AllowedIPs: []net.IPNet{*ipNet}, AllowedIPs: []net.IPNet{ipNet},
} }
config := wgtypes.Config{ config := wgtypes.Config{
@@ -116,10 +119,10 @@ func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error
return nil return nil
} }
func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP string) error { func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
_, ipNet, err := net.ParseCIDR(allowedIP) ipNet := net.IPNet{
if err != nil { IP: allowedIP.Addr().AsSlice(),
return fmt.Errorf("parse allowed IP: %w", err) Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()),
} }
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
@@ -187,7 +190,11 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
if err != nil { if err != nil {
return err return err
} }
defer wg.Close() defer func() {
if err := wg.Close(); err != nil {
log.Errorf("Failed to close wgctrl client: %v", err)
}
}()
// validate if device with name exists // validate if device with name exists
_, err = wg.Device(c.deviceName) _, err = wg.Device(c.deviceName)
@@ -201,6 +208,47 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
func (c *KernelConfigurer) Close() { func (c *KernelConfigurer) Close() {
} }
func (c *KernelConfigurer) FullStats() (*Stats, error) {
wg, err := wgctrl.New()
if err != nil {
return nil, fmt.Errorf("wgctl: %w", err)
}
defer func() {
err = wg.Close()
if err != nil {
log.Errorf("Got error while closing wgctl: %v", err)
}
}()
wgDevice, err := wg.Device(c.deviceName)
if err != nil {
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
}
fullStats := &Stats{
DeviceName: wgDevice.Name,
PublicKey: wgDevice.PublicKey.String(),
ListenPort: wgDevice.ListenPort,
FWMark: wgDevice.FirewallMark,
Peers: []Peer{},
}
for _, p := range wgDevice.Peers {
peer := Peer{
PublicKey: p.PublicKey.String(),
AllowedIPs: p.AllowedIPs,
TxBytes: p.TransmitBytes,
RxBytes: p.ReceiveBytes,
LastHandshake: p.LastHandshakeTime,
PresharedKey: p.PresharedKey != zeroKey,
}
if p.Endpoint != nil {
peer.Endpoint = *p.Endpoint
}
fullStats.Peers = append(fullStats.Peers, peer)
}
return fullStats, nil
}
func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) { func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
stats := make(map[string]WGStats) stats := make(map[string]WGStats)
wg, err := wgctrl.New() wg, err := wgctrl.New()

View File

@@ -5,6 +5,7 @@ import (
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"net" "net"
"net/netip"
"os" "os"
"runtime" "runtime"
"strconv" "strconv"
@@ -19,10 +20,17 @@ import (
) )
const ( const (
privateKey = "private_key"
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec" ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec" ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
ipcKeyTxBytes = "tx_bytes" ipcKeyTxBytes = "tx_bytes"
ipcKeyRxBytes = "rx_bytes" ipcKeyRxBytes = "rx_bytes"
allowedIP = "allowed_ip"
endpoint = "endpoint"
fwmark = "fwmark"
listenPort = "listen_port"
publicKey = "public_key"
presharedKey = "preshared_key"
) )
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found") var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
@@ -60,7 +68,7 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error
return c.device.IpcSet(toWgUserspaceString(config)) return c.device.IpcSet(toWgUserspaceString(config))
} }
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {
return err return err
@@ -69,7 +77,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, kee
PublicKey: peerKeyParsed, PublicKey: peerKeyParsed,
ReplaceAllowedIPs: false, ReplaceAllowedIPs: false,
// don't replace allowed ips, wg will handle duplicated peer IP // don't replace allowed ips, wg will handle duplicated peer IP
AllowedIPs: allowedIps, AllowedIPs: prefixesToIPNets(allowedIps),
PersistentKeepaliveInterval: &keepAlive, PersistentKeepaliveInterval: &keepAlive,
PresharedKey: preSharedKey, PresharedKey: preSharedKey,
Endpoint: endpoint, Endpoint: endpoint,
@@ -99,10 +107,10 @@ func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
return c.device.IpcSet(toWgUserspaceString(config)) return c.device.IpcSet(toWgUserspaceString(config))
} }
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error { func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
_, ipNet, err := net.ParseCIDR(allowedIP) ipNet := net.IPNet{
if err != nil { IP: allowedIP.Addr().AsSlice(),
return err Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()),
} }
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
@@ -113,7 +121,7 @@ func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
PublicKey: peerKeyParsed, PublicKey: peerKeyParsed,
UpdateOnly: true, UpdateOnly: true,
ReplaceAllowedIPs: false, ReplaceAllowedIPs: false,
AllowedIPs: []net.IPNet{*ipNet}, AllowedIPs: []net.IPNet{ipNet},
} }
config := wgtypes.Config{ config := wgtypes.Config{
@@ -123,7 +131,7 @@ func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
return c.device.IpcSet(toWgUserspaceString(config)) return c.device.IpcSet(toWgUserspaceString(config))
} }
func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error { func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
ipc, err := c.device.IpcGet() ipc, err := c.device.IpcGet()
if err != nil { if err != nil {
return err return err
@@ -146,6 +154,8 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
foundPeer := false foundPeer := false
removedAllowedIP := false removedAllowedIP := false
ip := allowedIP.String()
for _, line := range lines { for _, line := range lines {
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
@@ -168,8 +178,8 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
// Append the line to the output string // Append the line to the output string
if foundPeer && strings.HasPrefix(line, "allowed_ip=") { if foundPeer && strings.HasPrefix(line, "allowed_ip=") {
allowedIP := strings.TrimPrefix(line, "allowed_ip=") allowedIPStr := strings.TrimPrefix(line, "allowed_ip=")
_, ipNet, err := net.ParseCIDR(allowedIP) _, ipNet, err := net.ParseCIDR(allowedIPStr)
if err != nil { if err != nil {
return err return err
} }
@@ -186,6 +196,15 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
return c.device.IpcSet(toWgUserspaceString(config)) return c.device.IpcSet(toWgUserspaceString(config))
} }
func (c *WGUSPConfigurer) FullStats() (*Stats, error) {
ipcStr, err := c.device.IpcGet()
if err != nil {
return nil, fmt.Errorf("IpcGet failed: %w", err)
}
return parseStatus(c.deviceName, ipcStr)
}
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool // startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
func (t *WGUSPConfigurer) startUAPI() { func (t *WGUSPConfigurer) startUAPI() {
var err error var err error
@@ -365,3 +384,136 @@ func getFwmark() int {
} }
return 0 return 0
} }
func hexToWireguardKey(hexKey string) (wgtypes.Key, error) {
// Decode hex string to bytes
keyBytes, err := hex.DecodeString(hexKey)
if err != nil {
return wgtypes.Key{}, fmt.Errorf("failed to decode hex key: %w", err)
}
// Check if we have the right number of bytes (WireGuard keys are 32 bytes)
if len(keyBytes) != 32 {
return wgtypes.Key{}, fmt.Errorf("invalid key length: expected 32 bytes, got %d", len(keyBytes))
}
// Convert to wgtypes.Key
var key wgtypes.Key
copy(key[:], keyBytes)
return key, nil
}
func parseStatus(deviceName, ipcStr string) (*Stats, error) {
stats := &Stats{DeviceName: deviceName}
var currentPeer *Peer
for _, line := range strings.Split(strings.TrimSpace(ipcStr), "\n") {
if line == "" {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
continue
}
key := parts[0]
val := parts[1]
switch key {
case privateKey:
key, err := hexToWireguardKey(val)
if err != nil {
log.Errorf("failed to parse private key: %v", err)
continue
}
stats.PublicKey = key.PublicKey().String()
case publicKey:
// Save previous peer
if currentPeer != nil {
stats.Peers = append(stats.Peers, *currentPeer)
}
key, err := hexToWireguardKey(val)
if err != nil {
log.Errorf("failed to parse public key: %v", err)
continue
}
currentPeer = &Peer{
PublicKey: key.String(),
}
case listenPort:
if port, err := strconv.Atoi(val); err == nil {
stats.ListenPort = port
}
case fwmark:
if fwmark, err := strconv.Atoi(val); err == nil {
stats.FWMark = fwmark
}
case endpoint:
if currentPeer == nil {
continue
}
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
if err != nil {
log.Errorf("failed to parse endpoint: %v", err)
continue
}
port, err := strconv.Atoi(portStr)
if err != nil {
log.Errorf("failed to parse endpoint port: %v", err)
continue
}
currentPeer.Endpoint = net.UDPAddr{
IP: net.ParseIP(host),
Port: port,
}
case allowedIP:
if currentPeer == nil {
continue
}
_, ipnet, err := net.ParseCIDR(val)
if err == nil {
currentPeer.AllowedIPs = append(currentPeer.AllowedIPs, *ipnet)
}
case ipcKeyTxBytes:
if currentPeer == nil {
continue
}
rxBytes, err := toBytes(val)
if err != nil {
continue
}
currentPeer.TxBytes = rxBytes
case ipcKeyRxBytes:
if currentPeer == nil {
continue
}
rxBytes, err := toBytes(val)
if err != nil {
continue
}
currentPeer.RxBytes = rxBytes
case ipcKeyLastHandshakeTimeSec:
if currentPeer == nil {
continue
}
ts, err := toLastHandshake(val)
if err != nil {
continue
}
currentPeer.LastHandshake = ts
case presharedKey:
if currentPeer == nil {
continue
}
if val != "" {
currentPeer.PresharedKey = true
}
}
}
if currentPeer != nil {
stats.Peers = append(stats.Peers, *currentPeer)
}
return stats, nil
}

View File

@@ -0,0 +1,24 @@
package configurer
import (
"net"
"time"
)
type Peer struct {
PublicKey string
Endpoint net.UDPAddr
AllowedIPs []net.IPNet
TxBytes int64
RxBytes int64
LastHandshake time.Time
PresharedKey bool
}
type Stats struct {
DeviceName string
PublicKey string
ListenPort int
FWMark int
Peers []Peer
}

View File

@@ -24,6 +24,7 @@ type WGTunDevice struct {
mtu int mtu int
iceBind *bind.ICEBind iceBind *bind.ICEBind
tunAdapter TunAdapter tunAdapter TunAdapter
disableDNS bool
name string name string
device *device.Device device *device.Device
@@ -32,7 +33,7 @@ type WGTunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice { func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
return &WGTunDevice{ return &WGTunDevice{
address: address, address: address,
port: port, port: port,
@@ -40,6 +41,7 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind
mtu: mtu, mtu: mtu,
iceBind: iceBind, iceBind: iceBind,
tunAdapter: tunAdapter, tunAdapter: tunAdapter,
disableDNS: disableDNS,
} }
} }
@@ -49,6 +51,13 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
routesString := routesToString(routes) routesString := routesToString(routes)
searchDomainsToString := searchDomainsToString(searchDomains) searchDomainsToString := searchDomainsToString(searchDomains)
// Skip DNS configuration when DisableDNS is enabled
if t.disableDNS {
log.Info("DNS is disabled, skipping DNS and search domain configuration")
dns = ""
searchDomainsToString = ""
}
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString) fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
if err != nil { if err != nil {
log.Errorf("failed to create Android interface: %s", err) log.Errorf("failed to create Android interface: %s", err)

View File

@@ -1,7 +1,6 @@
package device package device
import ( import (
"net"
"net/netip" "net/netip"
"sync" "sync"
@@ -24,9 +23,6 @@ type PacketFilter interface {
// RemovePacketHook removes hook by ID // RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error RemovePacketHook(hookID string) error
// SetNetwork of the wireguard interface to which filtering applied
SetNetwork(*net.IPNet)
} }
// FilteredDevice to override Read or Write of packets // FilteredDevice to override Read or Write of packets

View File

@@ -51,7 +51,11 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
log.Info("create nbnetstack tun interface") log.Info("create nbnetstack tun interface")
// TODO: get from service listener runtime IP // TODO: get from service listener runtime IP
dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1) dnsAddr, err := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
if err != nil {
return nil, fmt.Errorf("last ip: %w", err)
}
log.Debugf("netstack using address: %s", t.address.IP) log.Debugf("netstack using address: %s", t.address.IP)
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu) t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
log.Debugf("netstack using dns address: %s", dnsAddr) log.Debugf("netstack using dns address: %s", dnsAddr)

View File

@@ -2,6 +2,7 @@ package device
import ( import (
"net" "net"
"net/netip"
"time" "time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -11,10 +12,11 @@ import (
type WGConfigurer interface { type WGConfigurer interface {
ConfigureInterface(privateKey string, port int) error ConfigureInterface(privateKey string, port int) error
UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
RemoveAllowedIP(peerKey string, allowedIP string) error RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
Close() Close()
GetStats() (map[string]configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error)
} }

View File

@@ -64,7 +64,15 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error {
} }
ip := address.IP.String() ip := address.IP.String()
mask := "0x" + address.Network.Mask.String()
// Convert prefix length to hex netmask
prefixLen := address.Network.Bits()
if !address.IP.Is4() {
return fmt.Errorf("IPv6 not supported for interface assignment")
}
maskBits := uint32(0xffffffff) << (32 - prefixLen)
mask := fmt.Sprintf("0x%08x", maskBits)
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name) log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)

View File

@@ -43,6 +43,7 @@ type WGIFaceOpts struct {
MobileArgs *device.MobileIFaceArguments MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net TransportNet transport.Net
FilterFn bind.FilterFn FilterFn bind.FilterFn
DisableDNS bool
} }
// WGIface represents an interface instance // WGIface represents an interface instance
@@ -111,14 +112,14 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
} }
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist // UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
// Endpoint is optional // Endpoint is optional.
// If allowedIps is given it will be added to the existing ones.
func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
netIPNets := prefixesToIPNets(allowedIps) log.Debugf("updating interface %s peer %s, endpoint %s, allowedIPs %v", w.tun.DeviceName(), peerKey, endpoint, allowedIps)
log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint) return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
return w.configurer.UpdatePeer(peerKey, netIPNets, keepAlive, endpoint, preSharedKey)
} }
// RemovePeer removes a Wireguard Peer from the interface iface // RemovePeer removes a Wireguard Peer from the interface iface
@@ -131,7 +132,7 @@ func (w *WGIface) RemovePeer(peerKey string) error {
} }
// AddAllowedIP adds a prefix to the allowed IPs list of peer // AddAllowedIP adds a prefix to the allowed IPs list of peer
func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error { func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
@@ -140,7 +141,7 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error {
} }
// RemoveAllowedIP removes a prefix from the allowed IPs list of peer // RemoveAllowedIP removes a prefix from the allowed IPs list of peer
func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error { func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
@@ -185,7 +186,6 @@ func (w *WGIface) SetFilter(filter device.PacketFilter) error {
} }
w.filter = filter w.filter = filter
w.filter.SetNetwork(w.tun.WgAddress().Network)
w.tun.FilteredDevice().SetFilter(filter) w.tun.FilteredDevice().SetFilter(filter)
return nil return nil
@@ -217,6 +217,10 @@ func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) {
return w.configurer.GetStats() return w.configurer.GetStats()
} }
func (w *WGIface) FullStats() (*configurer.Stats, error) {
return w.configurer.FullStats()
}
func (w *WGIface) waitUntilRemoved() error { func (w *WGIface) waitUntilRemoved() error {
maxWaitTime := 5 * time.Second maxWaitTime := 5 * time.Second
timeout := time.NewTimer(maxWaitTime) timeout := time.NewTimer(maxWaitTime)
@@ -251,14 +255,3 @@ func (w *WGIface) GetNet() *netstack.Net {
return w.tun.GetNet() return w.tun.GetNet()
} }
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
ipNets := make([]net.IPNet, len(prefixes))
for i, prefix := range prefixes {
ipNets[i] = net.IPNet{
IP: net.IP(prefix.Addr().AsSlice()), // Convert netip.Addr to net.IP
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), // Create subnet mask
}
}
return ipNets
}

View File

@@ -18,7 +18,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
userspaceBind: true, userspaceBind: true,
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter), tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind), wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
} }
return wgIFace, nil return wgIFace, nil

View File

@@ -5,7 +5,6 @@
package mocks package mocks
import ( import (
net "net"
"net/netip" "net/netip"
reflect "reflect" reflect "reflect"
@@ -90,15 +89,3 @@ func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomo
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
} }
// SetNetwork mocks base method.
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetNetwork", arg0)
}
// SetNetwork indicates an expected call of SetNetwork.
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
}

View File

@@ -1,8 +1,6 @@
package netstack package netstack
import ( import (
"fmt"
"net"
"net/netip" "net/netip"
"os" "os"
"strconv" "strconv"
@@ -15,8 +13,8 @@ import (
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY" const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
type NetStackTun struct { //nolint:revive type NetStackTun struct { //nolint:revive
address net.IP address netip.Addr
dnsAddress net.IP dnsAddress netip.Addr
mtu int mtu int
listenAddress string listenAddress string
@@ -24,7 +22,7 @@ type NetStackTun struct { //nolint:revive
tundev tun.Device tundev tun.Device
} }
func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu int) *NetStackTun { func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun {
return &NetStackTun{ return &NetStackTun{
address: address, address: address,
dnsAddress: dnsAddress, dnsAddress: dnsAddress,
@@ -34,19 +32,9 @@ func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu
} }
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) { func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
addr, ok := netip.AddrFromSlice(t.address)
if !ok {
return nil, nil, fmt.Errorf("convert address to netip.Addr: %v", t.address)
}
dnsAddr, ok := netip.AddrFromSlice(t.dnsAddress)
if !ok {
return nil, nil, fmt.Errorf("convert dns address to netip.Addr: %v", t.dnsAddress)
}
nsTunDev, tunNet, err := netstack.CreateNetTUN( nsTunDev, tunNet, err := netstack.CreateNetTUN(
[]netip.Addr{addr.Unmap()}, []netip.Addr{t.address},
[]netip.Addr{dnsAddr.Unmap()}, []netip.Addr{t.dnsAddress},
t.mtu) t.mtu)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err

View File

@@ -2,28 +2,27 @@ package wgaddr
import ( import (
"fmt" "fmt"
"net" "net/netip"
) )
// Address WireGuard parsed address // Address WireGuard parsed address
type Address struct { type Address struct {
IP net.IP IP netip.Addr
Network *net.IPNet Network netip.Prefix
} }
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address // ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
func ParseWGAddress(address string) (Address, error) { func ParseWGAddress(address string) (Address, error) {
ip, network, err := net.ParseCIDR(address) prefix, err := netip.ParsePrefix(address)
if err != nil { if err != nil {
return Address{}, err return Address{}, err
} }
return Address{ return Address{
IP: ip, IP: prefix.Addr().Unmap(),
Network: network, Network: prefix.Masked(),
}, nil }, nil
} }
func (addr Address) String() string { func (addr Address) String() string {
maskSize, _ := addr.Network.Mask.Size() return fmt.Sprintf("%s/%d", addr.IP.String(), addr.Network.Bits())
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
} }

View File

@@ -58,6 +58,11 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
d.mutex.Lock() d.mutex.Lock()
defer d.mutex.Unlock() defer d.mutex.Unlock()
if d.firewall == nil {
log.Debug("firewall manager is not supported, skipping firewall rules")
return
}
start := time.Now() start := time.Now()
defer func() { defer func() {
total := 0 total := 0
@@ -69,14 +74,8 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
time.Since(start), total) time.Since(start), total)
}() }()
if d.firewall == nil {
log.Debug("firewall manager is not supported, skipping firewall rules")
return
}
d.applyPeerACLs(networkMap) d.applyPeerACLs(networkMap)
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil { if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
log.Errorf("Failed to apply route ACLs: %v", err) log.Errorf("Failed to apply route ACLs: %v", err)
} }
@@ -285,8 +284,10 @@ func (d *DefaultManager) protoRuleToFirewallRule(
case mgmProto.RuleDirection_IN: case mgmProto.RuleDirection_IN:
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName) rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
case mgmProto.RuleDirection_OUT: case mgmProto.RuleDirection_OUT:
// TODO: Remove this soon. Outbound rules are obsolete. if d.firewall.IsStateful() {
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already return "", nil, nil
}
// return traffic for outbound connections if firewall is stateless
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName) rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
default: default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule") return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
@@ -397,11 +398,15 @@ func (d *DefaultManager) squashAcceptRules(
// //
// We zeroed this to notify squash function that this protocol can't be squashed. // We zeroed this to notify squash function that this protocol can't be squashed.
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) { addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
drop := r.Action == mgmProto.RuleAction_DROP || r.Port != "" hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP ||
if drop { r.Port != "" || !portInfoEmpty(r.PortInfo)
if hasPortRestrictions {
// Don't squash rules with port restrictions
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}} protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
return return
} }
if _, ok := protocols[r.Protocol]; !ok { if _, ok := protocols[r.Protocol]; !ok {
protocols[r.Protocol] = &protoMatch{ protocols[r.Protocol] = &protoMatch{
ips: map[string]int{}, ips: map[string]int{},

View File

@@ -1,13 +1,14 @@
package acl package acl
import ( import (
"net" "net/netip"
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks" "github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/netflow"
@@ -42,35 +43,31 @@ func TestDefaultManager(t *testing.T) {
ifaceMock := mocks.NewMockIFaceMapper(ctrl) ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes() ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any()) ifaceMock.EXPECT().SetFilter(gomock.Any())
ip, network, err := net.ParseCIDR("172.0.0.1/32") network := netip.MustParsePrefix("172.0.0.1/32")
if err != nil {
t.Fatalf("failed to parse IP address: %v", err)
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes() ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{ ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: ip, IP: network.Addr(),
Network: network, Network: network,
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
if err != nil { require.NoError(t, err)
t.Errorf("create firewall: %v", err) defer func() {
return err = fw.Close(nil)
} require.NoError(t, err)
defer func(fw manager.Manager) { }()
_ = fw.Close(nil)
}(fw)
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)
t.Run("apply firewall rules", func(t *testing.T) { t.Run("apply firewall rules", func(t *testing.T) {
acl.ApplyFiltering(networkMap, false) acl.ApplyFiltering(networkMap, false)
if len(acl.peerRulesPairs) != 2 { if fw.IsStateful() {
t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs) assert.Equal(t, 0, len(acl.peerRulesPairs))
return } else {
assert.Equal(t, 2, len(acl.peerRulesPairs))
} }
}) })
@@ -94,12 +91,13 @@ func TestDefaultManager(t *testing.T) {
acl.ApplyFiltering(networkMap, false) acl.ApplyFiltering(networkMap, false)
// we should have one old and one new rule in the existed rules expectedRules := 2
if len(acl.peerRulesPairs) != 2 { if fw.IsStateful() {
t.Errorf("firewall rules not applied") expectedRules = 1 // only the inbound rule
return
} }
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
// check that old rule was removed // check that old rule was removed
previousCount := 0 previousCount := 0
for id := range acl.peerRulesPairs { for id := range acl.peerRulesPairs {
@@ -107,26 +105,86 @@ func TestDefaultManager(t *testing.T) {
previousCount++ previousCount++
} }
} }
if previousCount != 1 {
t.Errorf("old rule was not removed") expectedPreviousCount := 0
if !fw.IsStateful() {
expectedPreviousCount = 1
} }
assert.Equal(t, expectedPreviousCount, previousCount)
}) })
t.Run("handle default rules", func(t *testing.T) { t.Run("handle default rules", func(t *testing.T) {
networkMap.FirewallRules = networkMap.FirewallRules[:0] networkMap.FirewallRules = networkMap.FirewallRules[:0]
networkMap.FirewallRulesIsEmpty = true networkMap.FirewallRulesIsEmpty = true
if acl.ApplyFiltering(networkMap, false); len(acl.peerRulesPairs) != 0 { acl.ApplyFiltering(networkMap, false)
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs)) assert.Equal(t, 0, len(acl.peerRulesPairs))
return
}
networkMap.FirewallRulesIsEmpty = false networkMap.FirewallRulesIsEmpty = false
acl.ApplyFiltering(networkMap, false) acl.ApplyFiltering(networkMap, false)
if len(acl.peerRulesPairs) != 1 {
t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs)) expectedRules := 1
return if fw.IsStateful() {
expectedRules = 1 // only inbound allow-all rule
} }
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
})
}
func TestDefaultManagerStateless(t *testing.T) {
// stateless currently only in userspace, so we have to disable kernel
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv("NB_DISABLE_CONNTRACK", "true")
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "80",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
Port: "53",
},
},
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
require.NoError(t, err)
defer func() {
err = fw.Close(nil)
require.NoError(t, err)
}()
acl := NewDefaultManager(fw)
t.Run("stateless firewall creates outbound rules", func(t *testing.T) {
acl.ApplyFiltering(networkMap, false)
// In stateless mode, we should have both inbound and outbound rules
assert.False(t, fw.IsStateful())
assert.Equal(t, 2, len(acl.peerRulesPairs))
}) })
} }
@@ -192,42 +250,19 @@ func TestDefaultManagerSquashRules(t *testing.T) {
manager := &DefaultManager{} manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap) rules, _ := manager.squashAcceptRules(networkMap)
if len(rules) != 2 { assert.Equal(t, 2, len(rules))
t.Errorf("rules should contain 2, got: %v", rules)
return
}
r := rules[0] r := rules[0]
switch { assert.Equal(t, "0.0.0.0", r.PeerIP)
case r.PeerIP != "0.0.0.0": assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction)
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP) assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
return assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
case r.Direction != mgmProto.RuleDirection_IN:
t.Errorf("direction should be IN, got: %v", r.Direction)
return
case r.Protocol != mgmProto.RuleProtocol_ALL:
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return
case r.Action != mgmProto.RuleAction_ACCEPT:
t.Errorf("action should be ACCEPT, got: %v", r.Action)
return
}
r = rules[1] r = rules[1]
switch { assert.Equal(t, "0.0.0.0", r.PeerIP)
case r.PeerIP != "0.0.0.0": assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction)
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP) assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
return assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
case r.Direction != mgmProto.RuleDirection_OUT:
t.Errorf("direction should be OUT, got: %v", r.Direction)
return
case r.Protocol != mgmProto.RuleProtocol_ALL:
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return
case r.Action != mgmProto.RuleAction_ACCEPT:
t.Errorf("action should be ACCEPT, got: %v", r.Action)
return
}
} }
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) { func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
@@ -291,8 +326,435 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
} }
manager := &DefaultManager{} manager := &DefaultManager{}
if rules, _ := manager.squashAcceptRules(networkMap); len(rules) != len(networkMap.FirewallRules) { rules, _ := manager.squashAcceptRules(networkMap)
t.Errorf("we should get the same amount of rules as output, got %v", len(rules)) assert.Equal(t, len(networkMap.FirewallRules), len(rules))
}
func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) {
tests := []struct {
name string
rules []*mgmProto.FirewallRule
expectedCount int
description string
}{
{
name: "should not squash rules with port ranges",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
},
expectedCount: 4,
description: "Rules with port ranges should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with specific ports",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
},
expectedCount: 4,
description: "Rules with specific ports should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with legacy port field",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
expectedCount: 4,
description: "Rules with legacy port field should not be squashed",
},
{
name: "should not squash rules with DROP action",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "Rules with DROP action should not be squashed",
},
{
name: "should squash rules without port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 1,
description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule",
},
{
name: "mixed rules should not squash protocol with port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "TCP should not be squashed because one rule has port restrictions",
},
{
name: "should squash UDP but not TCP when TCP has port restrictions",
rules: []*mgmProto.FirewallRule{
// TCP rules with port restrictions - should NOT be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
// UDP rules without port restrictions - SHOULD be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
},
expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0)
description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
{AllowedIps: []string{"10.93.0.4"}},
},
FirewallRules: tt.rules,
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, tt.expectedCount, len(rules), tt.description)
// For squashed rules, verify we get the expected 0.0.0.0 rule
if tt.expectedCount == 1 {
assert.Equal(t, "0.0.0.0", rules[0].PeerIP)
assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action)
}
})
}
}
func TestPortInfoEmpty(t *testing.T) {
tests := []struct {
name string
portInfo *mgmProto.PortInfo
expected bool
}{
{
name: "nil PortInfo should be empty",
portInfo: nil,
expected: true,
},
{
name: "PortInfo with zero port should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 0,
},
},
expected: true,
},
{
name: "PortInfo with valid port should not be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
expected: false,
},
{
name: "PortInfo with nil range should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: nil,
},
},
expected: true,
},
{
name: "PortInfo with zero start range should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 0,
End: 100,
},
},
},
expected: true,
},
{
name: "PortInfo with zero end range should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 80,
End: 0,
},
},
},
expected: true,
},
{
name: "PortInfo with valid range should not be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := portInfoEmpty(tt.portInfo)
assert.Equal(t, tt.expected, result)
})
} }
} }
@@ -336,33 +798,29 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
ifaceMock := mocks.NewMockIFaceMapper(ctrl) ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes() ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any()) ifaceMock.EXPECT().SetFilter(gomock.Any())
ip, network, err := net.ParseCIDR("172.0.0.1/32") network := netip.MustParsePrefix("172.0.0.1/32")
if err != nil {
t.Fatalf("failed to parse IP address: %v", err)
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes() ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{ ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: ip, IP: network.Addr(),
Network: network, Network: network,
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
if err != nil { require.NoError(t, err)
t.Errorf("create firewall: %v", err) defer func() {
return err = fw.Close(nil)
} require.NoError(t, err)
defer func(fw manager.Manager) { }()
_ = fw.Close(nil)
}(fw)
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)
acl.ApplyFiltering(networkMap, false) acl.ApplyFiltering(networkMap, false)
if len(acl.peerRulesPairs) != 3 { expectedRules := 3
t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs)) if fw.IsStateful() {
return expectedRules = 3 // 2 inbound rules + SSH rule
} }
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
} }

View File

@@ -68,8 +68,8 @@ type ConfigInput struct {
DisableServerRoutes *bool DisableServerRoutes *bool
DisableDNS *bool DisableDNS *bool
DisableFirewall *bool DisableFirewall *bool
BlockLANAccess *bool BlockLANAccess *bool
BlockInbound *bool
DisableNotifications *bool DisableNotifications *bool
@@ -98,8 +98,8 @@ type Config struct {
DisableServerRoutes bool DisableServerRoutes bool
DisableDNS bool DisableDNS bool
DisableFirewall bool DisableFirewall bool
BlockLANAccess bool BlockLANAccess bool
BlockInbound bool
DisableNotifications *bool DisableNotifications *bool
@@ -223,6 +223,8 @@ func createNewConfig(input ConfigInput) (*Config, error) {
config := &Config{ config := &Config{
// defaults to false only for new (post 0.26) configurations // defaults to false only for new (post 0.26) configurations
ServerSSHAllowed: util.False(), ServerSSHAllowed: util.False(),
// default to disabling server routes on Android for security
DisableServerRoutes: runtime.GOOS == "android",
} }
if _, err := config.apply(input); err != nil { if _, err := config.apply(input); err != nil {
@@ -416,9 +418,15 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
config.ServerSSHAllowed = input.ServerSSHAllowed config.ServerSSHAllowed = input.ServerSSHAllowed
updated = true updated = true
} else if config.ServerSSHAllowed == nil { } else if config.ServerSSHAllowed == nil {
if runtime.GOOS == "android" {
// default to disabled SSH on Android for security
log.Infof("setting SSH server to false by default on Android")
config.ServerSSHAllowed = util.False()
} else {
// enables SSH for configs from old versions to preserve backwards compatibility // enables SSH for configs from old versions to preserve backwards compatibility
log.Infof("falling back to enabled SSH server for pre-existing configuration") log.Infof("falling back to enabled SSH server for pre-existing configuration")
config.ServerSSHAllowed = util.True() config.ServerSSHAllowed = util.True()
}
updated = true updated = true
} }
@@ -483,6 +491,16 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true updated = true
} }
if input.BlockInbound != nil && *input.BlockInbound != config.BlockInbound {
if *input.BlockInbound {
log.Infof("blocking inbound connections")
} else {
log.Infof("allowing inbound connections")
}
config.BlockInbound = *input.BlockInbound
updated = true
}
if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications { if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications {
if *input.DisableNotifications { if *input.DisableNotifications {
log.Infof("disabling notifications") log.Infof("disabling notifications")

View File

@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher" "github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/route"
) )
// ConnMgr coordinates both lazy connections (established on-demand) and permanent peer connections. // ConnMgr coordinates both lazy connections (established on-demand) and permanent peer connections.
@@ -34,8 +35,8 @@ type ConnMgr struct {
lazyConnMgr *manager.Manager lazyConnMgr *manager.Manager
wg sync.WaitGroup wg sync.WaitGroup
ctx context.Context lazyCtx context.Context
ctxCancel context.CancelFunc lazyCtxCancel context.CancelFunc
} }
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface, dispatcher *dispatcher.ConnectionDispatcher) *ConnMgr { func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface, dispatcher *dispatcher.ConnectionDispatcher) *ConnMgr {
@@ -85,7 +86,7 @@ func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) er
log.Infof("lazy connection manager is enabled by management feature flag") log.Infof("lazy connection manager is enabled by management feature flag")
e.initLazyManager(ctx) e.initLazyManager(ctx)
e.statusRecorder.UpdateLazyConnection(true) e.statusRecorder.UpdateLazyConnection(true)
return e.addPeersToLazyConnManager(ctx) return e.addPeersToLazyConnManager()
} else { } else {
if e.lazyConnMgr == nil { if e.lazyConnMgr == nil {
return nil return nil
@@ -97,15 +98,25 @@ func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) er
} }
} }
// UpdateRouteHAMap updates the route HA mappings in the lazy connection manager
func (e *ConnMgr) UpdateRouteHAMap(haMap route.HAMap) {
if !e.isStartedWithLazyMgr() {
log.Debugf("lazy connection manager is not started, skipping UpdateRouteHAMap")
return
}
e.lazyConnMgr.UpdateRouteHAMap(haMap)
}
// SetExcludeList sets the list of peer IDs that should always have permanent connections. // SetExcludeList sets the list of peer IDs that should always have permanent connections.
func (e *ConnMgr) SetExcludeList(peerIDs []string) { func (e *ConnMgr) SetExcludeList(ctx context.Context, peerIDs map[string]bool) {
if e.lazyConnMgr == nil { if e.lazyConnMgr == nil {
return return
} }
excludedPeers := make([]lazyconn.PeerConfig, 0, len(peerIDs)) excludedPeers := make([]lazyconn.PeerConfig, 0, len(peerIDs))
for _, peerID := range peerIDs { for peerID := range peerIDs {
var peerConn *peer.Conn var peerConn *peer.Conn
var exists bool var exists bool
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists { if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
@@ -122,7 +133,7 @@ func (e *ConnMgr) SetExcludeList(peerIDs []string) {
excludedPeers = append(excludedPeers, lazyPeerCfg) excludedPeers = append(excludedPeers, lazyPeerCfg)
} }
added := e.lazyConnMgr.ExcludePeer(e.ctx, excludedPeers) added := e.lazyConnMgr.ExcludePeer(e.lazyCtx, excludedPeers)
for _, peerID := range added { for _, peerID := range added {
var peerConn *peer.Conn var peerConn *peer.Conn
var exists bool var exists bool
@@ -132,7 +143,7 @@ func (e *ConnMgr) SetExcludeList(peerIDs []string) {
} }
peerConn.Log.Infof("peer has been added to lazy connection exclude list, opening permanent connection") peerConn.Log.Infof("peer has been added to lazy connection exclude list, opening permanent connection")
if err := peerConn.Open(e.ctx); err != nil { if err := peerConn.Open(ctx); err != nil {
peerConn.Log.Errorf("failed to open connection: %v", err) peerConn.Log.Errorf("failed to open connection: %v", err)
} }
} }
@@ -164,7 +175,7 @@ func (e *ConnMgr) AddPeerConn(ctx context.Context, peerKey string, conn *peer.Co
PeerConnID: conn.ConnID(), PeerConnID: conn.ConnID(),
Log: conn.Log, Log: conn.Log,
} }
excluded, err := e.lazyConnMgr.AddPeer(lazyPeerCfg) excluded, err := e.lazyConnMgr.AddPeer(e.lazyCtx, lazyPeerCfg)
if err != nil { if err != nil {
conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err) conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err)
if err := conn.Open(ctx); err != nil { if err := conn.Open(ctx); err != nil {
@@ -210,9 +221,9 @@ func (e *ConnMgr) OnSignalMsg(ctx context.Context, peerKey string) (*peer.Conn,
return conn, true return conn, true
} }
if found := e.lazyConnMgr.ActivatePeer(ctx, peerKey); found { if found := e.lazyConnMgr.ActivatePeer(e.lazyCtx, peerKey); found {
conn.Log.Infof("activated peer from inactive state") conn.Log.Infof("activated peer from inactive state")
if err := conn.Open(e.ctx); err != nil { if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err) conn.Log.Errorf("failed to open connection: %v", err)
} }
} }
@@ -224,29 +235,27 @@ func (e *ConnMgr) Close() {
return return
} }
e.ctxCancel() e.lazyCtxCancel()
e.wg.Wait() e.wg.Wait()
e.lazyConnMgr = nil e.lazyConnMgr = nil
} }
func (e *ConnMgr) initLazyManager(parentCtx context.Context) { func (e *ConnMgr) initLazyManager(engineCtx context.Context) {
cfg := manager.Config{ cfg := manager.Config{
InactivityThreshold: inactivityThresholdEnv(), InactivityThreshold: inactivityThresholdEnv(),
} }
e.lazyConnMgr = manager.NewManager(cfg, e.peerStore, e.iface, e.dispatcher) e.lazyConnMgr = manager.NewManager(cfg, engineCtx, e.peerStore, e.iface, e.dispatcher)
ctx, cancel := context.WithCancel(parentCtx) e.lazyCtx, e.lazyCtxCancel = context.WithCancel(engineCtx)
e.ctx = ctx
e.ctxCancel = cancel
e.wg.Add(1) e.wg.Add(1)
go func() { go func() {
defer e.wg.Done() defer e.wg.Done()
e.lazyConnMgr.Start(ctx) e.lazyConnMgr.Start(e.lazyCtx)
}() }()
} }
func (e *ConnMgr) addPeersToLazyConnManager(ctx context.Context) error { func (e *ConnMgr) addPeersToLazyConnManager() error {
peers := e.peerStore.PeersPubKey() peers := e.peerStore.PeersPubKey()
lazyPeerCfgs := make([]lazyconn.PeerConfig, 0, len(peers)) lazyPeerCfgs := make([]lazyconn.PeerConfig, 0, len(peers))
for _, peerID := range peers { for _, peerID := range peers {
@@ -266,7 +275,7 @@ func (e *ConnMgr) addPeersToLazyConnManager(ctx context.Context) error {
lazyPeerCfgs = append(lazyPeerCfgs, lazyPeerCfg) lazyPeerCfgs = append(lazyPeerCfgs, lazyPeerCfg)
} }
return e.lazyConnMgr.AddActivePeers(ctx, lazyPeerCfgs) return e.lazyConnMgr.AddActivePeers(e.lazyCtx, lazyPeerCfgs)
} }
func (e *ConnMgr) closeManager(ctx context.Context) { func (e *ConnMgr) closeManager(ctx context.Context) {
@@ -274,7 +283,7 @@ func (e *ConnMgr) closeManager(ctx context.Context) {
return return
} }
e.ctxCancel() e.lazyCtxCancel()
e.wg.Wait() e.wg.Wait()
e.lazyConnMgr = nil e.lazyConnMgr = nil
@@ -284,7 +293,7 @@ func (e *ConnMgr) closeManager(ctx context.Context) {
} }
func (e *ConnMgr) isStartedWithLazyMgr() bool { func (e *ConnMgr) isStartedWithLazyMgr() bool {
return e.lazyConnMgr != nil && e.ctxCancel != nil return e.lazyConnMgr != nil && e.lazyCtxCancel != nil
} }
func inactivityThresholdEnv() *time.Duration { func inactivityThresholdEnv() *time.Duration {

View File

@@ -436,11 +436,12 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
DNSRouteInterval: config.DNSRouteInterval, DNSRouteInterval: config.DNSRouteInterval,
DisableClientRoutes: config.DisableClientRoutes, DisableClientRoutes: config.DisableClientRoutes,
DisableServerRoutes: config.DisableServerRoutes, DisableServerRoutes: config.DisableServerRoutes || config.BlockInbound,
DisableDNS: config.DisableDNS, DisableDNS: config.DisableDNS,
DisableFirewall: config.DisableFirewall, DisableFirewall: config.DisableFirewall,
BlockLANAccess: config.BlockLANAccess, BlockLANAccess: config.BlockLANAccess,
BlockInbound: config.BlockInbound,
LazyConnectionEnabled: config.LazyConnectionEnabled, LazyConnectionEnabled: config.LazyConnectionEnabled,
} }
@@ -499,6 +500,9 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.DisableServerRoutes, config.DisableServerRoutes,
config.DisableDNS, config.DisableDNS,
config.DisableFirewall, config.DisableFirewall,
config.BlockLANAccess,
config.BlockInbound,
config.LazyConnectionEnabled,
) )
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels) loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
if err != nil { if err != nil {

View File

@@ -270,11 +270,21 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("Failed to add corrupted state files to debug bundle: %v", err) log.Errorf("Failed to add corrupted state files to debug bundle: %v", err)
} }
if g.logFile != "console" { if err := g.addWgShow(); err != nil {
log.Errorf("Failed to add wg show output: %v", err)
}
if g.logFile != "console" && g.logFile != "" {
if err := g.addLogfile(); err != nil { if err := g.addLogfile(); err != nil {
return fmt.Errorf("add log file: %w", err) log.Errorf("Failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("Failed to add systemd logs as fallback: %v", err)
} }
} }
} else if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("Failed to add systemd logs: %v", err)
}
return nil return nil
} }
@@ -366,17 +376,33 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
configContent.WriteString(fmt.Sprintf("RosenpassEnabled: %v\n", g.internalConfig.RosenpassEnabled)) configContent.WriteString(fmt.Sprintf("RosenpassEnabled: %v\n", g.internalConfig.RosenpassEnabled))
configContent.WriteString(fmt.Sprintf("RosenpassPermissive: %v\n", g.internalConfig.RosenpassPermissive)) configContent.WriteString(fmt.Sprintf("RosenpassPermissive: %v\n", g.internalConfig.RosenpassPermissive))
if g.internalConfig.ServerSSHAllowed != nil { if g.internalConfig.ServerSSHAllowed != nil {
configContent.WriteString(fmt.Sprintf("BundleGeneratorSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed)) configContent.WriteString(fmt.Sprintf("ServerSSHAllowed: %v\n", *g.internalConfig.ServerSSHAllowed))
} }
configContent.WriteString(fmt.Sprintf("DisableAutoConnect: %v\n", g.internalConfig.DisableAutoConnect))
configContent.WriteString(fmt.Sprintf("DNSRouteInterval: %s\n", g.internalConfig.DNSRouteInterval))
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes)) configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
configContent.WriteString(fmt.Sprintf("DisableBundleGeneratorRoutes: %v\n", g.internalConfig.DisableServerRoutes)) configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
configContent.WriteString(fmt.Sprintf("DisableDNS: %v\n", g.internalConfig.DisableDNS)) configContent.WriteString(fmt.Sprintf("DisableDNS: %v\n", g.internalConfig.DisableDNS))
configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall)) configContent.WriteString(fmt.Sprintf("DisableFirewall: %v\n", g.internalConfig.DisableFirewall))
configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess)) configContent.WriteString(fmt.Sprintf("BlockLANAccess: %v\n", g.internalConfig.BlockLANAccess))
configContent.WriteString(fmt.Sprintf("BlockInbound: %v\n", g.internalConfig.BlockInbound))
if g.internalConfig.DisableNotifications != nil {
configContent.WriteString(fmt.Sprintf("DisableNotifications: %v\n", *g.internalConfig.DisableNotifications))
}
configContent.WriteString(fmt.Sprintf("DNSLabels: %v\n", g.internalConfig.DNSLabels))
configContent.WriteString(fmt.Sprintf("DisableAutoConnect: %v\n", g.internalConfig.DisableAutoConnect))
configContent.WriteString(fmt.Sprintf("DNSRouteInterval: %s\n", g.internalConfig.DNSRouteInterval))
if g.internalConfig.ClientCertPath != "" {
configContent.WriteString(fmt.Sprintf("ClientCertPath: %s\n", g.internalConfig.ClientCertPath))
}
if g.internalConfig.ClientCertKeyPath != "" {
configContent.WriteString(fmt.Sprintf("ClientCertKeyPath: %s\n", g.internalConfig.ClientCertKeyPath))
}
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled)) configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
} }

View File

@@ -4,17 +4,104 @@ package debug
import ( import (
"bytes" "bytes"
"context"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"os"
"os/exec" "os/exec"
"sort" "sort"
"strings" "strings"
"time"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const (
maxLogEntries = 100000
maxLogAge = 7 * 24 * time.Hour // Last 7 days
)
// trySystemdLogFallback attempts to get logs from systemd journal as fallback
func (g *BundleGenerator) trySystemdLogFallback() error {
log.Debug("Attempting to collect systemd journal logs")
serviceName := getServiceName()
journalLogs, err := getSystemdLogs(serviceName)
if err != nil {
return fmt.Errorf("get systemd logs for %s: %w", serviceName, err)
}
if strings.Contains(journalLogs, "No recent log entries found") {
log.Debug("No recent log entries found in systemd journal")
return nil
}
if g.anonymize {
journalLogs = g.anonymizer.AnonymizeString(journalLogs)
}
logReader := strings.NewReader(journalLogs)
fileName := fmt.Sprintf("systemd-%s.log", serviceName)
if err := g.addFileToZip(logReader, fileName); err != nil {
return fmt.Errorf("add systemd logs to bundle: %w", err)
}
log.Infof("Added systemd journal logs for %s to debug bundle", serviceName)
return nil
}
// getServiceName gets the service name from environment or defaults to netbird
func getServiceName() string {
if unitName := os.Getenv("SYSTEMD_UNIT"); unitName != "" {
log.Debugf("Detected SYSTEMD_UNIT environment variable: %s", unitName)
return unitName
}
return "netbird"
}
// getSystemdLogs retrieves logs from systemd journal for a specific service using journalctl
func getSystemdLogs(serviceName string) (string, error) {
args := []string{
"-u", fmt.Sprintf("%s.service", serviceName),
"--since", fmt.Sprintf("-%s", maxLogAge.String()),
"--lines", fmt.Sprintf("%d", maxLogEntries),
"--no-pager",
"--output", "short-iso",
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "journalctl", args...)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return "", fmt.Errorf("journalctl command timed out after 30 seconds")
}
if strings.Contains(err.Error(), "executable file not found") {
return "", fmt.Errorf("journalctl command not found: %w", err)
}
return "", fmt.Errorf("execute journalctl: %w (stderr: %s)", err, stderr.String())
}
logs := stdout.String()
if strings.TrimSpace(logs) == "" {
return "No recent log entries found in systemd journal", nil
}
header := fmt.Sprintf("=== Systemd Journal Logs for %s.service (last %d entries, max %s) ===\n",
serviceName, maxLogEntries, maxLogAge.String())
return header + logs, nil
}
// addFirewallRules collects and adds firewall rules to the archive // addFirewallRules collects and adds firewall rules to the archive
func (g *BundleGenerator) addFirewallRules() error { func (g *BundleGenerator) addFirewallRules() error {
log.Info("Collecting firewall rules") log.Info("Collecting firewall rules")
@@ -481,7 +568,7 @@ func formatExpr(exp expr.Any) string {
case *expr.Fib: case *expr.Fib:
return formatFib(e) return formatFib(e)
case *expr.Target: case *expr.Target:
return fmt.Sprintf("jump %s", e.Name) // Properly format jump targets return fmt.Sprintf("jump %s", e.Name)
case *expr.Immediate: case *expr.Immediate:
if e.Register == 1 { if e.Register == 1 {
return formatImmediateData(e.Data) return formatImmediateData(e.Data)

View File

@@ -6,3 +6,9 @@ package debug
func (g *BundleGenerator) addFirewallRules() error { func (g *BundleGenerator) addFirewallRules() error {
return nil return nil
} }
func (g *BundleGenerator) trySystemdLogFallback() error {
// Systemd is only available on Linux
// TODO: Add BSD support
return nil
}

View File

@@ -0,0 +1,66 @@
package debug
import (
"bytes"
"fmt"
"strings"
"time"
"github.com/netbirdio/netbird/client/iface/configurer"
)
type WGIface interface {
FullStats() (*configurer.Stats, error)
}
func (g *BundleGenerator) addWgShow() error {
result, err := g.statusRecorder.PeersStatus()
if err != nil {
return err
}
output := g.toWGShowFormat(result)
reader := bytes.NewReader([]byte(output))
if err := g.addFileToZip(reader, "wgshow.txt"); err != nil {
return fmt.Errorf("add wg show to zip: %w", err)
}
return nil
}
func (g *BundleGenerator) toWGShowFormat(s *configurer.Stats) string {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("interface: %s\n", s.DeviceName))
sb.WriteString(fmt.Sprintf(" public key: %s\n", s.PublicKey))
sb.WriteString(fmt.Sprintf(" listen port: %d\n", s.ListenPort))
if s.FWMark != 0 {
sb.WriteString(fmt.Sprintf(" fwmark: %#x\n", s.FWMark))
}
for _, peer := range s.Peers {
sb.WriteString(fmt.Sprintf("\npeer: %s\n", peer.PublicKey))
if peer.Endpoint.IP != nil {
if g.anonymize {
anonEndpoint := g.anonymizer.AnonymizeUDPAddr(peer.Endpoint)
sb.WriteString(fmt.Sprintf(" endpoint: %s\n", anonEndpoint.String()))
} else {
sb.WriteString(fmt.Sprintf(" endpoint: %s\n", peer.Endpoint.String()))
}
}
if len(peer.AllowedIPs) > 0 {
var ipStrings []string
for _, ipnet := range peer.AllowedIPs {
ipStrings = append(ipStrings, ipnet.String())
}
sb.WriteString(fmt.Sprintf(" allowed ips: %s\n", strings.Join(ipStrings, ", ")))
}
sb.WriteString(fmt.Sprintf(" latest handshake: %s\n", peer.LastHandshake.Format(time.RFC1123)))
sb.WriteString(fmt.Sprintf(" transfer: %d B received, %d B sent\n", peer.RxBytes, peer.TxBytes))
if peer.PresharedKey {
sb.WriteString(" preshared key: (hidden)\n")
}
}
return sb.String()
}

View File

@@ -2,7 +2,7 @@ package internal
import ( import (
"fmt" "fmt"
"net" "net/netip"
"slices" "slices"
"strings" "strings"
@@ -12,13 +12,14 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
) )
func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) { func createPTRRecord(aRecord nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) {
ip := net.ParseIP(aRecord.RData) ip, err := netip.ParseAddr(aRecord.RData)
if ip == nil || ip.To4() == nil { if err != nil {
log.Warnf("failed to parse IP address %s: %v", aRecord.RData, err)
return nbdns.SimpleRecord{}, false return nbdns.SimpleRecord{}, false
} }
if !ipNet.Contains(ip) { if !prefix.Contains(ip) {
return nbdns.SimpleRecord{}, false return nbdns.SimpleRecord{}, false
} }
@@ -36,16 +37,19 @@ func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.Simple
} }
// generateReverseZoneName creates the reverse DNS zone name for a given network // generateReverseZoneName creates the reverse DNS zone name for a given network
func generateReverseZoneName(ipNet *net.IPNet) (string, error) { func generateReverseZoneName(network netip.Prefix) (string, error) {
networkIP := ipNet.IP.Mask(ipNet.Mask) networkIP := network.Masked().Addr()
maskOnes, _ := ipNet.Mask.Size()
if !networkIP.Is4() {
return "", fmt.Errorf("reverse DNS is only supported for IPv4 networks, got: %s", networkIP)
}
// round up to nearest byte // round up to nearest byte
octetsToUse := (maskOnes + 7) / 8 octetsToUse := (network.Bits() + 7) / 8
octets := strings.Split(networkIP.String(), ".") octets := strings.Split(networkIP.String(), ".")
if octetsToUse > len(octets) { if octetsToUse > len(octets) {
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", maskOnes) return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", network.Bits())
} }
reverseOctets := make([]string, octetsToUse) reverseOctets := make([]string, octetsToUse)
@@ -68,7 +72,7 @@ func zoneExists(config *nbdns.Config, zoneName string) bool {
} }
// collectPTRRecords gathers all PTR records for the given network from A records // collectPTRRecords gathers all PTR records for the given network from A records
func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRecord { func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.SimpleRecord {
var records []nbdns.SimpleRecord var records []nbdns.SimpleRecord
for _, zone := range config.CustomZones { for _, zone := range config.CustomZones {
@@ -77,7 +81,7 @@ func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRec
continue continue
} }
if ptrRecord, ok := createPTRRecord(record, ipNet); ok { if ptrRecord, ok := createPTRRecord(record, prefix); ok {
records = append(records, ptrRecord) records = append(records, ptrRecord)
} }
} }
@@ -87,8 +91,8 @@ func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRec
} }
// addReverseZone adds a reverse DNS zone to the configuration for the given network // addReverseZone adds a reverse DNS zone to the configuration for the given network
func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) { func addReverseZone(config *nbdns.Config, network netip.Prefix) {
zoneName, err := generateReverseZoneName(ipNet) zoneName, err := generateReverseZoneName(network)
if err != nil { if err != nil {
log.Warn(err) log.Warn(err)
return return
@@ -99,7 +103,7 @@ func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
return return
} }
records := collectPTRRecords(config, ipNet) records := collectPTRRecords(config, network)
reverseZone := nbdns.CustomZone{ reverseZone := nbdns.CustomZone{
Domain: zoneName, Domain: zoneName,

View File

@@ -1,6 +1,7 @@
package dns package dns
import ( import (
"fmt"
"slices" "slices"
"strings" "strings"
"sync" "sync"
@@ -10,8 +11,9 @@ import (
) )
const ( const (
PriorityDNSRoute = 100 PriorityLocal = 100
PriorityMatchDomain = 50 PriorityDNSRoute = 75
PriorityUpstream = 50
PriorityDefault = 1 PriorityDefault = 1
) )
@@ -148,47 +150,27 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
qname := strings.ToLower(r.Question[0].Name) qname := strings.ToLower(r.Question[0].Name)
log.Tracef("handling DNS request for domain=%s", qname)
c.mu.RLock() c.mu.RLock()
handlers := slices.Clone(c.handlers) handlers := slices.Clone(c.handlers)
c.mu.RUnlock() c.mu.RUnlock()
if log.IsLevelEnabled(log.TraceLevel) { if log.IsLevelEnabled(log.TraceLevel) {
log.Tracef("current handlers (%d):", len(handlers)) var b strings.Builder
b.WriteString(fmt.Sprintf("DNS request domain=%s, handlers (%d):\n", qname, len(handlers)))
for _, h := range handlers { for _, h := range handlers {
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d", b.WriteString(fmt.Sprintf(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d\n",
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority) h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority))
} }
log.Trace(strings.TrimSuffix(b.String(), "\n"))
} }
// Try handlers in priority order // Try handlers in priority order
for _, entry := range handlers { for _, entry := range handlers {
var matched bool matched := c.isHandlerMatch(qname, entry)
switch {
case entry.Pattern == ".":
matched = true
case entry.IsWildcard:
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
default:
// For non-wildcard patterns:
// If handler wants subdomain matching, allow suffix match
// Otherwise require exact match
if entry.MatchSubdomains {
matched = strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
} else {
matched = strings.EqualFold(qname, entry.Pattern)
}
}
if !matched { if matched {
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d matched=false", log.Tracef("handler matched: domain=%s -> pattern=%s wildcard=%v match_subdomain=%v priority=%d",
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard, entry.Priority)
continue
}
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d",
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority) qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
chainWriter := &ResponseWriterChain{ chainWriter := &ResponseWriterChain{
@@ -199,11 +181,12 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// If handler wants to continue, try next handler // If handler wants to continue, try next handler
if chainWriter.shouldContinue { if chainWriter.shouldContinue {
log.Tracef("handler requested continue to next handler") log.Tracef("handler requested continue to next handler for domain=%s", qname)
continue continue
} }
return return
} }
}
// No handler matched or all handlers passed // No handler matched or all handlers passed
log.Tracef("no handler found for domain=%s", qname) log.Tracef("no handler found for domain=%s", qname)
@@ -213,3 +196,22 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
log.Errorf("failed to write DNS response: %v", err) log.Errorf("failed to write DNS response: %v", err)
} }
} }
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
switch {
case entry.Pattern == ".":
return true
case entry.IsWildcard:
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
default:
// For non-wildcard patterns:
// If handler wants subdomain matching, allow suffix match
// Otherwise require exact match
if entry.MatchSubdomains {
return strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
} else {
return strings.EqualFold(qname, entry.Pattern)
}
}
}

View File

@@ -22,7 +22,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
// Setup handlers with different priorities // Setup handlers with different priorities
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault) chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault)
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain) chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityUpstream)
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute) chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute)
// Create test request // Create test request
@@ -200,7 +200,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
priority int priority int
}{ }{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault}, {pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain}, {pattern: "*.example.com.", priority: nbdns.PriorityUpstream},
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute}, {pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
}, },
queryDomain: "test.example.com.", queryDomain: "test.example.com.",
@@ -214,7 +214,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
priority int priority int
}{ }{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault}, {pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain}, {pattern: "test.example.com.", priority: nbdns.PriorityUpstream},
{pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute}, {pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute},
}, },
queryDomain: "sub.test.example.com.", queryDomain: "sub.test.example.com.",
@@ -281,7 +281,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
// Add handlers in priority order // Add handlers in priority order
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute) chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute)
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain) chain.AddHandler("example.com.", handler2, nbdns.PriorityUpstream)
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault) chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault)
// Create test request // Create test request
@@ -344,13 +344,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
priority int priority int
}{ }{
{"add", "example.com.", nbdns.PriorityDNSRoute}, {"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityMatchDomain}, {"add", "example.com.", nbdns.PriorityUpstream},
{"remove", "example.com.", nbdns.PriorityDNSRoute}, {"remove", "example.com.", nbdns.PriorityDNSRoute},
}, },
query: "example.com.", query: "example.com.",
expectedCalls: map[int]bool{ expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: false, nbdns.PriorityDNSRoute: false,
nbdns.PriorityMatchDomain: true, nbdns.PriorityUpstream: true,
}, },
}, },
{ {
@@ -361,13 +361,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
priority int priority int
}{ }{
{"add", "example.com.", nbdns.PriorityDNSRoute}, {"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityMatchDomain}, {"add", "example.com.", nbdns.PriorityUpstream},
{"remove", "example.com.", nbdns.PriorityMatchDomain}, {"remove", "example.com.", nbdns.PriorityUpstream},
}, },
query: "example.com.", query: "example.com.",
expectedCalls: map[int]bool{ expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: true, nbdns.PriorityDNSRoute: true,
nbdns.PriorityMatchDomain: false, nbdns.PriorityUpstream: false,
}, },
}, },
{ {
@@ -378,15 +378,15 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
priority int priority int
}{ }{
{"add", "example.com.", nbdns.PriorityDNSRoute}, {"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityMatchDomain}, {"add", "example.com.", nbdns.PriorityUpstream},
{"add", "example.com.", nbdns.PriorityDefault}, {"add", "example.com.", nbdns.PriorityDefault},
{"remove", "example.com.", nbdns.PriorityDNSRoute}, {"remove", "example.com.", nbdns.PriorityDNSRoute},
{"remove", "example.com.", nbdns.PriorityMatchDomain}, {"remove", "example.com.", nbdns.PriorityUpstream},
}, },
query: "example.com.", query: "example.com.",
expectedCalls: map[int]bool{ expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: false, nbdns.PriorityDNSRoute: false,
nbdns.PriorityMatchDomain: false, nbdns.PriorityUpstream: false,
nbdns.PriorityDefault: true, nbdns.PriorityDefault: true,
}, },
}, },
@@ -454,7 +454,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
// Add handlers in mixed order // Add handlers in mixed order
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault) chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute) chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain) chain.AddHandler(testDomain, matchHandler, nbdns.PriorityUpstream)
// Test 1: Initial state // Test 1: Initial state
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
@@ -490,7 +490,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
defaultHandler.Calls = nil defaultHandler.Calls = nil
// Test 3: Remove middle priority handler // Test 3: Remove middle priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain) chain.RemoveHandler(testDomain, nbdns.PriorityUpstream)
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Now lowest priority handler (defaultHandler) should be called // Now lowest priority handler (defaultHandler) should be called
@@ -607,7 +607,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
shouldMatch bool shouldMatch bool
}{ }{
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false}, {"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
{"example.com.", nbdns.PriorityMatchDomain, false, false}, {"example.com.", nbdns.PriorityUpstream, false, false},
{"Example.Com.", nbdns.PriorityDNSRoute, false, true}, {"Example.Com.", nbdns.PriorityDNSRoute, false, true},
}, },
query: "example.com.", query: "example.com.",
@@ -702,8 +702,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int priority int
subdomain bool subdomain bool
}{ }{
{"add", "example.com.", nbdns.PriorityMatchDomain, true}, {"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false}, {"add", "sub.example.com.", nbdns.PriorityUpstream, false},
}, },
query: "sub.example.com.", query: "sub.example.com.",
expectedMatch: "sub.example.com.", expectedMatch: "sub.example.com.",
@@ -717,8 +717,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int priority int
subdomain bool subdomain bool
}{ }{
{"add", "example.com.", nbdns.PriorityMatchDomain, true}, {"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true}, {"add", "sub.example.com.", nbdns.PriorityUpstream, true},
}, },
query: "sub.example.com.", query: "sub.example.com.",
expectedMatch: "sub.example.com.", expectedMatch: "sub.example.com.",
@@ -732,10 +732,10 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int priority int
subdomain bool subdomain bool
}{ }{
{"add", "example.com.", nbdns.PriorityMatchDomain, true}, {"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true}, {"add", "sub.example.com.", nbdns.PriorityUpstream, true},
{"add", "test.sub.example.com.", nbdns.PriorityMatchDomain, false}, {"add", "test.sub.example.com.", nbdns.PriorityUpstream, false},
{"remove", "test.sub.example.com.", nbdns.PriorityMatchDomain, false}, {"remove", "test.sub.example.com.", nbdns.PriorityUpstream, false},
}, },
query: "test.sub.example.com.", query: "test.sub.example.com.",
expectedMatch: "sub.example.com.", expectedMatch: "sub.example.com.",
@@ -749,7 +749,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int priority int
subdomain bool subdomain bool
}{ }{
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false}, {"add", "sub.example.com.", nbdns.PriorityUpstream, false},
{"add", "example.com.", nbdns.PriorityDNSRoute, true}, {"add", "example.com.", nbdns.PriorityDNSRoute, true},
}, },
query: "sub.example.com.", query: "sub.example.com.",
@@ -764,9 +764,9 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int priority int
subdomain bool subdomain bool
}{ }{
{"add", "example.com.", nbdns.PriorityMatchDomain, true}, {"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "other.example.com.", nbdns.PriorityMatchDomain, true}, {"add", "other.example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false}, {"add", "sub.example.com.", nbdns.PriorityUpstream, false},
}, },
query: "sub.example.com.", query: "sub.example.com.",
expectedMatch: "sub.example.com.", expectedMatch: "sub.example.com.",

View File

@@ -1,11 +1,14 @@
package dns package dns
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"os/exec"
"strings" "strings"
"syscall" "syscall"
"time"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -41,6 +44,20 @@ const (
interfaceConfigNameServerKey = "NameServer" interfaceConfigNameServerKey = "NameServer"
interfaceConfigSearchListKey = "SearchList" interfaceConfigSearchListKey = "SearchList"
// Network interface DNS registration settings
disableDynamicUpdateKey = "DisableDynamicUpdate"
registrationEnabledKey = "RegistrationEnabled"
maxNumberOfAddressesToRegisterKey = "MaxNumberOfAddressesToRegister"
// NetBIOS/WINS settings
netbtInterfacePath = `SYSTEM\CurrentControlSet\Services\NetBT\Parameters\Interfaces`
netbiosOptionsKey = "NetbiosOptions"
// NetBIOS option values: 0 = from DHCP, 1 = enabled, 2 = disabled
netbiosFromDHCP = 0
netbiosEnabled = 1
netbiosDisabled = 2
// RP_FORCE: Reapply all policies even if no policy change was detected // RP_FORCE: Reapply all policies even if no policy change was detected
rpForce = 0x1 rpForce = 0x1
) )
@@ -67,16 +84,85 @@ func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
log.Infof("detected GPO DNS policy configuration, using policy store") log.Infof("detected GPO DNS policy configuration, using policy store")
} }
return &registryConfigurator{ configurator := &registryConfigurator{
guid: guid, guid: guid,
gpo: useGPO, gpo: useGPO,
}, nil }
if err := configurator.configureInterface(); err != nil {
log.Errorf("failed to configure interface settings: %v", err)
}
return configurator, nil
} }
func (r *registryConfigurator) supportCustomPort() bool { func (r *registryConfigurator) supportCustomPort() bool {
return false return false
} }
func (r *registryConfigurator) configureInterface() error {
var merr *multierror.Error
if err := r.disableDNSRegistrationForInterface(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("disable DNS registration: %w", err))
}
if err := r.disableWINSForInterface(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("disable WINS: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *registryConfigurator) disableDNSRegistrationForInterface() error {
regKey, err := r.getInterfaceRegistryKey()
if err != nil {
return fmt.Errorf("get interface registry key: %w", err)
}
defer closer(regKey)
var merr *multierror.Error
if err := regKey.SetDWordValue(disableDynamicUpdateKey, 1); err != nil {
merr = multierror.Append(merr, fmt.Errorf("set %s: %w", disableDynamicUpdateKey, err))
}
if err := regKey.SetDWordValue(registrationEnabledKey, 0); err != nil {
merr = multierror.Append(merr, fmt.Errorf("set %s: %w", registrationEnabledKey, err))
}
if err := regKey.SetDWordValue(maxNumberOfAddressesToRegisterKey, 0); err != nil {
merr = multierror.Append(merr, fmt.Errorf("set %s: %w", maxNumberOfAddressesToRegisterKey, err))
}
if merr == nil || len(merr.Errors) == 0 {
log.Infof("disabled DNS registration for interface %s", r.guid)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *registryConfigurator) disableWINSForInterface() error {
netbtKeyPath := fmt.Sprintf(`%s\Tcpip_%s`, netbtInterfacePath, r.guid)
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, netbtKeyPath, registry.SET_VALUE)
if err != nil {
regKey, _, err = registry.CreateKey(registry.LOCAL_MACHINE, netbtKeyPath, registry.SET_VALUE)
if err != nil {
return fmt.Errorf("create NetBT interface key %s: %w", netbtKeyPath, err)
}
}
defer closer(regKey)
// NetbiosOptions: 2 = disabled
if err := regKey.SetDWordValue(netbiosOptionsKey, netbiosDisabled); err != nil {
return fmt.Errorf("set %s: %w", netbiosOptionsKey, err)
}
log.Infof("disabled WINS/NetBIOS for interface %s", r.guid)
return nil
}
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
if config.RouteAll { if config.RouteAll {
if err := r.addDNSSetupForAll(config.ServerIP); err != nil { if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
@@ -119,9 +205,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
return fmt.Errorf("update search domains: %w", err) return fmt.Errorf("update search domains: %w", err)
} }
if err := r.flushDNSCache(); err != nil { go r.flushDNSCache()
log.Errorf("failed to flush DNS cache: %v", err)
}
return nil return nil
} }
@@ -191,7 +275,25 @@ func (r *registryConfigurator) string() string {
return "registry" return "registry"
} }
func (r *registryConfigurator) flushDNSCache() error { func (r *registryConfigurator) registerDNS() {
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
// nolint:misspell
cmd := exec.CommandContext(ctx, "ipconfig", "/registerdns")
out, err := cmd.CombinedOutput()
if err != nil {
log.Errorf("failed to register DNS: %v, output: %s", err, out)
return
}
log.Info("registered DNS names")
}
func (r *registryConfigurator) flushDNSCache() {
r.registerDNS()
// dnsFlushResolverCacheFn.Call() may panic if the func is not found // dnsFlushResolverCacheFn.Call() may panic if the func is not found
defer func() { defer func() {
if rec := recover(); rec != nil { if rec := recover(); rec != nil {
@@ -202,13 +304,14 @@ func (r *registryConfigurator) flushDNSCache() error {
ret, _, err := dnsFlushResolverCacheFn.Call() ret, _, err := dnsFlushResolverCacheFn.Call()
if ret == 0 { if ret == 0 {
if err != nil && !errors.Is(err, syscall.Errno(0)) { if err != nil && !errors.Is(err, syscall.Errno(0)) {
return fmt.Errorf("DnsFlushResolverCache failed: %w", err) log.Errorf("DnsFlushResolverCache failed: %v", err)
return
} }
return fmt.Errorf("DnsFlushResolverCache failed") log.Errorf("DnsFlushResolverCache failed")
return
} }
log.Info("flushed DNS cache") log.Info("flushed DNS cache")
return nil
} }
func (r *registryConfigurator) updateSearchDomains(domains []string) error { func (r *registryConfigurator) updateSearchDomains(domains []string) error {
@@ -263,9 +366,7 @@ func (r *registryConfigurator) restoreHostDNS() error {
return fmt.Errorf("remove interface registry key: %w", err) return fmt.Errorf("remove interface registry key: %w", err)
} }
if err := r.flushDNSCache(); err != nil { go r.flushDNSCache()
log.Errorf("failed to flush DNS cache: %v", err)
}
return nil return nil
} }

View File

@@ -12,16 +12,19 @@ import (
"github.com/netbirdio/netbird/client/internal/dns/types" "github.com/netbirdio/netbird/client/internal/dns/types"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
) )
type Resolver struct { type Resolver struct {
mu sync.RWMutex mu sync.RWMutex
records map[dns.Question][]dns.RR records map[dns.Question][]dns.RR
domains map[domain.Domain]struct{}
} }
func NewResolver() *Resolver { func NewResolver() *Resolver {
return &Resolver{ return &Resolver{
records: make(map[dns.Question][]dns.RR), records: make(map[dns.Question][]dns.RR),
domains: make(map[domain.Domain]struct{}),
} }
} }
@@ -64,8 +67,12 @@ func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
replyMessage.Rcode = dns.RcodeSuccess replyMessage.Rcode = dns.RcodeSuccess
replyMessage.Answer = append(replyMessage.Answer, records...) replyMessage.Answer = append(replyMessage.Answer, records...)
} else { } else {
// TODO: return success if we have a different record type for the same name, relevant for search domains // Check if we have any records for this domain name with different types
replyMessage.Rcode = dns.RcodeNameError if d.hasRecordsForDomain(domain.Domain(question.Name)) {
replyMessage.Rcode = dns.RcodeSuccess // NOERROR with 0 records
} else {
replyMessage.Rcode = dns.RcodeNameError // NXDOMAIN
}
} }
if err := w.WriteMsg(replyMessage); err != nil { if err := w.WriteMsg(replyMessage); err != nil {
@@ -73,6 +80,15 @@ func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
} }
// hasRecordsForDomain checks if any records exist for the given domain name regardless of type
func (d *Resolver) hasRecordsForDomain(domainName domain.Domain) bool {
d.mu.RLock()
defer d.mu.RUnlock()
_, exists := d.domains[domainName]
return exists
}
// lookupRecords fetches *all* DNS records matching the first question in r. // lookupRecords fetches *all* DNS records matching the first question in r.
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR { func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
d.mu.RLock() d.mu.RLock()
@@ -111,6 +127,7 @@ func (d *Resolver) Update(update []nbdns.SimpleRecord) {
defer d.mu.Unlock() defer d.mu.Unlock()
maps.Clear(d.records) maps.Clear(d.records)
maps.Clear(d.domains)
for _, rec := range update { for _, rec := range update {
if err := d.registerRecord(rec); err != nil { if err := d.registerRecord(rec); err != nil {
@@ -144,6 +161,7 @@ func (d *Resolver) registerRecord(record nbdns.SimpleRecord) error {
} }
d.records[q] = append(d.records[q], rr) d.records[q] = append(d.records[q], rr)
d.domains[domain.Domain(q.Name)] = struct{}{}
return nil return nil
} }

View File

@@ -470,3 +470,115 @@ func TestLocalResolver_CNAMEFallback(t *testing.T) {
}) })
} }
} }
// TestLocalResolver_NoErrorWithDifferentRecordType verifies that querying for a record type
// that doesn't exist but where other record types exist for the same domain returns NOERROR
// with 0 records instead of NXDOMAIN
func TestLocalResolver_NoErrorWithDifferentRecordType(t *testing.T) {
resolver := NewResolver()
recordA := nbdns.SimpleRecord{
Name: "example.netbird.cloud.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.100",
}
recordCNAME := nbdns.SimpleRecord{
Name: "alias.netbird.cloud.",
Type: int(dns.TypeCNAME),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "target.example.com.",
}
resolver.Update([]nbdns.SimpleRecord{recordA, recordCNAME})
testCases := []struct {
name string
queryName string
queryType uint16
expectedRcode int
shouldHaveData bool
}{
{
name: "Query A record that exists",
queryName: "example.netbird.cloud.",
queryType: dns.TypeA,
expectedRcode: dns.RcodeSuccess,
shouldHaveData: true,
},
{
name: "Query AAAA for domain with only A record",
queryName: "example.netbird.cloud.",
queryType: dns.TypeAAAA,
expectedRcode: dns.RcodeSuccess,
shouldHaveData: false,
},
{
name: "Query other record with different case and non-fqdn",
queryName: "EXAMPLE.netbird.cloud",
queryType: dns.TypeAAAA,
expectedRcode: dns.RcodeSuccess,
shouldHaveData: false,
},
{
name: "Query TXT for domain with only A record",
queryName: "example.netbird.cloud.",
queryType: dns.TypeTXT,
expectedRcode: dns.RcodeSuccess,
shouldHaveData: false,
},
{
name: "Query A for domain with only CNAME record",
queryName: "alias.netbird.cloud.",
queryType: dns.TypeA,
expectedRcode: dns.RcodeSuccess,
shouldHaveData: true,
},
{
name: "Query AAAA for domain with only CNAME record",
queryName: "alias.netbird.cloud.",
queryType: dns.TypeAAAA,
expectedRcode: dns.RcodeSuccess,
shouldHaveData: true,
},
{
name: "Query for completely non-existent domain",
queryName: "nonexistent.netbird.cloud.",
queryType: dns.TypeA,
expectedRcode: dns.RcodeNameError,
shouldHaveData: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var responseMSG *dns.Msg
msg := new(dns.Msg).SetQuestion(tc.queryName, tc.queryType)
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
resolver.ServeDNS(responseWriter, msg)
require.NotNil(t, responseMSG, "Should have received a response message")
assert.Equal(t, tc.expectedRcode, responseMSG.Rcode,
"Response code should be %d (%s)",
tc.expectedRcode, dns.RcodeToString[tc.expectedRcode])
if tc.shouldHaveData {
assert.Greater(t, len(responseMSG.Answer), 0, "Response should contain answers")
} else {
assert.Equal(t, 0, len(responseMSG.Answer), "Response should contain no answers")
}
})
}
}

View File

@@ -489,7 +489,7 @@ func (s *DefaultServer) applyHostConfig() {
} }
} }
log.Debugf("extra match domains: %v", s.extraDomains) log.Debugf("extra match domains: %v", maps.Keys(s.extraDomains))
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil { if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
log.Errorf("failed to apply DNS host manager update: %v", err) log.Errorf("failed to apply DNS host manager update: %v", err)
@@ -527,7 +527,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
muxUpdates = append(muxUpdates, handlerWrapper{ muxUpdates = append(muxUpdates, handlerWrapper{
domain: customZone.Domain, domain: customZone.Domain,
handler: s.localResolver, handler: s.localResolver,
priority: PriorityMatchDomain, priority: PriorityLocal,
}) })
for _, record := range customZone.Records { for _, record := range customZone.Records {
@@ -566,7 +566,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
groupedNS := groupNSGroupsByDomain(nameServerGroups) groupedNS := groupNSGroupsByDomain(nameServerGroups)
for _, domainGroup := range groupedNS { for _, domainGroup := range groupedNS {
basePriority := PriorityMatchDomain basePriority := PriorityUpstream
if domainGroup.domain == nbdns.RootZone { if domainGroup.domain == nbdns.RootZone {
basePriority = PriorityDefault basePriority = PriorityDefault
} }
@@ -588,10 +588,14 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
// Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts // Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts
priority := basePriority - i priority := basePriority - i
// Check if we're about to overlap with the next priority tier // Check if we're about to overlap with the next priority tier.
if basePriority == PriorityMatchDomain && priority <= PriorityDefault { // This boundary check ensures that the priority of upstream handlers does not conflict
// with the default priority tier. By decrementing the priority for each handler, we avoid
// overlaps, but if the calculated priority falls into the default tier, we skip the remaining
// handlers to maintain the integrity of the priority system.
if basePriority == PriorityUpstream && priority <= PriorityDefault {
log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers", log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers",
domainGroup.domain, PriorityMatchDomain-PriorityDefault) domainGroup.domain, PriorityUpstream-PriorityDefault)
break break
} }

View File

@@ -46,10 +46,9 @@ func (w *mocWGIface) Name() string {
} }
func (w *mocWGIface) Address() wgaddr.Address { func (w *mocWGIface) Address() wgaddr.Address {
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
return wgaddr.Address{ return wgaddr.Address{
IP: ip, IP: netip.MustParseAddr("100.66.100.1"),
Network: network, Network: netip.MustParsePrefix("100.66.100.0/24"),
} }
} }
@@ -165,12 +164,12 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{ generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
domain: "netbird.io", domain: "netbird.io",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
dummyHandler.ID(): handlerWrapper{ dummyHandler.ID(): handlerWrapper{
domain: "netbird.cloud", domain: "netbird.cloud",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityLocal,
}, },
generateDummyHandler(".", nameServers).ID(): handlerWrapper{ generateDummyHandler(".", nameServers).ID(): handlerWrapper{
domain: nbdns.RootZone, domain: nbdns.RootZone,
@@ -187,7 +186,7 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: "netbird.cloud", domain: "netbird.cloud",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
}, },
initSerial: 0, initSerial: 0,
@@ -211,12 +210,12 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{ generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
domain: "netbird.io", domain: "netbird.io",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
"local-resolver": handlerWrapper{ "local-resolver": handlerWrapper{
domain: "netbird.cloud", domain: "netbird.cloud",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityLocal,
}, },
}, },
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}}, expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
@@ -306,7 +305,7 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name, domain: zoneRecords[0].Name,
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
}, },
initSerial: 0, initSerial: 0,
@@ -322,7 +321,7 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name, domain: zoneRecords[0].Name,
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
}, },
initSerial: 0, initSerial: 0,
@@ -464,17 +463,10 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()
_, ipNet, err := net.ParseCIDR("100.66.100.1/32")
if err != nil {
t.Errorf("parse CIDR: %v", err)
return
}
packetfilter := pfmock.NewMockPacketFilter(ctrl) packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes() packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
packetfilter.EXPECT().RemovePacketHook(gomock.Any()) packetfilter.EXPECT().RemovePacketHook(gomock.Any())
packetfilter.EXPECT().SetNetwork(ipNet)
if err := wgIface.SetFilter(packetfilter); err != nil { if err := wgIface.SetFilter(packetfilter); err != nil {
t.Errorf("set packet filter: %v", err) t.Errorf("set packet filter: %v", err)
@@ -503,7 +495,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
"id1": handlerWrapper{ "id1": handlerWrapper{
domain: zoneRecords[0].Name, domain: zoneRecords[0].Name,
handler: &local.Resolver{}, handler: &local.Resolver{},
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
} }
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}} //dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
@@ -986,7 +978,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
} }
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute) chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute)
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain) chain.AddHandler("example.com.", upstreamHandler, PriorityUpstream)
testCases := []struct { testCases := []struct {
name string name string
@@ -1067,14 +1059,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
"upstream-group2": { "upstream-group2": {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityMatchDomain - 1, priority: PriorityUpstream - 1,
}, },
} }
@@ -1101,21 +1093,21 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
"upstream-group2": { "upstream-group2": {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityMatchDomain - 1, priority: PriorityUpstream - 1,
}, },
"upstream-other": { "upstream-other": {
domain: "other.com", domain: "other.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-other", Id: "upstream-other",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
} }
@@ -1136,7 +1128,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityMatchDomain - 1, priority: PriorityUpstream - 1,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@@ -1154,7 +1146,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@@ -1172,7 +1164,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group3", Id: "upstream-group3",
}, },
priority: PriorityMatchDomain + 1, priority: PriorityUpstream + 1,
}, },
// Keep existing groups with their original priorities // Keep existing groups with their original priorities
{ {
@@ -1180,14 +1172,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
{ {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityMatchDomain - 1, priority: PriorityUpstream - 1,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@@ -1207,14 +1199,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
{ {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityMatchDomain - 1, priority: PriorityUpstream - 1,
}, },
// Add group3 with lowest priority // Add group3 with lowest priority
{ {
@@ -1222,7 +1214,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group3", Id: "upstream-group3",
}, },
priority: PriorityMatchDomain - 2, priority: PriorityUpstream - 2,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@@ -1343,14 +1335,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
{ {
domain: "other.com", domain: "other.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-other", Id: "upstream-other",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@@ -1368,28 +1360,28 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
{ {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityMatchDomain - 1, priority: PriorityUpstream - 1,
}, },
{ {
domain: "other.com", domain: "other.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-other", Id: "upstream-other",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
{ {
domain: "new.com", domain: "new.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-new", Id: "upstream-new",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@@ -1799,14 +1791,14 @@ func TestExtraDomainsRefCounting(t *testing.T) {
// Register domains from different handlers with same domain // Register domains from different handlers with same domain
server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute) server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute)
server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityMatchDomain) server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityUpstream)
// Verify refcount is 2 // Verify refcount is 2
zoneKey := toZone("shared.example.com") zoneKey := toZone("shared.example.com")
assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice") assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice")
// Deregister one handler // Deregister one handler
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityMatchDomain) server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityUpstream)
// Verify refcount is 1 // Verify refcount is 1
assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler") assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler")
@@ -1933,7 +1925,7 @@ func TestDomainCaseHandling(t *testing.T) {
} }
server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault) server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault)
server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityMatchDomain) server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityUpstream)
assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized") assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized")
@@ -1953,3 +1945,111 @@ func TestDomainCaseHandling(t *testing.T) {
assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent") assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent")
assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present") assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present")
} }
func TestLocalResolverPriorityInServer(t *testing.T) {
server := &DefaultServer{
ctx: context.Background(),
wgInterface: &mocWGIface{},
handlerChain: NewHandlerChain(),
localResolver: local.NewResolver(),
service: &mockService{},
extraDomains: make(map[domain.Domain]int),
}
config := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "local.example.com",
Records: []nbdns.SimpleRecord{
{
Name: "test.local.example.com",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.100",
},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"local.example.com"}, // Same domain as local records
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
},
},
},
}
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
assert.NoError(t, err)
upstreamMuxUpdates, err := server.buildUpstreamHandlerUpdate(config.NameServerGroups)
assert.NoError(t, err)
// Verify that local handler has higher priority than upstream for same domain
var localPriority, upstreamPriority int
localFound, upstreamFound := false, false
for _, update := range localMuxUpdates {
if update.domain == "local.example.com" {
localPriority = update.priority
localFound = true
}
}
for _, update := range upstreamMuxUpdates {
if update.domain == "local.example.com" {
upstreamPriority = update.priority
upstreamFound = true
}
}
assert.True(t, localFound, "Local handler should be found")
assert.True(t, upstreamFound, "Upstream handler should be found")
assert.Greater(t, localPriority, upstreamPriority,
"Local handler priority (%d) should be higher than upstream priority (%d)",
localPriority, upstreamPriority)
assert.Equal(t, PriorityLocal, localPriority, "Local handler should use PriorityLocal")
assert.Equal(t, PriorityUpstream, upstreamPriority, "Upstream handler should use PriorityUpstream")
}
func TestLocalResolverPriorityConstants(t *testing.T) {
// Test that priority constants are ordered correctly
assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route")
assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream")
assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default")
// Test that local resolver uses the correct priority
server := &DefaultServer{
localResolver: local.NewResolver(),
}
config := nbdns.Config{
CustomZones: []nbdns.CustomZone{
{
Domain: "local.example.com",
Records: []nbdns.SimpleRecord{
{
Name: "test.local.example.com",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.100",
},
},
},
},
}
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
assert.NoError(t, err)
assert.Len(t, localMuxUpdates, 1)
assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal")
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
}

View File

@@ -24,11 +24,15 @@ type ServiceViaMemory struct {
} }
func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory { func NewServiceViaMemory(wgIface WGIface) *ServiceViaMemory {
lastIP, err := nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1)
if err != nil {
log.Errorf("get last ip from network: %v", err)
}
s := &ServiceViaMemory{ s := &ServiceViaMemory{
wgInterface: wgIface, wgInterface: wgIface,
dnsMux: dns.NewServeMux(), dnsMux: dns.NewServeMux(),
runtimeIP: nbnet.GetLastIPFromNetwork(wgIface.Address().Network, 1).String(), runtimeIP: lastIP.String(),
runtimePort: defaultPort, runtimePort: defaultPort,
} }
return s return s
@@ -91,7 +95,7 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
} }
firstLayerDecoder := layers.LayerTypeIPv4 firstLayerDecoder := layers.LayerTypeIPv4
if s.wgInterface.Address().Network.IP.To4() == nil { if s.wgInterface.Address().IP.Is6() {
firstLayerDecoder = layers.LayerTypeIPv6 firstLayerDecoder = layers.LayerTypeIPv6
} }

View File

@@ -1,33 +0,0 @@
package dns
import (
"net"
"testing"
nbnet "github.com/netbirdio/netbird/util/net"
)
func TestGetLastIPFromNetwork(t *testing.T) {
tests := []struct {
addr string
ip string
}{
{"2001:db8::/32", "2001:db8:ffff:ffff:ffff:ffff:ffff:fffe"},
{"192.168.0.0/30", "192.168.0.2"},
{"192.168.0.0/16", "192.168.255.254"},
{"192.168.0.0/24", "192.168.0.254"},
}
for _, tt := range tests {
_, ipnet, err := net.ParseCIDR(tt.addr)
if err != nil {
t.Errorf("Error parsing CIDR: %v", err)
return
}
lastIP := nbnet.GetLastIPFromNetwork(ipnet, 1).String()
if lastIP != tt.ip {
t.Errorf("wrong IP address, expected %s: got %s", tt.ip, lastIP)
}
}
}

View File

@@ -2,6 +2,7 @@ package dns
import ( import (
"context" "context"
"crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"errors" "errors"
@@ -103,19 +104,21 @@ func (u *upstreamResolverBase) Stop() {
// ServeDNS handles a DNS request // ServeDNS handles a DNS request
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := GenerateRequestID()
logger := log.WithField("request_id", requestID)
var err error var err error
defer func() { defer func() {
u.checkUpstreamFails(err) u.checkUpstreamFails(err)
}() }()
log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
if r.Extra == nil { if r.Extra == nil {
r.MsgHdr.AuthenticatedData = true r.MsgHdr.AuthenticatedData = true
} }
select { select {
case <-u.ctx.Done(): case <-u.ctx.Done():
log.Tracef("%s has been stopped", u) logger.Tracef("%s has been stopped", u)
return return
default: default:
} }
@@ -132,35 +135,35 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if err != nil { if err != nil {
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) { if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
log.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name) logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
continue continue
} }
log.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err) logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
continue continue
} }
if rm == nil || !rm.Response { if rm == nil || !rm.Response {
log.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
continue continue
} }
u.successCount.Add(1) u.successCount.Add(1)
log.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name) logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
if err = w.WriteMsg(rm); err != nil { if err = w.WriteMsg(rm); err != nil {
log.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err) logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
} }
// count the fails only if they happen sequentially // count the fails only if they happen sequentially
u.failsCount.Store(0) u.failsCount.Store(0)
return return
} }
u.failsCount.Add(1) u.failsCount.Add(1)
log.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name) logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
m := new(dns.Msg) m := new(dns.Msg)
m.SetRcode(r, dns.RcodeServerFailure) m.SetRcode(r, dns.RcodeServerFailure)
if err := w.WriteMsg(m); err != nil { if err := w.WriteMsg(m); err != nil {
log.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err) logger.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
} }
} }
@@ -385,3 +388,13 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
return rm, t, nil return rm, t, nil
} }
func GenerateRequestID() string {
bytes := make([]byte, 4)
_, err := rand.Read(bytes)
if err != nil {
log.Errorf("failed to generate request ID: %v", err)
return ""
}
return hex.EncodeToString(bytes)
}

View File

@@ -3,6 +3,7 @@ package dns
import ( import (
"context" "context"
"net" "net"
"net/netip"
"syscall" "syscall"
"time" "time"
@@ -23,8 +24,8 @@ type upstreamResolver struct {
func newUpstreamResolver( func newUpstreamResolver(
ctx context.Context, ctx context.Context,
_ string, _ string,
_ net.IP, _ netip.Addr,
_ *net.IPNet, _ netip.Prefix,
statusRecorder *peer.Status, statusRecorder *peer.Status,
hostsDNSHolder *hostsDNSHolder, hostsDNSHolder *hostsDNSHolder,
domain string, domain string,
@@ -83,3 +84,10 @@ func (u *upstreamResolver) isLocalResolver(upstream string) bool {
} }
return false return false
} }
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
return &dns.Client{
Timeout: dialTimeout,
Net: "udp",
}, nil
}

View File

@@ -4,7 +4,7 @@ package dns
import ( import (
"context" "context"
"net" "net/netip"
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
@@ -19,8 +19,8 @@ type upstreamResolver struct {
func newUpstreamResolver( func newUpstreamResolver(
ctx context.Context, ctx context.Context,
_ string, _ string,
_ net.IP, _ netip.Addr,
_ *net.IPNet, _ netip.Prefix,
statusRecorder *peer.Status, statusRecorder *peer.Status,
_ *hostsDNSHolder, _ *hostsDNSHolder,
domain string, domain string,
@@ -36,3 +36,10 @@ func newUpstreamResolver(
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) { func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream) return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream)
} }
func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
return &dns.Client{
Timeout: dialTimeout,
Net: "udp",
}, nil
}

View File

@@ -6,6 +6,7 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"syscall" "syscall"
"time" "time"
@@ -18,16 +19,16 @@ import (
type upstreamResolverIOS struct { type upstreamResolverIOS struct {
*upstreamResolverBase *upstreamResolverBase
lIP net.IP lIP netip.Addr
lNet *net.IPNet lNet netip.Prefix
interfaceName string interfaceName string
} }
func newUpstreamResolver( func newUpstreamResolver(
ctx context.Context, ctx context.Context,
interfaceName string, interfaceName string,
ip net.IP, ip netip.Addr,
net *net.IPNet, net netip.Prefix,
statusRecorder *peer.Status, statusRecorder *peer.Status,
_ *hostsDNSHolder, _ *hostsDNSHolder,
domain string, domain string,
@@ -58,8 +59,11 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
} }
client.DialTimeout = timeout client.DialTimeout = timeout
upstreamIP := net.ParseIP(upstreamHost) upstreamIP, err := netip.ParseAddr(upstreamHost)
if u.lNet.Contains(upstreamIP) || net.IP.IsPrivate(upstreamIP) { if err != nil {
log.Warnf("failed to parse upstream host %s: %s", upstreamHost, err)
}
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
log.Debugf("using private client to query upstream: %s", upstream) log.Debugf("using private client to query upstream: %s", upstream)
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout) client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
if err != nil { if err != nil {
@@ -73,7 +77,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface // GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
// This method is needed for iOS // This method is needed for iOS
func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) { func GetClientPrivate(ip netip.Addr, interfaceName string, dialTimeout time.Duration) (*dns.Client, error) {
index, err := getInterfaceIndex(interfaceName) index, err := getInterfaceIndex(interfaceName)
if err != nil { if err != nil {
log.Debugf("unable to get interface index for %s: %s", interfaceName, err) log.Debugf("unable to get interface index for %s: %s", interfaceName, err)
@@ -82,7 +86,7 @@ func GetClientPrivate(ip net.IP, interfaceName string, dialTimeout time.Duration
dialer := &net.Dialer{ dialer := &net.Dialer{
LocalAddr: &net.UDPAddr{ LocalAddr: &net.UDPAddr{
IP: ip, IP: ip.AsSlice(),
Port: 0, // Let the OS pick a free port Port: 0, // Let the OS pick a free port
}, },
Timeout: dialTimeout, Timeout: dialTimeout,

View File

@@ -2,7 +2,7 @@ package dns
import ( import (
"context" "context"
"net" "net/netip"
"strings" "strings"
"testing" "testing"
"time" "time"
@@ -58,7 +58,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO()) ctx, cancel := context.WithCancel(context.TODO())
resolver, _ := newUpstreamResolver(ctx, "", net.IP{}, &net.IPNet{}, nil, nil, ".") resolver, _ := newUpstreamResolver(ctx, "", netip.Addr{}, netip.Prefix{}, nil, nil, ".")
resolver.upstreamServers = testCase.InputServers resolver.upstreamServers = testCase.InputServers
resolver.upstreamTimeout = testCase.timeout resolver.upstreamTimeout = testCase.timeout
if testCase.cancelCTX { if testCase.cancelCTX {

View File

@@ -18,14 +18,20 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
const errResolveFailed = "failed to resolve query for domain=%s: %v" const errResolveFailed = "failed to resolve query for domain=%s: %v"
const upstreamTimeout = 15 * time.Second const upstreamTimeout = 15 * time.Second
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
type firewaller interface {
UpdateSet(set firewall.Set, prefixes []netip.Prefix) error
}
type DNSForwarder struct { type DNSForwarder struct {
listenAddress string listenAddress string
ttl uint32 ttl uint32
@@ -38,16 +44,18 @@ type DNSForwarder struct {
mutex sync.RWMutex mutex sync.RWMutex
fwdEntries []*ForwarderEntry fwdEntries []*ForwarderEntry
firewall firewall.Manager firewall firewaller
resolver resolver
} }
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager, statusRecorder *peer.Status) *DNSForwarder { func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl) log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
return &DNSForwarder{ return &DNSForwarder{
listenAddress: listenAddress, listenAddress: listenAddress,
ttl: ttl, ttl: ttl,
firewall: firewall, firewall: firewall,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
resolver: net.DefaultResolver,
} }
} }
@@ -57,14 +65,17 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
// UDP server // UDP server
mux := dns.NewServeMux() mux := dns.NewServeMux()
f.mux = mux f.mux = mux
mux.HandleFunc(".", f.handleDNSQueryUDP)
f.dnsServer = &dns.Server{ f.dnsServer = &dns.Server{
Addr: f.listenAddress, Addr: f.listenAddress,
Net: "udp", Net: "udp",
Handler: mux, Handler: mux,
} }
// TCP server // TCP server
tcpMux := dns.NewServeMux() tcpMux := dns.NewServeMux()
f.tcpMux = tcpMux f.tcpMux = tcpMux
tcpMux.HandleFunc(".", f.handleDNSQueryTCP)
f.tcpServer = &dns.Server{ f.tcpServer = &dns.Server{
Addr: f.listenAddress, Addr: f.listenAddress,
Net: "tcp", Net: "tcp",
@@ -87,30 +98,13 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
// return the first error we get (e.g. bind failure or shutdown) // return the first error we get (e.g. bind failure or shutdown)
return <-errCh return <-errCh
} }
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
f.mutex.Lock() f.mutex.Lock()
defer f.mutex.Unlock() defer f.mutex.Unlock()
if f.mux == nil {
log.Debug("DNS mux is nil, skipping domain update")
f.fwdEntries = entries f.fwdEntries = entries
return log.Debugf("Updated DNS forwarder with %d domains", len(entries))
}
oldDomains := filterDomains(f.fwdEntries)
for _, d := range oldDomains {
f.mux.HandleRemove(d.PunycodeString())
f.tcpMux.HandleRemove(d.PunycodeString())
}
newDomains := filterDomains(entries)
for _, d := range newDomains {
f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQueryUDP)
f.tcpMux.HandleFunc(d.PunycodeString(), f.handleDNSQueryTCP)
}
f.fwdEntries = entries
log.Debugf("Updated domains from %v to %v", oldDomains, newDomains)
} }
func (f *DNSForwarder) Close(ctx context.Context) error { func (f *DNSForwarder) Close(ctx context.Context) error {
@@ -157,22 +151,31 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
return nil return nil
} }
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
// query doesn't match any configured domain
if mostSpecificResId == "" {
resp.Rcode = dns.RcodeRefused
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
}
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout) ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel() defer cancel()
ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain) ips, err := f.resolver.LookupNetIP(ctx, network, domain)
if err != nil { if err != nil {
f.handleDNSError(w, query, resp, domain, err) f.handleDNSError(w, query, resp, domain, err)
return nil return nil
} }
f.updateInternalState(domain, ips) f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips) f.addIPsToResponse(resp, domain, ips)
return resp return resp
} }
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) { func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
resp := f.handleDNSQuery(w, query) resp := f.handleDNSQuery(w, query)
if resp == nil { if resp == nil {
return return
@@ -206,9 +209,8 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
} }
} }
func (f *DNSForwarder) updateInternalState(domain string, ips []netip.Addr) { func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
var prefixes []netip.Prefix var prefixes []netip.Prefix
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
if mostSpecificResId != "" { if mostSpecificResId != "" {
for _, ip := range ips { for _, ip := range ips {
var prefix netip.Prefix var prefix netip.Prefix
@@ -339,16 +341,3 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar
return selectedResId, matches return selectedResId, matches
} }
// filterDomains returns a list of normalized domains
func filterDomains(entries []*ForwarderEntry) domain.List {
newDomains := make(domain.List, 0, len(entries))
for _, d := range entries {
if d.Domain == "" {
log.Warn("empty domain in DNS forwarder")
continue
}
newDomains = append(newDomains, domain.Domain(nbdns.NormalizeZone(d.Domain.PunycodeString())))
}
return newDomains
}

View File

@@ -1,11 +1,21 @@
package dnsfwd package dnsfwd
import ( import (
"context"
"fmt"
"net/netip"
"strings"
"testing" "testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -13,7 +23,7 @@ import (
func Test_getMatchingEntries(t *testing.T) { func Test_getMatchingEntries(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
storedMappings map[string]route.ResID // key: domain pattern, value: resId storedMappings map[string]route.ResID
queryDomain string queryDomain string
expectedResId route.ResID expectedResId route.ResID
}{ }{
@@ -44,7 +54,7 @@ func Test_getMatchingEntries(t *testing.T) {
{ {
name: "Wildcard pattern does not match different domain", name: "Wildcard pattern does not match different domain",
storedMappings: map[string]route.ResID{"*.example.com": "res4"}, storedMappings: map[string]route.ResID{"*.example.com": "res4"},
queryDomain: "foo.notexample.com", queryDomain: "foo.example.org",
expectedResId: "", expectedResId: "",
}, },
{ {
@@ -101,3 +111,619 @@ func Test_getMatchingEntries(t *testing.T) {
}) })
} }
} }
type MockFirewall struct {
mock.Mock
}
func (m *MockFirewall) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
args := m.Called(set, prefixes)
return args.Error(0)
}
type MockResolver struct {
mock.Mock
}
func (m *MockResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
args := m.Called(ctx, network, host)
return args.Get(0).([]netip.Addr), args.Error(1)
}
func TestDNSForwarder_SubdomainAccessLogic(t *testing.T) {
tests := []struct {
name string
configuredDomain string
queryDomain string
shouldMatch bool
expectedResID route.ResID
description string
}{
{
name: "exact domain match should be allowed",
configuredDomain: "example.com",
queryDomain: "example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Direct match to configured domain should work",
},
{
name: "subdomain access should be restricted",
configuredDomain: "example.com",
queryDomain: "mail.example.com",
shouldMatch: false,
expectedResID: "",
description: "Subdomain should not be accessible unless explicitly configured",
},
{
name: "wildcard should allow subdomains",
configuredDomain: "*.example.com",
queryDomain: "mail.example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Wildcard domains should allow subdomain access",
},
{
name: "wildcard should allow base domain",
configuredDomain: "*.example.com",
queryDomain: "example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Wildcard should also match the base domain",
},
{
name: "deep subdomain should be restricted",
configuredDomain: "example.com",
queryDomain: "deep.mail.example.com",
shouldMatch: false,
expectedResID: "",
description: "Deep subdomains should not be accessible",
},
{
name: "wildcard allows deep subdomains",
configuredDomain: "*.example.com",
queryDomain: "deep.mail.example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Wildcard should allow deep subdomains",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
forwarder := &DNSForwarder{}
d, err := domain.FromString(tt.configuredDomain)
require.NoError(t, err)
entries := []*ForwarderEntry{
{
Domain: d,
ResID: "test-res-id",
},
}
forwarder.UpdateDomains(entries)
resID, matchingEntries := forwarder.getMatchingEntries(tt.queryDomain)
if tt.shouldMatch {
assert.Equal(t, tt.expectedResID, resID, "Expected matching ResID")
assert.NotEmpty(t, matchingEntries, "Expected matching entries")
t.Logf("✓ Domain %s correctly matches pattern %s", tt.queryDomain, tt.configuredDomain)
} else {
assert.Equal(t, tt.expectedResID, resID, "Expected no ResID match")
assert.Empty(t, matchingEntries, "Expected no matching entries")
t.Logf("✓ Domain %s correctly does NOT match pattern %s", tt.queryDomain, tt.configuredDomain)
}
})
}
}
func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
tests := []struct {
name string
configuredDomain string
queryDomain string
shouldResolve bool
description string
}{
{
name: "configured exact domain resolves",
configuredDomain: "example.com",
queryDomain: "example.com",
shouldResolve: true,
description: "Exact match should resolve",
},
{
name: "unauthorized subdomain blocked",
configuredDomain: "example.com",
queryDomain: "mail.example.com",
shouldResolve: false,
description: "Subdomain should be blocked without wildcard",
},
{
name: "wildcard allows subdomain",
configuredDomain: "*.example.com",
queryDomain: "mail.example.com",
shouldResolve: true,
description: "Wildcard should allow subdomain",
},
{
name: "wildcard allows base domain",
configuredDomain: "*.example.com",
queryDomain: "example.com",
shouldResolve: true,
description: "Wildcard should allow base domain",
},
{
name: "unrelated domain blocked",
configuredDomain: "example.com",
queryDomain: "example.org",
shouldResolve: false,
description: "Unrelated domain should be blocked",
},
{
name: "deep subdomain blocked",
configuredDomain: "example.com",
queryDomain: "deep.mail.example.com",
shouldResolve: false,
description: "Deep subdomain should be blocked",
},
{
name: "wildcard allows deep subdomain",
configuredDomain: "*.example.com",
queryDomain: "deep.mail.example.com",
shouldResolve: true,
description: "Wildcard should allow deep subdomain",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
if tt.shouldResolve {
mockFirewall.On("UpdateSet", mock.AnythingOfType("manager.Set"), mock.AnythingOfType("[]netip.Prefix")).Return(nil)
// Mock successful DNS resolution
fakeIP := netip.MustParseAddr("1.2.3.4")
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil)
}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
d, err := domain.FromString(tt.configuredDomain)
require.NoError(t, err)
entries := []*ForwarderEntry{
{
Domain: d,
ResID: "test-res-id",
Set: firewall.NewDomainSet([]domain.Domain{d}),
},
}
forwarder.UpdateDomains(entries)
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
assert.NotEmpty(t, resp.Answer, "Expected DNS answer records")
time.Sleep(10 * time.Millisecond)
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
} else {
if resp != nil {
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
"Unauthorized domain should not return successful answers")
}
mockFirewall.AssertNotCalled(t, "UpdateSet")
mockResolver.AssertNotCalled(t, "LookupNetIP")
}
})
}
}
func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
tests := []struct {
name string
configuredDomains []string
query string
mockIP string
shouldResolve bool
expectedSetCount int // How many sets should be updated
description string
}{
{
name: "exact domain gets firewall update",
configuredDomains: []string{"example.com"},
query: "example.com",
mockIP: "1.1.1.1",
shouldResolve: true,
expectedSetCount: 1,
description: "Single exact match updates one set",
},
{
name: "wildcard domain gets firewall update",
configuredDomains: []string{"*.example.com"},
query: "mail.example.com",
mockIP: "1.1.1.2",
shouldResolve: true,
expectedSetCount: 1,
description: "Wildcard match updates one set",
},
{
name: "overlapping exact and wildcard both get updates",
configuredDomains: []string{"*.example.com", "mail.example.com"},
query: "mail.example.com",
mockIP: "1.1.1.3",
shouldResolve: true,
expectedSetCount: 2,
description: "Both exact and wildcard sets should be updated",
},
{
name: "unauthorized domain gets no firewall update",
configuredDomains: []string{"example.com"},
query: "mail.example.com",
mockIP: "1.1.1.4",
shouldResolve: false,
expectedSetCount: 0,
description: "No firewall update for unauthorized domains",
},
{
name: "multiple wildcards matching get all updated",
configuredDomains: []string{"*.example.com", "*.sub.example.com"},
query: "test.sub.example.com",
mockIP: "1.1.1.5",
shouldResolve: true,
expectedSetCount: 2,
description: "All matching wildcard sets should be updated",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
// Set up forwarder
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
// Create entries and track sets
var entries []*ForwarderEntry
sets := make([]firewall.Set, 0)
for i, configDomain := range tt.configuredDomains {
d, err := domain.FromString(configDomain)
require.NoError(t, err)
set := firewall.NewDomainSet([]domain.Domain{d})
sets = append(sets, set)
entries = append(entries, &ForwarderEntry{
Domain: d,
ResID: route.ResID(fmt.Sprintf("res-%d", i)),
Set: set,
})
}
forwarder.UpdateDomains(entries)
// Set up mocks
if tt.shouldResolve {
fakeIP := netip.MustParseAddr(tt.mockIP)
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.query)).
Return([]netip.Addr{fakeIP}, nil).Once()
expectedPrefixes := []netip.Prefix{netip.PrefixFrom(fakeIP, 32)}
// Count how many sets should actually match
updateCount := 0
for i, entry := range entries {
domain := strings.ToLower(tt.query)
pattern := entry.Domain.PunycodeString()
matches := false
if strings.HasPrefix(pattern, "*.") {
baseDomain := strings.TrimPrefix(pattern, "*.")
if domain == baseDomain || strings.HasSuffix(domain, "."+baseDomain) {
matches = true
}
} else if domain == pattern {
matches = true
}
if matches {
mockFirewall.On("UpdateSet", sets[i], expectedPrefixes).Return(nil).Once()
updateCount++
}
}
assert.Equal(t, tt.expectedSetCount, updateCount,
"Expected %d sets to be updated, but mock expects %d",
tt.expectedSetCount, updateCount)
}
// Execute query
dnsQuery := &dns.Msg{}
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, dnsQuery)
// Verify response
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.NotEmpty(t, resp.Answer)
} else if resp != nil {
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
"Unauthorized domain should be refused or have no answers")
}
// Verify all mock expectations were met
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
})
}
}
// Test to verify that multiple IPs for one domain result in all prefixes being sent together
func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
// Configure a single domain
d, err := domain.FromString("example.com")
require.NoError(t, err)
set := firewall.NewDomainSet([]domain.Domain{d})
entries := []*ForwarderEntry{{
Domain: d,
ResID: "test-res",
Set: set,
}}
forwarder.UpdateDomains(entries)
// Mock resolver returns multiple IPs
ips := []netip.Addr{
netip.MustParseAddr("1.1.1.1"),
netip.MustParseAddr("1.1.1.2"),
netip.MustParseAddr("1.1.1.3"),
}
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return(ips, nil).Once()
// Expect ONE UpdateSet call with ALL prefixes
expectedPrefixes := []netip.Prefix{
netip.PrefixFrom(ips[0], 32),
netip.PrefixFrom(ips[1], 32),
netip.PrefixFrom(ips[2], 32),
}
mockFirewall.On("UpdateSet", set, expectedPrefixes).Return(nil).Once()
// Execute query
query := &dns.Msg{}
query.SetQuestion("example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
// Verify response contains all IPs
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
// Verify mocks
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
}
func TestDNSForwarder_ResponseCodes(t *testing.T) {
tests := []struct {
name string
queryType uint16
queryDomain string
configured string
expectedCode int
description string
}{
{
name: "unauthorized domain returns REFUSED",
queryType: dns.TypeA,
queryDomain: "evil.com",
configured: "example.com",
expectedCode: dns.RcodeRefused,
description: "RFC compliant REFUSED for unauthorized queries",
},
{
name: "unsupported query type returns NOTIMP",
queryType: dns.TypeMX,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "RFC compliant NOTIMP for unsupported types",
},
{
name: "CNAME query returns NOTIMP",
queryType: dns.TypeCNAME,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "CNAME queries not supported",
},
{
name: "TXT query returns NOTIMP",
queryType: dns.TypeTXT,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "TXT queries not supported",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
d, err := domain.FromString(tt.configured)
require.NoError(t, err)
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}}
forwarder.UpdateDomains(entries)
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(tt.queryDomain), tt.queryType)
// Capture the written response
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
_ = forwarder.handleDNSQuery(mockWriter, query)
// Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
})
}
}
func TestDNSForwarder_TCPTruncation(t *testing.T) {
// Test that large UDP responses are truncated with TC bit set
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
forwarder.resolver = mockResolver
d, _ := domain.FromString("example.com")
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}}
forwarder.UpdateDomains(entries)
// Mock many IPs to create a large response
var manyIPs []netip.Addr
for i := 0; i < 100; i++ {
manyIPs = append(manyIPs, netip.MustParseAddr(fmt.Sprintf("1.1.1.%d", i%256)))
}
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").Return(manyIPs, nil)
// Query without EDNS0
query := &dns.Msg{}
query.SetQuestion("example.com.", dns.TypeA)
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
forwarder.handleDNSQueryUDP(mockWriter, query)
require.NotNil(t, writtenResp)
assert.True(t, writtenResp.Truncated, "Large response should be truncated")
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
}
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
// Test complex overlapping pattern scenarios
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
// Set up complex overlapping patterns
patterns := []string{
"*.example.com", // Matches all subdomains
"*.mail.example.com", // More specific wildcard
"smtp.mail.example.com", // Exact match
"example.com", // Base domain
}
var entries []*ForwarderEntry
sets := make(map[string]firewall.Set)
for _, pattern := range patterns {
d, _ := domain.FromString(pattern)
set := firewall.NewDomainSet([]domain.Domain{d})
sets[pattern] = set
entries = append(entries, &ForwarderEntry{
Domain: d,
ResID: route.ResID("res-" + pattern),
Set: set,
})
}
forwarder.UpdateDomains(entries)
// Test smtp.mail.example.com - should match 3 patterns
fakeIP := netip.MustParseAddr("1.2.3.4")
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "smtp.mail.example.com.").Return([]netip.Addr{fakeIP}, nil)
expectedPrefix := netip.PrefixFrom(fakeIP, 32)
// All three matching patterns should get firewall updates
mockFirewall.On("UpdateSet", sets["smtp.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
mockFirewall.On("UpdateSet", sets["*.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
mockFirewall.On("UpdateSet", sets["*.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
query := &dns.Msg{}
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
// Verify all three sets were updated
mockFirewall.AssertExpectations(t)
// Verify the most specific ResID was selected
// (exact match should win over wildcards)
resID, matches := forwarder.getMatchingEntries("smtp.mail.example.com")
assert.Equal(t, route.ResID("res-smtp.mail.example.com"), resID)
assert.Len(t, matches, 3, "Should match 3 patterns")
}
func TestDNSForwarder_EmptyQuery(t *testing.T) {
// Test handling of malformed query with no questions
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
query := &dns.Msg{}
// Don't set any question
writeCalled := false
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writeCalled = true
return nil
},
}
resp := forwarder.handleDNSQuery(mockWriter, query)
assert.Nil(t, resp, "Should return nil for empty query")
assert.False(t, writeCalled, "Should not write response for empty query")
}

View File

@@ -121,8 +121,8 @@ type EngineConfig struct {
DisableServerRoutes bool DisableServerRoutes bool
DisableDNS bool DisableDNS bool
DisableFirewall bool DisableFirewall bool
BlockLANAccess bool BlockLANAccess bool
BlockInbound bool
LazyConnectionEnabled bool LazyConnectionEnabled bool
} }
@@ -359,6 +359,7 @@ func (e *Engine) Start() error {
return fmt.Errorf("new wg interface: %w", err) return fmt.Errorf("new wg interface: %w", err)
} }
e.wgInterface = wgIface e.wgInterface = wgIface
e.statusRecorder.SetWgIface(wgIface)
// start flow manager right after interface creation // start flow manager right after interface creation
publicKey := e.config.WgPrivateKey.PublicKey() publicKey := e.config.WgPrivateKey.PublicKey()
@@ -380,7 +381,6 @@ func (e *Engine) Start() error {
return fmt.Errorf("run rosenpass manager: %w", err) return fmt.Errorf("run rosenpass manager: %w", err)
} }
} }
e.stateManager.Start() e.stateManager.Start()
initialRoutes, dnsServer, err := e.newDnsServer() initialRoutes, dnsServer, err := e.newDnsServer()
@@ -431,7 +431,8 @@ func (e *Engine) Start() error {
return fmt.Errorf("up wg interface: %w", err) return fmt.Errorf("up wg interface: %w", err)
} }
if e.firewall != nil { // if inbound conns are blocked there is no need to create the ACL manager
if e.firewall != nil && !e.config.BlockInbound {
e.acl = acl.NewDefaultManager(e.firewall) e.acl = acl.NewDefaultManager(e.firewall)
} }
@@ -487,12 +488,10 @@ func (e *Engine) createFirewall() error {
} }
func (e *Engine) initFirewall() error { func (e *Engine) initFirewall() error {
if e.firewall.IsServerRouteSupported() {
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil { if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
e.close() e.close()
return fmt.Errorf("enable server router: %w", err) return fmt.Errorf("enable server router: %w", err)
} }
}
if e.config.BlockLANAccess { if e.config.BlockLANAccess {
e.blockLanAccess() e.blockLanAccess()
@@ -525,6 +524,11 @@ func (e *Engine) initFirewall() error {
} }
func (e *Engine) blockLanAccess() { func (e *Engine) blockLanAccess() {
if e.config.BlockInbound {
// no need to set up extra deny rules if inbound is already blocked in general
return
}
var merr *multierror.Error var merr *multierror.Error
// TODO: keep this updated // TODO: keep this updated
@@ -782,6 +786,9 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
e.config.DisableServerRoutes, e.config.DisableServerRoutes,
e.config.DisableDNS, e.config.DisableDNS,
e.config.DisableFirewall, e.config.DisableFirewall,
e.config.BlockLANAccess,
e.config.BlockInbound,
e.config.LazyConnectionEnabled,
) )
if err := e.mgmClient.SyncMeta(info); err != nil { if err := e.mgmClient.SyncMeta(info); err != nil {
@@ -796,11 +803,15 @@ func isNil(server nbssh.Server) bool {
} }
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
if e.config.BlockInbound {
log.Infof("SSH server is disabled because inbound connections are blocked")
return nil
}
if !e.config.ServerSSHAllowed { if !e.config.ServerSSHAllowed {
log.Warnf("running SSH server is not permitted") log.Info("SSH server is not enabled")
return nil return nil
} else { }
if sshConf.GetSshEnabled() { if sshConf.GetSshEnabled() {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
@@ -844,8 +855,6 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
e.sshServer = nil e.sshServer = nil
} }
return nil return nil
}
} }
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
@@ -899,6 +908,9 @@ func (e *Engine) receiveManagementEvents() {
e.config.DisableServerRoutes, e.config.DisableServerRoutes,
e.config.DisableDNS, e.config.DisableDNS,
e.config.DisableFirewall, e.config.DisableFirewall,
e.config.BlockLANAccess,
e.config.BlockInbound,
e.config.LazyConnectionEnabled,
) )
// err = e.mgmClient.Sync(info, e.handleSync) // err = e.mgmClient.Sync(info, e.handleSync)
@@ -988,12 +1000,29 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
} }
} }
protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil {
protoDNSConfig = &mgmProto.DNSConfig{}
}
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
log.Errorf("failed to update dns server, err: %v", err)
}
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap) dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
// apply routes first, route related actions might depend on routing being enabled // apply routes first, route related actions might depend on routing being enabled
routes := toRoutes(networkMap.GetRoutes()) routes := toRoutes(networkMap.GetRoutes())
if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil { serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
log.Errorf("failed to update clientRoutes, err: %v", err)
// lazy mgr needs to be aware of which routes are available before they are applied
if e.connMgr != nil {
e.connMgr.UpdateRouteHAMap(clientRoutes)
log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes))
}
if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil {
log.Errorf("failed to update routes: %v", err)
} }
if e.acl != nil { if e.acl != nil {
@@ -1052,17 +1081,8 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
} }
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store // must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
excludedLazyPeers := e.toExcludedLazyPeers(routes, forwardingRules, networkMap.GetRemotePeers()) excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, networkMap.GetRemotePeers())
e.connMgr.SetExcludeList(excludedLazyPeers) e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil {
protoDNSConfig = &mgmProto.DNSConfig{}
}
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
log.Errorf("failed to update dns server, err: %v", err)
}
e.networkSerial = serial e.networkSerial = serial
@@ -1098,7 +1118,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
convertedRoute := &route.Route{ convertedRoute := &route.Route{
ID: route.ID(protoRoute.ID), ID: route.ID(protoRoute.ID),
Network: prefix, Network: prefix.Masked(),
Domains: domain.FromPunycodeList(protoRoute.Domains), Domains: domain.FromPunycodeList(protoRoute.Domains),
NetID: route.NetID(protoRoute.NetID), NetID: route.NetID(protoRoute.NetID),
NetworkType: route.NetworkType(protoRoute.NetworkType), NetworkType: route.NetworkType(protoRoute.NetworkType),
@@ -1132,7 +1152,7 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
return entries return entries
} }
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config { func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
dnsUpdate := nbdns.Config{ dnsUpdate := nbdns.Config{
ServiceEnable: protoDNSConfig.GetServiceEnable(), ServiceEnable: protoDNSConfig.GetServiceEnable(),
CustomZones: make([]nbdns.CustomZone, 0), CustomZones: make([]nbdns.CustomZone, 0),
@@ -1447,6 +1467,7 @@ func (e *Engine) close() {
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err) log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
} }
e.wgInterface = nil e.wgInterface = nil
e.statusRecorder.SetWgIface(nil)
} }
if !isNil(e.sshServer) { if !isNil(e.sshServer) {
@@ -1478,6 +1499,9 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
e.config.DisableServerRoutes, e.config.DisableServerRoutes,
e.config.DisableDNS, e.config.DisableDNS,
e.config.DisableFirewall, e.config.DisableFirewall,
e.config.BlockLANAccess,
e.config.BlockInbound,
e.config.LazyConnectionEnabled,
) )
netMap, err := e.mgmClient.GetNetworkMap(info) netMap, err := e.mgmClient.GetNetworkMap(info)
@@ -1503,6 +1527,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
MTU: iface.DefaultMTU, MTU: iface.DefaultMTU,
TransportNet: transportNet, TransportNet: transportNet,
FilterFn: e.addrViaRoutes, FilterFn: e.addrViaRoutes,
DisableDNS: e.config.DisableDNS,
} }
switch runtime.GOOS { switch runtime.GOOS {
@@ -1671,7 +1696,7 @@ func (e *Engine) RunHealthProbes() bool {
func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult { func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult {
return append( return append(
relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns), relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns),
relay.ProbeAll(e.ctx, relay.ProbeSTUN, turns)..., relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)...,
) )
} }
@@ -1784,9 +1809,9 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
} }
// GetWgAddr returns the wireguard address // GetWgAddr returns the wireguard address
func (e *Engine) GetWgAddr() net.IP { func (e *Engine) GetWgAddr() netip.Addr {
if e.wgInterface == nil { if e.wgInterface == nil {
return nil return netip.Addr{}
} }
return e.wgInterface.Address().IP return e.wgInterface.Address().IP
} }
@@ -1796,6 +1821,10 @@ func (e *Engine) updateDNSForwarder(
enabled bool, enabled bool,
fwdEntries []*dnsfwd.ForwarderEntry, fwdEntries []*dnsfwd.ForwarderEntry,
) { ) {
if e.config.DisableServerRoutes {
return
}
if !enabled { if !enabled {
if e.dnsForwardMgr == nil { if e.dnsForwardMgr == nil {
return return
@@ -1851,12 +1880,7 @@ func (e *Engine) Address() (netip.Addr, error) {
return netip.Addr{}, errors.New("wireguard interface not initialized") return netip.Addr{}, errors.New("wireguard interface not initialized")
} }
addr := e.wgInterface.Address() return e.wgInterface.Address().IP, nil
ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
return netip.Addr{}, errors.New("failed to convert address to netip.Addr")
}
return ip.Unmap(), nil
} }
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) { func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) {
@@ -1927,16 +1951,8 @@ func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewal
return forwardingRules, nberrors.FormatErrorOrNil(merr) return forwardingRules, nberrors.FormatErrorOrNil(merr)
} }
func (e *Engine) toExcludedLazyPeers(routes []*route.Route, rules []firewallManager.ForwardRule, peers []*mgmProto.RemotePeerConfig) []string { func (e *Engine) toExcludedLazyPeers(rules []firewallManager.ForwardRule, peers []*mgmProto.RemotePeerConfig) map[string]bool {
excludedPeers := make([]string, 0) excludedPeers := make(map[string]bool)
for _, r := range routes {
if r.Peer == "" {
continue
}
log.Infof("exclude router peer from lazy connection: %s", r.Peer)
excludedPeers = append(excludedPeers, r.Peer)
}
for _, r := range rules { for _, r := range rules {
ip := r.TranslatedAddress ip := r.TranslatedAddress
for _, p := range peers { for _, p := range peers {
@@ -1945,7 +1961,7 @@ func (e *Engine) toExcludedLazyPeers(routes []*route.Route, rules []firewallMana
continue continue
} }
log.Infof("exclude forwarder peer from lazy connection: %s", p.GetWgPubKey()) log.Infof("exclude forwarder peer from lazy connection: %s", p.GetWgPubKey())
excludedPeers = append(excludedPeers, p.GetWgPubKey()) excludedPeers[p.GetWgPubKey()] = true
} }
} }
} }

View File

@@ -86,8 +86,8 @@ type MockWGIface struct {
UpdateAddrFunc func(newAddr string) error UpdateAddrFunc func(newAddr string) error
UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeerFunc func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeerFunc func(peerKey string) error RemovePeerFunc func(peerKey string) error
AddAllowedIPFunc func(peerKey string, allowedIP string) error AddAllowedIPFunc func(peerKey string, allowedIP netip.Prefix) error
RemoveAllowedIPFunc func(peerKey string, allowedIP string) error RemoveAllowedIPFunc func(peerKey string, allowedIP netip.Prefix) error
CloseFunc func() error CloseFunc func() error
SetFilterFunc func(filter device.PacketFilter) error SetFilterFunc func(filter device.PacketFilter) error
GetFilterFunc func() device.PacketFilter GetFilterFunc func() device.PacketFilter
@@ -99,6 +99,10 @@ type MockWGIface struct {
GetNetFunc func() *netstack.Net GetNetFunc func() *netstack.Net
} }
func (m *MockWGIface) FullStats() (*configurer.Stats, error) {
return nil, fmt.Errorf("not implemented")
}
func (m *MockWGIface) GetInterfaceGUIDString() (string, error) { func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
return m.GetInterfaceGUIDStringFunc() return m.GetInterfaceGUIDStringFunc()
} }
@@ -143,11 +147,11 @@ func (m *MockWGIface) RemovePeer(peerKey string) error {
return m.RemovePeerFunc(peerKey) return m.RemovePeerFunc(peerKey)
} }
func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP string) error { func (m *MockWGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
return m.AddAllowedIPFunc(peerKey, allowedIP) return m.AddAllowedIPFunc(peerKey, allowedIP)
} }
func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP string) error { func (m *MockWGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
return m.RemoveAllowedIPFunc(peerKey, allowedIP) return m.RemoveAllowedIPFunc(peerKey, allowedIP)
} }
@@ -371,11 +375,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: netip.MustParseAddr("10.20.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("10.20.0.0/24"),
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
UpdatePeerFunc: func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { UpdatePeerFunc: func(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
@@ -646,7 +647,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
inputErr error inputErr error
networkMap *mgmtProto.NetworkMap networkMap *mgmtProto.NetworkMap
expectedLen int expectedLen int
expectedRoutes []*route.Route expectedClientRoutes route.HAMap
expectedSerial uint64 expectedSerial uint64
}{ }{
{ {
@@ -675,7 +676,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
}, },
}, },
expectedLen: 2, expectedLen: 2,
expectedRoutes: []*route.Route{ expectedClientRoutes: route.HAMap{
"n1|192.168.0.0/24": []*route.Route{
{ {
ID: "a", ID: "a",
Network: netip.MustParsePrefix("192.168.0.0/24"), Network: netip.MustParsePrefix("192.168.0.0/24"),
@@ -684,6 +686,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
NetworkType: 1, NetworkType: 1,
Masquerade: false, Masquerade: false,
}, },
},
"n2|192.168.1.0/24": []*route.Route{
{ {
ID: "b", ID: "b",
Network: netip.MustParsePrefix("192.168.1.0/24"), Network: netip.MustParsePrefix("192.168.1.0/24"),
@@ -693,6 +697,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
Masquerade: false, Masquerade: false,
}, },
}, },
},
expectedSerial: 1, expectedSerial: 1,
}, },
{ {
@@ -704,7 +709,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
Routes: nil, Routes: nil,
}, },
expectedLen: 0, expectedLen: 0,
expectedRoutes: []*route.Route{}, expectedClientRoutes: nil,
expectedSerial: 1, expectedSerial: 1,
}, },
{ {
@@ -717,7 +722,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
Routes: nil, Routes: nil,
}, },
expectedLen: 0, expectedLen: 0,
expectedRoutes: []*route.Route{}, expectedClientRoutes: nil,
expectedSerial: 1, expectedSerial: 1,
}, },
} }
@@ -762,15 +767,28 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
input := struct { input := struct {
inputSerial uint64 inputSerial uint64
inputRoutes []*route.Route clientRoutes route.HAMap
}{} }{}
mockRouteManager := &routemanager.MockManager{ mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error { UpdateRoutesFunc: func(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error {
input.inputSerial = updateSerial input.inputSerial = updateSerial
input.inputRoutes = newRoutes input.clientRoutes = clientRoutes
return testCase.inputErr return testCase.inputErr
}, },
ClassifyRoutesFunc: func(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap) {
if len(newRoutes) == 0 {
return nil, nil
}
// Classify all routes as client routes (not matching our public key)
clientRoutes := make(route.HAMap)
for _, r := range newRoutes {
haID := r.GetHAUniqueID()
clientRoutes[haID] = append(clientRoutes[haID], r)
}
return nil, clientRoutes
},
} }
engine.routeManager = mockRouteManager engine.routeManager = mockRouteManager
@@ -788,8 +806,8 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
err = engine.updateNetworkMap(testCase.networkMap) err = engine.updateNetworkMap(testCase.networkMap)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match") assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
assert.Len(t, input.inputRoutes, testCase.expectedLen, "clientRoutes len should match") assert.Len(t, input.clientRoutes, testCase.expectedLen, "clientRoutes len should match")
assert.Equal(t, testCase.expectedRoutes, input.inputRoutes, "clientRoutes should match") assert.Equal(t, testCase.expectedClientRoutes, input.clientRoutes, "clientRoutes should match")
}) })
} }
} }
@@ -950,7 +968,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{ mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error { UpdateRoutesFunc: func(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error {
return nil return nil
}, },
} }

View File

@@ -28,8 +28,8 @@ type wgIfaceBase interface {
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
RemoveAllowedIP(peerKey string, allowedIP string) error RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
Close() error Close() error
SetFilter(filter device.PacketFilter) error SetFilter(filter device.PacketFilter) error
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
@@ -37,4 +37,5 @@ type wgIfaceBase interface {
GetWGDevice() *wgdevice.Device GetWGDevice() *wgdevice.Device
GetStats() (map[string]configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
GetNet() *netstack.Net GetNet() *netstack.Net
FullStats() (*configurer.Stats, error)
} }

View File

@@ -68,3 +68,8 @@ func (i *Monitor) PauseTimer() {
func (i *Monitor) ResetTimer() { func (i *Monitor) ResetTimer() {
i.timer.Reset(i.inactivityThreshold) i.timer.Reset(i.inactivityThreshold)
} }
func (i *Monitor) ResetMonitor(ctx context.Context, timeoutChan chan peer.ConnID) {
i.Stop()
go i.Start(ctx, timeoutChan)
}

View File

@@ -6,6 +6,7 @@ import (
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/lazyconn" "github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/lazyconn/activity" "github.com/netbirdio/netbird/client/internal/lazyconn/activity"
@@ -13,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/dispatcher" "github.com/netbirdio/netbird/client/internal/peer/dispatcher"
peerid "github.com/netbirdio/netbird/client/internal/peer/id" peerid "github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/route"
) )
const ( const (
@@ -37,7 +39,9 @@ type Config struct {
// - Managing inactivity monitors for lazy connections (based on peer disconnection events) // - Managing inactivity monitors for lazy connections (based on peer disconnection events)
// - Maintaining a list of excluded peers that should always have permanent connections // - Maintaining a list of excluded peers that should always have permanent connections
// - Handling connection establishment based on peer signaling // - Handling connection establishment based on peer signaling
// - Managing route HA groups and activating all peers in a group when one peer is activated
type Manager struct { type Manager struct {
engineCtx context.Context
peerStore *peerstore.Store peerStore *peerstore.Store
connStateDispatcher *dispatcher.ConnectionDispatcher connStateDispatcher *dispatcher.ConnectionDispatcher
inactivityThreshold time.Duration inactivityThreshold time.Duration
@@ -51,13 +55,20 @@ type Manager struct {
activityManager *activity.Manager activityManager *activity.Manager
inactivityMonitors map[peerid.ConnID]*inactivity.Monitor inactivityMonitors map[peerid.ConnID]*inactivity.Monitor
cancel context.CancelFunc // Route HA group management
peerToHAGroups map[string][]route.HAUniqueID // peer ID -> HA groups they belong to
haGroupToPeers map[route.HAUniqueID][]string // HA group -> peer IDs in the group
routesMu sync.RWMutex
onInactive chan peerid.ConnID onInactive chan peerid.ConnID
} }
func NewManager(config Config, peerStore *peerstore.Store, wgIface lazyconn.WGIface, connStateDispatcher *dispatcher.ConnectionDispatcher) *Manager { // NewManager creates a new lazy connection manager
// engineCtx is the context for creating peer Connection
func NewManager(config Config, engineCtx context.Context, peerStore *peerstore.Store, wgIface lazyconn.WGIface, connStateDispatcher *dispatcher.ConnectionDispatcher) *Manager {
log.Infof("setup lazy connection service") log.Infof("setup lazy connection service")
m := &Manager{ m := &Manager{
engineCtx: engineCtx,
peerStore: peerStore, peerStore: peerStore,
connStateDispatcher: connStateDispatcher, connStateDispatcher: connStateDispatcher,
inactivityThreshold: inactivity.DefaultInactivityThreshold, inactivityThreshold: inactivity.DefaultInactivityThreshold,
@@ -66,6 +77,8 @@ func NewManager(config Config, peerStore *peerstore.Store, wgIface lazyconn.WGIf
excludes: make(map[string]lazyconn.PeerConfig), excludes: make(map[string]lazyconn.PeerConfig),
activityManager: activity.NewManager(wgIface), activityManager: activity.NewManager(wgIface),
inactivityMonitors: make(map[peerid.ConnID]*inactivity.Monitor), inactivityMonitors: make(map[peerid.ConnID]*inactivity.Monitor),
peerToHAGroups: make(map[string][]route.HAUniqueID),
haGroupToPeers: make(map[route.HAUniqueID][]string),
onInactive: make(chan peerid.ConnID), onInactive: make(chan peerid.ConnID),
} }
@@ -87,11 +100,45 @@ func NewManager(config Config, peerStore *peerstore.Store, wgIface lazyconn.WGIf
return m return m
} }
// UpdateRouteHAMap updates the HA group mappings for routes
// This should be called when route configuration changes
func (m *Manager) UpdateRouteHAMap(haMap route.HAMap) {
m.routesMu.Lock()
defer m.routesMu.Unlock()
maps.Clear(m.peerToHAGroups)
maps.Clear(m.haGroupToPeers)
for haUniqueID, routes := range haMap {
var peers []string
peerSet := make(map[string]bool)
for _, r := range routes {
if !peerSet[r.Peer] {
peerSet[r.Peer] = true
peers = append(peers, r.Peer)
}
}
if len(peers) <= 1 {
continue
}
m.haGroupToPeers[haUniqueID] = peers
for _, peerID := range peers {
m.peerToHAGroups[peerID] = append(m.peerToHAGroups[peerID], haUniqueID)
}
}
log.Debugf("updated route HA mappings: %d HA groups, %d peers with routes",
len(m.haGroupToPeers), len(m.peerToHAGroups))
}
// Start starts the manager and listens for peer activity and inactivity events // Start starts the manager and listens for peer activity and inactivity events
func (m *Manager) Start(ctx context.Context) { func (m *Manager) Start(ctx context.Context) {
defer m.close() defer m.close()
ctx, m.cancel = context.WithCancel(ctx)
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@@ -99,7 +146,7 @@ func (m *Manager) Start(ctx context.Context) {
case peerConnID := <-m.activityManager.OnActivityChan: case peerConnID := <-m.activityManager.OnActivityChan:
m.onPeerActivity(ctx, peerConnID) m.onPeerActivity(ctx, peerConnID)
case peerConnID := <-m.onInactive: case peerConnID := <-m.onInactive:
m.onPeerInactivityTimedOut(peerConnID) m.onPeerInactivityTimedOut(ctx, peerConnID)
} }
} }
} }
@@ -150,7 +197,7 @@ func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerCo
return added return added
} }
func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) { func (m *Manager) AddPeer(ctx context.Context, peerCfg lazyconn.PeerConfig) (bool, error) {
m.managedPeersMu.Lock() m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock() defer m.managedPeersMu.Unlock()
@@ -178,6 +225,13 @@ func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
peerCfg: &peerCfg, peerCfg: &peerCfg,
expectedWatcher: watcherActivity, expectedWatcher: watcherActivity,
} }
// Check if this peer should be activated because its HA group peers are active
if group, ok := m.shouldActivateNewPeer(peerCfg.PublicKey); ok {
peerCfg.Log.Debugf("peer belongs to active HA group %s, will activate immediately", group)
m.activateNewPeerInActiveGroup(ctx, peerCfg)
}
return false, nil return false, nil
} }
@@ -209,25 +263,47 @@ func (m *Manager) RemovePeer(peerID string) {
} }
// ActivatePeer activates a peer connection when a signal message is received // ActivatePeer activates a peer connection when a signal message is received
// Also activates all peers in the same HA groups as this peer
func (m *Manager) ActivatePeer(ctx context.Context, peerID string) (found bool) { func (m *Manager) ActivatePeer(ctx context.Context, peerID string) (found bool) {
m.managedPeersMu.Lock() m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock() defer m.managedPeersMu.Unlock()
cfg, mp := m.getPeerForActivation(peerID)
if cfg == nil {
return false
}
if !m.activateSinglePeer(ctx, cfg, mp) {
return false
}
m.activateHAGroupPeers(ctx, peerID)
return true
}
// getPeerForActivation checks if a peer can be activated and returns the necessary structs
// Returns nil values if the peer should be skipped
func (m *Manager) getPeerForActivation(peerID string) (*lazyconn.PeerConfig, *managedPeer) {
cfg, ok := m.managedPeers[peerID] cfg, ok := m.managedPeers[peerID]
if !ok { if !ok {
return false return nil, nil
} }
mp, ok := m.managedPeersByConnID[cfg.PeerConnID] mp, ok := m.managedPeersByConnID[cfg.PeerConnID]
if !ok { if !ok {
return false return nil, nil
} }
// signal messages coming continuously after success activation, with this avoid the multiple activation // signal messages coming continuously after success activation, with this avoid the multiple activation
if mp.expectedWatcher == watcherInactivity { if mp.expectedWatcher == watcherInactivity {
return false return nil, nil
} }
return cfg, mp
}
// activateSinglePeer activates a single peer (internal method)
func (m *Manager) activateSinglePeer(ctx context.Context, cfg *lazyconn.PeerConfig, mp *managedPeer) bool {
mp.expectedWatcher = watcherInactivity mp.expectedWatcher = watcherInactivity
m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID) m.activityManager.RemovePeer(cfg.Log, cfg.PeerConnID)
@@ -238,12 +314,100 @@ func (m *Manager) ActivatePeer(ctx context.Context, peerID string) (found bool)
return false return false
} }
mp.peerCfg.Log.Infof("starting inactivity monitor") cfg.Log.Infof("starting inactivity monitor")
go im.Start(ctx, m.onInactive) go im.Start(ctx, m.onInactive)
return true return true
} }
// activateHAGroupPeers activates all peers in HA groups that the given peer belongs to
func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string) {
var peersToActivate []string
m.routesMu.RLock()
haGroups := m.peerToHAGroups[triggerPeerID]
if len(haGroups) == 0 {
m.routesMu.RUnlock()
log.Debugf("peer %s is not part of any HA groups", triggerPeerID)
return
}
for _, haGroup := range haGroups {
peers := m.haGroupToPeers[haGroup]
for _, peerID := range peers {
if peerID != triggerPeerID {
peersToActivate = append(peersToActivate, peerID)
}
}
}
m.routesMu.RUnlock()
activatedCount := 0
for _, peerID := range peersToActivate {
cfg, mp := m.getPeerForActivation(peerID)
if cfg == nil {
continue
}
if m.activateSinglePeer(ctx, cfg, mp) {
activatedCount++
cfg.Log.Infof("activated peer as part of HA group (triggered by %s)", triggerPeerID)
m.peerStore.PeerConnOpen(m.engineCtx, cfg.PublicKey)
}
}
if activatedCount > 0 {
log.Infof("activated %d additional peers in HA groups for peer %s (groups: %v)",
activatedCount, triggerPeerID, haGroups)
}
}
// shouldActivateNewPeer checks if a newly added peer should be activated
// because other peers in its HA groups are already active
func (m *Manager) shouldActivateNewPeer(peerID string) (route.HAUniqueID, bool) {
m.routesMu.RLock()
defer m.routesMu.RUnlock()
haGroups := m.peerToHAGroups[peerID]
if len(haGroups) == 0 {
return "", false
}
for _, haGroup := range haGroups {
peers := m.haGroupToPeers[haGroup]
for _, groupPeerID := range peers {
if groupPeerID == peerID {
continue
}
cfg, ok := m.managedPeers[groupPeerID]
if !ok {
continue
}
if mp, ok := m.managedPeersByConnID[cfg.PeerConnID]; ok && mp.expectedWatcher == watcherInactivity {
return haGroup, true
}
}
}
return "", false
}
// activateNewPeerInActiveGroup activates a newly added peer that should be active due to HA group
func (m *Manager) activateNewPeerInActiveGroup(ctx context.Context, peerCfg lazyconn.PeerConfig) {
mp, ok := m.managedPeersByConnID[peerCfg.PeerConnID]
if !ok {
return
}
if !m.activateSinglePeer(ctx, &peerCfg, mp) {
return
}
peerCfg.Log.Infof("activated newly added peer due to active HA group peers")
m.peerStore.PeerConnOpen(m.engineCtx, peerCfg.PublicKey)
}
func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error { func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error {
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok { if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
peerCfg.Log.Warnf("peer already managed") peerCfg.Log.Warnf("peer already managed")
@@ -287,8 +451,6 @@ func (m *Manager) close() {
m.managedPeersMu.Lock() m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock() defer m.managedPeersMu.Unlock()
m.cancel()
m.connStateDispatcher.RemoveListener(m.connStateListener) m.connStateDispatcher.RemoveListener(m.connStateListener)
m.activityManager.Close() m.activityManager.Close()
for _, iw := range m.inactivityMonitors { for _, iw := range m.inactivityMonitors {
@@ -297,9 +459,58 @@ func (m *Manager) close() {
m.inactivityMonitors = make(map[peerid.ConnID]*inactivity.Monitor) m.inactivityMonitors = make(map[peerid.ConnID]*inactivity.Monitor)
m.managedPeers = make(map[string]*lazyconn.PeerConfig) m.managedPeers = make(map[string]*lazyconn.PeerConfig)
m.managedPeersByConnID = make(map[peerid.ConnID]*managedPeer) m.managedPeersByConnID = make(map[peerid.ConnID]*managedPeer)
// Clear route mappings
m.routesMu.Lock()
m.peerToHAGroups = make(map[string][]route.HAUniqueID)
m.haGroupToPeers = make(map[route.HAUniqueID][]string)
m.routesMu.Unlock()
log.Infof("lazy connection manager closed") log.Infof("lazy connection manager closed")
} }
// shouldDeferIdleForHA checks if peer should stay connected due to HA group requirements
func (m *Manager) shouldDeferIdleForHA(peerID string) bool {
m.routesMu.RLock()
defer m.routesMu.RUnlock()
haGroups := m.peerToHAGroups[peerID]
if len(haGroups) == 0 {
return false
}
for _, haGroup := range haGroups {
groupPeers := m.haGroupToPeers[haGroup]
for _, groupPeerID := range groupPeers {
if groupPeerID == peerID {
continue
}
cfg, ok := m.managedPeers[groupPeerID]
if !ok {
continue
}
groupMp, ok := m.managedPeersByConnID[cfg.PeerConnID]
if !ok {
continue
}
if groupMp.expectedWatcher != watcherInactivity {
continue
}
// Other member is still connected, defer idle
if peer, ok := m.peerStore.PeerConn(groupPeerID); ok && peer.IsConnected() {
return true
}
}
}
return false
}
func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) { func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) {
m.managedPeersMu.Lock() m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock() defer m.managedPeersMu.Unlock()
@@ -317,15 +528,16 @@ func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID)
mp.peerCfg.Log.Infof("detected peer activity") mp.peerCfg.Log.Infof("detected peer activity")
mp.expectedWatcher = watcherInactivity if !m.activateSinglePeer(ctx, mp.peerCfg, mp) {
return
mp.peerCfg.Log.Infof("starting inactivity monitor")
go m.inactivityMonitors[peerConnID].Start(ctx, m.onInactive)
m.peerStore.PeerConnOpen(ctx, mp.peerCfg.PublicKey)
} }
func (m *Manager) onPeerInactivityTimedOut(peerConnID peerid.ConnID) { m.activateHAGroupPeers(ctx, mp.peerCfg.PublicKey)
m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey)
}
func (m *Manager) onPeerInactivityTimedOut(ctx context.Context, peerConnID peerid.ConnID) {
m.managedPeersMu.Lock() m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock() defer m.managedPeersMu.Unlock()
@@ -340,6 +552,17 @@ func (m *Manager) onPeerInactivityTimedOut(peerConnID peerid.ConnID) {
return return
} }
if m.shouldDeferIdleForHA(mp.peerCfg.PublicKey) {
iw, ok := m.inactivityMonitors[peerConnID]
if ok {
mp.peerCfg.Log.Debugf("resetting inactivity timer due to HA group requirements")
iw.ResetMonitor(ctx, m.onInactive)
} else {
mp.peerCfg.Log.Errorf("inactivity monitor not found for HA defer reset")
}
return
}
mp.peerCfg.Log.Infof("connection timed out") mp.peerCfg.Log.Infof("connection timed out")
// this is blocking operation, potentially can be optimized // this is blocking operation, potentially can be optimized
@@ -373,7 +596,7 @@ func (m *Manager) onPeerConnected(peerConnID peerid.ConnID) {
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID] iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
if !ok { if !ok {
mp.peerCfg.Log.Errorf("inactivity monitor not found for peer") mp.peerCfg.Log.Warnf("inactivity monitor not found for peer")
return return
} }

View File

@@ -116,6 +116,9 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte
config.DisableServerRoutes, config.DisableServerRoutes,
config.DisableDNS, config.DisableDNS,
config.DisableFirewall, config.DisableFirewall,
config.BlockLANAccess,
config.BlockInbound,
config.LazyConnectionEnabled,
) )
_, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels) _, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
return serverKey, err return serverKey, err
@@ -139,6 +142,9 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.
config.DisableServerRoutes, config.DisableServerRoutes,
config.DisableDNS, config.DisableDNS,
config.DisableFirewall, config.DisableFirewall,
config.BlockLANAccess,
config.BlockInbound,
config.LazyConnectionEnabled,
) )
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels) loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
if err != nil { if err != nil {

View File

@@ -204,7 +204,7 @@ func (c *ConnTrack) handleEvent(event nfct.Event) {
eventStr = "Ended" eventStr = "Ended"
} }
log.Tracef("%s %s %s connection: %s:%d -> %s:%d", eventStr, direction, proto, srcIP, srcPort, dstIP, dstPort) log.Tracef("%s %s %s connection: %s:%d %s:%d", eventStr, direction, proto, srcIP, srcPort, dstIP, dstPort)
c.flowLogger.StoreEvent(nftypes.EventFields{ c.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID, FlowID: flowID,
@@ -232,7 +232,7 @@ func (c *ConnTrack) relevantFlow(mark uint32, srcIP, dstIP netip.Addr) bool {
// fallback if mark rules are not in place // fallback if mark rules are not in place
wgnet := c.iface.Address().Network wgnet := c.iface.Address().Network
return wgnet.Contains(srcIP.AsSlice()) || wgnet.Contains(dstIP.AsSlice()) return wgnet.Contains(srcIP) || wgnet.Contains(dstIP)
} }
// mapRxPackets maps packet counts to RX based on flow direction // mapRxPackets maps packet counts to RX based on flow direction
@@ -293,17 +293,15 @@ func (c *ConnTrack) inferDirection(mark uint32, srcIP, dstIP netip.Addr) nftypes
// fallback if marks are not set // fallback if marks are not set
wgaddr := c.iface.Address().IP wgaddr := c.iface.Address().IP
wgnetwork := c.iface.Address().Network wgnetwork := c.iface.Address().Network
src, dst := srcIP.AsSlice(), dstIP.AsSlice()
switch { switch {
case wgaddr.Equal(src): case wgaddr == srcIP:
return nftypes.Egress return nftypes.Egress
case wgaddr.Equal(dst): case wgaddr == dstIP:
return nftypes.Ingress return nftypes.Ingress
case wgnetwork.Contains(src): case wgnetwork.Contains(srcIP):
// netbird network -> resource network // netbird network -> resource network
return nftypes.Ingress return nftypes.Ingress
case wgnetwork.Contains(dst): case wgnetwork.Contains(dstIP):
// resource network -> netbird network // resource network -> netbird network
return nftypes.Egress return nftypes.Egress
} }

View File

@@ -2,7 +2,7 @@ package logger
import ( import (
"context" "context"
"net" "net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -23,17 +23,16 @@ type Logger struct {
rcvChan atomic.Pointer[rcvChan] rcvChan atomic.Pointer[rcvChan]
cancel context.CancelFunc cancel context.CancelFunc
statusRecorder *peer.Status statusRecorder *peer.Status
wgIfaceIPNet net.IPNet wgIfaceNet netip.Prefix
dnsCollection atomic.Bool dnsCollection atomic.Bool
exitNodeCollection atomic.Bool exitNodeCollection atomic.Bool
Store types.Store Store types.Store
} }
func New(statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger { func New(statusRecorder *peer.Status, wgIfaceIPNet netip.Prefix) *Logger {
return &Logger{ return &Logger{
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
wgIfaceIPNet: wgIfaceIPNet, wgIfaceNet: wgIfaceIPNet,
Store: store.NewMemoryStore(), Store: store.NewMemoryStore(),
} }
} }
@@ -89,11 +88,11 @@ func (l *Logger) startReceiver() {
var isSrcExitNode bool var isSrcExitNode bool
var isDestExitNode bool var isDestExitNode bool
if !l.wgIfaceIPNet.Contains(net.IP(event.SourceIP.AsSlice())) { if !l.wgIfaceNet.Contains(event.SourceIP) {
event.SourceResourceID, isSrcExitNode = l.statusRecorder.CheckRoutes(event.SourceIP) event.SourceResourceID, isSrcExitNode = l.statusRecorder.CheckRoutes(event.SourceIP)
} }
if !l.wgIfaceIPNet.Contains(net.IP(event.DestIP.AsSlice())) { if !l.wgIfaceNet.Contains(event.DestIP) {
event.DestResourceID, isDestExitNode = l.statusRecorder.CheckRoutes(event.DestIP) event.DestResourceID, isDestExitNode = l.statusRecorder.CheckRoutes(event.DestIP)
} }

Some files were not shown because too many files have changed in this diff Show More