Compare commits

...

79 Commits

Author SHA1 Message Date
Viktor Liu
e20be2397c [client] Add missing peer ACL flush (#3247) 2025-01-28 23:25:22 +01:00
Maycon Santos
46766e7e24 [misc] Update sign pipeline version (#3246) 2025-01-28 22:48:19 +01:00
Viktor Liu
a7ddb8f1f8 [client] Replace engine probes with direct calls (#3195) 2025-01-28 12:25:45 +01:00
Pascal Fischer
7335c82553 [management] copy destination and source resource on policyRUle copy (#3235) 2025-01-28 07:05:21 +01:00
Viktor Liu
a32ec97911 [client] Use dynamic dns route resolution on iOS (#3243) 2025-01-27 18:13:10 +01:00
Viktor Liu
5c05131a94 [client] Support port ranges in peer ACLs (#3232) 2025-01-27 13:51:57 +01:00
Pascal Fischer
b6abd4b4da [management/signal/relay] add metrics descriptions (#3233) 2025-01-24 14:17:30 +01:00
Pascal Fischer
2605948e01 [management] use account request buffer on sync (#3229) 2025-01-24 12:04:50 +01:00
Viktor Liu
eb2ac039c7 [client] Mark redirected traffic early to match input filters on pre-DNAT ports (#3205) 2025-01-23 18:00:51 +01:00
Viktor Liu
790a9ed7df [client] Match more specific dns handler first (#3226) 2025-01-23 18:00:05 +01:00
Viktor Liu
2e61ce006d [client] Back up corrupted state files and present them in the debug bundle (#3227) 2025-01-23 17:59:44 +01:00
Viktor Liu
3cc485759e [client] Use correct stdout/stderr log paths for debug bundle on macOS (#3231) 2025-01-23 17:59:22 +01:00
Viktor Liu
aafa9c67fc [client] Fix freebsd default routes (#3230) 2025-01-23 16:57:11 +01:00
Pascal Fischer
69f48db0a3 [management] disable prepareStmt for sqlite (#3228) 2025-01-22 19:53:20 +01:00
Pascal Fischer
8c965434ae [management] remove peer from group on delete (#3223) 2025-01-22 19:33:20 +01:00
Eddie Garcia
78da6b42ad [misc] Fix typo in test output (#3216)
Fix a typo in test output
2025-01-22 18:57:54 +01:00
Bethuel Mmbaga
1ad2cb5582 [management] Refactor peers to use store methods (#2893) 2025-01-20 18:41:46 +01:00
Viktor Liu
c619bf5b0c [client] Allow freebsd to build netbird-ui (#3212) 2025-01-20 11:02:09 +01:00
Maycon Santos
9f4db0a953 [client] Close ice agent only if not nil (#3210) 2025-01-18 00:18:59 +01:00
Pascal Fischer
3e836db1d1 [management] add duration logs to Sync (#3203) 2025-01-17 12:26:44 +01:00
Bethuel Mmbaga
c01874e9ce [management] Fix network migration issue in postgres (#3198)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-01-17 14:00:46 +03:00
Viktor Liu
1b2517ea20 [relay] Don't start relay quic listener on invalid TLS config (#3202) 2025-01-17 11:39:08 +01:00
Viktor Liu
3e9f0d57ac [client] Fix windows info out of bounds panic (#3196) 2025-01-16 22:19:32 +01:00
Zoltan Papp
481bbe8513 [relay] Set InitialPacketSize to the maximum allowable value (#3188)
Fixes an issue on macOS where the server throws errors with default settings:
failed to write transport message to: DATAGRAM frame too large.

Further investigation is required to optimize MTU-related values.
2025-01-16 16:19:07 +01:00
Viktor Liu
bc7b2c6ba3 [client] Report client system flags to management server on login (#3187) 2025-01-16 13:58:00 +01:00
Pascal Fischer
c6f7a299a9 [management] fix groups delete and resource create and update error response (#3189) 2025-01-16 13:39:15 +01:00
Viktor Liu
992a6c79b4 [client] Flush macOS DNS cache after changes (#3185) 2025-01-15 23:26:31 +01:00
Viktor Liu
78795a4a73 [client] Add block lan access flag for routers (#3171) 2025-01-15 17:39:47 +01:00
Viktor Liu
5a82477d48 [client] Remove outbound chains (#3157) 2025-01-15 16:57:41 +01:00
Zoltan Papp
1ffa519387 [client,relay] Add QUIC support (#2962) 2025-01-15 16:28:19 +01:00
Edouard Vanbelle
e4a25b6a60 [client-android] add serial, product model, product manufacturer (#2958)
Signed-off-by: Edouard Vanbelle <edouard.vanbelle@shadow.tech>
2025-01-15 16:02:16 +01:00
Zoltan Papp
6a6b527f24 [relay] Code cleaning (#3074)
- Keep message byte processing in message.go file
- Add new unit tests
2025-01-15 16:01:08 +01:00
Viktor Liu
b34887a920 [client] Fix a panic on shutdown if dns host manager failed to initialize (#3182) 2025-01-15 13:14:46 +01:00
Viktor Liu
b9efda3ce8 [client] Disable DNS host manager for netstack mode (#3183) 2025-01-15 13:14:13 +01:00
James Hilliard
516de93627 [client] Fix gvisor.dev/gvisor commit (#3179)
Commit b8a429915ff1 was replaced with db3d49b921f9 in gvisor project.
2025-01-15 10:54:51 +01:00
Viktor Liu
15f0a665f8 [client] Allow ssh server on freebsd (#3170)
* Enable ssh server on freebsd

* Fix listening in netstack mode

* Fix panic if login cmd fails

* Tidy up go mod
2025-01-14 12:43:13 +01:00
Viktor Liu
9b5b632ff9 [client] Support non-openresolv for DNS on Linux (#3176) 2025-01-14 10:39:37 +01:00
adasauce
0c28099712 [management] enable optional zitadel configuration of a PAT (#3159)
* [management] enable optional zitadel configuration of a PAT for service user via the ExtraConfig fields

* [management] validate both PAT and JWT configurations for zitadel
2025-01-14 12:38:08 +03:00
Krzysztof Nazarewski (kdn)
522dd44bfa [client] make /var/lib/netbird paths configurable (#3084)
- NB_STATE_DIR
- NB_UNCLEAN_SHUTDOWN_RESOLV_FILE
- NB_DNS_STATE_FILE
2025-01-13 10:15:01 +01:00
Maycon Santos
8154069e77 [misc] Skip docker step when fork PR (#3175) 2025-01-13 10:11:54 +01:00
Viktor Liu
e161a92898 [client] Update fyne dependency (#3155) 2025-01-12 16:29:25 +01:00
Bethuel Mmbaga
3fce8485bb Enabled new network resource and router by default (#3174)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-01-11 20:09:29 +01:00
Maycon Santos
1cc88a2190 [management] adjust benchmark (#3168) 2025-01-11 14:08:13 +01:00
Bethuel Mmbaga
168ea9560e [Management] Send peer network map when SSH status is toggled (#3172) 2025-01-11 13:19:30 +01:00
Viktor Liu
f48e33b395 [client] Don't fail on v6 ops when disabled via kernel params (#3165) 2025-01-10 18:16:21 +01:00
Simon Smith
f1ed8599fc [misc] add missing relay to docker-compose.yml.tmpl.traefik (#3163) 2025-01-10 18:16:11 +01:00
Viktor Liu
93f3e1b14b [client] Prevent local routes in status from being overridden by updates (#3166) 2025-01-10 11:02:05 +01:00
Maycon Santos
649bfb236b [management] Send relay credentials with turn updates (#3164)
send relay credentials when sending turn credentials update to avoid removing servers
from clients
2025-01-10 09:44:02 +01:00
Bethuel Mmbaga
409003b4f9 [management] Add support for disabling resources and routing peers in networks (#3154)
* sync openapi changes

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add option to disable network resource(s)

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add network resource enabled state from api

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add option to disable network router(s)

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* migrate old network resources and routers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-01-08 19:35:57 +03:00
Simon Smith
9e6e34b42d [misc] Upgrade go to 1.23 inn devcontainer (#3160) 2025-01-08 11:48:10 +01:00
Viktor Liu
d9905d1a57 [client] Add disable system flags (#3153) 2025-01-07 20:38:18 +01:00
Joakim Nohlgård
2bd68efc08 [relay] Handle IPv6 addresses in X-Real-IP header on relay service (#3085) 2025-01-06 17:31:35 +01:00
Viktor Liu
6848e1e128 [client] Add rootless container and fix client routes in netstack mode (#3150) 2025-01-06 14:16:31 +01:00
Viktor Liu
668aead4c8 [misc] remove outdated readme header (#3151) 2025-01-06 14:12:28 +01:00
Viktor Liu
f08605a7f1 [client] Enable network map persistence by default (#3152) 2025-01-06 14:11:43 +01:00
Bethuel Mmbaga
02a3feddb8 [management] Add MySQL Support (#3108)
* Add mysql store support
* Add support to disable activity events recording
2025-01-06 13:38:30 +01:00
Pascal Fischer
d9487a5749 [misc] separate integration and benchmark test workflows (#3147) 2025-01-03 15:48:31 +01:00
Pascal Fischer
cfa6d09c5e [management] add peers benchmark (#3143) 2025-01-03 15:28:15 +01:00
Pascal Fischer
a01253c3c8 [management] add users benchmark (#3141) 2025-01-03 15:24:30 +01:00
Pascal Fischer
bc013e4888 [management] exclude self from network map if self is routing peer (#3142) 2025-01-02 18:46:28 +01:00
Pascal Fischer
782e3f8853 [management] Add integration test for the setup-keys API endpoints (#2936) 2025-01-02 13:51:01 +01:00
Maycon Santos
03fd656344 [management] Fix policy tests (#3135)
- Add firewall rule isEqual method
- Fix tests
2024-12-31 18:45:40 +01:00
Pascal Fischer
18b049cd24 [management] remove sorting from network map generation (#3126) 2024-12-31 18:10:40 +01:00
Bethuel Mmbaga
2bdb4cb44a [management] Preserve jwt groups when accessing API with PAT (#3128)
* Skip JWT group sync for token-based authentication

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-12-31 18:59:37 +03:00
Viktor Liu
abbdf20f65 [client] Allow inbound rosenpass port (#3109) 2024-12-31 14:08:48 +01:00
Viktor Liu
43ef64cf67 [client] Ignore case when matching domains in handler chain (#3133) 2024-12-31 14:07:21 +01:00
Pascal Fischer
18316be09a [management] add selfhosted metrics for networks (#3118) 2024-12-30 12:53:51 +01:00
Maycon Santos
1a623943c8 [management] Fix networks net map generation with posture checks (#3124) 2024-12-30 12:40:24 +01:00
Pascal Fischer
fbce8bb511 [management] remove ids from policy creation api (#2997) 2024-12-27 14:13:36 +01:00
Bethuel Mmbaga
445b626dc8 [management] Add missing group usage checks for network resources and routes access control (#3117)
* Prevent deletion of groups linked to routes access control groups

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Prevent deletion of groups linked to network resource

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-12-27 14:39:34 +03:00
Viktor Liu
b3c87cb5d1 [client] Fix inbound tracking in userspace firewall (#3111)
* Don't create state for inbound SYN

* Allow final ack in some cases

* Relax state machine test a little
2024-12-26 00:51:27 +01:00
Viktor Liu
0dbaddc7be [client] Don't fail debug if log file is console (#3103) 2024-12-24 15:05:23 +01:00
Viktor Liu
ad9f044aad [client] Add stateful userspace firewall and remove egress filters (#3093)
- Add stateful firewall functionality for UDP/TCP/ICMP in userspace firewalll
- Removes all egress drop rules/filters, still needs refactoring so we don't add output rules to any chains/filters.
- on Linux, if the OUTPUT policy is DROP  then we don't do anything about it (no extra allow rules). This is up to the user, if they don't want anything leaving their machine they'll have to manage these rules explicitly.
2024-12-23 18:22:17 +01:00
Viktor Liu
05930ee6b1 [client] Add firewall rules to the debug bundle (#3089)
Adds the following to the debug bundle:
- iptables: `iptables-save`, `iptables -v -n -L`
- nftables: `nft list ruleset` or if not available formatted output from netlink (WIP)
2024-12-23 15:57:15 +01:00
Pascal Fischer
e670068cab [management] Run test sequential (#3101) 2024-12-23 14:37:09 +01:00
Viktor Liu
b48cf1bf65 [client] Reduce DNS handler chain lock contention (#3099) 2024-12-21 15:56:52 +01:00
Bethuel Mmbaga
7ee7ada273 [management] Fix duplicate resource routes when routing peer is part of the source group (#3095)
* Remove duplicate resource routes when routing peer is part of the source group

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-12-20 21:10:53 +03:00
Zoltan Papp
82b4e58ad0 Do not start DNS forwarder on client side (#3094) 2024-12-20 16:20:50 +01:00
Viktor Liu
ddc365f7a0 [client, management] Add new network concept (#3047)
---------

Co-authored-by: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com>
Co-authored-by: bcmmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
2024-12-20 11:30:28 +01:00
301 changed files with 27294 additions and 7870 deletions

View File

@@ -1,4 +1,4 @@
FROM golang:1.21-bullseye
FROM golang:1.23-bullseye
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
&& apt-get -y install --no-install-recommends\

View File

@@ -7,7 +7,7 @@
"features": {
"ghcr.io/devcontainers/features/docker-in-docker:2": {},
"ghcr.io/devcontainers/features/go:1": {
"version": "1.21"
"version": "1.23"
}
},
"workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}",

View File

@@ -44,4 +44,5 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)

View File

@@ -24,7 +24,7 @@ jobs:
copyback: false
release: "14.1"
prepare: |
pkg install -y go
pkg install -y go pkgconf xorg
# -x - to print all executed commands
# -e - to faile on first error
@@ -33,7 +33,7 @@ jobs:
time go build -o netbird client/main.go
# check all component except management, since we do not support management server on freebsd
time go test -timeout 1m -failfast ./base62/...
# NOTE: without -p1 `client/internal/dns` will fail becasue of `listen udp4 :33100: bind: address already in use`
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
time go test -timeout 8m -failfast -p 1 ./client/...
time go test -timeout 1m -failfast ./dns/...
time go test -timeout 1m -failfast ./encryption/...

View File

@@ -13,7 +13,7 @@ concurrency:
jobs:
build-cache:
runs-on: ubuntu-22.04
steps:
steps:
- name: Checkout code
uses: actions/checkout@v4
@@ -134,9 +134,189 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management)
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management)
test_management:
needs: [ build-cache ]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Login to Docker hub
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management)
benchmark:
needs: [ build-cache ]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Login to Docker hub
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags devcert -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./...
api_benchmark:
needs: [ build-cache ]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Login to Docker hub
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: download mysql image
if: matrix.store == 'mysql'
run: docker pull mlsmaycon/warmed-mysql:8
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -tags=benchmark -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=benchmark ./... | grep /management)
api_integration_test:
needs: [ build-cache ]
strategy:
fail-fast: false
@@ -183,56 +363,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management)
benchmark:
needs: [ build-cache ]
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m ./...
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=integration -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 30m $(go list -tags=integration ./... | grep /management)
test_client_on_docker:
needs: [ build-cache ]

View File

@@ -65,7 +65,7 @@ jobs:
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
- name: test
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
- name: test output
if: ${{ always() }}
run: Get-Content test-out.txt

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd
ignore_words_list: erro,clienta,hastable,iif,groupd,testin
skip: go.mod,go.sum
only_warn: 1
golangci:
@@ -46,7 +46,7 @@ jobs:
if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: golangci-lint
uses: golangci/golangci-lint-action@v3
uses: golangci/golangci-lint-action@v4
with:
version: latest
args: --timeout=12m
args: --timeout=12m --out-format colored-line-number

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.0.17"
SIGN_PIPE_VER: "v0.0.18"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"

View File

@@ -20,7 +20,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
store: [ 'sqlite', 'postgres' ]
store: [ 'sqlite', 'postgres', 'mysql' ]
services:
postgres:
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
@@ -34,6 +34,19 @@ jobs:
--health-timeout 5s
ports:
- 5432:5432
mysql:
image: ${{ (matrix.store == 'mysql') && 'mysql' || '' }}
env:
MYSQL_USER: netbird
MYSQL_PASSWORD: mysql
MYSQL_ROOT_PASSWORD: mysqlroot
MYSQL_DATABASE: netbird
options: >-
--health-cmd "mysqladmin ping --silent"
--health-interval 10s
--health-timeout 5s
ports:
- 3306:3306
steps:
- name: Set Database Connection String
run: |
@@ -42,6 +55,11 @@ jobs:
else
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN==" >> $GITHUB_ENV
fi
if [ "${{ matrix.store }}" == "mysql" ]; then
echo "NETBIRD_STORE_ENGINE_MYSQL_DSN=netbird:mysql@tcp($(hostname -I | awk '{print $1}'):3306)/netbird" >> $GITHUB_ENV
else
echo "NETBIRD_STORE_ENGINE_MYSQL_DSN==" >> $GITHUB_ENV
fi
- name: Install jq
run: sudo apt-get install -y jq
@@ -84,6 +102,7 @@ jobs:
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
NETBIRD_STORE_ENGINE_POSTGRES_DSN: ${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}
NETBIRD_STORE_ENGINE_MYSQL_DSN: ${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
- name: check values
@@ -112,6 +131,7 @@ jobs:
CI_NETBIRD_SIGNAL_PORT: 12345
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
@@ -149,6 +169,7 @@ jobs:
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
# check relay values
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml

View File

@@ -179,6 +179,51 @@ dockers:
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-amd64
ids:
- netbird
goarch: amd64
use: buildx
dockerfile: client/Dockerfile-rootless
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
ids:
- netbird
goarch: arm64
use: buildx
dockerfile: client/Dockerfile-rootless
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm
ids:
- netbird
goarch: arm
goarm: 6
use: buildx
dockerfile: client/Dockerfile-rootless
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-amd64
ids:
@@ -377,6 +422,18 @@ docker_manifests:
- netbirdio/netbird:{{ .Version }}-arm
- netbirdio/netbird:{{ .Version }}-amd64
- name_template: netbirdio/netbird:{{ .Version }}-rootless
image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- netbirdio/netbird:{{ .Version }}-rootless-arm
- netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: netbirdio/netbird:rootless-latest
image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- netbirdio/netbird:{{ .Version }}-rootless-arm
- netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: netbirdio/relay:{{ .Version }}
image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8

View File

@@ -1,10 +1,3 @@
<p align="center">
<strong>:hatching_chick: New Release! Device Posture Checks.</strong>
<a href="https://docs.netbird.io/how-to/manage-posture-checks">
Learn more
</a>
</p>
<br/>
<div align="center">
<p align="center">
<img width="234" src="docs/media/logo-full.png"/>

View File

@@ -0,0 +1,16 @@
FROM alpine:3.21.0
COPY netbird /usr/local/bin/netbird
RUN apk add --no-cache ca-certificates \
&& adduser -D -h /var/lib/netbird netbird
WORKDIR /var/lib/netbird
USER netbird:netbird
ENV NB_FOREGROUND_MODE=true
ENV NB_USE_NETSTACK_MODE=true
ENV NB_CONFIG=config.json
ENV NB_DAEMON_ADDR=unix://netbird.sock
ENV NB_DISABLE_DNS=true
ENTRYPOINT [ "/usr/local/bin/netbird", "up" ]

View File

@@ -162,7 +162,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
// check if we need to generate JWT token
err := a.withBackOff(a.ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey)
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
return
})
if err != nil {

View File

@@ -21,6 +21,8 @@ type Anonymizer struct {
currentAnonIPv6 netip.Addr
startAnonIPv4 netip.Addr
startAnonIPv6 netip.Addr
domainKeyRegex *regexp.Regexp
}
func DefaultAddresses() (netip.Addr, netip.Addr) {
@@ -36,6 +38,8 @@ func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
currentAnonIPv6: startIPv6,
startAnonIPv4: startIPv4,
startAnonIPv6: startIPv6,
domainKeyRegex: regexp.MustCompile(`\bdomain=([^\s,:"]+)`),
}
}
@@ -171,20 +175,15 @@ func (a *Anonymizer) AnonymizeSchemeURI(text string) string {
return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
}
// AnonymizeDNSLogLine anonymizes domain names in DNS log entries by replacing them with a random string.
func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
domainPattern := `dns\.Question{Name:"([^"]+)",`
domainRegex := regexp.MustCompile(domainPattern)
return domainRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
parts := strings.Split(match, `"`)
return a.domainKeyRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
parts := strings.SplitN(match, "=", 2)
if len(parts) >= 2 {
domain := parts[1]
if strings.HasSuffix(domain, anonTLD) {
return match
}
randomDomain := generateRandomString(10) + anonTLD
return strings.Replace(match, domain, randomDomain, 1)
return "domain=" + a.AnonymizeDomain(domain)
}
return match
})

View File

@@ -46,11 +46,59 @@ func TestAnonymizeIP(t *testing.T) {
func TestAnonymizeDNSLogLine(t *testing.T) {
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
testLog := `2024-04-23T20:01:11+02:00 TRAC client/internal/dns/local.go:25: received question: dns.Question{Name:"example.com", Qtype:0x1c, Qclass:0x1}`
tests := []struct {
name string
input string
original string
expect string
}{
{
name: "Basic domain with trailing content",
input: "received DNS request for DNS forwarder: domain=example.com: something happened with code=123",
original: "example.com",
expect: `received DNS request for DNS forwarder: domain=anon-[a-zA-Z0-9]+\.domain: something happened with code=123`,
},
{
name: "Domain with trailing dot",
input: "domain=example.com. processing request with status=pending",
original: "example.com",
expect: `domain=anon-[a-zA-Z0-9]+\.domain\. processing request with status=pending`,
},
{
name: "Multiple domains in log",
input: "forward domain=first.com status=ok, redirect to domain=second.com port=443",
original: "first.com", // testing just one is sufficient as AnonymizeDomain is tested separately
expect: `forward domain=anon-[a-zA-Z0-9]+\.domain status=ok, redirect to domain=anon-[a-zA-Z0-9]+\.domain port=443`,
},
{
name: "Already anonymized domain",
input: "got request domain=anon-xyz123.domain from=client1 to=server2",
original: "", // nothing should be anonymized
expect: `got request domain=anon-xyz123\.domain from=client1 to=server2`,
},
{
name: "Subdomain with trailing dot",
input: "domain=sub.example.com. next_hop=10.0.0.1 proto=udp",
original: "example.com",
expect: `domain=sub\.anon-[a-zA-Z0-9]+\.domain\. next_hop=10\.0\.0\.1 proto=udp`,
},
{
name: "Handler chain pattern log",
input: "pattern: domain=example.com. original: domain=*.example.com. wildcard=true priority=100",
original: "example.com",
expect: `pattern: domain=anon-[a-zA-Z0-9]+\.domain\. original: domain=\*\.anon-[a-zA-Z0-9]+\.domain\. wildcard=true priority=100`,
},
}
result := anonymizer.AnonymizeDNSLogLine(testLog)
require.NotEqual(t, testLog, result)
assert.NotContains(t, result, "example.com")
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := anonymizer.AnonymizeDNSLogLine(tc.input)
if tc.original != "" {
assert.NotContains(t, result, tc.original)
}
assert.Regexp(t, tc.expect, result)
})
}
}
func TestAnonymizeDomain(t *testing.T) {

173
client/cmd/networks.go Normal file
View File

@@ -0,0 +1,173 @@
package cmd
import (
"fmt"
"strings"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var appendFlag bool
var networksCMD = &cobra.Command{
Use: "networks",
Aliases: []string{"routes"},
Short: "Manage networks",
Long: `Commands to list, select, or deselect networks. Replaces the "routes" command.`,
}
var routesListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List networks",
Example: " netbird networks list",
Long: "List all available network routes.",
RunE: networksList,
}
var routesSelectCmd = &cobra.Command{
Use: "select network...|all",
Short: "Select network",
Long: "Select a list of networks by identifiers or 'all' to clear all selections and to accept all (including new) networks.\nDefault mode is replace, use -a to append to already selected networks.",
Example: " netbird networks select all\n netbird networks select route1 route2\n netbird routes select -a route3",
Args: cobra.MinimumNArgs(1),
RunE: networksSelect,
}
var routesDeselectCmd = &cobra.Command{
Use: "deselect network...|all",
Short: "Deselect networks",
Long: "Deselect previously selected networks by identifiers or 'all' to disable accepting any networks.",
Example: " netbird networks deselect all\n netbird networks deselect route1 route2",
Args: cobra.MinimumNArgs(1),
RunE: networksDeselect,
}
func init() {
routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current network selection instead of replacing")
}
func networksList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.ListNetworks(cmd.Context(), &proto.ListNetworksRequest{})
if err != nil {
return fmt.Errorf("failed to list network: %v", status.Convert(err).Message())
}
if len(resp.Routes) == 0 {
cmd.Println("No networks available.")
return nil
}
printNetworks(cmd, resp)
return nil
}
func printNetworks(cmd *cobra.Command, resp *proto.ListNetworksResponse) {
cmd.Println("Available Networks:")
for _, route := range resp.Routes {
printNetwork(cmd, route)
}
}
func printNetwork(cmd *cobra.Command, route *proto.Network) {
selectedStatus := getSelectedStatus(route)
domains := route.GetDomains()
if len(domains) > 0 {
printDomainRoute(cmd, route, domains, selectedStatus)
} else {
printNetworkRoute(cmd, route, selectedStatus)
}
}
func getSelectedStatus(route *proto.Network) string {
if route.GetSelected() {
return "Selected"
}
return "Not Selected"
}
func printDomainRoute(cmd *cobra.Command, route *proto.Network, domains []string, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus)
resolvedIPs := route.GetResolvedIPs()
if len(resolvedIPs) > 0 {
printResolvedIPs(cmd, domains, resolvedIPs)
} else {
cmd.Printf(" Resolved IPs: -\n")
}
}
func printNetworkRoute(cmd *cobra.Command, route *proto.Network, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetRange(), selectedStatus)
}
func printResolvedIPs(cmd *cobra.Command, _ []string, resolvedIPs map[string]*proto.IPList) {
cmd.Printf(" Resolved IPs:\n")
for resolvedDomain, ipList := range resolvedIPs {
cmd.Printf(" [%s]: %s\n", resolvedDomain, strings.Join(ipList.GetIps(), ", "))
}
}
func networksSelect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectNetworksRequest{
NetworkIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
} else if appendFlag {
req.Append = true
}
if _, err := client.SelectNetworks(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to select networks: %v", status.Convert(err).Message())
}
cmd.Println("Networks selected successfully.")
return nil
}
func networksDeselect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectNetworksRequest{
NetworkIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
}
if _, err := client.DeselectNetworks(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to deselect networks: %v", status.Convert(err).Message())
}
cmd.Println("Networks deselected successfully.")
return nil
}

View File

@@ -38,6 +38,7 @@ const (
extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info"
blockLANAccessFlag = "block-lan-access"
)
var (
@@ -73,6 +74,7 @@ var (
anonymizeFlag bool
debugSystemInfoFlag bool
dnsRouteInterval time.Duration
blockLANAccess bool
rootCmd = &cobra.Command{
Use: "netbird",
@@ -142,14 +144,14 @@ func init() {
rootCmd.AddCommand(loginCmd)
rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd)
rootCmd.AddCommand(routesCmd)
rootCmd.AddCommand(networksCMD)
rootCmd.AddCommand(debugCmd)
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
routesCmd.AddCommand(routesListCmd)
routesCmd.AddCommand(routesSelectCmd, routesDeselectCmd)
networksCMD.AddCommand(routesListCmd)
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
debugCmd.AddCommand(debugBundleCmd)
debugCmd.AddCommand(logCmd)

View File

@@ -1,174 +0,0 @@
package cmd
import (
"fmt"
"strings"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var appendFlag bool
var routesCmd = &cobra.Command{
Use: "routes",
Short: "Manage network routes",
Long: `Commands to list, select, or deselect network routes.`,
}
var routesListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List routes",
Example: " netbird routes list",
Long: "List all available network routes.",
RunE: routesList,
}
var routesSelectCmd = &cobra.Command{
Use: "select route...|all",
Short: "Select routes",
Long: "Select a list of routes by identifiers or 'all' to clear all selections and to accept all (including new) routes.\nDefault mode is replace, use -a to append to already selected routes.",
Example: " netbird routes select all\n netbird routes select route1 route2\n netbird routes select -a route3",
Args: cobra.MinimumNArgs(1),
RunE: routesSelect,
}
var routesDeselectCmd = &cobra.Command{
Use: "deselect route...|all",
Short: "Deselect routes",
Long: "Deselect previously selected routes by identifiers or 'all' to disable accepting any routes.",
Example: " netbird routes deselect all\n netbird routes deselect route1 route2",
Args: cobra.MinimumNArgs(1),
RunE: routesDeselect,
}
func init() {
routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current route selection instead of replacing")
}
func routesList(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.ListRoutes(cmd.Context(), &proto.ListRoutesRequest{})
if err != nil {
return fmt.Errorf("failed to list routes: %v", status.Convert(err).Message())
}
if len(resp.Routes) == 0 {
cmd.Println("No routes available.")
return nil
}
printRoutes(cmd, resp)
return nil
}
func printRoutes(cmd *cobra.Command, resp *proto.ListRoutesResponse) {
cmd.Println("Available Routes:")
for _, route := range resp.Routes {
printRoute(cmd, route)
}
}
func printRoute(cmd *cobra.Command, route *proto.Route) {
selectedStatus := getSelectedStatus(route)
domains := route.GetDomains()
if len(domains) > 0 {
printDomainRoute(cmd, route, domains, selectedStatus)
} else {
printNetworkRoute(cmd, route, selectedStatus)
}
}
func getSelectedStatus(route *proto.Route) string {
if route.GetSelected() {
return "Selected"
}
return "Not Selected"
}
func printDomainRoute(cmd *cobra.Command, route *proto.Route, domains []string, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus)
resolvedIPs := route.GetResolvedIPs()
if len(resolvedIPs) > 0 {
printResolvedIPs(cmd, domains, resolvedIPs)
} else {
cmd.Printf(" Resolved IPs: -\n")
}
}
func printNetworkRoute(cmd *cobra.Command, route *proto.Route, selectedStatus string) {
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
}
func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) {
cmd.Printf(" Resolved IPs:\n")
for _, domain := range domains {
if ipList, exists := resolvedIPs[domain]; exists {
cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", "))
}
}
}
func routesSelect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectRoutesRequest{
RouteIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
} else if appendFlag {
req.Append = true
}
if _, err := client.SelectRoutes(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to select routes: %v", status.Convert(err).Message())
}
cmd.Println("Routes selected successfully.")
return nil
}
func routesDeselect(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
req := &proto.SelectRoutesRequest{
RouteIDs: args,
}
if len(args) == 1 && args[0] == "all" {
req.All = true
}
if _, err := client.DeselectRoutes(cmd.Context(), req); err != nil {
return fmt.Errorf("failed to deselect routes: %v", status.Convert(err).Message())
}
cmd.Println("Routes deselected successfully.")
return nil
}

View File

@@ -9,7 +9,6 @@ import (
"strings"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/client/internal"
@@ -73,7 +72,7 @@ var sshCmd = &cobra.Command{
go func() {
// blocking
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
log.Debug(err)
cmd.Printf("Error: %v\n", err)
os.Exit(1)
}
cancel()

View File

@@ -40,6 +40,7 @@ type peerStateDetailOutput struct {
Latency time.Duration `json:"latency" yaml:"latency"`
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
Routes []string `json:"routes" yaml:"routes"`
Networks []string `json:"networks" yaml:"networks"`
}
type peersStateOutput struct {
@@ -98,6 +99,7 @@ type statusOutputOverview struct {
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
Routes []string `json:"routes" yaml:"routes"`
Networks []string `json:"networks" yaml:"networks"`
NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
}
@@ -282,7 +284,8 @@ func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverv
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
Routes: pbFullStatus.GetLocalPeerState().GetRoutes(),
Routes: pbFullStatus.GetLocalPeerState().GetNetworks(),
Networks: pbFullStatus.GetLocalPeerState().GetNetworks(),
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
}
@@ -390,7 +393,8 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput {
TransferSent: transferSent,
Latency: pbPeerState.GetLatency().AsDuration(),
RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
Routes: pbPeerState.GetRoutes(),
Routes: pbPeerState.GetNetworks(),
Networks: pbPeerState.GetNetworks(),
}
peersStateDetail = append(peersStateDetail, peerState)
@@ -491,10 +495,10 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
}
routes := "-"
if len(overview.Routes) > 0 {
sort.Strings(overview.Routes)
routes = strings.Join(overview.Routes, ", ")
networks := "-"
if len(overview.Networks) > 0 {
sort.Strings(overview.Networks)
networks = strings.Join(overview.Networks, ", ")
}
var dnsServersString string
@@ -556,6 +560,7 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
"Interface type: %s\n"+
"Quantum resistance: %s\n"+
"Routes: %s\n"+
"Networks: %s\n"+
"Peers count: %s\n",
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
overview.DaemonVersion,
@@ -568,7 +573,8 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays
interfaceIP,
interfaceTypeString,
rosenpassEnabledStatus,
routes,
networks,
networks,
peersCountString,
)
return summary
@@ -631,10 +637,10 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
}
}
routes := "-"
if len(peerState.Routes) > 0 {
sort.Strings(peerState.Routes)
routes = strings.Join(peerState.Routes, ", ")
networks := "-"
if len(peerState.Networks) > 0 {
sort.Strings(peerState.Networks)
networks = strings.Join(peerState.Networks, ", ")
}
peerString := fmt.Sprintf(
@@ -652,6 +658,7 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
" Transfer status (received/sent) %s/%s\n"+
" Quantum resistance: %s\n"+
" Routes: %s\n"+
" Networks: %s\n"+
" Latency: %s\n",
peerState.FQDN,
peerState.IP,
@@ -668,7 +675,8 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
toIEC(peerState.TransferReceived),
toIEC(peerState.TransferSent),
rosenpassEnabledStatus,
routes,
networks,
networks,
peerState.Latency.String(),
)
@@ -810,6 +818,14 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
for i, route := range peer.Networks {
peer.Networks[i] = a.AnonymizeIPString(route)
}
for i, route := range peer.Networks {
peer.Networks[i] = a.AnonymizeRoute(route)
}
for i, route := range peer.Routes {
peer.Routes[i] = a.AnonymizeIPString(route)
}
@@ -850,6 +866,10 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview)
}
}
for i, route := range overview.Networks {
overview.Networks[i] = a.AnonymizeRoute(route)
}
for i, route := range overview.Routes {
overview.Routes[i] = a.AnonymizeRoute(route)
}

View File

@@ -44,7 +44,7 @@ var resp = &proto.StatusResponse{
LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)),
BytesRx: 200,
BytesTx: 100,
Routes: []string{
Networks: []string{
"10.1.0.0/24",
},
Latency: durationpb.New(time.Duration(10000000)),
@@ -93,7 +93,7 @@ var resp = &proto.StatusResponse{
PubKey: "Some-Pub-Key",
KernelInterface: true,
Fqdn: "some-localhost.awesome-domain.com",
Routes: []string{
Networks: []string{
"10.10.0.0/24",
},
},
@@ -149,6 +149,9 @@ var overview = statusOutputOverview{
Routes: []string{
"10.1.0.0/24",
},
Networks: []string{
"10.1.0.0/24",
},
Latency: time.Duration(10000000),
},
{
@@ -230,6 +233,9 @@ var overview = statusOutputOverview{
Routes: []string{
"10.10.0.0/24",
},
Networks: []string{
"10.10.0.0/24",
},
}
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
@@ -295,6 +301,9 @@ func TestParsingToJSON(t *testing.T) {
"quantumResistance": false,
"routes": [
"10.1.0.0/24"
],
"networks": [
"10.1.0.0/24"
]
},
{
@@ -318,7 +327,8 @@ func TestParsingToJSON(t *testing.T) {
"transferSent": 1000,
"latency": 10000000,
"quantumResistance": false,
"routes": null
"routes": null,
"networks": null
}
]
},
@@ -359,6 +369,9 @@ func TestParsingToJSON(t *testing.T) {
"routes": [
"10.10.0.0/24"
],
"networks": [
"10.10.0.0/24"
],
"dnsServers": [
{
"servers": [
@@ -418,6 +431,8 @@ func TestParsingToYAML(t *testing.T) {
quantumResistance: false
routes:
- 10.1.0.0/24
networks:
- 10.1.0.0/24
- fqdn: peer-2.awesome-domain.com
netbirdIp: 192.168.178.102
publicKey: Pubkey2
@@ -437,6 +452,7 @@ func TestParsingToYAML(t *testing.T) {
latency: 10ms
quantumResistance: false
routes: []
networks: []
cliVersion: development
daemonVersion: 0.14.1
management:
@@ -465,6 +481,8 @@ quantumResistance: false
quantumResistancePermissive: false
routes:
- 10.10.0.0/24
networks:
- 10.10.0.0/24
dnsServers:
- servers:
- 8.8.8.8:53
@@ -509,6 +527,7 @@ func TestParsingToDetail(t *testing.T) {
Transfer status (received/sent) 200 B/100 B
Quantum resistance: false
Routes: 10.1.0.0/24
Networks: 10.1.0.0/24
Latency: 10ms
peer-2.awesome-domain.com:
@@ -525,6 +544,7 @@ func TestParsingToDetail(t *testing.T) {
Transfer status (received/sent) 2.0 KiB/1000 B
Quantum resistance: false
Routes: -
Networks: -
Latency: 10ms
OS: %s/%s
@@ -543,6 +563,7 @@ NetBird IP: 192.168.178.100/16
Interface type: Kernel
Quantum resistance: false
Routes: 10.10.0.0/24
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
@@ -564,6 +585,7 @@ NetBird IP: 192.168.178.100/16
Interface type: Kernel
Quantum resistance: false
Routes: 10.10.0.0/24
Networks: 10.10.0.0/24
Peers count: 2/2 Connected
`

31
client/cmd/system.go Normal file
View File

@@ -0,0 +1,31 @@
package cmd
// Flag constants for system configuration
const (
disableClientRoutesFlag = "disable-client-routes"
disableServerRoutesFlag = "disable-server-routes"
disableDNSFlag = "disable-dns"
disableFirewallFlag = "disable-firewall"
)
var (
disableClientRoutes bool
disableServerRoutes bool
disableDNS bool
disableFirewall bool
)
func init() {
// Add system flags to upCmd
upCmd.PersistentFlags().BoolVar(&disableClientRoutes, disableClientRoutesFlag, false,
"Disable client routes. If enabled, the client won't process client routes received from the management service.")
upCmd.PersistentFlags().BoolVar(&disableServerRoutes, disableServerRoutesFlag, false,
"Disable server routes. If enabled, the client won't act as a router for server routes received from the management service.")
upCmd.PersistentFlags().BoolVar(&disableDNS, disableDNSFlag, false,
"Disable DNS. If enabled, the client won't configure DNS settings.")
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
}

View File

@@ -10,6 +10,8 @@ import (
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util"
@@ -71,7 +73,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
t.Fatal(err)
}
s := grpc.NewServer()
store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir())
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir())
if err != nil {
t.Fatal(err)
}
@@ -93,7 +95,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
}
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -48,6 +48,7 @@ func init() {
)
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
}
func upFunc(cmd *cobra.Command, args []string) error {
@@ -147,6 +148,23 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
ic.DNSRouteInterval = &dnsRouteInterval
}
if cmd.Flag(disableClientRoutesFlag).Changed {
ic.DisableClientRoutes = &disableClientRoutes
}
if cmd.Flag(disableServerRoutesFlag).Changed {
ic.DisableServerRoutes = &disableServerRoutes
}
if cmd.Flag(disableDNSFlag).Changed {
ic.DisableDNS = &disableDNS
}
if cmd.Flag(disableFirewallFlag).Changed {
ic.DisableFirewall = &disableFirewall
}
if cmd.Flag(blockLANAccessFlag).Changed {
ic.BlockLANAccess = &blockLANAccess
}
providedSetupKey, err := getSetupKey()
if err != nil {
return err
@@ -172,7 +190,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r)
return connectClient.Run()
return connectClient.Run(nil)
}
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
@@ -264,6 +282,23 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
loginRequest.DnsRouteInterval = durationpb.New(dnsRouteInterval)
}
if cmd.Flag(disableClientRoutesFlag).Changed {
loginRequest.DisableClientRoutes = &disableClientRoutes
}
if cmd.Flag(disableServerRoutesFlag).Changed {
loginRequest.DisableServerRoutes = &disableServerRoutes
}
if cmd.Flag(disableDNSFlag).Changed {
loginRequest.DisableDns = &disableDNS
}
if cmd.Flag(disableFirewallFlag).Changed {
loginRequest.DisableFirewall = &disableFirewall
}
if cmd.Flag(blockLANAccessFlag).Changed {
loginRequest.BlockLanAccess = &blockLANAccess
}
var loginErr error
var loginResp *proto.LoginResponse

24
client/configs/configs.go Normal file
View File

@@ -0,0 +1,24 @@
package configs
import (
"os"
"path/filepath"
"runtime"
)
var StateDir string
func init() {
StateDir = os.Getenv("NB_STATE_DIR")
if StateDir != "" {
return
}
switch runtime.GOOS {
case "windows":
StateDir = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird")
case "darwin", "linux":
StateDir = "/var/lib/netbird"
case "freebsd", "openbsd", "netbsd", "dragonfly":
StateDir = "/var/db/netbird"
}
}

View File

@@ -3,7 +3,7 @@ package iptables
import (
"fmt"
"net"
"strconv"
"slices"
"github.com/coreos/go-iptables/iptables"
"github.com/google/uuid"
@@ -19,8 +19,7 @@ const (
tableName = "filter"
// rules chains contains the effective ACL rules
chainNameInputRules = "NETBIRD-ACL-INPUT"
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
chainNameInputRules = "NETBIRD-ACL-INPUT"
)
type aclEntries map[string][][]string
@@ -84,28 +83,22 @@ func (m *aclManager) AddPeerFiltering(
protocol firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
ipsetName string,
) ([]firewall.Rule, error) {
var dPortVal, sPortVal string
if dPort != nil && dPort.Values != nil {
// TODO: we support only one port per rule in current implementation of ACLs
dPortVal = strconv.Itoa(dPort.Values[0])
}
if sPort != nil && sPort.Values != nil {
sPortVal = strconv.Itoa(sPort.Values[0])
}
chain := chainNameInputRules
var chain string
if direction == firewall.RuleDirectionOUT {
chain = chainNameOutputRules
} else {
chain = chainNameInputRules
}
ipsetName = transformIPsetName(ipsetName, sPort, dPort)
specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName)
ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal)
specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName)
mangleSpecs := slices.Clone(specs)
mangleSpecs = append(mangleSpecs,
"-i", m.wgIface.Name(),
"-m", "addrtype", "--dst-type", "LOCAL",
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
)
specs = append(specs, "-j", actionToStr(action))
if ipsetName != "" {
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
if err := ipset.Add(ipsetName, ip.String()); err != nil {
@@ -137,7 +130,7 @@ func (m *aclManager) AddPeerFiltering(
m.ipsetStore.addIpList(ipsetName, ipList)
}
ok, err := m.iptablesClient.Exists("filter", chain, specs...)
ok, err := m.iptablesClient.Exists(tableFilter, chain, specs...)
if err != nil {
return nil, fmt.Errorf("failed to check rule: %w", err)
}
@@ -145,16 +138,22 @@ func (m *aclManager) AddPeerFiltering(
return nil, fmt.Errorf("rule already exists")
}
if err := m.iptablesClient.Append("filter", chain, specs...); err != nil {
if err := m.iptablesClient.Append(tableFilter, chain, specs...); err != nil {
return nil, err
}
if err := m.iptablesClient.Append(tableMangle, chainRTPRE, mangleSpecs...); err != nil {
log.Errorf("failed to add mangle rule: %v", err)
mangleSpecs = nil
}
rule := &Rule{
ruleID: uuid.New().String(),
specs: specs,
ipsetName: ipsetName,
ip: ip.String(),
chain: chain,
ruleID: uuid.New().String(),
specs: specs,
mangleSpecs: mangleSpecs,
ipsetName: ipsetName,
ip: ip.String(),
chain: chain,
}
m.updateState()
@@ -197,6 +196,12 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
}
if r.mangleSpecs != nil {
if err := m.iptablesClient.Delete(tableMangle, chainRTPRE, r.mangleSpecs...); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
m.updateState()
return nil
@@ -214,28 +219,7 @@ func (m *aclManager) Reset() error {
// todo write less destructive cleanup mechanism
func (m *aclManager) cleanChains() error {
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules)
if err != nil {
log.Debugf("failed to list chains: %s", err)
return err
}
if ok {
rules := m.entries["OUTPUT"]
for _, rule := range rules {
err := m.iptablesClient.DeleteIfExists(tableName, "OUTPUT", rule...)
if err != nil {
log.Errorf("failed to delete rule: %v, %s", rule, err)
}
}
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameOutputRules)
if err != nil {
log.Debugf("failed to clear and delete %s chain: %s", chainNameOutputRules, err)
return err
}
}
ok, err = m.iptablesClient.ChainExists(tableName, chainNameInputRules)
ok, err := m.iptablesClient.ChainExists(tableName, chainNameInputRules)
if err != nil {
log.Debugf("failed to list chains: %s", err)
return err
@@ -295,12 +279,6 @@ func (m *aclManager) createDefaultChains() error {
return err
}
// chain netbird-acl-output-rules
if err := m.iptablesClient.NewChain(tableName, chainNameOutputRules); err != nil {
log.Debugf("failed to create '%s' chain: %s", chainNameOutputRules, err)
return err
}
for chainName, rules := range m.entries {
for _, rule := range rules {
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
@@ -329,21 +307,13 @@ func (m *aclManager) createDefaultChains() error {
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
func (m *aclManager) seedInitialEntries() {
established := getConntrackEstablished()
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", chainNameOutputRules})
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("OUTPUT", append([]string{"-o", m.wgIface.Name()}, established...))
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
@@ -352,17 +322,10 @@ func (m *aclManager) seedInitialEntries() {
func (m *aclManager) seedInitialOptionalEntries() {
m.optionalEntries["FORWARD"] = []entry{
{
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", chainNameInputRules},
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
position: 2,
},
}
m.optionalEntries["PREROUTING"] = []entry{
{
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected)},
position: 1,
},
}
}
func (m *aclManager) appendToEntries(chainName string, spec []string) {
@@ -396,42 +359,26 @@ func (m *aclManager) updateState() {
}
// filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs(
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
) (specs []string) {
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
matchByIP := true
// don't use IP matching if IP is ip 0.0.0.0
if ip.String() == "0.0.0.0" {
matchByIP = false
}
switch direction {
case firewall.RuleDirectionIN:
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
} else {
specs = append(specs, "-s", ip.String())
}
}
case firewall.RuleDirectionOUT:
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
} else {
specs = append(specs, "-d", ip.String())
}
if matchByIP {
if ipsetName != "" {
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
} else {
specs = append(specs, "-s", ip.String())
}
}
if protocol != "all" {
specs = append(specs, "-p", protocol)
}
if sPort != "" {
specs = append(specs, "--sport", sPort)
}
if dPort != "" {
specs = append(specs, "--dport", dPort)
}
return append(specs, "-j", actionToStr(action))
specs = append(specs, applyPort("--sport", sPort)...)
specs = append(specs, applyPort("--dport", dPort)...)
return specs
}
func actionToStr(action firewall.Action) string {
@@ -441,15 +388,15 @@ func actionToStr(action firewall.Action) string {
return "DROP"
}
func transformIPsetName(ipsetName string, sPort, dPort string) string {
func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port) string {
switch {
case ipsetName == "":
return ""
case sPort != "" && dPort != "":
case sPort != nil && dPort != nil:
return ipsetName + "-sport-dport"
case sPort != "":
case sPort != nil:
return ipsetName + "-sport"
case dPort != "":
case dPort != nil:
return ipsetName + "-dport"
default:
return ipsetName

View File

@@ -100,15 +100,14 @@ func (m *Manager) AddPeerFiltering(
protocol firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
ipsetName string,
comment string,
_ string,
) ([]firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, action, ipsetName)
}
func (m *Manager) AddRouteFiltering(
@@ -197,29 +196,18 @@ func (m *Manager) AllowNetbird() error {
}
_, err := m.AddPeerFiltering(
net.ParseIP("0.0.0.0"),
net.IP{0, 0, 0, 0},
"all",
nil,
nil,
firewall.RuleDirectionIN,
firewall.ActionAccept,
"",
"",
)
if err != nil {
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
return fmt.Errorf("allow netbird interface traffic: %w", err)
}
_, err = m.AddPeerFiltering(
net.ParseIP("0.0.0.0"),
"all",
nil,
nil,
firewall.RuleDirectionOUT,
firewall.ActionAccept,
"",
"",
)
return err
return nil
}
// Flush doesn't need to be implemented for this manager

View File

@@ -68,27 +68,14 @@ func TestIptablesManager(t *testing.T) {
time.Sleep(time.Second)
}()
var rule1 []fw.Rule
t.Run("add first rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
for _, r := range rule1 {
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
}
})
var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{
Values: []int{8043: 8046},
IsRange: true,
Values: []uint16{8043, 8046},
}
rule2, err = manager.AddPeerFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
require.NoError(t, err, "failed to add rule")
for _, r := range rule2 {
@@ -97,15 +84,6 @@ func TestIptablesManager(t *testing.T) {
}
})
t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 {
err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
}
})
t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 {
err := manager.DeletePeerRule(r)
@@ -118,8 +96,8 @@ func TestIptablesManager(t *testing.T) {
t.Run("reset check", func(t *testing.T) {
// add second rule
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []int{5353}}
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
port := &fw.Port{Values: []uint16{5353}}
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic")
require.NoError(t, err, "failed to add rule")
err = manager.Reset(nil)
@@ -135,9 +113,6 @@ func TestIptablesManager(t *testing.T) {
}
func TestIptablesManagerIPSet(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)
mock := &iFaceMock{
NameFunc: func() string {
return "lo"
@@ -167,33 +142,13 @@ func TestIptablesManagerIPSet(t *testing.T) {
time.Sleep(time.Second)
}()
var rule1 []fw.Rule
t.Run("add first rule with set", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddPeerFiltering(
ip, "tcp", nil, port, fw.RuleDirectionOUT,
fw.ActionAccept, "default", "accept HTTP traffic",
)
require.NoError(t, err, "failed to add rule")
for _, r := range rule1 {
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
require.Equal(t, r.(*Rule).ipsetName, "default-dport", "ipset name must be set")
require.Equal(t, r.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
}
})
var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3")
port := &fw.Port{
Values: []int{443},
Values: []uint16{443},
}
rule2, err = manager.AddPeerFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
"default", "accept HTTPS traffic from ports range",
)
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range")
for _, r := range rule2 {
require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
@@ -201,15 +156,6 @@ func TestIptablesManagerIPSet(t *testing.T) {
}
})
t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 {
err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule")
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
}
})
t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 {
err := manager.DeletePeerRule(r)
@@ -269,12 +215,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
ip := net.ParseIP("10.20.0.100")
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
}

View File

@@ -590,10 +590,10 @@ func applyPort(flag string, port *firewall.Port) []string {
if len(port.Values) > 1 {
portList := make([]string, len(port.Values))
for i, p := range port.Values {
portList[i] = strconv.Itoa(p)
portList[i] = strconv.Itoa(int(p))
}
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
}
return []string{flag, strconv.Itoa(port.Values[0])}
return []string{flag, strconv.Itoa(int(port.Values[0]))}
}

View File

@@ -239,7 +239,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolTCP,
sPort: nil,
dPort: &firewall.Port{Values: []int{80}},
dPort: &firewall.Port{Values: []uint16{80}},
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
@@ -252,7 +252,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolUDP,
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
sPort: &firewall.Port{Values: []uint16{1024, 2048}, IsRange: true},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionDrop,
@@ -285,7 +285,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
destination: netip.MustParsePrefix("192.168.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
sPort: &firewall.Port{Values: []uint16{80, 443, 8080}},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
@@ -297,7 +297,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolUDP,
sPort: nil,
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
dPort: &firewall.Port{Values: []uint16{5000, 5100}, IsRange: true},
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
@@ -307,8 +307,8 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
destination: netip.MustParsePrefix("172.16.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []int{22}},
sPort: &firewall.Port{Values: []uint16{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []uint16{22}},
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,

View File

@@ -5,9 +5,10 @@ type Rule struct {
ruleID string
ipsetName string
specs []string
ip string
chain string
specs []string
mangleSpecs []string
ip string
chain string
}
// GetRuleID returns the rule id

View File

@@ -69,7 +69,6 @@ type Manager interface {
proto Protocol,
sPort *Port,
dPort *Port,
direction RuleDirection,
action Action,
ipsetName string,
comment string,

View File

@@ -30,7 +30,7 @@ type Port struct {
IsRange bool
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
Values []int
Values []uint16
}
// String interface implementation
@@ -40,7 +40,11 @@ func (p *Port) String() string {
if ports != "" {
ports += ","
}
ports += strconv.Itoa(port)
ports += strconv.Itoa(int(port))
}
if p.IsRange {
ports = "range:" + ports
}
return ports
}

View File

@@ -2,10 +2,9 @@ package nftables
import (
"bytes"
"encoding/binary"
"fmt"
"net"
"net/netip"
"slices"
"strconv"
"strings"
"time"
@@ -23,12 +22,10 @@ import (
const (
// rules chains contains the effective ACL rules
chainNameInputRules = "netbird-acl-input-rules"
chainNameOutputRules = "netbird-acl-output-rules"
chainNameInputRules = "netbird-acl-input-rules"
// filter chains contains the rules that jump to the rules chains
chainNameInputFilter = "netbird-acl-input-filter"
chainNameOutputFilter = "netbird-acl-output-filter"
chainNameForwardFilter = "netbird-acl-forward-filter"
chainNamePrerouting = "netbird-rt-prerouting"
@@ -47,9 +44,9 @@ type AclManager struct {
wgIface iFaceMapper
routingFwChainName string
workTable *nftables.Table
chainInputRules *nftables.Chain
chainOutputRules *nftables.Chain
workTable *nftables.Table
chainInputRules *nftables.Chain
chainPrerouting *nftables.Chain
ipsetStore *ipsetStore
rules map[string]*Rule
@@ -91,7 +88,6 @@ func (m *AclManager) AddPeerFiltering(
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
ipsetName string,
comment string,
@@ -106,7 +102,7 @@ func (m *AclManager) AddPeerFiltering(
}
newRules := make([]firewall.Rule, 0, 2)
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, direction, action, ipset, comment)
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset, comment)
if err != nil {
return nil, err
}
@@ -123,23 +119,32 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
}
if r.nftSet == nil {
err := m.rConn.DelRule(r.nftRule)
if err != nil {
if err := m.rConn.DelRule(r.nftRule); err != nil {
log.Errorf("failed to delete rule: %v", err)
}
if r.mangleRule != nil {
if err := m.rConn.DelRule(r.mangleRule); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
delete(m.rules, r.GetRuleID())
return m.rConn.Flush()
}
ips, ok := m.ipsetStore.ips(r.nftSet.Name)
if !ok {
err := m.rConn.DelRule(r.nftRule)
if err != nil {
if err := m.rConn.DelRule(r.nftRule); err != nil {
log.Errorf("failed to delete rule: %v", err)
}
if r.mangleRule != nil {
if err := m.rConn.DelRule(r.mangleRule); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
delete(m.rules, r.GetRuleID())
return m.rConn.Flush()
}
if _, ok := ips[r.ip.String()]; ok {
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}})
if err != nil {
@@ -158,12 +163,16 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
return nil
}
err := m.rConn.DelRule(r.nftRule)
if err != nil {
if err := m.rConn.DelRule(r.nftRule); err != nil {
log.Errorf("failed to delete rule: %v", err)
}
err = m.rConn.Flush()
if err != nil {
if r.mangleRule != nil {
if err := m.rConn.DelRule(r.mangleRule); err != nil {
log.Errorf("failed to delete mangle rule: %v", err)
}
}
if err := m.rConn.Flush(); err != nil {
return err
}
@@ -216,38 +225,6 @@ func (m *AclManager) createDefaultAllowRules() error {
Exprs: expIn,
})
expOut := []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
// mask
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: []byte{0, 0, 0, 0},
Xor: []byte{0, 0, 0, 0},
},
// net address
&expr.Cmp{
Register: 1,
Data: []byte{0, 0, 0, 0},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainOutputRules,
Position: 0,
Exprs: expOut,
})
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
@@ -262,25 +239,33 @@ func (m *AclManager) Flush() error {
return err
}
if err := m.refreshRuleHandles(m.chainInputRules); err != nil {
if err := m.refreshRuleHandles(m.chainInputRules, false); err != nil {
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
}
if err := m.refreshRuleHandles(m.chainOutputRules); err != nil {
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
if err := m.refreshRuleHandles(m.chainPrerouting, true); err != nil {
log.Errorf("failed to refresh rule handles prerouting chain: %v", err)
}
return nil
}
func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) {
ruleId := generatePeerRuleId(ip, sPort, dPort, direction, action, ipset)
func (m *AclManager) addIOFiltering(
ip net.IP,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
ipset *nftables.Set,
comment string,
) (*Rule, error) {
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
if r, ok := m.rules[ruleId]; ok {
return &Rule{
r.nftRule,
r.nftSet,
r.ruleID,
ip,
nftRule: r.nftRule,
mangleRule: r.mangleRule,
nftSet: r.nftSet,
ruleID: r.ruleID,
ip: ip,
}, nil
}
@@ -312,9 +297,6 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
if !bytes.HasPrefix(anyIP, rawIP) {
// source address position
addrOffset := uint32(12)
if direction == firewall.RuleDirectionOUT {
addrOffset += 4 // is ipv4 address length
}
expressions = append(expressions,
&expr.Payload{
@@ -344,73 +326,100 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
}
}
if sPort != nil && len(sPort.Values) != 0 {
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 0,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: encodePort(*sPort),
},
)
}
expressions = append(expressions, applyPort(sPort, true)...)
expressions = append(expressions, applyPort(dPort, false)...)
if dPort != nil && len(dPort.Values) != 0 {
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: encodePort(*dPort),
},
)
}
mainExpressions := slices.Clone(expressions)
switch action {
case firewall.ActionAccept:
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept})
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictAccept})
case firewall.ActionDrop:
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
}
userData := []byte(strings.Join([]string{ruleId, comment}, " "))
var chain *nftables.Chain
if direction == firewall.RuleDirectionIN {
chain = m.chainInputRules
} else {
chain = m.chainOutputRules
}
chain := m.chainInputRules
nftRule := m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: chain,
Exprs: expressions,
Exprs: mainExpressions,
UserData: userData,
})
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf(flushError, err)
}
rule := &Rule{
nftRule: nftRule,
nftSet: ipset,
ruleID: ruleId,
ip: ip,
nftRule: nftRule,
mangleRule: m.createPreroutingRule(expressions, userData),
nftSet: ipset,
ruleID: ruleId,
ip: ip,
}
m.rules[ruleId] = rule
if ipset != nil {
m.ipsetStore.AddReferenceToIpset(ipset.Name)
}
return rule, nil
}
func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
if m.chainPrerouting == nil {
log.Warn("prerouting chain is not created")
return nil
}
preroutingExprs := slices.Clone(expressions)
// interface
preroutingExprs = append([]expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
}, preroutingExprs...)
// local destination and mark
preroutingExprs = append(preroutingExprs,
&expr.Fib{
Register: 1,
ResultADDRTYPE: true,
FlagDADDR: true,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
)
return m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainPrerouting,
Exprs: preroutingExprs,
UserData: userData,
})
}
func (m *AclManager) createDefaultChains() (err error) {
// chainNameInputRules
chain := m.createChain(chainNameInputRules)
@@ -421,15 +430,6 @@ func (m *AclManager) createDefaultChains() (err error) {
}
m.chainInputRules = chain
// chainNameOutputRules
chain = m.createChain(chainNameOutputRules)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameOutputRules, err)
return err
}
m.chainOutputRules = chain
// netbird-acl-input-filter
// type filter hook input priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
@@ -441,18 +441,6 @@ func (m *AclManager) createDefaultChains() (err error) {
return err
}
// netbird-acl-output-filter
// type filter hook output priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput)
m.addFwdAllow(chain, expr.MetaKeyOIFNAME)
m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules
m.addDropExpressions(chain, expr.MetaKeyOIFNAME)
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameOutputFilter, err)
return err
}
// netbird-acl-forward-filter
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
@@ -475,7 +463,7 @@ func (m *AclManager) createDefaultChains() (err error) {
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
// netbird peer IP.
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
preroutingChain := m.rConn.AddChain(&nftables.Chain{
m.chainPrerouting = m.rConn.AddChain(&nftables.Chain{
Name: chainNamePrerouting,
Table: m.workTable,
Type: nftables.ChainTypeFilter,
@@ -483,8 +471,6 @@ func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error
Priority: nftables.ChainPriorityMangle,
})
m.addPreroutingRule(preroutingChain)
m.addFwmarkToForward(chainFwFilter)
if err := m.rConn.Flush(); err != nil {
@@ -494,43 +480,6 @@ func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error
return nil
}
func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: preroutingChain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Fib{
Register: 1,
ResultADDRTYPE: true,
FlagDADDR: true,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
},
})
}
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
@@ -546,8 +495,7 @@ func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: m.chainInputRules.Name,
Kind: expr.VerdictAccept,
},
},
})
@@ -619,45 +567,6 @@ func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.Met
return nil
}
func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
dstOp := expr.CmpOpNeq
expressions := []expr.Any{
&expr.Meta{Key: iifname, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: dstOp,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: expressions,
})
}
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
@@ -733,6 +642,7 @@ func (m *AclManager) flushWithBackoff() (err error) {
for i := 0; ; i++ {
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to flush nftables: %v", err)
if !strings.Contains(err.Error(), "busy") {
return
}
@@ -749,7 +659,7 @@ func (m *AclManager) flushWithBackoff() (err error) {
return
}
func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) error {
if m.workTable == nil || chain == nil {
return nil
}
@@ -766,22 +676,19 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
split := bytes.Split(rule.UserData, []byte(" "))
r, ok := m.rules[string(split[0])]
if ok {
*r.nftRule = *rule
if mangle {
*r.mangleRule = *rule
} else {
*r.nftRule = *rule
}
}
}
return nil
}
func generatePeerRuleId(
ip net.IP,
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
ipset *nftables.Set,
) string {
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
rulesetID := ":"
if sPort != nil {
rulesetID += sPort.String()
}
@@ -797,12 +704,6 @@ func generatePeerRuleId(
return "set:" + ipset.Name + rulesetID
}
func encodePort(port firewall.Port) []byte {
bs := make([]byte, 2)
binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))
return bs
}
func ifname(n string) []byte {
b := make([]byte, 16)
copy(b, n+"\x00")

View File

@@ -117,7 +117,6 @@ func (m *Manager) AddPeerFiltering(
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
ipsetName string,
comment string,
@@ -130,10 +129,17 @@ func (m *Manager) AddPeerFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
}
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, action, ipsetName, comment)
}
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
func (m *Manager) AddRouteFiltering(
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()

View File

@@ -74,16 +74,7 @@ func TestNftablesManager(t *testing.T) {
testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering(
ip,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{53}},
fw.RuleDirectionIN,
fw.ActionDrop,
"",
"",
)
rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "", "")
require.NoError(t, err, "failed to add rule")
err = manager.Flush()
@@ -209,12 +200,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
ip := net.ParseIP("10.20.0.100")
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
if i%100 == 0 {
@@ -296,16 +283,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
})
ip := net.ParseIP("100.96.0.1")
_, err = manager.AddPeerFiltering(
ip,
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{80}},
fw.RuleDirectionIN,
fw.ActionAccept,
"",
"test rule",
)
_, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "", "test rule")
require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering(
@@ -313,7 +291,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
netip.MustParsePrefix("10.1.0.0/24"),
fw.ProtocolTCP,
nil,
&fw.Port{Values: []int{443}},
&fw.Port{Values: []uint16{443}},
fw.ActionAccept,
)
require.NoError(t, err, "failed to add route filtering rule")

View File

@@ -956,12 +956,12 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
&expr.Cmp{
Op: expr.CmpOpGte,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[0])),
Data: binaryutil.BigEndian.PutUint16(port.Values[0]),
},
&expr.Cmp{
Op: expr.CmpOpLte,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[1])),
Data: binaryutil.BigEndian.PutUint16(port.Values[1]),
},
)
} else {
@@ -980,7 +980,7 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
exprs = append(exprs, &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(p)),
Data: binaryutil.BigEndian.PutUint16(p),
})
}
}

View File

@@ -222,7 +222,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolTCP,
sPort: nil,
dPort: &firewall.Port{Values: []int{80}},
dPort: &firewall.Port{Values: []uint16{80}},
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
@@ -235,7 +235,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolUDP,
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
sPort: &firewall.Port{Values: []uint16{1024, 2048}, IsRange: true},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionDrop,
@@ -268,7 +268,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
destination: netip.MustParsePrefix("192.168.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
sPort: &firewall.Port{Values: []uint16{80, 443, 8080}},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
@@ -280,7 +280,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolUDP,
sPort: nil,
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
dPort: &firewall.Port{Values: []uint16{5000, 5100}, IsRange: true},
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
@@ -290,8 +290,8 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
destination: netip.MustParsePrefix("172.16.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []int{22}},
sPort: &firewall.Port{Values: []uint16{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []uint16{22}},
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,

View File

@@ -8,10 +8,11 @@ import (
// Rule to handle management of rules
type Rule struct {
nftRule *nftables.Rule
nftSet *nftables.Set
ruleID string
ip net.IP
nftRule *nftables.Rule
mangleRule *nftables.Rule
nftSet *nftables.Set
ruleID string
ip net.IP
}
// GetRuleID returns the rule id

View File

@@ -2,7 +2,10 @@
package uspfilter
import "github.com/netbirdio/netbird/client/internal/statemanager"
import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
// Reset firewall to the default state
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
@@ -12,6 +15,21 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
}
if m.nativeFirewall != nil {
return m.nativeFirewall.Reset(stateManager)
}

View File

@@ -7,6 +7,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -26,6 +27,21 @@ func (m *Manager) Reset(*statemanager.Manager) error {
m.outgoingRules = make(map[string]RuleSet)
m.incomingRules = make(map[string]RuleSet)
if m.udpTracker != nil {
m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
}
if m.icmpTracker != nil {
m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
}
if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
}
if !isWindowsFirewallReachable() {
return nil
}

View File

@@ -0,0 +1,137 @@
// common.go
package conntrack
import (
"net"
"sync"
"sync/atomic"
"time"
)
// BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct {
SourceIP net.IP
DestIP net.IP
SourcePort uint16
DestPort uint16
lastSeen atomic.Int64 // Unix nano for atomic access
established atomic.Bool
}
// these small methods will be inlined by the compiler
// UpdateLastSeen safely updates the last seen timestamp
func (b *BaseConnTrack) UpdateLastSeen() {
b.lastSeen.Store(time.Now().UnixNano())
}
// IsEstablished safely checks if connection is established
func (b *BaseConnTrack) IsEstablished() bool {
return b.established.Load()
}
// SetEstablished safely sets the established state
func (b *BaseConnTrack) SetEstablished(state bool) {
b.established.Store(state)
}
// GetLastSeen safely gets the last seen timestamp
func (b *BaseConnTrack) GetLastSeen() time.Time {
return time.Unix(0, b.lastSeen.Load())
}
// timeoutExceeded checks if the connection has exceeded the given timeout
func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
lastSeen := time.Unix(0, b.lastSeen.Load())
return time.Since(lastSeen) > timeout
}
// IPAddr is a fixed-size IP address to avoid allocations
type IPAddr [16]byte
// MakeIPAddr creates an IPAddr from net.IP
func MakeIPAddr(ip net.IP) (addr IPAddr) {
// Optimization: check for v4 first as it's more common
if ip4 := ip.To4(); ip4 != nil {
copy(addr[12:], ip4)
} else {
copy(addr[:], ip.To16())
}
return addr
}
// ConnKey uniquely identifies a connection
type ConnKey struct {
SrcIP IPAddr
DstIP IPAddr
SrcPort uint16
DstPort uint16
}
// makeConnKey creates a connection key
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
return ConnKey{
SrcIP: MakeIPAddr(srcIP),
DstIP: MakeIPAddr(dstIP),
SrcPort: srcPort,
DstPort: dstPort,
}
}
// ValidateIPs checks if IPs match without allocation
func ValidateIPs(connIP IPAddr, pktIP net.IP) bool {
if ip4 := pktIP.To4(); ip4 != nil {
// Compare IPv4 addresses (last 4 bytes)
for i := 0; i < 4; i++ {
if connIP[12+i] != ip4[i] {
return false
}
}
return true
}
// Compare full IPv6 addresses
ip6 := pktIP.To16()
for i := 0; i < 16; i++ {
if connIP[i] != ip6[i] {
return false
}
}
return true
}
// PreallocatedIPs is a pool of IP byte slices to reduce allocations
type PreallocatedIPs struct {
sync.Pool
}
// NewPreallocatedIPs creates a new IP pool
func NewPreallocatedIPs() *PreallocatedIPs {
return &PreallocatedIPs{
Pool: sync.Pool{
New: func() interface{} {
ip := make(net.IP, 16)
return &ip
},
},
}
}
// Get retrieves an IP from the pool
func (p *PreallocatedIPs) Get() net.IP {
return *p.Pool.Get().(*net.IP)
}
// Put returns an IP to the pool
func (p *PreallocatedIPs) Put(ip net.IP) {
p.Pool.Put(&ip)
}
// copyIP copies an IP address efficiently
func copyIP(dst, src net.IP) {
if len(src) == 16 {
copy(dst, src)
} else {
// Handle IPv4
copy(dst[12:], src.To4())
}
}

View File

@@ -0,0 +1,115 @@
package conntrack
import (
"net"
"testing"
)
func BenchmarkIPOperations(b *testing.B) {
b.Run("MakeIPAddr", func(b *testing.B) {
ip := net.ParseIP("192.168.1.1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = MakeIPAddr(ip)
}
})
b.Run("ValidateIPs", func(b *testing.B) {
ip1 := net.ParseIP("192.168.1.1")
ip2 := net.ParseIP("192.168.1.1")
addr := MakeIPAddr(ip1)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = ValidateIPs(addr, ip2)
}
})
b.Run("IPPool", func(b *testing.B) {
pool := NewPreallocatedIPs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
ip := pool.Get()
pool.Put(ip)
}
})
}
func BenchmarkAtomicOperations(b *testing.B) {
conn := &BaseConnTrack{}
b.Run("UpdateLastSeen", func(b *testing.B) {
for i := 0; i < b.N; i++ {
conn.UpdateLastSeen()
}
})
b.Run("IsEstablished", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = conn.IsEstablished()
}
})
b.Run("SetEstablished", func(b *testing.B) {
for i := 0; i < b.N; i++ {
conn.SetEstablished(i%2 == 0)
}
})
b.Run("GetLastSeen", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = conn.GetLastSeen()
}
})
}
// Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) {
b.Run("TCPHighLoad", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout)
defer tracker.Close()
// Generate different IPs
srcIPs := make([]net.IP, 100)
dstIPs := make([]net.IP, 100)
for i := 0; i < 100; i++ {
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn)
// Simulate some valid inbound packets
if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck)
}
}
})
b.Run("UDPHighLoad", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout)
defer tracker.Close()
// Generate different IPs
srcIPs := make([]net.IP, 100)
dstIPs := make([]net.IP, 100)
for i := 0; i < 100; i++ {
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80)
// Simulate some valid inbound packets
if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535))
}
}
})
}

View File

@@ -0,0 +1,170 @@
package conntrack
import (
"net"
"sync"
"time"
"github.com/google/gopacket/layers"
)
const (
// DefaultICMPTimeout is the default timeout for ICMP connections
DefaultICMPTimeout = 30 * time.Second
// ICMPCleanupInterval is how often we check for stale ICMP connections
ICMPCleanupInterval = 15 * time.Second
)
// ICMPConnKey uniquely identifies an ICMP connection
type ICMPConnKey struct {
// Supports both IPv4 and IPv6
SrcIP [16]byte
DstIP [16]byte
Sequence uint16 // ICMP sequence number
ID uint16 // ICMP identifier
}
// ICMPConnTrack represents an ICMP connection state
type ICMPConnTrack struct {
BaseConnTrack
Sequence uint16
ID uint16
}
// ICMPTracker manages ICMP connection states
type ICMPTracker struct {
connections map[ICMPConnKey]*ICMPConnTrack
timeout time.Duration
cleanupTicker *time.Ticker
mutex sync.RWMutex
done chan struct{}
ipPool *PreallocatedIPs
}
// NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration) *ICMPTracker {
if timeout == 0 {
timeout = DefaultICMPTimeout
}
tracker := &ICMPTracker{
connections: make(map[ICMPConnKey]*ICMPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
done: make(chan struct{}),
ipPool: NewPreallocatedIPs(),
}
go tracker.cleanupRoutine()
return tracker
}
// TrackOutbound records an outbound ICMP Echo Request
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
key := makeICMPKey(srcIP, dstIP, id, seq)
now := time.Now().UnixNano()
t.mutex.Lock()
conn, exists := t.connections[key]
if !exists {
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &ICMPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
},
ID: id,
Sequence: seq,
}
conn.lastSeen.Store(now)
conn.established.Store(true)
t.connections[key] = conn
}
t.mutex.Unlock()
conn.lastSeen.Store(now)
}
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
switch icmpType {
case uint8(layers.ICMPv4TypeDestinationUnreachable),
uint8(layers.ICMPv4TypeTimeExceeded):
return true
case uint8(layers.ICMPv4TypeEchoReply):
// continue processing
default:
return false
}
key := makeICMPKey(dstIP, srcIP, id, seq)
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists {
return false
}
if conn.timeoutExceeded(t.timeout) {
return false
}
return conn.IsEstablished() &&
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
conn.ID == id &&
conn.Sequence == seq
}
func (t *ICMPTracker) cleanupRoutine() {
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-t.done:
return
}
}
}
func (t *ICMPTracker) cleanup() {
t.mutex.Lock()
defer t.mutex.Unlock()
for key, conn := range t.connections {
if conn.timeoutExceeded(t.timeout) {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key)
}
}
}
// Close stops the cleanup routine and releases resources
func (t *ICMPTracker) Close() {
t.cleanupTicker.Stop()
close(t.done)
t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil
t.mutex.Unlock()
}
// makeICMPKey creates an ICMP connection key
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
return ICMPConnKey{
SrcIP: MakeIPAddr(srcIP),
DstIP: MakeIPAddr(dstIP),
ID: id,
Sequence: seq,
}
}

View File

@@ -0,0 +1,39 @@
package conntrack
import (
"net"
"testing"
)
func BenchmarkICMPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535))
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0)
}
})
}

View File

@@ -0,0 +1,352 @@
package conntrack
// TODO: Send RST packets for invalid/timed-out connections
import (
"net"
"sync"
"time"
)
const (
// MSL (Maximum Segment Lifetime) is typically 2 minutes
MSL = 2 * time.Minute
// TimeWaitTimeout (TIME-WAIT) should last 2*MSL
TimeWaitTimeout = 2 * MSL
)
const (
TCPSyn uint8 = 0x02
TCPAck uint8 = 0x10
TCPFin uint8 = 0x01
TCPRst uint8 = 0x04
TCPPush uint8 = 0x08
TCPUrg uint8 = 0x20
)
const (
// DefaultTCPTimeout is the default timeout for established TCP connections
DefaultTCPTimeout = 3 * time.Hour
// TCPHandshakeTimeout is timeout for TCP handshake completion
TCPHandshakeTimeout = 60 * time.Second
// TCPCleanupInterval is how often we check for stale connections
TCPCleanupInterval = 5 * time.Minute
)
// TCPState represents the state of a TCP connection
type TCPState int
const (
TCPStateNew TCPState = iota
TCPStateSynSent
TCPStateSynReceived
TCPStateEstablished
TCPStateFinWait1
TCPStateFinWait2
TCPStateClosing
TCPStateTimeWait
TCPStateCloseWait
TCPStateLastAck
TCPStateClosed
)
// TCPConnKey uniquely identifies a TCP connection
type TCPConnKey struct {
SrcIP [16]byte
DstIP [16]byte
SrcPort uint16
DstPort uint16
}
// TCPConnTrack represents a TCP connection state
type TCPConnTrack struct {
BaseConnTrack
State TCPState
sync.RWMutex
}
// TCPTracker manages TCP connection states
type TCPTracker struct {
connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex
cleanupTicker *time.Ticker
done chan struct{}
timeout time.Duration
ipPool *PreallocatedIPs
}
// NewTCPTracker creates a new TCP connection tracker
func NewTCPTracker(timeout time.Duration) *TCPTracker {
tracker := &TCPTracker{
connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval),
done: make(chan struct{}),
timeout: timeout,
ipPool: NewPreallocatedIPs(),
}
go tracker.cleanupRoutine()
return tracker
}
// TrackOutbound processes an outbound TCP packet and updates connection state
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
// Create key before lock
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
now := time.Now().UnixNano()
t.mutex.Lock()
conn, exists := t.connections[key]
if !exists {
// Use preallocated IPs
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
SourcePort: srcPort,
DestPort: dstPort,
},
State: TCPStateNew,
}
conn.lastSeen.Store(now)
conn.established.Store(false)
t.connections[key] = conn
}
t.mutex.Unlock()
// Lock individual connection for state update
conn.Lock()
t.updateState(conn, flags, true)
conn.Unlock()
conn.lastSeen.Store(now)
}
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
if !isValidFlagCombination(flags) {
return false
}
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists {
return false
}
// Handle RST packets
if flags&TCPRst != 0 {
conn.Lock()
if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
conn.State = TCPStateClosed
conn.SetEstablished(false)
conn.Unlock()
return true
}
conn.Unlock()
return false
}
conn.Lock()
t.updateState(conn, flags, false)
conn.UpdateLastSeen()
isEstablished := conn.IsEstablished()
isValidState := t.isValidStateForFlags(conn.State, flags)
conn.Unlock()
return isEstablished || isValidState
}
// updateState updates the TCP connection state based on flags
func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) {
// Handle RST flag specially - it always causes transition to closed
if flags&TCPRst != 0 {
conn.State = TCPStateClosed
conn.SetEstablished(false)
return
}
switch conn.State {
case TCPStateNew:
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
conn.State = TCPStateSynSent
}
case TCPStateSynSent:
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
if isOutbound {
conn.State = TCPStateSynReceived
} else {
// Simultaneous open
conn.State = TCPStateEstablished
conn.SetEstablished(true)
}
}
case TCPStateSynReceived:
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
conn.State = TCPStateEstablished
conn.SetEstablished(true)
}
case TCPStateEstablished:
if flags&TCPFin != 0 {
if isOutbound {
conn.State = TCPStateFinWait1
} else {
conn.State = TCPStateCloseWait
}
conn.SetEstablished(false)
}
case TCPStateFinWait1:
switch {
case flags&TCPFin != 0 && flags&TCPAck != 0:
// Simultaneous close - both sides sent FIN
conn.State = TCPStateClosing
case flags&TCPFin != 0:
conn.State = TCPStateFinWait2
case flags&TCPAck != 0:
conn.State = TCPStateFinWait2
}
case TCPStateFinWait2:
if flags&TCPFin != 0 {
conn.State = TCPStateTimeWait
}
case TCPStateClosing:
if flags&TCPAck != 0 {
conn.State = TCPStateTimeWait
// Keep established = false from previous state
}
case TCPStateCloseWait:
if flags&TCPFin != 0 {
conn.State = TCPStateLastAck
}
case TCPStateLastAck:
if flags&TCPAck != 0 {
conn.State = TCPStateClosed
}
case TCPStateTimeWait:
// Stay in TIME-WAIT for 2MSL before transitioning to closed
// This is handled by the cleanup routine
}
}
// isValidStateForFlags checks if the TCP flags are valid for the current connection state
func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
if !isValidFlagCombination(flags) {
return false
}
switch state {
case TCPStateNew:
return flags&TCPSyn != 0 && flags&TCPAck == 0
case TCPStateSynSent:
return flags&TCPSyn != 0 && flags&TCPAck != 0
case TCPStateSynReceived:
return flags&TCPAck != 0
case TCPStateEstablished:
if flags&TCPRst != 0 {
return true
}
return flags&TCPAck != 0
case TCPStateFinWait1:
return flags&TCPFin != 0 || flags&TCPAck != 0
case TCPStateFinWait2:
return flags&TCPFin != 0 || flags&TCPAck != 0
case TCPStateClosing:
// In CLOSING state, we should accept the final ACK
return flags&TCPAck != 0
case TCPStateTimeWait:
// In TIME_WAIT, we might see retransmissions
return flags&TCPAck != 0
case TCPStateCloseWait:
return flags&TCPFin != 0 || flags&TCPAck != 0
case TCPStateLastAck:
return flags&TCPAck != 0
case TCPStateClosed:
// Accept retransmitted ACKs in closed state
// This is important because the final ACK might be lost
// and the peer will retransmit their FIN-ACK
return flags&TCPAck != 0
}
return false
}
func (t *TCPTracker) cleanupRoutine() {
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-t.done:
return
}
}
}
func (t *TCPTracker) cleanup() {
t.mutex.Lock()
defer t.mutex.Unlock()
for key, conn := range t.connections {
var timeout time.Duration
switch {
case conn.State == TCPStateTimeWait:
timeout = TimeWaitTimeout
case conn.IsEstablished():
timeout = t.timeout
default:
timeout = TCPHandshakeTimeout
}
lastSeen := conn.GetLastSeen()
if time.Since(lastSeen) > timeout {
// Return IPs to pool
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key)
}
}
}
// Close stops the cleanup routine and releases resources
func (t *TCPTracker) Close() {
t.cleanupTicker.Stop()
close(t.done)
// Clean up all remaining IPs
t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil
t.mutex.Unlock()
}
func isValidFlagCombination(flags uint8) bool {
// Invalid: SYN+FIN
if flags&TCPSyn != 0 && flags&TCPFin != 0 {
return false
}
// Invalid: RST with SYN or FIN
if flags&TCPRst != 0 && (flags&TCPSyn != 0 || flags&TCPFin != 0) {
return false
}
return true
}

View File

@@ -0,0 +1,308 @@
package conntrack
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestTCPStateMachine(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1")
dstIP := net.ParseIP("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Security Tests", func(t *testing.T) {
tests := []struct {
name string
flags uint8
wantDrop bool
desc string
}{
{
name: "Block unsolicited SYN-ACK",
flags: TCPSyn | TCPAck,
wantDrop: true,
desc: "Should block SYN-ACK without prior SYN",
},
{
name: "Block invalid SYN-FIN",
flags: TCPSyn | TCPFin,
wantDrop: true,
desc: "Should block invalid SYN-FIN combination",
},
{
name: "Block unsolicited RST",
flags: TCPRst,
wantDrop: true,
desc: "Should block RST without connection",
},
{
name: "Block unsolicited ACK",
flags: TCPAck,
wantDrop: true,
desc: "Should block ACK without connection",
},
{
name: "Block data without connection",
flags: TCPAck | TCPPush,
wantDrop: true,
desc: "Should block data without established connection",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags)
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
})
}
})
t.Run("Connection Flow Tests", func(t *testing.T) {
tests := []struct {
name string
test func(*testing.T)
desc string
}{
{
name: "Normal Handshake",
test: func(t *testing.T) {
t.Helper()
// Send initial SYN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
// Receive SYN-ACK
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
require.True(t, valid, "SYN-ACK should be allowed")
// Send ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
// Test data transfer
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck)
require.True(t, valid, "Data should be allowed after handshake")
},
},
{
name: "Normal Close",
test: func(t *testing.T) {
t.Helper()
// First establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Send FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
// Receive ACK for FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
require.True(t, valid, "ACK for FIN should be allowed")
// Receive FIN from other side
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
require.True(t, valid, "FIN should be allowed")
// Send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
},
},
{
name: "RST During Connection",
test: func(t *testing.T) {
t.Helper()
// First establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Receive RST
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
require.True(t, valid, "RST should be allowed for established connection")
// Connection is logically dead but we don't enforce blocking subsequent packets
// The connection will be cleaned up by timeout
},
},
{
name: "Simultaneous Close",
test: func(t *testing.T) {
t.Helper()
// First establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Both sides send FIN+ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
require.True(t, valid, "Simultaneous FIN should be allowed")
// Both sides send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
require.True(t, valid, "Final ACKs should be allowed")
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Helper()
tracker = NewTCPTracker(DefaultTCPTimeout)
tt.test(t)
})
}
})
}
func TestRSTHandling(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1")
dstIP := net.ParseIP("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
tests := []struct {
name string
setupState func()
sendRST func()
wantValid bool
desc string
}{
{
name: "RST in established",
setupState: func() {
// Establish connection first
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
},
sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
},
wantValid: true,
desc: "Should accept RST for established connection",
},
{
name: "RST without connection",
setupState: func() {},
sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
},
wantValid: false,
desc: "Should reject RST without connection",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setupState()
tt.sendRST()
// Verify connection state is as expected
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
conn := tracker.connections[key]
if tt.wantValid {
require.NotNil(t, conn)
require.Equal(t, TCPStateClosed, conn.State)
require.False(t, conn.IsEstablished())
}
})
}
}
// Helper to establish a TCP connection
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) {
t.Helper()
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
require.True(t, valid, "SYN-ACK should be allowed")
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
}
func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck)
}
})
b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
if i%2 == 0 {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
} else {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck)
}
i++
}
})
})
}
// Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing
defer tracker.Close()
// Pre-populate with expired connections
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
for i := 0; i < 10000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
}
// Wait for connections to expire
time.Sleep(200 * time.Millisecond)
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.cleanup()
}
})
}

View File

@@ -0,0 +1,158 @@
package conntrack
import (
"net"
"sync"
"time"
)
const (
// DefaultUDPTimeout is the default timeout for UDP connections
DefaultUDPTimeout = 30 * time.Second
// UDPCleanupInterval is how often we check for stale connections
UDPCleanupInterval = 15 * time.Second
)
// UDPConnTrack represents a UDP connection state
type UDPConnTrack struct {
BaseConnTrack
}
// UDPTracker manages UDP connection states
type UDPTracker struct {
connections map[ConnKey]*UDPConnTrack
timeout time.Duration
cleanupTicker *time.Ticker
mutex sync.RWMutex
done chan struct{}
ipPool *PreallocatedIPs
}
// NewUDPTracker creates a new UDP connection tracker
func NewUDPTracker(timeout time.Duration) *UDPTracker {
if timeout == 0 {
timeout = DefaultUDPTimeout
}
tracker := &UDPTracker{
connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval),
done: make(chan struct{}),
ipPool: NewPreallocatedIPs(),
}
go tracker.cleanupRoutine()
return tracker
}
// TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
now := time.Now().UnixNano()
t.mutex.Lock()
conn, exists := t.connections[key]
if !exists {
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &UDPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
SourcePort: srcPort,
DestPort: dstPort,
},
}
conn.lastSeen.Store(now)
conn.established.Store(true)
t.connections[key] = conn
}
t.mutex.Unlock()
conn.lastSeen.Store(now)
}
// IsValidInbound checks if an inbound packet matches a tracked connection
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool {
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists {
return false
}
if conn.timeoutExceeded(t.timeout) {
return false
}
return conn.IsEstablished() &&
ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
conn.DestPort == srcPort &&
conn.SourcePort == dstPort
}
// cleanupRoutine periodically removes stale connections
func (t *UDPTracker) cleanupRoutine() {
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-t.done:
return
}
}
}
func (t *UDPTracker) cleanup() {
t.mutex.Lock()
defer t.mutex.Unlock()
for key, conn := range t.connections {
if conn.timeoutExceeded(t.timeout) {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key)
}
}
}
// Close stops the cleanup routine and releases resources
func (t *UDPTracker) Close() {
t.cleanupTicker.Stop()
close(t.done)
t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil
t.mutex.Unlock()
}
// GetConnection safely retrieves a connection state
func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) {
t.mutex.RLock()
defer t.mutex.RUnlock()
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
conn, exists := t.connections[key]
if !exists {
return nil, false
}
return conn, true
}
// Timeout returns the configured timeout duration for the tracker
func (t *UDPTracker) Timeout() time.Duration {
return t.timeout
}

View File

@@ -0,0 +1,243 @@
package conntrack
import (
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewUDPTracker(t *testing.T) {
tests := []struct {
name string
timeout time.Duration
wantTimeout time.Duration
}{
{
name: "with custom timeout",
timeout: 1 * time.Minute,
wantTimeout: 1 * time.Minute,
},
{
name: "with zero timeout uses default",
timeout: 0,
wantTimeout: DefaultUDPTimeout,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tracker := NewUDPTracker(tt.timeout)
assert.NotNil(t, tracker)
assert.Equal(t, tt.wantTimeout, tracker.timeout)
assert.NotNil(t, tracker.connections)
assert.NotNil(t, tracker.cleanupTicker)
assert.NotNil(t, tracker.done)
})
}
}
func TestUDPTracker_TrackOutbound(t *testing.T) {
tracker := NewUDPTracker(DefaultUDPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2")
dstIP := net.ParseIP("192.168.1.3")
srcPort := uint16(12345)
dstPort := uint16(53)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
// Verify connection was tracked
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
conn, exists := tracker.connections[key]
require.True(t, exists)
assert.True(t, conn.SourceIP.Equal(srcIP))
assert.True(t, conn.DestIP.Equal(dstIP))
assert.Equal(t, srcPort, conn.SourcePort)
assert.Equal(t, dstPort, conn.DestPort)
assert.True(t, conn.IsEstablished())
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
}
func TestUDPTracker_IsValidInbound(t *testing.T) {
tracker := NewUDPTracker(1 * time.Second)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2")
dstIP := net.ParseIP("192.168.1.3")
srcPort := uint16(12345)
dstPort := uint16(53)
// Track outbound connection
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
tests := []struct {
name string
srcIP net.IP
dstIP net.IP
srcPort uint16
dstPort uint16
sleep time.Duration
want bool
}{
{
name: "valid inbound response",
srcIP: dstIP, // Original destination is now source
dstIP: srcIP, // Original source is now destination
srcPort: dstPort, // Original destination port is now source
dstPort: srcPort, // Original source port is now destination
sleep: 0,
want: true,
},
{
name: "invalid source IP",
srcIP: net.ParseIP("192.168.1.4"),
dstIP: srcIP,
srcPort: dstPort,
dstPort: srcPort,
sleep: 0,
want: false,
},
{
name: "invalid destination IP",
srcIP: dstIP,
dstIP: net.ParseIP("192.168.1.4"),
srcPort: dstPort,
dstPort: srcPort,
sleep: 0,
want: false,
},
{
name: "invalid source port",
srcIP: dstIP,
dstIP: srcIP,
srcPort: 54321,
dstPort: srcPort,
sleep: 0,
want: false,
},
{
name: "invalid destination port",
srcIP: dstIP,
dstIP: srcIP,
srcPort: dstPort,
dstPort: 54321,
sleep: 0,
want: false,
},
{
name: "expired connection",
srcIP: dstIP,
dstIP: srcIP,
srcPort: dstPort,
dstPort: srcPort,
sleep: 2 * time.Second, // Longer than tracker timeout
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.sleep > 0 {
time.Sleep(tt.sleep)
}
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort)
assert.Equal(t, tt.want, got)
})
}
}
func TestUDPTracker_Cleanup(t *testing.T) {
// Use shorter intervals for testing
timeout := 50 * time.Millisecond
cleanupInterval := 25 * time.Millisecond
// Create tracker with custom cleanup interval
tracker := &UDPTracker{
connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout,
cleanupTicker: time.NewTicker(cleanupInterval),
done: make(chan struct{}),
ipPool: NewPreallocatedIPs(),
}
// Start cleanup routine
go tracker.cleanupRoutine()
// Add some connections
connections := []struct {
srcIP net.IP
dstIP net.IP
srcPort uint16
dstPort uint16
}{
{
srcIP: net.ParseIP("192.168.1.2"),
dstIP: net.ParseIP("192.168.1.3"),
srcPort: 12345,
dstPort: 53,
},
{
srcIP: net.ParseIP("192.168.1.4"),
dstIP: net.ParseIP("192.168.1.5"),
srcPort: 12346,
dstPort: 53,
},
}
for _, conn := range connections {
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort)
}
// Verify initial connections
assert.Len(t, tracker.connections, 2)
// Wait for connection timeout and cleanup interval
time.Sleep(timeout + 2*cleanupInterval)
tracker.mutex.RLock()
connCount := len(tracker.connections)
tracker.mutex.RUnlock()
// Verify connections were cleaned up
assert.Equal(t, 0, connCount, "Expected all connections to be cleaned up")
// Properly close the tracker
tracker.Close()
}
func BenchmarkUDPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000))
}
})
}

View File

@@ -15,9 +15,8 @@ type Rule struct {
ipLayer gopacket.LayerType
matchByIP bool
protoLayer gopacket.LayerType
direction firewall.RuleDirection
sPort uint16
dPort uint16
sPort *firewall.Port
dPort *firewall.Port
drop bool
comment string

View File

@@ -4,6 +4,8 @@ import (
"fmt"
"net"
"net/netip"
"os"
"strconv"
"sync"
"github.com/google/gopacket"
@@ -12,6 +14,7 @@ import (
log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/statemanager"
@@ -19,6 +22,8 @@ import (
const layerTypeAll = 0
const EnvDisableConntrack = "NB_DISABLE_CONNTRACK"
var (
errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall")
)
@@ -34,7 +39,9 @@ type RuleSet map[string]Rule
// Manager userspace firewall manager
type Manager struct {
outgoingRules map[string]RuleSet
// outgoingRules is used for hooks only
outgoingRules map[string]RuleSet
// incomingRules is used for filtering and hooks
incomingRules map[string]RuleSet
wgNetwork *net.IPNet
decoders sync.Pool
@@ -42,6 +49,11 @@ type Manager struct {
nativeFirewall firewall.Manager
mutex sync.RWMutex
stateful bool
udpTracker *conntrack.UDPTracker
icmpTracker *conntrack.ICMPTracker
tcpTracker *conntrack.TCPTracker
}
// decoder for packages
@@ -73,6 +85,8 @@ func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager
}
func create(iface IFaceMapper) (*Manager, error) {
disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack))
m := &Manager{
decoders: sync.Pool{
New: func() any {
@@ -90,6 +104,16 @@ func create(iface IFaceMapper) (*Manager, error) {
outgoingRules: make(map[string]RuleSet),
incomingRules: make(map[string]RuleSet),
wgIface: iface,
stateful: !disableConntrack,
}
// Only initialize trackers if stateful mode is enabled
if disableConntrack {
log.Info("conntrack is disabled")
} else {
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
}
if err := iface.SetFilter(m); err != nil {
@@ -134,9 +158,8 @@ func (m *Manager) AddPeerFiltering(
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
direction firewall.RuleDirection,
action firewall.Action,
ipsetName string,
_ string,
comment string,
) ([]firewall.Rule, error) {
r := Rule{
@@ -144,7 +167,6 @@ func (m *Manager) AddPeerFiltering(
ip: ip,
ipLayer: layers.LayerTypeIPv6,
matchByIP: true,
direction: direction,
drop: action == firewall.ActionDrop,
comment: comment,
}
@@ -157,13 +179,8 @@ func (m *Manager) AddPeerFiltering(
r.matchByIP = false
}
if sPort != nil && len(sPort.Values) == 1 {
r.sPort = uint16(sPort.Values[0])
}
if dPort != nil && len(dPort.Values) == 1 {
r.dPort = uint16(dPort.Values[0])
}
r.sPort = sPort
r.dPort = dPort
switch proto {
case firewall.ProtocolTCP:
@@ -180,17 +197,10 @@ func (m *Manager) AddPeerFiltering(
}
m.mutex.Lock()
if direction == firewall.RuleDirectionIN {
if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(RuleSet)
}
m.incomingRules[r.ip.String()][r.id] = r
} else {
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
m.outgoingRules[r.ip.String()] = make(RuleSet)
}
m.outgoingRules[r.ip.String()][r.id] = r
if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(RuleSet)
}
m.incomingRules[r.ip.String()][r.id] = r
m.mutex.Unlock()
return []firewall.Rule{&r}, nil
}
@@ -219,19 +229,10 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
}
if r.direction == firewall.RuleDirectionIN {
_, ok := m.incomingRules[r.ip.String()][r.id]
if !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.incomingRules[r.ip.String()], r.id)
} else {
_, ok := m.outgoingRules[r.ip.String()][r.id]
if !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.outgoingRules[r.ip.String()], r.id)
if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
}
delete(m.incomingRules[r.ip.String()], r.id)
return nil
}
@@ -249,16 +250,16 @@ func (m *Manager) Flush() error { return nil }
// DropOutgoing filter outgoing packets
func (m *Manager) DropOutgoing(packetData []byte) bool {
return m.dropFilter(packetData, m.outgoingRules, false)
return m.processOutgoingHooks(packetData)
}
// DropIncoming filter incoming packets
func (m *Manager) DropIncoming(packetData []byte) bool {
return m.dropFilter(packetData, m.incomingRules, true)
return m.dropFilter(packetData, m.incomingRules)
}
// dropFilter implements same logic for booth direction of the traffic
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool {
// processOutgoingHooks processes UDP hooks for outgoing packets and tracks TCP/UDP/ICMP
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
@@ -266,64 +267,235 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isInco
defer m.decoders.Put(d)
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
log.Tracef("couldn't decode layer, err: %s", err)
return true
return false
}
if len(d.decoded) < 2 {
log.Tracef("not enough levels in network packet")
return false
}
srcIP, dstIP := m.extractIPs(d)
if srcIP == nil {
return false
}
// Always process UDP hooks
if d.decoded[1] == layers.LayerTypeUDP {
// Track UDP state only if enabled
if m.stateful {
m.trackUDPOutbound(d, srcIP, dstIP)
}
return m.checkUDPHooks(d, dstIP, packetData)
}
// Track other protocols only if stateful mode is enabled
if m.stateful {
switch d.decoded[1] {
case layers.LayerTypeTCP:
m.trackTCPOutbound(d, srcIP, dstIP)
case layers.LayerTypeICMPv4:
m.trackICMPOutbound(d, srcIP, dstIP)
}
}
return false
}
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) {
switch d.decoded[0] {
case layers.LayerTypeIPv4:
return d.ip4.SrcIP, d.ip4.DstIP
case layers.LayerTypeIPv6:
return d.ip6.SrcIP, d.ip6.DstIP
default:
return nil, nil
}
}
func (m *Manager) trackTCPOutbound(d *decoder, srcIP, dstIP net.IP) {
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(
srcIP,
dstIP,
uint16(d.tcp.SrcPort),
uint16(d.tcp.DstPort),
flags,
)
}
func getTCPFlags(tcp *layers.TCP) uint8 {
var flags uint8
if tcp.SYN {
flags |= conntrack.TCPSyn
}
if tcp.ACK {
flags |= conntrack.TCPAck
}
if tcp.FIN {
flags |= conntrack.TCPFin
}
if tcp.RST {
flags |= conntrack.TCPRst
}
if tcp.PSH {
flags |= conntrack.TCPPush
}
if tcp.URG {
flags |= conntrack.TCPUrg
}
return flags
}
func (m *Manager) trackUDPOutbound(d *decoder, srcIP, dstIP net.IP) {
m.udpTracker.TrackOutbound(
srcIP,
dstIP,
uint16(d.udp.SrcPort),
uint16(d.udp.DstPort),
)
}
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool {
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} {
if rules, exists := m.outgoingRules[ipKey]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
return rule.udpHook(packetData)
}
}
}
}
return false
}
func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) {
if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest {
m.icmpTracker.TrackOutbound(
srcIP,
dstIP,
d.icmp4.Id,
d.icmp4.Seq,
)
}
}
// dropFilter implements filtering logic for incoming packets
func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool {
// TODO: Disable router if --disable-server-router is set
m.mutex.RLock()
defer m.mutex.RUnlock()
d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d)
if !m.isValidPacket(d, packetData) {
return true
}
ipLayer := d.decoded[0]
switch ipLayer {
case layers.LayerTypeIPv4:
if !m.wgNetwork.Contains(d.ip4.SrcIP) || !m.wgNetwork.Contains(d.ip4.DstIP) {
return false
}
case layers.LayerTypeIPv6:
if !m.wgNetwork.Contains(d.ip6.SrcIP) || !m.wgNetwork.Contains(d.ip6.DstIP) {
return false
}
default:
srcIP, dstIP := m.extractIPs(d)
if srcIP == nil {
log.Errorf("unknown layer: %v", d.decoded[0])
return true
}
var ip net.IP
switch ipLayer {
case layers.LayerTypeIPv4:
if isIncomingPacket {
ip = d.ip4.SrcIP
} else {
ip = d.ip4.DstIP
}
case layers.LayerTypeIPv6:
if isIncomingPacket {
ip = d.ip6.SrcIP
} else {
ip = d.ip6.DstIP
}
if !m.isWireguardTraffic(srcIP, dstIP) {
return false
}
filter, ok := validateRule(ip, packetData, rules[ip.String()], d)
if ok {
return filter
}
filter, ok = validateRule(ip, packetData, rules["0.0.0.0"], d)
if ok {
return filter
}
filter, ok = validateRule(ip, packetData, rules["::"], d)
if ok {
return filter
// Check connection state only if enabled
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
return false
}
// default policy is DROP ALL
return m.applyRules(srcIP, packetData, rules, d)
}
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
log.Tracef("couldn't decode layer, err: %s", err)
return false
}
if len(d.decoded) < 2 {
log.Tracef("not enough levels in network packet")
return false
}
return true
}
func (m *Manager) isWireguardTraffic(srcIP, dstIP net.IP) bool {
return m.wgNetwork.Contains(srcIP) && m.wgNetwork.Contains(dstIP)
}
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
switch d.decoded[1] {
case layers.LayerTypeTCP:
return m.tcpTracker.IsValidInbound(
srcIP,
dstIP,
uint16(d.tcp.SrcPort),
uint16(d.tcp.DstPort),
getTCPFlags(&d.tcp),
)
case layers.LayerTypeUDP:
return m.udpTracker.IsValidInbound(
srcIP,
dstIP,
uint16(d.udp.SrcPort),
uint16(d.udp.DstPort),
)
case layers.LayerTypeICMPv4:
return m.icmpTracker.IsValidInbound(
srcIP,
dstIP,
d.icmp4.Id,
d.icmp4.Seq,
d.icmp4.TypeCode.Type(),
)
// TODO: ICMPv6
}
return false
}
func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok {
return filter
}
if filter, ok := validateRule(srcIP, packetData, rules["0.0.0.0"], d); ok {
return filter
}
if filter, ok := validateRule(srcIP, packetData, rules["::"], d); ok {
return filter
}
// Default policy: DROP ALL
return true
}
func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
if rulePort == nil {
return true
}
if rulePort.IsRange {
return packetPort >= rulePort.Values[0] && packetPort <= rulePort.Values[1]
}
for _, p := range rulePort.Values {
if p == packetPort {
return true
}
}
return false
}
func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) {
payloadLayer := d.decoded[1]
for _, rule := range rules {
@@ -341,13 +513,7 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
switch payloadLayer {
case layers.LayerTypeTCP:
if rule.sPort == 0 && rule.dPort == 0 {
return rule.drop, true
}
if rule.sPort != 0 && rule.sPort == uint16(d.tcp.SrcPort) {
return rule.drop, true
}
if rule.dPort != 0 && rule.dPort == uint16(d.tcp.DstPort) {
if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) {
return rule.drop, true
}
case layers.LayerTypeUDP:
@@ -357,13 +523,7 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode
return rule.udpHook(packetData), true
}
if rule.sPort == 0 && rule.dPort == 0 {
return rule.drop, true
}
if rule.sPort != 0 && rule.sPort == uint16(d.udp.SrcPort) {
return rule.drop, true
}
if rule.dPort != 0 && rule.dPort == uint16(d.udp.DstPort) {
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
return rule.drop, true
}
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
@@ -388,9 +548,8 @@ func (m *Manager) AddUDPPacketHook(
id: uuid.New().String(),
ip: ip,
protoLayer: layers.LayerTypeUDP,
dPort: dPort,
dPort: &firewall.Port{Values: []uint16{dPort}},
ipLayer: layers.LayerTypeIPv6,
direction: firewall.RuleDirectionOUT,
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
udpHook: hook,
}
@@ -401,7 +560,6 @@ func (m *Manager) AddUDPPacketHook(
m.mutex.Lock()
if in {
r.direction = firewall.RuleDirectionIN
if _, ok := m.incomingRules[r.ip.String()]; !ok {
m.incomingRules[r.ip.String()] = make(map[string]Rule)
}
@@ -420,19 +578,22 @@ func (m *Manager) AddUDPPacketHook(
// RemovePacketHook removes packet hook by given ID
func (m *Manager) RemovePacketHook(hookID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
for _, arr := range m.incomingRules {
for _, r := range arr {
if r.id == hookID {
rule := r
return m.DeletePeerRule(&rule)
delete(arr, r.id)
return nil
}
}
}
for _, arr := range m.outgoingRules {
for _, r := range arr {
if r.id == hookID {
rule := r
return m.DeletePeerRule(&rule)
delete(arr, r.id)
return nil
}
}
}

View File

@@ -0,0 +1,998 @@
package uspfilter
import (
"fmt"
"math/rand"
"net"
"os"
"strings"
"testing"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/iface/device"
)
// generateRandomIPs generates n different random IPs in the 100.64.0.0/10 range
func generateRandomIPs(n int) []net.IP {
ips := make([]net.IP, n)
seen := make(map[string]bool)
for i := 0; i < n; {
ip := make(net.IP, 4)
ip[0] = 100
ip[1] = byte(64 + rand.Intn(63)) // 64-126
ip[2] = byte(rand.Intn(256))
ip[3] = byte(1 + rand.Intn(254)) // avoid .0 and .255
key := ip.String()
if !seen[key] {
ips[i] = ip
seen[key] = true
i++
}
}
return ips
}
func generatePacket(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16, protocol layers.IPProtocol) []byte {
b.Helper()
ipv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: srcIP,
DstIP: dstIP,
Protocol: protocol,
}
var transportLayer gopacket.SerializableLayer
switch protocol {
case layers.IPProtocolTCP:
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
SYN: true,
}
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4))
transportLayer = tcp
case layers.IPProtocolUDP:
udp := &layers.UDP{
SrcPort: layers.UDPPort(srcPort),
DstPort: layers.UDPPort(dstPort),
}
require.NoError(b, udp.SetNetworkLayerForChecksum(ipv4))
transportLayer = udp
}
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
err := gopacket.SerializeLayers(buf, opts, ipv4, transportLayer, gopacket.Payload("test"))
require.NoError(b, err)
return buf.Bytes()
}
// BenchmarkCoreFiltering focuses on the essential performance comparisons between
// stateful and stateless filtering approaches
func BenchmarkCoreFiltering(b *testing.B) {
scenarios := []struct {
name string
stateful bool
setupFunc func(*Manager)
desc string
}{
{
name: "stateless_single_allow_all",
stateful: false,
setupFunc: func(m *Manager) {
// Single rule allowing all traffic
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil,
fw.ActionAccept, "", "allow all")
require.NoError(b, err)
},
desc: "Baseline: Single 'allow all' rule without connection tracking",
},
{
name: "stateful_no_rules",
stateful: true,
setupFunc: func(m *Manager) {
// No explicit rules - rely purely on connection tracking
},
desc: "Pure connection tracking without any rules",
},
{
name: "stateless_explicit_return",
stateful: false,
setupFunc: func(m *Manager) {
// Add explicit rules matching return traffic pattern
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
ip := generateRandomIPs(1)[0]
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP,
&fw.Port{Values: []uint16{uint16(1024 + i)}},
&fw.Port{Values: []uint16{80}},
fw.ActionAccept, "", "explicit return")
require.NoError(b, err)
}
},
desc: "Explicit rules matching return traffic patterns without state",
},
{
name: "stateful_with_established",
stateful: true,
setupFunc: func(m *Manager) {
// Add some basic rules but rely on state for established connections
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil,
fw.ActionDrop, "", "default drop")
require.NoError(b, err)
},
desc: "Connection tracking with established connections",
},
}
// Test both TCP and UDP
protocols := []struct {
name string
proto layers.IPProtocol
}{
{"TCP", layers.IPProtocolTCP},
{"UDP", layers.IPProtocolUDP},
}
for _, sc := range scenarios {
for _, proto := range protocols {
b.Run(fmt.Sprintf("%s_%s", sc.name, proto.name), func(b *testing.B) {
// Configure stateful/stateless mode
if !sc.stateful {
require.NoError(b, os.Setenv("NB_DISABLE_CONNTRACK", "1"))
} else {
require.NoError(b, os.Setenv("NB_CONNTRACK_TIMEOUT", "1m"))
}
// Create manager and basic setup
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
// Apply scenario-specific setup
sc.setupFunc(manager)
// Generate test packets
srcIP := generateRandomIPs(1)[0]
dstIP := generateRandomIPs(1)[0]
srcPort := uint16(1024 + b.N%60000)
dstPort := uint16(80)
outbound := generatePacket(b, srcIP, dstIP, srcPort, dstPort, proto.proto)
inbound := generatePacket(b, dstIP, srcIP, dstPort, srcPort, proto.proto)
// For stateful scenarios, establish the connection
if sc.stateful {
manager.processOutgoingHooks(outbound)
}
// Measure inbound packet processing
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, manager.incomingRules)
}
})
}
}
}
// BenchmarkStateScaling measures how performance scales with connection table size
func BenchmarkStateScaling(b *testing.B) {
connCounts := []int{100, 1000, 10000, 100000}
for _, count := range connCounts {
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
// Pre-populate connection table
srcIPs := generateRandomIPs(count)
dstIPs := generateRandomIPs(count)
for i := 0; i < count; i++ {
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, layers.IPProtocolTCP)
manager.processOutgoingHooks(outbound)
}
// Test packet
testOut := generatePacket(b, srcIPs[0], dstIPs[0], 1024, 80, layers.IPProtocolTCP)
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
// First establish our test connection
manager.processOutgoingHooks(testOut)
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(testIn, manager.incomingRules)
}
})
}
}
// BenchmarkEstablishmentOverhead measures the overhead of connection establishment
func BenchmarkEstablishmentOverhead(b *testing.B) {
scenarios := []struct {
name string
established bool
}{
{"established", true},
{"new", false},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
srcIP := generateRandomIPs(1)[0]
dstIP := generateRandomIPs(1)[0]
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
if sc.established {
manager.processOutgoingHooks(outbound)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, manager.incomingRules)
}
})
}
}
// BenchmarkRoutedNetworkReturn compares approaches for handling routed network return traffic
func BenchmarkRoutedNetworkReturn(b *testing.B) {
scenarios := []struct {
name string
proto layers.IPProtocol
state string // "new", "established", "post_handshake" (TCP only)
setupFunc func(*Manager)
genPackets func(net.IP, net.IP) ([]byte, []byte) // generates appropriate packets for the scenario
desc string
}{
{
name: "allow_non_wg_tcp_new",
proto: layers.IPProtocolTCP,
state: "new",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1")
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP),
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
},
desc: "Allow non-WG: TCP new connection",
},
{
name: "allow_non_wg_tcp_established",
proto: layers.IPProtocolTCP,
state: "established",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1")
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
// Generate packets with ACK flag for established connection
return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)),
generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPAck))
},
desc: "Allow non-WG: TCP established connection",
},
{
name: "allow_non_wg_udp_new",
proto: layers.IPProtocolUDP,
state: "new",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1")
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP),
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP)
},
desc: "Allow non-WG: UDP new connection",
},
{
name: "allow_non_wg_udp_established",
proto: layers.IPProtocolUDP,
state: "established",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1")
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP),
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP)
},
desc: "Allow non-WG: UDP established connection",
},
{
name: "stateful_tcp_new",
proto: layers.IPProtocolTCP,
state: "new",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP),
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
},
desc: "Stateful: TCP new connection",
},
{
name: "stateful_tcp_established",
proto: layers.IPProtocolTCP,
state: "established",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
// Generate established TCP packets (ACK flag)
return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)),
generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPAck))
},
desc: "Stateful: TCP established connection",
},
{
name: "stateful_tcp_post_handshake",
proto: layers.IPProtocolTCP,
state: "post_handshake",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
// Generate packets with PSH+ACK flags for data transfer
return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPPush|conntrack.TCPAck)),
generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPPush|conntrack.TCPAck))
},
desc: "Stateful: TCP post-handshake data transfer",
},
{
name: "stateful_udp_new",
proto: layers.IPProtocolUDP,
state: "new",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP),
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP)
},
desc: "Stateful: UDP new connection",
},
{
name: "stateful_udp_established",
proto: layers.IPProtocolUDP,
state: "established",
setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
},
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP),
generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP)
},
desc: "Stateful: UDP established connection",
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
// Setup scenario
sc.setupFunc(manager)
// Use IPs outside WG range for routed network simulation
srcIP := net.ParseIP("192.168.1.2")
dstIP := net.ParseIP("8.8.8.8")
outbound, inbound := sc.genPackets(srcIP, dstIP)
// For stateful cases and established connections
if !strings.Contains(sc.name, "allow_non_wg") ||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
manager.processOutgoingHooks(outbound)
// For TCP post-handshake, simulate full handshake
if sc.state == "post_handshake" {
// SYN
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn)
// SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, manager.incomingRules)
// ACK
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack)
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.dropFilter(inbound, manager.incomingRules)
}
})
}
}
var scenarios = []struct {
name string
stateful bool // Whether conntrack is enabled
rules bool // Whether to add return traffic rules
routed bool // Whether to test routed network traffic
connCount int // Number of concurrent connections
desc string
}{
{
name: "stateless_with_rules_100conns",
stateful: false,
rules: true,
routed: false,
connCount: 100,
desc: "Pure stateless with return traffic rules, 100 conns",
},
{
name: "stateless_with_rules_1000conns",
stateful: false,
rules: true,
routed: false,
connCount: 1000,
desc: "Pure stateless with return traffic rules, 1000 conns",
},
{
name: "stateful_no_rules_100conns",
stateful: true,
rules: false,
routed: false,
connCount: 100,
desc: "Pure stateful tracking without rules, 100 conns",
},
{
name: "stateful_no_rules_1000conns",
stateful: true,
rules: false,
routed: false,
connCount: 1000,
desc: "Pure stateful tracking without rules, 1000 conns",
},
{
name: "stateful_with_rules_100conns",
stateful: true,
rules: true,
routed: false,
connCount: 100,
desc: "Combined stateful + rules (current implementation), 100 conns",
},
{
name: "stateful_with_rules_1000conns",
stateful: true,
rules: true,
routed: false,
connCount: 1000,
desc: "Combined stateful + rules (current implementation), 1000 conns",
},
{
name: "routed_network_100conns",
stateful: true,
rules: false,
routed: true,
connCount: 100,
desc: "Routed network traffic (non-WG), 100 conns",
},
{
name: "routed_network_1000conns",
stateful: true,
rules: false,
routed: true,
connCount: 1000,
desc: "Routed network traffic (non-WG), 1000 conns",
},
}
// BenchmarkLongLivedConnections tests performance with realistic TCP traffic patterns
func BenchmarkLongLivedConnections(b *testing.B) {
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
// Configure stateful/stateless mode
if !sc.stateful {
b.Setenv("NB_DISABLE_CONNTRACK", "1")
} else {
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario
if sc.rules {
// Single rule to allow all return traffic from port 80
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []uint16{80}},
nil,
fw.ActionAccept, "", "return traffic")
require.NoError(b, err)
}
// Generate IPs for connections
srcIPs := make([]net.IP, sc.connCount)
dstIPs := make([]net.IP, sc.connCount)
for i := 0; i < sc.connCount; i++ {
if sc.routed {
srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4()
dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4()
} else {
srcIPs[i] = generateRandomIPs(1)[0]
dstIPs[i] = generateRandomIPs(1)[0]
}
}
// Create established connections
for i := 0; i < sc.connCount; i++ {
// Initial SYN
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn)
// SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, manager.incomingRules)
// ACK
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack)
}
// Prepare test packets simulating bidirectional traffic
inPackets := make([][]byte, sc.connCount)
outPackets := make([][]byte, sc.connCount)
for i := 0; i < sc.connCount; i++ {
// Server -> Client (inbound)
inPackets[i] = generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck))
// Client -> Server (outbound)
outPackets[i] = generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
connIdx := i % sc.connCount
// Simulate bidirectional traffic
// First outbound data
manager.processOutgoingHooks(outPackets[connIdx])
// Then inbound response - this is what we're actually measuring
manager.dropFilter(inPackets[connIdx], manager.incomingRules)
}
})
}
}
// BenchmarkShortLivedConnections tests performance with many short-lived connections
func BenchmarkShortLivedConnections(b *testing.B) {
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
// Configure stateful/stateless mode
if !sc.stateful {
b.Setenv("NB_DISABLE_CONNTRACK", "1")
} else {
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario
if sc.rules {
// Single rule to allow all return traffic from port 80
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []uint16{80}},
nil,
fw.ActionAccept, "", "return traffic")
require.NoError(b, err)
}
// Generate IPs for connections
srcIPs := make([]net.IP, sc.connCount)
dstIPs := make([]net.IP, sc.connCount)
for i := 0; i < sc.connCount; i++ {
if sc.routed {
srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4()
dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4()
} else {
srcIPs[i] = generateRandomIPs(1)[0]
dstIPs[i] = generateRandomIPs(1)[0]
}
}
// Create packet patterns for a complete HTTP-like short connection:
// 1. Initial handshake (SYN, SYN-ACK, ACK)
// 2. HTTP Request (PSH+ACK from client)
// 3. HTTP Response (PSH+ACK from server)
// 4. Connection teardown (FIN+ACK, ACK, FIN+ACK, ACK)
type connPackets struct {
syn []byte
synAck []byte
ack []byte
request []byte
response []byte
finClient []byte
ackServer []byte
finServer []byte
ackClient []byte
}
// Generate all possible connection patterns
patterns := make([]connPackets, sc.connCount)
for i := 0; i < sc.connCount; i++ {
patterns[i] = connPackets{
// Handshake
syn: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn)),
synAck: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)),
ack: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck)),
// Data transfer
request: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)),
response: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)),
// Connection teardown
finClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPFin|conntrack.TCPAck)),
ackServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPAck)),
finServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPFin|conntrack.TCPAck)),
ackClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck)),
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Each iteration creates a new short-lived connection
connIdx := i % sc.connCount
p := patterns[connIdx]
// Connection establishment
manager.processOutgoingHooks(p.syn)
manager.dropFilter(p.synAck, manager.incomingRules)
manager.processOutgoingHooks(p.ack)
// Data transfer
manager.processOutgoingHooks(p.request)
manager.dropFilter(p.response, manager.incomingRules)
// Connection teardown
manager.processOutgoingHooks(p.finClient)
manager.dropFilter(p.ackServer, manager.incomingRules)
manager.dropFilter(p.finServer, manager.incomingRules)
manager.processOutgoingHooks(p.ackClient)
}
})
}
}
// BenchmarkParallelLongLivedConnections tests performance with realistic TCP traffic patterns in parallel
func BenchmarkParallelLongLivedConnections(b *testing.B) {
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
// Configure stateful/stateless mode
if !sc.stateful {
b.Setenv("NB_DISABLE_CONNTRACK", "1")
} else {
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario
if sc.rules {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []uint16{80}},
nil,
fw.ActionAccept, "", "return traffic")
require.NoError(b, err)
}
// Generate IPs for connections
srcIPs := make([]net.IP, sc.connCount)
dstIPs := make([]net.IP, sc.connCount)
for i := 0; i < sc.connCount; i++ {
if sc.routed {
srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4()
dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4()
} else {
srcIPs[i] = generateRandomIPs(1)[0]
dstIPs[i] = generateRandomIPs(1)[0]
}
}
// Create established connections
for i := 0; i < sc.connCount; i++ {
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn)
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack, manager.incomingRules)
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack)
}
// Pre-generate test packets
inPackets := make([][]byte, sc.connCount)
outPackets := make([][]byte, sc.connCount)
for i := 0; i < sc.connCount; i++ {
inPackets[i] = generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck))
outPackets[i] = generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck))
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
// Each goroutine gets its own counter to distribute load
counter := 0
for pb.Next() {
connIdx := counter % sc.connCount
counter++
// Simulate bidirectional traffic
manager.processOutgoingHooks(outPackets[connIdx])
manager.dropFilter(inPackets[connIdx], manager.incomingRules)
}
})
})
}
}
// BenchmarkParallelShortLivedConnections tests performance with many short-lived connections in parallel
func BenchmarkParallelShortLivedConnections(b *testing.B) {
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
// Configure stateful/stateless mode
if !sc.stateful {
b.Setenv("NB_DISABLE_CONNTRACK", "1")
} else {
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil))
})
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
if sc.rules {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
&fw.Port{Values: []uint16{80}},
nil,
fw.ActionAccept, "", "return traffic")
require.NoError(b, err)
}
// Generate IPs and pre-generate all packet patterns
srcIPs := make([]net.IP, sc.connCount)
dstIPs := make([]net.IP, sc.connCount)
for i := 0; i < sc.connCount; i++ {
if sc.routed {
srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4()
dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4()
} else {
srcIPs[i] = generateRandomIPs(1)[0]
dstIPs[i] = generateRandomIPs(1)[0]
}
}
type connPackets struct {
syn []byte
synAck []byte
ack []byte
request []byte
response []byte
finClient []byte
ackServer []byte
finServer []byte
ackClient []byte
}
patterns := make([]connPackets, sc.connCount)
for i := 0; i < sc.connCount; i++ {
patterns[i] = connPackets{
syn: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn)),
synAck: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)),
ack: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck)),
request: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)),
response: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)),
finClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPFin|conntrack.TCPAck)),
ackServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPAck)),
finServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPFin|conntrack.TCPAck)),
ackClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck)),
}
}
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
counter := 0
for pb.Next() {
connIdx := counter % sc.connCount
counter++
p := patterns[connIdx]
// Full connection lifecycle
manager.processOutgoingHooks(p.syn)
manager.dropFilter(p.synAck, manager.incomingRules)
manager.processOutgoingHooks(p.ack)
manager.processOutgoingHooks(p.request)
manager.dropFilter(p.response, manager.incomingRules)
manager.processOutgoingHooks(p.finClient)
manager.dropFilter(p.ackServer, manager.incomingRules)
manager.dropFilter(p.finServer, manager.incomingRules)
manager.processOutgoingHooks(p.ackClient)
}
})
})
}
}
// generateTCPPacketWithFlags creates a TCP packet with specific flags
func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte {
b.Helper()
ipv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: srcIP,
DstIP: dstIP,
Protocol: layers.IPProtocolTCP,
}
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
}
// Set TCP flags
tcp.SYN = (flags & uint16(conntrack.TCPSyn)) != 0
tcp.ACK = (flags & uint16(conntrack.TCPAck)) != 0
tcp.PSH = (flags & uint16(conntrack.TCPPush)) != 0
tcp.RST = (flags & uint16(conntrack.TCPRst)) != 0
tcp.FIN = (flags & uint16(conntrack.TCPFin)) != 0
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4))
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")))
return buf.Bytes()
}

View File

@@ -3,6 +3,7 @@ package uspfilter
import (
"fmt"
"net"
"sync"
"testing"
"time"
@@ -11,6 +12,7 @@ import (
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
@@ -67,12 +69,11 @@ func TestManagerAddPeerFiltering(t *testing.T) {
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []int{80}}
direction := fw.RuleDirectionOUT
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
comment := "Test rule"
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -102,38 +103,16 @@ func TestManagerDeleteRule(t *testing.T) {
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []int{80}}
direction := fw.RuleDirectionOUT
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
comment := "Test rule"
comment := "Test rule 2"
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
}
ip = net.ParseIP("192.168.1.1")
proto = fw.ProtocolTCP
port = &fw.Port{Values: []int{80}}
direction = fw.RuleDirectionIN
action = fw.ActionDrop
comment = "Test rule 2"
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
}
for _, r := range rule {
err = m.DeletePeerRule(r)
if err != nil {
t.Errorf("failed to delete rule: %v", err)
return
}
}
for _, r := range rule2 {
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
t.Errorf("rule2 is not in the incomingRules")
@@ -185,10 +164,10 @@ func TestAddUDPPacketHook(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager := &Manager{
incomingRules: map[string]RuleSet{},
outgoingRules: map[string]RuleSet{},
}
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
require.NoError(t, err)
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
@@ -215,18 +194,14 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
return
}
if tt.dPort != addedRule.dPort {
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort)
if tt.dPort != addedRule.dPort.Values[0] {
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort.Values[0])
return
}
if layers.LayerTypeUDP != addedRule.protoLayer {
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
return
}
if tt.expDir != addedRule.direction {
t.Errorf("expected direction %d, got %d", tt.expDir, addedRule.direction)
return
}
if addedRule.udpHook == nil {
t.Errorf("expected udpHook to be set")
return
@@ -248,12 +223,11 @@ func TestManagerReset(t *testing.T) {
ip := net.ParseIP("192.168.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []int{80}}
direction := fw.RuleDirectionOUT
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
_, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -287,11 +261,10 @@ func TestNotMatchByIP(t *testing.T) {
ip := net.ParseIP("0.0.0.0")
proto := fw.ProtocolUDP
direction := fw.RuleDirectionOUT
action := fw.ActionAccept
comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, nil, direction, action, "", comment)
_, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment)
if err != nil {
t.Errorf("failed to add filtering: %v", err)
return
@@ -313,7 +286,7 @@ func TestNotMatchByIP(t *testing.T) {
t.Errorf("failed to set network layer for checksum: %v", err)
return
}
payload := gopacket.Payload([]byte("test"))
payload := gopacket.Payload("test")
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
@@ -325,7 +298,7 @@ func TestNotMatchByIP(t *testing.T) {
return
}
if m.dropFilter(buf.Bytes(), m.outgoingRules, false) {
if m.dropFilter(buf.Bytes(), m.incomingRules) {
t.Errorf("expected packet to be accepted")
return
}
@@ -348,6 +321,9 @@ func TestRemovePacketHook(t *testing.T) {
if err != nil {
t.Fatalf("Failed to create Manager: %s", err)
}
defer func() {
require.NoError(t, manager.Reset(nil))
}()
// Add a UDP packet hook
hookFunc := func(data []byte) bool { return true }
@@ -384,6 +360,88 @@ func TestRemovePacketHook(t *testing.T) {
}
}
func TestProcessOutgoingHooks(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond)
defer func() {
require.NoError(t, manager.Reset(nil))
}()
manager.decoders = sync.Pool{
New: func() any {
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
return d
},
}
hookCalled := false
hookID := manager.AddUDPPacketHook(
false,
net.ParseIP("100.10.0.100"),
53,
func([]byte) bool {
hookCalled = true
return true
},
)
require.NotEmpty(t, hookID)
// Create test UDP packet
ipv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: net.ParseIP("100.10.0.1"),
DstIP: net.ParseIP("100.10.0.100"),
Protocol: layers.IPProtocolUDP,
}
udp := &layers.UDP{
SrcPort: 51334,
DstPort: 53,
}
err = udp.SetNetworkLayerForChecksum(ipv4)
require.NoError(t, err)
payload := gopacket.Payload("test")
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
err = gopacket.SerializeLayers(buf, opts, ipv4, udp, payload)
require.NoError(t, err)
// Test hook gets called
result := manager.processOutgoingHooks(buf.Bytes())
require.True(t, result)
require.True(t, hookCalled)
// Test non-UDP packet is ignored
ipv4.Protocol = layers.IPProtocolTCP
buf = gopacket.NewSerializeBuffer()
err = gopacket.SerializeLayers(buf, opts, ipv4)
require.NoError(t, err)
result = manager.processOutgoingHooks(buf.Bytes())
require.False(t, result)
}
func TestUSPFilterCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
@@ -405,12 +463,8 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
ip := net.ParseIP("10.20.0.100")
start := time.Now()
for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else {
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
}
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule")
}
@@ -418,3 +472,213 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
})
}
}
func TestStatefulFirewall_UDPTracking(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
})
require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond)
manager.decoders = sync.Pool{
New: func() any {
d := &decoder{
decoded: []gopacket.LayerType{},
}
d.parser = gopacket.NewDecodingLayerParser(
layers.LayerTypeIPv4,
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
)
d.parser.IgnoreUnsupported = true
return d
},
}
defer func() {
require.NoError(t, manager.Reset(nil))
}()
// Set up packet parameters
srcIP := net.ParseIP("100.10.0.1")
dstIP := net.ParseIP("100.10.0.100")
srcPort := uint16(51334)
dstPort := uint16(53)
// Create outbound packet
outboundIPv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: srcIP,
DstIP: dstIP,
Protocol: layers.IPProtocolUDP,
}
outboundUDP := &layers.UDP{
SrcPort: layers.UDPPort(srcPort),
DstPort: layers.UDPPort(dstPort),
}
err = outboundUDP.SetNetworkLayerForChecksum(outboundIPv4)
require.NoError(t, err)
outboundBuf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
err = gopacket.SerializeLayers(outboundBuf, opts,
outboundIPv4,
outboundUDP,
gopacket.Payload("test"),
)
require.NoError(t, err)
// Process outbound packet and verify connection tracking
drop := manager.DropOutgoing(outboundBuf.Bytes())
require.False(t, drop, "Initial outbound packet should not be dropped")
// Verify connection was tracked
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
require.True(t, exists, "Connection should be tracked after outbound packet")
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match")
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match")
require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
// Create valid inbound response packet
inboundIPv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: dstIP, // Original destination is now source
DstIP: srcIP, // Original source is now destination
Protocol: layers.IPProtocolUDP,
}
inboundUDP := &layers.UDP{
SrcPort: layers.UDPPort(dstPort), // Original destination port is now source
DstPort: layers.UDPPort(srcPort), // Original source port is now destination
}
err = inboundUDP.SetNetworkLayerForChecksum(inboundIPv4)
require.NoError(t, err)
inboundBuf := gopacket.NewSerializeBuffer()
err = gopacket.SerializeLayers(inboundBuf, opts,
inboundIPv4,
inboundUDP,
gopacket.Payload("response"),
)
require.NoError(t, err)
// Test roundtrip response handling over time
checkPoints := []struct {
sleep time.Duration
shouldAllow bool
description string
}{
{
sleep: 0,
shouldAllow: true,
description: "Immediate response should be allowed",
},
{
sleep: 50 * time.Millisecond,
shouldAllow: true,
description: "Response within timeout should be allowed",
},
{
sleep: 100 * time.Millisecond,
shouldAllow: true,
description: "Response at half timeout should be allowed",
},
{
// tracker hasn't updated conn for 250ms -> greater than 200ms timeout
sleep: 250 * time.Millisecond,
shouldAllow: false,
description: "Response after timeout should be dropped",
},
}
for _, cp := range checkPoints {
time.Sleep(cp.sleep)
drop = manager.dropFilter(inboundBuf.Bytes(), manager.incomingRules)
require.Equal(t, cp.shouldAllow, !drop, cp.description)
// If the connection should still be valid, verify it exists
if cp.shouldAllow {
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
require.True(t, exists, "Connection should still exist during valid window")
require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(),
"LastSeen should be updated for valid responses")
}
}
// Test invalid response packets (while connection is expired)
invalidCases := []struct {
name string
modifyFunc func(*layers.IPv4, *layers.UDP)
description string
}{
{
name: "wrong source IP",
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
ip.SrcIP = net.ParseIP("100.10.0.101")
},
description: "Response from wrong IP should be dropped",
},
{
name: "wrong destination IP",
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
ip.DstIP = net.ParseIP("100.10.0.2")
},
description: "Response to wrong IP should be dropped",
},
{
name: "wrong source port",
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
udp.SrcPort = 54
},
description: "Response from wrong port should be dropped",
},
{
name: "wrong destination port",
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
udp.DstPort = 51335
},
description: "Response to wrong port should be dropped",
},
}
// Create a new outbound connection for invalid tests
drop = manager.processOutgoingHooks(outboundBuf.Bytes())
require.False(t, drop, "Second outbound packet should not be dropped")
for _, tc := range invalidCases {
t.Run(tc.name, func(t *testing.T) {
testIPv4 := *inboundIPv4
testUDP := *inboundUDP
tc.modifyFunc(&testIPv4, &testUDP)
err = testUDP.SetNetworkLayerForChecksum(&testIPv4)
require.NoError(t, err)
testBuf := gopacket.NewSerializeBuffer()
err = gopacket.SerializeLayers(testBuf, opts,
&testIPv4,
&testUDP,
gopacket.Payload("response"),
)
require.NoError(t, err)
// Verify the invalid packet is dropped
drop = manager.dropFilter(testBuf.Bytes(), manager.incomingRules)
require.True(t, drop, tc.description)
})
}
}

View File

@@ -33,8 +33,6 @@ type TunKernelDevice struct {
}
func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
checkUser()
ctx, cancel := context.WithCancel(context.Background())
return &TunKernelDevice{
ctx: ctx,

View File

@@ -4,8 +4,6 @@ package device
import (
"fmt"
"os"
"runtime"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device"
@@ -32,8 +30,6 @@ type USPDevice struct {
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
log.Infof("using userspace bind mode")
checkUser()
return &USPDevice{
name: name,
address: address,
@@ -134,12 +130,3 @@ func (t *USPDevice) assignAddr() error {
return link.assignAddr(t.address)
}
func checkUser() {
if runtime.GOOS == "freebsd" {
euid := os.Geteuid()
if euid != 0 {
log.Warn("newTunUSPDevice: on netbird must run as root to be able to assign address to the tun interface with ifconfig")
}
}
}

View File

@@ -203,6 +203,11 @@ func (l *Link) setAddr(ip, netmask string) error {
return fmt.Errorf("set interface addr: %w", err)
}
cmd = exec.Command("ifconfig", l.name, "inet6", "fe80::/64")
if out, err := cmd.CombinedOutput(); err != nil {
log.Debugf("adding address command '%v' failed with output: %s", cmd.String(), out)
}
return nil
}

View File

@@ -15,6 +15,10 @@ func IsEnabled() bool {
func ListenAddr() string {
sPort := os.Getenv("NB_SOCKS5_LISTENER_PORT")
if sPort == "" {
return listenAddr(DefaultSocks5Port)
}
port, err := strconv.Atoi(sPort)
if err != nil {
log.Warnf("invalid socks5 listener port, unable to convert it to int, falling back to default: %d", DefaultSocks5Port)

View File

@@ -151,7 +151,7 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
d.rollBack(newRulePairs)
break
}
if len(rules) > 0 {
if len(rulePair) > 0 {
d.peerRulesPairs[pairID] = rulePair
newRulePairs[pairID] = rulePair
}
@@ -268,13 +268,16 @@ func (d *DefaultManager) protoRuleToFirewallRule(
}
var port *firewall.Port
if r.Port != "" {
if r.PortInfo != nil {
port = convertPortInfo(r.PortInfo)
} else if r.Port != "" {
// old version of management, single port
value, err := strconv.Atoi(r.Port)
if err != nil {
return "", nil, fmt.Errorf("invalid port, skipping firewall rule")
return "", nil, fmt.Errorf("invalid port: %w", err)
}
port = &firewall.Port{
Values: []int{value},
Values: []uint16{uint16(value)},
}
}
@@ -288,6 +291,8 @@ func (d *DefaultManager) protoRuleToFirewallRule(
case mgmProto.RuleDirection_IN:
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
case mgmProto.RuleDirection_OUT:
// TODO: Remove this soon. Outbound rules are obsolete.
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
@@ -308,25 +313,12 @@ func (d *DefaultManager) addInRules(
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
var rules []firewall.Rule
rule, err := d.firewall.AddPeerFiltering(
ip, protocol, nil, port, firewall.RuleDirectionIN, action, ipsetName, comment)
rule, err := d.firewall.AddPeerFiltering(ip, protocol, nil, port, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
rules = append(rules, rule...)
if shouldSkipInvertedRule(protocol, port) {
return rules, nil
return nil, fmt.Errorf("add firewall rule: %w", err)
}
rule, err = d.firewall.AddPeerFiltering(
ip, protocol, port, nil, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
return append(rules, rule...), nil
return rule, nil
}
func (d *DefaultManager) addOutRules(
@@ -337,25 +329,16 @@ func (d *DefaultManager) addOutRules(
ipsetName string,
comment string,
) ([]firewall.Rule, error) {
var rules []firewall.Rule
rule, err := d.firewall.AddPeerFiltering(
ip, protocol, nil, port, firewall.RuleDirectionOUT, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
}
rules = append(rules, rule...)
if shouldSkipInvertedRule(protocol, port) {
return rules, nil
return nil, nil
}
rule, err = d.firewall.AddPeerFiltering(
ip, protocol, port, nil, firewall.RuleDirectionIN, action, ipsetName, comment)
rule, err := d.firewall.AddPeerFiltering(ip, protocol, port, nil, action, ipsetName, comment)
if err != nil {
return nil, fmt.Errorf("failed to add firewall rule: %v", err)
return nil, fmt.Errorf("add firewall rule: %w", err)
}
return append(rules, rule...), nil
return rule, nil
}
// getPeerRuleID() returns unique ID for the rule based on its parameters.
@@ -559,14 +542,14 @@ func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port {
if portInfo.GetPort() != 0 {
return &firewall.Port{
Values: []int{int(portInfo.GetPort())},
Values: []uint16{uint16(int(portInfo.GetPort()))},
}
}
if portInfo.GetRange() != nil {
return &firewall.Port{
IsRange: true,
Values: []int{int(portInfo.GetRange().Start), int(portInfo.GetRange().End)},
Values: []uint16{uint16(portInfo.GetRange().Start), uint16(portInfo.GetRange().End)},
}
}

View File

@@ -119,8 +119,8 @@ func TestDefaultManager(t *testing.T) {
networkMap.FirewallRulesIsEmpty = false
acl.ApplyFiltering(networkMap)
if len(acl.peerRulesPairs) != 2 {
t.Errorf("rules should contain 2 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
if len(acl.peerRulesPairs) != 1 {
t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
return
}
})
@@ -356,8 +356,8 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
acl.ApplyFiltering(networkMap)
if len(acl.peerRulesPairs) != 4 {
t.Errorf("expect 4 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
if len(acl.peerRulesPairs) != 3 {
t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
return
}
}

View File

@@ -61,6 +61,13 @@ type ConfigInput struct {
DNSRouteInterval *time.Duration
ClientCertPath string
ClientCertKeyPath string
DisableClientRoutes *bool
DisableServerRoutes *bool
DisableDNS *bool
DisableFirewall *bool
BlockLANAccess *bool
}
// Config Configuration type
@@ -78,6 +85,14 @@ type Config struct {
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed *bool
DisableClientRoutes bool
DisableServerRoutes bool
DisableDNS bool
DisableFirewall bool
BlockLANAccess bool
// SSHKey is a private SSH key in a PEM format
SSHKey string
@@ -402,7 +417,56 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
config.DNSRouteInterval = dynamic.DefaultInterval
log.Infof("using default DNS route interval %s", config.DNSRouteInterval)
updated = true
}
if input.DisableClientRoutes != nil && *input.DisableClientRoutes != config.DisableClientRoutes {
if *input.DisableClientRoutes {
log.Infof("disabling client routes")
} else {
log.Infof("enabling client routes")
}
config.DisableClientRoutes = *input.DisableClientRoutes
updated = true
}
if input.DisableServerRoutes != nil && *input.DisableServerRoutes != config.DisableServerRoutes {
if *input.DisableServerRoutes {
log.Infof("disabling server routes")
} else {
log.Infof("enabling server routes")
}
config.DisableServerRoutes = *input.DisableServerRoutes
updated = true
}
if input.DisableDNS != nil && *input.DisableDNS != config.DisableDNS {
if *input.DisableDNS {
log.Infof("disabling DNS configuration")
} else {
log.Infof("enabling DNS configuration")
}
config.DisableDNS = *input.DisableDNS
updated = true
}
if input.DisableFirewall != nil && *input.DisableFirewall != config.DisableFirewall {
if *input.DisableFirewall {
log.Infof("disabling firewall configuration")
} else {
log.Infof("enabling firewall configuration")
}
config.DisableFirewall = *input.DisableFirewall
updated = true
}
if input.BlockLANAccess != nil && *input.BlockLANAccess != config.BlockLANAccess {
if *input.BlockLANAccess {
log.Infof("blocking LAN access")
} else {
log.Infof("allowing LAN access")
}
config.BlockLANAccess = *input.BlockLANAccess
updated = true
}
if input.ClientCertKeyPath != "" {

View File

@@ -59,13 +59,8 @@ func NewConnectClient(
}
// Run with main logic.
func (c *ConnectClient) Run() error {
return c.run(MobileDependency{}, nil, nil)
}
// RunWithProbes runs the client's main logic with probes attached
func (c *ConnectClient) RunWithProbes(probes *ProbeHolder, runningChan chan error) error {
return c.run(MobileDependency{}, probes, runningChan)
func (c *ConnectClient) Run(runningChan chan error) error {
return c.run(MobileDependency{}, runningChan)
}
// RunOnAndroid with main logic on mobile system
@@ -84,7 +79,7 @@ func (c *ConnectClient) RunOnAndroid(
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
}
return c.run(mobileDependency, nil, nil)
return c.run(mobileDependency, nil)
}
func (c *ConnectClient) RunOniOS(
@@ -102,10 +97,10 @@ func (c *ConnectClient) RunOniOS(
DnsManager: dnsManager,
StateFilePath: stateFilePath,
}
return c.run(mobileDependency, nil, nil)
return c.run(mobileDependency, nil)
}
func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHolder, runningChan chan error) error {
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan error) error {
defer func() {
if r := recover(); r != nil {
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
@@ -183,7 +178,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
}()
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Wiretrustee config
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey)
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.config)
if err != nil {
log.Debug(err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
@@ -261,7 +256,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
checks := loginResp.GetChecks()
c.engineMutex.Lock()
c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks)
c.engine = NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
c.engine.SetNetworkMapPersistence(c.persistNetworkMap)
c.engineMutex.Unlock()
@@ -382,8 +377,7 @@ func (c *ConnectClient) isContextCancelled() bool {
// SetNetworkMapPersistence enables or disables network map persistence.
// When enabled, the last received network map will be stored and can be retrieved
// through the Engine's getLatestNetworkMap method. When disabled, any stored
// network map will be cleared. This functionality is primarily used for debugging
// and should not be enabled during normal operation.
// network map will be cleared.
func (c *ConnectClient) SetNetworkMapPersistence(enabled bool) {
c.engineMutex.Lock()
c.persistNetworkMap = enabled
@@ -416,6 +410,13 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
RosenpassPermissive: config.RosenpassPermissive,
ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed),
DNSRouteInterval: config.DNSRouteInterval,
DisableClientRoutes: config.DisableClientRoutes,
DisableServerRoutes: config.DisableServerRoutes,
DisableDNS: config.DisableDNS,
DisableFirewall: config.DisableFirewall,
BlockLANAccess: config.BlockLANAccess,
}
if config.PreSharedKey != "" {
@@ -457,7 +458,7 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig,
}
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Wiretrustee config (signal, turn, stun hosts, etc)
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey()
if err != nil {
@@ -465,6 +466,15 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte)
}
sysInfo := system.GetInfo(ctx)
sysInfo.SetFlags(
config.RosenpassEnabled,
config.RosenpassPermissive,
config.ServerSSHAllowed,
config.DisableClientRoutes,
config.DisableServerRoutes,
config.DisableDNS,
config.DisableFirewall,
)
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey)
if err != nil {
return nil, err

View File

@@ -0,0 +1,18 @@
//go:build !android
package dns
import (
"github.com/netbirdio/netbird/client/configs"
"os"
"path/filepath"
)
var fileUncleanShutdownResolvConfLocation string
func init() {
fileUncleanShutdownResolvConfLocation = os.Getenv("NB_UNCLEAN_SHUTDOWN_RESOLV_FILE")
if fileUncleanShutdownResolvConfLocation == "" {
fileUncleanShutdownResolvConfLocation = filepath.Join(configs.StateDir, "resolv.conf")
}
}

View File

@@ -1,5 +0,0 @@
package dns
const (
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
)

View File

@@ -1,7 +0,0 @@
//go:build !android
package dns
const (
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
)

View File

@@ -0,0 +1,238 @@
package dns
import (
"slices"
"strings"
"sync"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
)
const (
PriorityDNSRoute = 100
PriorityMatchDomain = 50
PriorityDefault = 0
)
type SubdomainMatcher interface {
dns.Handler
MatchSubdomains() bool
}
type HandlerEntry struct {
Handler dns.Handler
Priority int
Pattern string
OrigPattern string
IsWildcard bool
StopHandler handlerWithStop
MatchSubdomains bool
}
// HandlerChain represents a prioritized chain of DNS handlers
type HandlerChain struct {
mu sync.RWMutex
handlers []HandlerEntry
}
// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain
type ResponseWriterChain struct {
dns.ResponseWriter
origPattern string
shouldContinue bool
}
func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error {
// Check if this is a continue signal (NXDOMAIN with Zero bit set)
if m.Rcode == dns.RcodeNameError && m.MsgHdr.Zero {
w.shouldContinue = true
return nil
}
return w.ResponseWriter.WriteMsg(m)
}
func NewHandlerChain() *HandlerChain {
return &HandlerChain{
handlers: make([]HandlerEntry, 0),
}
}
// GetOrigPattern returns the original pattern of the handler that wrote the response
func (w *ResponseWriterChain) GetOrigPattern() string {
return w.origPattern
}
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) {
c.mu.Lock()
defer c.mu.Unlock()
pattern = strings.ToLower(dns.Fqdn(pattern))
origPattern := pattern
isWildcard := strings.HasPrefix(pattern, "*.")
if isWildcard {
pattern = pattern[2:]
}
// First remove any existing handler with same pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- {
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
if c.handlers[i].StopHandler != nil {
c.handlers[i].StopHandler.stop()
}
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
break
}
}
// Check if handler implements SubdomainMatcher interface
matchSubdomains := false
if matcher, ok := handler.(SubdomainMatcher); ok {
matchSubdomains = matcher.MatchSubdomains()
}
log.Debugf("adding handler pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d",
pattern, origPattern, isWildcard, matchSubdomains, priority)
entry := HandlerEntry{
Handler: handler,
Priority: priority,
Pattern: pattern,
OrigPattern: origPattern,
IsWildcard: isWildcard,
StopHandler: stopHandler,
MatchSubdomains: matchSubdomains,
}
pos := c.findHandlerPosition(entry)
c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...)
}
// findHandlerPosition determines where to insert a new handler based on priority and specificity
func (c *HandlerChain) findHandlerPosition(newEntry HandlerEntry) int {
for i, h := range c.handlers {
// prio first
if h.Priority < newEntry.Priority {
return i
}
// domain specificity next
if h.Priority == newEntry.Priority {
newDots := strings.Count(newEntry.Pattern, ".")
existingDots := strings.Count(h.Pattern, ".")
if newDots > existingDots {
return i
}
}
}
// add at end
return len(c.handlers)
}
// RemoveHandler removes a handler for the given pattern and priority
func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
c.mu.Lock()
defer c.mu.Unlock()
pattern = dns.Fqdn(pattern)
// Find and remove handlers matching both original pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i]
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
if entry.StopHandler != nil {
entry.StopHandler.stop()
}
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
return
}
}
}
// HasHandlers returns true if there are any handlers remaining for the given pattern
func (c *HandlerChain) HasHandlers(pattern string) bool {
c.mu.RLock()
defer c.mu.RUnlock()
pattern = strings.ToLower(dns.Fqdn(pattern))
for _, entry := range c.handlers {
if strings.EqualFold(entry.Pattern, pattern) {
return true
}
}
return false
}
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
return
}
qname := strings.ToLower(r.Question[0].Name)
log.Tracef("handling DNS request for domain=%s", qname)
c.mu.RLock()
handlers := slices.Clone(c.handlers)
c.mu.RUnlock()
if log.IsLevelEnabled(log.TraceLevel) {
log.Tracef("current handlers (%d):", len(handlers))
for _, h := range handlers {
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d",
h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority)
}
}
// Try handlers in priority order
for _, entry := range handlers {
var matched bool
switch {
case entry.Pattern == ".":
matched = true
case entry.IsWildcard:
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
default:
// For non-wildcard patterns:
// If handler wants subdomain matching, allow suffix match
// Otherwise require exact match
if entry.MatchSubdomains {
matched = strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
} else {
matched = strings.EqualFold(qname, entry.Pattern)
}
}
if !matched {
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v matched=false",
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard)
continue
}
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v",
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains)
chainWriter := &ResponseWriterChain{
ResponseWriter: w,
origPattern: entry.OrigPattern,
}
entry.Handler.ServeDNS(chainWriter, r)
// If handler wants to continue, try next handler
if chainWriter.shouldContinue {
log.Tracef("handler requested continue to next handler")
continue
}
return
}
// No handler matched or all handlers passed
log.Tracef("no handler found for domain=%s", qname)
resp := &dns.Msg{}
resp.SetRcode(r, dns.RcodeNameError)
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
}
}

View File

@@ -0,0 +1,832 @@
package dns_test
import (
"net"
"testing"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
)
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
chain := nbdns.NewHandlerChain()
// Create mock handlers for different priorities
defaultHandler := &nbdns.MockHandler{}
matchDomainHandler := &nbdns.MockHandler{}
dnsRouteHandler := &nbdns.MockHandler{}
// Setup handlers with different priorities
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil)
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil)
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil)
// Create test request
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
// Create test writer
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Setup expectations - only highest priority handler should be called
dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once()
matchDomainHandler.On("ServeDNS", mock.Anything, r).Maybe()
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe()
// Execute
chain.ServeDNS(w, r)
// Verify all expectations were met
dnsRouteHandler.AssertExpectations(t)
matchDomainHandler.AssertExpectations(t)
defaultHandler.AssertExpectations(t)
}
// TestHandlerChain_ServeDNS_DomainMatching tests various domain matching scenarios
func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
tests := []struct {
name string
handlerDomain string
queryDomain string
isWildcard bool
matchSubdomains bool
shouldMatch bool
}{
{
name: "exact match",
handlerDomain: "example.com.",
queryDomain: "example.com.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "subdomain with non-wildcard and MatchSubdomains true",
handlerDomain: "example.com.",
queryDomain: "sub.example.com.",
isWildcard: false,
matchSubdomains: true,
shouldMatch: true,
},
{
name: "subdomain with non-wildcard and MatchSubdomains false",
handlerDomain: "example.com.",
queryDomain: "sub.example.com.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: false,
},
{
name: "wildcard match",
handlerDomain: "*.example.com.",
queryDomain: "sub.example.com.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "wildcard no match on apex",
handlerDomain: "*.example.com.",
queryDomain: "example.com.",
isWildcard: true,
matchSubdomains: false,
shouldMatch: false,
},
{
name: "root zone match",
handlerDomain: ".",
queryDomain: "anything.com.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: true,
},
{
name: "no match different domain",
handlerDomain: "example.com.",
queryDomain: "example.org.",
isWildcard: false,
matchSubdomains: false,
shouldMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
var handler dns.Handler
if tt.matchSubdomains {
mockSubHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
handler = mockSubHandler
if tt.shouldMatch {
mockSubHandler.On("ServeDNS", mock.Anything, mock.Anything).Once()
}
} else {
mockHandler := &nbdns.MockHandler{}
handler = mockHandler
if tt.shouldMatch {
mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Once()
}
}
pattern := tt.handlerDomain
if tt.isWildcard {
pattern = "*." + tt.handlerDomain[2:]
}
chain.AddHandler(pattern, handler, nbdns.PriorityDefault, nil)
r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
chain.ServeDNS(w, r)
if h, ok := handler.(*nbdns.MockHandler); ok {
h.AssertExpectations(t)
} else if h, ok := handler.(*nbdns.MockSubdomainHandler); ok {
h.AssertExpectations(t)
}
})
}
}
// TestHandlerChain_ServeDNS_OverlappingDomains tests behavior with overlapping domain patterns
func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
tests := []struct {
name string
handlers []struct {
pattern string
priority int
}
queryDomain string
expectedCalls int
expectedHandler int // index of the handler that should be called
}{
{
name: "wildcard and exact same priority - exact should win",
handlers: []struct {
pattern string
priority int
}{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "example.com.", priority: nbdns.PriorityDefault},
},
queryDomain: "example.com.",
expectedCalls: 1,
expectedHandler: 1, // exact match handler should be called
},
{
name: "higher priority wildcard over lower priority exact",
handlers: []struct {
pattern string
priority int
}{
{pattern: "example.com.", priority: nbdns.PriorityDefault},
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "test.example.com.",
expectedCalls: 1,
expectedHandler: 1, // higher priority wildcard handler should be called
},
{
name: "multiple wildcards different priorities",
handlers: []struct {
pattern string
priority int
}{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain},
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "test.example.com.",
expectedCalls: 1,
expectedHandler: 2, // highest priority handler should be called
},
{
name: "subdomain with mix of patterns",
handlers: []struct {
pattern string
priority int
}{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain},
{pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "sub.test.example.com.",
expectedCalls: 1,
expectedHandler: 2, // highest priority matching handler should be called
},
{
name: "root zone with specific domain",
handlers: []struct {
pattern string
priority int
}{
{pattern: ".", priority: nbdns.PriorityDefault},
{pattern: "example.com.", priority: nbdns.PriorityDNSRoute},
},
queryDomain: "example.com.",
expectedCalls: 1,
expectedHandler: 1, // higher priority specific domain should win over root
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
var handlers []*nbdns.MockHandler
// Setup handlers and expectations
for i := range tt.handlers {
handler := &nbdns.MockHandler{}
handlers = append(handlers, handler)
// Set expectation based on whether this handler should be called
if i == tt.expectedHandler {
handler.On("ServeDNS", mock.Anything, mock.Anything).Once()
} else {
handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe()
}
chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil)
}
// Create and execute request
r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
chain.ServeDNS(w, r)
// Verify expectations
for _, handler := range handlers {
handler.AssertExpectations(t)
}
})
}
}
// TestHandlerChain_ServeDNS_ChainContinuation tests the chain continuation functionality
func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
chain := nbdns.NewHandlerChain()
// Create handlers
handler1 := &nbdns.MockHandler{}
handler2 := &nbdns.MockHandler{}
handler3 := &nbdns.MockHandler{}
// Add handlers in priority order
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil)
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil)
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil)
// Create test request
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
// Setup mock responses to simulate chain continuation
handler1.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
// First handler signals continue
w := args.Get(0).(*nbdns.ResponseWriterChain)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeNameError)
resp.MsgHdr.Zero = true // Signal to continue
assert.NoError(t, w.WriteMsg(resp))
}).Once()
handler2.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
// Second handler signals continue
w := args.Get(0).(*nbdns.ResponseWriterChain)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeNameError)
resp.MsgHdr.Zero = true
assert.NoError(t, w.WriteMsg(resp))
}).Once()
handler3.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
// Last handler responds normally
w := args.Get(0).(*nbdns.ResponseWriterChain)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeSuccess)
assert.NoError(t, w.WriteMsg(resp))
}).Once()
// Execute
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
chain.ServeDNS(w, r)
// Verify all handlers were called in order
handler1.AssertExpectations(t)
handler2.AssertExpectations(t)
handler3.AssertExpectations(t)
}
// mockResponseWriter implements dns.ResponseWriter for testing
type mockResponseWriter struct {
mock.Mock
}
func (m *mockResponseWriter) LocalAddr() net.Addr { return nil }
func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil }
func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil }
func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
func (m *mockResponseWriter) Close() error { return nil }
func (m *mockResponseWriter) TsigStatus() error { return nil }
func (m *mockResponseWriter) TsigTimersOnly(bool) {}
func (m *mockResponseWriter) Hijack() {}
func TestHandlerChain_PriorityDeregistration(t *testing.T) {
tests := []struct {
name string
ops []struct {
action string // "add" or "remove"
pattern string
priority int
}
query string
expectedCalls map[int]bool // map[priority]shouldBeCalled
}{
{
name: "remove high priority keeps lower priority handler",
ops: []struct {
action string
pattern string
priority int
}{
{"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityMatchDomain},
{"remove", "example.com.", nbdns.PriorityDNSRoute},
},
query: "example.com.",
expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: false,
nbdns.PriorityMatchDomain: true,
},
},
{
name: "remove lower priority keeps high priority handler",
ops: []struct {
action string
pattern string
priority int
}{
{"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityMatchDomain},
{"remove", "example.com.", nbdns.PriorityMatchDomain},
},
query: "example.com.",
expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: true,
nbdns.PriorityMatchDomain: false,
},
},
{
name: "remove all handlers in order",
ops: []struct {
action string
pattern string
priority int
}{
{"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityMatchDomain},
{"add", "example.com.", nbdns.PriorityDefault},
{"remove", "example.com.", nbdns.PriorityDNSRoute},
{"remove", "example.com.", nbdns.PriorityMatchDomain},
},
query: "example.com.",
expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: false,
nbdns.PriorityMatchDomain: false,
nbdns.PriorityDefault: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
handlers := make(map[int]*nbdns.MockHandler)
// Execute operations
for _, op := range tt.ops {
if op.action == "add" {
handler := &nbdns.MockHandler{}
handlers[op.priority] = handler
chain.AddHandler(op.pattern, handler, op.priority, nil)
} else {
chain.RemoveHandler(op.pattern, op.priority)
}
}
// Create test request
r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Setup expectations
for priority, handler := range handlers {
if shouldCall, exists := tt.expectedCalls[priority]; exists && shouldCall {
handler.On("ServeDNS", mock.Anything, r).Once()
} else {
handler.On("ServeDNS", mock.Anything, r).Maybe()
}
}
// Execute request
chain.ServeDNS(w, r)
// Verify expectations
for _, handler := range handlers {
handler.AssertExpectations(t)
}
// Verify handler exists check
for priority, shouldExist := range tt.expectedCalls {
if shouldExist {
assert.True(t, chain.HasHandlers(tt.ops[0].pattern),
"Handler chain should have handlers for pattern after removing priority %d", priority)
}
}
})
}
}
func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
chain := nbdns.NewHandlerChain()
testDomain := "example.com."
testQuery := "test.example.com."
// Create handlers with MatchSubdomains enabled
routeHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
matchHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
defaultHandler := &nbdns.MockSubdomainHandler{Subdomains: true}
// Create test request that will be reused
r := new(dns.Msg)
r.SetQuestion(testQuery, dns.TypeA)
// Add handlers in mixed order
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault, nil)
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute, nil)
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain, nil)
// Test 1: Initial state with all three handlers
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Highest priority handler (routeHandler) should be called
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
chain.ServeDNS(w, r)
routeHandler.AssertExpectations(t)
// Test 2: Remove highest priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
assert.True(t, chain.HasHandlers(testDomain))
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Now middle priority handler (matchHandler) should be called
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
chain.ServeDNS(w, r)
matchHandler.AssertExpectations(t)
// Test 3: Remove middle priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
assert.True(t, chain.HasHandlers(testDomain))
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Now lowest priority handler (defaultHandler) should be called
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
chain.ServeDNS(w, r)
defaultHandler.AssertExpectations(t)
// Test 4: Remove last handler
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
assert.False(t, chain.HasHandlers(testDomain))
}
func TestHandlerChain_CaseSensitivity(t *testing.T) {
tests := []struct {
name string
scenario string
addHandlers []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}
query string
expectedCalls int
}{
{
name: "case insensitive exact match",
scenario: "handler registered lowercase, query uppercase",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"example.com.", nbdns.PriorityDefault, false, true},
},
query: "EXAMPLE.COM.",
expectedCalls: 1,
},
{
name: "case insensitive wildcard match",
scenario: "handler registered mixed case wildcard, query different case",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"*.Example.Com.", nbdns.PriorityDefault, false, true},
},
query: "sub.EXAMPLE.COM.",
expectedCalls: 1,
},
{
name: "multiple handlers different case same domain",
scenario: "second handler should replace first despite case difference",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
{"example.com.", nbdns.PriorityDefault, false, true},
},
query: "ExAmPlE.cOm.",
expectedCalls: 1,
},
{
name: "subdomain matching case insensitive",
scenario: "handler with MatchSubdomains true should match regardless of case",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"example.com.", nbdns.PriorityDefault, true, true},
},
query: "SUB.EXAMPLE.COM.",
expectedCalls: 1,
},
{
name: "root zone case insensitive",
scenario: "root zone handler should match regardless of case",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{".", nbdns.PriorityDefault, false, true},
},
query: "EXAMPLE.COM.",
expectedCalls: 1,
},
{
name: "multiple handlers different priority",
scenario: "should call higher priority handler despite case differences",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
{"example.com.", nbdns.PriorityMatchDomain, false, false},
{"Example.Com.", nbdns.PriorityDNSRoute, false, true},
},
query: "example.com.",
expectedCalls: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
handlerCalls := make(map[string]bool) // track which patterns were called
// Add handlers according to test case
for _, h := range tt.addHandlers {
var handler dns.Handler
pattern := h.pattern // capture pattern for closure
if h.subdomains {
subHandler := &nbdns.MockSubdomainHandler{
Subdomains: true,
}
if h.shouldMatch {
subHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
handlerCalls[pattern] = true
w := args.Get(0).(dns.ResponseWriter)
r := args.Get(1).(*dns.Msg)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeSuccess)
assert.NoError(t, w.WriteMsg(resp))
}).Once()
}
handler = subHandler
} else {
mockHandler := &nbdns.MockHandler{}
if h.shouldMatch {
mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
handlerCalls[pattern] = true
w := args.Get(0).(dns.ResponseWriter)
r := args.Get(1).(*dns.Msg)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeSuccess)
assert.NoError(t, w.WriteMsg(resp))
}).Once()
}
handler = mockHandler
}
chain.AddHandler(pattern, handler, h.priority, nil)
}
// Execute request
r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA)
chain.ServeDNS(&mockResponseWriter{}, r)
// Verify each handler was called exactly as expected
for _, h := range tt.addHandlers {
wasCalled := handlerCalls[h.pattern]
assert.Equal(t, h.shouldMatch, wasCalled,
"Handler for pattern %q was %s when it should%s have been",
h.pattern,
map[bool]string{true: "called", false: "not called"}[wasCalled],
map[bool]string{true: "", false: " not"}[wasCalled == h.shouldMatch])
}
// Verify total number of calls
assert.Equal(t, tt.expectedCalls, len(handlerCalls),
"Wrong number of total handler calls")
})
}
}
func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
tests := []struct {
name string
scenario string
ops []struct {
action string
pattern string
priority int
subdomain bool
}
query string
expectedMatch string
}{
{
name: "more specific domain matches first",
scenario: "sub.example.com should match before example.com",
ops: []struct {
action string
pattern string
priority int
subdomain bool
}{
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
},
query: "sub.example.com.",
expectedMatch: "sub.example.com.",
},
{
name: "more specific domain matches first, both match subdomains",
scenario: "sub.example.com should match before example.com",
ops: []struct {
action string
pattern string
priority int
subdomain bool
}{
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
},
query: "sub.example.com.",
expectedMatch: "sub.example.com.",
},
{
name: "maintain specificity order after removal",
scenario: "after removing most specific, should fall back to less specific",
ops: []struct {
action string
pattern string
priority int
subdomain bool
}{
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true},
{"add", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
{"remove", "test.sub.example.com.", nbdns.PriorityMatchDomain, false},
},
query: "test.sub.example.com.",
expectedMatch: "sub.example.com.",
},
{
name: "priority overrides specificity",
scenario: "less specific domain with higher priority should match first",
ops: []struct {
action string
pattern string
priority int
subdomain bool
}{
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
{"add", "example.com.", nbdns.PriorityDNSRoute, true},
},
query: "sub.example.com.",
expectedMatch: "example.com.",
},
{
name: "equal priority respects specificity",
scenario: "with equal priority, more specific domain should match",
ops: []struct {
action string
pattern string
priority int
subdomain bool
}{
{"add", "example.com.", nbdns.PriorityMatchDomain, true},
{"add", "other.example.com.", nbdns.PriorityMatchDomain, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false},
},
query: "sub.example.com.",
expectedMatch: "sub.example.com.",
},
{
name: "specific matches before wildcard",
scenario: "specific domain should match before wildcard at same priority",
ops: []struct {
action string
pattern string
priority int
subdomain bool
}{
{"add", "*.example.com.", nbdns.PriorityDNSRoute, false},
{"add", "sub.example.com.", nbdns.PriorityDNSRoute, false},
},
query: "sub.example.com.",
expectedMatch: "sub.example.com.",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
handlers := make(map[string]*nbdns.MockSubdomainHandler)
for _, op := range tt.ops {
if op.action == "add" {
handler := &nbdns.MockSubdomainHandler{Subdomains: op.subdomain}
handlers[op.pattern] = handler
chain.AddHandler(op.pattern, handler, op.priority, nil)
} else {
chain.RemoveHandler(op.pattern, op.priority)
}
}
r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
// Setup handler expectations
for pattern, handler := range handlers {
if pattern == tt.expectedMatch {
handler.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) {
w := args.Get(0).(dns.ResponseWriter)
r := args.Get(1).(*dns.Msg)
resp := new(dns.Msg)
resp.SetReply(r)
assert.NoError(t, w.WriteMsg(resp))
}).Once()
}
}
chain.ServeDNS(w, r)
for pattern, handler := range handlers {
if pattern == tt.expectedMatch {
handler.AssertNumberOfCalls(t, "ServeDNS", 1)
} else {
handler.AssertNumberOfCalls(t, "ServeDNS", 0)
}
}
})
}
}

View File

@@ -102,3 +102,17 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
return config
}
type noopHostConfigurator struct{}
func (n noopHostConfigurator) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error {
return nil
}
func (n noopHostConfigurator) restoreHostDNS() error {
return nil
}
func (n noopHostConfigurator) supportCustomPort() bool {
return true
}

View File

@@ -28,6 +28,7 @@ const (
arraySymbol = "* "
digitSymbol = "# "
scutilPath = "/usr/sbin/scutil"
dscacheutilPath = "/usr/bin/dscacheutil"
searchSuffix = "Search"
matchSuffix = "Match"
localSuffix = "Local"
@@ -106,6 +107,10 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
return fmt.Errorf("add search domains: %w", err)
}
if err := s.flushDNSCache(); err != nil {
log.Errorf("failed to flush DNS cache: %v", err)
}
return nil
}
@@ -123,6 +128,10 @@ func (s *systemConfigurator) restoreHostDNS() error {
}
}
if err := s.flushDNSCache(); err != nil {
log.Errorf("failed to flush DNS cache: %v", err)
}
return nil
}
@@ -316,6 +325,21 @@ func (s *systemConfigurator) getPrimaryService() (string, string, error) {
return primaryService, router, nil
}
func (s *systemConfigurator) flushDNSCache() error {
cmd := exec.Command(dscacheutilPath, "-flushcache")
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("flush DNS cache: %w, output: %s", err, out)
}
cmd = exec.Command("killall", "-HUP", "mDNSResponder")
if out, err := cmd.CombinedOutput(); err != nil {
return fmt.Errorf("restart mDNSResponder: %w, output: %s", err, out)
}
log.Info("flushed DNS cache")
return nil
}
func (s *systemConfigurator) restoreUncleanShutdownDNS() error {
if err := s.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via scutil: %w", err)

View File

@@ -48,11 +48,17 @@ type restoreHostManager interface {
func newHostManager(wgInterface string) (hostManager, error) {
osManager, err := getOSDNSManagerType()
if err != nil {
return nil, err
return nil, fmt.Errorf("get os dns manager type: %w", err)
}
log.Infof("System DNS manager discovered: %s", osManager)
return newHostManagerFromType(wgInterface, osManager)
mgr, err := newHostManagerFromType(wgInterface, osManager)
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
if err != nil {
return nil, fmt.Errorf("create host manager: %w", err)
}
return mgr, nil
}
func newHostManagerFromType(wgInterface string, osManager osManagerType) (restoreHostManager, error) {

View File

@@ -17,12 +17,24 @@ type localResolver struct {
records sync.Map
}
func (d *localResolver) MatchSubdomains() bool {
return true
}
func (d *localResolver) stop() {
}
// String returns a string representation of the local resolver
func (d *localResolver) String() string {
return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap))
}
// ServeDNS handles a DNS request
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
log.Tracef("received question: %#v", r.Question[0])
if len(r.Question) > 0 {
log.Tracef("received question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
}
replyMessage := &dns.Msg{}
replyMessage.SetReply(r)
replyMessage.RecursionAvailable = true

View File

@@ -3,14 +3,30 @@ package dns
import (
"fmt"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
)
// MockServer is the mock instance of a dns server
type MockServer struct {
InitializeFunc func() error
StopFunc func()
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
InitializeFunc func() error
StopFunc func()
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
RegisterHandlerFunc func([]string, dns.Handler, int)
DeregisterHandlerFunc func([]string, int)
}
func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
if m.RegisterHandlerFunc != nil {
m.RegisterHandlerFunc(domains, handler, priority)
}
}
func (m *MockServer) DeregisterHandler(domains []string, priority int) {
if m.DeregisterHandlerFunc != nil {
m.DeregisterHandlerFunc(domains, priority)
}
}
// Initialize mock implementation of Initialize from Server interface

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"net/netip"
"os/exec"
"strings"
log "github.com/sirupsen/logrus"
@@ -15,23 +16,64 @@ import (
const resolvconfCommand = "resolvconf"
// resolvconfType represents the type of resolvconf implementation
type resolvconfType int
func (r resolvconfType) String() string {
switch r {
case typeOpenresolv:
return "openresolv"
case typeResolvconf:
return "resolvconf"
default:
return "unknown"
}
}
const (
typeOpenresolv resolvconfType = iota
typeResolvconf
)
type resolvconf struct {
ifaceName string
implType resolvconfType
originalSearchDomains []string
originalNameServers []string
othersConfigs []string
}
// supported "openresolv" only
func detectResolvconfType() (resolvconfType, error) {
cmd := exec.Command(resolvconfCommand, "--version")
out, err := cmd.Output()
if err != nil {
return typeOpenresolv, fmt.Errorf("failed to determine resolvconf type: %w", err)
}
if strings.Contains(string(out), "openresolv") {
return typeOpenresolv, nil
}
return typeResolvconf, nil
}
func newResolvConfConfigurator(wgInterface string) (*resolvconf, error) {
resolvConfEntries, err := parseDefaultResolvConf()
if err != nil {
log.Errorf("could not read original search domains from %s: %s", defaultResolvConfPath, err)
}
implType, err := detectResolvconfType()
if err != nil {
log.Warnf("failed to detect resolvconf type, defaulting to openresolv: %v", err)
implType = typeOpenresolv
} else {
log.Infof("detected resolvconf type: %v", implType)
}
return &resolvconf{
ifaceName: wgInterface,
implType: implType,
originalSearchDomains: resolvConfEntries.searchDomains,
originalNameServers: resolvConfEntries.nameServers,
othersConfigs: resolvConfEntries.others,
@@ -80,8 +122,15 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *stateman
}
func (r *resolvconf) restoreHostDNS() error {
// openresolv only, debian resolvconf doesn't support "-f"
cmd := exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
var cmd *exec.Cmd
switch r.implType {
case typeOpenresolv:
cmd = exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
case typeResolvconf:
cmd = exec.Command(resolvconfCommand, "-d", r.ifaceName)
}
_, err := cmd.Output()
if err != nil {
return fmt.Errorf("removing resolvconf configuration for %s interface: %w", r.ifaceName, err)
@@ -91,10 +140,21 @@ func (r *resolvconf) restoreHostDNS() error {
}
func (r *resolvconf) applyConfig(content bytes.Buffer) error {
// openresolv only, debian resolvconf doesn't support "-x"
cmd := exec.Command(resolvconfCommand, "-x", "-a", r.ifaceName)
var cmd *exec.Cmd
switch r.implType {
case typeOpenresolv:
// OpenResolv supports exclusive mode with -x
cmd = exec.Command(resolvconfCommand, "-x", "-a", r.ifaceName)
case typeResolvconf:
cmd = exec.Command(resolvconfCommand, "-a", r.ifaceName)
default:
return fmt.Errorf("unsupported resolvconf type: %v", r.implType)
}
cmd.Stdin = &content
_, err := cmd.Output()
out, err := cmd.Output()
log.Tracef("resolvconf output: %s", out)
if err != nil {
return fmt.Errorf("applying resolvconf configuration for %s interface: %w", r.ifaceName, err)
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/mitchellh/hashstructure/v2"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
@@ -30,6 +31,8 @@ type IosDnsManager interface {
// Server is a dns server interface
type Server interface {
RegisterHandler(domains []string, handler dns.Handler, priority int)
DeregisterHandler(domains []string, priority int)
Initialize() error
Stop()
DnsIP() string
@@ -45,15 +48,18 @@ type registeredHandlerMap map[string]handlerWithStop
type DefaultServer struct {
ctx context.Context
ctxCancel context.CancelFunc
disableSys bool
mux sync.Mutex
service service
dnsMuxMap registeredHandlerMap
handlerPriorities map[string]int
localResolver *localResolver
wgInterface WGIface
hostManager hostManager
updateSerial uint64
previousConfigHash uint64
currentConfig HostDNSConfig
handlerChain *HandlerChain
// permanent related properties
permanent bool
@@ -74,12 +80,20 @@ type handlerWithStop interface {
}
type muxUpdate struct {
domain string
handler handlerWithStop
domain string
handler handlerWithStop
priority int
}
// NewDefaultServer returns a new dns server
func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string, statusRecorder *peer.Status, stateManager *statemanager.Manager) (*DefaultServer, error) {
func NewDefaultServer(
ctx context.Context,
wgInterface WGIface,
customAddress string,
statusRecorder *peer.Status,
stateManager *statemanager.Manager,
disableSys bool,
) (*DefaultServer, error) {
var addrPort *netip.AddrPort
if customAddress != "" {
parsedAddrPort, err := netip.ParseAddrPort(customAddress)
@@ -96,7 +110,7 @@ func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress st
dnsService = newServiceViaListener(wgInterface, addrPort)
}
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager), nil
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager, disableSys), nil
}
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
@@ -107,9 +121,10 @@ func NewDefaultServerPermanentUpstream(
config nbdns.Config,
listener listener.NetworkChangeListener,
statusRecorder *peer.Status,
disableSys bool,
) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil)
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
ds.hostsDNSHolder.set(hostsDnsList)
ds.permanent = true
ds.addHostRootZone()
@@ -126,19 +141,30 @@ func NewDefaultServerIos(
wgInterface WGIface,
iosDnsManager IosDnsManager,
statusRecorder *peer.Status,
disableSys bool,
) *DefaultServer {
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil)
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys)
ds.iosDnsManager = iosDnsManager
return ds
}
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status, stateManager *statemanager.Manager) *DefaultServer {
func newDefaultServer(
ctx context.Context,
wgInterface WGIface,
dnsService service,
statusRecorder *peer.Status,
stateManager *statemanager.Manager,
disableSys bool,
) *DefaultServer {
ctx, stop := context.WithCancel(ctx)
defaultServer := &DefaultServer{
ctx: ctx,
ctxCancel: stop,
service: dnsService,
dnsMuxMap: make(registeredHandlerMap),
ctx: ctx,
ctxCancel: stop,
disableSys: disableSys,
service: dnsService,
handlerChain: NewHandlerChain(),
dnsMuxMap: make(registeredHandlerMap),
handlerPriorities: make(map[string]int),
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
@@ -151,6 +177,51 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
return defaultServer
}
func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
s.mux.Lock()
defer s.mux.Unlock()
s.registerHandler(domains, handler, priority)
}
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
log.Debugf("registering handler %s with priority %d", handler, priority)
for _, domain := range domains {
if domain == "" {
log.Warn("skipping empty domain")
continue
}
s.handlerChain.AddHandler(domain, handler, priority, nil)
s.handlerPriorities[domain] = priority
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
}
}
func (s *DefaultServer) DeregisterHandler(domains []string, priority int) {
s.mux.Lock()
defer s.mux.Unlock()
s.deregisterHandler(domains, priority)
}
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
log.Debugf("deregistering handler %v with priority %d", domains, priority)
for _, domain := range domains {
s.handlerChain.RemoveHandler(domain, priority)
// Only deregister from service if no handlers remain
if !s.handlerChain.HasHandlers(domain) {
if domain == "" {
log.Warn("skipping empty domain")
continue
}
s.service.DeregisterMux(nbdns.NormalizeZone(domain))
}
}
}
// Initialize instantiate host manager and the dns service
func (s *DefaultServer) Initialize() (err error) {
s.mux.Lock()
@@ -168,6 +239,16 @@ func (s *DefaultServer) Initialize() (err error) {
}
s.stateManager.RegisterState(&ShutdownState{})
// use noop host manager if requested or running in netstack mode.
// Netstack mode currently doesn't have a way to receive DNS requests.
// TODO: Use listener on localhost in netstack mode when running as root.
if s.disableSys || netstack.IsEnabled() {
log.Info("system DNS is disabled, not setting up host manager")
s.hostManager = &noopHostConfigurator{}
return nil
}
s.hostManager, err = s.initialize()
if err != nil {
return fmt.Errorf("initialize: %w", err)
@@ -216,47 +297,47 @@ func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) {
// UpdateDNSServer processes an update received from the management service
func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error {
select {
case <-s.ctx.Done():
if s.ctx.Err() != nil {
log.Infof("not updating DNS server as context is closed")
return s.ctx.Err()
default:
if serial < s.updateSerial {
return fmt.Errorf("not applying dns update, error: "+
"network update is %d behind the last applied update", s.updateSerial-serial)
}
s.mux.Lock()
defer s.mux.Unlock()
}
if s.hostManager == nil {
return fmt.Errorf("dns service is not initialized yet")
}
if serial < s.updateSerial {
return fmt.Errorf("not applying dns update, error: "+
"network update is %d behind the last applied update", s.updateSerial-serial)
}
hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{
ZeroNil: true,
IgnoreZeroValue: true,
SlicesAsSets: true,
UseStringer: true,
})
if err != nil {
log.Errorf("unable to hash the dns configuration update, got error: %s", err)
}
s.mux.Lock()
defer s.mux.Unlock()
if s.previousConfigHash == hash {
log.Debugf("not applying the dns configuration update as there is nothing new")
s.updateSerial = serial
return nil
}
if s.hostManager == nil {
return fmt.Errorf("dns service is not initialized yet")
}
if err := s.applyConfiguration(update); err != nil {
return fmt.Errorf("apply configuration: %w", err)
}
hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{
ZeroNil: true,
IgnoreZeroValue: true,
SlicesAsSets: true,
UseStringer: true,
})
if err != nil {
log.Errorf("unable to hash the dns configuration update, got error: %s", err)
}
if s.previousConfigHash == hash {
log.Debugf("not applying the dns configuration update as there is nothing new")
s.updateSerial = serial
s.previousConfigHash = hash
return nil
}
if err := s.applyConfiguration(update); err != nil {
return fmt.Errorf("apply configuration: %w", err)
}
s.updateSerial = serial
s.previousConfigHash = hash
return nil
}
func (s *DefaultServer) SearchDomains() []string {
@@ -343,14 +424,14 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
localRecords := make(map[string]nbdns.SimpleRecord, 0)
for _, customZone := range customZones {
if len(customZone.Records) == 0 {
return nil, nil, fmt.Errorf("received an empty list of records")
}
muxUpdates = append(muxUpdates, muxUpdate{
domain: customZone.Domain,
handler: s.localResolver,
domain: customZone.Domain,
handler: s.localResolver,
priority: PriorityMatchDomain,
})
for _, record := range customZone.Records {
@@ -412,8 +493,9 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
if nsGroup.Primary {
muxUpdates = append(muxUpdates, muxUpdate{
domain: nbdns.RootZone,
handler: handler,
domain: nbdns.RootZone,
handler: handler,
priority: PriorityDefault,
})
continue
}
@@ -429,8 +511,9 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
return nil, fmt.Errorf("received a nameserver group with an empty domain element")
}
muxUpdates = append(muxUpdates, muxUpdate{
domain: domain,
handler: handler,
domain: domain,
handler: handler,
priority: PriorityMatchDomain,
})
}
}
@@ -440,12 +523,16 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
muxUpdateMap := make(registeredHandlerMap)
handlersByPriority := make(map[string]int)
var isContainRootUpdate bool
// First register new handlers
for _, update := range muxUpdates {
s.service.RegisterMux(update.domain, update.handler)
s.registerHandler([]string{update.domain}, update.handler, update.priority)
muxUpdateMap[update.domain] = update.handler
handlersByPriority[update.domain] = update.priority
if existingHandler, ok := s.dnsMuxMap[update.domain]; ok {
existingHandler.stop()
}
@@ -455,6 +542,7 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
}
}
// Then deregister old handlers not in the update
for key, existingHandler := range s.dnsMuxMap {
_, found := muxUpdateMap[key]
if !found {
@@ -463,12 +551,16 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) {
existingHandler.stop()
} else {
existingHandler.stop()
s.service.DeregisterMux(key)
// Deregister with the priority that was used to register
if oldPriority, ok := s.handlerPriorities[key]; ok {
s.deregisterHandler([]string{key}, oldPriority)
}
}
}
}
s.dnsMuxMap = muxUpdateMap
s.handlerPriorities = handlersByPriority
}
func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
@@ -517,13 +609,13 @@ func (s *DefaultServer) upstreamCallbacks(
if nsGroup.Primary {
removeIndex[nbdns.RootZone] = -1
s.currentConfig.RouteAll = false
s.service.DeregisterMux(nbdns.RootZone)
s.deregisterHandler([]string{nbdns.RootZone}, PriorityDefault)
}
for i, item := range s.currentConfig.Domains {
if _, found := removeIndex[item.Domain]; found {
s.currentConfig.Domains[i].Disabled = true
s.service.DeregisterMux(item.Domain)
s.deregisterHandler([]string{item.Domain}, PriorityMatchDomain)
removeIndex[item.Domain] = i
}
}
@@ -554,7 +646,7 @@ func (s *DefaultServer) upstreamCallbacks(
continue
}
s.currentConfig.Domains[i].Disabled = false
s.service.RegisterMux(domain, handler)
s.registerHandler([]string{domain}, handler, PriorityMatchDomain)
}
l := log.WithField("nameservers", nsGroup.NameServers)
@@ -562,10 +654,13 @@ func (s *DefaultServer) upstreamCallbacks(
if nsGroup.Primary {
s.currentConfig.RouteAll = true
s.service.RegisterMux(nbdns.RootZone, handler)
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
}
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
if s.hostManager != nil {
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
}
}
s.updateNSState(nsGroup, nil, true)
@@ -593,7 +688,8 @@ func (s *DefaultServer) addHostRootZone() {
}
handler.deactivate = func(error) {}
handler.reactivate = func() {}
s.service.RegisterMux(nbdns.RootZone, handler)
s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault)
}
func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) {

View File

@@ -11,7 +11,9 @@ import (
"time"
"github.com/golang/mock/gomock"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/mock"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/firewall/uspfilter"
@@ -292,7 +294,7 @@ func TestUpdateDNSServer(t *testing.T) {
t.Log(err)
}
}()
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil)
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil, false)
if err != nil {
t.Fatal(err)
}
@@ -401,7 +403,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
return
}
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil)
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil, false)
if err != nil {
t.Errorf("create DNS server: %v", err)
return
@@ -496,7 +498,7 @@ func TestDNSServerStartStop(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil)
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil, false)
if err != nil {
t.Fatalf("%v", err)
}
@@ -512,7 +514,7 @@ func TestDNSServerStartStop(t *testing.T) {
t.Error(err)
}
dnsServer.service.RegisterMux("netbird.cloud", dnsServer.localResolver)
dnsServer.registerHandler([]string{"netbird.cloud"}, dnsServer.localResolver, 1)
resolver := &net.Resolver{
PreferGo: true,
@@ -560,7 +562,9 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
localResolver: &localResolver{
registeredMap: make(registrationMap),
},
hostManager: hostManager,
handlerChain: NewHandlerChain(),
handlerPriorities: make(map[string]int),
hostManager: hostManager,
currentConfig: HostDNSConfig{
Domains: []DomainConfig{
{false, "domain0", false},
@@ -629,7 +633,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) {
var dnsList []string
dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, &peer.Status{})
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, &peer.Status{}, false)
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
@@ -653,7 +657,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) {
}
defer wgIFace.Close()
dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{})
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{}, false)
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
@@ -745,7 +749,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
}
defer wgIFace.Close()
dnsConfig := nbdns.Config{}
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{})
dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{}, false)
err = dnsServer.Initialize()
if err != nil {
t.Errorf("failed to initialize DNS server: %v", err)
@@ -872,3 +876,86 @@ func newDnsResolver(ip string, port int) *net.Resolver {
},
}
}
// MockHandler implements dns.Handler interface for testing
type MockHandler struct {
mock.Mock
}
func (m *MockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
m.Called(w, r)
}
type MockSubdomainHandler struct {
MockHandler
Subdomains bool
}
func (m *MockSubdomainHandler) MatchSubdomains() bool {
return m.Subdomains
}
func TestHandlerChain_DomainPriorities(t *testing.T) {
chain := NewHandlerChain()
dnsRouteHandler := &MockHandler{}
upstreamHandler := &MockSubdomainHandler{
Subdomains: true,
}
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute, nil)
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain, nil)
testCases := []struct {
name string
query string
expectedHandler dns.Handler
}{
{
name: "exact domain with dns route handler",
query: "example.com.",
expectedHandler: dnsRouteHandler,
},
{
name: "subdomain should use upstream handler",
query: "sub.example.com.",
expectedHandler: upstreamHandler,
},
{
name: "deep subdomain should use upstream handler",
query: "deep.sub.example.com.",
expectedHandler: upstreamHandler,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
r := new(dns.Msg)
r.SetQuestion(tc.query, dns.TypeA)
w := &ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
mh.On("ServeDNS", mock.Anything, r).Once()
} else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok {
mh.On("ServeDNS", mock.Anything, r).Once()
}
chain.ServeDNS(w, r)
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
mh.AssertExpectations(t)
} else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok {
mh.AssertExpectations(t)
}
// Reset mocks
if mh, ok := tc.expectedHandler.(*MockHandler); ok {
mh.ExpectedCalls = nil
mh.Calls = nil
} else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok {
mh.ExpectedCalls = nil
mh.Calls = nil
}
})
}
}

View File

@@ -105,6 +105,7 @@ func (s *serviceViaListener) Stop() {
}
func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) {
log.Debugf("registering dns handler for pattern: %s", pattern)
s.dnsMux.Handle(pattern, handler)
}

View File

@@ -66,6 +66,15 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) *
}
}
// String returns a string representation of the upstream resolver
func (u *upstreamResolverBase) String() string {
return fmt.Sprintf("upstream %v", u.upstreamServers)
}
func (u *upstreamResolverBase) MatchSubdomains() bool {
return true
}
func (u *upstreamResolverBase) stop() {
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
u.cancel()

View File

@@ -0,0 +1,157 @@
package dnsfwd
import (
"context"
"errors"
"net"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
)
const errResolveFailed = "failed to resolve query for domain=%s: %v"
type DNSForwarder struct {
listenAddress string
ttl uint32
domains []string
dnsServer *dns.Server
mux *dns.ServeMux
}
func NewDNSForwarder(listenAddress string, ttl uint32) *DNSForwarder {
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
return &DNSForwarder{
listenAddress: listenAddress,
ttl: ttl,
}
}
func (f *DNSForwarder) Listen(domains []string) error {
log.Infof("listen DNS forwarder on address=%s", f.listenAddress)
mux := dns.NewServeMux()
dnsServer := &dns.Server{
Addr: f.listenAddress,
Net: "udp",
Handler: mux,
}
f.dnsServer = dnsServer
f.mux = mux
f.UpdateDomains(domains)
return dnsServer.ListenAndServe()
}
func (f *DNSForwarder) UpdateDomains(domains []string) {
log.Debugf("Updating domains from %v to %v", f.domains, domains)
for _, d := range f.domains {
f.mux.HandleRemove(d)
}
newDomains := filterDomains(domains)
for _, d := range newDomains {
f.mux.HandleFunc(d, f.handleDNSQuery)
}
f.domains = newDomains
}
func (f *DNSForwarder) Close(ctx context.Context) error {
if f.dnsServer == nil {
return nil
}
return f.dnsServer.ShutdownContext(ctx)
}
func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
if len(query.Question) == 0 {
return
}
log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v",
query.Question[0].Name, query.Question[0].Qtype, query.Question[0].Qclass)
question := query.Question[0]
domain := question.Name
resp := query.SetReply(query)
ips, err := net.LookupIP(domain)
if err != nil {
var dnsErr *net.DNSError
switch {
case errors.As(err, &dnsErr):
resp.Rcode = dns.RcodeServerFailure
if dnsErr.IsNotFound {
// Pass through NXDOMAIN
resp.Rcode = dns.RcodeNameError
}
if dnsErr.Server != "" {
log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err)
} else {
log.Warnf(errResolveFailed, domain, err)
}
default:
resp.Rcode = dns.RcodeServerFailure
log.Warnf(errResolveFailed, domain, err)
}
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write failure DNS response: %v", err)
}
return
}
for _, ip := range ips {
var respRecord dns.RR
if ip.To4() == nil {
log.Tracef("resolved domain=%s to IPv6=%s", domain, ip)
rr := dns.AAAA{
AAAA: ip,
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: f.ttl,
},
}
respRecord = &rr
} else {
log.Tracef("resolved domain=%s to IPv4=%s", domain, ip)
rr := dns.A{
A: ip,
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: f.ttl,
},
}
respRecord = &rr
}
resp.Answer = append(resp.Answer, respRecord)
}
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
}
}
// filterDomains returns a list of normalized domains
func filterDomains(domains []string) []string {
newDomains := make([]string, 0, len(domains))
for _, d := range domains {
if d == "" {
log.Warn("empty domain in DNS forwarder")
continue
}
newDomains = append(newDomains, nbdns.NormalizeZone(d))
}
return newDomains
}

View File

@@ -0,0 +1,111 @@
package dnsfwd
import (
"context"
"fmt"
"net"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
const (
// ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also
ListenPort = 5353
dnsTTL = 60 //seconds
)
type Manager struct {
firewall firewall.Manager
fwRules []firewall.Rule
dnsForwarder *DNSForwarder
}
func NewManager(fw firewall.Manager) *Manager {
return &Manager{
firewall: fw,
}
}
func (m *Manager) Start(domains []string) error {
log.Infof("starting DNS forwarder")
if m.dnsForwarder != nil {
return nil
}
if err := m.allowDNSFirewall(); err != nil {
return err
}
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL)
go func() {
if err := m.dnsForwarder.Listen(domains); err != nil {
// todo handle close error if it is exists
log.Errorf("failed to start DNS forwarder, err: %v", err)
}
}()
return nil
}
func (m *Manager) UpdateDomains(domains []string) {
if m.dnsForwarder == nil {
return
}
m.dnsForwarder.UpdateDomains(domains)
}
func (m *Manager) Stop(ctx context.Context) error {
if m.dnsForwarder == nil {
return nil
}
var mErr *multierror.Error
if err := m.dropDNSFirewall(); err != nil {
mErr = multierror.Append(mErr, err)
}
if err := m.dnsForwarder.Close(ctx); err != nil {
mErr = multierror.Append(mErr, err)
}
m.dnsForwarder = nil
return nberrors.FormatErrorOrNil(mErr)
}
func (h *Manager) allowDNSFirewall() error {
dport := &firewall.Port{
IsRange: false,
Values: []uint16{ListenPort},
}
if h.firewall == nil {
return nil
}
dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "", "")
if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err)
return err
}
h.fwRules = dnsRules
return nil
}
func (h *Manager) dropDNSFirewall() error {
var mErr *multierror.Error
for _, rule := range h.fwRules {
if err := h.firewall.DeletePeerRule(rule); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err))
}
}
h.fwRules = nil
return nberrors.FormatErrorOrNil(mErr)
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"maps"
"math/rand"
"net"
"net/netip"
@@ -17,23 +16,28 @@ import (
"sync/atomic"
"time"
"github.com/hashicorp/go-multierror"
"github.com/pion/ice/v3"
"github.com/pion/stun/v2"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/protobuf/proto"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager"
@@ -107,6 +111,13 @@ type EngineConfig struct {
ServerSSHAllowed bool
DNSRouteInterval time.Duration
DisableClientRoutes bool
DisableServerRoutes bool
DisableDNS bool
DisableFirewall bool
BlockLANAccess bool
}
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@@ -117,7 +128,7 @@ type Engine struct {
// mgmClient is a Management Service client
mgmClient mgm.Client
// peerConns is a map that holds all the peers that are known to this peer
peerConns map[string]*peer.Conn
peerStore *peerstore.Store
beforePeerHook nbnet.AddHookFunc
afterPeerHook nbnet.RemoveHookFunc
@@ -137,10 +148,6 @@ type Engine struct {
TURNs []*stun.URI
stunTurn atomic.Value
// clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes route.HAMap
clientRoutesMu sync.RWMutex
clientCtx context.Context
clientCancel context.CancelFunc
@@ -161,14 +168,13 @@ type Engine struct {
statusRecorder *peer.Status
firewall manager.Manager
routeManager routemanager.Manager
acl acl.Manager
firewall manager.Manager
routeManager routemanager.Manager
acl acl.Manager
dnsForwardMgr *dnsfwd.Manager
dnsServer dns.Server
probes *ProbeHolder
// checks are the client-applied posture checks that need to be evaluated on the client
checks []*mgmProto.Checks
@@ -188,7 +194,7 @@ type Peer struct {
WgAllowedIps string
}
// NewEngine creates a new Connection Engine
// NewEngine creates a new Connection Engine with probes attached
func NewEngine(
clientCtx context.Context,
clientCancel context.CancelFunc,
@@ -199,33 +205,6 @@ func NewEngine(
mobileDep MobileDependency,
statusRecorder *peer.Status,
checks []*mgmProto.Checks,
) *Engine {
return NewEngineWithProbes(
clientCtx,
clientCancel,
signalClient,
mgmClient,
relayManager,
config,
mobileDep,
statusRecorder,
nil,
checks,
)
}
// NewEngineWithProbes creates a new Connection Engine with probes attached
func NewEngineWithProbes(
clientCtx context.Context,
clientCancel context.CancelFunc,
signalClient signal.Client,
mgmClient mgm.Client,
relayManager *relayClient.Manager,
config *EngineConfig,
mobileDep MobileDependency,
statusRecorder *peer.Status,
probes *ProbeHolder,
checks []*mgmProto.Checks,
) *Engine {
engine := &Engine{
clientCtx: clientCtx,
@@ -234,7 +213,7 @@ func NewEngineWithProbes(
signaler: peer.NewSignaler(signalClient, config.WgPrivateKey),
mgmClient: mgmClient,
relayManager: relayManager,
peerConns: make(map[string]*peer.Conn),
peerStore: peerstore.NewConnStore(),
syncMsgMux: &sync.Mutex{},
config: config,
mobileDep: mobileDep,
@@ -243,7 +222,6 @@ func NewEngineWithProbes(
networkSerial: 0,
sshServerFunc: nbssh.DefaultSSHServer,
statusRecorder: statusRecorder,
probes: probes,
checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
}
@@ -287,6 +265,13 @@ func (e *Engine) Stop() error {
e.routeManager.Stop(e.stateManager)
}
if e.dnsForwardMgr != nil {
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
e.dnsForwardMgr = nil
}
if e.srWatcher != nil {
e.srWatcher.Close()
}
@@ -300,10 +285,6 @@ func (e *Engine) Stop() error {
return fmt.Errorf("failed to remove all peers: %s", err)
}
e.clientRoutesMu.Lock()
e.clientRoutes = nil
e.clientRoutesMu.Unlock()
if e.cancel != nil {
e.cancel()
}
@@ -373,16 +354,20 @@ func (e *Engine) Start() error {
}
e.dnsServer = dnsServer
e.routeManager = routemanager.NewManager(
e.ctx,
e.config.WgPrivateKey.PublicKey().String(),
e.config.DNSRouteInterval,
e.wgInterface,
e.statusRecorder,
e.relayManager,
initialRoutes,
e.stateManager,
)
e.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
Context: e.ctx,
PublicKey: e.config.WgPrivateKey.PublicKey().String(),
DNSRouteInterval: e.config.DNSRouteInterval,
WGInterface: e.wgInterface,
StatusRecorder: e.statusRecorder,
RelayManager: e.relayManager,
InitialRoutes: initialRoutes,
StateManager: e.stateManager,
DNSServer: dnsServer,
PeerStore: e.peerStore,
DisableClientRoutes: e.config.DisableClientRoutes,
DisableServerRoutes: e.config.DisableServerRoutes,
})
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
if err != nil {
log.Errorf("Failed to initialize route manager: %s", err)
@@ -400,17 +385,8 @@ func (e *Engine) Start() error {
return fmt.Errorf("create wg interface: %w", err)
}
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager)
if err != nil {
log.Errorf("failed creating firewall manager: %s", err)
}
if e.firewall != nil && e.firewall.IsServerRouteSupported() {
err = e.routeManager.EnableServerRouter(e.firewall)
if err != nil {
e.close()
return fmt.Errorf("enable server router: %w", err)
}
if err := e.createFirewall(); err != nil {
return err
}
e.udpMux, err = e.wgInterface.Up()
@@ -444,7 +420,6 @@ func (e *Engine) Start() error {
e.receiveSignalEvents()
e.receiveManagementEvents()
e.receiveProbeEvents()
// starting network monitor at the very last to avoid disruptions
e.startNetworkMonitor()
@@ -452,6 +427,93 @@ func (e *Engine) Start() error {
return nil
}
func (e *Engine) createFirewall() error {
if e.config.DisableFirewall {
log.Infof("firewall is disabled")
return nil
}
var err error
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager)
if err != nil || e.firewall == nil {
log.Errorf("failed creating firewall manager: %s", err)
return nil
}
if err := e.initFirewall(); err != nil {
return err
}
return nil
}
func (e *Engine) initFirewall() error {
if e.firewall.IsServerRouteSupported() {
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
e.close()
return fmt.Errorf("enable server router: %w", err)
}
}
if e.config.BlockLANAccess {
e.blockLanAccess()
}
if e.rpManager == nil || !e.config.RosenpassEnabled {
return nil
}
rosenpassPort := e.rpManager.GetAddress().Port
port := manager.Port{Values: []uint16{uint16(rosenpassPort)}}
// this rule is static and will be torn down on engine down by the firewall manager
if _, err := e.firewall.AddPeerFiltering(
net.IP{0, 0, 0, 0},
manager.ProtocolUDP,
nil,
&port,
manager.ActionAccept,
"",
"",
); err != nil {
log.Errorf("failed to allow rosenpass interface traffic: %v", err)
return nil
}
log.Infof("rosenpass interface traffic allowed on port %d", rosenpassPort)
return nil
}
func (e *Engine) blockLanAccess() {
var merr *multierror.Error
// TODO: keep this updated
toBlock, err := getInterfacePrefixes()
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("get local addresses: %w", err))
}
log.Infof("blocking route LAN access for networks: %v", toBlock)
v4 := netip.PrefixFrom(netip.IPv4Unspecified(), 0)
for _, network := range toBlock {
if _, err := e.firewall.AddRouteFiltering(
[]netip.Prefix{v4},
network,
manager.ProtocolALL,
nil,
nil,
manager.ActionDrop,
); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add fw rule for network %s: %w", network, err))
}
}
if merr != nil {
log.Warnf("encountered errors blocking IPs to block LAN access: %v", nberrors.FormatErrorOrNil(merr))
}
}
// modifyPeers updates peers that have been modified (e.g. IP address has been changed).
// It closes the existing connection, removes it from the peerConns map, and creates a new one.
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
@@ -460,8 +522,8 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
var modified []*mgmProto.RemotePeerConfig
for _, p := range peersUpdate {
peerPubKey := p.GetWgPubKey()
if peerConn, ok := e.peerConns[peerPubKey]; ok {
if peerConn.WgConfig().AllowedIps != strings.Join(p.AllowedIps, ",") {
if allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey); ok {
if allowedIPs != strings.Join(p.AllowedIps, ",") {
modified = append(modified, p)
continue
}
@@ -492,17 +554,12 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
// removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service.
// It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method.
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
currentPeers := make([]string, 0, len(e.peerConns))
for p := range e.peerConns {
currentPeers = append(currentPeers, p)
}
newPeers := make([]string, 0, len(peersUpdate))
for _, p := range peersUpdate {
newPeers = append(newPeers, p.GetWgPubKey())
}
toRemove := util.SliceDiff(currentPeers, newPeers)
toRemove := util.SliceDiff(e.peerStore.PeersPubKey(), newPeers)
for _, p := range toRemove {
err := e.removePeer(p)
@@ -516,7 +573,7 @@ func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
func (e *Engine) removeAllPeers() error {
log.Debugf("removing all peer connections")
for p := range e.peerConns {
for _, p := range e.peerStore.PeersPubKey() {
err := e.removePeer(p)
if err != nil {
return err
@@ -540,9 +597,8 @@ func (e *Engine) removePeer(peerKey string) error {
}
}()
conn, exists := e.peerConns[peerKey]
conn, exists := e.peerStore.Remove(peerKey)
if exists {
delete(e.peerConns, peerKey)
conn.Close()
}
return nil
@@ -629,6 +685,15 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
}
info.SetFlags(
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
e.config.DisableFirewall,
)
if err := e.mgmClient.SyncMeta(info); err != nil {
log.Errorf("could not sync meta: error %s", err)
@@ -649,18 +714,22 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
} else {
if sshConf.GetSshEnabled() {
if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" {
if runtime.GOOS == "windows" {
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
return nil
}
// start SSH server if it wasn't running
if isNil(e.sshServer) {
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
if netstack.IsEnabled() {
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
}
// nil sshServer means it has not yet been started
var err error
e.sshServer, err = e.sshServerFunc(e.config.SSHKey,
fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort))
e.sshServer, err = e.sshServerFunc(e.config.SSHKey, listenAddr)
if err != nil {
return err
return fmt.Errorf("create ssh server: %w", err)
}
go func() {
// blocking
@@ -709,16 +778,17 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
if conf.GetSshConfig() != nil {
err := e.updateSSH(conf.GetSshConfig())
if err != nil {
log.Warnf("failed handling SSH server setup %v", err)
log.Warnf("failed handling SSH server setup: %v", err)
}
}
e.statusRecorder.UpdateLocalPeerState(peer.LocalPeerState{
IP: e.config.WgAddr,
PubKey: e.config.WgPrivateKey.PublicKey().String(),
KernelInterface: device.WireGuardModuleIsLoaded(),
FQDN: conf.GetFqdn(),
})
state := e.statusRecorder.GetLocalPeerState()
state.IP = e.config.WgAddr
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
state.KernelInterface = device.WireGuardModuleIsLoaded()
state.FQDN = conf.GetFqdn()
e.statusRecorder.UpdateLocalPeerState(state)
return nil
}
@@ -732,6 +802,15 @@ func (e *Engine) receiveManagementEvents() {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
}
info.SetFlags(
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
e.config.DisableFirewall,
)
// err = e.mgmClient.Sync(info, e.handleSync)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
@@ -786,7 +865,6 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
}
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
// intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't
if networkMap.GetPeerConfig() != nil {
err := e.updateConfig(networkMap.GetPeerConfig())
@@ -806,20 +884,16 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.acl.ApplyFiltering(networkMap)
}
protoRoutes := networkMap.GetRoutes()
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}
// DNS forwarder
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
dnsRouteDomains := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes())
e.updateDNSForwarder(dnsRouteFeatureFlag, dnsRouteDomains)
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
if err != nil {
routes := toRoutes(networkMap.GetRoutes())
if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err)
}
e.clientRoutesMu.Lock()
e.clientRoutes = clientRoutes
e.clientRoutesMu.Unlock()
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
e.updateOfflinePeers(networkMap.GetOfflinePeers())
@@ -867,8 +941,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
protoDNSConfig = &mgmProto.DNSConfig{}
}
err = e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig))
if err != nil {
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)); err != nil {
log.Errorf("failed to update dns server, err: %v", err)
}
@@ -881,7 +954,18 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return nil
}
func toDNSFeatureFlag(networkMap *mgmProto.NetworkMap) bool {
if networkMap.PeerConfig != nil {
return networkMap.PeerConfig.RoutingPeerDnsResolutionEnabled
}
return false
}
func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}
routes := make([]*route.Route, 0)
for _, protoRoute := range protoRoutes {
var prefix netip.Prefix
@@ -892,6 +976,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
continue
}
}
convertedRoute := &route.Route{
ID: route.ID(protoRoute.ID),
Network: prefix,
@@ -908,6 +993,23 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
return routes
}
func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) []string {
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}
var dnsRoutes []string
for _, protoRoute := range protoRoutes {
if len(protoRoute.Domains) == 0 {
continue
}
if protoRoute.Peer == myPubKey {
dnsRoutes = append(dnsRoutes, protoRoute.Domains...)
}
}
return dnsRoutes
}
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config {
dnsUpdate := nbdns.Config{
ServiceEnable: protoDNSConfig.GetServiceEnable(),
@@ -982,12 +1084,16 @@ func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
peerKey := peerConfig.GetWgPubKey()
peerIPs := peerConfig.GetAllowedIps()
if _, ok := e.peerConns[peerKey]; !ok {
if _, ok := e.peerStore.PeerConn(peerKey); !ok {
conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ","))
if err != nil {
return fmt.Errorf("create peer connection: %w", err)
}
e.peerConns[peerKey] = conn
if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok {
conn.Close()
return fmt.Errorf("peer already exists: %s", peerKey)
}
if e.beforePeerHook != nil && e.afterPeerHook != nil {
conn.AddBeforeAddPeerHook(e.beforePeerHook)
@@ -1076,8 +1182,8 @@ func (e *Engine) receiveSignalEvents() {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
conn := e.peerConns[msg.Key]
if conn == nil {
conn, ok := e.peerStore.PeerConn(msg.Key)
if !ok {
return fmt.Errorf("wrongly addressed message %s", msg.Key)
}
@@ -1135,7 +1241,7 @@ func (e *Engine) receiveSignalEvents() {
return err
}
go conn.OnRemoteCandidate(candidate, e.GetClientRoutes())
go conn.OnRemoteCandidate(candidate, e.routeManager.GetClientRoutes())
case sProto.Body_MODE:
}
@@ -1235,6 +1341,16 @@ func (e *Engine) close() {
func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) {
info := system.GetInfo(e.ctx)
info.SetFlags(
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
e.config.DisableFirewall,
)
netMap, err := e.mgmClient.GetNetworkMap(info)
if err != nil {
return nil, nil, err
@@ -1293,6 +1409,7 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
if e.dnsServer != nil {
return nil, e.dnsServer, nil
}
switch runtime.GOOS {
case "android":
routes, dnsConfig, err := e.readInitialSettings()
@@ -1306,14 +1423,17 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
*dnsConfig,
e.mobileDep.NetworkChangeListener,
e.statusRecorder,
e.config.DisableDNS,
)
go e.mobileDep.DnsReadyListener.OnReady()
return routes, dnsServer, nil
case "ios":
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder)
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS)
return nil, dnsServer, nil
default:
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager)
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS)
if err != nil {
return nil, nil, err
}
@@ -1322,26 +1442,6 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
}
}
// GetClientRoutes returns the current routes from the route map
func (e *Engine) GetClientRoutes() route.HAMap {
e.clientRoutesMu.RLock()
defer e.clientRoutesMu.RUnlock()
return maps.Clone(e.clientRoutes)
}
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
func (e *Engine) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
e.clientRoutesMu.RLock()
defer e.clientRoutesMu.RUnlock()
routes := make(map[route.NetID][]*route.Route, len(e.clientRoutes))
for id, v := range e.clientRoutes {
routes[id.NetID()] = v
}
return routes
}
// GetRouteManager returns the route manager
func (e *Engine) GetRouteManager() routemanager.Manager {
return e.routeManager
@@ -1382,73 +1482,58 @@ func (e *Engine) getRosenpassAddr() string {
return ""
}
func (e *Engine) receiveProbeEvents() {
if e.probes == nil {
return
// RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services
// and updates the status recorder with the latest states.
func (e *Engine) RunHealthProbes() bool {
signalHealthy := e.signal.IsHealthy()
log.Debugf("signal health check: healthy=%t", signalHealthy)
managementHealthy := e.mgmClient.IsHealthy()
log.Debugf("management health check: healthy=%t", managementHealthy)
results := append(e.probeSTUNs(), e.probeTURNs()...)
e.statusRecorder.UpdateRelayStates(results)
relayHealthy := true
for _, res := range results {
if res.Err != nil {
relayHealthy = false
break
}
}
if e.probes.SignalProbe != nil {
go e.probes.SignalProbe.Receive(e.ctx, func() bool {
healthy := e.signal.IsHealthy()
log.Debugf("received signal probe request, healthy: %t", healthy)
return healthy
})
log.Debugf("relay health check: healthy=%t", relayHealthy)
for _, key := range e.peerStore.PeersPubKey() {
wgStats, err := e.wgInterface.GetStats(key)
if err != nil {
log.Debugf("failed to get wg stats for peer %s: %s", key, err)
continue
}
// wgStats could be zero value, in which case we just reset the stats
if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil {
log.Debugf("failed to update wg stats for peer %s: %s", key, err)
}
}
if e.probes.MgmProbe != nil {
go e.probes.MgmProbe.Receive(e.ctx, func() bool {
healthy := e.mgmClient.IsHealthy()
log.Debugf("received management probe request, healthy: %t", healthy)
return healthy
})
}
if e.probes.RelayProbe != nil {
go e.probes.RelayProbe.Receive(e.ctx, func() bool {
healthy := true
results := append(e.probeSTUNs(), e.probeTURNs()...)
e.statusRecorder.UpdateRelayStates(results)
// A single failed server will result in a "failed" probe
for _, res := range results {
if res.Err != nil {
healthy = false
break
}
}
log.Debugf("received relay probe request, healthy: %t", healthy)
return healthy
})
}
if e.probes.WgProbe != nil {
go e.probes.WgProbe.Receive(e.ctx, func() bool {
log.Debug("received wg probe request")
for _, peer := range e.peerConns {
key := peer.GetKey()
wgStats, err := peer.WgConfig().WgInterface.GetStats(key)
if err != nil {
log.Debugf("failed to get wg stats for peer %s: %s", key, err)
}
// wgStats could be zero value, in which case we just reset the stats
if err := e.statusRecorder.UpdateWireGuardPeerState(key, wgStats); err != nil {
log.Debugf("failed to update wg stats for peer %s: %s", key, err)
}
}
return true
})
}
allHealthy := signalHealthy && managementHealthy && relayHealthy
log.Debugf("all health checks completed: healthy=%t", allHealthy)
return allHealthy
}
func (e *Engine) probeSTUNs() []relay.ProbeResult {
return relay.ProbeAll(e.ctx, relay.ProbeSTUN, e.STUNs)
e.syncMsgMux.Lock()
stuns := slices.Clone(e.STUNs)
e.syncMsgMux.Unlock()
return relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns)
}
func (e *Engine) probeTURNs() []relay.ProbeResult {
return relay.ProbeAll(e.ctx, relay.ProbeTURN, e.TURNs)
e.syncMsgMux.Lock()
turns := slices.Clone(e.TURNs)
e.syncMsgMux.Unlock()
return relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)
}
func (e *Engine) restartEngine() {
@@ -1505,7 +1590,7 @@ func (e *Engine) startNetworkMonitor() {
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
var vpnRoutes []netip.Prefix
for _, routes := range e.GetClientRoutes() {
for _, routes := range e.routeManager.GetClientRoutes() {
if len(routes) > 0 && routes[0] != nil {
vpnRoutes = append(vpnRoutes, routes[0].Network)
}
@@ -1563,7 +1648,7 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
return nil, nil
}
// Create a deep copy to avoid external modifications
log.Debugf("Retrieving latest network map with size %d bytes", proto.Size(e.latestNetworkMap))
nm, ok := proto.Clone(e.latestNetworkMap).(*mgmProto.NetworkMap)
if !ok {
@@ -1573,6 +1658,40 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
return nm, nil
}
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
func (e *Engine) updateDNSForwarder(enabled bool, domains []string) {
if !enabled {
if e.dnsForwardMgr == nil {
return
}
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
return
}
if len(domains) > 0 {
log.Infof("enable domain router service for domains: %v", domains)
if e.dnsForwardMgr == nil {
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall)
if err := e.dnsForwardMgr.Start(domains); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
e.dnsForwardMgr = nil
}
} else {
log.Infof("update domain router service for domains: %v", domains)
e.dnsForwardMgr.UpdateDomains(domains)
}
} else if e.dnsForwardMgr != nil {
log.Infof("disable domain router service")
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
e.dnsForwardMgr = nil
}
}
// isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
for _, check := range checks {
@@ -1590,3 +1709,45 @@ func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files)
})
}
func getInterfacePrefixes() ([]netip.Prefix, error) {
ifaces, err := net.Interfaces()
if err != nil {
return nil, fmt.Errorf("get interfaces: %w", err)
}
var prefixes []netip.Prefix
var merr *multierror.Error
for _, iface := range ifaces {
addrs, err := iface.Addrs()
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("get addresses for interface %s: %w", iface.Name, err))
continue
}
for _, addr := range addrs {
ipNet, ok := addr.(*net.IPNet)
if !ok {
merr = multierror.Append(merr, fmt.Errorf("cast address to IPNet: %v", addr))
continue
}
addr, ok := netip.AddrFromSlice(ipNet.IP)
if !ok {
merr = multierror.Append(merr, fmt.Errorf("cast IPNet to netip.Addr: %v", ipNet.IP))
continue
}
ones, _ := ipNet.Mask.Size()
prefix := netip.PrefixFrom(addr.Unmap(), ones).Masked()
ip := prefix.Addr()
// TODO: add IPv6
if !ip.Is4() || ip.IsLoopback() || ip.IsMulticast() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
continue
}
prefixes = append(prefixes, prefix)
}
}
return prefixes, nberrors.FormatErrorOrNil(merr)
}

View File

@@ -39,6 +39,8 @@ import (
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
relayClient "github.com/netbirdio/netbird/relay/client"
"github.com/netbirdio/netbird/route"
@@ -69,8 +71,7 @@ func TestMain(m *testing.M) {
}
func TestEngine_SSH(t *testing.T) {
// todo resolve test execution on freebsd
if runtime.GOOS == "windows" || runtime.GOOS == "freebsd" {
if runtime.GOOS == "windows" {
t.Skip("skipping TestEngine_SSH")
}
@@ -251,7 +252,14 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
},
}
engine.wgInterface = wgIface
engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil)
engine.routeManager = routemanager.NewManager(routemanager.ManagerConfig{
Context: ctx,
PublicKey: key.PublicKey().String(),
DNSRouteInterval: time.Minute,
WGInterface: engine.wgInterface,
StatusRecorder: engine.statusRecorder,
RelayManager: relayMgr,
})
_, _, err = engine.routeManager.Init()
require.NoError(t, err)
engine.dnsServer = &dns.MockServer{
@@ -391,8 +399,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
return
}
if len(engine.peerConns) != c.expectedLen {
t.Errorf("expecting Engine.peerConns to be of size %d, got %d", c.expectedLen, len(engine.peerConns))
if len(engine.peerStore.PeersPubKey()) != c.expectedLen {
t.Errorf("expecting Engine.peerConns to be of size %d, got %d", c.expectedLen, len(engine.peerStore.PeersPubKey()))
}
if engine.networkSerial != c.expectedSerial {
@@ -400,7 +408,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
}
for _, p := range c.expectedPeers {
conn, ok := engine.peerConns[p.GetWgPubKey()]
conn, ok := engine.peerStore.PeerConn(p.GetWgPubKey())
if !ok {
t.Errorf("expecting Engine.peerConns to contain peer %s", p)
}
@@ -625,10 +633,10 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
}{}
mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
input.inputSerial = updateSerial
input.inputRoutes = newRoutes
return nil, nil, testCase.inputErr
return testCase.inputErr
},
}
@@ -801,8 +809,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
return nil, nil, nil
UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error {
return nil
},
}
@@ -1196,7 +1204,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := server.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
if err != nil {
return nil, "", err
}
@@ -1218,7 +1226,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
}
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil)
if err != nil {
return nil, "", err
}
@@ -1237,7 +1245,8 @@ func getConnectedPeers(e *Engine) int {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
i := 0
for _, conn := range e.peerConns {
for _, id := range e.peerStore.PeersPubKey() {
conn, _ := e.peerStore.PeerConn(id)
if conn.Status() == peer.StatusConnected {
i++
}
@@ -1249,5 +1258,5 @@ func getPeers(e *Engine) int {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
return len(e.peerConns)
return len(e.peerStore.PeersPubKey())
}

View File

@@ -17,8 +17,9 @@ import (
)
// IsLoginRequired check that the server is support SSO or not
func IsLoginRequired(ctx context.Context, privateKey string, mgmURL *url.URL, sshKey string) (bool, error) {
mgmClient, err := getMgmClient(ctx, privateKey, mgmURL)
func IsLoginRequired(ctx context.Context, config *Config) (bool, error) {
mgmURL := config.ManagementURL
mgmClient, err := getMgmClient(ctx, config.PrivateKey, mgmURL)
if err != nil {
return false, err
}
@@ -33,12 +34,12 @@ func IsLoginRequired(ctx context.Context, privateKey string, mgmURL *url.URL, ss
}()
log.Debugf("connected to the Management service %s", mgmURL.String())
pubSSHKey, err := ssh.GeneratePublicKey([]byte(sshKey))
pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
if err != nil {
return false, err
}
_, err = doMgmLogin(ctx, mgmClient, pubSSHKey)
_, err = doMgmLogin(ctx, mgmClient, pubSSHKey, config)
if isLoginNeeded(err) {
return true, nil
}
@@ -67,10 +68,10 @@ func Login(ctx context.Context, config *Config, setupKey string, jwtToken string
return err
}
serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey)
serverKey, err := doMgmLogin(ctx, mgmClient, pubSSHKey, config)
if serverKey != nil && isRegistrationNeeded(err) {
log.Debugf("peer registration required")
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey)
_, err = registerPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey, config)
return err
}
@@ -99,7 +100,7 @@ func getMgmClient(ctx context.Context, privateKey string, mgmURL *url.URL) (*mgm
return mgmClient, err
}
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte) (*wgtypes.Key, error) {
func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte, config *Config) (*wgtypes.Key, error) {
serverKey, err := mgmClient.GetServerPublicKey()
if err != nil {
log.Errorf("failed while getting Management Service public key: %v", err)
@@ -107,13 +108,22 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte
}
sysInfo := system.GetInfo(ctx)
sysInfo.SetFlags(
config.RosenpassEnabled,
config.RosenpassPermissive,
config.ServerSSHAllowed,
config.DisableClientRoutes,
config.DisableServerRoutes,
config.DisableDNS,
config.DisableFirewall,
)
_, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey)
return serverKey, err
}
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
// Otherwise tries to register with the provided setupKey via command line.
func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) {
validSetupKey, err := uuid.Parse(setupKey)
if err != nil && jwtToken == "" {
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
@@ -121,6 +131,15 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.
log.Debugf("sending peer registration request to Management Service")
info := system.GetInfo(ctx)
info.SetFlags(
config.RosenpassEnabled,
config.RosenpassPermissive,
config.ServerSSHAllowed,
config.DisableClientRoutes,
config.DisableServerRoutes,
config.DisableDNS,
config.DisableFirewall,
)
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey)
if err != nil {
log.Errorf("failed registering peer %v,%s", err, validSetupKey.String())

View File

@@ -747,6 +747,11 @@ func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) {
conn.wgProxyRelay = proxy
}
// AllowedIP returns the allowed IP of the remote peer
func (conn *Conn) AllowedIP() net.IP {
return conn.allowedIP
}
func isController(config ConnConfig) bool {
return config.LocalKey > config.Key
}

View File

@@ -17,6 +17,11 @@ import (
relayClient "github.com/netbirdio/netbird/relay/client"
)
type ResolvedDomainInfo struct {
Prefixes []netip.Prefix
ParentDomain domain.Domain
}
// State contains the latest state of a peer
type State struct {
Mux *sync.RWMutex
@@ -79,6 +84,12 @@ type LocalPeerState struct {
Routes map[string]struct{}
}
// Clone returns a copy of the LocalPeerState
func (l LocalPeerState) Clone() LocalPeerState {
l.Routes = maps.Clone(l.Routes)
return l
}
// SignalState contains the latest state of a signal connection
type SignalState struct {
URL string
@@ -138,7 +149,7 @@ type Status struct {
rosenpassEnabled bool
rosenpassPermissive bool
nsGroupStates []NSGroupState
resolvedDomainsStates map[domain.Domain][]netip.Prefix
resolvedDomainsStates map[domain.Domain]ResolvedDomainInfo
// To reduce the number of notification invocation this bool will be true when need to call the notification
// Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events
@@ -156,7 +167,7 @@ func NewRecorder(mgmAddress string) *Status {
offlinePeers: make([]State, 0),
notifier: newNotifier(),
mgmAddress: mgmAddress,
resolvedDomainsStates: make(map[domain.Domain][]netip.Prefix),
resolvedDomainsStates: map[domain.Domain]ResolvedDomainInfo{},
}
}
@@ -496,7 +507,7 @@ func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} {
func (d *Status) GetLocalPeerState() LocalPeerState {
d.mux.Lock()
defer d.mux.Unlock()
return d.localPeer
return d.localPeer.Clone()
}
// UpdateLocalPeerState updates local peer status
@@ -591,16 +602,27 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
d.nsGroupStates = dnsStates
}
func (d *Status) UpdateResolvedDomainsStates(domain domain.Domain, prefixes []netip.Prefix) {
func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix) {
d.mux.Lock()
defer d.mux.Unlock()
d.resolvedDomainsStates[domain] = prefixes
// Store both the original domain pattern and resolved domain
d.resolvedDomainsStates[resolvedDomain] = ResolvedDomainInfo{
Prefixes: prefixes,
ParentDomain: originalDomain,
}
}
func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
d.mux.Lock()
defer d.mux.Unlock()
delete(d.resolvedDomainsStates, domain)
// Remove all entries that have this domain as their parent
for k, v := range d.resolvedDomainsStates {
if v.ParentDomain == domain {
delete(d.resolvedDomainsStates, k)
}
}
}
func (d *Status) GetRosenpassState() RosenpassState {
@@ -702,7 +724,7 @@ func (d *Status) GetDNSStates() []NSGroupState {
return d.nsGroupStates
}
func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix {
func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo {
d.mux.Lock()
defer d.mux.Unlock()
return maps.Clone(d.resolvedDomainsStates)

View File

@@ -255,6 +255,10 @@ func (w *WorkerICE) closeAgent(cancel context.CancelFunc) {
defer w.muxAgent.Unlock()
cancel()
if w.agent == nil {
return
}
if err := w.agent.Close(); err != nil {
w.log.Warnf("failed to close ICE agent: %s", err)
}

View File

@@ -0,0 +1,87 @@
package peerstore
import (
"net"
"sync"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/peer"
)
// Store is a thread-safe store for peer connections.
type Store struct {
peerConns map[string]*peer.Conn
peerConnsMu sync.RWMutex
}
func NewConnStore() *Store {
return &Store{
peerConns: make(map[string]*peer.Conn),
}
}
func (s *Store) AddPeerConn(pubKey string, conn *peer.Conn) bool {
s.peerConnsMu.Lock()
defer s.peerConnsMu.Unlock()
_, ok := s.peerConns[pubKey]
if ok {
return false
}
s.peerConns[pubKey] = conn
return true
}
func (s *Store) Remove(pubKey string) (*peer.Conn, bool) {
s.peerConnsMu.Lock()
defer s.peerConnsMu.Unlock()
p, ok := s.peerConns[pubKey]
if !ok {
return nil, false
}
delete(s.peerConns, pubKey)
return p, true
}
func (s *Store) AllowedIPs(pubKey string) (string, bool) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return "", false
}
return p.WgConfig().AllowedIps, true
}
func (s *Store) AllowedIP(pubKey string) (net.IP, bool) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return nil, false
}
return p.AllowedIP(), true
}
func (s *Store) PeerConn(pubKey string) (*peer.Conn, bool) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return nil, false
}
return p, true
}
func (s *Store) PeersPubKey() []string {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
return maps.Keys(s.peerConns)
}

View File

@@ -1,58 +0,0 @@
package internal
import "context"
type ProbeHolder struct {
MgmProbe *Probe
SignalProbe *Probe
RelayProbe *Probe
WgProbe *Probe
}
// Probe allows to run on-demand callbacks from different code locations.
// Pass the probe to a receiving and a sending end. The receiving end starts listening
// to requests with Receive and executes a callback when the sending end requests it
// by calling Probe.
type Probe struct {
request chan struct{}
result chan bool
ready bool
}
// NewProbe returns a new initialized probe.
func NewProbe() *Probe {
return &Probe{
request: make(chan struct{}),
result: make(chan bool),
}
}
// Probe requests the callback to be run and returns a bool indicating success.
// It always returns true as long as the receiver is not ready.
func (p *Probe) Probe() bool {
if !p.ready {
return true
}
p.request <- struct{}{}
return <-p.result
}
// Receive starts listening for probe requests. On such a request it runs the supplied
// callback func which must return a bool indicating success.
// Blocks until the passed context is cancelled.
func (p *Probe) Receive(ctx context.Context, callback func() bool) {
p.ready = true
defer func() {
p.ready = false
}()
for {
select {
case <-ctx.Done():
return
case <-p.request:
p.result <- callback()
}
}
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"reflect"
runtime "runtime"
"time"
"github.com/hashicorp/go-multierror"
@@ -13,12 +14,20 @@ import (
"github.com/netbirdio/netbird/client/iface"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/static"
"github.com/netbirdio/netbird/route"
)
const (
handlerTypeDynamic = iota
handlerTypeDomain
handlerTypeStatic
)
type routerPeerStatus struct {
connected bool
relayed bool
@@ -53,7 +62,18 @@ type clientNetwork struct {
updateSerial uint64
}
func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface iface.IWGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork {
func newClientNetworkWatcher(
ctx context.Context,
dnsRouteInterval time.Duration,
wgInterface iface.IWGIface,
statusRecorder *peer.Status,
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
dnsServer nbdns.Server,
peerStore *peerstore.Store,
useNewDNSRoute bool,
) *clientNetwork {
ctx, cancel := context.WithCancel(ctx)
client := &clientNetwork{
@@ -65,7 +85,17 @@ func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration
routePeersNotifiers: make(map[string]chan struct{}),
routeUpdate: make(chan routesUpdate),
peerStateUpdate: make(chan struct{}),
handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface),
handler: handlerFromRoute(
rt,
routeRefCounter,
allowedIPsRefCounter,
dnsRouteInterval,
statusRecorder,
wgInterface,
dnsServer,
peerStore,
useNewDNSRoute,
),
}
return client
}
@@ -368,10 +398,50 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
}
}
func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface iface.IWGIface) RouteHandler {
if rt.IsDynamic() {
func handlerFromRoute(
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
dnsRouterInteval time.Duration,
statusRecorder *peer.Status,
wgInterface iface.IWGIface,
dnsServer nbdns.Server,
peerStore *peerstore.Store,
useNewDNSRoute bool,
) RouteHandler {
switch handlerType(rt, useNewDNSRoute) {
case handlerTypeDomain:
return dnsinterceptor.New(
rt,
routeRefCounter,
allowedIPsRefCounter,
statusRecorder,
dnsServer,
peerStore,
)
case handlerTypeDynamic:
dns := nbdns.NewServiceViaMemory(wgInterface)
return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()))
return dynamic.NewRoute(
rt,
routeRefCounter,
allowedIPsRefCounter,
dnsRouterInteval,
statusRecorder,
wgInterface,
fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()),
)
default:
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
}
return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter)
}
func handlerType(rt *route.Route, useNewDNSRoute bool) int {
if !rt.IsDynamic() {
return handlerTypeStatic
}
if useNewDNSRoute && runtime.GOOS != "ios" {
return handlerTypeDomain
}
return handlerTypeDynamic
}

View File

@@ -0,0 +1,356 @@
package dnsinterceptor
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
"sync"
"time"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
)
type domainMap map[domain.Domain][]netip.Prefix
type DnsInterceptor struct {
mu sync.RWMutex
route *route.Route
routeRefCounter *refcounter.RouteRefCounter
allowedIPsRefcounter *refcounter.AllowedIPsRefCounter
statusRecorder *peer.Status
dnsServer nbdns.Server
currentPeerKey string
interceptedDomains domainMap
peerStore *peerstore.Store
}
func New(
rt *route.Route,
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
statusRecorder *peer.Status,
dnsServer nbdns.Server,
peerStore *peerstore.Store,
) *DnsInterceptor {
return &DnsInterceptor{
route: rt,
routeRefCounter: routeRefCounter,
allowedIPsRefcounter: allowedIPsRefCounter,
statusRecorder: statusRecorder,
dnsServer: dnsServer,
interceptedDomains: make(domainMap),
peerStore: peerStore,
}
}
func (d *DnsInterceptor) String() string {
return d.route.Domains.SafeString()
}
func (d *DnsInterceptor) AddRoute(context.Context) error {
d.dnsServer.RegisterHandler(d.route.Domains.ToPunycodeList(), d, nbdns.PriorityDNSRoute)
return nil
}
func (d *DnsInterceptor) RemoveRoute() error {
d.mu.Lock()
var merr *multierror.Error
for domain, prefixes := range d.interceptedDomains {
for _, prefix := range prefixes {
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err))
}
if d.currentPeerKey != "" {
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
}
}
}
log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", "))
}
for _, domain := range d.route.Domains {
d.statusRecorder.DeleteResolvedDomainsStates(domain)
}
clear(d.interceptedDomains)
d.mu.Unlock()
d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList(), nbdns.PriorityDNSRoute)
return nberrors.FormatErrorOrNil(merr)
}
func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error {
d.mu.Lock()
defer d.mu.Unlock()
var merr *multierror.Error
for domain, prefixes := range d.interceptedDomains {
for _, prefix := range prefixes {
if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
} else if ref.Count > 1 && ref.Out != peerKey {
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
prefix.Addr(),
domain.SafeString(),
ref.Out,
)
}
}
}
d.currentPeerKey = peerKey
return nberrors.FormatErrorOrNil(merr)
}
func (d *DnsInterceptor) RemoveAllowedIPs() error {
d.mu.Lock()
defer d.mu.Unlock()
var merr *multierror.Error
for _, prefixes := range d.interceptedDomains {
for _, prefix := range prefixes {
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
}
}
}
d.currentPeerKey = ""
return nberrors.FormatErrorOrNil(merr)
}
// ServeDNS implements the dns.Handler interface
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
return
}
log.Tracef("received DNS request for domain=%s type=%v class=%v",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
d.mu.RLock()
peerKey := d.currentPeerKey
d.mu.RUnlock()
if peerKey == "" {
log.Tracef("no current peer key set, letting next handler try for domain=%s", r.Question[0].Name)
d.continueToNextHandler(w, r, "no current peer key")
return
}
upstreamIP, err := d.getUpstreamIP(peerKey)
if err != nil {
log.Errorf("failed to get upstream IP: %v", err)
d.continueToNextHandler(w, r, fmt.Sprintf("failed to get upstream IP: %v", err))
return
}
client := &dns.Client{
Timeout: 5 * time.Second,
Net: "udp",
}
upstream := fmt.Sprintf("%s:%d", upstreamIP, dnsfwd.ListenPort)
reply, _, err := client.ExchangeContext(context.Background(), r, upstream)
var answer []dns.RR
if reply != nil {
answer = reply.Answer
}
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP, peerKey, r.Question[0].Name, answer)
if err != nil {
log.Errorf("failed to exchange DNS request with %s: %v", upstream, err)
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
log.Errorf("failed writing DNS response: %v", err)
}
return
}
reply.Id = r.Id
if err := d.writeMsg(w, reply); err != nil {
log.Errorf("failed writing DNS response: %v", err)
}
}
// continueToNextHandler signals the handler chain to try the next handler
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) {
log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeNameError)
// Set Zero bit to signal handler chain to continue
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed writing DNS continue response: %v", err)
}
}
func (d *DnsInterceptor) getUpstreamIP(peerKey string) (net.IP, error) {
peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey)
if !exists {
return nil, fmt.Errorf("peer connection not found for key: %s", peerKey)
}
return peerAllowedIP, nil
}
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {
if r == nil {
return fmt.Errorf("received nil DNS message")
}
if len(r.Answer) > 0 && len(r.Question) > 0 {
origPattern := ""
if writer, ok := w.(*nbdns.ResponseWriterChain); ok {
origPattern = writer.GetOrigPattern()
}
resolvedDomain := domain.Domain(r.Question[0].Name)
// already punycode via RegisterHandler()
originalDomain := domain.Domain(origPattern)
if originalDomain == "" {
originalDomain = resolvedDomain
}
var newPrefixes []netip.Prefix
for _, answer := range r.Answer {
var ip netip.Addr
switch rr := answer.(type) {
case *dns.A:
addr, ok := netip.AddrFromSlice(rr.A)
if !ok {
log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A)
continue
}
ip = addr
case *dns.AAAA:
addr, ok := netip.AddrFromSlice(rr.AAAA)
if !ok {
log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA)
continue
}
ip = addr
default:
continue
}
prefix := netip.PrefixFrom(ip, ip.BitLen())
newPrefixes = append(newPrefixes, prefix)
}
if len(newPrefixes) > 0 {
if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil {
log.Errorf("failed to update domain prefixes: %v", err)
}
}
}
if err := w.WriteMsg(r); err != nil {
return fmt.Errorf("failed to write DNS response: %v", err)
}
return nil
}
func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error {
d.mu.Lock()
defer d.mu.Unlock()
oldPrefixes := d.interceptedDomains[resolvedDomain]
toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes)
var merr *multierror.Error
// Add new prefixes
for _, prefix := range toAdd {
if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err))
continue
}
if d.currentPeerKey == "" {
continue
}
if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err))
} else if ref.Count > 1 && ref.Out != d.currentPeerKey {
log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled",
prefix.Addr(),
resolvedDomain.SafeString(),
ref.Out,
)
}
}
if !d.route.KeepRoute {
// Remove old prefixes
for _, prefix := range toRemove {
if _, err := d.routeRefCounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err))
}
if d.currentPeerKey != "" {
if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err))
}
}
}
}
// Update domain prefixes using resolved domain as key
if len(toAdd) > 0 || len(toRemove) > 0 {
d.interceptedDomains[resolvedDomain] = newPrefixes
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes)
if len(toAdd) > 0 {
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toAdd)
}
if len(toRemove) > 0 {
log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s",
resolvedDomain.SafeString(),
originalDomain.SafeString(),
toRemove)
}
}
return nberrors.FormatErrorOrNil(merr)
}
func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) {
prefixSet := make(map[netip.Prefix]bool)
for _, prefix := range oldPrefixes {
prefixSet[prefix] = false
}
for _, prefix := range newPrefixes {
if _, exists := prefixSet[prefix]; exists {
prefixSet[prefix] = true
} else {
toAdd = append(toAdd, prefix)
}
}
for prefix, inUse := range prefixSet {
if !inUse {
toRemove = append(toRemove, prefix)
}
}
return
}

View File

@@ -74,11 +74,7 @@ func NewRoute(
}
func (r *Route) String() string {
s, err := r.route.Domains.String()
if err != nil {
return r.route.Domains.PunycodeString()
}
return s
return r.route.Domains.SafeString()
}
func (r *Route) AddRoute(ctx context.Context) error {
@@ -292,7 +288,7 @@ func (r *Route) updateDynamicRoutes(ctx context.Context, newDomains domainMap) e
updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes)
r.dynamicDomains[domain] = updatedPrefixes
r.statusRecorder.UpdateResolvedDomainsStates(domain, updatedPrefixes)
r.statusRecorder.UpdateResolvedDomainsStates(domain, domain, updatedPrefixes)
}
return nberrors.FormatErrorOrNil(merr)

View File

@@ -12,12 +12,16 @@ import (
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
@@ -33,15 +37,32 @@ import (
// Manager is a route manager interface
type Manager interface {
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error
TriggerSelection(route.HAMap)
GetRouteSelector() *routeselector.RouteSelector
GetClientRoutes() route.HAMap
GetClientRoutesWithNetID() map[route.NetID][]*route.Route
SetRouteChangeListener(listener listener.NetworkChangeListener)
InitialRouteRange() []string
EnableServerRouter(firewall firewall.Manager) error
Stop(stateManager *statemanager.Manager)
}
type ManagerConfig struct {
Context context.Context
PublicKey string
DNSRouteInterval time.Duration
WGInterface iface.IWGIface
StatusRecorder *peer.Status
RelayManager *relayClient.Manager
InitialRoutes []*route.Route
StateManager *statemanager.Manager
DNSServer dns.Server
PeerStore *peerstore.Store
DisableClientRoutes bool
DisableServerRoutes bool
}
// DefaultManager is the default instance of a route manager
type DefaultManager struct {
ctx context.Context
@@ -49,7 +70,7 @@ type DefaultManager struct {
mux sync.Mutex
clientNetworks map[route.HAUniqueID]*clientNetwork
routeSelector *routeselector.RouteSelector
serverRouter serverRouter
serverRouter *serverRouter
sysOps *systemops.SysOps
statusRecorder *peer.Status
relayMgr *relayClient.Manager
@@ -60,52 +81,80 @@ type DefaultManager struct {
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter
dnsRouteInterval time.Duration
stateManager *statemanager.Manager
// clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes route.HAMap
dnsServer dns.Server
peerStore *peerstore.Store
useNewDNSRoute bool
disableClientRoutes bool
disableServerRoutes bool
}
func NewManager(
ctx context.Context,
pubKey string,
dnsRouteInterval time.Duration,
wgInterface iface.IWGIface,
statusRecorder *peer.Status,
relayMgr *relayClient.Manager,
initialRoutes []*route.Route,
stateManager *statemanager.Manager,
) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx)
func NewManager(config ManagerConfig) *DefaultManager {
mCTX, cancel := context.WithCancel(config.Context)
notifier := notifier.NewNotifier()
sysOps := systemops.NewSysOps(wgInterface, notifier)
sysOps := systemops.NewSysOps(config.WGInterface, notifier)
dm := &DefaultManager{
ctx: mCTX,
stop: cancel,
dnsRouteInterval: dnsRouteInterval,
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
relayMgr: relayMgr,
sysOps: sysOps,
statusRecorder: statusRecorder,
wgInterface: wgInterface,
pubKey: pubKey,
notifier: notifier,
stateManager: stateManager,
ctx: mCTX,
stop: cancel,
dnsRouteInterval: config.DNSRouteInterval,
clientNetworks: make(map[route.HAUniqueID]*clientNetwork),
relayMgr: config.RelayManager,
sysOps: sysOps,
statusRecorder: config.StatusRecorder,
wgInterface: config.WGInterface,
pubKey: config.PublicKey,
notifier: notifier,
stateManager: config.StateManager,
dnsServer: config.DNSServer,
peerStore: config.PeerStore,
disableClientRoutes: config.DisableClientRoutes,
disableServerRoutes: config.DisableServerRoutes,
}
dm.routeRefCounter = refcounter.New(
// don't proceed with client routes if it is disabled
if config.DisableClientRoutes {
return dm
}
dm.setupRefCounters()
if runtime.GOOS == "android" {
cr := dm.initialClientRoutes(config.InitialRoutes)
dm.notifier.SetInitialClientRoutes(cr)
}
return dm
}
func (m *DefaultManager) setupRefCounters() {
m.routeRefCounter = refcounter.New(
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
return struct{}{}, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface())
return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface())
},
func(prefix netip.Prefix, _ struct{}) error {
return sysOps.RemoveVPNRoute(prefix, wgInterface.ToInterface())
return m.sysOps.RemoveVPNRoute(prefix, m.wgInterface.ToInterface())
},
)
dm.allowedIPsRefCounter = refcounter.New(
if netstack.IsEnabled() {
m.routeRefCounter = refcounter.New(
func(netip.Prefix, struct{}) (struct{}, error) {
return struct{}{}, refcounter.ErrIgnore
},
func(netip.Prefix, struct{}) error {
return nil
},
)
}
m.allowedIPsRefCounter = refcounter.New(
func(prefix netip.Prefix, peerKey string) (string, error) {
// save peerKey to use it in the remove function
return peerKey, wgInterface.AddAllowedIP(peerKey, prefix.String())
return peerKey, m.wgInterface.AddAllowedIP(peerKey, prefix.String())
},
func(prefix netip.Prefix, peerKey string) error {
if err := wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil {
if err := m.wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil {
if !errors.Is(err, configurer.ErrPeerNotFound) && !errors.Is(err, configurer.ErrAllowedIPNotFound) {
return err
}
@@ -114,19 +163,13 @@ func NewManager(
return nil
},
)
if runtime.GOOS == "android" {
cr := dm.clientRoutes(initialRoutes)
dm.notifier.SetInitialClientRoutes(cr)
}
return dm
}
// Init sets up the routing
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
m.routeSelector = m.initSelector()
if nbnet.CustomRoutingDisabled() {
if nbnet.CustomRoutingDisabled() || m.disableClientRoutes {
return nil, nil, nil
}
@@ -172,6 +215,15 @@ func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
}
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
if m.disableServerRoutes {
log.Info("server routes are disabled")
return nil
}
if firewall == nil {
return errors.New("firewall manager is not set")
}
var err error
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
if err != nil {
@@ -198,7 +250,7 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
}
}
if !nbnet.CustomRoutingDisabled() {
if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes {
if err := m.sysOps.CleanupRouting(stateManager); err != nil {
log.Errorf("Error cleaning up routing: %v", err)
} else {
@@ -207,33 +259,43 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
}
m.ctx = nil
m.mux.Lock()
defer m.mux.Unlock()
m.clientRoutes = nil
}
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error {
select {
case <-m.ctx.Done():
log.Infof("not updating routes as context is closed")
return nil, nil, m.ctx.Err()
return nil
default:
m.mux.Lock()
defer m.mux.Unlock()
}
newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
m.mux.Lock()
defer m.mux.Unlock()
m.useNewDNSRoute = useNewDNSRoute
newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes)
if !m.disableClientRoutes {
filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap)
m.updateClientNetworks(updateSerial, filteredClientRoutes)
m.notifier.OnNewRoutes(filteredClientRoutes)
if m.serverRouter != nil {
err := m.serverRouter.updateRoutes(newServerRoutesMap)
if err != nil {
return nil, nil, fmt.Errorf("update routes: %w", err)
}
}
return newServerRoutesMap, newClientRoutesIDMap, nil
}
if m.serverRouter != nil {
err := m.serverRouter.updateRoutes(newServerRoutesMap)
if err != nil {
return err
}
}
m.clientRoutes = newClientRoutesIDMap
return nil
}
// SetRouteChangeListener set RouteListener for route change Notifier
@@ -251,9 +313,24 @@ func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector {
return m.routeSelector
}
// GetClientRoutes returns the client routes
func (m *DefaultManager) GetClientRoutes() map[route.HAUniqueID]*clientNetwork {
return m.clientNetworks
// GetClientRoutes returns most recent list of clientRoutes received from the Management Service
func (m *DefaultManager) GetClientRoutes() route.HAMap {
m.mux.Lock()
defer m.mux.Unlock()
return maps.Clone(m.clientRoutes)
}
// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only
func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
m.mux.Lock()
defer m.mux.Unlock()
routes := make(map[route.NetID][]*route.Route, len(m.clientRoutes))
for id, v := range m.clientRoutes {
routes[id.NetID()] = v
}
return routes
}
// TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones
@@ -273,7 +350,18 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
continue
}
clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
clientNetworkWatcher := newClientNetworkWatcher(
m.ctx,
m.dnsRouteInterval,
m.wgInterface,
m.statusRecorder,
routes[0],
m.routeRefCounter,
m.allowedIPsRefCounter,
m.dnsServer,
m.peerStore,
m.useNewDNSRoute,
)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher()
clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes})
@@ -302,7 +390,18 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id]
if !found {
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter)
clientNetworkWatcher = newClientNetworkWatcher(
m.ctx,
m.dnsRouteInterval,
m.wgInterface,
m.statusRecorder,
routes[0],
m.routeRefCounter,
m.allowedIPsRefCounter,
m.dnsServer,
m.peerStore,
m.useNewDNSRoute,
)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher()
}
@@ -345,7 +444,7 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID]
return newServerRoutesMap, newClientRoutesIDMap
}
func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route {
func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*route.Route {
_, crMap := m.classifyRoutes(initialRoutes)
rs := make([]*route.Route, 0, len(crMap))
for _, routes := range crMap {

View File

@@ -424,7 +424,12 @@ func TestManagerUpdateRoutes(t *testing.T) {
statusRecorder := peer.NewRecorder("https://mgm")
ctx := context.TODO()
routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil)
routeManager := NewManager(ManagerConfig{
Context: ctx,
PublicKey: localPeerKey,
WGInterface: wgInterface,
StatusRecorder: statusRecorder,
})
_, _, err = routeManager.Init()
@@ -436,11 +441,11 @@ func TestManagerUpdateRoutes(t *testing.T) {
}
if len(testCase.inputInitRoutes) > 0 {
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes)
_ = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes, false)
require.NoError(t, err, "should update routes with init routes")
}
_, _, err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
_ = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes, false)
require.NoError(t, err, "should update routes")
expectedWatchers := testCase.clientNetworkWatchersExpected
@@ -450,8 +455,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")
if runtime.GOOS == "linux" && routeManager.serverRouter != nil {
sr := routeManager.serverRouter.(*defaultServerRouter)
require.Len(t, sr.routes, testCase.serverRoutesExpected, "server networks size should match")
require.Len(t, routeManager.serverRouter.routes, testCase.serverRoutesExpected, "server networks size should match")
}
})
}

View File

@@ -2,7 +2,6 @@ package routemanager
import (
"context"
"fmt"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
@@ -15,10 +14,12 @@ import (
// MockManager is the mock instance of a route manager
type MockManager struct {
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error)
TriggerSelectionFunc func(haMap route.HAMap)
GetRouteSelectorFunc func() *routeselector.RouteSelector
StopFunc func(manager *statemanager.Manager)
UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error
TriggerSelectionFunc func(haMap route.HAMap)
GetRouteSelectorFunc func() *routeselector.RouteSelector
GetClientRoutesFunc func() route.HAMap
GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route
StopFunc func(manager *statemanager.Manager)
}
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) {
@@ -31,11 +32,11 @@ func (m *MockManager) InitialRouteRange() []string {
}
// UpdateRoutes mock implementation of UpdateRoutes from Manager interface
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) {
func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, b bool) error {
if m.UpdateRoutesFunc != nil {
return m.UpdateRoutesFunc(updateSerial, newRoutes)
}
return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented")
return nil
}
func (m *MockManager) TriggerSelection(networks route.HAMap) {
@@ -52,6 +53,22 @@ func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector {
return nil
}
// GetClientRoutes mock implementation of GetClientRoutes from Manager interface
func (m *MockManager) GetClientRoutes() route.HAMap {
if m.GetClientRoutesFunc != nil {
return m.GetClientRoutesFunc()
}
return nil
}
// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface
func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route {
if m.GetClientRoutesWithNetIDFunc != nil {
return m.GetClientRoutesWithNetIDFunc()
}
return nil
}
// Start mock implementation of Start from Manager interface
func (m *MockManager) Start(ctx context.Context, iface *iface.WGIface) {
}

View File

@@ -1,9 +0,0 @@
package routemanager
import "github.com/netbirdio/netbird/route"
type serverRouter interface {
updateRoutes(map[route.ID]*route.Route) error
removeFromServerNetwork(*route.Route) error
cleanUp()
}

View File

@@ -9,8 +9,19 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/route"
)
func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (serverRouter, error) {
type serverRouter struct {
}
func (r serverRouter) cleanUp() {
}
func (r serverRouter) updateRoutes(map[route.ID]*route.Route) error {
return nil
}
func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (*serverRouter, error) {
return nil, fmt.Errorf("server route not supported on this os")
}

View File

@@ -17,7 +17,7 @@ import (
"github.com/netbirdio/netbird/route"
)
type defaultServerRouter struct {
type serverRouter struct {
mux sync.Mutex
ctx context.Context
routes map[route.ID]*route.Route
@@ -26,8 +26,8 @@ type defaultServerRouter struct {
statusRecorder *peer.Status
}
func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) {
return &defaultServerRouter{
func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*serverRouter, error) {
return &serverRouter{
ctx: ctx,
routes: make(map[route.ID]*route.Route),
firewall: firewall,
@@ -36,7 +36,7 @@ func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall f
}, nil
}
func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error {
serverRoutesToRemove := make([]route.ID, 0)
for routeID := range m.routes {
@@ -80,74 +80,72 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route)
return nil
}
func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error {
select {
case <-m.ctx.Done():
func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
if m.ctx.Err() != nil {
log.Infof("Not removing from server network because context is done")
return m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
routerPair, err := routeToRouterPair(route)
if err != nil {
return fmt.Errorf("parse prefix: %w", err)
}
err = m.firewall.RemoveNatRule(routerPair)
if err != nil {
return fmt.Errorf("remove routing rules: %w", err)
}
delete(m.routes, route.ID)
state := m.statusRecorder.GetLocalPeerState()
delete(state.Routes, route.Network.String())
m.statusRecorder.UpdateLocalPeerState(state)
return nil
}
m.mux.Lock()
defer m.mux.Unlock()
routerPair, err := routeToRouterPair(route)
if err != nil {
return fmt.Errorf("parse prefix: %w", err)
}
err = m.firewall.RemoveNatRule(routerPair)
if err != nil {
return fmt.Errorf("remove routing rules: %w", err)
}
delete(m.routes, route.ID)
state := m.statusRecorder.GetLocalPeerState()
delete(state.Routes, route.Network.String())
m.statusRecorder.UpdateLocalPeerState(state)
return nil
}
func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
select {
case <-m.ctx.Done():
func (m *serverRouter) addToServerNetwork(route *route.Route) error {
if m.ctx.Err() != nil {
log.Infof("Not adding to server network because context is done")
return m.ctx.Err()
default:
m.mux.Lock()
defer m.mux.Unlock()
routerPair, err := routeToRouterPair(route)
if err != nil {
return fmt.Errorf("parse prefix: %w", err)
}
err = m.firewall.AddNatRule(routerPair)
if err != nil {
return fmt.Errorf("insert routing rules: %w", err)
}
m.routes[route.ID] = route
state := m.statusRecorder.GetLocalPeerState()
if state.Routes == nil {
state.Routes = map[string]struct{}{}
}
routeStr := route.Network.String()
if route.IsDynamic() {
routeStr = route.Domains.SafeString()
}
state.Routes[routeStr] = struct{}{}
m.statusRecorder.UpdateLocalPeerState(state)
return nil
}
m.mux.Lock()
defer m.mux.Unlock()
routerPair, err := routeToRouterPair(route)
if err != nil {
return fmt.Errorf("parse prefix: %w", err)
}
err = m.firewall.AddNatRule(routerPair)
if err != nil {
return fmt.Errorf("insert routing rules: %w", err)
}
m.routes[route.ID] = route
state := m.statusRecorder.GetLocalPeerState()
if state.Routes == nil {
state.Routes = map[string]struct{}{}
}
routeStr := route.Network.String()
if route.IsDynamic() {
routeStr = route.Domains.SafeString()
}
state.Routes[routeStr] = struct{}{}
m.statusRecorder.UpdateLocalPeerState(state)
return nil
}
func (m *defaultServerRouter) cleanUp() {
func (m *serverRouter) cleanUp() {
m.mux.Lock()
defer m.mux.Unlock()
for _, r := range m.routes {

View File

@@ -17,6 +17,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
@@ -62,6 +63,17 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
r.removeFromRouteTable,
)
if netstack.IsEnabled() {
refCounter = refcounter.New(
func(netip.Prefix, struct{}) (Nexthop, error) {
return Nexthop{}, refcounter.ErrIgnore
},
func(netip.Prefix, Nexthop) error {
return nil
},
)
}
r.refCounter = refCounter
return r.setupHooks(initAddresses, stateManager)

View File

@@ -266,7 +266,7 @@ func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
return fmt.Errorf("add gateway and device: %w", err)
}
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) {
return fmt.Errorf("netlink add route: %w", err)
}
@@ -289,7 +289,7 @@ func addUnreachableRoute(prefix netip.Prefix, tableID int) error {
Dst: ipNet,
}
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) {
return fmt.Errorf("netlink add unreachable route: %w", err)
}
@@ -312,7 +312,7 @@ func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
if err := netlink.RouteDel(route); err != nil &&
!errors.Is(err, syscall.ESRCH) &&
!errors.Is(err, syscall.ENOENT) &&
!errors.Is(err, syscall.EAFNOSUPPORT) {
!isOpErr(err) {
return fmt.Errorf("netlink remove unreachable route: %w", err)
}
@@ -338,7 +338,7 @@ func removeRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error {
return fmt.Errorf("add gateway and device: %w", err)
}
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) {
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !isOpErr(err) {
return fmt.Errorf("netlink remove route: %w", err)
}
@@ -362,7 +362,7 @@ func flushRoutes(tableID, family int) error {
routes[i].Dst = &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)}
}
}
if err := netlink.RouteDel(&routes[i]); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
if err := netlink.RouteDel(&routes[i]); err != nil && !isOpErr(err) {
result = multierror.Append(result, fmt.Errorf("failed to delete route %v from table %d: %w", routes[i], tableID, err))
}
}
@@ -450,7 +450,7 @@ func addRule(params ruleParams) error {
rule.Invert = params.invert
rule.SuppressPrefixlen = params.suppressPrefix
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) {
return fmt.Errorf("add routing rule: %w", err)
}
@@ -467,7 +467,7 @@ func removeRule(params ruleParams) error {
rule.Priority = params.priority
rule.SuppressPrefixlen = params.suppressPrefix
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) {
if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !isOpErr(err) {
return fmt.Errorf("remove routing rule: %w", err)
}
@@ -509,3 +509,13 @@ func hasSeparateRouting() ([]netip.Prefix, error) {
}
return nil, ErrRoutingIsSeparate
}
func isOpErr(err error) bool {
// EAFTNOSUPPORT when ipv6 is disabled via sysctl, EOPNOTSUPP when disabled in boot options or otherwise not supported
if errors.Is(err, syscall.EAFNOSUPPORT) || errors.Is(err, syscall.EOPNOTSUPP) {
log.Debugf("route operation not supported: %v", err)
return true
}
return false
}

View File

@@ -303,20 +303,29 @@ func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage,
var rawStates map[string]json.RawMessage
if err := json.Unmarshal(data, &rawStates); err != nil {
if deleteCorrupt {
log.Warn("State file appears to be corrupted, attempting to delete it", err)
if err := os.Remove(m.filePath); err != nil {
log.Errorf("Failed to delete corrupted state file: %v", err)
} else {
log.Info("State file deleted")
}
}
m.handleCorruptedState(deleteCorrupt)
return nil, fmt.Errorf("unmarshal states: %w", err)
}
return rawStates, nil
}
// handleCorruptedState creates a backup of a corrupted state file by moving it
func (m *Manager) handleCorruptedState(deleteCorrupt bool) {
if !deleteCorrupt {
return
}
log.Warn("State file appears to be corrupted, attempting to back it up")
backupPath := fmt.Sprintf("%s.corrupted.%d", m.filePath, time.Now().UnixNano())
if err := os.Rename(m.filePath, backupPath); err != nil {
log.Errorf("Failed to backup corrupted state file: %v", err)
return
}
log.Infof("Created backup of corrupted state file at: %s", backupPath)
}
// loadSingleRawState unmarshals a raw state into a concrete state object
func (m *Manager) loadSingleRawState(name string, rawState json.RawMessage) (State, error) {
stateType, ok := m.stateTypes[name]

View File

@@ -1,23 +1,16 @@
package statemanager
import (
"github.com/netbirdio/netbird/client/configs"
"os"
"path/filepath"
"runtime"
)
// GetDefaultStatePath returns the path to the state file based on the operating system
// It returns an empty string if the path cannot be determined.
func GetDefaultStatePath() string {
switch runtime.GOOS {
case "windows":
return filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
case "darwin", "linux":
return "/var/lib/netbird/state.json"
case "freebsd", "openbsd", "netbsd", "dragonfly":
return "/var/db/netbird/state.json"
if path := os.Getenv("NB_DNS_STATE_FILE"); path != "" {
return path
}
return ""
return filepath.Join(configs.StateDir, "state.json")
}

View File

@@ -207,7 +207,7 @@ func (c *Client) IsLoginRequired() bool {
ConfigPath: c.cfgFile,
})
needsLogin, _ := internal.IsLoginRequired(ctx, cfg.PrivateKey, cfg.ManagementURL, cfg.SSHKey)
needsLogin, _ := internal.IsLoginRequired(ctx, cfg)
return needsLogin
}
@@ -272,8 +272,8 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
return nil, fmt.Errorf("not connected")
}
routesMap := engine.GetClientRoutesWithNetID()
routeManager := engine.GetRouteManager()
routesMap := routeManager.GetClientRoutesWithNetID()
if routeManager == nil {
return nil, fmt.Errorf("could not get route manager")
}
@@ -317,7 +317,7 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
}
func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain][]netip.Prefix) *RoutesSelectionDetails {
func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) *RoutesSelectionDetails {
var routeSelection []RoutesSelectionInfo
for _, r := range routes {
domainList := make([]DomainInfo, 0)
@@ -325,9 +325,10 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom
domainResp := DomainInfo{
Domain: d.SafeString(),
}
if prefixes, exists := resolvedDomains[d]; exists {
if info, exists := resolvedDomains[d]; exists {
var ipStrings []string
for _, prefix := range prefixes {
for _, prefix := range info.Prefixes {
ipStrings = append(ipStrings, prefix.Addr().String())
}
domainResp.ResolvedIPs = strings.Join(ipStrings, ", ")
@@ -365,12 +366,12 @@ func (c *Client) SelectRoute(id string) error {
} else {
log.Debugf("select route with id: %s", id)
routes := toNetIDs([]string{id})
if err := routeSelector.SelectRoutes(routes, true, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil {
if err := routeSelector.SelectRoutes(routes, true, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil {
log.Debugf("error when selecting routes: %s", err)
return fmt.Errorf("select routes: %w", err)
}
}
routeManager.TriggerSelection(engine.GetClientRoutes())
routeManager.TriggerSelection(routeManager.GetClientRoutes())
return nil
}
@@ -392,12 +393,12 @@ func (c *Client) DeselectRoute(id string) error {
} else {
log.Debugf("deselect route with id: %s", id)
routes := toNetIDs([]string{id})
if err := routeSelector.DeselectRoutes(routes, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil {
if err := routeSelector.DeselectRoutes(routes, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil {
log.Debugf("error when deselecting routes: %s", err)
return fmt.Errorf("deselect routes: %w", err)
}
}
routeManager.TriggerSelection(engine.GetClientRoutes())
routeManager.TriggerSelection(routeManager.GetClientRoutes())
return nil
}

Some files were not shown because too many files have changed in this diff Show More