Compare commits

...

69 Commits

Author SHA1 Message Date
Viktor Liu
3efa7a282a Set log level debug 2024-11-27 13:46:37 +01:00
Viktor Liu
40551099b3 Add debug 2024-11-27 13:41:32 +01:00
Pascal Fischer
9db1932664 [management] Fix getSetupKey call (#2927) 2024-11-22 10:15:51 +01:00
Viktor Liu
1bbabf70b0 [client] Fix allow netbird rule verdict (#2925)
* Fix allow netbird rule verdict

* Fix chain name
2024-11-21 16:53:37 +01:00
Pascal Fischer
aa575d6f44 [management] Add activity events to group propagation flow (#2916) 2024-11-21 15:10:34 +01:00
Pascal Fischer
f66bbcc54c [management] Add metric for peer meta update (#2913) 2024-11-19 18:13:26 +01:00
Pascal Fischer
5dd6a08ea6 link peer meta update back to account object (#2911) 2024-11-19 17:25:49 +01:00
Krzysztof Nazarewski (kdn)
eb5d0569ae [client] Add NB_SKIP_SOCKET_MARK & fix crash instead of returing an error (#2899)
* dialer: fix crash instead of returning error

* add NB_SKIP_SOCKET_MARK
2024-11-19 14:14:58 +01:00
Pascal Fischer
52ea2e84e9 [management] Add transaction metrics and exclude getAccount time from peers update (#2904) 2024-11-19 00:04:50 +01:00
Maycon Santos
78fab877c0 [misc] Update signing pipeline version (#2900) 2024-11-18 15:31:53 +01:00
Maycon Santos
65a94f695f use google domain for tests (#2902) 2024-11-18 12:55:02 +01:00
Kursat Aktas
ec543f89fb Introducing NetBird Guru on Gurubase.io (#2778) 2024-11-16 15:45:31 +01:00
Viktor Liu
a7d5c52203 Fix error state race on mgmt connection error (#2892) 2024-11-15 22:59:49 +01:00
Viktor Liu
582bb58714 Move state updates outside the refcounter (#2897) 2024-11-15 22:55:33 +01:00
Viktor Liu
121dfda915 [client] Fix state manager race conditions (#2890) 2024-11-15 20:05:26 +01:00
İsmail
a1c5287b7c Fix the Inactivity Expiration problem. (#2865) 2024-11-15 18:21:27 +01:00
Bethuel Mmbaga
12f442439a [management] Refactor group to use store methods (#2867)
* Refactor setup key handling to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add lock to get account groups

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add check for regular user

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* get only required groups for auto-group validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add account lock and return auto groups map on validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor account peers update

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor groups to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor GetGroupByID and add NewGroupNotFoundError

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add AddPeer and RemovePeer methods to Group struct

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Preserve store engine in SqlStore transactions

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Run groups ops in transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix missing group removed from setup key activity

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix sonar

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Change setup key log level to debug for missing group

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Retrieve modified peers once for group events

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add account locking and merge group deletion methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-15 20:09:32 +03:00
Pascal Fischer
d9b691b8a5 [management] Limit the setup-key update operation (#2841) 2024-11-15 17:00:06 +01:00
Pascal Fischer
4aee3c9e33 [client/management] add peer lock to peer meta update and fix isEqual func (#2840) 2024-11-15 16:59:03 +01:00
Pascal Fischer
44e799c687 [management] Fix limited peer view groups (#2894) 2024-11-15 11:16:16 +01:00
Viktor Liu
be78efbd42 [client] Handle panic on nil wg interface (#2891) 2024-11-14 20:15:16 +01:00
Maycon Santos
6886691213 Update route calculation tests (#2884)
- Add two new test cases for p2p and relay routes with same latency
- Add extra statuses generation
2024-11-13 15:21:33 +01:00
Zoltan Papp
b48afd92fd [relay-server] Always close ws conn when work thread exit (#2879)
Close ws conn when work thread exit
2024-11-13 15:02:51 +01:00
Viktor Liu
39329e12a1 [client] Improve state write timeout and abort work early on timeout (#2882)
* Improve state write timeout and abort work early on timeout

* Don't block on initial persist state
2024-11-13 13:46:00 +01:00
Pascal Fischer
20a5afc359 [management] Add more logs to the peer update processes (#2881) 2024-11-12 14:19:22 +01:00
Bethuel Mmbaga
6cb697eed6 [management] Refactor setup key to use store methods (#2861)
* Refactor setup key handling to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add lock to get account groups

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add check for regular user

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* get only required groups for auto-group validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add account lock and return auto groups map on validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix missing group removed from setup key activity

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Remove context from DB queries

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add user permission check and add setup events into events to store slice

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Retrieve all groups once during setup key auto-group validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix lint

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix sonar

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-11 19:46:10 +03:00
Viktor Liu
e0bed2b0fb [client] Fix race conditions (#2869)
* Fix concurrent map access in status

* Fix race when retrieving ctx state error

* Fix race when accessing service controller server instance
2024-11-11 14:55:10 +01:00
Zoltan Papp
30f025e7dd [client] fix/proxy close (#2873)
When the remote peer switches the Relay instance then must to close the proxy connection to the old instance.

It can cause issues when the remote peer switch connects to the Relay instance multiple times and then reconnects to an instance it had previously connected to.
2024-11-11 14:18:38 +01:00
Zoltan Papp
b4d7605147 [client] Remove loop after route calculation (#2856)
- ICE do not trigger disconnect callbacks if the stated did not change
- Fix route calculation callback loop
- Move route state updates into protected scope by mutex
- Do not calculate routes in case of peer.Open() and peer.Close()
2024-11-11 10:53:57 +01:00
Viktor Liu
08b6e9d647 [management] Fix api error message typo peers_group (#2862) 2024-11-08 23:28:02 +01:00
Pascal Fischer
67ce14eaea [management] Add peer lock to grpc server (#2859)
* add peer lock to grpc server

* remove sleep and put db update first

* don't export lock method
2024-11-08 18:47:22 +01:00
Pascal Fischer
669904cd06 [management] Remove context from database calls (#2863) 2024-11-08 15:49:00 +01:00
Zoltan Papp
4be826450b [client] Use offload in WireGuard bind receiver (#2815)
Improve the performance on Linux and Android in case of P2P connections
2024-11-07 17:28:38 +01:00
Maycon Santos
738387f2de Add benchmark tests to get account with claims (#2761)
* Add benchmark tests to get account with claims

* add users to account objects

* remove hardcoded env
2024-11-07 17:23:35 +01:00
Pascal Fischer
baf0678ceb [management] Fix potential panic on inactivity expiration log message (#2854) 2024-11-07 16:33:57 +01:00
Pascal Fischer
7fef8f6758 [management] Enforce max conn of 1 for sqlite setups (#2855) 2024-11-07 16:32:35 +01:00
Viktor Liu
6829a64a2d [client] Exclude split default route ip addresses from anonymization (#2853) 2024-11-07 16:29:32 +01:00
Zoltan Papp
cbf500024f [relay-server] Use X-Real-IP in case of reverse proxy (#2848)
* Use X-Real-IP in case of reverse proxy

* Use sprintf
2024-11-07 16:14:53 +01:00
Viktor Liu
509e184e10 [client] Use the prerouting chain to mark for masquerading to support older systems (#2808) 2024-11-07 12:37:04 +01:00
Pascal Fischer
3e88b7c56e [management] Fix network map update on peer validation (#2849) 2024-11-07 09:50:13 +01:00
Maycon Santos
b952d8693d Fix cached device flow oauth (#2833)
This change removes the cached device flow oauth info when a down command is called

Removing the need for the agent to be restarted
2024-11-05 14:51:17 +01:00
Maycon Santos
5b46cc8e9c Avoid failing all other matrix tests if one fails (#2839) 2024-11-05 13:28:42 +01:00
Pascal Fischer
a9d06b883f add all group to add peer affected peers network map check (#2830) 2024-11-01 22:09:08 +01:00
Viktor Liu
5f06b202c3 [client] Log windows panics (#2829) 2024-11-01 15:08:22 +01:00
Zoltan Papp
0eb99c266a Fix unused servers cleanup (#2826)
The cleanup loop did not manage those situations well when a connection failed or 
the connection success but the code did not add a peer connection to it yet.

- in the cleanup loop check if a connection failed to a server
- after adding a foreign server connection force to keep it a minimum 5 sec
2024-11-01 12:33:29 +01:00
Pascal Fischer
bac95ace18 [management] Add DB access duration to logs for context cancel (#2781) 2024-11-01 10:58:39 +01:00
Zoltan Papp
9812de853b Allocate new buffer for every package (#2823) 2024-11-01 00:33:25 +01:00
Zoltan Papp
ad4f0a6fdf [client] Nil check on ICE remote conn (#2806) 2024-10-31 23:18:35 +01:00
Pascal Fischer
4c758c6e52 [management] remove network map diff calculations (#2820) 2024-10-31 19:24:15 +01:00
Misha Bragin
ec5095ba6b Create FUNDING.yml (#2814) 2024-10-30 17:25:02 +01:00
Misha Bragin
49a54624f8 Create funding.json (#2813) 2024-10-30 17:18:27 +01:00
Pascal Fischer
729bcf2b01 [management] add metrics to network map diff (#2811) 2024-10-30 16:53:23 +01:00
Jing
a0cdb58303 [client] Fix the broken dependency gvisor.dev/gvisor (#2789)
The release was removed which is described at
https://github.com/google/gvisor/issues/11085#issuecomment-2438974962.
2024-10-29 20:17:40 +01:00
pascal-fischer
39c99781cb fix meta is equal slices (#2807) 2024-10-29 19:54:38 +01:00
Marco Garcês
01f24907c5 [client] Fix multiple peer name filtering in netbird status command (#2798) 2024-10-29 17:49:41 +01:00
pascal-fischer
10480eb52f [management] Setup key improvements (#2775) 2024-10-28 17:52:23 +01:00
pascal-fischer
1e44c5b574 [client] allow relay leader on iOS (#2795) 2024-10-28 16:55:00 +01:00
Viktor Liu
940f8b4547 [client] Remove legacy forwarding rules in userspace mode (#2782) 2024-10-28 12:29:29 +01:00
Viktor Liu
46e37fa04c [client] Ignore route rules with no sources instead of erroring out (#2786) 2024-10-28 12:28:44 +01:00
Stefano
b9f205b2ce [misc] Update Zitadel from v2.54.10 to v2.64.1 2024-10-28 10:08:17 +01:00
Viktor Liu
0fd874fa45 [client] Make native firewall init fail firewall creation (#2784) 2024-10-28 10:02:27 +01:00
Viktor Liu
8016710d24 [client] Cleanup firewall state on startup (#2768) 2024-10-24 14:46:24 +02:00
Zoltan Papp
4e918e55ba [client] Fix controller re-connection (#2758)
Rethink the peer reconnection implementation
2024-10-24 11:43:14 +02:00
Viktor Liu
869537c951 [client] Cleanup dns and route states on startup (#2757) 2024-10-24 10:53:46 +02:00
Zoltan Papp
44f2ce666e [relay-client] Log exposed address (#2771)
* Log exposed address
2024-10-23 18:32:27 +02:00
pascal-fischer
563dca705c [management] Fix session inactivity response (#2770) 2024-10-23 16:40:15 +02:00
Bethuel Mmbaga
7bda385e1b [management] Optimize network map updates (#2718)
* Skip peer update on unchanged network map (#2236)

* Enhance network updates by skipping unchanged messages

Optimizes the network update process
by skipping updates where no changes in the peer update message received.

* Add unit tests

* add locks

* Improve concurrency and update peer message handling

* Refactor account manager network update tests

* fix test

* Fix inverted network map update condition

* Add default group and policy to test data

* Run peer updates in a separate goroutine

* Refactor

* Refactor lock

* Fix peers update by including NetworkMap and posture Checks

* go mod tidy

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* [management] Skip account peers update if no changes affect peers (#2310)

* Remove incrementing network serial and updating peers after group deletion

* Update account peer if posture check is linked to policy

* Remove account peers update on saving setup key

* Refactor group link checking into re-usable functions

* Add HasPeers function to group

* Refactor group management

* Optimize group change effects on account peers

* Update account peers if ns group has peers

* Refactor group changes

* Optimize account peers update in DNS settings

* Optimize update of account peers on jwt groups sync

* Refactor peer account updates for efficiency

* Optimize peer update on user deletion and changes

* Remove condition check for network serial update

* Optimize account peers updates on route changes

* Remove UpdatePeerSSHKey method

* Remove unused isPolicyRuleGroupsEmpty

* Add tests for peer update behavior on posture check changes

* Add tests for peer update behavior on policy changes

* Add tests for peer update behavior on group changes

* Add tests for peer update behavior on dns settings changes

* Refactor

* Add tests for peer update behavior on name server changes

* Add tests for peer update behavior on user changes

* Add tests for peer update behavior on route changes

* fix tests

* Add tests for peer update behavior on setup key changes

* Add tests for peer update behavior on peers changes

* fix merge

* Fix tests

* go mod tidy

* Add NameServer and Route comparators

* Update network map diff logic with custom comparators

* Add tests

* Refactor duplicate diff handling logic

* fix linter

* fix tests

* Refactor policy group handling and update logic.

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Update route check by checking if group has peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor posture check policy linking logic

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Simplify peer update condition in DNS management

Refactor the condition for updating account peers to remove redundant checks

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add policy tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add posture checks tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix user and setup key tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix account and route tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix typo

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix nameserver tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix routes tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix group tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* upgrade diff package

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix nameserver tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* use generic differ for netip.Addr and netip.Prefix

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* go mod tidy

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add peer tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix management suite tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix postgres tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* enable diff nil structs comparison

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* skip the update only last sent the serial is larger

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor peer and user

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* skip spell check for groupD

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor group, ns group, policy and posture checks

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* skip spell check for GroupD

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* update account policy check before verifying policy status

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Update management/server/route_test.go

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>

* Update management/server/route_test.go

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>

* Update management/server/route_test.go

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>

* Update management/server/route_test.go

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>

* Update management/server/route_test.go

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>

* add tests missing tests for dns setting groups

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add tests for posture checks changes

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add ns group and policy tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add route and group tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* increase Linux test timeout to 10 minutes

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Run diff for client posture checks only

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add panic recovery and detailed logging in peer update comparison

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2024-10-23 13:05:02 +03:00
Zoltan Papp
30ebcf38c7 [client] Eliminate UDP proxy in user-space mode (#2712)
In the case of user space WireGuard mode, use in-memory proxy between the TURN/Relay connection and the WireGuard Bind. We keep the UDP proxy and eBPF proxy for kernel mode.

The key change is the new wgproxy/bind and the iface/bind/ice_bind changes. Everything else is just to fulfill the dependencies.
2024-10-22 20:53:14 +02:00
Bethuel Mmbaga
0106a95f7a lock account and use transaction (#2767)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-22 13:29:17 +03:00
216 changed files with 9667 additions and 3159 deletions

3
.github/FUNDING.yml vendored Normal file
View File

@@ -0,0 +1,3 @@
# These are supported funding model platforms
github: [netbirdio]

View File

@@ -13,6 +13,7 @@ concurrency:
jobs: jobs:
test: test:
strategy: strategy:
fail-fast: false
matrix: matrix:
arch: [ '386','amd64' ] arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres'] store: [ 'sqlite', 'postgres']
@@ -49,7 +50,7 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./... run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
test_client_on_docker: test_client_on_docker:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
@@ -79,9 +80,6 @@ jobs:
- name: check git status - name: check git status
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Generate Iface Test bin
run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./client/iface/
- name: Generate Shared Sock Test bin - name: Generate Shared Sock Test bin
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
@@ -98,7 +96,7 @@ jobs:
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
- name: Generate Peer Test bin - name: Generate Peer Test bin
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/... run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/
- run: chmod +x *testing.bin - run: chmod +x *testing.bin
@@ -106,7 +104,7 @@ jobs:
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1 run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Iface tests in docker - name: Run Iface tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin -test.timeout 5m -test.parallel 1 run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/...
- name: Run RouteManager tests in docker - name: Run RouteManager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1 run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell - name: codespell
uses: codespell-project/actions-codespell@v2 uses: codespell-project/actions-codespell@v2
with: with:
ignore_words_list: erro,clienta,hastable,iif ignore_words_list: erro,clienta,hastable,iif,groupd
skip: go.mod,go.sum skip: go.mod,go.sum
only_warn: 1 only_warn: 1
golangci: golangci:

View File

@@ -9,7 +9,7 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.0.16" SIGN_PIPE_VER: "v0.0.17"
GORELEASER_VER: "v2.3.2" GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)" COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"

View File

@@ -19,6 +19,10 @@
<br> <br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ"> <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/> <img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a>
<br>
<a href="https://gurubase.io/g/netbird">
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
</a> </a>
</p> </p>
</div> </div>

View File

@@ -201,6 +201,8 @@ func isWellKnown(addr netip.Addr) bool {
"2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6 "2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6
"9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4 "9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4
"2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6 "2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6
"128.0.0.0", "8000::", // 2nd split subnet for default routes
} }
if slices.Contains(wellKnown, addr.String()) { if slices.Contains(wellKnown, addr.String()) {

View File

@@ -2,6 +2,7 @@ package cmd
import ( import (
"context" "context"
"sync"
"github.com/kardianos/service" "github.com/kardianos/service"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -13,10 +14,11 @@ import (
) )
type program struct { type program struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
serv *grpc.Server serv *grpc.Server
serverInstance *server.Server serverInstance *server.Server
serverInstanceMu sync.Mutex
} }
func newProgram(ctx context.Context, cancel context.CancelFunc) *program { func newProgram(ctx context.Context, cancel context.CancelFunc) *program {

View File

@@ -61,7 +61,9 @@ func (p *program) Start(svc service.Service) error {
} }
proto.RegisterDaemonServiceServer(p.serv, serverInstance) proto.RegisterDaemonServiceServer(p.serv, serverInstance)
p.serverInstanceMu.Lock()
p.serverInstance = serverInstance p.serverInstance = serverInstance
p.serverInstanceMu.Unlock()
log.Printf("started daemon server: %v", split[1]) log.Printf("started daemon server: %v", split[1])
if err := p.serv.Serve(listen); err != nil { if err := p.serv.Serve(listen); err != nil {
@@ -72,6 +74,7 @@ func (p *program) Start(svc service.Service) error {
} }
func (p *program) Stop(srv service.Service) error { func (p *program) Stop(srv service.Service) error {
p.serverInstanceMu.Lock()
if p.serverInstance != nil { if p.serverInstance != nil {
in := new(proto.DownRequest) in := new(proto.DownRequest)
_, err := p.serverInstance.Down(p.ctx, in) _, err := p.serverInstance.Down(p.ctx, in)
@@ -79,6 +82,7 @@ func (p *program) Stop(srv service.Service) error {
log.Errorf("failed to stop daemon: %v", err) log.Errorf("failed to stop daemon: %v", err)
} }
} }
p.serverInstanceMu.Unlock()
p.cancel() p.cancel()

View File

@@ -680,7 +680,7 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool { func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
statusEval := false statusEval := false
ipEval := false ipEval := false
nameEval := false nameEval := true
if statusFilter != "" { if statusFilter != "" {
lowerStatusFilter := strings.ToLower(statusFilter) lowerStatusFilter := strings.ToLower(statusFilter)
@@ -700,11 +700,13 @@ func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
if len(prefixNamesFilter) > 0 { if len(prefixNamesFilter) > 0 {
for prefixNameFilter := range prefixNamesFilterMap { for prefixNameFilter := range prefixNamesFilterMap {
if !strings.HasPrefix(peerState.Fqdn, prefixNameFilter) { if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
nameEval = true nameEval = false
break break
} }
} }
} else {
nameEval = false
} }
return statusEval || ipEval || nameEval return statusEval || ipEval || nameEval

View File

@@ -3,7 +3,6 @@
package firewall package firewall
import ( import (
"context"
"fmt" "fmt"
"runtime" "runtime"
@@ -11,10 +10,11 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
// NewFirewall creates a firewall manager instance // NewFirewall creates a firewall manager instance
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) { func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) {
if !iface.IsUserspaceBind() { if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
} }

View File

@@ -3,7 +3,7 @@
package firewall package firewall
import ( import (
"context" "errors"
"fmt" "fmt"
"os" "os"
@@ -15,6 +15,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables" nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const ( const (
@@ -32,54 +33,65 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type // FWType is the type for the firewall type
type FWType int type FWType int
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) { func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables // on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic // in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers // so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall // for the userspace packet filtering firewall
var fm firewall.Manager fm, err := createNativeFirewall(iface, stateManager)
var errFw error
if !iface.IsUserspaceBind() {
return fm, err
}
if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
}
return createUserspaceFirewall(iface, fm)
}
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
fm, err := createFW(iface)
if err != nil {
return nil, fmt.Errorf("create firewall: %s", err)
}
if err = fm.Init(stateManager); err != nil {
return nil, fmt.Errorf("init firewall: %s", err)
}
return fm, nil
}
func createFW(iface IFaceMapper) (firewall.Manager, error) {
switch check() { switch check() {
case IPTABLES: case IPTABLES:
log.Info("creating an iptables firewall manager") log.Info("creating an iptables firewall manager")
fm, errFw = nbiptables.Create(context, iface) return nbiptables.Create(iface)
if errFw != nil {
log.Errorf("failed to create iptables manager: %s", errFw)
}
case NFTABLES: case NFTABLES:
log.Info("creating an nftables firewall manager") log.Info("creating an nftables firewall manager")
fm, errFw = nbnftables.Create(context, iface) return nbnftables.Create(iface)
if errFw != nil {
log.Errorf("failed to create nftables manager: %s", errFw)
}
default: default:
errFw = fmt.Errorf("no firewall manager found")
log.Info("no firewall manager found, trying to use userspace packet filtering firewall") log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
return nil, errors.New("no firewall manager found")
}
}
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) {
var errUsp error
if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
} else {
fm, errUsp = uspfilter.Create(iface)
} }
if iface.IsUserspaceBind() { if errUsp != nil {
var errUsp error return nil, fmt.Errorf("create userspace firewall: %s", errUsp)
if errFw == nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
} else {
fm, errUsp = uspfilter.Create(iface)
}
if errUsp != nil {
log.Debugf("failed to create userspace filtering firewall: %s", errUsp)
return nil, errUsp
}
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
return fm, nil
} }
if errFw != nil { if err := fm.AllowNetbird(); err != nil {
return nil, errFw log.Errorf("failed to allow netbird interface traffic: %v", err)
} }
return fm, nil return fm, nil
} }

View File

@@ -11,6 +11,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
@@ -22,6 +23,8 @@ const (
chainNameOutputRules = "NETBIRD-ACL-OUTPUT" chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
) )
type aclEntries map[string][][]string
type entry struct { type entry struct {
spec []string spec []string
position int position int
@@ -32,9 +35,11 @@ type aclManager struct {
wgIface iFaceMapper wgIface iFaceMapper
routingFwChainName string routingFwChainName string
entries map[string][][]string entries aclEntries
optionalEntries map[string][]entry optionalEntries map[string][]entry
ipsetStore *ipsetStore ipsetStore *ipsetStore
stateManager *statemanager.Manager
} }
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) { func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
@@ -48,24 +53,30 @@ func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routi
ipsetStore: newIpsetStore(), ipsetStore: newIpsetStore(),
} }
err := ipset.Init() if err := ipset.Init(); err != nil {
if err != nil { return nil, fmt.Errorf("init ipset: %w", err)
return nil, fmt.Errorf("failed to init ipset: %w", err)
} }
return m, nil
}
func (m *aclManager) init(stateManager *statemanager.Manager) error {
m.stateManager = stateManager
m.seedInitialEntries() m.seedInitialEntries()
m.seedInitialOptionalEntries() m.seedInitialOptionalEntries()
err = m.cleanChains() if err := m.cleanChains(); err != nil {
if err != nil { return fmt.Errorf("clean chains: %w", err)
return nil, err
} }
err = m.createDefaultChains() if err := m.createDefaultChains(); err != nil {
if err != nil { return fmt.Errorf("create default chains: %w", err)
return nil, err
} }
return m, nil
m.updateState()
return nil
} }
func (m *aclManager) AddPeerFiltering( func (m *aclManager) AddPeerFiltering(
@@ -146,6 +157,8 @@ func (m *aclManager) AddPeerFiltering(
chain: chain, chain: chain,
} }
m.updateState()
return []firewall.Rule{rule}, nil return []firewall.Rule{rule}, nil
} }
@@ -180,15 +193,23 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
} }
} }
err := m.iptablesClient.Delete(tableName, r.chain, r.specs...) if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
if err != nil { return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err)
} }
return err
m.updateState()
return nil
} }
func (m *aclManager) Reset() error { func (m *aclManager) Reset() error {
return m.cleanChains() if err := m.cleanChains(); err != nil {
return fmt.Errorf("clean chains: %w", err)
}
m.updateState()
return nil
} }
// todo write less destructive cleanup mechanism // todo write less destructive cleanup mechanism
@@ -331,14 +352,14 @@ func (m *aclManager) seedInitialEntries() {
func (m *aclManager) seedInitialOptionalEntries() { func (m *aclManager) seedInitialOptionalEntries() {
m.optionalEntries["FORWARD"] = []entry{ m.optionalEntries["FORWARD"] = []entry{
{ {
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark), "-j", chainNameInputRules}, spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", chainNameInputRules},
position: 2, position: 2,
}, },
} }
m.optionalEntries["PREROUTING"] = []entry{ m.optionalEntries["PREROUTING"] = []entry{
{ {
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmark)}, spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected)},
position: 1, position: 1,
}, },
} }
@@ -348,6 +369,32 @@ func (m *aclManager) appendToEntries(chainName string, spec []string) {
m.entries[chainName] = append(m.entries[chainName], spec) m.entries[chainName] = append(m.entries[chainName], spec)
} }
func (m *aclManager) updateState() {
if m.stateManager == nil {
return
}
var currentState *ShutdownState
if existing := m.stateManager.GetState(currentState); existing != nil {
if existingState, ok := existing.(*ShutdownState); ok {
currentState = existingState
}
}
if currentState == nil {
currentState = &ShutdownState{}
}
currentState.Lock()
defer currentState.Unlock()
currentState.ACLEntries = m.entries
currentState.ACLIPsetStore = m.ipsetStore
if err := m.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err)
}
}
// filterRuleSpecs returns the specs of a filtering rule // filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs( func filterRuleSpecs(
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string, ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,

View File

@@ -8,10 +8,13 @@ import (
"sync" "sync"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
// Manager of iptables firewall // Manager of iptables firewall
@@ -33,10 +36,10 @@ type iFaceMapper interface {
} }
// Create iptables firewall manager // Create iptables firewall manager
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { func Create(wgIface iFaceMapper) (*Manager, error) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil { if err != nil {
return nil, fmt.Errorf("iptables is not installed in the system or not supported") return nil, fmt.Errorf("init iptables: %w", err)
} }
m := &Manager{ m := &Manager{
@@ -44,20 +47,51 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
ipv4Client: iptablesClient, ipv4Client: iptablesClient,
} }
m.router, err = newRouter(context, iptablesClient, wgIface) m.router, err = newRouter(iptablesClient, wgIface)
if err != nil { if err != nil {
log.Debugf("failed to initialize route related chains: %s", err) return nil, fmt.Errorf("create router: %w", err)
return nil, err
} }
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD) m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
if err != nil { if err != nil {
log.Debugf("failed to initialize ACL manager: %s", err) return nil, fmt.Errorf("create acl manager: %w", err)
return nil, err
} }
return m, nil return m, nil
} }
func (m *Manager) Init(stateManager *statemanager.Manager) error {
state := &ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
},
}
stateManager.RegisterState(state)
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update state: %v", err)
}
if err := m.router.init(stateManager); err != nil {
return fmt.Errorf("router init: %w", err)
}
if err := m.aclMgr.init(stateManager); err != nil {
// TODO: cleanup router
return fmt.Errorf("acl manager init: %w", err)
}
// persist early to ensure cleanup of chains
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
}()
return nil
}
// AddPeerFiltering adds a rule to the firewall // AddPeerFiltering adds a rule to the firewall
// //
// Comment will be ignored because some system this feature is not supported // Comment will be ignored because some system this feature is not supported
@@ -133,20 +167,27 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
} }
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset() error { func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
errAcl := m.aclMgr.Reset() var merr *multierror.Error
if errAcl != nil {
log.Errorf("failed to clean up ACL rules from firewall: %s", errAcl) if err := m.aclMgr.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
} }
errMgr := m.router.Reset() if err := m.router.Reset(); err != nil {
if errMgr != nil { merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
log.Errorf("failed to clean up router rules from firewall: %s", errMgr)
return errMgr
} }
return errAcl
// attempt to delete state only if all other operations succeeded
if merr == nil {
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete state: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
} }
// AllowNetbird allows netbird interface traffic // AllowNetbird allows netbird interface traffic

View File

@@ -1,7 +1,6 @@
package iptables package iptables
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"testing" "testing"
@@ -56,13 +55,14 @@ func TestIptablesManager(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// just check on the local interface // just check on the local interface
manager, err := Create(context.Background(), ifaceMock) manager, err := Create(ifaceMock)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
err := manager.Reset() err := manager.Reset(nil)
require.NoError(t, err, "clear the manager state") require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -122,7 +122,7 @@ func TestIptablesManager(t *testing.T) {
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic") _, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Reset() err = manager.Reset(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules) ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
@@ -154,13 +154,14 @@ func TestIptablesManagerIPSet(t *testing.T) {
} }
// just check on the local interface // just check on the local interface
manager, err := Create(context.Background(), mock) manager, err := Create(mock)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
err := manager.Reset() err := manager.Reset(nil)
require.NoError(t, err, "clear the manager state") require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -219,7 +220,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
}) })
t.Run("reset check", func(t *testing.T) { t.Run("reset check", func(t *testing.T) {
err = manager.Reset() err = manager.Reset(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
}) })
} }
@@ -251,12 +252,13 @@ func TestIptablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(context.Background(), mock) manager, err := Create(mock)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
err := manager.Reset() err := manager.Reset(nil)
require.NoError(t, err, "clear the manager state") require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second) time.Sleep(time.Second)

View File

@@ -3,7 +3,6 @@
package iptables package iptables
import ( import (
"context"
"fmt" "fmt"
"net/netip" "net/netip"
"strconv" "strconv"
@@ -18,22 +17,25 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
) "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
const (
ipv4Nat = "netbird-rt-nat"
) )
// constants needed to manage and create iptable rules // constants needed to manage and create iptable rules
const ( const (
tableFilter = "filter" tableFilter = "filter"
tableNat = "nat" tableNat = "nat"
tableMangle = "mangle"
chainPOSTROUTING = "POSTROUTING" chainPOSTROUTING = "POSTROUTING"
chainPREROUTING = "PREROUTING"
chainRTNAT = "NETBIRD-RT-NAT" chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWD = "NETBIRD-RT-FWD" chainRTFWD = "NETBIRD-RT-FWD"
chainRTPRE = "NETBIRD-RT-PRE"
routingFinalForwardJump = "ACCEPT" routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE" routingFinalNatJump = "MASQUERADE"
jumpPre = "jump-pre"
jumpNat = "jump-nat"
matchSet = "--match-set" matchSet = "--match-set"
) )
@@ -48,28 +50,31 @@ type routeFilteringRuleParams struct {
SetName string SetName string
} }
type routeRules map[string][]string
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
type router struct { type router struct {
ctx context.Context
stop context.CancelFunc
iptablesClient *iptables.IPTables iptablesClient *iptables.IPTables
rules map[string][]string rules routeRules
ipsetCounter *refcounter.Counter[string, []netip.Prefix, struct{}] ipsetCounter *ipsetCounter
wgIface iFaceMapper wgIface iFaceMapper
legacyManagement bool legacyManagement bool
stateManager *statemanager.Manager
} }
func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
ctx, cancel := context.WithCancel(parentCtx)
r := &router{ r := &router{
ctx: ctx,
stop: cancel,
iptablesClient: iptablesClient, iptablesClient: iptablesClient,
rules: make(map[string][]string), rules: make(map[string][]string),
wgIface: wgIface, wgIface: wgIface,
} }
r.ipsetCounter = refcounter.New( r.ipsetCounter = refcounter.New(
r.createIpSet, func(name string, sources []netip.Prefix) (struct{}, error) {
return struct{}{}, r.createIpSet(name, sources)
},
func(name string, _ struct{}) error { func(name string, _ struct{}) error {
return r.deleteIpSet(name) return r.deleteIpSet(name)
}, },
@@ -79,16 +84,23 @@ func newRouter(parentCtx context.Context, iptablesClient *iptables.IPTables, wgI
return nil, fmt.Errorf("init ipset: %w", err) return nil, fmt.Errorf("init ipset: %w", err)
} }
err := r.cleanUpDefaultForwardRules() return r, nil
if err != nil { }
log.Errorf("cleanup routing rules: %s", err)
return nil, err func (r *router) init(stateManager *statemanager.Manager) error {
r.stateManager = stateManager
if err := r.cleanUpDefaultForwardRules(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
} }
err = r.createContainers()
if err != nil { if err := r.createContainers(); err != nil {
log.Errorf("create containers for route: %s", err) return fmt.Errorf("create containers: %w", err)
} }
return r, err
r.updateState()
return nil
} }
func (r *router) AddRouteFiltering( func (r *router) AddRouteFiltering(
@@ -129,6 +141,8 @@ func (r *router) AddRouteFiltering(
r.rules[string(ruleKey)] = rule r.rules[string(ruleKey)] = rule
r.updateState()
return ruleKey, nil return ruleKey, nil
} }
@@ -152,6 +166,8 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
log.Debugf("route rule %s not found", ruleKey) log.Debugf("route rule %s not found", ruleKey)
} }
r.updateState()
return nil return nil
} }
@@ -164,18 +180,18 @@ func (r *router) findSetNameInRule(rule []string) string {
return "" return ""
} }
func (r *router) createIpSet(setName string, sources []netip.Prefix) (struct{}, error) { func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil { if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
return struct{}{}, fmt.Errorf("create set %s: %w", setName, err) return fmt.Errorf("create set %s: %w", setName, err)
} }
for _, prefix := range sources { for _, prefix := range sources {
if err := ipset.AddPrefix(setName, prefix); err != nil { if err := ipset.AddPrefix(setName, prefix); err != nil {
return struct{}{}, fmt.Errorf("add element to set %s: %w", setName, err) return fmt.Errorf("add element to set %s: %w", setName, err)
} }
} }
return struct{}{}, nil return nil
} }
func (r *router) deleteIpSet(setName string) error { func (r *router) deleteIpSet(setName string) error {
@@ -206,6 +222,8 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
return fmt.Errorf("add inverse nat rule: %w", err) return fmt.Errorf("add inverse nat rule: %w", err)
} }
r.updateState()
return nil return nil
} }
@@ -223,6 +241,8 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return fmt.Errorf("remove legacy routing rule: %w", err) return fmt.Errorf("remove legacy routing rule: %w", err)
} }
r.updateState()
return nil return nil
} }
@@ -278,8 +298,13 @@ func (r *router) RemoveAllLegacyRouteRules() error {
} }
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil { if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
} else {
delete(r.rules, k)
} }
} }
r.updateState()
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
@@ -294,28 +319,31 @@ func (r *router) Reset() error {
merr = multierror.Append(merr, err) merr = multierror.Append(merr, err)
} }
r.updateState()
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
func (r *router) cleanUpDefaultForwardRules() error { func (r *router) cleanUpDefaultForwardRules() error {
err := r.cleanJumpRules() if err := r.cleanJumpRules(); err != nil {
if err != nil { return fmt.Errorf("clean jump rules: %w", err)
return err
} }
log.Debug("flushing routing related tables") log.Debug("flushing routing related tables")
for _, chain := range []string{chainRTFWD, chainRTNAT} { for _, chainInfo := range []struct {
table := r.getTableForChain(chain) chain string
table string
ok, err := r.iptablesClient.ChainExists(table, chain) }{
{chainRTFWD, tableFilter},
{chainRTNAT, tableNat},
{chainRTPRE, tableMangle},
} {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
if err != nil { if err != nil {
log.Errorf("failed check chain %s, error: %v", chain, err) return fmt.Errorf("check chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
return err
} else if ok { } else if ok {
err = r.iptablesClient.ClearAndDeleteChain(table, chain) if err = r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
if err != nil { return fmt.Errorf("clear and delete chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
log.Errorf("failed cleaning chain %s, error: %v", chain, err)
return err
} }
} }
} }
@@ -324,9 +352,16 @@ func (r *router) cleanUpDefaultForwardRules() error {
} }
func (r *router) createContainers() error { func (r *router) createContainers() error {
for _, chain := range []string{chainRTFWD, chainRTNAT} { for _, chainInfo := range []struct {
if err := r.createAndSetupChain(chain); err != nil { chain string
return fmt.Errorf("create chain %s: %w", chain, err) table string
}{
{chainRTFWD, tableFilter},
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
} {
if err := r.createAndSetupChain(chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
} }
} }
@@ -334,6 +369,10 @@ func (r *router) createContainers() error {
return fmt.Errorf("insert established rule: %w", err) return fmt.Errorf("insert established rule: %w", err)
} }
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add static nat rules: %w", err)
}
if err := r.addJumpRules(); err != nil { if err := r.addJumpRules(); err != nil {
return fmt.Errorf("add jump rules: %w", err) return fmt.Errorf("add jump rules: %w", err)
} }
@@ -341,6 +380,32 @@ func (r *router) createContainers() error {
return nil return nil
} }
func (r *router) addPostroutingRules() error {
// First rule for outbound masquerade
rule1 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
"!", "-o", "lo",
"-j", routingFinalNatJump,
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule1...); err != nil {
return fmt.Errorf("add outbound masquerade rule: %v", err)
}
r.rules["static-nat-outbound"] = rule1
// Second rule for return traffic masquerade
rule2 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
"-o", r.wgIface.Name(),
"-j", routingFinalNatJump,
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule2...); err != nil {
return fmt.Errorf("add return masquerade rule: %v", err)
}
r.rules["static-nat-return"] = rule2
return nil
}
func (r *router) createAndSetupChain(chain string) error { func (r *router) createAndSetupChain(chain string) error {
table := r.getTableForChain(chain) table := r.getTableForChain(chain)
@@ -352,10 +417,14 @@ func (r *router) createAndSetupChain(chain string) error {
} }
func (r *router) getTableForChain(chain string) string { func (r *router) getTableForChain(chain string) string {
if chain == chainRTNAT { switch chain {
case chainRTNAT:
return tableNat return tableNat
case chainRTPRE:
return tableMangle
default:
return tableFilter
} }
return tableFilter
} }
func (r *router) insertEstablishedRule(chain string) error { func (r *router) insertEstablishedRule(chain string) error {
@@ -373,25 +442,39 @@ func (r *router) insertEstablishedRule(chain string) error {
} }
func (r *router) addJumpRules() error { func (r *router) addJumpRules() error {
rule := []string{"-j", chainRTNAT} // Jump to NAT chain
err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...) natRule := []string{"-j", chainRTNAT}
if err != nil { if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
return err return fmt.Errorf("add nat jump rule: %v", err)
} }
r.rules[ipv4Nat] = rule r.rules[jumpNat] = natRule
// Jump to prerouting chain
preRule := []string{"-j", chainRTPRE}
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
return fmt.Errorf("add prerouting jump rule: %v", err)
}
r.rules[jumpPre] = preRule
return nil return nil
} }
func (r *router) cleanJumpRules() error { func (r *router) cleanJumpRules() error {
rule, found := r.rules[ipv4Nat] for _, ruleKey := range []string{jumpNat, jumpPre} {
if found { if rule, exists := r.rules[ruleKey]; exists {
err := r.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...) table := tableNat
if err != nil { chain := chainPOSTROUTING
return fmt.Errorf("failed cleaning rule from chain %s, err: %v", chainPOSTROUTING, err) if ruleKey == jumpPre {
table = tableMangle
chain = chainPREROUTING
}
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
return fmt.Errorf("delete rule from chain %s in table %s, err: %v", chain, table, err)
}
delete(r.rules, ruleKey)
} }
} }
return nil return nil
} }
@@ -399,19 +482,35 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair) ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil { if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err) return fmt.Errorf("error while removing existing marking rule for %s: %v", pair.Destination, err)
} }
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
} }
rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, r.wgIface.Name(), pair.Inverse) markValue := nbnet.PreroutingFwmarkMasquerade
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule...); err != nil { if pair.Inverse {
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err) markValue = nbnet.PreroutingFwmarkMasqueradeReturn
}
rule := []string{"-i", r.wgIface.Name()}
if pair.Inverse {
rule = []string{"!", "-i", r.wgIface.Name()}
}
rule = append(rule,
"-m", "conntrack",
"--ctstate", "NEW",
"-s", pair.Source.String(),
"-d", pair.Destination.String(),
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
)
if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
} }
r.rules[ruleKey] = rule r.rules[ruleKey] = rule
return nil return nil
} }
@@ -419,26 +518,41 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair) ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableNat, chainRTNAT, rule...); err != nil { if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while removing existing nat rule for %s: %v", pair.Destination, err) return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
} }
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
} else { } else {
log.Debugf("nat rule %s not found", ruleKey) log.Debugf("marking rule %s not found", ruleKey)
} }
return nil return nil
} }
func genRuleSpec(jump string, source, destination netip.Prefix, intf string, inverse bool) []string { func (r *router) updateState() {
intdir := "-i" if r.stateManager == nil {
lointdir := "-o" return
if inverse { }
intdir = "-o"
lointdir = "-i" var currentState *ShutdownState
if existing := r.stateManager.GetState(currentState); existing != nil {
if existingState, ok := existing.(*ShutdownState); ok {
currentState = existingState
}
}
if currentState == nil {
currentState = &ShutdownState{}
}
currentState.Lock()
defer currentState.Unlock()
currentState.RouteRules = r.rules
currentState.RouteIPsetCounter = r.ipsetCounter
if err := r.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err)
} }
return []string{intdir, intf, "!", lointdir, "lo", "-s", source.String(), "-d", destination.String(), "-j", jump}
} }
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {

View File

@@ -3,18 +3,18 @@
package iptables package iptables
import ( import (
"context" "fmt"
"net/netip" "net/netip"
"os/exec" "os/exec"
"testing" "testing"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test" "github.com/netbirdio/netbird/client/firewall/test"
nbnet "github.com/netbirdio/netbird/util/net"
) )
func isIptablesSupported() bool { func isIptablesSupported() bool {
@@ -30,18 +30,29 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client") require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) manager, err := newRouter(iptablesClient, ifaceMock)
require.NoError(t, err, "should return a valid iptables manager") require.NoError(t, err, "should return a valid iptables manager")
require.NoError(t, manager.init(nil))
defer func() { defer func() {
_ = manager.Reset() assert.NoError(t, manager.Reset(), "shouldn't return error")
}() }()
require.Len(t, manager.rules, 2, "should have created rules map") // Now 5 rules:
// 1. established rule in forward chain
// 2. jump rule to NAT chain
// 3. jump rule to PRE chain
// 4. static outbound masquerade rule
// 5. static return masquerade rule
require.Len(t, manager.rules, 5, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...) exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
require.True(t, exists, "postrouting rule should exist") require.True(t, exists, "postrouting jump rule should exist")
exists, err = manager.iptablesClient.Exists(tableMangle, chainPREROUTING, "-j", chainRTPRE)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPREROUTING)
require.True(t, exists, "prerouting jump rule should exist")
pair := firewall.RouterPair{ pair := firewall.RouterPair{
ID: "abc", ID: "abc",
@@ -49,22 +60,15 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
Destination: netip.MustParsePrefix("100.100.100.0/24"), Destination: netip.MustParsePrefix("100.100.100.0/24"),
Masquerade: true, Masquerade: true,
} }
forward4Rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...) err = manager.AddNatRule(pair)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "adding NAT rule should not return error")
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination, ifaceMock.Name(), false)
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
require.NoError(t, err, "inserting rule should not return error")
err = manager.Reset() err = manager.Reset()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
} }
func TestIptablesManager_AddNatRule(t *testing.T) { func TestIptablesManager_AddNatRule(t *testing.T) {
if !isIptablesSupported() { if !isIptablesSupported() {
t.SkipNow() t.SkipNow()
} }
@@ -74,56 +78,71 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client") require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) manager, err := newRouter(iptablesClient, ifaceMock)
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
defer func() { defer func() {
err := manager.Reset() assert.NoError(t, manager.Reset(), "shouldn't return error")
if err != nil {
log.Errorf("failed to reset iptables manager: %s", err)
}
}() }()
err = manager.AddNatRule(testCase.InputPair) err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "forwarding pair should be inserted") require.NoError(t, err, "marking rule should be inserted")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair) natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false) markingRule := []string{
"-i", ifaceMock.Name(),
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...) "-m", "conntrack",
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) "--ctstate", "NEW",
if testCase.InputPair.Masquerade { "-s", testCase.InputPair.Source.String(),
require.True(t, exists, "nat rule should be created") "-d", testCase.InputPair.Destination.String(),
foundNatRule, foundNat := manager.rules[natRuleKey] "-j", "MARK", "--set-mark",
require.True(t, foundNat, "nat rule should exist in the map") fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match")
} else {
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[natRuleKey]
require.False(t, foundNat, "nat rule should not exist in the map")
} }
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
if testCase.InputPair.Masquerade { if testCase.InputPair.Masquerade {
require.True(t, exists, "income nat rule should be created") require.True(t, exists, "marking rule should be created")
foundNatRule, foundNat := manager.rules[inNatRuleKey] foundRule, found := manager.rules[natRuleKey]
require.True(t, foundNat, "income nat rule should exist in the map") require.True(t, found, "marking rule should exist in the map")
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match") require.Equal(t, markingRule, foundRule, "stored marking rule should match")
} else { } else {
require.False(t, exists, "nat rule should not be created") require.False(t, exists, "marking rule should not be created")
_, foundNat := manager.rules[inNatRuleKey] _, found := manager.rules[natRuleKey]
require.False(t, foundNat, "income nat rule should not exist in the map") require.False(t, found, "marking rule should not exist in the map")
}
// Check inverse rule
inversePair := firewall.GetInversePair(testCase.InputPair)
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", inversePair.Source.String(),
"-d", inversePair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
}
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
if testCase.InputPair.Masquerade {
require.True(t, exists, "inverse marking rule should be created")
foundRule, found := manager.rules[inverseRuleKey]
require.True(t, found, "inverse marking rule should exist in the map")
require.Equal(t, inverseMarkingRule, foundRule, "stored inverse marking rule should match")
} else {
require.False(t, exists, "inverse marking rule should not be created")
_, found := manager.rules[inverseRuleKey]
require.False(t, found, "inverse marking rule should not exist in the map")
} }
}) })
} }
} }
func TestIptablesManager_RemoveNatRule(t *testing.T) { func TestIptablesManager_RemoveNatRule(t *testing.T) {
if !isIptablesSupported() { if !isIptablesSupported() {
t.SkipNow() t.SkipNow()
} }
@@ -132,45 +151,56 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
manager, err := newRouter(context.TODO(), iptablesClient, ifaceMock) manager, err := newRouter(iptablesClient, ifaceMock)
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
defer func() { defer func() {
_ = manager.Reset() assert.NoError(t, manager.Reset(), "shouldn't return error")
}() }()
require.NoError(t, err, "shouldn't return error") err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "should add NAT rule without error")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination, ifaceMock.Name(), false)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
require.NoError(t, err, "inserting rule should not return error")
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair))
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInversePair(testCase.InputPair).Source, firewall.GetInversePair(testCase.InputPair).Destination, ifaceMock.Name(), true)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
require.NoError(t, err, "inserting rule should not return error")
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
err = manager.RemoveNatRule(testCase.InputPair) err = manager.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
exists, err := iptablesClient.Exists(tableNat, chainRTNAT, natRule...) natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) markingRule := []string{
require.False(t, exists, "nat rule should not exist") "-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", testCase.InputPair.Source.String(),
"-d", testCase.InputPair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
}
exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
require.False(t, exists, "marking rule should not exist")
_, found := manager.rules[natRuleKey] _, found := manager.rules[natRuleKey]
require.False(t, found, "nat rule should exist in the manager map") require.False(t, found, "marking rule should not exist in the manager map")
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...) // Check inverse rule removal
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT) inversePair := firewall.GetInversePair(testCase.InputPair)
require.False(t, exists, "income nat rule should not exist") inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", inversePair.Source.String(),
"-d", inversePair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
}
_, found = manager.rules[inNatRuleKey] exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
require.False(t, found, "income nat rule should exist in the manager map") require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
require.False(t, exists, "inverse marking rule should not exist")
_, found = manager.rules[inverseRuleKey]
require.False(t, found, "inverse marking rule should not exist in the map")
}) })
} }
} }
@@ -183,8 +213,9 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "Failed to create iptables client") require.NoError(t, err, "Failed to create iptables client")
r, err := newRouter(context.Background(), iptablesClient, ifaceMock) r, err := newRouter(iptablesClient, ifaceMock)
require.NoError(t, err, "Failed to create router manager") require.NoError(t, err, "Failed to create router manager")
require.NoError(t, r.init(nil))
defer func() { defer func() {
err := r.Reset() err := r.Reset()

View File

@@ -1,14 +1,16 @@
package iptables package iptables
import "encoding/json"
type ipList struct { type ipList struct {
ips map[string]struct{} ips map[string]struct{}
} }
func newIpList(ip string) ipList { func newIpList(ip string) *ipList {
ips := make(map[string]struct{}) ips := make(map[string]struct{})
ips[ip] = struct{}{} ips[ip] = struct{}{}
return ipList{ return &ipList{
ips: ips, ips: ips,
} }
} }
@@ -17,27 +19,47 @@ func (s *ipList) addIP(ip string) {
s.ips[ip] = struct{}{} s.ips[ip] = struct{}{}
} }
// MarshalJSON implements json.Marshaler
func (s *ipList) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
IPs map[string]struct{} `json:"ips"`
}{
IPs: s.ips,
})
}
// UnmarshalJSON implements json.Unmarshaler
func (s *ipList) UnmarshalJSON(data []byte) error {
temp := struct {
IPs map[string]struct{} `json:"ips"`
}{}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
s.ips = temp.IPs
return nil
}
type ipsetStore struct { type ipsetStore struct {
ipsets map[string]ipList // ipsetName -> ruleset ipsets map[string]*ipList
} }
func newIpsetStore() *ipsetStore { func newIpsetStore() *ipsetStore {
return &ipsetStore{ return &ipsetStore{
ipsets: make(map[string]ipList), ipsets: make(map[string]*ipList),
} }
} }
func (s *ipsetStore) ipset(ipsetName string) (ipList, bool) { func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
r, ok := s.ipsets[ipsetName] r, ok := s.ipsets[ipsetName]
return r, ok return r, ok
} }
func (s *ipsetStore) addIpList(ipsetName string, list ipList) { func (s *ipsetStore) addIpList(ipsetName string, list *ipList) {
s.ipsets[ipsetName] = list s.ipsets[ipsetName] = list
} }
func (s *ipsetStore) deleteIpset(ipsetName string) { func (s *ipsetStore) deleteIpset(ipsetName string) {
s.ipsets[ipsetName] = ipList{}
delete(s.ipsets, ipsetName) delete(s.ipsets, ipsetName)
} }
@@ -48,3 +70,24 @@ func (s *ipsetStore) ipsetNames() []string {
} }
return names return names
} }
// MarshalJSON implements json.Marshaler
func (s *ipsetStore) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
IPSets map[string]*ipList `json:"ipsets"`
}{
IPSets: s.ipsets,
})
}
// UnmarshalJSON implements json.Unmarshaler
func (s *ipsetStore) UnmarshalJSON(data []byte) error {
temp := struct {
IPSets map[string]*ipList `json:"ipsets"`
}{}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
s.ipsets = temp.IPSets
return nil
}

View File

@@ -0,0 +1,70 @@
package iptables
import (
"fmt"
"sync"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
}
func (i *InterfaceState) Name() string {
return i.NameStr
}
func (i *InterfaceState) Address() device.WGAddress {
return i.WGAddress
}
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct {
sync.Mutex
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
RouteRules routeRules `json:"route_rules,omitempty"`
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
ACLEntries aclEntries `json:"acl_entries,omitempty"`
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
}
func (s *ShutdownState) Name() string {
return "iptables_state"
}
func (s *ShutdownState) Cleanup() error {
ipt, err := Create(s.InterfaceState)
if err != nil {
return fmt.Errorf("create iptables manager: %w", err)
}
if s.RouteRules != nil {
ipt.router.rules = s.RouteRules
}
if s.RouteIPsetCounter != nil {
ipt.router.ipsetCounter.LoadData(s.RouteIPsetCounter)
}
if s.ACLEntries != nil {
ipt.aclMgr.entries = s.ACLEntries
}
if s.ACLIPsetStore != nil {
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
}
if err := ipt.Reset(nil); err != nil {
return fmt.Errorf("reset iptables manager: %w", err)
}
return nil
}

View File

@@ -10,11 +10,14 @@ import (
"strings" "strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const ( const (
ForwardingFormatPrefix = "netbird-fwd-" ForwardingFormatPrefix = "netbird-fwd-"
ForwardingFormat = "netbird-fwd-%s-%t" ForwardingFormat = "netbird-fwd-%s-%t"
PreroutingFormat = "netbird-prerouting-%s-%t"
NatFormat = "netbird-nat-%s-%t" NatFormat = "netbird-nat-%s-%t"
) )
@@ -52,6 +55,8 @@ const (
// It declares methods which handle actions required by the // It declares methods which handle actions required by the
// Netbird client for ACL and routing functionality // Netbird client for ACL and routing functionality
type Manager interface { type Manager interface {
Init(stateManager *statemanager.Manager) error
// AllowNetbird allows netbird interface traffic // AllowNetbird allows netbird interface traffic
AllowNetbird() error AllowNetbird() error
@@ -91,7 +96,7 @@ type Manager interface {
SetLegacyManagement(legacy bool) error SetLegacyManagement(legacy bool) error
// Reset firewall to the default state // Reset firewall to the default state
Reset() error Reset(stateManager *statemanager.Manager) error
// Flush the changes to firewall controller // Flush the changes to firewall controller
Flush() error Flush() error

View File

@@ -17,7 +17,6 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
@@ -56,13 +55,6 @@ type AclManager struct {
rules map[string]*Rule rules map[string]*Rule
} }
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
}
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) { func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) {
// sConn is used for creating sets and adding/removing elements from them // sConn is used for creating sets and adding/removing elements from them
// it's differ then rConn (which does create new conn for each flush operation) // it's differ then rConn (which does create new conn for each flush operation)
@@ -70,10 +62,10 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam
// overloads netlink with high amount of rules ( > 10000) // overloads netlink with high amount of rules ( > 10000)
sConn, err := nftables.New(nftables.AsLasting()) sConn, err := nftables.New(nftables.AsLasting())
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("create nf conn: %w", err)
} }
m := &AclManager{ return &AclManager{
rConn: &nftables.Conn{}, rConn: &nftables.Conn{},
sConn: sConn, sConn: sConn,
wgIface: wgIface, wgIface: wgIface,
@@ -82,14 +74,12 @@ func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainNam
ipsetStore: newIpsetStore(), ipsetStore: newIpsetStore(),
rules: make(map[string]*Rule), rules: make(map[string]*Rule),
} }, nil
}
err = m.createDefaultChains() func (m *AclManager) init(workTable *nftables.Table) error {
if err != nil { m.workTable = workTable
return nil, err return m.createDefaultChains()
}
return m, nil
} }
// AddPeerFiltering rule to the firewall // AddPeerFiltering rule to the firewall
@@ -530,7 +520,7 @@ func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
}, },
&expr.Immediate{ &expr.Immediate{
Register: 1, Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark), Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
}, },
&expr.Meta{ &expr.Meta{
Key: expr.MetaKeyMARK, Key: expr.MetaKeyMARK,
@@ -553,7 +543,7 @@ func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
Register: 1, Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmark), Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
}, },
&expr.Verdict{ &expr.Verdict{
Kind: expr.VerdictJump, Kind: expr.VerdictJump,

View File

@@ -14,6 +14,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const ( const (
@@ -24,6 +26,13 @@ const (
chainNameInput = "INPUT" chainNameInput = "INPUT"
) )
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
}
// Manager of iptables firewall // Manager of iptables firewall
type Manager struct { type Manager struct {
mutex sync.Mutex mutex sync.Mutex
@@ -35,30 +44,70 @@ type Manager struct {
} }
// Create nftables firewall manager // Create nftables firewall manager
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { func Create(wgIface iFaceMapper) (*Manager, error) {
m := &Manager{ m := &Manager{
rConn: &nftables.Conn{}, rConn: &nftables.Conn{},
wgIface: wgIface, wgIface: wgIface,
} }
workTable, err := m.createWorkTable() workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}
if err != nil {
return nil, err
}
m.router, err = newRouter(context, workTable, wgIface) var err error
m.router, err = newRouter(workTable, wgIface)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("create router: %w", err)
} }
m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw) m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("create acl manager: %w", err)
} }
return m, nil return m, nil
} }
// Init nftables firewall manager
func (m *Manager) Init(stateManager *statemanager.Manager) error {
workTable, err := m.createWorkTable()
if err != nil {
return fmt.Errorf("create work table: %w", err)
}
if err := m.router.init(workTable); err != nil {
return fmt.Errorf("router init: %w", err)
}
if err := m.aclManager.init(workTable); err != nil {
// TODO: cleanup router
return fmt.Errorf("acl manager init: %w", err)
}
stateManager.RegisterState(&ShutdownState{})
// We only need to record minimal interface state for potential recreation.
// Unlike iptables, which requires tracking individual rules, nftables maintains
// a known state (our netbird table plus a few static rules). This allows for easy
// cleanup using Reset() without needing to store specific rules.
if err := stateManager.UpdateState(&ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
},
}); err != nil {
log.Errorf("failed to update state: %v", err)
}
// persist early
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
}()
return nil
}
// AddPeerFiltering rule to the firewall // AddPeerFiltering rule to the firewall
// //
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
@@ -150,7 +199,7 @@ func (m *Manager) AllowNetbird() error {
var chain *nftables.Chain var chain *nftables.Chain
for _, c := range chains { for _, c := range chains {
if c.Table.Name == tableNameFilter && c.Name == chainNameForward { if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
chain = c chain = c
break break
} }
@@ -183,68 +232,84 @@ func (m *Manager) AllowNetbird() error {
// SetLegacyManagement sets the route manager to use legacy management // SetLegacyManagement sets the route manager to use legacy management
func (m *Manager) SetLegacyManagement(isLegacy bool) error { func (m *Manager) SetLegacyManagement(isLegacy bool) error {
oldLegacy := m.router.legacyManagement return firewall.SetLegacyManagement(m.router, isLegacy)
}
if oldLegacy != isLegacy { // Reset firewall to the default state
m.router.legacyManagement = isLegacy func (m *Manager) Reset(stateManager *statemanager.Manager) error {
log.Debugf("Set legacy management to %v", isLegacy) m.mutex.Lock()
defer m.mutex.Unlock()
if err := m.resetNetbirdInputRules(); err != nil {
return fmt.Errorf("reset netbird input rules: %v", err)
} }
// client reconnected to a newer mgmt, we need to cleanup the legacy rules if err := m.router.Reset(); err != nil {
if !isLegacy && oldLegacy { return fmt.Errorf("reset router: %v", err)
if err := m.router.RemoveAllLegacyRouteRules(); err != nil { }
return fmt.Errorf("remove legacy routing rules: %v", err)
}
log.Debugf("Legacy routing rules removed") if err := m.cleanupNetbirdTables(); err != nil {
return fmt.Errorf("cleanup netbird tables: %v", err)
}
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
return fmt.Errorf("delete state: %v", err)
} }
return nil return nil
} }
// Reset firewall to the default state func (m *Manager) resetNetbirdInputRules() error {
func (m *Manager) Reset() error {
m.mutex.Lock()
defer m.mutex.Unlock()
chains, err := m.rConn.ListChains() chains, err := m.rConn.ListChains()
if err != nil { if err != nil {
return fmt.Errorf("list of chains: %w", err) return fmt.Errorf("list chains: %w", err)
} }
m.deleteNetbirdInputRules(chains)
return nil
}
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
for _, c := range chains { for _, c := range chains {
// delete Netbird allow input traffic rule if it exists if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
if c.Table.Name == "filter" && c.Name == "INPUT" {
rules, err := m.rConn.GetRules(c.Table, c) rules, err := m.rConn.GetRules(c.Table, c)
if err != nil { if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err) log.Errorf("get rules for chain %q: %v", c.Name, err)
continue continue
} }
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) { m.deleteMatchingRules(rules)
if err := m.rConn.DelRule(r); err != nil { }
log.Errorf("delete rule: %v", err) }
} }
}
func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) {
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
if err := m.rConn.DelRule(r); err != nil {
log.Errorf("delete rule: %v", err)
} }
} }
} }
}
if err := m.router.Reset(); err != nil { func (m *Manager) cleanupNetbirdTables() error {
return fmt.Errorf("reset forward rules: %v", err)
}
tables, err := m.rConn.ListTables() tables, err := m.rConn.ListTables()
if err != nil { if err != nil {
return fmt.Errorf("list of tables: %w", err) return fmt.Errorf("list tables: %w", err)
} }
for _, t := range tables { for _, t := range tables {
if t.Name == tableNameNetbird { if t.Name == tableNameNetbird {
m.rConn.DelTable(t) m.rConn.DelTable(t)
} }
} }
return nil
return m.rConn.Flush()
} }
// Flush rule/chain/set operations from the buffer // Flush rule/chain/set operations from the buffer
@@ -286,7 +351,9 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
Register: 1, Register: 1,
Data: ifname(m.wgIface.Name()), Data: ifname(m.wgIface.Name()),
}, },
&expr.Verdict{}, &expr.Verdict{
Kind: expr.VerdictAccept,
},
}, },
UserData: []byte(allowNetbirdInputRuleID), UserData: []byte(allowNetbirdInputRuleID),
} }

View File

@@ -1,7 +1,6 @@
package nftables package nftables
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@@ -58,12 +57,13 @@ func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestNftablesManager(t *testing.T) { func TestNftablesManager(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(context.Background(), ifaceMock) manager, err := Create(ifaceMock)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
defer func() { defer func() {
err = manager.Reset() err = manager.Reset(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
@@ -169,7 +169,7 @@ func TestNftablesManager(t *testing.T) {
// established rule remains // established rule remains
require.Len(t, rules, 1, "expected 1 rules after deletion") require.Len(t, rules, 1, "expected 1 rules after deletion")
err = manager.Reset() err = manager.Reset(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
} }
@@ -192,12 +192,13 @@ func TestNFtablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(context.Background(), mock) manager, err := Create(mock)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
defer func() { defer func() {
if err := manager.Reset(); err != nil { if err := manager.Reset(nil); err != nil {
t.Errorf("clear the manager state: %v", err) t.Errorf("clear the manager state: %v", err)
} }
time.Sleep(time.Second) time.Sleep(time.Second)

View File

@@ -2,7 +2,6 @@ package nftables
import ( import (
"bytes" "bytes"
"context"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
@@ -22,6 +21,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/util/net"
) )
const ( const (
@@ -40,8 +40,6 @@ var (
) )
type router struct { type router struct {
ctx context.Context
stop context.CancelFunc
conn *nftables.Conn conn *nftables.Conn
workTable *nftables.Table workTable *nftables.Table
filterTable *nftables.Table filterTable *nftables.Table
@@ -54,12 +52,8 @@ type router struct {
legacyManagement bool legacyManagement bool
} }
func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFaceMapper) (*router, error) { func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
ctx, cancel := context.WithCancel(parentCtx)
r := &router{ r := &router{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{}, conn: &nftables.Conn{},
workTable: workTable, workTable: workTable,
chains: make(map[string]*nftables.Chain), chains: make(map[string]*nftables.Chain),
@@ -78,20 +72,25 @@ func newRouter(parentCtx context.Context, workTable *nftables.Table, wgIface iFa
if errors.Is(err, errFilterTableNotFound) { if errors.Is(err, errFilterTableNotFound) {
log.Warnf("table 'filter' not found for forward rules") log.Warnf("table 'filter' not found for forward rules")
} else { } else {
return nil, err return nil, fmt.Errorf("load filter table: %w", err)
} }
} }
err = r.removeAcceptForwardRules() return r, nil
if err != nil { }
func (r *router) init(workTable *nftables.Table) error {
r.workTable = workTable
if err := r.removeAcceptForwardRules(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err) log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
} }
err = r.createContainers() if err := r.createContainers(); err != nil {
if err != nil { return fmt.Errorf("create containers: %w", err)
log.Errorf("failed to create containers for route: %s", err)
} }
return r, err
return nil
} }
// Reset cleans existing nftables default forward rules from the system // Reset cleans existing nftables default forward rules from the system
@@ -126,7 +125,6 @@ func (r *router) createContainers() error {
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw]) insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
prio := *nftables.ChainPriorityNATSource - 1 prio := *nftables.ChainPriorityNATSource - 1
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{ r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat, Name: chainNameRoutingNat,
Table: r.workTable, Table: r.workTable,
@@ -135,6 +133,21 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeNAT, Type: nftables.ChainTypeNAT,
}) })
// Chain is created by acl manager
// TODO: move creation to a common place
r.chains[chainNamePrerouting] = &nftables.Chain{
Name: chainNamePrerouting,
Table: r.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
}
// Add the single NAT rule that matches on mark
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add single nat rule: %v", err)
}
if err := r.acceptForwardRules(); err != nil { if err := r.acceptForwardRules(); err != nil {
log.Errorf("failed to add accept rules for the forward chain: %s", err) log.Errorf("failed to add accept rules for the forward chain: %s", err)
} }
@@ -424,59 +437,149 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source) sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination) destExp := generateCIDRMatcherExpressions(false, pair.Destination)
dir := expr.MetaKeyIIFNAME op := expr.CmpOpEq
notDir := expr.MetaKeyOIFNAME
if pair.Inverse { if pair.Inverse {
dir = expr.MetaKeyOIFNAME op = expr.CmpOpNeq
notDir = expr.MetaKeyIIFNAME
} }
lo := ifname("lo")
intf := ifname(r.wgIface.Name())
exprs := []expr.Any{ exprs := []expr.Any{
&expr.Meta{ // We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
Key: dir, // Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1, Register: 1,
}, },
&expr.Cmp{ &expr.Bitwise{
Op: expr.CmpOpEq, SourceRegister: 1,
Register: 1, DestRegister: 1,
Data: intf, Len: 4,
}, Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
// We need to exclude the loopback interface as this changes the ebpf proxy port
&expr.Meta{
Key: notDir,
Register: 1,
}, },
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpNeq, Op: expr.CmpOpNeq,
Register: 1, Register: 1,
Data: lo, Data: []byte{0, 0, 0, 0},
},
// interface matching
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: op,
Register: 1,
Data: ifname(r.wgIface.Name()),
}, },
} }
exprs = append(exprs, sourceExp...) exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...) exprs = append(exprs, destExp...)
var markValue uint32 = nbnet.PreroutingFwmarkMasquerade
if pair.Inverse {
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
}
exprs = append(exprs, exprs = append(exprs,
&expr.Counter{}, &expr.Masq{}, &expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(markValue),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
SourceRegister: true,
Register: 1,
},
) )
ruleKey := firewall.GenKey(firewall.NatFormat, pair) ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if _, exists := r.rules[ruleKey]; exists { if _, exists := r.rules[ruleKey]; exists {
if err := r.removeNatRule(pair); err != nil { if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove routing rule: %w", err) return fmt.Errorf("remove prerouting rule: %w", err)
} }
} }
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable, Table: r.workTable,
Chain: r.chains[chainNameRoutingNat], Chain: r.chains[chainNamePrerouting],
Exprs: exprs, Exprs: exprs,
UserData: []byte(ruleKey), UserData: []byte(ruleKey),
}) })
return nil
}
// addPostroutingRules adds the masquerade rules
func (r *router) addPostroutingRules() error {
// First masquerade rule for traffic coming in from WireGuard interface
exprs := []expr.Any{
// Match on the first fwmark
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasquerade),
},
// We need to exclude the loopback interface as this changes the ebpf proxy port
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
&expr.Counter{},
&expr.Masq{},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs,
})
// Second masquerade rule for traffic going out through WireGuard interface
exprs2 := []expr.Any{
// Match on the second fwmark
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasqueradeReturn),
},
// Match WireGuard interface
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Counter{},
&expr.Masq{},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs2,
})
return nil return nil
} }
@@ -553,7 +656,10 @@ func (r *router) RemoveAllLegacyRouteRules() error {
} }
if err := r.conn.DelRule(rule); err != nil { if err := r.conn.DelRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
} else {
delete(r.rules, k)
} }
} }
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
@@ -722,18 +828,18 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
// RemoveNatRule removes a nftables rule pair from nat chains // RemoveNatRule removes the prerouting mark rule
func (r *router) RemoveNatRule(pair firewall.RouterPair) error { func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
if err := r.removeNatRule(pair); err != nil { if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err) return fmt.Errorf("remove prerouting rule: %w", err)
} }
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse nat rule: %w", err) return fmt.Errorf("remove inverse prerouting rule: %w", err)
} }
if err := r.removeLegacyRouteRule(pair); err != nil { if err := r.removeLegacyRouteRule(pair); err != nil {
@@ -748,21 +854,20 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return nil return nil
} }
// removeNatRule adds a nftables rule to the removal queue and deletes it from the rules map
func (r *router) removeNatRule(pair firewall.RouterPair) error { func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair) ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
err := r.conn.DelRule(rule) err := r.conn.DelRule(rule)
if err != nil { if err != nil {
return fmt.Errorf("remove nat rule %s -> %s: %v", pair.Source, pair.Destination, err) return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
} }
log.Debugf("nftables: removed nat rule %s -> %s", pair.Source, pair.Destination) log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
} else { } else {
log.Debugf("nftables: nat rule %s not found", ruleKey) log.Debugf("nftables: prerouting rule %s not found", ruleKey)
} }
return nil return nil

View File

@@ -3,7 +3,6 @@
package nftables package nftables
import ( import (
"context"
"encoding/binary" "encoding/binary"
"net/netip" "net/netip"
"os/exec" "os/exec"
@@ -11,6 +10,7 @@ import (
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -33,99 +33,87 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
t.Skip("nftables not supported on this OS") t.Skip("nftables not supported on this OS")
} }
table, err := createWorkTable()
require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable()
for _, testCase := range test.InsertRuleTestCases { for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
manager, err := newRouter(context.TODO(), table, ifaceMock) // need fw manager to init both acl mgr and router for all chains to be present
require.NoError(t, err, "failed to create router") manager, err := Create(ifaceMock)
t.Cleanup(func() {
require.NoError(t, manager.Reset(nil))
})
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
nftablesTestingClient := &nftables.Conn{} nftablesTestingClient := &nftables.Conn{}
defer func(manager *router) { rtr := manager.router
require.NoError(t, manager.Reset(), "failed to reset rules") err = rtr.AddNatRule(testCase.InputPair)
}(manager)
require.NoError(t, err, "shouldn't return error")
err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "pair should be inserted") require.NoError(t, err, "pair should be inserted")
defer func(manager *router, pair firewall.RouterPair) { t.Cleanup(func() {
require.NoError(t, manager.RemoveNatRule(pair), "failed to remove rule") require.NoError(t, rtr.RemoveNatRule(testCase.InputPair), "failed to remove rule")
}(manager, testCase.InputPair) })
if testCase.InputPair.Masquerade { if testCase.InputPair.Masquerade {
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) // Build expected expressions for connection tracking
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) conntrackExprs := []expr.Any{
testingExpression := append(sourceExp, destExp...) //nolint:gocritic &expr.Ct{
testingExpression = append(testingExpression, Key: expr.CtKeySTATE,
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
}
// Build interface matching expression
ifaceExprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
Register: 1, Register: 1,
Data: ifname(ifaceMock.Name()), Data: ifname(ifaceMock.Name()),
}, },
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
)
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match")
found = 1
}
}
} }
require.Equal(t, 1, found, "should find at least 1 rule to test")
}
if testCase.InputPair.Masquerade { // Build CIDR matching expressions
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
testingExpression = append(testingExpression,
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(ifaceMock.Name()),
},
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
)
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) // Combine all expressions in the correct order
// nolint:gocritic
testingExpression := append(conntrackExprs, ifaceExprs...)
testingExpression = append(testingExpression, sourceExp...)
testingExpression = append(testingExpression, destExp...)
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
found := 0 found := 0
for _, chain := range manager.chains { for _, chain := range rtr.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain) if chain.Name == chainNamePrerouting {
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
for _, rule := range rules { require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey { for _, rule := range rules {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match") if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found = 1 // Compare expressions up to the mark setting expressions
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "prerouting nat rule elements should match")
found = 1
}
} }
} }
} }
require.Equal(t, 1, found, "should find at least 1 rule to test") require.Equal(t, 1, found, "should find at least 1 rule in prerouting chain")
} }
}) })
} }
} }
@@ -135,67 +123,66 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
t.Skip("nftables not supported on this OS") t.Skip("nftables not supported on this OS")
} }
table, err := createWorkTable()
require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable()
for _, testCase := range test.RemoveRuleTestCases { for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
manager, err := newRouter(context.TODO(), table, ifaceMock) manager, err := Create(ifaceMock)
require.NoError(t, err, "failed to create router") t.Cleanup(func() {
require.NoError(t, manager.Reset(nil))
nftablesTestingClient := &nftables.Conn{}
defer func(manager *router) {
require.NoError(t, manager.Reset(), "failed to reset rules")
}(manager)
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRoutingNat],
Exprs: natExp,
UserData: []byte(natRuleKey),
}) })
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInversePair(testCase.InputPair).Source) rtr := manager.router
destExp = generateCIDRMatcherExpressions(false, firewall.GetInversePair(testCase.InputPair).Destination)
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic // First add the NAT rule using the router's method
inNatRuleKey := firewall.GenKey(firewall.NatFormat, firewall.GetInversePair(testCase.InputPair)) err = rtr.AddNatRule(testCase.InputPair)
require.NoError(t, err, "should add NAT rule")
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{ // Verify the rule was added
Table: manager.workTable, natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
Chain: manager.chains[chainNameRoutingNat], found := false
Exprs: natExp, rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
UserData: []byte(inNatRuleKey), require.NoError(t, err, "should list rules")
}) for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
err = nftablesTestingClient.Flush() found = true
require.NoError(t, err, "shouldn't return error") break
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
err = manager.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 {
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist")
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist")
}
} }
} }
require.True(t, found, "NAT rule should exist before removal")
// Now remove the rule
err = rtr.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error when removing rule")
// Verify the rule was removed
found = false
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
require.NoError(t, err, "should list rules after removal")
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found = true
break
}
}
require.False(t, found, "NAT rule should not exist after removal")
// Verify the static postrouting rules still exist
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameRoutingNat])
require.NoError(t, err, "should list postrouting rules")
foundCounter := false
for _, rule := range rules {
for _, e := range rule.Exprs {
if _, ok := e.(*expr.Counter); ok {
foundCounter = true
break
}
}
if foundCounter {
break
}
}
require.True(t, foundCounter, "static postrouting rule should remain")
}) })
} }
} }
@@ -210,8 +197,9 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
defer deleteWorkTable() defer deleteWorkTable()
r, err := newRouter(context.Background(), workTable, ifaceMock) r, err := newRouter(workTable, ifaceMock)
require.NoError(t, err, "Failed to create router") require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
defer func(r *router) { defer func(r *router) {
require.NoError(t, r.Reset(), "Failed to reset rules") require.NoError(t, r.Reset(), "Failed to reset rules")
@@ -376,8 +364,9 @@ func TestNftablesCreateIpSet(t *testing.T) {
defer deleteWorkTable() defer deleteWorkTable()
r, err := newRouter(context.Background(), workTable, ifaceMock) r, err := newRouter(workTable, ifaceMock)
require.NoError(t, err, "Failed to create router") require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
defer func() { defer func() {
require.NoError(t, r.Reset(), "Failed to reset router") require.NoError(t, r.Reset(), "Failed to reset router")

View File

@@ -0,0 +1 @@
package nftables

View File

@@ -0,0 +1,47 @@
package nftables
import (
"fmt"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
}
func (i *InterfaceState) Name() string {
return i.NameStr
}
func (i *InterfaceState) Address() device.WGAddress {
return i.WGAddress
}
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct {
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
}
func (s *ShutdownState) Name() string {
return "nftables_state"
}
func (s *ShutdownState) Cleanup() error {
nft, err := Create(s.InterfaceState)
if err != nil {
return fmt.Errorf("create nftables manager: %w", err)
}
if err := nft.Reset(nil); err != nil {
return fmt.Errorf("reset nftables manager: %w", err)
}
return nil
}

View File

@@ -2,8 +2,10 @@
package uspfilter package uspfilter
import "github.com/netbirdio/netbird/client/internal/statemanager"
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset() error { func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -11,7 +13,7 @@ func (m *Manager) Reset() error {
m.incomingRules = make(map[string]RuleSet) m.incomingRules = make(map[string]RuleSet)
if m.nativeFirewall != nil { if m.nativeFirewall != nil {
return m.nativeFirewall.Reset() return m.nativeFirewall.Reset(stateManager)
} }
return nil return nil
} }

View File

@@ -6,6 +6,8 @@ import (
"syscall" "syscall"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
type action string type action string
@@ -17,7 +19,7 @@ const (
) )
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset() error { func (m *Manager) Reset(*statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()

View File

@@ -14,6 +14,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const layerTypeAll = 0 const layerTypeAll = 0
@@ -97,6 +98,10 @@ func create(iface IFaceMapper) (*Manager, error) {
return m, nil return m, nil
} }
func (m *Manager) Init(*statemanager.Manager) error {
return nil
}
func (m *Manager) IsServerRouteSupported() bool { func (m *Manager) IsServerRouteSupported() bool {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
return false return false
@@ -190,7 +195,7 @@ func (m *Manager) AddPeerFiltering(
return []firewall.Rule{&r}, nil return []firewall.Rule{&r}, nil
} }
func (m *Manager) AddRouteFiltering(sources [] netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action ) (firewall.Rule, error) { func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
return nil, errRouteNotSupported return nil, errRouteNotSupported
} }
@@ -232,8 +237,11 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
} }
// SetLegacyManagement doesn't need to be implemented for this manager // SetLegacyManagement doesn't need to be implemented for this manager
func (m *Manager) SetLegacyManagement(_ bool) error { func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return nil if m.nativeFirewall == nil {
return errRouteNotSupported
}
return m.nativeFirewall.SetLegacyManagement(isLegacy)
} }
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager

View File

@@ -259,7 +259,7 @@ func TestManagerReset(t *testing.T) {
return return
} }
err = m.Reset() err = m.Reset(nil)
if err != nil { if err != nil {
t.Errorf("failed to reset Manager: %v", err) t.Errorf("failed to reset Manager: %v", err)
return return
@@ -330,7 +330,7 @@ func TestNotMatchByIP(t *testing.T) {
return return
} }
if err = m.Reset(); err != nil { if err = m.Reset(nil); err != nil {
t.Errorf("failed to reset Manager: %v", err) t.Errorf("failed to reset Manager: %v", err)
return return
} }
@@ -396,7 +396,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
if err := manager.Reset(); err != nil { if err := manager.Reset(nil); err != nil {
t.Errorf("clear the manager state: %v", err) t.Errorf("clear the manager state: %v", err)
} }
time.Sleep(time.Second) time.Sleep(time.Second)

View File

@@ -1,142 +0,0 @@
package bind
import (
"fmt"
"net"
"runtime"
"sync"
"github.com/pion/stun/v2"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4"
wgConn "golang.zx2c4.com/wireguard/conn"
)
type receiverCreator struct {
iceBind *ICEBind
}
func (rc receiverCreator) CreateIPv4ReceiverFn(msgPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(msgPool, pc, conn)
}
type ICEBind struct {
*wgConn.StdNetBind
muUDPMux sync.Mutex
transportNet transport.Net
udpMux *UniversalUDPMuxDefault
filterFn FilterFn
}
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
ib := &ICEBind{
transportNet: transportNet,
filterFn: filterFn,
}
rc := receiverCreator{
ib,
}
ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc)
return ib
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
if s.udpMux == nil {
return nil, fmt.Errorf("ICEBind has not been initialized yet")
}
return s.udpMux, nil
}
func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketConn, conn *net.UDPConn) wgConn.ReceiveFunc {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: conn,
Net: s.transportNet,
FilterFn: s.filterFn,
},
)
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := ipv4MsgsPool.Get().(*[]ipv4.Message)
defer ipv4MsgsPool.Put(msgs)
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
}
var numMsgs int
if runtime.GOOS == "linux" {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
// todo: handle err
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
if ok {
sizes[i] = 0
} else {
sizes[i] = msg.N
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
}
}
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
for i := range buffers {
if !stun.IsMessage(buffers[i]) {
continue
}
msg, err := s.parseSTUNMessage(buffers[i][:n])
if err != nil {
buffers[i] = []byte{}
return true, err
}
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
if muxErr != nil {
log.Warnf("failed to handle STUN packet")
}
buffers[i] = []byte{}
return true, nil
}
return false, nil
}
func (s *ICEBind) parseSTUNMessage(raw []byte) (*stun.Message, error) {
msg := &stun.Message{
Raw: raw,
}
if err := msg.Decode(); err != nil {
return nil, err
}
return msg, nil
}

View File

@@ -0,0 +1,5 @@
package bind
import wgConn "golang.zx2c4.com/wireguard/conn"
type Endpoint = wgConn.StdNetEndpoint

View File

@@ -0,0 +1,303 @@
package bind
import (
"fmt"
"net"
"net/netip"
"runtime"
"strings"
"sync"
"github.com/pion/stun/v2"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
)
type RecvMessage struct {
Endpoint *Endpoint
Buffer []byte
}
type receiverCreator struct {
iceBind *ICEBind
}
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
}
// ICEBind is a bind implementation with two main features:
// 1. filter out STUN messages and handle them
// 2. forward the received packets to the WireGuard interface from the relayed connection
//
// ICEBind.endpoints var is a map that stores the connection for each relayed peer. Fake address is just an IP address
// without port, in the format of 127.1.x.x where x.x is the last two octets of the peer address. We try to avoid to
// use the port because in the Send function the wgConn.Endpoint the port info is not exported.
type ICEBind struct {
*wgConn.StdNetBind
RecvChan chan RecvMessage
transportNet transport.Net
filterFn FilterFn
endpoints map[netip.Addr]net.Conn
endpointsMu sync.Mutex
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
// new closed channel. With the closedChanMu we can safely close the channel and create a new one
closedChan chan struct{}
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
closed bool
muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault
}
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{
StdNetBind: b,
RecvChan: make(chan RecvMessage, 1),
transportNet: transportNet,
filterFn: filterFn,
endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}),
closed: true,
}
rc := receiverCreator{
ib,
}
ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc)
return ib
}
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
s.closed = false
s.closedChanMu.Lock()
s.closedChan = make(chan struct{})
s.closedChanMu.Unlock()
fns, port, err := s.StdNetBind.Open(uport)
if err != nil {
return nil, 0, err
}
fns = append(fns, s.receiveRelayed)
return fns, port, nil
}
func (s *ICEBind) Close() error {
if s.closed {
return nil
}
s.closed = true
close(s.closedChan)
return s.StdNetBind.Close()
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
if s.udpMux == nil {
return nil, fmt.Errorf("ICEBind has not been initialized yet")
}
return s.udpMux, nil
}
func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) {
fakeUDPAddr, err := fakeAddress(peerAddress)
if err != nil {
return nil, err
}
// force IPv4
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
return nil, fmt.Errorf("failed to convert IP to netip.Addr")
}
b.endpointsMu.Lock()
b.endpoints[fakeAddr] = conn
b.endpointsMu.Unlock()
return fakeUDPAddr, nil
}
func (b *ICEBind) RemoveEndpoint(fakeUDPAddr *net.UDPAddr) {
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
log.Warnf("failed to convert IP to netip.Addr")
return
}
b.endpointsMu.Lock()
defer b.endpointsMu.Unlock()
delete(b.endpoints, fakeAddr)
}
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
b.endpointsMu.Lock()
conn, ok := b.endpoints[ep.DstIP()]
b.endpointsMu.Unlock()
if !ok {
return b.StdNetBind.Send(bufs, ep)
}
for _, buf := range bufs {
if _, err := conn.Write(buf); err != nil {
return err
}
}
return nil
}
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: conn,
Net: s.transportNet,
FilterFn: s.filterFn,
},
)
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := getMessages(msgsPool)
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
}
defer putMessages(msgs, msgsPool)
var numMsgs int
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
if rxOffload {
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
//nolint
numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0)
if err != nil {
return 0, err
}
numMsgs, err = wgConn.SplitCoalescedMessages(*msgs, readAt, wgConn.GetGSOSize)
if err != nil {
return 0, err
}
} else {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
// todo: handle err
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
if ok {
continue
}
sizes[i] = msg.N
if sizes[i] == 0 {
continue
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
}
}
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
for i := range buffers {
if !stun.IsMessage(buffers[i]) {
continue
}
msg, err := s.parseSTUNMessage(buffers[i][:n])
if err != nil {
buffers[i] = []byte{}
return true, err
}
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
if muxErr != nil {
log.Warnf("failed to handle STUN packet")
}
buffers[i] = []byte{}
return true, nil
}
return false, nil
}
func (s *ICEBind) parseSTUNMessage(raw []byte) (*stun.Message, error) {
msg := &stun.Message{
Raw: raw,
}
if err := msg.Decode(); err != nil {
return nil, err
}
return msg, nil
}
// receiveRelayed is a receive function that is used to receive packets from the relayed connection and forward to the
// WireGuard. Critical part is do not block if the Closed() has been called.
func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
c.closedChanMu.RLock()
defer c.closedChanMu.RUnlock()
select {
case <-c.closedChan:
return 0, net.ErrClosed
case msg, ok := <-c.RecvChan:
if !ok {
return 0, net.ErrClosed
}
copy(buffs[0], msg.Buffer)
sizes[0] = len(msg.Buffer)
eps[0] = wgConn.Endpoint(msg.Endpoint)
return 1, nil
}
}
// fakeAddress returns a fake address that is used to as an identifier for the peer.
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
octets := strings.Split(peerAddress.IP.String(), ".")
if len(octets) != 4 {
return nil, fmt.Errorf("invalid IP format")
}
newAddr := &net.UDPAddr{
IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])),
Port: peerAddress.Port,
}
return newAddr, nil
}
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
return msgsPool.Get().(*[]ipv6.Message)
}
func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
for i := range *msgs {
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
}
msgsPool.Put(msgs)
}

View File

@@ -5,7 +5,6 @@ package device
import ( import (
"strings" "strings"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
@@ -31,13 +30,13 @@ type WGTunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter, filterFn bind.FilterFn) *WGTunDevice { func NewTunDevice(address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
return &WGTunDevice{ return &WGTunDevice{
address: address, address: address,
port: port, port: port,
key: key, key: key,
mtu: mtu, mtu: mtu,
iceBind: bind.NewICEBind(transportNet, filterFn), iceBind: iceBind,
tunAdapter: tunAdapter, tunAdapter: tunAdapter,
} }
} }

View File

@@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"os/exec" "os/exec"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
@@ -29,14 +28,14 @@ type TunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice { func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{ return &TunDevice{
name: name, name: name,
address: address, address: address,
port: port, port: port,
key: key, key: key,
mtu: mtu, mtu: mtu,
iceBind: bind.NewICEBind(transportNet, filterFn), iceBind: iceBind,
} }
} }

View File

@@ -6,7 +6,6 @@ package device
import ( import (
"os" "os"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
@@ -30,13 +29,13 @@ type TunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int, filterFn bind.FilterFn) *TunDevice { func NewTunDevice(name string, address WGAddress, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
return &TunDevice{ return &TunDevice{
name: name, name: name,
address: address, address: address,
port: port, port: port,
key: key, key: key,
iceBind: bind.NewICEBind(transportNet, filterFn), iceBind: iceBind,
tunFd: tunFd, tunFd: tunFd,
} }
} }

View File

@@ -6,7 +6,6 @@ package device
import ( import (
"fmt" "fmt"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
@@ -31,7 +30,7 @@ type TunNetstackDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) *TunNetstackDevice { func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
return &TunNetstackDevice{ return &TunNetstackDevice{
name: name, name: name,
address: address, address: address,
@@ -39,7 +38,7 @@ func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, m
key: key, key: key,
mtu: mtu, mtu: mtu,
listenAddress: listenAddress, listenAddress: listenAddress,
iceBind: bind.NewICEBind(transportNet, filterFn), iceBind: iceBind,
} }
} }

View File

@@ -7,7 +7,6 @@ import (
"os" "os"
"runtime" "runtime"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
@@ -30,7 +29,7 @@ type USPDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *USPDevice { func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
log.Infof("using userspace bind mode") log.Infof("using userspace bind mode")
checkUser() checkUser()
@@ -41,7 +40,8 @@ func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int,
port: port, port: port,
key: key, key: key,
mtu: mtu, mtu: mtu,
iceBind: bind.NewICEBind(transportNet, filterFn)} iceBind: iceBind,
}
} }
func (t *USPDevice) Create() (WGConfigurer, error) { func (t *USPDevice) Create() (WGConfigurer, error) {

View File

@@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"net/netip" "net/netip"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
@@ -32,14 +31,14 @@ type TunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) *TunDevice { func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{ return &TunDevice{
name: name, name: name,
address: address, address: address,
port: port, port: port,
key: key, key: key,
mtu: mtu, mtu: mtu,
iceBind: bind.NewICEBind(transportNet, filterFn), iceBind: iceBind,
} }
} }

View File

@@ -6,12 +6,16 @@ import (
"sync" "sync"
"time" "time"
"github.com/hashicorp/go-multierror"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
) )
const ( const (
@@ -22,14 +26,35 @@ const (
type WGAddress = device.WGAddress type WGAddress = device.WGAddress
type wgProxyFactory interface {
GetProxy() wgproxy.Proxy
Free() error
}
type WGIFaceOpts struct {
IFaceName string
Address string
WGPort int
WGPrivKey string
MTU int
MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net
FilterFn bind.FilterFn
}
// WGIface represents an interface instance // WGIface represents an interface instance
type WGIface struct { type WGIface struct {
tun WGTunDevice tun WGTunDevice
userspaceBind bool userspaceBind bool
mu sync.Mutex mu sync.Mutex
configurer device.WGConfigurer configurer device.WGConfigurer
filter device.PacketFilter filter device.PacketFilter
wgProxyFactory wgProxyFactory
}
func (w *WGIface) GetProxy() wgproxy.Proxy {
return w.wgProxyFactory.GetProxy()
} }
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind // IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
@@ -124,22 +149,26 @@ func (w *WGIface) Close() error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
err := w.tun.Close() var result *multierror.Error
if err != nil {
return fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err) if err := w.wgProxyFactory.Free(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
} }
err = w.waitUntilRemoved() if err := w.tun.Close(); err != nil {
if err != nil { result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
}
if err := w.waitUntilRemoved(); err != nil {
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err) log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
err = w.Destroy() if err := w.Destroy(); err != nil {
if err != nil { result = multierror.Append(result, fmt.Errorf("failed to remove WireGuard interface %s: %w", w.Name(), err))
return fmt.Errorf("failed to remove WireGuard interface %s: %w", w.Name(), err) return errors.FormatErrorOrNil(result)
} }
log.Infof("interface %s successfully removed", w.Name()) log.Infof("interface %s successfully removed", w.Name())
} }
return nil return errors.FormatErrorOrNil(result)
} }
// SetFilter sets packet filters for the userspace implementation // SetFilter sets packet filters for the userspace implementation

View File

@@ -1,43 +0,0 @@
package iface
import (
"fmt"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{
tun: device.NewTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn),
userspaceBind: true,
}
return wgIFace, nil
}
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
w.mu.Lock()
defer w.mu.Unlock()
cfgr, err := w.tun.Create(routes, dns, searchDomains)
if err != nil {
return err
}
w.configurer = cfgr
return nil
}
// Create this function make sense on mobile only
func (w *WGIface) Create() error {
return fmt.Errorf("this function has not implemented on this platform")
}

View File

@@ -2,6 +2,8 @@
package iface package iface
import "fmt"
// Create creates a new Wireguard interface, sets a given IP and brings it up. // Create creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one. // Will reuse an existing one.
// this function is different on Android // this function is different on Android
@@ -17,3 +19,8 @@ func (w *WGIface) Create() error {
w.configurer = cfgr w.configurer = cfgr
return nil return nil
} }
// CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on non mobile")
}

View File

@@ -0,0 +1,24 @@
package iface
import (
"fmt"
)
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
w.mu.Lock()
defer w.mu.Unlock()
cfgr, err := w.tun.Create(routes, dns, searchDomains)
if err != nil {
return err
}
w.configurer = cfgr
return nil
}
// Create this function make sense on mobile only
func (w *WGIface) Create() error {
return fmt.Errorf("this function has not implemented on this platform")
}

View File

@@ -7,39 +7,8 @@ import (
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
) )
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{
userspaceBind: true,
}
if netstack.IsEnabled() {
wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
return wgIFace, nil
}
wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
return wgIFace, nil
}
// CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on this platform")
}
// Create creates a new Wireguard interface, sets a given IP and brings it up. // Create creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one. // Will reuse an existing one.
// this function is different on Android // this function is different on Android
@@ -65,3 +34,8 @@ func (w *WGIface) Create() error {
return backoff.Retry(operation, backOff) return backoff.Retry(operation, backOff)
} }
// CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on this platform")
}

View File

@@ -0,0 +1,10 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/device"
)
// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only
func (w *WGIface) GetInterfaceGUIDString() (string, error) {
return w.tun.(*device.TunDevice).GetInterfaceGUIDString()
}

View File

@@ -1,31 +0,0 @@
//go:build ios
package iface
import (
"fmt"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{
tun: device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd, filterFn),
userspaceBind: true,
}
return wgIFace, nil
}
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on this platform")
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
) )
type MockWGIface struct { type MockWGIface struct {
@@ -30,6 +31,7 @@ type MockWGIface struct {
GetDeviceFunc func() *device.FilteredDevice GetDeviceFunc func() *device.FilteredDevice
GetStatsFunc func(peerKey string) (configurer.WGStats, error) GetStatsFunc func(peerKey string) (configurer.WGStats, error)
GetInterfaceGUIDStringFunc func() (string, error) GetInterfaceGUIDStringFunc func() (string, error)
GetProxyFunc func() wgproxy.Proxy
} }
func (m *MockWGIface) GetInterfaceGUIDString() (string, error) { func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
@@ -103,3 +105,8 @@ func (m *MockWGIface) GetDevice() *device.FilteredDevice {
func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) { func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
return m.GetStatsFunc(peerKey) return m.GetStatsFunc(peerKey)
} }
func (m *MockWGIface) GetProxy() wgproxy.Proxy {
//TODO implement me
panic("implement me")
}

View File

@@ -0,0 +1,24 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace := &WGIface{
userspaceBind: true,
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil
}

View File

@@ -0,0 +1,34 @@
//go:build !ios
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
var tun WGTunDevice
if netstack.IsEnabled() {
tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
} else {
tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
}
wgIFace := &WGIface{
userspaceBind: true,
tun: tun,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil
}

View File

@@ -0,0 +1,26 @@
//go:build ios
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace := &WGIface{
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd),
userspaceBind: true,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil
}

View File

@@ -0,0 +1,45 @@
//go:build (linux && !android) || freebsd
package iface
import (
"fmt"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{}
if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
return wgIFace, nil
}
if device.WireGuardModuleIsLoaded() {
wgIFace.tun = device.NewKernelDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, opts.TransportNet)
wgIFace.wgProxyFactory = wgproxy.NewKernelFactory(opts.WGPort)
return wgIFace, nil
}
if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
return wgIFace, nil
}
return nil, fmt.Errorf("couldn't check or load tun module")
}

View File

@@ -0,0 +1,32 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
var tun WGTunDevice
if netstack.IsEnabled() {
tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
} else {
tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
}
wgIFace := &WGIface{
userspaceBind: true,
tun: tun,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil
}

View File

@@ -45,7 +45,16 @@ func TestWGIface_UpdateAddr(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, addr, wgPort, key, DefaultMTU, newNet, nil, nil) opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: addr,
WGPort: wgPort,
WGPrivKey: key,
MTU: DefaultMTU,
TransportNet: newNet,
}
iface, err := NewWGIFace(opts)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -118,7 +127,16 @@ func Test_CreateInterface(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil) opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: wgIP,
WGPort: 33100,
WGPrivKey: key,
MTU: DefaultMTU,
TransportNet: newNet,
}
iface, err := NewWGIFace(opts)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -153,7 +171,16 @@ func Test_Close(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil) opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: wgIP,
WGPort: wgPort,
WGPrivKey: key,
MTU: DefaultMTU,
TransportNet: newNet,
}
iface, err := NewWGIFace(opts)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -189,7 +216,16 @@ func TestRecreation(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil) opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: wgIP,
WGPort: wgPort,
WGPrivKey: key,
MTU: DefaultMTU,
TransportNet: newNet,
}
iface, err := NewWGIFace(opts)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -252,7 +288,15 @@ func Test_ConfigureInterface(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil) opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: wgIP,
WGPort: wgPort,
WGPrivKey: key,
MTU: DefaultMTU,
TransportNet: newNet,
}
iface, err := NewWGIFace(opts)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -300,7 +344,16 @@ func Test_UpdatePeer(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil) opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: wgIP,
WGPort: 33100,
WGPrivKey: key,
MTU: DefaultMTU,
TransportNet: newNet,
}
iface, err := NewWGIFace(opts)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -361,7 +414,16 @@ func Test_RemovePeer(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil) opts := WGIFaceOpts{
IFaceName: ifaceName,
Address: wgIP,
WGPort: 33100,
WGPrivKey: key,
MTU: DefaultMTU,
TransportNet: newNet,
}
iface, err := NewWGIFace(opts)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -418,7 +480,15 @@ func Test_ConnectPeers(t *testing.T) {
guid := fmt.Sprintf("{%s}", uuid.New().String()) guid := fmt.Sprintf("{%s}", uuid.New().String())
device.CustomWindowsGUIDString = strings.ToLower(guid) device.CustomWindowsGUIDString = strings.ToLower(guid)
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, peer1wgPort, peer1Key.String(), DefaultMTU, newNet, nil, nil) optsPeer1 := WGIFaceOpts{
IFaceName: peer1ifaceName,
Address: peer1wgIP,
WGPort: peer1wgPort,
WGPrivKey: peer1Key.String(),
MTU: DefaultMTU,
TransportNet: newNet,
}
iface1, err := NewWGIFace(optsPeer1)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -432,7 +502,12 @@ func Test_ConnectPeers(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
peer1endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", peer1wgPort)) localIP, err := getLocalIP()
if err != nil {
t.Fatal(err)
}
peer1endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", localIP, peer1wgPort))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -444,7 +519,17 @@ func Test_ConnectPeers(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, peer2wgPort, peer2Key.String(), DefaultMTU, newNet, nil, nil)
optsPeer2 := WGIFaceOpts{
IFaceName: peer2ifaceName,
Address: peer2wgIP,
WGPort: peer2wgPort,
WGPrivKey: peer2Key.String(),
MTU: DefaultMTU,
TransportNet: newNet,
}
iface2, err := NewWGIFace(optsPeer2)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -458,7 +543,7 @@ func Test_ConnectPeers(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
peer2endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("127.0.0.1:%d", peer2wgPort)) peer2endpoint, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", localIP, peer2wgPort))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -527,3 +612,28 @@ func getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
} }
return wgtypes.Peer{}, fmt.Errorf("peer not found") return wgtypes.Peer{}, fmt.Errorf("peer not found")
} }
func getLocalIP() (string, error) {
// Get all interfaces
addrs, err := net.InterfaceAddrs()
if err != nil {
return "", err
}
for _, addr := range addrs {
ipNet, ok := addr.(*net.IPNet)
if !ok {
continue
}
if ipNet.IP.IsLoopback() {
continue
}
if ipNet.IP.To4() == nil {
continue
}
return ipNet.IP.String(), nil
}
return "", fmt.Errorf("no local IP found")
}

View File

@@ -1,49 +0,0 @@
//go:build (linux && !android) || freebsd
package iface
import (
"fmt"
"runtime"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{}
// move the kernel/usp/netstack preference evaluation to upper layer
if netstack.IsEnabled() {
wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
wgIFace.userspaceBind = true
return wgIFace, nil
}
if device.WireGuardModuleIsLoaded() {
wgIFace.tun = device.NewKernelDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet)
wgIFace.userspaceBind = false
return wgIFace, nil
}
if !device.ModuleTunIsLoaded() {
return nil, fmt.Errorf("couldn't check or load tun module")
}
wgIFace.tun = device.NewUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil)
wgIFace.userspaceBind = true
return wgIFace, nil
}
// CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("CreateOnAndroid function has not implemented on %s platform", runtime.GOOS)
}

View File

@@ -1,41 +0,0 @@
package iface
import (
"fmt"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *device.MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{
userspaceBind: true,
}
if netstack.IsEnabled() {
wgIFace.tun = device.NewNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
return wgIFace, nil
}
wgIFace.tun = device.NewTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
return wgIFace, nil
}
// CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on non mobile")
}
// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only
func (w *WGIface) GetInterfaceGUIDString() (string, error) {
return w.tun.(*device.TunDevice).GetInterfaceGUIDString()
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
) )
type IWGIface interface { type IWGIface interface {
@@ -22,6 +23,7 @@ type IWGIface interface {
ToInterface() *net.Interface ToInterface() *net.Interface
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error UpdateAddr(newAddr string) error
GetProxy() wgproxy.Proxy
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error AddAllowedIP(peerKey string, allowedIP string) error

View File

@@ -9,6 +9,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
) )
type IWGIface interface { type IWGIface interface {
@@ -20,6 +21,7 @@ type IWGIface interface {
ToInterface() *net.Interface ToInterface() *net.Interface
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error UpdateAddr(newAddr string) error
GetProxy() wgproxy.Proxy
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error AddAllowedIP(peerKey string, allowedIP string) error

View File

@@ -0,0 +1,141 @@
package bind
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind"
)
type ProxyBind struct {
Bind *bind.ICEBind
wgAddr *net.UDPAddr
wgEndpoint *bind.Endpoint
remoteConn net.Conn
ctx context.Context
cancel context.CancelFunc
closeMu sync.Mutex
closed bool
pausedMu sync.Mutex
paused bool
isStarted bool
}
// AddTurnConn adds a new connection to the bind.
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
// WireGuard configuration.
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.Bind.SetEndpoint(nbAddr, remoteConn)
if err != nil {
return err
}
p.wgAddr = addr
p.wgEndpoint = addrToEndpoint(addr)
p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx)
return err
}
func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
return p.wgAddr
}
func (p *ProxyBind) Work() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = false
p.pausedMu.Unlock()
// Start the proxy only once
if !p.isStarted {
p.isStarted = true
go p.proxyToLocal(p.ctx)
}
}
func (p *ProxyBind) Pause() {
if p.remoteConn == nil {
return
}
p.pausedMu.Lock()
p.paused = true
p.pausedMu.Unlock()
}
func (p *ProxyBind) CloseConn() error {
if p.cancel == nil {
return fmt.Errorf("proxy not started")
}
return p.close()
}
func (p *ProxyBind) close() error {
p.closeMu.Lock()
defer p.closeMu.Unlock()
if p.closed {
return nil
}
p.closed = true
p.cancel()
p.Bind.RemoveEndpoint(p.wgAddr)
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
return rErr
}
return nil
}
func (p *ProxyBind) proxyToLocal(ctx context.Context) {
defer func() {
if err := p.close(); err != nil {
log.Warnf("failed to close remote conn: %s", err)
}
}()
for {
buf := make([]byte, 1500)
n, err := p.remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return
}
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return
}
p.pausedMu.Lock()
if p.paused {
p.pausedMu.Unlock()
continue
}
msg := bind.RecvMessage{
Endpoint: p.wgEndpoint,
Buffer: buf[:n],
}
p.Bind.RecvChan <- msg
p.pausedMu.Unlock()
}
}
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint {
ip, _ := netip.AddrFromSlice(addr.IP.To4())
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port))
return &bind.Endpoint{AddrPort: addrPort}
}

View File

@@ -5,9 +5,9 @@ import (
"net" "net"
) )
const ( var (
portRangeStart = 3128 portRangeStart = 3128
portRangeEnd = 3228 portRangeEnd = portRangeStart + 100
) )
type portLookup struct { type portLookup struct {

View File

@@ -17,6 +17,9 @@ func Test_portLookup_searchFreePort(t *testing.T) {
func Test_portLookup_on_allocated(t *testing.T) { func Test_portLookup_on_allocated(t *testing.T) {
pl := portLookup{} pl := portLookup{}
portRangeStart = 4128
portRangeEnd = portRangeStart + 100
allocatedPort, err := allocatePort(portRangeStart) allocatedPort, err := allocatePort(portRangeStart)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@@ -119,7 +119,7 @@ func (p *WGEBPFProxy) Free() error {
p.ctxCancel() p.ctxCancel()
var result *multierror.Error var result *multierror.Error
if p.conn != nil { // p.conn will be nil if we have failed to listen if p.conn != nil {
if err := p.conn.Close(); err != nil { if err := p.conn.Close(); err != nil {
result = multierror.Append(result, err) result = multierror.Append(result, err)
} }

View File

@@ -28,7 +28,7 @@ type ProxyWrapper struct {
isStarted bool isStarted bool
} }
func (p *ProxyWrapper) AddTurnConn(ctx context.Context, remoteConn net.Conn) error { func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn) addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn)
if err != nil { if err != nil {
return fmt.Errorf("add turn conn: %w", err) return fmt.Errorf("add turn conn: %w", err)
@@ -77,7 +77,7 @@ func (e *ProxyWrapper) CloseConn() error {
e.cancel() e.cancel()
if err := e.remoteConn.Close(); err != nil { if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
return fmt.Errorf("failed to close remote conn: %w", err) return fmt.Errorf("failed to close remote conn: %w", err)
} }
return nil return nil

View File

@@ -0,0 +1,49 @@
//go:build linux && !android
package wgproxy
import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
type KernelFactory struct {
wgPort int
ebpfProxy *ebpf.WGEBPFProxy
}
func NewKernelFactory(wgPort int) *KernelFactory {
f := &KernelFactory{
wgPort: wgPort,
}
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort)
if err := ebpfProxy.Listen(); err != nil {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
log.Warnf("failed to initialize ebpf proxy, fallback to user space proxy: %s", err)
return f
}
log.Infof("WireGuard Proxy Factory will produce eBPF proxy")
f.ebpfProxy = ebpfProxy
return f
}
func (w *KernelFactory) GetProxy() Proxy {
if w.ebpfProxy == nil {
return udpProxy.NewWGUDPProxy(w.wgPort)
}
return &ebpf.ProxyWrapper{
WgeBPFProxy: w.ebpfProxy,
}
}
func (w *KernelFactory) Free() error {
if w.ebpfProxy == nil {
return nil
}
return w.ebpfProxy.Free()
}

View File

@@ -0,0 +1,29 @@
package wgproxy
import (
log "github.com/sirupsen/logrus"
udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
)
// KernelFactory todo: check eBPF support on FreeBSD
type KernelFactory struct {
wgPort int
}
func NewKernelFactory(wgPort int) *KernelFactory {
log.Infof("WireGuard Proxy Factory will produce UDP proxy")
f := &KernelFactory{
wgPort: wgPort,
}
return f
}
func (w *KernelFactory) GetProxy() Proxy {
return udpProxy.NewWGUDPProxy(w.wgPort)
}
func (w *KernelFactory) Free() error {
return nil
}

View File

@@ -0,0 +1,30 @@
package wgproxy
import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bind"
proxyBind "github.com/netbirdio/netbird/client/iface/wgproxy/bind"
)
type USPFactory struct {
bind *bind.ICEBind
}
func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory {
log.Infof("WireGuard Proxy Factory will produce bind proxy")
f := &USPFactory{
bind: iceBind,
}
return f
}
func (w *USPFactory) GetProxy() Proxy {
return &proxyBind.ProxyBind{
Bind: w.bind,
}
}
func (w *USPFactory) Free() error {
return nil
}

View File

@@ -0,0 +1,15 @@
package wgproxy
import (
"context"
"net"
)
// Proxy is a transfer layer between the relayed connection and the WireGuard
type Proxy interface {
AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error
EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint
Work() // Work start or resume the proxy
Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works.
CloseConn() error
}

View File

@@ -0,0 +1,56 @@
//go:build linux && !android
package wgproxy
import (
"context"
"os"
"testing"
"github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
)
func TestProxyCloseByRemoteConnEBPF(t *testing.T) {
if os.Getenv("GITHUB_ACTIONS") != "true" {
t.Skip("Skipping test as it requires root privileges")
}
ctx := context.Background()
ebpfProxy := ebpf.NewWGEBPFProxy(51831)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %s", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %s", err)
}
}()
tests := []struct {
name string
proxy Proxy
}{
{
name: "ebpf proxy",
proxy: &ebpf.ProxyWrapper{
WgeBPFProxy: ebpfProxy,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
relayedConn := newMockConn()
err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
if err != nil {
t.Errorf("error: %v", err)
}
_ = relayedConn.Close()
if err := tt.proxy.CloseConn(); err != nil {
t.Errorf("error: %v", err)
}
})
}
}

View File

@@ -11,8 +11,8 @@ import (
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/client/internal/wgproxy/ebpf" "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf"
"github.com/netbirdio/netbird/client/internal/wgproxy/usp" udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -84,7 +84,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
}{ }{
{ {
name: "userspace proxy", name: "userspace proxy",
proxy: usp.NewWGUserSpaceProxy(51830), proxy: udpProxy.NewWGUDPProxy(51830),
}, },
} }
@@ -114,7 +114,7 @@ func TestProxyCloseByRemoteConn(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
relayedConn := newMockConn() relayedConn := newMockConn()
err := tt.proxy.AddTurnConn(ctx, relayedConn) err := tt.proxy.AddTurnConn(ctx, nil, relayedConn)
if err != nil { if err != nil {
t.Errorf("error: %v", err) t.Errorf("error: %v", err)
} }

View File

@@ -1,19 +1,21 @@
package usp package udp
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"io"
"net" "net"
"sync" "sync"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/errors" cerrors "github.com/netbirdio/netbird/client/errors"
) )
// WGUserSpaceProxy proxies // WGUDPProxy proxies
type WGUserSpaceProxy struct { type WGUDPProxy struct {
localWGListenPort int localWGListenPort int
remoteConn net.Conn remoteConn net.Conn
@@ -28,10 +30,10 @@ type WGUserSpaceProxy struct {
isStarted bool isStarted bool
} }
// NewWGUserSpaceProxy instantiate a user space WireGuard proxy. This is not a thread safe implementation // NewWGUDPProxy instantiate a UDP based WireGuard proxy. This is not a thread safe implementation
func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy { func NewWGUDPProxy(wgPort int) *WGUDPProxy {
log.Debugf("Initializing new user space proxy with port %d", wgPort) log.Debugf("Initializing new user space proxy with port %d", wgPort)
p := &WGUserSpaceProxy{ p := &WGUDPProxy{
localWGListenPort: wgPort, localWGListenPort: wgPort,
} }
return p return p
@@ -42,7 +44,7 @@ func NewWGUserSpaceProxy(wgPort int) *WGUserSpaceProxy {
// the connection is complete, an error is returned. Once successfully // the connection is complete, an error is returned. Once successfully
// connected, any expiration of the context will not affect the // connected, any expiration of the context will not affect the
// connection. // connection.
func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn) error { func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error {
dialer := net.Dialer{} dialer := net.Dialer{}
localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort))
if err != nil { if err != nil {
@@ -57,7 +59,7 @@ func (p *WGUserSpaceProxy) AddTurnConn(ctx context.Context, remoteConn net.Conn)
return err return err
} }
func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr { func (p *WGUDPProxy) EndpointAddr() *net.UDPAddr {
if p.localConn == nil { if p.localConn == nil {
return nil return nil
} }
@@ -66,7 +68,7 @@ func (p *WGUserSpaceProxy) EndpointAddr() *net.UDPAddr {
} }
// Work starts the proxy or resumes it if it was paused // Work starts the proxy or resumes it if it was paused
func (p *WGUserSpaceProxy) Work() { func (p *WGUDPProxy) Work() {
if p.remoteConn == nil { if p.remoteConn == nil {
return return
} }
@@ -83,7 +85,7 @@ func (p *WGUserSpaceProxy) Work() {
} }
// Pause pauses the proxy from receiving data from the remote peer // Pause pauses the proxy from receiving data from the remote peer
func (p *WGUserSpaceProxy) Pause() { func (p *WGUDPProxy) Pause() {
if p.remoteConn == nil { if p.remoteConn == nil {
return return
} }
@@ -94,14 +96,14 @@ func (p *WGUserSpaceProxy) Pause() {
} }
// CloseConn close the localConn // CloseConn close the localConn
func (p *WGUserSpaceProxy) CloseConn() error { func (p *WGUDPProxy) CloseConn() error {
if p.cancel == nil { if p.cancel == nil {
return fmt.Errorf("proxy not started") return fmt.Errorf("proxy not started")
} }
return p.close() return p.close()
} }
func (p *WGUserSpaceProxy) close() error { func (p *WGUDPProxy) close() error {
p.closeMu.Lock() p.closeMu.Lock()
defer p.closeMu.Unlock() defer p.closeMu.Unlock()
@@ -114,18 +116,18 @@ func (p *WGUserSpaceProxy) close() error {
p.cancel() p.cancel()
var result *multierror.Error var result *multierror.Error
if err := p.remoteConn.Close(); err != nil { if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
result = multierror.Append(result, fmt.Errorf("remote conn: %s", err)) result = multierror.Append(result, fmt.Errorf("remote conn: %s", err))
} }
if err := p.localConn.Close(); err != nil { if err := p.localConn.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("local conn: %s", err)) result = multierror.Append(result, fmt.Errorf("local conn: %s", err))
} }
return errors.FormatErrorOrNil(result) return cerrors.FormatErrorOrNil(result)
} }
// proxyToRemote proxies from Wireguard to the RemoteKey // proxyToRemote proxies from Wireguard to the RemoteKey
func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) { func (p *WGUDPProxy) proxyToRemote(ctx context.Context) {
defer func() { defer func() {
if err := p.close(); err != nil { if err := p.close(); err != nil {
log.Warnf("error in proxy to remote loop: %s", err) log.Warnf("error in proxy to remote loop: %s", err)
@@ -157,21 +159,19 @@ func (p *WGUserSpaceProxy) proxyToRemote(ctx context.Context) {
// proxyToLocal proxies from the Remote peer to local WireGuard // proxyToLocal proxies from the Remote peer to local WireGuard
// if the proxy is paused it will drain the remote conn and drop the packets // if the proxy is paused it will drain the remote conn and drop the packets
func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) { func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
defer func() { defer func() {
if err := p.close(); err != nil { if err := p.close(); err != nil {
log.Warnf("error in proxy to local loop: %s", err) if !errors.Is(err, io.EOF) {
log.Warnf("error in proxy to local loop: %s", err)
}
} }
}() }()
buf := make([]byte, 1500) buf := make([]byte, 1500)
for { for {
n, err := p.remoteConn.Read(buf) n, err := p.remoteConnRead(ctx, buf)
if err != nil { if err != nil {
if ctx.Err() != nil {
return
}
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.RemoteAddr(), err)
return return
} }
@@ -193,3 +193,15 @@ func (p *WGUserSpaceProxy) proxyToLocal(ctx context.Context) {
} }
} }
} }
func (p *WGUDPProxy) remoteConnRead(ctx context.Context, buf []byte) (n int, err error) {
n, err = p.remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {
return
}
log.Errorf("failed to read from remote conn: %s, %s", p.remoteConn.LocalAddr(), err)
return
}
return
}

View File

@@ -3,6 +3,7 @@ package acl
import ( import (
"crypto/md5" "crypto/md5"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@@ -10,14 +11,18 @@ import (
"sync" "sync"
"time" "time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
) )
var ErrSourceRangesEmpty = errors.New("sources range is empty")
// Manager is a ACL rules manager // Manager is a ACL rules manager
type Manager interface { type Manager interface {
ApplyFiltering(networkMap *mgmProto.NetworkMap) ApplyFiltering(networkMap *mgmProto.NetworkMap)
@@ -167,31 +172,40 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
} }
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error { func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error {
var newRouteRules = make(map[id.RuleID]struct{}) newRouteRules := make(map[id.RuleID]struct{}, len(rules))
var merr *multierror.Error
// Apply new rules - firewall manager will return existing rule ID if already present
for _, rule := range rules { for _, rule := range rules {
id, err := d.applyRouteACL(rule) id, err := d.applyRouteACL(rule)
if err != nil { if err != nil {
return fmt.Errorf("apply route ACL: %w", err) if errors.Is(err, ErrSourceRangesEmpty) {
log.Debugf("skipping empty rule with destination %s: %v", rule.Destination, err)
} else {
merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err))
}
continue
} }
newRouteRules[id] = struct{}{} newRouteRules[id] = struct{}{}
} }
// Clean up old firewall rules
for id := range d.routeRules { for id := range d.routeRules {
if _, ok := newRouteRules[id]; !ok { if _, exists := newRouteRules[id]; !exists {
if err := d.firewall.DeleteRouteRule(id); err != nil { if err := d.firewall.DeleteRouteRule(id); err != nil {
log.Errorf("failed to delete route firewall rule: %v", err) merr = multierror.Append(merr, fmt.Errorf("delete route rule: %w", err))
continue
} }
delete(d.routeRules, id) // implicitly deleted from the map
} }
} }
d.routeRules = newRouteRules d.routeRules = newRouteRules
return nil return nberrors.FormatErrorOrNil(merr)
} }
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) { func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) {
if len(rule.SourceRanges) == 0 { if len(rule.SourceRanges) == 0 {
return "", fmt.Errorf("source ranges is empty") return "", ErrSourceRangesEmpty
} }
var sources []netip.Prefix var sources []netip.Prefix

View File

@@ -1,7 +1,6 @@
package acl package acl
import ( import (
"context"
"net" "net"
"testing" "testing"
@@ -52,13 +51,13 @@ func TestDefaultManager(t *testing.T) {
}).AnyTimes() }).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it // we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(context.Background(), ifaceMock) fw, err := firewall.NewFirewall(ifaceMock, nil)
if err != nil { if err != nil {
t.Errorf("create firewall: %v", err) t.Errorf("create firewall: %v", err)
return return
} }
defer func(fw manager.Manager) { defer func(fw manager.Manager) {
_ = fw.Reset() _ = fw.Reset(nil)
}(fw) }(fw)
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)
@@ -345,13 +344,13 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
}).AnyTimes() }).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it // we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(context.Background(), ifaceMock) fw, err := firewall.NewFirewall(ifaceMock, nil)
if err != nil { if err != nil {
t.Errorf("create firewall: %v", err) t.Errorf("create firewall: %v", err)
return return
} }
defer func(fw manager.Manager) { defer func(fw manager.Manager) {
_ = fw.Reset() _ = fw.Reset(nil)
}(fw) }(fw)
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)

View File

@@ -164,7 +164,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg) err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg)
return cfg, err return cfg, err
} }
@@ -185,7 +185,7 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
// WriteOutConfig write put the prepared config to the given path // WriteOutConfig write put the prepared config to the given path
func WriteOutConfig(path string, config *Config) error { func WriteOutConfig(path string, config *Config) error {
return util.WriteJson(path, config) return util.WriteJson(context.Background(), path, config)
} }
// createNewConfig creates a new config generating a new Wireguard key and saving to file // createNewConfig creates a new config generating a new Wireguard key and saving to file
@@ -215,7 +215,7 @@ func update(input ConfigInput) (*Config, error) {
} }
if updated { if updated {
if err := util.WriteJson(input.ConfigPath, config); err != nil { if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil {
return nil, err return nil, err
} }
} }

View File

@@ -62,10 +62,7 @@ func (c *ConnectClient) Run() error {
} }
// RunWithProbes runs the client's main logic with probes attached // RunWithProbes runs the client's main logic with probes attached
func (c *ConnectClient) RunWithProbes( func (c *ConnectClient) RunWithProbes(probes *ProbeHolder, runningChan chan error) error {
probes *ProbeHolder,
runningChan chan error,
) error {
return c.run(MobileDependency{}, probes, runningChan) return c.run(MobileDependency{}, probes, runningChan)
} }
@@ -104,11 +101,7 @@ func (c *ConnectClient) RunOniOS(
return c.run(mobileDependency, nil, nil) return c.run(mobileDependency, nil, nil)
} }
func (c *ConnectClient) run( func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHolder, runningChan chan error) error {
mobileDependency MobileDependency,
probes *ProbeHolder,
runningChan chan error,
) error {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack())) log.Panicf("Panic occurred: %v, stack trace: %s", r, string(debug.Stack()))
@@ -117,12 +110,6 @@ func (c *ConnectClient) run(
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH) log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
// Check if client was not shut down in a clean way and restore DNS config if required.
// Otherwise, we might not be able to connect to the management server to retrieve new config.
if err := dns.CheckUncleanShutdown(c.config.WgIface); err != nil {
log.Errorf("checking unclean shutdown error: %s", err)
}
backOff := &backoff.ExponentialBackOff{ backOff := &backoff.ExponentialBackOff{
InitialInterval: time.Second, InitialInterval: time.Second,
RandomizationFactor: 1, RandomizationFactor: 1,
@@ -170,7 +157,8 @@ func (c *ConnectClient) run(
engineCtx, cancel := context.WithCancel(c.ctx) engineCtx, cancel := context.WithCancel(c.ctx)
defer func() { defer func() {
c.statusRecorder.MarkManagementDisconnected(state.err) _, err := state.Status()
c.statusRecorder.MarkManagementDisconnected(err)
c.statusRecorder.CleanLocalPeerState() c.statusRecorder.CleanLocalPeerState()
cancel() cancel()
}() }()
@@ -220,7 +208,8 @@ func (c *ConnectClient) run(
c.statusRecorder.MarkSignalDisconnected(nil) c.statusRecorder.MarkSignalDisconnected(nil)
defer func() { defer func() {
c.statusRecorder.MarkSignalDisconnected(state.err) _, err := state.Status()
c.statusRecorder.MarkSignalDisconnected(err)
}() }()
// with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal // with the global Wiretrustee config in hand connect (just a connection, no stream yet) Signal
@@ -358,7 +347,11 @@ func (c *ConnectClient) Stop() error {
if c.engine == nil { if c.engine == nil {
return nil return nil
} }
return c.engine.Stop() if err := c.engine.Stop(); err != nil {
return fmt.Errorf("stop engine: %w", err)
}
return nil
} }
func (c *ConnectClient) isContextCancelled() bool { func (c *ConnectClient) isContextCancelled() bool {

View File

@@ -1,6 +1,5 @@
package dns package dns
const ( const (
fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf" fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/db/netbird/manager"
) )

View File

@@ -3,6 +3,5 @@
package dns package dns
const ( const (
fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf" fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf"
fileUncleanShutdownManagerTypeLocation = "/var/lib/netbird/manager"
) )

View File

@@ -9,6 +9,8 @@ import (
"github.com/fsnotify/fsnotify" "github.com/fsnotify/fsnotify"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
var ( var (
@@ -20,7 +22,7 @@ var (
} }
) )
type repairConfFn func([]string, string, *resolvConf) error type repairConfFn func([]string, string, *resolvConf, *statemanager.Manager) error
type repair struct { type repair struct {
operationFile string operationFile string
@@ -40,7 +42,7 @@ func newRepair(operationFile string, updateFn repairConfFn) *repair {
} }
} }
func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string) { func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP string, stateManager *statemanager.Manager) {
if f.inotify != nil { if f.inotify != nil {
return return
} }
@@ -81,7 +83,7 @@ func (f *repair) watchFileChanges(nbSearchDomains []string, nbNameserverIP strin
log.Errorf("failed to rm inotify watch for resolv.conf: %s", err) log.Errorf("failed to rm inotify watch for resolv.conf: %s", err)
} }
err = f.updateFn(nbSearchDomains, nbNameserverIP, rConf) err = f.updateFn(nbSearchDomains, nbNameserverIP, rConf, stateManager)
if err != nil { if err != nil {
log.Errorf("failed to repair resolv.conf: %v", err) log.Errorf("failed to repair resolv.conf: %v", err)
} }

View File

@@ -9,6 +9,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -104,14 +105,14 @@ nameserver 8.8.8.8`,
var changed bool var changed bool
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
updateFn := func([]string, string, *resolvConf) error { updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error {
changed = true changed = true
cancel() cancel()
return nil return nil
} }
r := newRepair(operationFile, updateFn) r := newRepair(operationFile, updateFn)
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1") r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil)
err = os.WriteFile(operationFile, []byte(tt.touchedConfContent), 0755) err = os.WriteFile(operationFile, []byte(tt.touchedConfContent), 0755)
if err != nil { if err != nil {
@@ -151,14 +152,14 @@ searchdomain netbird.cloud something`
var changed bool var changed bool
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
updateFn := func([]string, string, *resolvConf) error { updateFn := func([]string, string, *resolvConf, *statemanager.Manager) error {
changed = true changed = true
cancel() cancel()
return nil return nil
} }
r := newRepair(tmpLink, updateFn) r := newRepair(tmpLink, updateFn)
r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1") r.watchFileChanges([]string{"netbird.cloud"}, "10.0.0.1", nil)
err = os.WriteFile(tmpLink, []byte(modifyContent), 0755) err = os.WriteFile(tmpLink, []byte(modifyContent), 0755)
if err != nil { if err != nil {

View File

@@ -11,6 +11,8 @@ import (
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const ( const (
@@ -36,7 +38,7 @@ type fileConfigurator struct {
nbNameserverIP string nbNameserverIP string
} }
func newFileConfigurator() (hostManager, error) { func newFileConfigurator() (*fileConfigurator, error) {
fc := &fileConfigurator{} fc := &fileConfigurator{}
fc.repair = newRepair(defaultResolvConfPath, fc.updateConfig) fc.repair = newRepair(defaultResolvConfPath, fc.updateConfig)
return fc, nil return fc, nil
@@ -46,7 +48,7 @@ func (f *fileConfigurator) supportCustomPort() bool {
return false return false
} }
func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error { func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
backupFileExist := f.isBackupFileExist() backupFileExist := f.isBackupFileExist()
if !config.RouteAll { if !config.RouteAll {
if backupFileExist { if backupFileExist {
@@ -76,15 +78,15 @@ func (f *fileConfigurator) applyDNSConfig(config HostDNSConfig) error {
f.repair.stopWatchFileChanges() f.repair.stopWatchFileChanges()
err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf) err = f.updateConfig(nbSearchDomains, f.nbNameserverIP, resolvConf, stateManager)
if err != nil { if err != nil {
return err return err
} }
f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP) f.repair.watchFileChanges(nbSearchDomains, f.nbNameserverIP, stateManager)
return nil return nil
} }
func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf) error { func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP string, cfg *resolvConf, stateManager *statemanager.Manager) error {
searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains) searchDomainList := mergeSearchDomains(nbSearchDomains, cfg.searchDomains)
nameServers := generateNsList(nbNameserverIP, cfg) nameServers := generateNsList(nbNameserverIP, cfg)
@@ -107,7 +109,7 @@ func (f *fileConfigurator) updateConfig(nbSearchDomains []string, nbNameserverIP
log.Infof("created a NetBird managed %s file with the DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList) log.Infof("created a NetBird managed %s file with the DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList)
// create another backup for unclean shutdown detection right after overwriting the original resolv.conf // create another backup for unclean shutdown detection right after overwriting the original resolv.conf
if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, fileManager, nbNameserverIP); err != nil { if err := createUncleanShutdownIndicator(fileDefaultResolvConfBackupLocation, nbNameserverIP, stateManager); err != nil {
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err) log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err)
} }
@@ -145,10 +147,6 @@ func (f *fileConfigurator) restore() error {
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err) return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileDefaultResolvConfBackupLocation, err)
} }
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
}
return os.RemoveAll(fileDefaultResolvConfBackupLocation) return os.RemoveAll(fileDefaultResolvConfBackupLocation)
} }
@@ -176,7 +174,7 @@ func (f *fileConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Add
return restoreResolvConfFile() return restoreResolvConfFile()
} }
log.Info("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: not restoring") log.Infof("restoring unclean shutdown: first current nameserver differs from saved nameserver pre-netbird: %s (current) vs %s (stored): not restoring", currentDNSAddress, storedDNSAddress)
return nil return nil
} }
@@ -192,10 +190,6 @@ func restoreResolvConfFile() error {
return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation, err) return fmt.Errorf("restoring %s from %s: %w", defaultResolvConfPath, fileUncleanShutdownResolvConfLocation, err)
} }
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown resolv.conf file: %s", err)
}
return nil return nil
} }

View File

@@ -5,14 +5,14 @@ import (
"net/netip" "net/netip"
"strings" "strings"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
) )
type hostManager interface { type hostManager interface {
applyDNSConfig(config HostDNSConfig) error applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error
restoreHostDNS() error restoreHostDNS() error
supportCustomPort() bool supportCustomPort() bool
restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error
} }
type SystemDNSSettings struct { type SystemDNSSettings struct {
@@ -35,15 +35,15 @@ type DomainConfig struct {
} }
type mockHostConfigurator struct { type mockHostConfigurator struct {
applyDNSConfigFunc func(config HostDNSConfig) error applyDNSConfigFunc func(config HostDNSConfig, stateManager *statemanager.Manager) error
restoreHostDNSFunc func() error restoreHostDNSFunc func() error
supportCustomPortFunc func() bool supportCustomPortFunc func() bool
restoreUncleanShutdownDNSFunc func(*netip.Addr) error restoreUncleanShutdownDNSFunc func(*netip.Addr) error
} }
func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig) error { func (m *mockHostConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
if m.applyDNSConfigFunc != nil { if m.applyDNSConfigFunc != nil {
return m.applyDNSConfigFunc(config) return m.applyDNSConfigFunc(config, stateManager)
} }
return fmt.Errorf("method applyDNSSettings is not implemented") return fmt.Errorf("method applyDNSSettings is not implemented")
} }
@@ -62,16 +62,9 @@ func (m *mockHostConfigurator) supportCustomPort() bool {
return false return false
} }
func (m *mockHostConfigurator) restoreUncleanShutdownDNS(storedDNSAddress *netip.Addr) error {
if m.restoreUncleanShutdownDNSFunc != nil {
return m.restoreUncleanShutdownDNSFunc(storedDNSAddress)
}
return fmt.Errorf("method restoreUncleanShutdownDNS is not implemented")
}
func newNoopHostMocker() hostManager { func newNoopHostMocker() hostManager {
return &mockHostConfigurator{ return &mockHostConfigurator{
applyDNSConfigFunc: func(config HostDNSConfig) error { return nil }, applyDNSConfigFunc: func(config HostDNSConfig, stateManager *statemanager.Manager) error { return nil },
restoreHostDNSFunc: func() error { return nil }, restoreHostDNSFunc: func() error { return nil },
supportCustomPortFunc: func() bool { return true }, supportCustomPortFunc: func() bool { return true },
restoreUncleanShutdownDNSFunc: func(*netip.Addr) error { return nil }, restoreUncleanShutdownDNSFunc: func(*netip.Addr) error { return nil },

View File

@@ -1,15 +1,17 @@
package dns package dns
import "net/netip" import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
type androidHostManager struct { type androidHostManager struct {
} }
func newHostManager() (hostManager, error) { func newHostManager() (*androidHostManager, error) {
return &androidHostManager{}, nil return &androidHostManager{}, nil
} }
func (a androidHostManager) applyDNSConfig(config HostDNSConfig) error { func (a androidHostManager) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error {
return nil return nil
} }
@@ -20,7 +22,3 @@ func (a androidHostManager) restoreHostDNS() error {
func (a androidHostManager) supportCustomPort() bool { func (a androidHostManager) supportCustomPort() bool {
return false return false
} }
func (a androidHostManager) restoreUncleanShutdownDNS(*netip.Addr) error {
return nil
}

View File

@@ -8,12 +8,13 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"net/netip"
"os/exec" "os/exec"
"strconv" "strconv"
"strings" "strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const ( const (
@@ -37,7 +38,7 @@ type systemConfigurator struct {
systemDNSSettings SystemDNSSettings systemDNSSettings SystemDNSSettings
} }
func newHostManager() (hostManager, error) { func newHostManager() (*systemConfigurator, error) {
return &systemConfigurator{ return &systemConfigurator{
createdKeys: make(map[string]struct{}), createdKeys: make(map[string]struct{}),
}, nil }, nil
@@ -47,12 +48,11 @@ func (s *systemConfigurator) supportCustomPort() bool {
return true return true
} }
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig) error { func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
var err error var err error
// create a file for unclean shutdown detection if err := stateManager.UpdateState(&ShutdownState{}); err != nil {
if err := createUncleanShutdownIndicator(); err != nil { log.Errorf("failed to update shutdown state: %s", err)
log.Errorf("failed to create unclean shutdown file: %s", err)
} }
var ( var (
@@ -123,10 +123,6 @@ func (s *systemConfigurator) restoreHostDNS() error {
} }
} }
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown file: %s", err)
}
return nil return nil
} }
@@ -320,7 +316,7 @@ func (s *systemConfigurator) getPrimaryService() (string, string, error) {
return primaryService, router, nil return primaryService, router, nil
} }
func (s *systemConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error { func (s *systemConfigurator) restoreUncleanShutdownDNS() error {
if err := s.restoreHostDNS(); err != nil { if err := s.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via scutil: %w", err) return fmt.Errorf("restoring dns via scutil: %w", err)
} }

View File

@@ -3,9 +3,10 @@ package dns
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/netip"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
type iosHostManager struct { type iosHostManager struct {
@@ -13,13 +14,13 @@ type iosHostManager struct {
config HostDNSConfig config HostDNSConfig
} }
func newHostManager(dnsManager IosDnsManager) (hostManager, error) { func newHostManager(dnsManager IosDnsManager) (*iosHostManager, error) {
return &iosHostManager{ return &iosHostManager{
dnsManager: dnsManager, dnsManager: dnsManager,
}, nil }, nil
} }
func (a iosHostManager) applyDNSConfig(config HostDNSConfig) error { func (a iosHostManager) applyDNSConfig(config HostDNSConfig, _ *statemanager.Manager) error {
jsonData, err := json.Marshal(config) jsonData, err := json.Marshal(config)
if err != nil { if err != nil {
return fmt.Errorf("marshal: %w", err) return fmt.Errorf("marshal: %w", err)
@@ -37,7 +38,3 @@ func (a iosHostManager) restoreHostDNS() error {
func (a iosHostManager) supportCustomPort() bool { func (a iosHostManager) supportCustomPort() bool {
return false return false
} }
func (a iosHostManager) restoreUncleanShutdownDNS(*netip.Addr) error {
return nil
}

View File

@@ -4,9 +4,9 @@ package dns
import ( import (
"bufio" "bufio"
"errors"
"fmt" "fmt"
"io" "io"
"net/netip"
"os" "os"
"strings" "strings"
@@ -21,27 +21,8 @@ const (
resolvConfManager resolvConfManager
) )
var ErrUnknownOsManagerType = errors.New("unknown os manager type")
type osManagerType int type osManagerType int
func newOsManagerType(osManager string) (osManagerType, error) {
switch osManager {
case "netbird":
return fileManager, nil
case "file":
return netbirdManager, nil
case "networkManager":
return networkManager, nil
case "systemd":
return systemdManager, nil
case "resolvconf":
return resolvConfManager, nil
default:
return 0, ErrUnknownOsManagerType
}
}
func (t osManagerType) String() string { func (t osManagerType) String() string {
switch t { switch t {
case netbirdManager: case netbirdManager:
@@ -59,6 +40,11 @@ func (t osManagerType) String() string {
} }
} }
type restoreHostManager interface {
hostManager
restoreUncleanShutdownDNS(*netip.Addr) error
}
func newHostManager(wgInterface string) (hostManager, error) { func newHostManager(wgInterface string) (hostManager, error) {
osManager, err := getOSDNSManagerType() osManager, err := getOSDNSManagerType()
if err != nil { if err != nil {
@@ -69,7 +55,7 @@ func newHostManager(wgInterface string) (hostManager, error) {
return newHostManagerFromType(wgInterface, osManager) return newHostManagerFromType(wgInterface, osManager)
} }
func newHostManagerFromType(wgInterface string, osManager osManagerType) (hostManager, error) { func newHostManagerFromType(wgInterface string, osManager osManagerType) (restoreHostManager, error) {
switch osManager { switch osManager {
case networkManager: case networkManager:
return newNetworkManagerDbusConfigurator(wgInterface) return newNetworkManagerDbusConfigurator(wgInterface)

View File

@@ -3,11 +3,12 @@ package dns
import ( import (
"fmt" "fmt"
"io" "io"
"net/netip"
"strings" "strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows/registry" "golang.org/x/sys/windows/registry"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const ( const (
@@ -31,7 +32,7 @@ type registryConfigurator struct {
routingAll bool routingAll bool
} }
func newHostManager(wgInterface WGIface) (hostManager, error) { func newHostManager(wgInterface WGIface) (*registryConfigurator, error) {
guid, err := wgInterface.GetInterfaceGUIDString() guid, err := wgInterface.GetInterfaceGUIDString()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -39,7 +40,7 @@ func newHostManager(wgInterface WGIface) (hostManager, error) {
return newHostManagerWithGuid(guid) return newHostManagerWithGuid(guid)
} }
func newHostManagerWithGuid(guid string) (hostManager, error) { func newHostManagerWithGuid(guid string) (*registryConfigurator, error) {
return &registryConfigurator{ return &registryConfigurator{
guid: guid, guid: guid,
}, nil }, nil
@@ -49,7 +50,7 @@ func (r *registryConfigurator) supportCustomPort() bool {
return false return false
} }
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error { func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
var err error var err error
if config.RouteAll { if config.RouteAll {
err = r.addDNSSetupForAll(config.ServerIP) err = r.addDNSSetupForAll(config.ServerIP)
@@ -65,9 +66,8 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig) error {
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
} }
// create a file for unclean shutdown detection if err := stateManager.UpdateState(&ShutdownState{Guid: r.guid}); err != nil {
if err := createUncleanShutdownIndicator(r.guid); err != nil { log.Errorf("failed to update shutdown state: %s", err)
log.Errorf("failed to create unclean shutdown file: %s", err)
} }
var ( var (
@@ -160,10 +160,6 @@ func (r *registryConfigurator) restoreHostDNS() error {
return fmt.Errorf("remove interface registry key: %w", err) return fmt.Errorf("remove interface registry key: %w", err)
} }
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown file: %s", err)
}
return nil return nil
} }
@@ -221,7 +217,7 @@ func (r *registryConfigurator) getInterfaceRegistryKey() (registry.Key, error) {
return regKey, nil return regKey, nil
} }
func (r *registryConfigurator) restoreUncleanShutdownDNS(*netip.Addr) error { func (r *registryConfigurator) restoreUncleanShutdownDNS() error {
if err := r.restoreHostDNS(); err != nil { if err := r.restoreHostDNS(); err != nil {
return fmt.Errorf("restoring dns via registry: %w", err) return fmt.Errorf("restoring dns via registry: %w", err)
} }

View File

@@ -16,6 +16,7 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbversion "github.com/netbirdio/netbird/version" nbversion "github.com/netbirdio/netbird/version"
) )
@@ -53,6 +54,7 @@ var supportedNetworkManagerVersionConstraints = []string{
type networkManagerDbusConfigurator struct { type networkManagerDbusConfigurator struct {
dbusLinkObject dbus.ObjectPath dbusLinkObject dbus.ObjectPath
routingAll bool routingAll bool
ifaceName string
} }
// the types below are based on dbus specification, each field is mapped to a dbus type // the types below are based on dbus specification, each field is mapped to a dbus type
@@ -77,7 +79,7 @@ func (s networkManagerConnSettings) cleanDeprecatedSettings() {
} }
} }
func newNetworkManagerDbusConfigurator(wgInterface string) (hostManager, error) { func newNetworkManagerDbusConfigurator(wgInterface string) (*networkManagerDbusConfigurator, error) {
obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode) obj, closeConn, err := getDbusObject(networkManagerDest, networkManagerDbusObjectNode)
if err != nil { if err != nil {
return nil, fmt.Errorf("get nm dbus: %w", err) return nil, fmt.Errorf("get nm dbus: %w", err)
@@ -93,6 +95,7 @@ func newNetworkManagerDbusConfigurator(wgInterface string) (hostManager, error)
return &networkManagerDbusConfigurator{ return &networkManagerDbusConfigurator{
dbusLinkObject: dbus.ObjectPath(s), dbusLinkObject: dbus.ObjectPath(s),
ifaceName: wgInterface,
}, nil }, nil
} }
@@ -100,7 +103,7 @@ func (n *networkManagerDbusConfigurator) supportCustomPort() bool {
return false return false
} }
func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) error { func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
connSettings, configVersion, err := n.getAppliedConnectionSettings() connSettings, configVersion, err := n.getAppliedConnectionSettings()
if err != nil { if err != nil {
return fmt.Errorf("retrieving the applied connection settings, error: %w", err) return fmt.Errorf("retrieving the applied connection settings, error: %w", err)
@@ -151,10 +154,12 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig) er
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority) connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSPriorityKey] = dbus.MakeVariant(priority)
connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList) connSettings[networkManagerDbusIPv4Key][networkManagerDbusDNSSearchKey] = dbus.MakeVariant(newDomainList)
// create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file. state := &ShutdownState{
// The file content itself is not important for network-manager restoration ManagerType: networkManager,
if err := createUncleanShutdownIndicator(defaultResolvConfPath, networkManager, dnsIP.String()); err != nil { WgIface: n.ifaceName,
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err) }
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
} }
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains) log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
@@ -171,10 +176,6 @@ func (n *networkManagerDbusConfigurator) restoreHostDNS() error {
return fmt.Errorf("delete connection settings: %w", err) return fmt.Errorf("delete connection settings: %w", err)
} }
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
}
return nil return nil
} }

View File

@@ -9,6 +9,8 @@ import (
"os/exec" "os/exec"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const resolvconfCommand = "resolvconf" const resolvconfCommand = "resolvconf"
@@ -22,7 +24,7 @@ type resolvconf struct {
} }
// supported "openresolv" only // supported "openresolv" only
func newResolvConfConfigurator(wgInterface string) (hostManager, error) { func newResolvConfConfigurator(wgInterface string) (*resolvconf, error) {
resolvConfEntries, err := parseDefaultResolvConf() resolvConfEntries, err := parseDefaultResolvConf()
if err != nil { if err != nil {
log.Errorf("could not read original search domains from %s: %s", defaultResolvConfPath, err) log.Errorf("could not read original search domains from %s: %s", defaultResolvConfPath, err)
@@ -40,7 +42,7 @@ func (r *resolvconf) supportCustomPort() bool {
return false return false
} }
func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error { func (r *resolvconf) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
var err error var err error
if !config.RouteAll { if !config.RouteAll {
err = r.restoreHostDNS() err = r.restoreHostDNS()
@@ -60,9 +62,12 @@ func (r *resolvconf) applyDNSConfig(config HostDNSConfig) error {
append([]string{config.ServerIP}, r.originalNameServers...), append([]string{config.ServerIP}, r.originalNameServers...),
options) options)
// create a backup for unclean shutdown detection before the resolv.conf is changed state := &ShutdownState{
if err := createUncleanShutdownIndicator(defaultResolvConfPath, resolvConfManager, config.ServerIP); err != nil { ManagerType: resolvConfManager,
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err) WgIface: r.ifaceName,
}
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
} }
err = r.applyConfig(buf) err = r.applyConfig(buf)
@@ -79,11 +84,7 @@ func (r *resolvconf) restoreHostDNS() error {
cmd := exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName) cmd := exec.Command(resolvconfCommand, "-f", "-d", r.ifaceName)
_, err := cmd.Output() _, err := cmd.Output()
if err != nil { if err != nil {
return fmt.Errorf("removing resolvconf configuration for %s interface, error: %w", r.ifaceName, err) return fmt.Errorf("removing resolvconf configuration for %s interface: %w", r.ifaceName, err)
}
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
} }
return nil return nil
@@ -95,7 +96,7 @@ func (r *resolvconf) applyConfig(content bytes.Buffer) error {
cmd.Stdin = &content cmd.Stdin = &content
_, err := cmd.Output() _, err := cmd.Output()
if err != nil { if err != nil {
return fmt.Errorf("applying resolvconf configuration for %s interface, error: %w", r.ifaceName, err) return fmt.Errorf("applying resolvconf configuration for %s interface: %w", r.ifaceName, err)
} }
return nil return nil
} }

View File

@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
) )
@@ -63,6 +64,7 @@ type DefaultServer struct {
iosDnsManager IosDnsManager iosDnsManager IosDnsManager
statusRecorder *peer.Status statusRecorder *peer.Status
stateManager *statemanager.Manager
} }
type handlerWithStop interface { type handlerWithStop interface {
@@ -77,12 +79,7 @@ type muxUpdate struct {
} }
// NewDefaultServer returns a new dns server // NewDefaultServer returns a new dns server
func NewDefaultServer( func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string, statusRecorder *peer.Status, stateManager *statemanager.Manager) (*DefaultServer, error) {
ctx context.Context,
wgInterface WGIface,
customAddress string,
statusRecorder *peer.Status,
) (*DefaultServer, error) {
var addrPort *netip.AddrPort var addrPort *netip.AddrPort
if customAddress != "" { if customAddress != "" {
parsedAddrPort, err := netip.ParseAddrPort(customAddress) parsedAddrPort, err := netip.ParseAddrPort(customAddress)
@@ -99,7 +96,7 @@ func NewDefaultServer(
dnsService = newServiceViaListener(wgInterface, addrPort) dnsService = newServiceViaListener(wgInterface, addrPort)
} }
return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder), nil return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager), nil
} }
// NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems // NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems
@@ -112,7 +109,7 @@ func NewDefaultServerPermanentUpstream(
statusRecorder *peer.Status, statusRecorder *peer.Status,
) *DefaultServer { ) *DefaultServer {
log.Debugf("host dns address list is: %v", hostsDnsList) log.Debugf("host dns address list is: %v", hostsDnsList)
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder) ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil)
ds.hostsDNSHolder.set(hostsDnsList) ds.hostsDNSHolder.set(hostsDnsList)
ds.permanent = true ds.permanent = true
ds.addHostRootZone() ds.addHostRootZone()
@@ -130,12 +127,12 @@ func NewDefaultServerIos(
iosDnsManager IosDnsManager, iosDnsManager IosDnsManager,
statusRecorder *peer.Status, statusRecorder *peer.Status,
) *DefaultServer { ) *DefaultServer {
ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder) ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil)
ds.iosDnsManager = iosDnsManager ds.iosDnsManager = iosDnsManager
return ds return ds
} }
func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status) *DefaultServer { func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status, stateManager *statemanager.Manager) *DefaultServer {
ctx, stop := context.WithCancel(ctx) ctx, stop := context.WithCancel(ctx)
defaultServer := &DefaultServer{ defaultServer := &DefaultServer{
ctx: ctx, ctx: ctx,
@@ -147,6 +144,7 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi
}, },
wgInterface: wgInterface, wgInterface: wgInterface,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
stateManager: stateManager,
hostsDNSHolder: newHostsDNSHolder(), hostsDNSHolder: newHostsDNSHolder(),
} }
@@ -169,6 +167,7 @@ func (s *DefaultServer) Initialize() (err error) {
} }
} }
s.stateManager.RegisterState(&ShutdownState{})
s.hostManager, err = s.initialize() s.hostManager, err = s.initialize()
if err != nil { if err != nil {
return fmt.Errorf("initialize: %w", err) return fmt.Errorf("initialize: %w", err)
@@ -191,9 +190,10 @@ func (s *DefaultServer) Stop() {
s.ctxCancel() s.ctxCancel()
if s.hostManager != nil { if s.hostManager != nil {
err := s.hostManager.restoreHostDNS() if err := s.hostManager.restoreHostDNS(); err != nil {
if err != nil { log.Error("failed to restore host DNS settings: ", err)
log.Error(err) } else if err := s.stateManager.DeleteState(&ShutdownState{}); err != nil {
log.Errorf("failed to delete shutdown dns state: %v", err)
} }
} }
@@ -318,10 +318,17 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
hostUpdate.RouteAll = false hostUpdate.RouteAll = false
} }
if err = s.hostManager.applyDNSConfig(hostUpdate); err != nil { if err = s.hostManager.applyDNSConfig(hostUpdate, s.stateManager); err != nil {
log.Error(err) log.Error(err)
} }
go func() {
// persist dns state right away
if err := s.stateManager.PersistState(s.ctx); err != nil {
log.Errorf("Failed to persist dns state: %v", err)
}
}()
if s.searchDomainNotifier != nil { if s.searchDomainNotifier != nil {
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains()) s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
} }
@@ -521,10 +528,16 @@ func (s *DefaultServer) upstreamCallbacks(
} }
} }
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err) l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
} }
go func() {
if err := s.stateManager.PersistState(s.ctx); err != nil {
l.Errorf("Failed to persist dns state: %v", err)
}
}()
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 { if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
s.addHostRootZone() s.addHostRootZone()
} }
@@ -551,7 +564,7 @@ func (s *DefaultServer) upstreamCallbacks(
s.currentConfig.RouteAll = true s.currentConfig.RouteAll = true
s.service.RegisterMux(nbdns.RootZone, handler) s.service.RegisterMux(nbdns.RootZone, handler)
} }
if err := s.hostManager.applyDNSConfig(s.currentConfig); err != nil { if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
} }

View File

@@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
pfmock "github.com/netbirdio/netbird/client/iface/mocks" pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
@@ -267,7 +268,17 @@ func TestUpdateDNSServer(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
opts := iface.WGIFaceOpts{
IFaceName: fmt.Sprintf("utun230%d", n),
Address: fmt.Sprintf("100.66.100.%d/32", n+1),
WGPort: 33100,
WGPrivKey: privKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgIface, err := iface.NewWGIFace(opts)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -281,7 +292,7 @@ func TestUpdateDNSServer(t *testing.T) {
t.Log(err) t.Log(err)
} }
}() }()
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}) dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -345,7 +356,15 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
} }
privKey, _ := wgtypes.GeneratePrivateKey() privKey, _ := wgtypes.GeneratePrivateKey()
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil) opts := iface.WGIFaceOpts{
IFaceName: "utun2301",
Address: "100.66.100.1/32",
WGPort: 33100,
WGPrivKey: privKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgIface, err := iface.NewWGIFace(opts)
if err != nil { if err != nil {
t.Errorf("build interface wireguard: %v", err) t.Errorf("build interface wireguard: %v", err)
return return
@@ -382,7 +401,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
return return
} }
dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}) dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil)
if err != nil { if err != nil {
t.Errorf("create DNS server: %v", err) t.Errorf("create DNS server: %v", err)
return return
@@ -477,7 +496,7 @@ func TestDNSServerStartStop(t *testing.T) {
for _, testCase := range testCases { for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) { t.Run(testCase.name, func(t *testing.T) {
dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}) dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil)
if err != nil { if err != nil {
t.Fatalf("%v", err) t.Fatalf("%v", err)
} }
@@ -536,6 +555,7 @@ func TestDNSServerStartStop(t *testing.T) {
func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
hostManager := &mockHostConfigurator{} hostManager := &mockHostConfigurator{}
server := DefaultServer{ server := DefaultServer{
ctx: context.Background(),
service: NewServiceViaMemory(&mocWGIface{}), service: NewServiceViaMemory(&mocWGIface{}),
localResolver: &localResolver{ localResolver: &localResolver{
registeredMap: make(registrationMap), registeredMap: make(registrationMap),
@@ -552,7 +572,7 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) {
} }
var domainsUpdate string var domainsUpdate string
hostManager.applyDNSConfigFunc = func(config HostDNSConfig) error { hostManager.applyDNSConfigFunc = func(config HostDNSConfig, statemanager *statemanager.Manager) error {
domains := []string{} domains := []string{}
for _, item := range config.Domains { for _, item := range config.Domains {
if item.Disabled { if item.Disabled {
@@ -762,7 +782,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
Port: 53, Port: 53,
}, },
}, },
Domains: []string{"customdomain.com"}, Domains: []string{"google.com"},
Primary: false, Primary: false,
}, },
}, },
@@ -784,7 +804,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
if ips[0] != zoneRecords[0].RData { if ips[0] != zoneRecords[0].RData {
t.Fatalf("invalid zone record: %v", err) t.Fatalf("invalid zone record: %v", err)
} }
_, err = resolver.LookupHost(context.Background(), "customdomain.com") _, err = resolver.LookupHost(context.Background(), "google.com")
if err != nil { if err != nil {
t.Errorf("failed to resolve: %s", err) t.Errorf("failed to resolve: %s", err)
} }
@@ -803,7 +823,17 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
} }
privKey, _ := wgtypes.GeneratePrivateKey() privKey, _ := wgtypes.GeneratePrivateKey()
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
opts := iface.WGIFaceOpts{
IFaceName: "utun2301",
Address: "100.66.100.2/24",
WGPort: 33100,
WGPrivKey: privKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgIface, err := iface.NewWGIFace(opts)
if err != nil { if err != nil {
t.Fatalf("build interface wireguard: %v", err) t.Fatalf("build interface wireguard: %v", err)
return nil, err return nil, err

View File

@@ -1,5 +1,5 @@
package dns package dns
func (s *DefaultServer) initialize() (manager hostManager, err error) { func (s *DefaultServer) initialize() (hostManager, error) {
return newHostManager(s.wgInterface) return newHostManager(s.wgInterface)
} }

View File

@@ -7,7 +7,7 @@ import (
var errNotImplemented = errors.New("not implemented") var errNotImplemented = errors.New("not implemented")
func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) { func newSystemdDbusConfigurator(string) (restoreHostManager, error) {
return nil, fmt.Errorf("systemd dns management: %w on freebsd", errNotImplemented) return nil, fmt.Errorf("systemd dns management: %w on freebsd", errNotImplemented)
} }

View File

@@ -15,6 +15,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
) )
@@ -38,6 +39,7 @@ const (
type systemdDbusConfigurator struct { type systemdDbusConfigurator struct {
dbusLinkObject dbus.ObjectPath dbusLinkObject dbus.ObjectPath
routingAll bool routingAll bool
ifaceName string
} }
// the types below are based on dbus specification, each field is mapped to a dbus type // the types below are based on dbus specification, each field is mapped to a dbus type
@@ -55,7 +57,7 @@ type systemdDbusLinkDomainsInput struct {
MatchOnly bool MatchOnly bool
} }
func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) { func newSystemdDbusConfigurator(wgInterface string) (*systemdDbusConfigurator, error) {
iface, err := net.InterfaceByName(wgInterface) iface, err := net.InterfaceByName(wgInterface)
if err != nil { if err != nil {
return nil, fmt.Errorf("get interface: %w", err) return nil, fmt.Errorf("get interface: %w", err)
@@ -77,6 +79,7 @@ func newSystemdDbusConfigurator(wgInterface string) (hostManager, error) {
return &systemdDbusConfigurator{ return &systemdDbusConfigurator{
dbusLinkObject: dbus.ObjectPath(s), dbusLinkObject: dbus.ObjectPath(s),
ifaceName: wgInterface,
}, nil }, nil
} }
@@ -84,7 +87,7 @@ func (s *systemdDbusConfigurator) supportCustomPort() bool {
return true return true
} }
func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error { func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
parsedIP, err := netip.ParseAddr(config.ServerIP) parsedIP, err := netip.ParseAddr(config.ServerIP)
if err != nil { if err != nil {
return fmt.Errorf("unable to parse ip address, error: %w", err) return fmt.Errorf("unable to parse ip address, error: %w", err)
@@ -135,10 +138,12 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig) error {
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort) log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
} }
// create a backup for unclean shutdown detection before adding domains, as these might end up in the resolv.conf file. state := &ShutdownState{
// The file content itself is not important for systemd restoration ManagerType: systemdManager,
if err := createUncleanShutdownIndicator(defaultResolvConfPath, systemdManager, parsedIP.String()); err != nil { WgIface: s.ifaceName,
log.Errorf("failed to create unclean shutdown resolv.conf backup: %s", err) }
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update shutdown state: %s", err)
} }
log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains) log.Infof("adding %d search domains and %d match domains. Search list: %s , Match list: %s", len(searchDomains), len(matchDomains), searchDomains, matchDomains)
@@ -174,10 +179,6 @@ func (s *systemdDbusConfigurator) restoreHostDNS() error {
return fmt.Errorf("unable to revert link configuration, got error: %w", err) return fmt.Errorf("unable to revert link configuration, got error: %w", err)
} }
if err := removeUncleanShutdownIndicator(); err != nil {
log.Errorf("failed to remove unclean shutdown resolv.conf backup: %s", err)
}
return s.flushCaches() return s.flushCaches()
} }

View File

@@ -1,5 +0,0 @@
package dns
func CheckUncleanShutdown(string) error {
return nil
}

View File

@@ -3,57 +3,25 @@
package dns package dns
import ( import (
"errors"
"fmt" "fmt"
"io/fs"
"os"
"path/filepath"
log "github.com/sirupsen/logrus"
) )
const fileUncleanShutdownFileLocation = "/var/lib/netbird/unclean_shutdown_dns" type ShutdownState struct {
}
func CheckUncleanShutdown(string) error { func (s *ShutdownState) Name() string {
if _, err := os.Stat(fileUncleanShutdownFileLocation); err != nil { return "dns_state"
if errors.Is(err, fs.ErrNotExist) { }
// no file -> clean shutdown
return nil
} else {
return fmt.Errorf("state: %w", err)
}
}
log.Warnf("detected unclean shutdown, file %s exists. Restoring unclean shutdown dns settings.", fileUncleanShutdownFileLocation)
func (s *ShutdownState) Cleanup() error {
manager, err := newHostManager() manager, err := newHostManager()
if err != nil { if err != nil {
return fmt.Errorf("create host manager: %w", err) return fmt.Errorf("create host manager: %w", err)
} }
if err := manager.restoreUncleanShutdownDNS(nil); err != nil { if err := manager.restoreUncleanShutdownDNS(); err != nil {
return fmt.Errorf("restore unclean shutdown backup: %w", err) return fmt.Errorf("restore unclean shutdown dns: %w", err)
} }
return nil return nil
} }
func createUncleanShutdownIndicator() error {
dir := filepath.Dir(fileUncleanShutdownFileLocation)
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
return fmt.Errorf("create dir %s: %w", dir, err)
}
if err := os.WriteFile(fileUncleanShutdownFileLocation, nil, 0644); err != nil { //nolint:gosec
return fmt.Errorf("create %s: %w", fileUncleanShutdownFileLocation, err)
}
return nil
}
func removeUncleanShutdownIndicator() error {
if err := os.Remove(fileUncleanShutdownFileLocation); err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("remove %s: %w", fileUncleanShutdownFileLocation, err)
}
return nil
}

View File

@@ -1,5 +0,0 @@
package dns
func CheckUncleanShutdown(string) error {
return nil
}

View File

@@ -0,0 +1,14 @@
//go:build ios || android
package dns
type ShutdownState struct {
}
func (s *ShutdownState) Name() string {
return "dns_state"
}
func (s *ShutdownState) Cleanup() error {
return nil
}

View File

@@ -3,66 +3,44 @@
package dns package dns
import ( import (
"errors"
"fmt" "fmt"
"io/fs"
"net/netip" "net/netip"
"os" "os"
"path/filepath" "path/filepath"
"strings"
log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
func CheckUncleanShutdown(wgIface string) error { type ShutdownState struct {
if _, err := os.Stat(fileUncleanShutdownResolvConfLocation); err != nil { ManagerType osManagerType
if errors.Is(err, fs.ErrNotExist) { DNSAddress netip.Addr
// no file -> clean shutdown WgIface string
return nil }
} else {
return fmt.Errorf("state: %w", err)
}
}
log.Warnf("detected unclean shutdown, file %s exists", fileUncleanShutdownResolvConfLocation) func (s *ShutdownState) Name() string {
return "dns_state"
}
managerData, err := os.ReadFile(fileUncleanShutdownManagerTypeLocation) func (s *ShutdownState) Cleanup() error {
if err != nil { manager, err := newHostManagerFromType(s.WgIface, s.ManagerType)
return fmt.Errorf("read %s: %w", fileUncleanShutdownManagerTypeLocation, err)
}
managerFields := strings.Split(string(managerData), ",")
if len(managerFields) < 2 {
return errors.New("split manager data: insufficient number of fields")
}
osManagerTypeStr, dnsAddressStr := managerFields[0], managerFields[1]
dnsAddress, err := netip.ParseAddr(dnsAddressStr)
if err != nil {
return fmt.Errorf("parse dns address %s failed: %w", dnsAddressStr, err)
}
log.Warnf("restoring unclean shutdown dns settings via previously detected manager: %s", osManagerTypeStr)
// determine os manager type, so we can invoke the respective restore action
osManagerType, err := newOsManagerType(osManagerTypeStr)
if err != nil {
return fmt.Errorf("detect previous host manager: %w", err)
}
manager, err := newHostManagerFromType(wgIface, osManagerType)
if err != nil { if err != nil {
return fmt.Errorf("create previous host manager: %w", err) return fmt.Errorf("create previous host manager: %w", err)
} }
if err := manager.restoreUncleanShutdownDNS(&dnsAddress); err != nil { if err := manager.restoreUncleanShutdownDNS(&s.DNSAddress); err != nil {
return fmt.Errorf("restore unclean shutdown backup: %w", err) return fmt.Errorf("restore unclean shutdown dns: %w", err)
} }
return nil return nil
} }
func createUncleanShutdownIndicator(sourcePath string, managerType osManagerType, dnsAddress string) error { // TODO: move file contents to state manager
func createUncleanShutdownIndicator(sourcePath string, dnsAddressStr string, stateManager *statemanager.Manager) error {
dnsAddress, err := netip.ParseAddr(dnsAddressStr)
if err != nil {
return fmt.Errorf("parse dns address %s: %w", dnsAddressStr, err)
}
dir := filepath.Dir(fileUncleanShutdownResolvConfLocation) dir := filepath.Dir(fileUncleanShutdownResolvConfLocation)
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil { if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
return fmt.Errorf("create dir %s: %w", dir, err) return fmt.Errorf("create dir %s: %w", dir, err)
@@ -72,20 +50,13 @@ func createUncleanShutdownIndicator(sourcePath string, managerType osManagerType
return fmt.Errorf("create %s: %w", sourcePath, err) return fmt.Errorf("create %s: %w", sourcePath, err)
} }
managerData := fmt.Sprintf("%s,%s", managerType, dnsAddress) state := &ShutdownState{
ManagerType: fileManager,
if err := os.WriteFile(fileUncleanShutdownManagerTypeLocation, []byte(managerData), 0644); err != nil { //nolint:gosec DNSAddress: dnsAddress,
return fmt.Errorf("create %s: %w", fileUncleanShutdownManagerTypeLocation, err)
}
return nil
}
func removeUncleanShutdownIndicator() error {
if err := os.Remove(fileUncleanShutdownResolvConfLocation); err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("remove %s: %w", fileUncleanShutdownResolvConfLocation, err)
}
if err := os.Remove(fileUncleanShutdownManagerTypeLocation); err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("remove %s: %w", fileUncleanShutdownManagerTypeLocation, err)
} }
if err := stateManager.UpdateState(state); err != nil {
return fmt.Errorf("update state: %w", err)
}
return nil return nil
} }

View File

@@ -1,75 +1,26 @@
package dns package dns
import ( import (
"errors"
"fmt" "fmt"
"io/fs"
"os"
"path/filepath"
"github.com/sirupsen/logrus"
) )
const ( type ShutdownState struct {
netbirdProgramDataLocation = "Netbird" Guid string
fileUncleanShutdownFile = "unclean_shutdown_dns.txt" }
)
func CheckUncleanShutdown(string) error { func (s *ShutdownState) Name() string {
file := getUncleanShutdownFile() return "dns_state"
}
if _, err := os.Stat(file); err != nil { func (s *ShutdownState) Cleanup() error {
if errors.Is(err, fs.ErrNotExist) { manager, err := newHostManagerWithGuid(s.Guid)
// no file -> clean shutdown
return nil
} else {
return fmt.Errorf("state: %w", err)
}
}
logrus.Warnf("detected unclean shutdown, file %s exists. Restoring unclean shutdown dns settings.", file)
guid, err := os.ReadFile(file)
if err != nil {
return fmt.Errorf("read %s: %w", file, err)
}
manager, err := newHostManagerWithGuid(string(guid))
if err != nil { if err != nil {
return fmt.Errorf("create host manager: %w", err) return fmt.Errorf("create host manager: %w", err)
} }
if err := manager.restoreUncleanShutdownDNS(nil); err != nil { if err := manager.restoreUncleanShutdownDNS(); err != nil {
return fmt.Errorf("restore unclean shutdown backup: %w", err) return fmt.Errorf("restore unclean shutdown dns: %w", err)
} }
return nil return nil
} }
func createUncleanShutdownIndicator(guid string) error {
file := getUncleanShutdownFile()
dir := filepath.Dir(file)
if err := os.MkdirAll(dir, os.FileMode(0755)); err != nil {
return fmt.Errorf("create dir %s: %w", dir, err)
}
if err := os.WriteFile(file, []byte(guid), 0600); err != nil {
return fmt.Errorf("create %s: %w", file, err)
}
return nil
}
func removeUncleanShutdownIndicator() error {
file := getUncleanShutdownFile()
if err := os.Remove(file); err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("remove %s: %w", file, err)
}
return nil
}
func getUncleanShutdownFile() string {
return filepath.Join(os.Getenv("PROGRAMDATA"), netbirdProgramDataLocation, fileUncleanShutdownFile)
}

View File

@@ -11,6 +11,7 @@ import (
"reflect" "reflect"
"runtime" "runtime"
"slices" "slices"
"sort"
"strings" "strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -23,19 +24,21 @@ import (
"github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/wgproxy" "github.com/netbirdio/netbird/client/internal/statemanager"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
@@ -141,8 +144,7 @@ type Engine struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
wgInterface iface.IWGIface wgInterface iface.IWGIface
wgProxyFactory *wgproxy.Factory
udpMux *bind.UniversalUDPMuxDefault udpMux *bind.UniversalUDPMuxDefault
@@ -168,6 +170,8 @@ type Engine struct {
checks []*mgmProto.Checks checks []*mgmProto.Checks
relayManager *relayClient.Manager relayManager *relayClient.Manager
stateManager *statemanager.Manager
srWatcher *guard.SRWatcher
} }
// Peer is an instance of the Connection Peer // Peer is an instance of the Connection Peer
@@ -215,7 +219,7 @@ func NewEngineWithProbes(
probes *ProbeHolder, probes *ProbeHolder,
checks []*mgmProto.Checks, checks []*mgmProto.Checks,
) *Engine { ) *Engine {
return &Engine{ engine := &Engine{
clientCtx: clientCtx, clientCtx: clientCtx,
clientCancel: clientCancel, clientCancel: clientCancel,
signal: signalClient, signal: signalClient,
@@ -234,6 +238,11 @@ func NewEngineWithProbes(
probes: probes, probes: probes,
checks: checks, checks: checks,
} }
if path := statemanager.GetDefaultStatePath(); path != "" {
engine.stateManager = statemanager.New(path)
}
return engine
} }
func (e *Engine) Stop() error { func (e *Engine) Stop() error {
@@ -255,7 +264,11 @@ func (e *Engine) Stop() error {
e.stopDNSServer() e.stopDNSServer()
if e.routeManager != nil { if e.routeManager != nil {
e.routeManager.Stop() e.routeManager.Stop(e.stateManager)
}
if e.srWatcher != nil {
e.srWatcher.Close()
} }
err := e.removeAllPeers() err := e.removeAllPeers()
@@ -277,6 +290,17 @@ func (e *Engine) Stop() error {
e.close() e.close()
log.Infof("stopped Netbird Engine") log.Infof("stopped Netbird Engine")
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
if err := e.stateManager.Stop(ctx); err != nil {
return fmt.Errorf("failed to stop state manager: %w", err)
}
if err := e.stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
return nil return nil
} }
@@ -299,9 +323,6 @@ func (e *Engine) Start() error {
} }
e.wgInterface = wgIface e.wgInterface = wgIface
userspace := e.wgInterface.IsUserspaceBind()
e.wgProxyFactory = wgproxy.NewFactory(userspace, e.config.WgPort)
if e.config.RosenpassEnabled { if e.config.RosenpassEnabled {
log.Infof("rosenpass is enabled") log.Infof("rosenpass is enabled")
if e.config.RosenpassPermissive { if e.config.RosenpassPermissive {
@@ -319,6 +340,8 @@ func (e *Engine) Start() error {
} }
} }
e.stateManager.Start()
initialRoutes, dnsServer, err := e.newDnsServer() initialRoutes, dnsServer, err := e.newDnsServer()
if err != nil { if err != nil {
e.close() e.close()
@@ -327,7 +350,7 @@ func (e *Engine) Start() error {
e.dnsServer = dnsServer e.dnsServer = dnsServer
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes) e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes)
beforePeerHook, afterPeerHook, err := e.routeManager.Init() beforePeerHook, afterPeerHook, err := e.routeManager.Init(e.stateManager)
if err != nil { if err != nil {
log.Errorf("Failed to initialize route manager: %s", err) log.Errorf("Failed to initialize route manager: %s", err)
} else { } else {
@@ -344,7 +367,7 @@ func (e *Engine) Start() error {
return fmt.Errorf("create wg interface: %w", err) return fmt.Errorf("create wg interface: %w", err)
} }
e.firewall, err = firewall.NewFirewall(e.ctx, e.wgInterface) e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager)
if err != nil { if err != nil {
log.Errorf("failed creating firewall manager: %s", err) log.Errorf("failed creating firewall manager: %s", err)
} }
@@ -374,6 +397,18 @@ func (e *Engine) Start() error {
return fmt.Errorf("initialize dns server: %w", err) return fmt.Errorf("initialize dns server: %w", err)
} }
iceCfg := icemaker.Config{
StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery,
UDPMux: e.udpMux.UDPMuxDefault,
UDPMuxSrflx: e.udpMux,
NATExternalIPs: e.parseNATExternalIPMappings(),
}
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
e.srWatcher.Start()
e.receiveSignalEvents() e.receiveSignalEvents()
e.receiveManagementEvents() e.receiveManagementEvents()
e.receiveProbeEvents() e.receiveProbeEvents()
@@ -606,6 +641,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
} }
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
if e.wgInterface == nil {
return errors.New("wireguard interface is not initialized")
}
if e.wgInterface.Address().String() != conf.Address { if e.wgInterface.Address().String() != conf.Address {
oldAddr := e.wgInterface.Address().String() oldAddr := e.wgInterface.Address().String()
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address) log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
@@ -956,7 +995,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
LocalWgPort: e.config.WgPort, LocalWgPort: e.config.WgPort,
RosenpassPubKey: e.getRosenpassPubKey(), RosenpassPubKey: e.getRosenpassPubKey(),
RosenpassAddr: e.getRosenpassAddr(), RosenpassAddr: e.getRosenpassAddr(),
ICEConfig: peer.ICEConfig{ ICEConfig: icemaker.Config{
StunTurn: &e.stunTurn, StunTurn: &e.stunTurn,
InterfaceBlackList: e.config.IFaceBlackList, InterfaceBlackList: e.config.IFaceBlackList,
DisableIPv6Discovery: e.config.DisableIPv6Discovery, DisableIPv6Discovery: e.config.DisableIPv6Discovery,
@@ -966,7 +1005,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e
}, },
} }
peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.wgProxyFactory, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager) peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1117,12 +1156,6 @@ func (e *Engine) parseNATExternalIPMappings() []string {
} }
func (e *Engine) close() { func (e *Engine) close() {
if e.wgProxyFactory != nil {
if err := e.wgProxyFactory.Free(); err != nil {
log.Errorf("failed closing ebpf proxy: %s", err)
}
}
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName) log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
if e.wgInterface != nil { if e.wgInterface != nil {
if err := e.wgInterface.Close(); err != nil { if err := e.wgInterface.Close(); err != nil {
@@ -1139,7 +1172,7 @@ func (e *Engine) close() {
} }
if e.firewall != nil { if e.firewall != nil {
err := e.firewall.Reset() err := e.firewall.Reset(e.stateManager)
if err != nil { if err != nil {
log.Warnf("failed to reset firewall: %s", err) log.Warnf("failed to reset firewall: %s", err)
} }
@@ -1167,21 +1200,29 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
log.Errorf("failed to create pion's stdnet: %s", err) log.Errorf("failed to create pion's stdnet: %s", err)
} }
var mArgs *device.MobileIFaceArguments opts := iface.WGIFaceOpts{
IFaceName: e.config.WgIfaceName,
Address: e.config.WgAddr,
WGPort: e.config.WgPort,
WGPrivKey: e.config.WgPrivateKey.String(),
MTU: iface.DefaultMTU,
TransportNet: transportNet,
FilterFn: e.addrViaRoutes,
}
switch runtime.GOOS { switch runtime.GOOS {
case "android": case "android":
mArgs = &device.MobileIFaceArguments{ opts.MobileArgs = &device.MobileIFaceArguments{
TunAdapter: e.mobileDep.TunAdapter, TunAdapter: e.mobileDep.TunAdapter,
TunFd: int(e.mobileDep.FileDescriptor), TunFd: int(e.mobileDep.FileDescriptor),
} }
case "ios": case "ios":
mArgs = &device.MobileIFaceArguments{ opts.MobileArgs = &device.MobileIFaceArguments{
TunFd: int(e.mobileDep.FileDescriptor), TunFd: int(e.mobileDep.FileDescriptor),
} }
default:
} }
return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs, e.addrViaRoutes) return iface.NewWGIFace(opts)
} }
func (e *Engine) wgInterfaceCreate() (err error) { func (e *Engine) wgInterfaceCreate() (err error) {
@@ -1222,10 +1263,11 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) {
dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder) dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder)
return nil, dnsServer, nil return nil, dnsServer, nil
default: default:
dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder) dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
return nil, dnsServer, nil return nil, dnsServer, nil
} }
} }
@@ -1443,6 +1485,17 @@ func (e *Engine) stopDNSServer() {
// isChecksEqual checks if two slices of checks are equal. // isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool { func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
for _, check := range checks {
sort.Slice(check.Files, func(i, j int) bool {
return check.Files[i] < check.Files[j]
})
}
for _, oCheck := range oChecks {
sort.Slice(oCheck.Files, func(i, j int) bool {
return oCheck.Files[i] < oCheck.Files[j]
})
}
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool { return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files) return slices.Equal(checks.Files, oChecks.Files)
}) })

View File

@@ -29,6 +29,8 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
@@ -258,6 +260,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
} }
engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn}) engine.udpMux = bind.NewUniversalUDPMuxDefault(bind.UniversalUDPMuxParams{UDPConn: conn})
engine.ctx = ctx engine.ctx = ctx
engine.srWatcher = guard.NewSRWatcher(nil, nil, nil, icemaker.Config{})
type testCase struct { type testCase struct {
name string name string
@@ -602,7 +605,16 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil)
opts := iface.WGIFaceOpts{
IFaceName: wgIfaceName,
Address: wgAddr,
WGPort: engine.config.WgPort,
WGPrivKey: key.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
engine.wgInterface, err = iface.NewWGIFace(opts)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
input := struct { input := struct {
inputSerial uint64 inputSerial uint64
@@ -774,7 +786,15 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil, nil) opts := iface.WGIFaceOpts{
IFaceName: wgIfaceName,
Address: wgAddr,
WGPort: 33100,
WGPrivKey: key.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
engine.wgInterface, err = iface.NewWGIFace(opts)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{ mockRouteManager := &routemanager.MockManager{
@@ -986,6 +1006,99 @@ func Test_ParseNATExternalIPMappings(t *testing.T) {
} }
} }
func Test_CheckFilesEqual(t *testing.T) {
testCases := []struct {
name string
inputChecks1 []*mgmtProto.Checks
inputChecks2 []*mgmtProto.Checks
expectedBool bool
}{
{
name: "Equal Files In Equal Order Should Return True",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
expectedBool: true,
},
{
name: "Equal Files In Reverse Order Should Return True",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{
"testfile2",
"testfile1",
},
},
},
expectedBool: true,
},
{
name: "Unequal Files Should Return False",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile3",
},
},
},
expectedBool: false,
},
{
name: "Compared With Empty Should Return False",
inputChecks1: []*mgmtProto.Checks{
{
Files: []string{
"testfile1",
"testfile2",
},
},
},
inputChecks2: []*mgmtProto.Checks{
{
Files: []string{},
},
},
expectedBool: false,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
result := isChecksEqual(testCase.inputChecks1, testCase.inputChecks2)
assert.Equal(t, testCase.expectedBool, result, "result should match expected bool")
})
}
}
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) { func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {

Some files were not shown because too many files have changed in this diff Show More