Compare commits

...

158 Commits

Author SHA1 Message Date
Viktor Liu
ff5eddf70b Merge branch 'main' into add-ns-punnycode-support 2025-06-08 13:14:52 +02:00
Maycon Santos
0f050e5fe1 [client] Optmize process check time (#3938)
This PR optimizes the process check time by updating the implementation of getRunningProcesses and introducing new benchmark tests.

Updated getRunningProcesses to use process.Pids() instead of process.Processes()
Added benchmark tests for both the new and the legacy implementations

Benchmark: https://github.com/netbirdio/netbird/actions/runs/15512741612

todo: evaluate windows optmizations and caching risks
2025-06-08 12:19:54 +02:00
Maycon Santos
0f7c7f1da2 [misc] use generic slack url (#3939) 2025-06-08 10:53:27 +02:00
Maycon Santos
b56f61bf1b [misc] fix relay exposed address test (#3931) 2025-06-05 15:44:44 +02:00
Viktor Liu
64f111923e [client] Increase stun status probe timeout (#3930) 2025-06-05 15:22:59 +02:00
Abdul Latif
122a89c02b [misc] remove error causing dnf config-manager add (#3925) 2025-06-05 14:28:19 +02:00
Robert Neumann
c6cceba381 Update getting-started-with-zitadel.sh - fix zitadel user console (#3446) 2025-06-05 14:16:04 +02:00
Ghazy Abdallah
6c0cdb6ed1 [misc] fix: traefik relay accessibility (#3696) 2025-06-05 14:15:01 +02:00
Viktor Liu
84354951d3 [client] Add systemd netbird logs to debug bundle (#3917) 2025-06-05 13:54:15 +02:00
Viktor Liu
55957a1960 [client] Run registerdns before flushing (#3926)
* Run registerdns before flushing

* Disable WINS, dynamic updates and registration
2025-06-05 12:40:23 +02:00
Viktor Liu
df82a45d99 [client] Improve dns match trace log (#3928) 2025-06-05 12:39:58 +02:00
Zoltan Papp
9424b88db2 [client] Add output similar to wg show to the debug package (#3922) 2025-06-05 11:51:39 +02:00
Viktor Liu
609654eee7 [client] Allow userspace local forwarding to internal interfaces if requested (#3884) 2025-06-04 18:12:48 +02:00
Bethuel Mmbaga
b604c66140 [management] Add postgres support for activity event store (#3890) 2025-06-04 17:38:49 +03:00
Viktor Liu
ea4d13e96d [client] Use platform-native routing APIs for freeBSD, macOS and Windows 2025-06-04 16:28:58 +02:00
Pedro Maia Costa
87148c503f [management] support account retrieval and creation by private domain (#3825)
* [management] sys initiator save user (#3911)

* [management] activity events with multiple external account users (#3914)
2025-06-04 11:21:31 +01:00
Viktor Liu
0cd36baf67 [client] Allow the netbird service to log to console (#3916) 2025-06-03 13:09:39 +02:00
Viktor Liu
06980e7fa0 [client] Apply routes right away instead of on peer connection (#3907) 2025-06-03 10:53:39 +02:00
Viktor Liu
1ce4ee0cef [client] Add block inbound flag to disallow inbound connections of any kind (#3897) 2025-06-03 10:53:27 +02:00
Viktor Liu
f367925496 [client] Log duplicate client ui pid (#3915) 2025-06-03 10:52:10 +02:00
hakansa
616b19c064 [client] Add "Deselect All" Menu Item to Exit Node Menu (#3877)
* [client] Enhance exit node menu functionality with deselect all option

* Hide exit nodes before removal in recreateExitNodeMenu

* recreateExitNodeMenu adding mutex locks

* Refetch exit nodes after deselecting all in exit node menu
2025-06-03 09:49:13 +02:00
Zoltan Papp
af27aaf9af [client] Refactor peer state change subscription mechanism (#3910)
* Refactor peer state change subscription mechanism

Because the code generated new channel for every single event, was easy to miss notification.
Use single channel.

* Fix lint

* Avoid potential deadlock

* Fix test

* Add context

* Fix test
2025-06-03 09:20:33 +02:00
Maycon Santos
35287f8241 [misc] Fail linter workflows on codespell failures (#3913)
* Fail linter workflows on codespell failures

* testing workflow

* remove test
2025-06-03 00:37:51 +02:00
Pedro Maia Costa
07b220d91b [management] REST client impersonation (#3879) 2025-06-02 22:11:28 +02:00
Viktor Liu
41cd4952f1 [client] Apply return traffic rules only if firewall is stateless (#3895) 2025-06-02 12:11:54 +02:00
Zoltan Papp
f16f0c7831 [client] Fix HA router switch (#3889)
* Fix HA router switch.

- Simplify the notification filter logic.
Always send notification if a state has been changed

- Remove IP changes check because we never modify

* Notify only the proper listeners

* Fix test

* Fix TestGetPeerStateChangeNotifierLogic test

* Before lazy connection, when the peer disconnected, the status switched to disconnected.
After implementing lazy connection, the peer state is connecting, so we did not decrease the reference counters on the routes.

* When switch to idle notify the route mgr
2025-06-01 16:08:27 +02:00
Zoltan Papp
aa07b3b87b Fix deadlock (#3904) 2025-05-30 23:38:02 +02:00
Bethuel Mmbaga
2bef214cc0 [management] Fix user groups propagation (#3902) 2025-05-30 18:12:30 +03:00
hakansa
cfb2d82352 [client] Refactor exclude list handling to use a map for permanent connections (#3901)
[client] Refactor exclude list handling to use a map for permanent connections (#3901)
2025-05-30 16:54:49 +03:00
Bethuel Mmbaga
684501fd35 [management] Prevent deletion of peers linked to network routers (#3881)
- Prevent deletion of peers linked to network routers
- Add API endpoint to list all network routers
2025-05-29 18:50:00 +03:00
Zoltan Papp
0492c1724a [client, android] Fix/notifier threading (#3807)
- Fix potential deadlocks
- When adding a listener, immediately notify with the last known IP and fqdn.
2025-05-27 17:12:04 +02:00
Zoltan Papp
6f436e57b5 [server-test] Install libs for i386 tests (#3887)
Install libs for i386 tests
2025-05-27 16:42:06 +02:00
Bethuel Mmbaga
a0d28f9851 [management] Reset test containers after cleanup (#3885) 2025-05-27 14:42:00 +03:00
Zoltan Papp
cdd27a9fe5 [client, android] Fix/android enable server route (#3806)
Enable the server route; otherwise, the manager throws an error and the engine will restart.
2025-05-27 13:32:54 +02:00
Bethuel Mmbaga
5523040acd [management] Add correlated network traffic event schema (#3680) 2025-05-27 13:47:53 +03:00
M. Essam
670446d42e [management/client/rest] Fix panic on unknown errors (#3865) 2025-05-25 16:57:34 +02:00
Viktor Liu
273160c682 [client] Use punycode domains internally consequently (#3867) 2025-05-24 18:25:15 +02:00
Pedro Maia Costa
5bed6777d5 [management] force account id on save groups update (#3850) 2025-05-23 14:42:42 +01:00
Pascal Fischer
a0482ebc7b [client] avoid overwriting state manager on iOS (#3870) 2025-05-23 14:04:12 +02:00
bcmmbaga
1d6c360aec fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-23 13:07:26 +03:00
bcmmbaga
f04e7c3f06 Merge branch 'main' into add-ns-punnycode-support
# Conflicts:
#	management/server/nameserver.go
#	management/server/nameserver_test.go
2025-05-23 13:00:19 +03:00
Bethuel Mmbaga
2a89d6e47a [management] Extend nameserver match domain validation (#3864)
* Enhance match domain validation logic and add test cases

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* remove the leading dot and root dot support ns regex

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Remove support for wildcard ns match domain

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-22 23:16:19 +02:00
bcmmbaga
3d89cd43c2 fix tests
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-22 22:44:30 +03:00
bcmmbaga
0eeda712d0 add support for punycode domain
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-22 22:44:12 +03:00
bcmmbaga
3e3268db5f Remove support for wildcard ns match domain
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-22 19:01:53 +03:00
bcmmbaga
31f0879e71 remove the leading dot and root dot support ns regex
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-22 18:51:05 +03:00
bcmmbaga
f25b5bb987 Enhance match domain validation logic and add test cases
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-05-22 16:35:45 +03:00
Bethuel Mmbaga
24f932b2ce [management] Update traffic events pagination filters (#3857) 2025-05-22 16:28:14 +03:00
Pedro Maia Costa
c03435061c [management] lazy connection account setting (#3855) 2025-05-22 14:09:00 +01:00
Misha Bragin
8e948739f1 Fix CLA link in the PR template (#3860) 2025-05-22 10:38:58 +02:00
Maycon Santos
9b53cad752 [misc] add CLA note (#3859) 2025-05-21 22:40:36 +02:00
Zoltan Papp
802a18167c [client] Do not reconnect to mgm server in case of handler error (#3856)
* Do not reconnect to mgm server in case of handler error
Set to nil the flow grpc client to nil

* Better error handling
2025-05-21 20:18:21 +02:00
hakansa
e9108ffe6c [client] Add latest gzipped rotated log file to the debug bundle (#3848)
[client] Add latest gzipped rotated log file to the debug bundle
2025-05-21 17:50:54 +03:00
Viktor Liu
e806d9de38 [client] Fix legacy routes when connecting to management servers older than v0.30.0 (#3854) 2025-05-21 13:48:55 +02:00
Zoltan Papp
daa8380df9 [client] Feature/lazy connection (#3379)
With the lazy connection feature, the peer will connect to target peers on-demand. The trigger can be any IP traffic.

This feature can be enabled with the NB_ENABLE_EXPERIMENTAL_LAZY_CONN environment variable.

When the engine receives a network map, it binds a free UDP port for every remote peer, and the system configures WireGuard endpoints for these ports. When traffic appears on a UDP socket, the system removes this listener and starts the peer connection procedure immediately.

Key changes
Fix slow netbird status -d command
Move from engine.go file to conn_mgr.go the peer connection related code
Refactor the iface interface usage and moved interface file next to the engine code
Add new command line flag and UI option to enable feature
The peer.Conn struct is reusable after it has been closed.
Change connection states
Connection states
Idle: The peer is not attempting to establish a connection. This typically means it's in a lazy state or the remote peer is expired.

Connecting: The peer is actively trying to establish a connection. This occurs when the peer has entered an active state and is continuously attempting to reach the remote peer.

Connected: A successful peer-to-peer connection has been established and communication is active.
2025-05-21 11:12:28 +02:00
Bethuel Mmbaga
4785f23fc4 [management] Migrate events sqlite store to gorm (#3837) 2025-05-20 17:00:37 +03:00
Viktor Liu
1d4cfb83e7 [client] Fix UI new version notifier (#3845) 2025-05-20 10:39:17 +02:00
Pascal Fischer
207fa059d2 [management] make locking strength clause optional (#3844) 2025-05-19 16:42:47 +02:00
Viktor Liu
cbcdad7814 [misc] Update issue template (#3842) 2025-05-19 15:41:24 +02:00
Pascal Fischer
701c13807a [management] add flag to disable auto-migration (#3840) 2025-05-19 13:36:24 +02:00
Viktor Liu
99f8dc7748 [client] Offer to remove netbird data in windows uninstall (#3766) 2025-05-16 17:39:30 +02:00
Pascal Fischer
f1de8e6eb0 [management] Make startup period configurable (#3767) 2025-05-16 13:16:51 +02:00
Viktor Liu
b2a10780af [client] Disable dnssec for systemd explicitly (#3831) 2025-05-16 09:43:13 +02:00
Pascal Fischer
43ae79d848 [management] extend rest client lib (#3830) 2025-05-15 18:20:29 +02:00
Pascal Fischer
e520b64c6d [signal] remove stream receive server side (#3820) 2025-05-14 19:28:51 +02:00
hakansa
92c91bbdd8 [client] Add FreeBSD desktop client support to OAuth flow (#3822)
[client] Add FreeBSD desktop client support to OAuth flow
2025-05-14 19:52:02 +03:00
Vlad
adf494e1ac [management] fix a bug with missed extra dns labels for a new peer (#3798) 2025-05-14 17:50:21 +02:00
Vlad
2158461121 [management,client] PKCE add flag parameter prompt=login or max_age (#3824) 2025-05-14 17:48:51 +02:00
Bethuel Mmbaga
0cd4b601c3 [management] Add connection type filter to Network Traffic API (#3815) 2025-05-14 11:15:50 +03:00
Zoltan Papp
ee1cec47b3 [client, android] Do not propagate empty routes (#3805)
If we get domain routes the Network prefix variable in route structure will be invalid (engine.go:1057). When we handower to Android the routes, we must to filter out the domain routes. If we do not do it the Android code will get "invalid prefix" string as a route.
2025-05-13 15:21:06 +02:00
Pascal Fischer
efb0edfc4c [signal] adjust signal log levels 2 (#3817) 2025-05-12 23:52:29 +02:00
Pascal Fischer
20f59ddecb [signal] adjust log levels (#3813) 2025-05-12 19:48:47 +02:00
hakansa
2f34e984b0 [client] Add TCP support to DNS forwarder service listener (#3790)
[client] Add TCP support to DNS forwarder service listener
2025-05-09 15:06:34 +03:00
Viktor Liu
d5b52e86b6 [client] Ignore irrelevant route changes to tracked network monitor routes (#3796) 2025-05-09 14:01:21 +02:00
Zoltan Papp
cad2fe1f39 Return with the correct copied length (#3804) 2025-05-09 13:56:27 +02:00
Pascal Fischer
fcd2c15a37 [management] policy delete cleans policy rules (#3788) 2025-05-07 07:25:25 +02:00
Bethuel Mmbaga
ebda0fc538 [management] Delete service users with account manager (#3793) 2025-05-06 17:31:03 +02:00
M. Essam
ac135ab11d [management/client/rest] fix panic when body is nil (#3714)
Fixes panic occurring when body is nil (this usually happens when connections is refused) due to lack of nil check by centralizing response.Body.Close() behavior.
2025-05-05 18:54:47 +02:00
Pascal Fischer
25faf9283d [management] removal of foreign key constraint enforcement on sqlite (#3786) 2025-05-05 18:21:48 +02:00
hakansa
59faaa99f6 [client] Improve NetBird installation script to handle daemon connection timeout (#3761)
[client] Improve NetBird installation script to handle daemon connection timeout
2025-05-05 17:05:01 +03:00
Viktor Liu
9762b39f29 [client] Fix stale local records (#3776) 2025-05-05 14:29:05 +02:00
Alin Trăistaru
ffdd115ded [client] set TLS ServerName for hostname-based QUIC connections (#3673)
* fix: set TLS ServerName for hostname-based QUIC connections

When connecting to a relay server by hostname, certificates are
validated against the IP address instead of the hostname.
This change sets ServerName in the TLS config when connecting
via hostname, ensuring proper certificate validation.

* use default port if port is missing in URL string
2025-05-05 12:20:54 +02:00
Pascal Fischer
055df9854c [management] add gorm tag for primary key for the networks objects (#3758) 2025-05-04 20:58:04 +02:00
Maycon Santos
12f883badf [management] Optimize load account (#3774) 2025-05-02 00:59:41 +02:00
Maycon Santos
2abb92b0d4 [management] Get account id with order (#3773)
updated log to display account id
2025-05-02 00:25:46 +02:00
Viktor Liu
01c3719c5d [client] Add debug for duration option to netbird ui (#3772) 2025-05-01 23:25:27 +02:00
Pedro Maia Costa
7b64953eed [management] user info with role permissions (#3728) 2025-05-01 11:24:55 +01:00
Viktor Liu
9bc7d788f0 [client] Add debug upload option to netbird ui (#3768) 2025-05-01 00:48:31 +02:00
Pedro Maia Costa
b5419ef11a [management] limit peers based on module read permission (#3757) 2025-04-30 15:53:18 +01:00
Zoltan Papp
d5081cef90 [client] Revert mgm client error handling (#3764) 2025-04-30 13:09:00 +02:00
Bethuel Mmbaga
488e619ec7 [management] Add network traffic events pagination (#3580)
* Add network traffic events pagination schema
2025-04-30 11:51:40 +03:00
hakansa
d2b42c8f68 [client] Add macOS .pkg installer support to installation script (#3755)
[client] Add macOS .pkg installer support to installation script
2025-04-29 13:43:42 +03:00
Maycon Santos
2f44fe2e23 [client] Feature/upload bundle (#3734)
Add an upload bundle option with the flag --upload-bundle; by default, the upload will use a NetBird address, which can be replaced using the flag --upload-bundle-url.

The upload server is available under the /upload-server path. The release change will push a docker image to netbirdio/upload image repository.

The server supports using s3 with pre-signed URL for direct upload and local file for storing bundles.
2025-04-29 00:43:50 +02:00
Bethuel Mmbaga
d8dc107bee [management] Skip IdP cache warm-up on Redis if data exists (#3733)
* Add Redis cache check to skip warm-up on startup if cache is already populated
* Refactor Redis test container setup for reusability
2025-04-28 15:10:40 +03:00
Viktor Liu
3fa915e271 [misc] Exclude client benchmarks from CI (#3752) 2025-04-28 13:40:36 +02:00
Pedro Maia Costa
47c3afe561 [management] add missing network admin mapping (#3751) 2025-04-28 11:05:27 +01:00
hakansa
84bfecdd37 [client] add byte counters & ruleID for routed traffic on userspace (#3653)
* [client] add byte counters for routed traffic on userspace 
* [client] add allowed ruleID for routed traffic on userspace
2025-04-28 10:10:41 +03:00
Viktor Liu
3cf87b6846 [client] Run container tests more generically (#3737) 2025-04-25 18:50:44 +02:00
Maycon Santos
4fe4c2054d [client] Move static check when running on foreground (#3742) 2025-04-25 18:25:48 +02:00
Pascal Fischer
38ada44a0e [management] allow impersonation via pats (#3739) 2025-04-25 16:40:54 +02:00
Pedro Maia Costa
dbf81a145e [management] network admin role (#3720) 2025-04-25 15:14:32 +01:00
Pedro Maia Costa
39483f8ca8 [management] Auditor role (#3721) 2025-04-25 15:04:25 +01:00
Carlos Hernandez
c0eaea938e [client] Fix macos privacy warning when checking static info (#3496)
avoid checking static info with a init call
2025-04-25 14:41:57 +02:00
Viktor Liu
ef8b8a2891 [client] Ensure dst-type local marks can overwrite nat marks (#3738) 2025-04-25 12:43:20 +02:00
Zoltan Papp
2817f62c13 [client] Fix error handling case of flow grpc error (#3727)
When a gRPC error occurs in the Flow package, it will be propagated to the upper layers and handled similarly to a Management gRPC error.

Always report a disconnected state in the event of any error
Hide the underlying gRPC errors
Force close the gRPC connection in the event of any error
2025-04-25 09:26:18 +02:00
Viktor Liu
4a9049566a [client] Set up firewall rules for dns routes dynamically based on dns response (#3702) 2025-04-24 17:37:28 +02:00
Viktor Liu
85f92f8321 [client] Add more userspace filter ACL test cases (#3730) 2025-04-24 12:57:46 +02:00
Viktor Liu
714beb6e3b [client] Fix exit node deselection (#3722) 2025-04-24 12:36:05 +02:00
Viktor Liu
400b9fca32 [management] Add firewall rule route ID and missing route domains (#3700) 2025-04-23 21:29:46 +02:00
hakansa
4013298e22 [client/ui] add connecting state to status handling (#3712) 2025-04-23 21:04:38 +02:00
Pascal Fischer
312bfd9bd7 [management] support custom domains per account (#3726) 2025-04-23 19:36:53 +02:00
Pascal Fischer
8db05838ca [misc] Change github runner for docker test (#3707) 2025-04-23 19:35:26 +02:00
Misha Bragin
c69df13515 [management] Add account meta (#3724) 2025-04-23 18:44:22 +02:00
Pascal Fischer
986eb8c1e0 [management] fix lastLogin on dashboard (#3725) 2025-04-23 15:54:49 +02:00
dependabot[bot]
197761ba4d Bump github.com/redis/go-redis/v9 from 9.7.1 to 9.7.3 (#3553)
Bumps [github.com/redis/go-redis/v9](https://github.com/redis/go-redis) from 9.7.1 to 9.7.3.
- [Release notes](https://github.com/redis/go-redis/releases)
- [Changelog](https://github.com/redis/go-redis/blob/master/CHANGELOG.md)
- [Commits](https://github.com/redis/go-redis/compare/v9.7.1...v9.7.3)

---
updated-dependencies:
- dependency-name: github.com/redis/go-redis/v9
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-04-23 10:21:36 +02:00
dependabot[bot]
f74ea64c7b Bump golang.org/x/net from 0.36.0 to 0.38.0 (#3695)
Bumps [golang.org/x/net](https://github.com/golang/net) from 0.36.0 to 0.38.0.
- [Commits](https://github.com/golang/net/compare/v0.36.0...v0.38.0)

---
updated-dependencies:
- dependency-name: golang.org/x/net
  dependency-version: 0.38.0
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-04-23 10:20:51 +02:00
Viktor Liu
3b7b9d25bc [client] Keep new routes selected unless all are deselected (#3692) 2025-04-23 01:07:04 +02:00
Pascal Fischer
1a6d6b3109 [management] fix github run id (#3705) 2025-04-18 11:21:54 +02:00
Pascal Fischer
f686615876 [management] benchmarks use ref_name instead (#3704) 2025-04-17 21:57:54 +02:00
Pascal Fischer
a4311f574d [management] push benchmark results to grafana (#3701) 2025-04-17 21:01:23 +02:00
Pierre Timmermans
0bb8eae903 [docs] fix: broken link in the README file (#3697)
improve README.md, broken link for activity logging
2025-04-17 14:48:10 +02:00
Pascal Fischer
e0b33d325d [management] permissions manager use crud operations (#3690) 2025-04-16 17:25:03 +02:00
Zoltan Papp
c38e07d89a [client] Fix Rosenpass permissive mode handling (#3689)
fixes the Rosenpass preshared key handling to enable successful WireGuard handshakes when one side is in permissive mode. Key changes include:

Updating field accesses from RosenpassPubKey/RosenpassAddr to RosenpassConfig.PubKey/RosenpassConfig.Addr.
Modifying the preshared key computation logic to account for permissive mode.
Revising peer configuration in the Engine to use the new RosenpassConfig struct.
2025-04-16 16:04:43 +02:00
Lamera
a37368fff4 [misc] update gpt file permissions in install.sh (#3663)
* Fix install.sh for some installations

Fix install.sh for some installations by explicitly setting the file permissions

* Add sudo
2025-04-16 14:23:25 +02:00
Viktor Liu
0c93bd3d06 [client] Keep selecting new networks after first deselection (#3671) 2025-04-16 13:55:26 +02:00
Viktor Liu
a675531b5c [client] Set up signal to generate debug bundles (#3683) 2025-04-16 11:06:22 +02:00
hakansa
7cb366bc7d [client] Remove logrus writer assignment in pion logging (#3684) 2025-04-15 18:15:52 +03:00
Viktor Liu
a354004564 [client] Add remaining debug profiles (#3681) 2025-04-15 13:06:28 +02:00
Pedro Maia Costa
75bdd47dfb [management] get current user endpoint (#3666) 2025-04-15 11:06:07 +01:00
Viktor Liu
b165f63327 [client] Add heap profile to debug bundle (#3679) 2025-04-15 11:36:41 +02:00
hakansa
51bb52cdf5 [client] Refactor DNSForwarder to improve handle wildcard domain resource id matching (#3651)
[client] Refactor DNSForwarder to improve handle wildcard domain resource id matching (#3651)
2025-04-15 10:54:17 +03:00
Pedro Maia Costa
4134b857b4 [management] add permissions manager to geolocation handler (#3665) 2025-04-14 17:57:58 +01:00
Vlad
7839d2c169 [management] Refactor/management/updchannel (#3645)
* refactoring updatechannel - use read mutex for send update
2025-04-11 18:22:59 +03:00
Pascal Fischer
b9f82e2f8a [management] Buffer updateAccountPeers calls (#3644) 2025-04-11 17:21:05 +02:00
Pedro Maia Costa
fd2a21c65d [management] remove unnecessary access control middleware (#3650) 2025-04-11 10:43:59 +01:00
Maycon Santos
82d982b0ab [management,client] Add support to configurable prompt login (#3660) 2025-04-11 11:34:55 +02:00
Maycon Santos
9e24fe7701 [docs] Fix a few typos on table (#3658) 2025-04-10 17:57:39 +02:00
Pedro Maia Costa
e470701b80 [ci] include stash in pr template (#3657) 2025-04-10 16:30:44 +01:00
Viktor Liu
e3ce026355 [client] Fix race dns cleanup race condition (#3652) 2025-04-10 13:21:14 +02:00
Pascal Fischer
5ea2806663 [management] use permission modules (#3622) 2025-04-10 11:06:52 +02:00
Viktor Liu
d6b0673580 [client] Support CNAME in local resolver (#3646) 2025-04-10 10:38:47 +02:00
Pedro Maia Costa
14913cfa7a add git town config (#3555) 2025-04-09 20:18:52 +01:00
Viktor Liu
03f600b576 [client] Fallback to TCP if a truncated UDP response is received from upstream DNS (#3632) 2025-04-08 13:41:13 +02:00
Viktor Liu
192c97aa63 [client] Support IP fragmentation in userspace (#3639) 2025-04-08 12:49:14 +02:00
Maycon Santos
4db78db49a [misc] Update FreeBSD workflow (#3638)
Update FreeBSD release to 14.2 and download Go package directly since port wasn't finding the package to install
2025-04-08 09:15:09 +02:00
Viktor Liu
87e600a4f3 [client] Automatically register match domains for DNS routes (#3614) 2025-04-07 15:18:45 +02:00
Viktor Liu
6162aeb82d [client] Mark netbird data plane traffic to identify interface traffic correctly (#3623) 2025-04-07 13:14:56 +02:00
hakansa
1ba1e092ce [client] Enhance DNS forwarder to track resolved IPs with resource IDs on routing peers (#3620)
[client] Enhance DNS forwarder to track resolved IPs with resource IDs on routing peers (#3620)
2025-04-07 15:16:12 +08:00
hakansa
86dbb4ee4f [client] Add no-browser flag to login and up commands for SSO login control (#3610)
* [client] Add no-browser flag to login and up commands for SSO login control (#3610)
2025-04-07 14:39:53 +08:00
hakansa
4af177215f [client] Fix Status Recorder Route Removal Logic to Handle Dynamic Routes Correctly 2025-04-06 09:57:28 +08:00
Viktor Liu
df9c1b9883 [client] Improve TCP conn tracking (#3572) 2025-04-05 11:42:15 +02:00
Viktor Liu
5752bb78f2 [client] Fix missing inbound flows in Linux userspace mode with native router (#3624)
* Fix missing inbound flows in Linux userspace mode with native router

* Fix route enable/disable order for userspace mode
2025-04-05 11:41:31 +02:00
Maycon Santos
fbd783ad58 [client] Use the netbird logger for ice and grpc (#3603)
updates the logging implementation to use the netbird logger for both ICE and gRPC components. The key changes include:

- Introducing a gRPC logger configuration in util/log.go that integrates with the netbird logging setup.
- Updating the log hook in formatter/hook/hook.go to ensure a default caller is used when not set.
- Refactoring ICE agent and UDP multiplexers to use a unified logger via the new getLogger() method.
2025-04-04 18:30:47 +02:00
Viktor Liu
80702b9323 [client] Fix dns forwarder handling of requested record types (#3615) 2025-04-03 13:58:36 +02:00
Viktor Liu
09243a0fe0 [management] Remove remaining backend linux router limitation (#3589) 2025-04-01 21:29:57 +02:00
Maycon Santos
3658215747 [client] Force new user login on PKCE auth in CLI (#3604)
With this change, browser session won't be considered for cli authentication and credentials will be requested
2025-04-01 10:29:29 +02:00
Viktor Liu
48ffec95dd Improve local ip lookup (#3551)
- lower memory footprint in most cases
- increase accuracy
2025-03-31 10:05:57 +02:00
Pedro Maia Costa
cbec7bda80 [management] permission manager validate account access (#3444) 2025-03-30 17:08:22 +02:00
398 changed files with 23043 additions and 9805 deletions

27
.git-branches.toml Normal file
View File

@@ -0,0 +1,27 @@
# More info around this file at https://www.git-town.com/configuration-file
[branches]
main = "main"
perennials = []
perennial-regex = ""
[create]
new-branch-type = "feature"
push-new-branches = false
[hosting]
dev-remote = "origin"
# platform = ""
# origin-hostname = ""
[ship]
delete-tracking-branch = false
strategy = "squash-merge"
[sync]
feature-strategy = "merge"
perennial-strategy = "rebase"
prototype-strategy = "merge"
push-hook = true
tags = true
upstream = false

View File

@@ -37,17 +37,22 @@ If yes, which one?
**Debug output** **Debug output**
To help us resolve the problem, please attach the following debug output To help us resolve the problem, please attach the following anonymized status output
netbird status -dA netbird status -dA
As well as the file created by Create and upload a debug bundle, and share the returned file key:
netbird debug for 1m -AS -U
*Uploaded files are automatically deleted after 30 days.*
Alternatively, create the file only and attach it here manually:
netbird debug for 1m -AS netbird debug for 1m -AS
We advise reviewing the anonymized output for any remaining personal information.
**Screenshots** **Screenshots**
If applicable, add screenshots to help explain your problem. If applicable, add screenshots to help explain your problem.
@@ -57,8 +62,10 @@ If applicable, add screenshots to help explain your problem.
Add any other context about the problem here. Add any other context about the problem here.
**Have you tried these troubleshooting steps?** **Have you tried these troubleshooting steps?**
- [ ] Reviewed [client troubleshooting](https://docs.netbird.io/how-to/troubleshooting-client) (if applicable)
- [ ] Checked for newer NetBird versions - [ ] Checked for newer NetBird versions
- [ ] Searched for similar issues on GitHub (including closed ones) - [ ] Searched for similar issues on GitHub (including closed ones)
- [ ] Restarted the NetBird client - [ ] Restarted the NetBird client
- [ ] Disabled other VPN software - [ ] Disabled other VPN software
- [ ] Checked firewall settings - [ ] Checked firewall settings

View File

@@ -2,6 +2,10 @@
## Issue ticket number and link ## Issue ticket number and link
## Stack
<!-- branch-stack -->
### Checklist ### Checklist
- [ ] Is it a bug fix - [ ] Is it a bug fix
- [ ] Is a typo/documentation fix - [ ] Is a typo/documentation fix
@@ -9,3 +13,5 @@
- [ ] It is a refactor - [ ] It is a refactor
- [ ] Created tests that fail without the change (if possible) - [ ] Created tests that fail without the change (if possible)
- [ ] Extended the README / documentation, if necessary - [ ] Extended the README / documentation, if necessary
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).

21
.github/workflows/git-town.yml vendored Normal file
View File

@@ -0,0 +1,21 @@
name: Git Town
on:
pull_request:
branches:
- '**'
jobs:
git-town:
name: Display the branch stack
runs-on: ubuntu-latest
permissions:
contents: read
pull-requests: write
steps:
- uses: actions/checkout@v4
- uses: git-town/action@v1
with:
skip-single-stacks: true

View File

@@ -22,14 +22,20 @@ jobs:
with: with:
usesh: true usesh: true
copyback: false copyback: false
release: "14.1" release: "14.2"
prepare: | prepare: |
pkg install -y go pkgconf xorg pkg install -y curl pkgconf xorg
LATEST_VERSION=$(curl -s https://go.dev/VERSION?m=text|head -n 1)
GO_TARBALL="$LATEST_VERSION.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -vLO "$GO_URL"
tar -C /usr/local -vxzf "$GO_TARBALL"
# -x - to print all executed commands # -x - to print all executed commands
# -e - to faile on first error # -e - to faile on first error
run: | run: |
set -e -x set -e -x
export PATH=$PATH:/usr/local/go/bin:$HOME/go/bin
time go build -o netbird client/main.go time go build -o netbird client/main.go
# check all component except management, since we do not support management server on freebsd # check all component except management, since we do not support management server on freebsd
time go test -timeout 1m -failfast ./base62/... time go test -timeout 1m -failfast ./base62/...

View File

@@ -146,6 +146,65 @@ jobs:
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay) run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
test_client_on_docker:
name: "Client (Docker) / Unit"
needs: [build-cache]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
id: go-env
run: |
echo "cache_dir=$(go env GOCACHE)" >> $GITHUB_OUTPUT
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
- name: Cache Go modules
uses: actions/cache/restore@v4
id: cache-restore
with:
path: |
${{ steps.go-env.outputs.cache_dir }}
${{ steps.go-env.outputs.modcache_dir }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Run tests in container
env:
HOST_GOCACHE: ${{ steps.go-env.outputs.cache_dir }}
HOST_GOMODCACHE: ${{ steps.go-env.outputs.modcache_dir }}
run: |
CONTAINER_GOCACHE="/root/.cache/go-build"
CONTAINER_GOMODCACHE="/go/pkg/mod"
docker run --rm \
--cap-add=NET_ADMIN \
--privileged \
-v $PWD:/app \
-w /app \
-v "${HOST_GOCACHE}:${CONTAINER_GOCACHE}" \
-v "${HOST_GOMODCACHE}:${CONTAINER_GOMODCACHE}" \
-e CGO_ENABLED=1 \
-e CI=true \
-e DOCKER_CI=true \
-e GOARCH=${GOARCH_TARGET} \
-e GOCACHE=${CONTAINER_GOCACHE} \
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
golang:1.23-alpine \
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
'
test_relay: test_relay:
name: "Relay / Unit" name: "Relay / Unit"
needs: [build-cache] needs: [build-cache]
@@ -164,6 +223,10 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
- name: Get Go environment - name: Get Go environment
run: | run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
@@ -179,13 +242,6 @@ jobs:
restore-keys: | restore-keys: |
${{ runner.os }}-gotest-cache- ${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules - name: Install modules
run: go mod tidy run: go mod tidy
@@ -217,6 +273,10 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Install dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
- name: Get Go environment - name: Get Go environment
run: | run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
@@ -232,13 +292,6 @@ jobs:
restore-keys: | restore-keys: |
${{ runner.os }}-gotest-cache- ${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules - name: Install modules
run: go mod tidy run: go mod tidy
@@ -286,13 +339,6 @@ jobs:
restore-keys: | restore-keys: |
${{ runner.os }}-gotest-cache- ${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules - name: Install modules
run: go mod tidy run: go mod tidy
@@ -314,6 +360,7 @@ jobs:
run: | run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \ NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
go test -tags=devcert \ go test -tags=devcert \
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \ -exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
-timeout 20m ./management/... -timeout 20m ./management/...
@@ -353,13 +400,6 @@ jobs:
restore-keys: | restore-keys: |
${{ runner.os }}-gotest-cache- ${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules - name: Install modules
run: go mod tidy run: go mod tidy
@@ -380,10 +420,11 @@ jobs:
- name: Test - name: Test
run: | run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \ NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
go test -tags devcert -run=^$ -bench=. \ go test -tags devcert -run=^$ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./... -timeout 20m ./management/...
api_benchmark: api_benchmark:
name: "Management / Benchmark (API)" name: "Management / Benchmark (API)"
@@ -396,6 +437,33 @@ jobs:
store: [ 'sqlite', 'postgres' ] store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Create Docker network
run: docker network create promnet
- name: Start Prometheus Pushgateway
run: docker run -d --name pushgateway --network promnet -p 9091:9091 prom/pushgateway
- name: Start Prometheus (for Pushgateway forwarding)
run: |
echo '
global:
scrape_interval: 15s
scrape_configs:
- job_name: "pushgateway"
static_configs:
- targets: ["pushgateway:9091"]
remote_write:
- url: ${{ secrets.GRAFANA_URL }}
basic_auth:
username: ${{ secrets.GRAFANA_USER }}
password: ${{ secrets.GRAFANA_API_KEY }}
' > prometheus.yml
docker run -d --name prometheus --network promnet \
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
-p 9090:9090 \
prom/prometheus
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
@@ -420,13 +488,6 @@ jobs:
restore-keys: | restore-keys: |
${{ runner.os }}-gotest-cache- ${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules - name: Install modules
run: go mod tidy run: go mod tidy
@@ -447,11 +508,13 @@ jobs:
- name: Test - name: Test
run: | run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \ NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
GIT_BRANCH=${{ github.ref_name }} \
go test -tags=benchmark \ go test -tags=benchmark \
-run=^$ \ -run=^$ \
-bench=. \ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
-timeout 20m ./management/... -timeout 20m ./management/...
api_integration_test: api_integration_test:
@@ -489,13 +552,6 @@ jobs:
restore-keys: | restore-keys: |
${{ runner.os }}-gotest-cache- ${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules - name: Install modules
run: go mod tidy run: go mod tidy
@@ -505,89 +561,8 @@ jobs:
- name: Test - name: Test
run: | run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \ NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
go test -tags=integration \ go test -tags=integration \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/... -timeout 20m ./management/...
test_client_on_docker:
name: "Client (Docker) / Unit"
needs: [ build-cache ]
runs-on: ubuntu-20.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Generate Shared Sock Test bin
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
- name: Generate RouteManager Test bin
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager
- name: Generate SystemOps Test bin
run: CGO_ENABLED=1 go test -c -o systemops-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/systemops
- name: Generate nftables Manager Test bin
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
- name: Generate Engine Test bin
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
- name: Generate Peer Test bin
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/
- run: chmod +x *testing.bin
- name: Run Shared Sock tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Iface tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/...
- name: Run RouteManager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
- name: Run SystemOps tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/systemops-testing.bin -test.timeout 5m -test.parallel 1
- name: Run nftables Manager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Engine tests in docker with file store
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Engine tests in docker with sqlite store
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Peer tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1

View File

@@ -21,7 +21,6 @@ jobs:
with: with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
skip: go.mod,go.sum skip: go.mod,go.sum
only_warn: 1
golangci: golangci:
strategy: strategy:
fail-fast: false fail-fast: false

View File

@@ -172,12 +172,14 @@ jobs:
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN" grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
# check relay values # check relay values
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml grep "NB_EXPOSED_ADDRESS=rels://$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
grep '33445:33445' docker-compose.yml grep '33445:33445' docker-compose.yml
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$' grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445" grep -A 7 Relay management.json | grep "rels://$CI_NETBIRD_DOMAIN:33445"
grep -A 7 Relay management.json | egrep '"Secret": ".+"' grep -A 7 Relay management.json | egrep '"Secret": ".+"'
grep DisablePromptLogin management.json | grep 'true'
grep LoginFlag management.json | grep 0
- name: Install modules - name: Install modules
run: go mod tidy run: go mod tidy

View File

@@ -96,6 +96,20 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-upload
dir: upload-server
env: [CGO_ENABLED=0]
binary: netbird-upload
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
universal_binaries: universal_binaries:
- id: netbird - id: netbird
@@ -409,6 +423,52 @@ dockers:
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-amd64
ids:
- netbird-upload
goarch: amd64
use: buildx
dockerfile: upload-server/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
ids:
- netbird-upload
goarch: arm64
use: buildx
dockerfile: upload-server/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-arm
ids:
- netbird-upload
goarch: arm
goarm: 6
use: buildx
dockerfile: upload-server/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
docker_manifests: docker_manifests:
- name_template: netbirdio/netbird:{{ .Version }} - name_template: netbirdio/netbird:{{ .Version }}
image_templates: image_templates:
@@ -475,7 +535,17 @@ docker_manifests:
- netbirdio/management:{{ .Version }}-debug-arm64v8 - netbirdio/management:{{ .Version }}-debug-arm64v8
- netbirdio/management:{{ .Version }}-debug-arm - netbirdio/management:{{ .Version }}-debug-arm
- netbirdio/management:{{ .Version }}-debug-amd64 - netbirdio/management:{{ .Version }}-debug-amd64
- name_template: netbirdio/upload:{{ .Version }}
image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
- name_template: netbirdio/upload:latest
image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
brews: brews:
- ids: - ids:
- default - default

View File

@@ -12,7 +12,7 @@
<img src="https://img.shields.io/badge/license-BSD--3-blue" /> <img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a> </a>
<br> <br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g"> <a href="https://docs.netbird.io/slack-url">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/> <img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a> </a>
<br> <br>
@@ -29,7 +29,7 @@
<br/> <br/>
See <a href="https://netbird.io/docs/">Documentation</a> See <a href="https://netbird.io/docs/">Documentation</a>
<br/> <br/>
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">Slack channel</a> Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a>
<br/> <br/>
</strong> </strong>
@@ -58,15 +58,15 @@
### Key features ### Key features
| Connectivity | Management | Security | Automation| Platforms | | Connectivity | Management | Security | Automation| Platforms |
|------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------| |----|----|----|----|----|
| <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> | | <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> |
| <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> | | <ul><li>- \[x] Peer-to-peer connections</ul></li> | <ul><li>- \[x] Auto peer discovery and configuration</ui></li> | <ul><li>- \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)</ui></li> | <ul><li>- \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)</ui></li> | <ul><li>- \[x] Mac</ui></li> |
| <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> | | <ul><li>- \[x] Connection relay fallback</ui></li> | <ul><li>- \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)</ui></li> | <ul><li>- \[x] [Activity logging](https://docs.netbird.io/how-to/audit-events-logging)</ui></li> | <ul><li>- \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)</ui></li> | <ul><li>- \[x] Windows</ui></li> |
| <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | <ul><li> - \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> | | <ul><li>- \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)</ui></li> | <ul><li>- \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)</ui></li> | <ul><li>- \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)</ui></li> | <ul><li>- \[x] IdP groups sync with JWT</ui></li> | <ul><li>- \[x] Android</ui></li> |
| <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | <ul><li> - \[x] Peer-to-peer encryption </ul></li> | | <ul><li> - \[x] iOS </ul></li> | | <ul><li>- \[x] NAT traversal with BPF</ui></li> | <ul><li>- \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)</ui></li> | <ul><li>- \[x] Peer-to-peer encryption</ui></li> || <ul><li>- \[x] iOS</ui></li> |
| | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> | ||| <ul><li>- \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)</ui></li> || <ul><li>- \[x] OpenWRT</ui></li> |
| | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> | ||| <ul><li>- \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ui></li> || <ul><li>- \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)</ui></li> |
| | | | | <ul><li> - \[x] Docker </ul></li> | ||||| <ul><li>- \[x] Docker</ui></li> |
### Quickstart with NetBird Cloud ### Quickstart with NetBird Cloud

View File

@@ -1,5 +1,6 @@
FROM alpine:3.21.3 FROM alpine:3.21.3
RUN apk add --no-cache ca-certificates iptables ip6tables # iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables
ENV NB_FOREGROUND_MODE=true ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"] ENTRYPOINT [ "/usr/local/bin/netbird","up"]
COPY netbird /usr/local/bin/netbird COPY netbird /usr/local/bin/netbird

View File

@@ -69,6 +69,22 @@ func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
return a.ipAnonymizer[ip] return a.ipAnonymizer[ip]
} }
func (a *Anonymizer) AnonymizeUDPAddr(addr net.UDPAddr) net.UDPAddr {
// Convert IP to netip.Addr
ip, ok := netip.AddrFromSlice(addr.IP)
if !ok {
return addr
}
anonIP := a.AnonymizeIP(ip)
return net.UDPAddr{
IP: anonIP.AsSlice(),
Port: addr.Port,
Zone: addr.Zone,
}
}
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs // isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool { func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 { if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {

View File

@@ -11,9 +11,12 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server" "github.com/netbirdio/netbird/client/server"
nbstatus "github.com/netbirdio/netbird/client/status" nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/management/proto"
) )
const errCloseConnection = "Failed to close connection: %v" const errCloseConnection = "Failed to close connection: %v"
@@ -84,16 +87,27 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
}() }()
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{ request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag, Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd, anonymizeFlag), Status: getStatusOutput(cmd, anonymizeFlag),
SystemInfo: debugSystemInfoFlag, SystemInfo: debugSystemInfoFlag,
}) }
if debugUploadBundle {
request.UploadURL = debugUploadBundleURL
}
resp, err := client.DebugBundle(cmd.Context(), request)
if err != nil { if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
} }
cmd.Printf("Local file:\n%s\n", resp.GetPath())
cmd.Println(resp.GetPath()) if resp.GetUploadFailureReason() != "" {
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
}
if debugUploadBundle {
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
}
return nil return nil
} }
@@ -208,23 +222,19 @@ func runForDuration(cmd *cobra.Command, args []string) error {
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration) headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag)) statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
request := &proto.DebugBundleRequest{
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag, Anonymize: anonymizeFlag,
Status: statusOutput, Status: statusOutput,
SystemInfo: debugSystemInfoFlag, SystemInfo: debugSystemInfoFlag,
}) }
if debugUploadBundle {
request.UploadURL = debugUploadBundleURL
}
resp, err := client.DebugBundle(cmd.Context(), request)
if err != nil { if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
} }
// Disable network map persistence after creating the debug bundle
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
Enabled: false,
}); err != nil {
return fmt.Errorf("failed to disable network map persistence: %v", status.Convert(err).Message())
}
if stateWasDown { if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message()) return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
@@ -239,7 +249,15 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Log level restored to", initialLogLevel.GetLevel()) cmd.Println("Log level restored to", initialLogLevel.GetLevel())
} }
cmd.Println(resp.GetPath()) cmd.Printf("Local file:\n%s\n", resp.GetPath())
if resp.GetUploadFailureReason() != "" {
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
}
if debugUploadBundle {
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
}
return nil return nil
} }
@@ -326,3 +344,34 @@ func formatDuration(d time.Duration) string {
s := d / time.Second s := d / time.Second
return fmt.Sprintf("%02d:%02d:%02d", h, m, s) return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
} }
func generateDebugBundle(config *internal.Config, recorder *peer.Status, connectClient *internal.ConnectClient, logFilePath string) {
var networkMap *mgmProto.NetworkMap
var err error
if connectClient != nil {
networkMap, err = connectClient.GetLatestNetworkMap()
if err != nil {
log.Warnf("Failed to get latest network map: %v", err)
}
}
bundleGenerator := debug.NewBundleGenerator(
debug.GeneratorDependencies{
InternalConfig: config,
StatusRecorder: recorder,
NetworkMap: networkMap,
LogFile: logFilePath,
},
debug.BundleConfig{
IncludeSystemInfo: true,
},
)
path, err := bundleGenerator.Generate()
if err != nil {
log.Errorf("Failed to generate debug bundle: %v", err)
return
}
log.Infof("Generated debug bundle from SIGUSR1 at: %s", path)
}

39
client/cmd/debug_unix.go Normal file
View File

@@ -0,0 +1,39 @@
//go:build unix
package cmd
import (
"context"
"os"
"os/signal"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
)
func SetupDebugHandler(
ctx context.Context,
config *internal.Config,
recorder *peer.Status,
connectClient *internal.ConnectClient,
logFilePath string,
) {
usr1Ch := make(chan os.Signal, 1)
signal.Notify(usr1Ch, syscall.SIGUSR1)
go func() {
for {
select {
case <-ctx.Done():
return
case <-usr1Ch:
log.Info("Received SIGUSR1. Triggering debug bundle generation.")
go generateDebugBundle(config, recorder, connectClient, logFilePath)
}
}
}()
}

126
client/cmd/debug_windows.go Normal file
View File

@@ -0,0 +1,126 @@
package cmd
import (
"context"
"errors"
"os"
"strconv"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
)
const (
envListenEvent = "NB_LISTEN_DEBUG_EVENT"
debugTriggerEventName = `Global\NetbirdDebugTriggerEvent`
waitTimeout = 5 * time.Second
)
// SetupDebugHandler sets up a Windows event to listen for a signal to generate a debug bundle.
// Example usage with PowerShell:
// $evt = [System.Threading.EventWaitHandle]::OpenExisting("Global\NetbirdDebugTriggerEvent")
// $evt.Set()
// $evt.Close()
func SetupDebugHandler(
ctx context.Context,
config *internal.Config,
recorder *peer.Status,
connectClient *internal.ConnectClient,
logFilePath string,
) {
env := os.Getenv(envListenEvent)
if env == "" {
return
}
listenEvent, err := strconv.ParseBool(env)
if err != nil {
log.Errorf("Failed to parse %s: %v", envListenEvent, err)
return
}
if !listenEvent {
return
}
eventNamePtr, err := windows.UTF16PtrFromString(debugTriggerEventName)
if err != nil {
log.Errorf("Failed to convert event name '%s' to UTF16: %v", debugTriggerEventName, err)
return
}
// TODO: restrict access by ACL
eventHandle, err := windows.CreateEvent(nil, 1, 0, eventNamePtr)
if err != nil {
if errors.Is(err, windows.ERROR_ALREADY_EXISTS) {
log.Warnf("Debug trigger event '%s' already exists. Attempting to open.", debugTriggerEventName)
// SYNCHRONIZE is needed for WaitForSingleObject, EVENT_MODIFY_STATE for ResetEvent.
eventHandle, err = windows.OpenEvent(windows.SYNCHRONIZE|windows.EVENT_MODIFY_STATE, false, eventNamePtr)
if err != nil {
log.Errorf("Failed to open existing debug trigger event '%s': %v", debugTriggerEventName, err)
return
}
log.Infof("Successfully opened existing debug trigger event '%s'.", debugTriggerEventName)
} else {
log.Errorf("Failed to create debug trigger event '%s': %v", debugTriggerEventName, err)
return
}
}
if eventHandle == windows.InvalidHandle {
log.Errorf("Obtained an invalid handle for debug trigger event '%s'", debugTriggerEventName)
return
}
log.Infof("Debug handler waiting for signal on event: %s", debugTriggerEventName)
go waitForEvent(ctx, config, recorder, connectClient, logFilePath, eventHandle)
}
func waitForEvent(
ctx context.Context,
config *internal.Config,
recorder *peer.Status,
connectClient *internal.ConnectClient,
logFilePath string,
eventHandle windows.Handle,
) {
defer func() {
if err := windows.CloseHandle(eventHandle); err != nil {
log.Errorf("Failed to close debug event handle '%s': %v", debugTriggerEventName, err)
}
}()
for {
if ctx.Err() != nil {
return
}
status, err := windows.WaitForSingleObject(eventHandle, uint32(waitTimeout.Milliseconds()))
switch status {
case windows.WAIT_OBJECT_0:
log.Info("Received signal on debug event. Triggering debug bundle generation.")
// reset the event so it can be triggered again later (manual reset == 1)
if err := windows.ResetEvent(eventHandle); err != nil {
log.Errorf("Failed to reset debug event '%s': %v", debugTriggerEventName, err)
}
go generateDebugBundle(config, recorder, connectClient, logFilePath)
case uint32(windows.WAIT_TIMEOUT):
default:
log.Errorf("Unexpected status %d from WaitForSingleObject for debug event '%s': %v", status, debugTriggerEventName, err)
select {
case <-time.After(5 * time.Second):
case <-ctx.Done():
return
}
}
}
}

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"runtime"
"strings" "strings"
"time" "time"
@@ -19,6 +20,10 @@ import (
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
func init() {
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
}
var loginCmd = &cobra.Command{ var loginCmd = &cobra.Command{
Use: "login", Use: "login",
Short: "login to the Netbird Management Service (first run)", Short: "login to the Netbird Management Service (first run)",
@@ -51,6 +56,9 @@ var loginCmd = &cobra.Command{
return err return err
} }
// update host's static platform and system information
system.UpdateStaticInfo()
ic := internal.ConfigInput{ ic := internal.ConfigInput{
ManagementURL: managementURL, ManagementURL: managementURL,
AdminURL: adminURL, AdminURL: adminURL,
@@ -93,7 +101,7 @@ var loginCmd = &cobra.Command{
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey, SetupKey: providedSetupKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
IsLinuxDesktopClient: isLinuxRunningDesktop(), IsUnixDesktopClient: isUnixRunningDesktop(),
Hostname: hostName, Hostname: hostName,
DnsLabels: dnsLabelsReq, DnsLabels: dnsLabelsReq,
} }
@@ -127,7 +135,7 @@ var loginCmd = &cobra.Command{
} }
if loginResp.NeedsSSOLogin { if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode) openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName}) _, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil { if err != nil {
@@ -188,7 +196,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
} }
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) { func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isLinuxRunningDesktop()) oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -198,7 +206,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err) return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
} }
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode) openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout) waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
@@ -212,23 +220,34 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
return &tokenInfo, nil return &tokenInfo, nil
} }
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) { func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser bool) {
var codeMsg string var codeMsg string
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) { if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode) codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
} }
if noBrowser {
cmd.Println("Use this URL to log in:\n\n" + verificationURIComplete + " " + codeMsg)
} else {
cmd.Println("Please do the SSO login in your browser. \n" + cmd.Println("Please do the SSO login in your browser. \n" +
"If your browser didn't open automatically, use this URL to log in:\n\n" + "If your browser didn't open automatically, use this URL to log in:\n\n" +
verificationURIComplete + " " + codeMsg) verificationURIComplete + " " + codeMsg)
}
cmd.Println("") cmd.Println("")
if !noBrowser {
if err := open.Run(verificationURIComplete); err != nil { if err := open.Run(verificationURIComplete); err != nil {
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" + cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys") "https://docs.netbird.io/how-to/register-machines-using-setup-keys")
} }
} }
}
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment // isUnixRunningDesktop checks if a Linux OS is running desktop environment
func isLinuxRunningDesktop() bool { func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
return false
}
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != "" return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
} }

View File

@@ -22,6 +22,7 @@ import (
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/upload-server/types"
) )
const ( const (
@@ -38,7 +39,9 @@ const (
extraIFaceBlackListFlag = "extra-iface-blacklist" extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval" dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info" systemInfoFlag = "system-info"
blockLANAccessFlag = "block-lan-access" enableLazyConnectionFlag = "enable-lazy-connection"
uploadBundle = "upload-bundle"
uploadBundleURL = "upload-bundle-url"
) )
var ( var (
@@ -74,7 +77,9 @@ var (
anonymizeFlag bool anonymizeFlag bool
debugSystemInfoFlag bool debugSystemInfoFlag bool
dnsRouteInterval time.Duration dnsRouteInterval time.Duration
blockLANAccess bool debugUploadBundle bool
debugUploadBundleURL string
lazyConnEnabled bool
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "netbird", Use: "netbird",
@@ -179,8 +184,11 @@ func init() {
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.") upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted") upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.") upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand.")
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle") debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))
debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
} }
// SetupCloseHandler handles SIGTERM signal and exits with success // SetupCloseHandler handles SIGTERM signal and exits with success

View File

@@ -2,6 +2,7 @@ package cmd
import ( import (
"context" "context"
"runtime"
"sync" "sync"
"github.com/kardianos/service" "github.com/kardianos/service"
@@ -27,12 +28,19 @@ func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
} }
func newSVCConfig() *service.Config { func newSVCConfig() *service.Config {
return &service.Config{ config := &service.Config{
Name: serviceName, Name: serviceName,
DisplayName: "Netbird", DisplayName: "Netbird",
Description: "A WireGuard-based mesh network that connects your devices into a single private network.", Description: "Netbird mesh network client",
Option: make(service.KeyValue), Option: make(service.KeyValue),
EnvVars: make(map[string]string),
} }
if runtime.GOOS == "linux" {
config.EnvVars["SYSTEMD_UNIT"] = serviceName
}
return config
} }
func newSVC(prg *program, conf *service.Config) (service.Service, error) { func newSVC(prg *program, conf *service.Config) (service.Service, error) {

View File

@@ -16,12 +16,17 @@ import (
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server" "github.com/netbirdio/netbird/client/server"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
func (p *program) Start(svc service.Service) error { func (p *program) Start(svc service.Service) error {
// Start should not block. Do the actual work async. // Start should not block. Do the actual work async.
log.Info("starting Netbird service") //nolint log.Info("starting Netbird service") //nolint
// Collect static system and platform information
system.UpdateStaticInfo()
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API. // in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
p.serv = grpc.NewServer() p.serv = grpc.NewServer()
@@ -115,6 +120,7 @@ var runCmd = &cobra.Command{
ctx, cancel := context.WithCancel(cmd.Context()) ctx, cancel := context.WithCancel(cmd.Context())
SetupCloseHandler(ctx, cancel) SetupCloseHandler(ctx, cancel)
SetupDebugHandler(ctx, nil, nil, nil, logFile)
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig()) s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
if err != nil { if err != nil {

View File

@@ -39,7 +39,7 @@ var installCmd = &cobra.Command{
svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL) svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL)
} }
if logFile != "console" { if logFile != "" {
svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile) svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile)
} }

View File

@@ -44,7 +44,7 @@ func init() {
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4") statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200") statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud") statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected") statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
} }
func statusFunc(cmd *cobra.Command, args []string) error { func statusFunc(cmd *cobra.Command, args []string) error {
@@ -127,12 +127,12 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
func parseFilters() error { func parseFilters() error {
switch strings.ToLower(statusFilter) { switch strings.ToLower(statusFilter) {
case "", "disconnected", "connected": case "", "idle", "connecting", "connected":
if strings.ToLower(statusFilter) != "" { if strings.ToLower(statusFilter) != "" {
enableDetailFlagWhenFilterFlag() enableDetailFlagWhenFilterFlag()
} }
default: default:
return fmt.Errorf("wrong status filter, should be one of connected|disconnected, got: %s", statusFilter) return fmt.Errorf("wrong status filter, should be one of connected|connecting|idle, got: %s", statusFilter)
} }
if len(ipsFilter) > 0 { if len(ipsFilter) > 0 {

View File

@@ -6,6 +6,8 @@ const (
disableServerRoutesFlag = "disable-server-routes" disableServerRoutesFlag = "disable-server-routes"
disableDNSFlag = "disable-dns" disableDNSFlag = "disable-dns"
disableFirewallFlag = "disable-firewall" disableFirewallFlag = "disable-firewall"
blockLANAccessFlag = "block-lan-access"
blockInboundFlag = "block-inbound"
) )
var ( var (
@@ -13,6 +15,8 @@ var (
disableServerRoutes bool disableServerRoutes bool
disableDNS bool disableDNS bool
disableFirewall bool disableFirewall bool
blockLANAccess bool
blockInbound bool
) )
func init() { func init() {
@@ -28,4 +32,11 @@ func init() {
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false, upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
"Disable firewall configuration. If enabled, the client won't modify firewall rules.") "Disable firewall configuration. If enabled, the client won't modify firewall rules.")
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false,
"Block access to local networks (LAN) when using this peer as a router or exit node")
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
"This overrides any policies received from the management service.")
} }

View File

@@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
@@ -91,13 +92,18 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err) require.NoError(t, err)
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish) t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl) settingsMockManager := settings.NewMockManager(ctrl)
permissionsManagerMock := permissions.NewMockManager(ctrl)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager) settingsMockManager.EXPECT().
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&types.Settings{}, nil).
AnyTimes()
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -17,7 +17,7 @@ var traceCmd = &cobra.Command{
Example: ` Example: `
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53 netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0 netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --icmp-type 8 --icmp-code 0
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`, netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
Args: cobra.ExactArgs(3), Args: cobra.ExactArgs(3),
RunE: tracePacket, RunE: tracePacket,

View File

@@ -32,12 +32,16 @@ const (
const ( const (
dnsLabelsFlag = "extra-dns-labels" dnsLabelsFlag = "extra-dns-labels"
noBrowserFlag = "no-browser"
noBrowserDesc = "do not open the browser for SSO login"
) )
var ( var (
foregroundMode bool foregroundMode bool
dnsLabels []string dnsLabels []string
dnsLabelsValidated domain.List dnsLabelsValidated domain.List
noBrowser bool
upCmd = &cobra.Command{ upCmd = &cobra.Command{
Use: "up", Use: "up",
@@ -51,12 +55,11 @@ func init() {
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name") upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port") upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor, upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux. `+ `Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`, `E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
) )
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening") upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval") upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil, upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
`Sets DNS labels`+ `Sets DNS labels`+
@@ -65,6 +68,9 @@ func init() {
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+ `E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
`or --extra-dns-labels ""`, `or --extra-dns-labels ""`,
) )
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
} }
func upFunc(cmd *cobra.Command, args []string) error { func upFunc(cmd *cobra.Command, args []string) error {
@@ -112,6 +118,124 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
return err return err
} }
ic, err := setupConfig(customDNSAddressConverted, cmd)
if err != nil {
return fmt.Errorf("setup config: %v", err)
}
providedSetupKey, err := getSetupKey()
if err != nil {
return err
}
config, err := internal.UpdateOrCreateConfig(*ic)
if err != nil {
return fmt.Errorf("get config file: %v", err)
}
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
SetupCloseHandler(ctx, cancel)
r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r)
SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil)
}
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
customDNSAddressConverted, err := parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed)
if err != nil {
return err
}
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
defer func() {
err := conn.Close()
if err != nil {
log.Warnf("failed closing daemon gRPC client connection %v", err)
return
}
}()
client := proto.NewDaemonServiceClient(conn)
status, err := client.Status(ctx, &proto.StatusRequest{})
if err != nil {
return fmt.Errorf("unable to get daemon status: %v", err)
}
if status.Status == string(internal.StatusConnected) {
cmd.Println("Already connected")
return nil
}
providedSetupKey, err := getSetupKey()
if err != nil {
return fmt.Errorf("get setup key: %v", err)
}
loginRequest, err := setupLoginRequest(providedSetupKey, customDNSAddressConverted, cmd)
if err != nil {
return fmt.Errorf("setup login request: %v", err)
}
var loginErr error
var loginResp *proto.LoginResponse
err = WithBackOff(func() error {
var backOffErr error
loginResp, backOffErr = client.Login(ctx, loginRequest)
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
s.Code() == codes.PermissionDenied ||
s.Code() == codes.NotFound ||
s.Code() == codes.Unimplemented) {
loginErr = backOffErr
return nil
}
return backOffErr
})
if err != nil {
return fmt.Errorf("login backoff cycle failed: %v", err)
}
if loginErr != nil {
return fmt.Errorf("login failed: %v", loginErr)
}
if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil {
return fmt.Errorf("waiting sso login failed with: %v", err)
}
}
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
return fmt.Errorf("call service up method: %v", err)
}
cmd.Println("Connected")
return nil
}
func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command) (*internal.ConfigInput, error) {
ic := internal.ConfigInput{ ic := internal.ConfigInput{
ManagementURL: managementURL, ManagementURL: managementURL,
AdminURL: adminURL, AdminURL: adminURL,
@@ -136,7 +260,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
if cmd.Flag(interfaceNameFlag).Changed { if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil { if err := parseInterfaceName(interfaceName); err != nil {
return err return nil, err
} }
ic.InterfaceName = &interfaceName ic.InterfaceName = &interfaceName
} }
@@ -187,71 +311,17 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
ic.BlockLANAccess = &blockLANAccess ic.BlockLANAccess = &blockLANAccess
} }
providedSetupKey, err := getSetupKey() if cmd.Flag(blockInboundFlag).Changed {
if err != nil { ic.BlockInbound = &blockInbound
return err
} }
config, err := internal.UpdateOrCreateConfig(ic) if cmd.Flag(enableLazyConnectionFlag).Changed {
if err != nil { ic.LazyConnectionEnabled = &lazyConnEnabled
return fmt.Errorf("get config file: %v", err) }
} return &ic, nil
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
SetupCloseHandler(ctx, cancel)
r := peer.NewRecorder(config.ManagementURL.String())
r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r)
return connectClient.Run(nil)
}
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
customDNSAddressConverted, err := parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed)
if err != nil {
return err
}
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
defer func() {
err := conn.Close()
if err != nil {
log.Warnf("failed closing daemon gRPC client connection %v", err)
return
}
}()
client := proto.NewDaemonServiceClient(conn)
status, err := client.Status(ctx, &proto.StatusRequest{})
if err != nil {
return fmt.Errorf("unable to get daemon status: %v", err)
}
if status.Status == string(internal.StatusConnected) {
cmd.Println("Already connected")
return nil
}
providedSetupKey, err := getSetupKey()
if err != nil {
return err
} }
func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte, cmd *cobra.Command) (*proto.LoginRequest, error) {
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey, SetupKey: providedSetupKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
@@ -259,7 +329,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
NatExternalIPs: natExternalIPs, NatExternalIPs: natExternalIPs,
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0, CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
CustomDNSAddress: customDNSAddressConverted, CustomDNSAddress: customDNSAddressConverted,
IsLinuxDesktopClient: isLinuxRunningDesktop(), IsUnixDesktopClient: isUnixRunningDesktop(),
Hostname: hostName, Hostname: hostName,
ExtraIFaceBlacklist: extraIFaceBlackList, ExtraIFaceBlacklist: extraIFaceBlackList,
DnsLabels: dnsLabels, DnsLabels: dnsLabels,
@@ -288,7 +358,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
if cmd.Flag(interfaceNameFlag).Changed { if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil { if err := parseInterfaceName(interfaceName); err != nil {
return err return nil, err
} }
loginRequest.InterfaceName = &interfaceName loginRequest.InterfaceName = &interfaceName
} }
@@ -323,45 +393,14 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
loginRequest.BlockLanAccess = &blockLANAccess loginRequest.BlockLanAccess = &blockLANAccess
} }
var loginErr error if cmd.Flag(blockInboundFlag).Changed {
loginRequest.BlockInbound = &blockInbound
var loginResp *proto.LoginResponse
err = WithBackOff(func() error {
var backOffErr error
loginResp, backOffErr = client.Login(ctx, &loginRequest)
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
s.Code() == codes.PermissionDenied ||
s.Code() == codes.NotFound ||
s.Code() == codes.Unimplemented) {
loginErr = backOffErr
return nil
}
return backOffErr
})
if err != nil {
return fmt.Errorf("login backoff cycle failed: %v", err)
} }
if loginErr != nil { if cmd.Flag(enableLazyConnectionFlag).Changed {
return fmt.Errorf("login failed: %v", loginErr) loginRequest.LazyConnectionEnabled = &lazyConnEnabled
} }
return &loginRequest, nil
if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil {
return fmt.Errorf("waiting sso login failed with: %v", err)
}
}
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
return fmt.Errorf("call service up method: %v", err)
}
cmd.Println("Connected")
return nil
} }
func validateNATExternalIPs(list []string) error { func validateNATExternalIPs(list []string) error {

View File

@@ -113,17 +113,16 @@ func (m *Manager) AddPeerFiltering(
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
id []byte, id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination firewall.Network,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort, dPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action, action firewall.Action,
) (firewall.Rule, error) { ) (firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if !destination.Addr().Is4() { if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String()) return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
} }
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
@@ -148,6 +147,10 @@ func (m *Manager) IsServerRouteSupported() bool {
return true return true
} }
func (m *Manager) IsStateful() bool {
return true
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -199,7 +202,7 @@ func (m *Manager) AllowNetbird() error {
_, err := m.AddPeerFiltering( _, err := m.AddPeerFiltering(
nil, nil,
net.IP{0, 0, 0, 0}, net.IP{0, 0, 0, 0},
"all", firewall.ProtocolALL,
nil, nil,
nil, nil,
firewall.ActionAccept, firewall.ActionAccept,
@@ -220,10 +223,16 @@ func (m *Manager) SetLogLevel(log.Level) {
} }
func (m *Manager) EnableRouting() error { func (m *Manager) EnableRouting() error {
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
return fmt.Errorf("enable IP forwarding: %w", err)
}
return nil return nil
} }
func (m *Manager) DisableRouting() error { func (m *Manager) DisableRouting() error {
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
return fmt.Errorf("disable IP forwarding: %w", err)
}
return nil return nil
} }
@@ -243,6 +252,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
return m.router.DeleteDNATRule(rule) return m.router.DeleteDNATRule(rule)
} }
// UpdateSet updates the set with the given prefixes
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.UpdateSet(set, prefixes)
}
func getConntrackEstablished() []string { func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
} }

View File

@@ -2,7 +2,7 @@ package iptables
import ( import (
"fmt" "fmt"
"net" "net/netip"
"testing" "testing"
"time" "time"
@@ -19,11 +19,8 @@ var ifaceMock = &iFaceMock{
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: netip.MustParseAddr("10.20.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("10.20.0.0/24"),
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -70,12 +67,12 @@ func TestIptablesManager(t *testing.T) {
var rule2 []fw.Rule var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) { t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3") ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{ port := &fw.Port{
IsRange: true, IsRange: true,
Values: []uint16{8043, 8046}, Values: []uint16{8043, 8046},
} }
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "") rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
for _, r := range rule2 { for _, r := range rule2 {
@@ -95,9 +92,9 @@ func TestIptablesManager(t *testing.T) {
t.Run("reset check", func(t *testing.T) { t.Run("reset check", func(t *testing.T) {
// add second rule // add second rule
ip := net.ParseIP("10.20.0.3") ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{Values: []uint16{5353}} port := &fw.Port{Values: []uint16{5353}}
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Close(nil) err = manager.Close(nil)
@@ -119,11 +116,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: netip.MustParseAddr("10.20.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("10.20.0.0/24"),
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -144,11 +138,11 @@ func TestIptablesManagerIPSet(t *testing.T) {
var rule2 []fw.Rule var rule2 []fw.Rule
t.Run("add second rule", func(t *testing.T) { t.Run("add second rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.3") ip := netip.MustParseAddr("10.20.0.3")
port := &fw.Port{ port := &fw.Port{
Values: []uint16{443}, Values: []uint16{443},
} }
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default") rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
for _, r := range rule2 { for _, r := range rule2 {
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set") require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
@@ -186,11 +180,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: netip.MustParseAddr("10.20.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("10.20.0.0/24"),
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -212,11 +203,11 @@ func TestIptablesCreatePerformance(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
ip := net.ParseIP("10.20.0.100") ip := netip.MustParseAddr("10.20.0.100")
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}} port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
} }

View File

@@ -41,6 +41,8 @@ const (
jumpManglePre = "jump-mangle-pre" jumpManglePre = "jump-mangle-pre"
jumpNatPre = "jump-nat-pre" jumpNatPre = "jump-nat-pre"
jumpNatPost = "jump-nat-post" jumpNatPost = "jump-nat-post"
markManglePre = "mark-mangle-pre"
markManglePost = "mark-mangle-post"
matchSet = "--match-set" matchSet = "--match-set"
dnatSuffix = "_dnat" dnatSuffix = "_dnat"
@@ -55,18 +57,18 @@ type ruleInfo struct {
} }
type routeFilteringRuleParams struct { type routeFilteringRuleParams struct {
Sources []netip.Prefix Source firewall.Network
Destination netip.Prefix Destination firewall.Network
Proto firewall.Protocol Proto firewall.Protocol
SPort *firewall.Port SPort *firewall.Port
DPort *firewall.Port DPort *firewall.Port
Direction firewall.RuleDirection Direction firewall.RuleDirection
Action firewall.Action Action firewall.Action
SetName string
} }
type routeRules map[string][]string type routeRules map[string][]string
// the ipset library currently does not support comments, so we use the name only (string)
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}] type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
type router struct { type router struct {
@@ -115,6 +117,10 @@ func (r *router) init(stateManager *statemanager.Manager) error {
return fmt.Errorf("create containers: %w", err) return fmt.Errorf("create containers: %w", err)
} }
if err := r.setupDataPlaneMark(); err != nil {
log.Errorf("failed to set up data plane mark: %v", err)
}
r.updateState() r.updateState()
return nil return nil
@@ -123,7 +129,7 @@ func (r *router) init(stateManager *statemanager.Manager) error {
func (r *router) AddRouteFiltering( func (r *router) AddRouteFiltering(
id []byte, id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination firewall.Network,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
@@ -134,27 +140,28 @@ func (r *router) AddRouteFiltering(
return ruleKey, nil return ruleKey, nil
} }
var setName string var source firewall.Network
if len(sources) > 1 { if len(sources) > 1 {
setName = firewall.GenerateSetName(sources) source.Set = firewall.NewPrefixSet(sources)
if _, err := r.ipsetCounter.Increment(setName, sources); err != nil { } else if len(sources) > 0 {
return nil, fmt.Errorf("create or get ipset: %w", err) source.Prefix = sources[0]
}
} }
params := routeFilteringRuleParams{ params := routeFilteringRuleParams{
Sources: sources, Source: source,
Destination: destination, Destination: destination,
Proto: proto, Proto: proto,
SPort: sPort, SPort: sPort,
DPort: dPort, DPort: dPort,
Action: action, Action: action,
SetName: setName,
} }
rule := genRouteFilteringRuleSpec(params) rule, err := r.genRouteRuleSpec(params, sources)
if err != nil {
return nil, fmt.Errorf("generate route rule spec: %w", err)
}
// Insert DROP rules at the beginning, append ACCEPT rules at the end // Insert DROP rules at the beginning, append ACCEPT rules at the end
var err error
if action == firewall.ActionDrop { if action == firewall.ActionDrop {
// after the established rule // after the established rule
err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...) err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...)
@@ -177,17 +184,13 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
ruleKey := rule.ID() ruleKey := rule.ID()
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
setName := r.findSetNameInRule(rule)
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil { if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil {
return fmt.Errorf("delete route rule: %v", err) return fmt.Errorf("delete route rule: %v", err)
} }
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
if setName != "" { if err := r.decrementSetCounter(rule); err != nil {
if _, err := r.ipsetCounter.Decrement(setName); err != nil { return fmt.Errorf("decrement ipset counter: %w", err)
return fmt.Errorf("failed to remove ipset: %w", err)
}
} }
} else { } else {
log.Debugf("route rule %s not found", ruleKey) log.Debugf("route rule %s not found", ruleKey)
@@ -198,13 +201,26 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
return nil return nil
} }
func (r *router) findSetNameInRule(rule []string) string { func (r *router) decrementSetCounter(rule []string) error {
sets := r.findSets(rule)
var merr *multierror.Error
for _, setName := range sets {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
merr = multierror.Append(merr, fmt.Errorf("decrement counter: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) findSets(rule []string) []string {
var sets []string
for i, arg := range rule { for i, arg := range rule {
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet { if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
return rule[i+3] sets = append(sets, rule[i+3])
} }
} }
return "" return sets
} }
func (r *router) createIpSet(setName string, sources []netip.Prefix) error { func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
@@ -225,15 +241,13 @@ func (r *router) deleteIpSet(setName string) error {
if err := ipset.Destroy(setName); err != nil { if err := ipset.Destroy(setName); err != nil {
return fmt.Errorf("destroy set %s: %w", setName, err) return fmt.Errorf("destroy set %s: %w", setName, err)
} }
log.Debugf("Deleted unused ipset %s", setName)
return nil return nil
} }
// AddNatRule inserts an iptables rule pair into the nat chain // AddNatRule inserts an iptables rule pair into the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error { func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if r.legacyManagement { if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination) log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil { if err := r.addLegacyRouteRule(pair); err != nil {
@@ -260,10 +274,7 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains // RemoveNatRule removes an iptables rule pair from forwarding and nat chains
func (r *router) RemoveNatRule(pair firewall.RouterPair) error { func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil { if pair.Masquerade {
log.Errorf("%v", err)
}
if err := r.removeNatRule(pair); err != nil { if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err) return fmt.Errorf("remove nat rule: %w", err)
} }
@@ -271,6 +282,7 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse nat rule: %w", err) return fmt.Errorf("remove inverse nat rule: %w", err)
} }
}
if err := r.removeLegacyRouteRule(pair); err != nil { if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err) return fmt.Errorf("remove legacy routing rule: %w", err)
@@ -307,8 +319,10 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
} }
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
} else {
log.Debugf("legacy forwarding rule %s not found", ruleKey) if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement ipset counter: %w", err)
}
} }
return nil return nil
@@ -348,12 +362,16 @@ func (r *router) Reset() error {
if err := r.cleanUpDefaultForwardRules(); err != nil { if err := r.cleanUpDefaultForwardRules(); err != nil {
merr = multierror.Append(merr, err) merr = multierror.Append(merr, err)
} }
r.rules = make(map[string][]string)
if err := r.ipsetCounter.Flush(); err != nil { if err := r.ipsetCounter.Flush(); err != nil {
merr = multierror.Append(merr, err) merr = multierror.Append(merr, err)
} }
if err := r.cleanupDataPlaneMark(); err != nil {
merr = multierror.Append(merr, err)
}
r.rules = make(map[string][]string)
r.updateState() r.updateState()
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
@@ -423,6 +441,57 @@ func (r *router) createContainers() error {
return nil return nil
} }
// setupDataPlaneMark configures the fwmark for the data plane
func (r *router) setupDataPlaneMark() error {
var merr *multierror.Error
preRule := []string{
"-i", r.wgIface.Name(),
"-m", "conntrack", "--ctstate", "NEW",
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkIn),
}
if err := r.iptablesClient.AppendUnique(tableMangle, chainPREROUTING, preRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add mangle prerouting rule: %w", err))
} else {
r.rules[markManglePre] = preRule
}
postRule := []string{
"-o", r.wgIface.Name(),
"-m", "conntrack", "--ctstate", "NEW",
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkOut),
}
if err := r.iptablesClient.AppendUnique(tableMangle, chainPOSTROUTING, postRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add mangle postrouting rule: %w", err))
} else {
r.rules[markManglePost] = postRule
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) cleanupDataPlaneMark() error {
var merr *multierror.Error
if preRule, exists := r.rules[markManglePre]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPREROUTING, preRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove mangle prerouting rule: %w", err))
} else {
delete(r.rules, markManglePre)
}
}
if postRule, exists := r.rules[markManglePost]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPOSTROUTING, postRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove mangle postrouting rule: %w", err))
} else {
delete(r.rules, markManglePost)
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) addPostroutingRules() error { func (r *router) addPostroutingRules() error {
// First rule for outbound masquerade // First rule for outbound masquerade
rule1 := []string{ rule1 := []string{
@@ -464,7 +533,7 @@ func (r *router) insertEstablishedRule(chain string) error {
} }
func (r *router) addJumpRules() error { func (r *router) addJumpRules() error {
// Jump to NAT chain // Jump to nat chain
natRule := []string{"-j", chainRTNAT} natRule := []string{"-j", chainRTNAT}
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil { if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
return fmt.Errorf("add nat postrouting jump rule: %v", err) return fmt.Errorf("add nat postrouting jump rule: %v", err)
@@ -538,12 +607,26 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
rule = append(rule, rule = append(rule,
"-m", "conntrack", "-m", "conntrack",
"--ctstate", "NEW", "--ctstate", "NEW",
"-s", pair.Source.String(), )
"-d", pair.Destination.String(), sourceExp, err := r.applyNetwork("-s", pair.Source, nil)
if err != nil {
return fmt.Errorf("apply network -s: %w", err)
}
destExp, err := r.applyNetwork("-d", pair.Destination, nil)
if err != nil {
return fmt.Errorf("apply network -d: %w", err)
}
rule = append(rule, sourceExp...)
rule = append(rule, destExp...)
rule = append(rule,
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue), "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
) )
if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil { // Ensure nat rules come first, so the mark can be overwritten.
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
if err := r.iptablesClient.Insert(tableMangle, chainRTPRE, 1, rule...); err != nil {
// TODO: rollback ipset counter
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err) return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
} }
@@ -561,6 +644,10 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err) return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
} }
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement ipset counter: %w", err)
}
} else { } else {
log.Debugf("marking rule %s not found", ruleKey) log.Debugf("marking rule %s not found", ruleKey)
} }
@@ -726,17 +813,21 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []netip.Prefix) ([]string, error) {
var rule []string var rule []string
if params.SetName != "" { sourceExp, err := r.applyNetwork("-s", params.Source, sources)
rule = append(rule, "-m", "set", matchSet, params.SetName, "src") if err != nil {
} else if len(params.Sources) > 0 { return nil, fmt.Errorf("apply network -s: %w", err)
source := params.Sources[0]
rule = append(rule, "-s", source.String()) }
destExp, err := r.applyNetwork("-d", params.Destination, nil)
if err != nil {
return nil, fmt.Errorf("apply network -d: %w", err)
} }
rule = append(rule, "-d", params.Destination.String()) rule = append(rule, sourceExp...)
rule = append(rule, destExp...)
if params.Proto != firewall.ProtocolALL { if params.Proto != firewall.ProtocolALL {
rule = append(rule, "-p", strings.ToLower(string(params.Proto))) rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
@@ -746,7 +837,47 @@ func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
rule = append(rule, "-j", actionToStr(params.Action)) rule = append(rule, "-j", actionToStr(params.Action))
return rule return rule, nil
}
func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
direction := "src"
if flag == "-d" {
direction = "dst"
}
if network.IsSet() {
if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err)
}
return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil
}
if network.IsPrefix() {
return []string{flag, network.Prefix.String()}, nil
}
// nolint:nilnil
return nil, nil
}
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
var merr *multierror.Error
for _, prefix := range prefixes {
// TODO: Implement IPv6 support
if prefix.Addr().Is6() {
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue
}
if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err))
}
}
if merr == nil {
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
}
return nberrors.FormatErrorOrNil(merr)
} }
func applyPort(flag string, port *firewall.Port) []string { func applyPort(flag string, port *firewall.Port) []string {

View File

@@ -46,7 +46,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
// 5. jump rule to PRE nat chain // 5. jump rule to PRE nat chain
// 6. static outbound masquerade rule // 6. static outbound masquerade rule
// 7. static return masquerade rule // 7. static return masquerade rule
require.Len(t, manager.rules, 7, "should have created rules map") // 8. mangle prerouting mark rule
// 9. mangle postrouting mark rule
require.Len(t, manager.rules, 9, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT) exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
@@ -58,8 +60,8 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
pair := firewall.RouterPair{ pair := firewall.RouterPair{
ID: "abc", ID: "abc",
Source: netip.MustParsePrefix("100.100.100.1/32"), Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: netip.MustParsePrefix("100.100.100.0/24"), Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.0/24")},
Masquerade: true, Masquerade: true,
} }
@@ -330,7 +332,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed") require.NoError(t, err, "AddRouteFiltering failed")
// Check if the rule is in the internal map // Check if the rule is in the internal map
@@ -345,23 +347,29 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
assert.NoError(t, err, "Failed to check rule existence") assert.NoError(t, err, "Failed to check rule existence")
assert.True(t, exists, "Rule not found in iptables") assert.True(t, exists, "Rule not found in iptables")
var source firewall.Network
if len(tt.sources) > 1 {
source.Set = firewall.NewPrefixSet(tt.sources)
} else if len(tt.sources) > 0 {
source.Prefix = tt.sources[0]
}
// Verify rule content // Verify rule content
params := routeFilteringRuleParams{ params := routeFilteringRuleParams{
Sources: tt.sources, Source: source,
Destination: tt.destination, Destination: firewall.Network{Prefix: tt.destination},
Proto: tt.proto, Proto: tt.proto,
SPort: tt.sPort, SPort: tt.sPort,
DPort: tt.dPort, DPort: tt.dPort,
Action: tt.action, Action: tt.action,
SetName: "",
} }
expectedRule := genRouteFilteringRuleSpec(params) expectedRule, err := r.genRouteRuleSpec(params, nil)
require.NoError(t, err, "Failed to generate expected rule spec")
if tt.expectSet { if tt.expectSet {
setName := firewall.GenerateSetName(tt.sources) setName := firewall.NewPrefixSet(tt.sources).HashedName()
params.SetName = setName expectedRule, err = r.genRouteRuleSpec(params, nil)
expectedRule = genRouteFilteringRuleSpec(params) require.NoError(t, err, "Failed to generate expected rule spec with set")
// Check if the set was created // Check if the set was created
_, exists := r.ipsetCounter.Get(setName) _, exists := r.ipsetCounter.Get(setName)
@@ -376,3 +384,62 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
}) })
} }
} }
func TestFindSetNameInRule(t *testing.T) {
r := &router{}
testCases := []struct {
name string
rule []string
expected []string
}{
{
name: "Basic rule with two sets",
rule: []string{
"-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-m", "set", "--match-set", "nb-2e5a2a05", "src",
"-m", "set", "--match-set", "nb-349ae051", "dst", "-m", "tcp", "--dport", "8080", "-j", "ACCEPT",
},
expected: []string{"nb-2e5a2a05", "nb-349ae051"},
},
{
name: "No sets",
rule: []string{"-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-j", "ACCEPT"},
expected: []string{},
},
{
name: "Multiple sets with different positions",
rule: []string{
"-m", "set", "--match-set", "set1", "src", "-p", "tcp",
"-m", "set", "--match-set", "set-abc123", "dst", "-j", "ACCEPT",
},
expected: []string{"set1", "set-abc123"},
},
{
name: "Boundary case - sequence appears at end",
rule: []string{"-p", "tcp", "-m", "set", "--match-set", "final-set"},
expected: []string{"final-set"},
},
{
name: "Incomplete pattern - missing set name",
rule: []string{"-p", "tcp", "-m", "set", "--match-set"},
expected: []string{},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := r.findSets(tc.rule)
if len(result) != len(tc.expected) {
t.Errorf("Expected %d sets, got %d. Sets found: %v", len(tc.expected), len(result), result)
return
}
for i, set := range result {
if set != tc.expected[i] {
t.Errorf("Expected set %q at position %d, got %q", tc.expected[i], i, set)
}
}
})
}
}

View File

@@ -1,13 +1,10 @@
package manager package manager
import ( import (
"crypto/sha256"
"encoding/hex"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"sort" "sort"
"strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -43,6 +40,18 @@ const (
// Action is the action to be taken on a rule // Action is the action to be taken on a rule
type Action int type Action int
// String returns the string representation of the action
func (a Action) String() string {
switch a {
case ActionAccept:
return "accept"
case ActionDrop:
return "drop"
default:
return "unknown"
}
}
const ( const (
// ActionAccept is the action to accept a packet // ActionAccept is the action to accept a packet
ActionAccept Action = iota ActionAccept Action = iota
@@ -50,6 +59,33 @@ const (
ActionDrop ActionDrop
) )
// Network is a rule destination, either a set or a prefix
type Network struct {
Set Set
Prefix netip.Prefix
}
// String returns the string representation of the destination
func (d Network) String() string {
if d.Prefix.IsValid() {
return d.Prefix.String()
}
if d.IsSet() {
return d.Set.HashedName()
}
return "<invalid network>"
}
// IsSet returns true if the destination is a set
func (d Network) IsSet() bool {
return d.Set != Set{}
}
// IsPrefix returns true if the destination is a valid prefix
func (d Network) IsPrefix() bool {
return d.Prefix.IsValid()
}
// Manager is the high level abstraction of a firewall manager // Manager is the high level abstraction of a firewall manager
// //
// It declares methods which handle actions required by the // It declares methods which handle actions required by the
@@ -80,13 +116,14 @@ type Manager interface {
// IsServerRouteSupported returns true if the firewall supports server side routing operations // IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool IsServerRouteSupported() bool
IsStateful() bool
AddRouteFiltering( AddRouteFiltering(
id []byte, id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination Network,
proto Protocol, proto Protocol,
sPort *Port, sPort, dPort *Port,
dPort *Port,
action Action, action Action,
) (Rule, error) ) (Rule, error)
@@ -119,6 +156,9 @@ type Manager interface {
// DeleteDNATRule deletes a DNAT rule // DeleteDNATRule deletes a DNAT rule
DeleteDNATRule(Rule) error DeleteDNATRule(Rule) error
// UpdateSet updates the set with the given prefixes
UpdateSet(hash Set, prefixes []netip.Prefix) error
} }
func GenKey(format string, pair RouterPair) string { func GenKey(format string, pair RouterPair) string {
@@ -153,22 +193,6 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
return nil return nil
} }
// GenerateSetName generates a unique name for an ipset based on the given sources.
func GenerateSetName(sources []netip.Prefix) string {
// sort for consistent naming
SortPrefixes(sources)
var sourcesStr strings.Builder
for _, src := range sources {
sourcesStr.WriteString(src.String())
}
hash := sha256.Sum256([]byte(sourcesStr.String()))
shortHash := hex.EncodeToString(hash[:])[:8]
return fmt.Sprintf("nb-%s", shortHash)
}
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix // MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix { func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
if len(prefixes) == 0 { if len(prefixes) == 0 {

View File

@@ -20,8 +20,8 @@ func TestGenerateSetName(t *testing.T) {
netip.MustParsePrefix("192.168.1.0/24"), netip.MustParsePrefix("192.168.1.0/24"),
} }
result1 := manager.GenerateSetName(prefixes1) result1 := manager.NewPrefixSet(prefixes1)
result2 := manager.GenerateSetName(prefixes2) result2 := manager.NewPrefixSet(prefixes2)
if result1 != result2 { if result1 != result2 {
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2) t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
@@ -34,9 +34,9 @@ func TestGenerateSetName(t *testing.T) {
netip.MustParsePrefix("10.0.0.0/8"), netip.MustParsePrefix("10.0.0.0/8"),
} }
result := manager.GenerateSetName(prefixes) result := manager.NewPrefixSet(prefixes)
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result) matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result.HashedName())
if err != nil { if err != nil {
t.Fatalf("Error matching regex: %v", err) t.Fatalf("Error matching regex: %v", err)
} }
@@ -46,8 +46,8 @@ func TestGenerateSetName(t *testing.T) {
}) })
t.Run("Empty input produces consistent result", func(t *testing.T) { t.Run("Empty input produces consistent result", func(t *testing.T) {
result1 := manager.GenerateSetName([]netip.Prefix{}) result1 := manager.NewPrefixSet([]netip.Prefix{})
result2 := manager.GenerateSetName([]netip.Prefix{}) result2 := manager.NewPrefixSet([]netip.Prefix{})
if result1 != result2 { if result1 != result2 {
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2) t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
@@ -64,8 +64,8 @@ func TestGenerateSetName(t *testing.T) {
netip.MustParsePrefix("192.168.1.0/24"), netip.MustParsePrefix("192.168.1.0/24"),
} }
result1 := manager.GenerateSetName(prefixes1) result1 := manager.NewPrefixSet(prefixes1)
result2 := manager.GenerateSetName(prefixes2) result2 := manager.NewPrefixSet(prefixes2)
if result1 != result2 { if result1 != result2 {
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2) t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)

View File

@@ -1,15 +1,13 @@
package manager package manager
import ( import (
"net/netip"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
type RouterPair struct { type RouterPair struct {
ID route.ID ID route.ID
Source netip.Prefix Source Network
Destination netip.Prefix Destination Network
Masquerade bool Masquerade bool
Inverse bool Inverse bool
} }

View File

@@ -0,0 +1,74 @@
package manager
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net/netip"
"slices"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/domain"
)
type Set struct {
hash [4]byte
comment string
}
// String returns the string representation of the set: hashed name and comment
func (h Set) String() string {
if h.comment == "" {
return h.HashedName()
}
return h.HashedName() + ": " + h.comment
}
// HashedName returns the string representation of the hash
func (h Set) HashedName() string {
return fmt.Sprintf(
"nb-%s",
hex.EncodeToString(h.hash[:]),
)
}
// Comment returns the comment of the set
func (h Set) Comment() string {
return h.comment
}
// NewPrefixSet generates a unique name for an ipset based on the given prefixes.
func NewPrefixSet(prefixes []netip.Prefix) Set {
// sort for consistent naming
SortPrefixes(prefixes)
hash := sha256.New()
for _, src := range prefixes {
bytes, err := src.MarshalBinary()
if err != nil {
log.Warnf("failed to marshal prefix %s: %v", src, err)
}
hash.Write(bytes)
}
var set Set
copy(set.hash[:], hash.Sum(nil)[:4])
return set
}
// NewDomainSet generates a unique name for an ipset based on the given domains.
func NewDomainSet(domains domain.List) Set {
slices.Sort(domains)
hash := sha256.New()
for _, d := range domains {
hash.Write([]byte(d.PunycodeString()))
}
set := Set{
comment: domains.SafeString(),
}
copy(set.hash[:], hash.Sum(nil)[:4])
return set
}

View File

@@ -27,7 +27,8 @@ const (
// filter chains contains the rules that jump to the rules chains // filter chains contains the rules that jump to the rules chains
chainNameInputFilter = "netbird-acl-input-filter" chainNameInputFilter = "netbird-acl-input-filter"
chainNameForwardFilter = "netbird-acl-forward-filter" chainNameForwardFilter = "netbird-acl-forward-filter"
chainNamePrerouting = "netbird-rt-prerouting" chainNameManglePrerouting = "netbird-mangle-prerouting"
chainNameManglePostrouting = "netbird-mangle-postrouting"
allowNetbirdInputRuleID = "allow Netbird incoming traffic" allowNetbirdInputRuleID = "allow Netbird incoming traffic"
) )
@@ -462,13 +463,15 @@ func (m *AclManager) createDefaultChains() (err error) {
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the // go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
// netbird peer IP. // netbird peer IP.
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error { func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
m.chainPrerouting = m.rConn.AddChain(&nftables.Chain{ // Chain is created by route manager
Name: chainNamePrerouting, // TODO: move creation to a common place
m.chainPrerouting = &nftables.Chain{
Name: chainNameManglePrerouting,
Table: m.workTable, Table: m.workTable,
Type: nftables.ChainTypeFilter, Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting, Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle, Priority: nftables.ChainPriorityMangle,
}) }
m.addFwmarkToForward(chainFwFilter) m.addFwmarkToForward(chainFwFilter)

View File

@@ -135,17 +135,16 @@ func (m *Manager) AddPeerFiltering(
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
id []byte, id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination firewall.Network,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort, dPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action, action firewall.Action,
) (firewall.Rule, error) { ) (firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if !destination.Addr().Is4() { if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String()) return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
} }
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
@@ -171,6 +170,10 @@ func (m *Manager) IsServerRouteSupported() bool {
return true return true
} }
func (m *Manager) IsStateful() bool {
return true
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -242,7 +245,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return firewall.SetLegacyManagement(m.router, isLegacy) return firewall.SetLegacyManagement(m.router, isLegacy)
} }
// Reset firewall to the default state // Close closes the firewall manager
func (m *Manager) Close(stateManager *statemanager.Manager) error { func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -325,10 +328,16 @@ func (m *Manager) SetLogLevel(log.Level) {
} }
func (m *Manager) EnableRouting() error { func (m *Manager) EnableRouting() error {
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
return fmt.Errorf("enable IP forwarding: %w", err)
}
return nil return nil
} }
func (m *Manager) DisableRouting() error { func (m *Manager) DisableRouting() error {
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
return fmt.Errorf("disable IP forwarding: %w", err)
}
return nil return nil
} }
@@ -359,6 +368,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
return m.router.DeleteDNATRule(rule) return m.router.DeleteDNATRule(rule)
} }
// UpdateSet updates the set with the given prefixes
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.UpdateSet(set, prefixes)
}
func (m *Manager) createWorkTable() (*nftables.Table, error) { func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil { if err != nil {

View File

@@ -3,7 +3,6 @@ package nftables
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"os/exec" "os/exec"
"testing" "testing"
@@ -25,11 +24,8 @@ var ifaceMock = &iFaceMock{
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("100.96.0.1"), IP: netip.MustParseAddr("100.96.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("100.96.0.0/16"),
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -70,11 +66,11 @@ func TestNftablesManager(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
ip := net.ParseIP("100.96.0.1") ip := netip.MustParseAddr("100.96.0.1").Unmap()
testClient := &nftables.Conn{} testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "") rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Flush() err = manager.Flush()
@@ -109,8 +105,6 @@ func TestNftablesManager(t *testing.T) {
} }
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1) compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
expectedExprs2 := []expr.Any{ expectedExprs2 := []expr.Any{
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
@@ -132,7 +126,7 @@ func TestNftablesManager(t *testing.T) {
&expr.Cmp{ &expr.Cmp{
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
Register: 1, Register: 1,
Data: add.AsSlice(), Data: ip.AsSlice(),
}, },
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
@@ -173,11 +167,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
}, },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("100.96.0.1"), IP: netip.MustParseAddr("100.96.0.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("100.96.0.0/16"),
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
} }
}, },
} }
@@ -197,11 +188,11 @@ func TestNFtablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
ip := net.ParseIP("10.20.0.100") ip := netip.MustParseAddr("10.20.0.100")
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}} port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
if i%100 == 0 { if i%100 == 0 {
@@ -282,14 +273,14 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
verifyIptablesOutput(t, stdout, stderr) verifyIptablesOutput(t, stdout, stderr)
}) })
ip := net.ParseIP("100.96.0.1") ip := netip.MustParseAddr("100.96.0.1")
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err, "failed to add peer filtering rule") require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering( _, err = manager.AddRouteFiltering(
nil, nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")}, []netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
netip.MustParsePrefix("10.1.0.0/24"), fw.Network{Prefix: netip.MustParsePrefix("10.1.0.0/24")},
fw.ProtocolTCP, fw.ProtocolTCP,
nil, nil,
&fw.Port{Values: []uint16{443}}, &fw.Port{Values: []uint16{443}},
@@ -298,8 +289,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
require.NoError(t, err, "failed to add route filtering rule") require.NoError(t, err, "failed to add route filtering rule")
pair := fw.RouterPair{ pair := fw.RouterPair{
Source: netip.MustParsePrefix("192.168.1.0/24"), Source: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
Destination: netip.MustParsePrefix("10.0.0.0/24"), Destination: fw.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
Masquerade: true, Masquerade: true,
} }
err = manager.AddNatRule(pair) err = manager.AddNatRule(pair)

View File

@@ -10,7 +10,6 @@ import (
"strings" "strings"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/davecgh/go-spew/spew"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil" "github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
@@ -44,9 +43,14 @@ const (
const refreshRulesMapError = "refresh rules map: %w" const refreshRulesMapError = "refresh rules map: %w"
var ( var (
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found") errFilterTableNotFound = fmt.Errorf("'filter' table not found")
) )
type setInput struct {
set firewall.Set
prefixes []netip.Prefix
}
type router struct { type router struct {
conn *nftables.Conn conn *nftables.Conn
workTable *nftables.Table workTable *nftables.Table
@@ -54,7 +58,7 @@ type router struct {
chains map[string]*nftables.Chain chains map[string]*nftables.Chain
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules // rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
rules map[string]*nftables.Rule rules map[string]*nftables.Rule
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set] ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
wgIface iFaceMapper wgIface iFaceMapper
ipFwdState *ipfwdstate.IPForwardingState ipFwdState *ipfwdstate.IPForwardingState
@@ -100,6 +104,10 @@ func (r *router) init(workTable *nftables.Table) error {
return fmt.Errorf("create containers: %w", err) return fmt.Errorf("create containers: %w", err)
} }
if err := r.setupDataPlaneMark(); err != nil {
log.Errorf("failed to set up data plane mark: %v", err)
}
return nil return nil
} }
@@ -159,7 +167,7 @@ func (r *router) removeNatPreroutingRules() error {
func (r *router) loadFilterTable() (*nftables.Table, error) { func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil { if err != nil {
return nil, fmt.Errorf("nftables: unable to list tables: %v", err) return nil, fmt.Errorf("unable to list tables: %v", err)
} }
for _, table := range tables { for _, table := range tables {
@@ -196,15 +204,21 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeNAT, Type: nftables.ChainTypeNAT,
}) })
// Chain is created by acl manager r.chains[chainNameManglePostrouting] = r.conn.AddChain(&nftables.Chain{
// TODO: move creation to a common place Name: chainNameManglePostrouting,
r.chains[chainNamePrerouting] = &nftables.Chain{ Table: r.workTable,
Name: chainNamePrerouting, Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
})
r.chains[chainNameManglePrerouting] = r.conn.AddChain(&nftables.Chain{
Name: chainNameManglePrerouting,
Table: r.workTable, Table: r.workTable,
Hooknum: nftables.ChainHookPrerouting, Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle, Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter, Type: nftables.ChainTypeFilter,
} })
// Add the single NAT rule that matches on mark // Add the single NAT rule that matches on mark
if err := r.addPostroutingRules(); err != nil { if err := r.addPostroutingRules(); err != nil {
@@ -220,7 +234,83 @@ func (r *router) createContainers() error {
} }
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: unable to initialize table: %v", err) return fmt.Errorf("initialize tables: %v", err)
}
return nil
}
// setupDataPlaneMark configures the fwmark for the data plane
func (r *router) setupDataPlaneMark() error {
if r.chains[chainNameManglePrerouting] == nil || r.chains[chainNameManglePostrouting] == nil {
return errors.New("no mangle chains found")
}
ctNew := getCtNewExprs()
preExprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}
preExprs = append(preExprs, ctNew...)
preExprs = append(preExprs,
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkIn),
},
&expr.Ct{
Key: expr.CtKeyMARK,
Register: 1,
SourceRegister: true,
},
)
preNftRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameManglePrerouting],
Exprs: preExprs,
}
r.conn.AddRule(preNftRule)
postExprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}
postExprs = append(postExprs, ctNew...)
postExprs = append(postExprs,
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkOut),
},
&expr.Ct{
Key: expr.CtKeyMARK,
Register: 1,
SourceRegister: true,
},
)
postNftRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameManglePostrouting],
Exprs: postExprs,
}
r.conn.AddRule(postNftRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush: %w", err)
} }
return nil return nil
@@ -230,7 +320,7 @@ func (r *router) createContainers() error {
func (r *router) AddRouteFiltering( func (r *router) AddRouteFiltering(
id []byte, id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination firewall.Network,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
@@ -245,23 +335,29 @@ func (r *router) AddRouteFiltering(
chain := r.chains[chainNameRoutingFw] chain := r.chains[chainNameRoutingFw]
var exprs []expr.Any var exprs []expr.Any
var source firewall.Network
switch { switch {
case len(sources) == 1 && sources[0].Bits() == 0: case len(sources) == 1 && sources[0].Bits() == 0:
// If it's 0.0.0.0/0, we don't need to add any source matching // If it's 0.0.0.0/0, we don't need to add any source matching
case len(sources) == 1: case len(sources) == 1:
// If there's only one source, we can use it directly // If there's only one source, we can use it directly
exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...) source.Prefix = sources[0]
default: default:
// If there are multiple sources, create or get an ipset // If there are multiple sources, use a set
var err error source.Set = firewall.NewPrefixSet(sources)
exprs, err = r.getIpSetExprs(sources, exprs)
if err != nil {
return nil, fmt.Errorf("get ipset expressions: %w", err)
}
} }
// Handle destination sourceExp, err := r.applyNetwork(source, sources, true)
exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...) if err != nil {
return nil, fmt.Errorf("apply source: %w", err)
}
exprs = append(exprs, sourceExp...)
destExp, err := r.applyNetwork(destination, nil, false)
if err != nil {
return nil, fmt.Errorf("apply destination: %w", err)
}
exprs = append(exprs, destExp...)
// Handle protocol // Handle protocol
if proto != firewall.ProtocolALL { if proto != firewall.ProtocolALL {
@@ -305,39 +401,27 @@ func (r *router) AddRouteFiltering(
rule = r.conn.AddRule(rule) rule = r.conn.AddRule(rule)
} }
log.Tracef("Adding route rule %s", spew.Sdump(rule))
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf(flushError, err) return nil, fmt.Errorf(flushError, err)
} }
r.rules[string(ruleKey)] = rule r.rules[string(ruleKey)] = rule
log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action) log.Debugf("added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
return ruleKey, nil return ruleKey, nil
} }
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) { func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bool) ([]expr.Any, error) {
setName := firewall.GenerateSetName(sources) ref, err := r.ipsetCounter.Increment(set.HashedName(), setInput{
ref, err := r.ipsetCounter.Increment(setName, sources) set: set,
prefixes: prefixes,
})
if err != nil { if err != nil {
return nil, fmt.Errorf("create or get ipset for sources: %w", err) return nil, fmt.Errorf("create or get ipset: %w", err)
} }
exprs = append(exprs, return getIpSetExprs(ref, isSource)
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: ref.Out.Name,
SetID: ref.Out.ID,
},
)
return exprs, nil
} }
func (r *router) DeleteRouteRule(rule firewall.Rule) error { func (r *router) DeleteRouteRule(rule firewall.Rule) error {
@@ -356,42 +440,54 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
return fmt.Errorf("route rule %s has no handle", ruleKey) return fmt.Errorf("route rule %s has no handle", ruleKey)
} }
setName := r.findSetNameInRule(nftRule)
if err := r.deleteNftRule(nftRule, ruleKey); err != nil { if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
return fmt.Errorf("delete: %w", err) return fmt.Errorf("delete: %w", err)
} }
if setName != "" {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
return fmt.Errorf("decrement ipset reference: %w", err)
}
}
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err) return fmt.Errorf(flushError, err)
} }
if err := r.decrementSetCounter(nftRule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
return nil return nil
} }
func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) { func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, error) {
// overlapping prefixes will result in an error, so we need to merge them // overlapping prefixes will result in an error, so we need to merge them
sources = firewall.MergeIPRanges(sources) prefixes := firewall.MergeIPRanges(input.prefixes)
set := &nftables.Set{ nfset := &nftables.Set{
Name: setName, Name: setName,
Comment: input.set.Comment(),
Table: r.workTable, Table: r.workTable,
// required for prefixes // required for prefixes
Interval: true, Interval: true,
KeyType: nftables.TypeIPAddr, KeyType: nftables.TypeIPAddr,
} }
elements := convertPrefixesToSet(prefixes)
if err := r.conn.AddSet(nfset, elements); err != nil {
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
return nfset, nil
}
func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
var elements []nftables.SetElement var elements []nftables.SetElement
for _, prefix := range sources { for _, prefix := range prefixes {
// TODO: Implement IPv6 support // TODO: Implement IPv6 support
if prefix.Addr().Is6() { if prefix.Addr().Is6() {
log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue continue
} }
@@ -407,18 +503,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true}, nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
) )
} }
return elements
if err := r.conn.AddSet(set, elements); err != nil {
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
return set, nil
} }
// calculateLastIP determines the last IP in a given prefix. // calculateLastIP determines the last IP in a given prefix.
@@ -442,8 +527,8 @@ func uint32ToBytes(ip uint32) [4]byte {
return b return b
} }
func (r *router) deleteIpSet(setName string, set *nftables.Set) error { func (r *router) deleteIpSet(setName string, nfset *nftables.Set) error {
r.conn.DelSet(set) r.conn.DelSet(nfset)
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err) return fmt.Errorf(flushError, err)
} }
@@ -452,13 +537,27 @@ func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
return nil return nil
} }
func (r *router) findSetNameInRule(rule *nftables.Rule) string { func (r *router) decrementSetCounter(rule *nftables.Rule) error {
sets := r.findSets(rule)
var merr *multierror.Error
for _, setName := range sets {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
merr = multierror.Append(merr, fmt.Errorf("decrement set counter: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) findSets(rule *nftables.Rule) []string {
var sets []string
for _, e := range rule.Exprs { for _, e := range rule.Exprs {
if lookup, ok := e.(*expr.Lookup); ok { if lookup, ok := e.(*expr.Lookup); ok {
return lookup.SetName sets = append(sets, lookup.SetName)
} }
} }
return "" return sets
} }
func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error { func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
@@ -474,10 +573,6 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
// AddNatRule appends a nftables rule pair to the nat chain // AddNatRule appends a nftables rule pair to the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error { func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
@@ -500,7 +595,8 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
} }
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err) // TODO: rollback ipset counter
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
} }
return nil return nil
@@ -508,8 +604,15 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
// addNatRule inserts a nftables rule to the conn client flush queue // addNatRule inserts a nftables rule to the conn client flush queue
func (r *router) addNatRule(pair firewall.RouterPair) error { func (r *router) addNatRule(pair firewall.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source) sourceExp, err := r.applyNetwork(pair.Source, nil, true)
destExp := generateCIDRMatcherExpressions(false, pair.Destination) if err != nil {
return fmt.Errorf("apply source: %w", err)
}
destExp, err := r.applyNetwork(pair.Destination, nil, false)
if err != nil {
return fmt.Errorf("apply destination: %w", err)
}
op := expr.CmpOpEq op := expr.CmpOpEq
if pair.Inverse { if pair.Inverse {
@@ -517,26 +620,6 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
} }
exprs := []expr.Any{ exprs := []expr.Any{
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
// interface matching
&expr.Meta{ &expr.Meta{
Key: expr.MetaKeyIIFNAME, Key: expr.MetaKeyIIFNAME,
Register: 1, Register: 1,
@@ -547,6 +630,9 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
Data: ifname(r.wgIface.Name()), Data: ifname(r.wgIface.Name()),
}, },
} }
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
exprs = append(exprs, getCtNewExprs()...)
exprs = append(exprs, sourceExp...) exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...) exprs = append(exprs, destExp...)
@@ -576,9 +662,11 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
} }
} }
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ // Ensure nat rules come first, so the mark can be overwritten.
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
Table: r.workTable, Table: r.workTable,
Chain: r.chains[chainNamePrerouting], Chain: r.chains[chainNameManglePrerouting],
Exprs: exprs, Exprs: exprs,
UserData: []byte(ruleKey), UserData: []byte(ruleKey),
}) })
@@ -659,8 +747,15 @@ func (r *router) addPostroutingRules() error {
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls // addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source) sourceExp, err := r.applyNetwork(pair.Source, nil, true)
destExp := generateCIDRMatcherExpressions(false, pair.Destination) if err != nil {
return fmt.Errorf("apply source: %w", err)
}
destExp, err := r.applyNetwork(pair.Destination, nil, false)
if err != nil {
return fmt.Errorf("apply destination: %w", err)
}
exprs := []expr.Any{ exprs := []expr.Any{
&expr.Counter{}, &expr.Counter{},
@@ -669,7 +764,8 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
}, },
} }
expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...)
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
@@ -682,7 +778,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable, Table: r.workTable,
Chain: r.chains[chainNameRoutingFw], Chain: r.chains[chainNameRoutingFw],
Exprs: expression, Exprs: exprs,
UserData: []byte(ruleKey), UserData: []byte(ruleKey),
}) })
return nil return nil
@@ -697,11 +793,13 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
} }
log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination) log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
} else {
log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey) if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
} }
return nil return nil
@@ -904,14 +1002,11 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
// RemoveNatRule removes the prerouting mark rule // RemoveNatRule removes the prerouting mark rule
func (r *router) RemoveNatRule(pair firewall.RouterPair) error { func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil { if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove prerouting rule: %w", err) return fmt.Errorf("remove prerouting rule: %w", err)
} }
@@ -919,16 +1014,17 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse prerouting rule: %w", err) return fmt.Errorf("remove inverse prerouting rule: %w", err)
} }
}
if err := r.removeLegacyRouteRule(pair); err != nil { if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err) return fmt.Errorf("remove legacy routing rule: %w", err)
} }
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err) // TODO: rollback set counter
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
} }
log.Debugf("nftables: removed nat rules for %s", pair.Destination)
return nil return nil
} }
@@ -936,16 +1032,19 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair) ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
err := r.conn.DelRule(rule) if err := r.conn.DelRule(rule); err != nil {
if err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err) return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
} }
log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination) log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
} else { } else {
log.Debugf("nftables: prerouting rule %s not found", ruleKey) log.Debugf("prerouting rule %s not found", ruleKey)
} }
return nil return nil
@@ -957,7 +1056,7 @@ func (r *router) refreshRulesMap() error {
for _, chain := range r.chains { for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain) rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil { if err != nil {
return fmt.Errorf("nftables: unable to list rules: %v", err) return fmt.Errorf(" unable to list rules: %v", err)
} }
for _, rule := range rules { for _, rule := range rules {
if len(rule.UserData) > 0 { if len(rule.UserData) > 0 {
@@ -1231,13 +1330,54 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any { nfset, err := r.conn.GetSetByName(r.workTable, set.HashedName())
var offset uint32 if err != nil {
if source { return fmt.Errorf("get set %s: %w", set.HashedName(), err)
offset = 12 // src offset }
} else {
offset = 16 // dst offset elements := convertPrefixesToSet(prefixes)
if err := r.conn.SetAddElements(nfset, elements); err != nil {
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
return nil
}
// applyNetwork generates nftables expressions for networks (CIDR) or sets
func (r *router) applyNetwork(
network firewall.Network,
setPrefixes []netip.Prefix,
isSource bool,
) ([]expr.Any, error) {
if network.IsSet() {
exprs, err := r.getIpSet(network.Set, setPrefixes, isSource)
if err != nil {
return nil, fmt.Errorf("source: %w", err)
}
return exprs, nil
}
if network.IsPrefix() {
return applyPrefix(network.Prefix, isSource), nil
}
return nil, nil
}
// applyPrefix generates nftables expressions for a CIDR prefix
func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
// dst offset
offset := uint32(16)
if isSource {
// src offset
offset = 12
} }
ones := prefix.Bits() ones := prefix.Bits()
@@ -1324,3 +1464,48 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
return exprs return exprs
} }
func getCtNewExprs() []expr.Any {
return []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
}
}
func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
// dst offset
offset := uint32(16)
if isSource {
// src offset
offset = 12
}
return []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offset,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: ref.Out.Name,
SetID: ref.Out.ID,
},
}, nil
}

View File

@@ -88,8 +88,8 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
} }
// Build CIDR matching expressions // Build CIDR matching expressions
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false)
// Combine all expressions in the correct order // Combine all expressions in the correct order
// nolint:gocritic // nolint:gocritic
@@ -100,7 +100,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair) natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
found := 0 found := 0
for _, chain := range rtr.chains { for _, chain := range rtr.chains {
if chain.Name == chainNamePrerouting { if chain.Name == chainNameManglePrerouting {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain) rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules { for _, rule := range rules {
@@ -141,7 +141,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
// Verify the rule was added // Verify the rule was added
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair) natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
found := false found := false
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting]) rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
require.NoError(t, err, "should list rules") require.NoError(t, err, "should list rules")
for _, rule := range rules { for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey { if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
@@ -157,7 +157,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
// Verify the rule was removed // Verify the rule was removed
found = false found = false
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting]) rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
require.NoError(t, err, "should list rules after removal") require.NoError(t, err, "should list rules after removal")
for _, rule := range rules { for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey { if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
@@ -311,7 +311,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed") require.NoError(t, err, "AddRouteFiltering failed")
t.Cleanup(func() { t.Cleanup(func() {
@@ -441,8 +441,8 @@ func TestNftablesCreateIpSet(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
setName := firewall.GenerateSetName(tt.sources) setName := firewall.NewPrefixSet(tt.sources).HashedName()
set, err := r.createIpSet(setName, tt.sources) set, err := r.createIpSet(setName, setInput{prefixes: tt.sources})
if err != nil { if err != nil {
t.Logf("Failed to create IP set: %v", err) t.Logf("Failed to create IP set: %v", err)
printNftSets() printNftSets()

View File

@@ -15,8 +15,8 @@ var (
Name: "Insert Forwarding IPV4 Rule", Name: "Insert Forwarding IPV4 Rule",
InputPair: firewall.RouterPair{ InputPair: firewall.RouterPair{
ID: "zxa", ID: "zxa",
Source: netip.MustParsePrefix("100.100.100.1/32"), Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: netip.MustParsePrefix("100.100.200.0/24"), Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
Masquerade: false, Masquerade: false,
}, },
}, },
@@ -24,8 +24,8 @@ var (
Name: "Insert Forwarding And Nat IPV4 Rules", Name: "Insert Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{ InputPair: firewall.RouterPair{
ID: "zxa", ID: "zxa",
Source: netip.MustParsePrefix("100.100.100.1/32"), Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: netip.MustParsePrefix("100.100.200.0/24"), Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
Masquerade: true, Masquerade: true,
}, },
}, },
@@ -40,8 +40,8 @@ var (
Name: "Remove Forwarding And Nat IPV4 Rules", Name: "Remove Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{ InputPair: firewall.RouterPair{
ID: "zxa", ID: "zxa",
Source: netip.MustParsePrefix("100.100.100.1/32"), Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: netip.MustParsePrefix("100.100.200.0/24"), Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
Masquerade: true, Masquerade: true,
}, },
}, },

View File

@@ -12,7 +12,7 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
// Reset firewall to the default state // Close cleans up the firewall manager by removing all rules and closing trackers
func (m *Manager) Close(stateManager *statemanager.Manager) error { func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()

View File

@@ -10,7 +10,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -22,7 +21,7 @@ const (
firewallRuleName = "Netbird" firewallRuleName = "Netbird"
) )
// Reset firewall to the default state // Close cleans up the firewall manager by removing all rules and closing trackers
func (m *Manager) Close(*statemanager.Manager) error { func (m *Manager) Close(*statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -32,17 +31,14 @@ func (m *Manager) Close(*statemanager.Manager) error {
if m.udpTracker != nil { if m.udpTracker != nil {
m.udpTracker.Close() m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
} }
if m.icmpTracker != nil { if m.icmpTracker != nil {
m.icmpTracker.Close() m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
} }
if m.tcpTracker != nil { if m.tcpTracker != nil {
m.tcpTracker.Close() m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
} }
if fwder := m.forwarder.Load(); fwder != nil { if fwder := m.forwarder.Load(); fwder != nil {

View File

@@ -189,7 +189,7 @@ func (t *ICMPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) { if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key) delete(t.connections, key)
t.logger.Debug("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]", t.logger.Trace("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil) t.sendEvent(nftypes.TypeEnd, conn, nil)
} }

View File

@@ -23,11 +23,11 @@ const (
) )
const ( const (
TCPSyn uint8 = 0x02
TCPAck uint8 = 0x10
TCPFin uint8 = 0x01 TCPFin uint8 = 0x01
TCPSyn uint8 = 0x02
TCPRst uint8 = 0x04 TCPRst uint8 = 0x04
TCPPush uint8 = 0x08 TCPPush uint8 = 0x08
TCPAck uint8 = 0x10
TCPUrg uint8 = 0x20 TCPUrg uint8 = 0x20
) )
@@ -41,7 +41,7 @@ const (
) )
// TCPState represents the state of a TCP connection // TCPState represents the state of a TCP connection
type TCPState int type TCPState int32
func (s TCPState) String() string { func (s TCPState) String() string {
switch s { switch s {
@@ -91,20 +91,23 @@ type TCPConnTrack struct {
BaseConnTrack BaseConnTrack
SourcePort uint16 SourcePort uint16
DestPort uint16 DestPort uint16
State TCPState state atomic.Int32
established atomic.Bool
tombstone atomic.Bool tombstone atomic.Bool
sync.RWMutex
} }
// IsEstablished safely checks if connection is established // GetState safely retrieves the current state
func (t *TCPConnTrack) IsEstablished() bool { func (t *TCPConnTrack) GetState() TCPState {
return t.established.Load() return TCPState(t.state.Load())
} }
// SetEstablished safely sets the established state // SetState safely updates the current state
func (t *TCPConnTrack) SetEstablished(state bool) { func (t *TCPConnTrack) SetState(state TCPState) {
t.established.Store(state) t.state.Store(int32(state))
}
// CompareAndSwapState atomically changes the state from old to new if current == old
func (t *TCPConnTrack) CompareAndSwapState(old, newState TCPState) bool {
return t.state.CompareAndSwap(int32(old), int32(newState))
} }
// IsTombstone safely checks if the connection is marked for deletion // IsTombstone safely checks if the connection is marked for deletion
@@ -125,13 +128,17 @@ type TCPTracker struct {
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
tickerCancel context.CancelFunc tickerCancel context.CancelFunc
timeout time.Duration timeout time.Duration
waitTimeout time.Duration
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
} }
// NewTCPTracker creates a new TCP connection tracker // NewTCPTracker creates a new TCP connection tracker
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker { func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker {
waitTimeout := TimeWaitTimeout
if timeout == 0 { if timeout == 0 {
timeout = DefaultTCPTimeout timeout = DefaultTCPTimeout
} else {
waitTimeout = timeout / 45
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@@ -142,6 +149,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
cleanupTicker: time.NewTicker(TCPCleanupInterval), cleanupTicker: time.NewTicker(TCPCleanupInterval),
tickerCancel: cancel, tickerCancel: cancel,
timeout: timeout, timeout: timeout,
waitTimeout: waitTimeout,
flowLogger: flowLogger, flowLogger: flowLogger,
} }
@@ -149,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
return tracker return tracker
} }
func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) { func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
key := ConnKey{ key := ConnKey{
SrcIP: srcIP, SrcIP: srcIP,
DstIP: dstIP, DstIP: dstIP,
@@ -162,12 +170,7 @@ func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
t.mutex.RUnlock() t.mutex.RUnlock()
if exists { if exists {
conn.Lock() t.updateState(key, conn, flags, direction, size)
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
conn.Unlock()
conn.UpdateCounters(direction, size)
return key, true return key, true
} }
@@ -175,7 +178,7 @@ func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
} }
// TrackOutbound records an outbound TCP connection // TrackOutbound records an outbound TCP connection
func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) { func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists { if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction // if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size) t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
@@ -183,14 +186,14 @@ func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort u
} }
// TrackInbound processes an inbound TCP packet and updates connection state // TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, ruleID []byte, size int) { func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size) t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
} }
// track is the common implementation for tracking both inbound and outbound connections // track is the common implementation for tracking both inbound and outbound connections
func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) { func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size) key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
if exists { if exists || flags&TCPSyn == 0 {
return return
} }
@@ -205,12 +208,11 @@ func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
DestPort: dstPort, DestPort: dstPort,
} }
conn.established.Store(false)
conn.tombstone.Store(false) conn.tombstone.Store(false)
conn.state.Store(int32(TCPStateNew))
t.logger.Trace("New %s TCP connection: %s", direction, key) t.logger.Trace("New %s TCP connection: %s", direction, key)
t.updateState(key, conn, flags, direction == nftypes.Egress) t.updateState(key, conn, flags, direction, size)
conn.UpdateCounters(direction, size)
t.mutex.Lock() t.mutex.Lock()
t.connections[key] = conn t.connections[key] = conn
@@ -220,7 +222,7 @@ func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
} }
// IsValidInbound checks if an inbound TCP packet matches a tracked connection // IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) bool { func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool {
key := ConnKey{ key := ConnKey{
SrcIP: dstIP, SrcIP: dstIP,
DstIP: srcIP, DstIP: srcIP,
@@ -232,134 +234,125 @@ func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
conn, exists := t.connections[key] conn, exists := t.connections[key]
t.mutex.RUnlock() t.mutex.RUnlock()
if !exists { if !exists || conn.IsTombstone() {
return false return false
} }
// Handle RST flag specially - it always causes transition to closed currentState := conn.GetState()
if flags&TCPRst != 0 { if !t.isValidStateForFlags(currentState, flags) {
return t.handleRst(key, conn, size) t.logger.Warn("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
} // allow all flags for established for now
if currentState == TCPStateEstablished {
conn.Lock()
t.updateState(key, conn, flags, false)
isEstablished := conn.IsEstablished()
isValidState := t.isValidStateForFlags(conn.State, flags)
conn.Unlock()
conn.UpdateCounters(nftypes.Ingress, size)
return isEstablished || isValidState
}
func (t *TCPTracker) handleRst(key ConnKey, conn *TCPConnTrack, size int) bool {
if conn.IsTombstone() {
return true return true
} }
return false
}
conn.Lock() t.updateState(key, conn, flags, nftypes.Ingress, size)
conn.SetTombstone()
conn.State = TCPStateClosed
conn.SetEstablished(false)
conn.Unlock()
conn.UpdateCounters(nftypes.Ingress, size)
t.logger.Trace("TCP connection reset: %s", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
return true return true
} }
// updateState updates the TCP connection state based on flags // updateState updates the TCP connection state based on flags
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) { func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) {
conn.UpdateLastSeen() conn.UpdateLastSeen()
conn.UpdateCounters(packetDir, size)
state := conn.State currentState := conn.GetState()
defer func() {
if state != conn.State { if flags&TCPRst != 0 {
t.logger.Trace("TCP connection %s transitioned from %s to %s", key, state, conn.State) if conn.CompareAndSwapState(currentState, TCPStateClosed) {
conn.SetTombstone()
t.logger.Trace("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
return
} }
}()
switch state { var newState TCPState
switch currentState {
case TCPStateNew: case TCPStateNew:
if flags&TCPSyn != 0 && flags&TCPAck == 0 { if flags&TCPSyn != 0 && flags&TCPAck == 0 {
conn.State = TCPStateSynSent if conn.Direction == nftypes.Egress {
newState = TCPStateSynSent
} else {
newState = TCPStateSynReceived
}
} }
case TCPStateSynSent: case TCPStateSynSent:
if flags&TCPSyn != 0 && flags&TCPAck != 0 { if flags&TCPSyn != 0 && flags&TCPAck != 0 {
if isOutbound { if packetDir != conn.Direction {
conn.State = TCPStateEstablished newState = TCPStateEstablished
conn.SetEstablished(true)
} else { } else {
// Simultaneous open // Simultaneous open
conn.State = TCPStateSynReceived newState = TCPStateSynReceived
} }
} }
case TCPStateSynReceived: case TCPStateSynReceived:
if flags&TCPAck != 0 && flags&TCPSyn == 0 { if flags&TCPAck != 0 && flags&TCPSyn == 0 {
conn.State = TCPStateEstablished if packetDir == conn.Direction {
conn.SetEstablished(true) newState = TCPStateEstablished
}
} }
case TCPStateEstablished: case TCPStateEstablished:
if flags&TCPFin != 0 { if flags&TCPFin != 0 {
if isOutbound { if packetDir == conn.Direction {
conn.State = TCPStateFinWait1 newState = TCPStateFinWait1
} else { } else {
conn.State = TCPStateCloseWait newState = TCPStateCloseWait
} }
conn.SetEstablished(false)
} else if flags&TCPRst != 0 {
conn.State = TCPStateClosed
conn.SetTombstone()
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateFinWait1: case TCPStateFinWait1:
if packetDir != conn.Direction {
switch { switch {
case flags&TCPFin != 0 && flags&TCPAck != 0: case flags&TCPFin != 0 && flags&TCPAck != 0:
conn.State = TCPStateClosing newState = TCPStateClosing
case flags&TCPFin != 0: case flags&TCPFin != 0:
conn.State = TCPStateFinWait2 newState = TCPStateClosing
case flags&TCPAck != 0: case flags&TCPAck != 0:
conn.State = TCPStateFinWait2 newState = TCPStateFinWait2
case flags&TCPRst != 0: }
conn.State = TCPStateClosed
conn.SetTombstone()
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateFinWait2: case TCPStateFinWait2:
if flags&TCPFin != 0 { if flags&TCPFin != 0 {
conn.State = TCPStateTimeWait newState = TCPStateTimeWait
t.logger.Trace("TCP connection %s completed", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateClosing: case TCPStateClosing:
if flags&TCPAck != 0 { if flags&TCPAck != 0 {
conn.State = TCPStateTimeWait newState = TCPStateTimeWait
// Keep established = false from previous state
t.logger.Trace("TCP connection %s closed (simultaneous)", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateCloseWait: case TCPStateCloseWait:
if flags&TCPFin != 0 { if flags&TCPFin != 0 {
conn.State = TCPStateLastAck newState = TCPStateLastAck
} }
case TCPStateLastAck: case TCPStateLastAck:
if flags&TCPAck != 0 { if flags&TCPAck != 0 {
conn.State = TCPStateClosed newState = TCPStateClosed
conn.SetTombstone() }
}
// Send close event for gracefully closed connections if newState != 0 && conn.CompareAndSwapState(currentState, newState) {
t.logger.Trace("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir)
switch newState {
case TCPStateTimeWait:
t.logger.Trace("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
case TCPStateClosed:
conn.SetTombstone()
t.logger.Trace("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil) t.sendEvent(nftypes.TypeEnd, conn, nil)
t.logger.Trace("TCP connection %s closed gracefully", key)
} }
} }
} }
@@ -369,18 +362,22 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
if !isValidFlagCombination(flags) { if !isValidFlagCombination(flags) {
return false return false
} }
if flags&TCPRst != 0 {
if state == TCPStateSynSent {
return flags&TCPAck != 0
}
return true
}
switch state { switch state {
case TCPStateNew: case TCPStateNew:
return flags&TCPSyn != 0 && flags&TCPAck == 0 return flags&TCPSyn != 0 && flags&TCPAck == 0
case TCPStateSynSent: case TCPStateSynSent:
// TODO: support simultaneous open
return flags&TCPSyn != 0 && flags&TCPAck != 0 return flags&TCPSyn != 0 && flags&TCPAck != 0
case TCPStateSynReceived: case TCPStateSynReceived:
return flags&TCPAck != 0 return flags&TCPAck != 0
case TCPStateEstablished: case TCPStateEstablished:
if flags&TCPRst != 0 {
return true
}
return flags&TCPAck != 0 return flags&TCPAck != 0
case TCPStateFinWait1: case TCPStateFinWait1:
return flags&TCPFin != 0 || flags&TCPAck != 0 return flags&TCPFin != 0 || flags&TCPAck != 0
@@ -397,9 +394,7 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
case TCPStateLastAck: case TCPStateLastAck:
return flags&TCPAck != 0 return flags&TCPAck != 0
case TCPStateClosed: case TCPStateClosed:
// Accept retransmitted ACKs in closed state // Accept retransmitted ACKs in closed state, the final ACK might be lost and the peer will retransmit their FIN-ACK
// This is important because the final ACK might be lost
// and the peer will retransmit their FIN-ACK
return flags&TCPAck != 0 return flags&TCPAck != 0
} }
return false return false
@@ -430,23 +425,24 @@ func (t *TCPTracker) cleanup() {
} }
var timeout time.Duration var timeout time.Duration
switch { currentState := conn.GetState()
case conn.State == TCPStateTimeWait: switch currentState {
timeout = TimeWaitTimeout case TCPStateTimeWait:
case conn.IsEstablished(): timeout = t.waitTimeout
case TCPStateEstablished:
timeout = t.timeout timeout = t.timeout
default: default:
timeout = TCPHandshakeTimeout timeout = TCPHandshakeTimeout
} }
if conn.timeoutExceeded(timeout) { if conn.timeoutExceeded(timeout) {
// Return IPs to pool
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace("Cleaned up timed-out TCP connection %s", key) t.logger.Trace("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
// event already handled by state change // event already handled by state change
if conn.State != TCPStateTimeWait { if currentState != TCPStateTimeWait {
t.sendEvent(nftypes.TypeEnd, conn, nil) t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
} }

View File

@@ -0,0 +1,83 @@
package conntrack
import (
"net/netip"
"testing"
"time"
)
func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck|TCPSyn, 0)
}
})
b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
if i%2 == 0 {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
} else {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck|TCPSyn, 0)
}
i++
}
})
})
}
// Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger)
defer tracker.Close()
// Pre-populate with expired connections
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
for i := 0; i < 10000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
}
// Wait for connections to expire
time.Sleep(200 * time.Millisecond)
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.cleanup()
}
})
}

View File

@@ -5,6 +5,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -124,9 +125,6 @@ func TestTCPStateMachine(t *testing.T) {
// Receive RST // Receive RST
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
require.True(t, valid, "RST should be allowed for established connection") require.True(t, valid, "RST should be allowed for established connection")
// Connection is logically dead but we don't enforce blocking subsequent packets
// The connection will be cleaned up by timeout
}, },
}, },
{ {
@@ -217,97 +215,446 @@ func TestRSTHandling(t *testing.T) {
conn := tracker.connections[key] conn := tracker.connections[key]
if tt.wantValid { if tt.wantValid {
require.NotNil(t, conn) require.NotNil(t, conn)
require.Equal(t, TCPStateClosed, conn.State) require.Equal(t, TCPStateClosed, conn.GetState())
require.False(t, conn.IsEstablished())
} }
}) })
} }
} }
func TestTCPRetransmissions(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
// Test SYN retransmission
t.Run("SYN Retransmission", func(t *testing.T) {
// Initial SYN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
// Retransmit SYN (should not affect the state machine)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
// Verify we're still in SYN-SENT state
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
require.Equal(t, TCPStateSynSent, conn.GetState())
// Complete the handshake
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
require.True(t, valid)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// Verify we're in ESTABLISHED state
require.Equal(t, TCPStateEstablished, conn.GetState())
})
// Test ACK retransmission in established state
t.Run("ACK Retransmission", func(t *testing.T) {
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
// Establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Get connection object
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
require.Equal(t, TCPStateEstablished, conn.GetState())
// Retransmit ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// State should remain ESTABLISHED
require.Equal(t, TCPStateEstablished, conn.GetState())
})
// Test FIN retransmission
t.Run("FIN Retransmission", func(t *testing.T) {
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
// Establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Get connection object
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
// Send FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Retransmit FIN (should not change state)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Receive ACK for FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateFinWait2, conn.GetState())
})
}
func TestTCPDataTransfer(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Data Transfer", func(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Get connection object
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
// Send data
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000)
// Receive ACK for data
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 100)
require.True(t, valid)
// Receive data
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1500)
require.True(t, valid)
// Send ACK for received data
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
// State should remain ESTABLISHED
require.Equal(t, TCPStateEstablished, conn.GetState())
assert.Equal(t, uint64(1300), conn.BytesTx.Load())
assert.Equal(t, uint64(1700), conn.BytesRx.Load())
assert.Equal(t, uint64(4), conn.PacketsTx.Load())
assert.Equal(t, uint64(3), conn.PacketsRx.Load())
})
}
func TestTCPHalfClosedConnections(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
// Test half-closed connection: local end closes, remote end continues sending data
t.Run("Local Close, Remote Data", func(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
// Send FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Receive ACK for FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateFinWait2, conn.GetState())
// Remote end can still send data
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1000)
require.True(t, valid)
// We can still ACK their data
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// Receive FIN from remote end
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// State should remain TIME-WAIT (waiting for possible retransmissions)
require.Equal(t, TCPStateTimeWait, conn.GetState())
})
// Test half-closed connection: remote end closes, local end continues sending data
t.Run("Remote Close, Local Data", func(t *testing.T) {
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
// Establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Get connection object
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
// Receive FIN from remote
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateCloseWait, conn.GetState())
// We can still send data
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000)
// Remote can still ACK our data
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
// Send our FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateLastAck, conn.GetState())
// Receive final ACK
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateClosed, conn.GetState())
})
}
func TestTCPAbnormalSequences(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
// Test handling of unsolicited RST in various states
t.Run("Unsolicited RST in SYN-SENT", func(t *testing.T) {
// Send SYN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
// Receive unsolicited RST (without proper ACK)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
require.False(t, valid, "RST without proper ACK in SYN-SENT should be rejected")
// Receive RST with proper ACK
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
require.True(t, valid, "RST with proper ACK in SYN-SENT should be accepted")
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.Equal(t, TCPStateClosed, conn.GetState())
require.True(t, conn.IsTombstone())
})
}
func TestTCPTimeoutHandling(t *testing.T) {
// Create tracker with a very short timeout for testing
shortTimeout := 100 * time.Millisecond
tracker := NewTCPTracker(shortTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Connection Timeout", func(t *testing.T) {
// Establish a connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Get connection object
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
require.Equal(t, TCPStateEstablished, conn.GetState())
// Wait for the connection to timeout
time.Sleep(2 * shortTimeout)
// Force cleanup
tracker.cleanup()
// Connection should be removed
_, exists := tracker.connections[key]
require.False(t, exists, "Connection should be removed after timeout")
})
t.Run("TIME_WAIT Timeout", func(t *testing.T) {
tracker = NewTCPTracker(shortTimeout, logger, flowLogger)
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
// Complete the connection close to enter TIME_WAIT
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.Equal(t, TCPStateTimeWait, conn.GetState())
// TIME_WAIT should have its own timeout value (usually 2*MSL)
// For the test, we're using a short timeout
time.Sleep(2 * shortTimeout)
tracker.cleanup()
// Connection should be removed
_, exists := tracker.connections[key]
require.False(t, exists, "Connection should be removed after TIME_WAIT timeout")
})
}
func TestSynFlood(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
basePort := uint16(10000)
dstPort := uint16(80)
// Create a large number of SYN packets to simulate a SYN flood
for i := uint16(0); i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, basePort+i, dstPort, TCPSyn, 0)
}
// Check that we're tracking all connections
require.Equal(t, 1000, len(tracker.connections))
// Now simulate SYN timeout
var oldConns int
tracker.mutex.Lock()
for _, conn := range tracker.connections {
if conn.GetState() == TCPStateSynSent {
// Make the connection appear old
conn.lastSeen.Store(time.Now().Add(-TCPHandshakeTimeout - time.Second).UnixNano())
oldConns++
}
}
tracker.mutex.Unlock()
require.Equal(t, 1000, oldConns)
// Run cleanup
tracker.cleanup()
// Check that stale connections were cleaned up
require.Equal(t, 0, len(tracker.connections))
}
func TestTCPInboundInitiatedConnection(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
clientIP := netip.MustParseAddr("100.64.0.1")
serverIP := netip.MustParseAddr("100.64.0.2")
clientPort := uint16(12345)
serverPort := uint16(80)
// 1. Client sends SYN (we receive it as inbound)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100)
key := ConnKey{
SrcIP: clientIP,
DstIP: serverIP,
SrcPort: clientPort,
DstPort: serverPort,
}
tracker.mutex.RLock()
conn := tracker.connections[key]
tracker.mutex.RUnlock()
require.NotNil(t, conn)
require.Equal(t, TCPStateSynReceived, conn.GetState(), "Connection should be in SYN-RECEIVED state after inbound SYN")
// 2. Server sends SYN-ACK response
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
// 3. Client sends ACK to complete handshake
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion")
// 4. Test data transfer
// Client sends data
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000)
// Server sends ACK for data
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
// Server sends data
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500)
// Client sends ACK for data
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
// Verify state and counters
require.Equal(t, TCPStateEstablished, conn.GetState())
assert.Equal(t, uint64(1300), conn.BytesRx.Load()) // 3 packets * 100 + 1000 data
assert.Equal(t, uint64(1700), conn.BytesTx.Load()) // 2 packets * 100 + 1500 data
assert.Equal(t, uint64(4), conn.PacketsRx.Load()) // SYN, ACK, Data
assert.Equal(t, uint64(3), conn.PacketsTx.Load()) // SYN-ACK, Data
}
// Helper to establish a TCP connection // Helper to establish a TCP connection
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) { func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
t.Helper() t.Helper()
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
require.True(t, valid, "SYN-ACK should be allowed") require.True(t, valid, "SYN-ACK should be allowed")
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
}
func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, 0)
}
})
b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
if i%2 == 0 {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
} else {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, 0)
}
i++
}
})
})
}
// Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger) // Short timeout for testing
defer tracker.Close()
// Pre-populate with expired connections
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
for i := 0; i < 10000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
}
// Wait for connections to expire
time.Sleep(200 * time.Millisecond)
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.cleanup()
}
})
} }

View File

@@ -165,7 +165,7 @@ func (t *UDPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) { if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]", t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil) t.sendEvent(nftypes.TypeEnd, conn, nil)
} }

View File

@@ -4,7 +4,9 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"runtime" "runtime"
"sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/buffer"
@@ -17,6 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common" "github.com/netbirdio/netbird/client/firewall/uspfilter/common"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
@@ -31,12 +34,14 @@ const (
type Forwarder struct { type Forwarder struct {
logger *nblog.Logger logger *nblog.Logger
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
// ruleIdMap is used to store the rule ID for a given connection
ruleIdMap sync.Map
stack *stack.Stack stack *stack.Stack
endpoint *endpoint endpoint *endpoint
udpForwarder *udpForwarder udpForwarder *udpForwarder
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
ip net.IP ip tcpip.Address
netstack bool netstack bool
} }
@@ -66,12 +71,11 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
return nil, fmt.Errorf("failed to create NIC: %v", err) return nil, fmt.Errorf("failed to create NIC: %v", err)
} }
ones, _ := iface.Address().Network.Mask.Size()
protoAddr := tcpip.ProtocolAddress{ protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber, Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddressWithPrefix{ AddressWithPrefix: tcpip.AddressWithPrefix{
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()), Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
PrefixLen: ones, PrefixLen: iface.Address().Network.Bits(),
}, },
} }
@@ -111,7 +115,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
netstack: netstack, netstack: netstack,
ip: iface.Address().IP, ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
} }
receiveWindow := defaultReceiveWindow receiveWindow := defaultReceiveWindow
@@ -162,8 +166,39 @@ func (f *Forwarder) Stop() {
} }
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP { func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
if f.netstack && f.ip.Equal(addr.AsSlice()) { if f.netstack && f.ip.Equal(addr) {
return net.IPv4(127, 0, 0, 1) return net.IPv4(127, 0, 0, 1)
} }
return addr.AsSlice() return addr.AsSlice()
} }
func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) {
key := buildKey(srcIP, dstIP, srcPort, dstPort)
f.ruleIdMap.LoadOrStore(key, ruleID)
}
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
return value.([]byte), true
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {
return value.([]byte), true
}
return nil, false
}
func (f *Forwarder) DeleteRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
if _, ok := f.ruleIdMap.LoadAndDelete(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
return
}
f.ruleIdMap.LoadAndDelete(buildKey(dstIP, srcIP, dstPort, srcPort))
}
func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKey {
return conntrack.ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
}

View File

@@ -25,7 +25,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
} }
flowID := uuid.New() flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode) f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode, 0, 0)
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second) ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel() defer cancel()
@@ -34,14 +34,14 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
// TODO: support non-root // TODO: support non-root
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil { if err != nil {
f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err) f.logger.Error("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err)
// This will make netstack reply on behalf of the original destination, that's ok for now // This will make netstack reply on behalf of the original destination, that's ok for now
return false return false
} }
defer func() { defer func() {
if err := conn.Close(); err != nil { if err := conn.Close(); err != nil {
f.logger.Debug("Failed to close ICMP socket: %v", err) f.logger.Debug("forwarder: Failed to close ICMP socket: %v", err)
} }
}() }()
@@ -52,36 +52,37 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
payload := fullPacket.AsSlice() payload := fullPacket.AsSlice()
if _, err = conn.WriteTo(payload, dst); err != nil { if _, err = conn.WriteTo(payload, dst); err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err) f.logger.Error("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err)
return true return true
} }
f.logger.Trace("Forwarded ICMP packet %v type %v code %v", f.logger.Trace("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code()) epID(id), icmpHdr.Type(), icmpHdr.Code())
// For Echo Requests, send and handle response // For Echo Requests, send and handle response
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo { if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
f.handleEchoResponse(icmpHdr, conn, id) rxBytes := pkt.Size()
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode) txBytes := f.handleEchoResponse(icmpHdr, conn, id)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
} }
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing // For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
return true return true
} }
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) { func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int {
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error("Failed to set read deadline for ICMP response: %v", err) f.logger.Error("forwarder: Failed to set read deadline for ICMP response: %v", err)
return return 0
} }
response := make([]byte, f.endpoint.mtu) response := make([]byte, f.endpoint.mtu)
n, _, err := conn.ReadFrom(response) n, _, err := conn.ReadFrom(response)
if err != nil { if err != nil {
if !isTimeout(err) { if !isTimeout(err) {
f.logger.Error("Failed to read ICMP response: %v", err) f.logger.Error("forwarder: Failed to read ICMP response: %v", err)
} }
return return 0
} }
ipHdr := make([]byte, header.IPv4MinimumSize) ipHdr := make([]byte, header.IPv4MinimumSize)
@@ -100,28 +101,54 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon
fullPacket = append(fullPacket, response[:n]...) fullPacket = append(fullPacket, response[:n]...)
if err := f.InjectIncomingPacket(fullPacket); err != nil { if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error("Failed to inject ICMP response: %v", err) f.logger.Error("forwarder: Failed to inject ICMP response: %v", err)
return return 0
} }
f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v", f.logger.Trace("forwarder: Forwarded ICMP echo reply for %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code()) epID(id), icmpHdr.Type(), icmpHdr.Code())
return len(fullPacket)
} }
// sendICMPEvent stores flow events for ICMP packets // sendICMPEvent stores flow events for ICMP packets
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8) { func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, rxBytes, txBytes uint64) {
f.flowLogger.StoreEvent(nftypes.EventFields{ var rxPackets, txPackets uint64
if rxBytes > 0 {
rxPackets = 1
}
if txBytes > 0 {
txPackets = 1
}
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
fields := nftypes.EventFields{
FlowID: flowID, FlowID: flowID,
Type: typ, Type: typ,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: nftypes.ICMP, Protocol: nftypes.ICMP,
// TODO: handle ipv6 // TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()), SourceIP: srcIp,
DestIP: netip.AddrFrom4(id.LocalAddress.As4()), DestIP: dstIp,
ICMPType: icmpType, ICMPType: icmpType,
ICMPCode: icmpCode, ICMPCode: icmpCode,
// TODO: get packets/bytes RxBytes: rxBytes,
}) TxBytes: txBytes,
RxPackets: rxPackets,
TxPackets: txPackets,
}
if typ == nftypes.TypeStart {
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
fields.RuleID = ruleId
}
} else {
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
}
f.flowLogger.StoreEvent(fields)
} }

View File

@@ -6,8 +6,10 @@ import (
"io" "io"
"net" "net"
"net/netip" "net/netip"
"sync"
"github.com/google/uuid" "github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -23,11 +25,11 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
flowID := uuid.New() flowID := uuid.New()
f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil) f.sendTCPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
var success bool var success bool
defer func() { defer func() {
if !success { if !success {
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil) f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
} }
}() }()
@@ -65,67 +67,97 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
} }
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) { func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
defer func() {
if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: inConn close error: %v", err)
}
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: outConn close error: %v", err)
}
ep.Close()
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep)
}()
// Create context for managing the proxy goroutines
ctx, cancel := context.WithCancel(f.ctx) ctx, cancel := context.WithCancel(f.ctx)
defer cancel() defer cancel()
errChan := make(chan error, 2) go func() {
<-ctx.Done()
// Close connections and endpoint.
if err := inConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug("forwarder: inConn close error: %v", err)
}
if err := outConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug("forwarder: outConn close error: %v", err)
}
ep.Close()
}()
var wg sync.WaitGroup
wg.Add(2)
var (
bytesFromInToOut int64 // bytes from client to server (tx for client)
bytesFromOutToIn int64 // bytes from server to client (rx for client)
errInToOut error
errOutToIn error
)
go func() { go func() {
_, err := io.Copy(outConn, inConn) bytesFromInToOut, errInToOut = io.Copy(outConn, inConn)
errChan <- err cancel()
wg.Done()
}() }()
go func() { go func() {
_, err := io.Copy(inConn, outConn)
errChan <- err bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn)
cancel()
wg.Done()
}() }()
select { wg.Wait()
case <-ctx.Done():
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", epID(id)) if errInToOut != nil {
return if !isClosedError(errInToOut) {
case err := <-errChan: f.logger.Error("proxyTCP: copy error (in -> out) for %s: %v", epID(id), errInToOut)
if err != nil && !isClosedError(err) {
f.logger.Error("proxyTCP: copy error: %v", err)
} }
f.logger.Trace("forwarder: tearing down TCP connection %v", epID(id)) }
return if errOutToIn != nil {
if !isClosedError(errOutToIn) {
f.logger.Error("proxyTCP: copy error (out -> in) for %s: %v", epID(id), errOutToIn)
} }
} }
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) { var rxPackets, txPackets uint64
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
// fields are flipped since this is the in conn
rxPackets = tcpStats.SegmentsSent.Value()
txPackets = tcpStats.SegmentsReceived.Value()
}
f.logger.Trace("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
}
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
fields := nftypes.EventFields{ fields := nftypes.EventFields{
FlowID: flowID, FlowID: flowID,
Type: typ, Type: typ,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: nftypes.TCP, Protocol: nftypes.TCP,
// TODO: handle ipv6 // TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()), SourceIP: srcIp,
DestIP: netip.AddrFrom4(id.LocalAddress.As4()), DestIP: dstIp,
SourcePort: id.RemotePort, SourcePort: id.RemotePort,
DestPort: id.LocalPort, DestPort: id.LocalPort,
RxBytes: rxBytes,
TxBytes: txBytes,
RxPackets: rxPackets,
TxPackets: txPackets,
} }
if ep != nil { if typ == nftypes.TypeStart {
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok { if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
// fields are flipped since this is the in conn fields.RuleID = ruleId
// TODO: get bytes
fields.RxPackets = tcpStats.SegmentsSent.Value()
fields.TxPackets = tcpStats.SegmentsReceived.Value()
} }
} else {
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
} }
f.flowLogger.StoreEvent(fields) f.flowLogger.StoreEvent(fields)

View File

@@ -149,11 +149,11 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
flowID := uuid.New() flowID := uuid.New()
f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil) f.sendUDPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
var success bool var success bool
defer func() { defer func() {
if !success { if !success {
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil) f.sendUDPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
} }
}() }()
@@ -199,7 +199,6 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if err := outConn.Close(); err != nil { if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err) f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
} }
return return
} }
f.udpForwarder.conns[id] = pConn f.udpForwarder.conns[id] = pConn
@@ -212,68 +211,94 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
} }
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) { func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
defer func() {
ctx, cancel := context.WithCancel(f.ctx)
defer cancel()
go func() {
<-ctx.Done()
pConn.cancel() pConn.cancel()
if err := pConn.conn.Close(); err != nil { if err := pConn.conn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err) f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
} }
if err := pConn.outConn.Close(); err != nil { if err := pConn.outConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err) f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
} }
ep.Close() ep.Close()
}()
var wg sync.WaitGroup
wg.Add(2)
var txBytes, rxBytes int64
var outboundErr, inboundErr error
// outbound->inbound: copy from pConn.conn to pConn.outConn
go func() {
defer wg.Done()
txBytes, outboundErr = pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
}()
// inbound->outbound: copy from pConn.outConn to pConn.conn
go func() {
defer wg.Done()
rxBytes, inboundErr = pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
}()
wg.Wait()
if outboundErr != nil && !isClosedError(outboundErr) {
f.logger.Error("proxyUDP: copy error (outbound->inbound) for %s: %v", epID(id), outboundErr)
}
if inboundErr != nil && !isClosedError(inboundErr) {
f.logger.Error("proxyUDP: copy error (inbound->outbound) for %s: %v", epID(id), inboundErr)
}
var rxPackets, txPackets uint64
if udpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
// fields are flipped since this is the in conn
rxPackets = udpStats.PacketsSent.Value()
txPackets = udpStats.PacketsReceived.Value()
}
f.logger.Trace("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
f.udpForwarder.Lock() f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id) delete(f.udpForwarder.conns, id)
f.udpForwarder.Unlock() f.udpForwarder.Unlock()
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep) f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, uint64(rxBytes), uint64(txBytes), rxPackets, txPackets)
}()
errChan := make(chan error, 2)
go func() {
errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
}()
go func() {
errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
}()
select {
case <-ctx.Done():
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", epID(id))
return
case err := <-errChan:
if err != nil && !isClosedError(err) {
f.logger.Error("proxyUDP: copy error: %v", err)
}
f.logger.Trace("forwarder: tearing down UDP connection %v", epID(id))
return
}
} }
// sendUDPEvent stores flow events for UDP connections // sendUDPEvent stores flow events for UDP connections
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) { func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
fields := nftypes.EventFields{ fields := nftypes.EventFields{
FlowID: flowID, FlowID: flowID,
Type: typ, Type: typ,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: nftypes.UDP, Protocol: nftypes.UDP,
// TODO: handle ipv6 // TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()), SourceIP: srcIp,
DestIP: netip.AddrFrom4(id.LocalAddress.As4()), DestIP: dstIp,
SourcePort: id.RemotePort, SourcePort: id.RemotePort,
DestPort: id.LocalPort, DestPort: id.LocalPort,
RxBytes: rxBytes,
TxBytes: txBytes,
RxPackets: rxPackets,
TxPackets: txPackets,
} }
if ep != nil { if typ == nftypes.TypeStart {
if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok { if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
// fields are flipped since this is the in conn fields.RuleID = ruleId
// TODO: get bytes
fields.RxPackets = tcpStats.PacketsSent.Value()
fields.TxPackets = tcpStats.PacketsReceived.Value()
} }
} else {
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
} }
f.flowLogger.StoreEvent(fields) f.flowLogger.StoreEvent(fields)
@@ -288,18 +313,20 @@ func (c *udpPacketConn) getIdleDuration() time.Duration {
return time.Since(lastSeen) return time.Since(lastSeen)
} }
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error { // copy reads from src and writes to dst.
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) (int64, error) {
bufp := bufPool.Get().(*[]byte) bufp := bufPool.Get().(*[]byte)
defer bufPool.Put(bufp) defer bufPool.Put(bufp)
buffer := *bufp buffer := *bufp
var totalBytes int64 = 0
for { for {
if ctx.Err() != nil { if ctx.Err() != nil {
return ctx.Err() return totalBytes, ctx.Err()
} }
if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil { if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil {
return fmt.Errorf("set read deadline: %w", err) return totalBytes, fmt.Errorf("set read deadline: %w", err)
} }
n, err := src.Read(buffer) n, err := src.Read(buffer)
@@ -307,14 +334,15 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu
if isTimeout(err) { if isTimeout(err) {
continue continue
} }
return fmt.Errorf("read from %s: %w", direction, err) return totalBytes, fmt.Errorf("read from %s: %w", direction, err)
} }
_, err = dst.Write(buffer[:n]) nWritten, err := dst.Write(buffer[:n])
if err != nil { if err != nil {
return fmt.Errorf("write to %s: %w", direction, err) return totalBytes, fmt.Errorf("write to %s: %w", direction, err)
} }
totalBytes += int64(nWritten)
c.updateLastSeen() c.updateLastSeen()
} }
} }

View File

@@ -14,8 +14,13 @@ import (
type localIPManager struct { type localIPManager struct {
mu sync.RWMutex mu sync.RWMutex
// Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory) // fixed-size high array for upper byte of a IPv4 address
ipv4Bitmap [1 << 16]uint32 ipv4Bitmap [256]*ipv4LowBitmap
}
// ipv4LowBitmap is a map for the low 16 bits of a IPv4 address
type ipv4LowBitmap struct {
bitmap [8192]uint32
} }
func newLocalIPManager() *localIPManager { func newLocalIPManager() *localIPManager {
@@ -27,35 +32,61 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
if ipv4 == nil { if ipv4 == nil {
return return
} }
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1]) high := uint16(ipv4[0])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3]) low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
m.ipv4Bitmap[high] |= 1 << (low % 32)
index := low / 32
bit := low % 32
if m.ipv4Bitmap[high] == nil {
m.ipv4Bitmap[high] = &ipv4LowBitmap{}
}
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
}
func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
if !ip.Is4() {
return
}
ipv4 := ip.AsSlice()
high := uint16(ipv4[0])
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
if bitmap[high] == nil {
bitmap[high] = &ipv4LowBitmap{}
}
index := low / 32
bit := low % 32
bitmap[high].bitmap[index] |= 1 << bit
if _, exists := ipv4Set[ip]; !exists {
ipv4Set[ip] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ip)
}
} }
func (m *localIPManager) checkBitmapBit(ip []byte) bool { func (m *localIPManager) checkBitmapBit(ip []byte) bool {
high := (uint16(ip[0]) << 8) | uint16(ip[1]) high := uint16(ip[0])
low := (uint16(ip[2]) << 8) | uint16(ip[3]) low := (uint16(ip[1]) << 8) | (uint16(ip[2]) << 4) | uint16(ip[3])
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
if m.ipv4Bitmap[high] == nil {
return false
} }
func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error { index := low / 32
if ipv4 := ip.To4(); ipv4 != nil { bit := low % 32
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1]) return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
if int(high) >= len(*newIPv4Bitmap) {
return fmt.Errorf("invalid IPv4 address: %s", ip)
}
ipStr := ip.String()
if _, exists := ipv4Set[ipStr]; !exists {
ipv4Set[ipStr] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ipStr)
newIPv4Bitmap[high] |= 1 << (low % 32)
}
} }
func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
return nil return nil
} }
func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) { func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
addrs, err := iface.Addrs() addrs, err := iface.Addrs()
if err != nil { if err != nil {
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err) log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
@@ -73,7 +104,13 @@ func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1
continue continue
} }
if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil { addr, ok := netip.AddrFromSlice(ip)
if !ok {
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
continue
}
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
log.Debugf("process IP failed: %v", err) log.Debugf("process IP failed: %v", err)
} }
} }
@@ -86,14 +123,14 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
} }
}() }()
var newIPv4Bitmap [1 << 16]uint32 var newIPv4Bitmap [256]*ipv4LowBitmap
ipv4Set := make(map[string]struct{}) ipv4Set := make(map[netip.Addr]struct{})
var ipv4Addresses []string var ipv4Addresses []netip.Addr
// 127.0.0.0/8 // 127.0.0.0/8
high := uint16(127) << 8 newIPv4Bitmap[127] = &ipv4LowBitmap{}
for i := uint16(0); i < 256; i++ { for i := 0; i < 8192; i++ {
newIPv4Bitmap[high|i] = 0xffffffff newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
} }
if iface != nil { if iface != nil {
@@ -120,12 +157,12 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
} }
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool { func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
if !ip.Is4() {
return false
}
m.mu.RLock() m.mu.RLock()
defer m.mu.RUnlock() defer m.mu.RUnlock()
if ip.Is4() {
return m.checkBitmapBit(ip.AsSlice()) return m.checkBitmapBit(ip.AsSlice())
} }
return false
}

View File

@@ -20,11 +20,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Localhost range", name: "Localhost range",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("127.0.0.2"), testIP: netip.MustParseAddr("127.0.0.2"),
expected: true, expected: true,
@@ -32,11 +29,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Localhost standard address", name: "Localhost standard address",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("127.0.0.1"), testIP: netip.MustParseAddr("127.0.0.1"),
expected: true, expected: true,
@@ -44,11 +38,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Localhost range edge", name: "Localhost range edge",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("127.255.255.255"), testIP: netip.MustParseAddr("127.255.255.255"),
expected: true, expected: true,
@@ -56,11 +47,8 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Local IP matches", name: "Local IP matches",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("192.168.1.1"), testIP: netip.MustParseAddr("192.168.1.1"),
expected: true, expected: true,
@@ -68,23 +56,26 @@ func TestLocalIPManager(t *testing.T) {
{ {
name: "Local IP doesn't match", name: "Local IP doesn't match",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: netip.MustParseAddr("192.168.1.1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
}, },
testIP: netip.MustParseAddr("192.168.1.2"), testIP: netip.MustParseAddr("192.168.1.2"),
expected: false, expected: false,
}, },
{
name: "Local IP doesn't match - addresses 32 apart",
setupAddr: wgaddr.Address{
IP: netip.MustParseAddr("192.168.1.1"),
Network: netip.MustParsePrefix("192.168.1.0/24"),
},
testIP: netip.MustParseAddr("192.168.1.33"),
expected: false,
},
{ {
name: "IPv6 address", name: "IPv6 address",
setupAddr: wgaddr.Address{ setupAddr: wgaddr.Address{
IP: net.ParseIP("fe80::1"), IP: netip.MustParseAddr("fe80::1"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("192.168.1.0/24"),
IP: net.ParseIP("fe80::"),
Mask: net.CIDRMask(64, 128),
},
}, },
testIP: netip.MustParseAddr("fe80::1"), testIP: netip.MustParseAddr("fe80::1"),
expected: false, expected: false,
@@ -192,10 +183,8 @@ func BenchmarkIPChecks(b *testing.B) {
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i)) interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
} }
// Setup bitmap version // Setup bitmap
bitmapManager := &localIPManager{ bitmapManager := newLocalIPManager()
ipv4Bitmap: [1 << 16]uint32{},
}
for _, ip := range interfaces[:8] { // Add half of IPs for _, ip := range interfaces[:8] { // Add half of IPs
bitmapManager.setBitmapBit(ip) bitmapManager.setBitmapBit(ip)
} }
@@ -248,7 +237,7 @@ func BenchmarkWGPosition(b *testing.B) {
// Create two managers - one checks WG IP first, other checks it last // Create two managers - one checks WG IP first, other checks it last
b.Run("WG_First", func(b *testing.B) { b.Run("WG_First", func(b *testing.B) {
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}} bm := newLocalIPManager()
bm.setBitmapBit(wgIP) bm.setBitmapBit(wgIP)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@@ -257,7 +246,7 @@ func BenchmarkWGPosition(b *testing.B) {
}) })
b.Run("WG_Last", func(b *testing.B) { b.Run("WG_Last", func(b *testing.B) {
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}} bm := newLocalIPManager()
// Fill with other IPs first // Fill with other IPs first
for i := 0; i < 15; i++ { for i := 0; i < 15; i++ {
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i))) bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))

View File

@@ -32,7 +32,8 @@ type RouteRule struct {
id string id string
mgmtId []byte mgmtId []byte
sources []netip.Prefix sources []netip.Prefix
destination netip.Prefix dstSet firewall.Set
destinations []netip.Prefix
proto firewall.Protocol proto firewall.Protocol
srcPort *firewall.Port srcPort *firewall.Port
dstPort *firewall.Port dstPort *firewall.Port

View File

@@ -38,11 +38,8 @@ func TestTracePacket(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"), IP: netip.MustParseAddr("100.10.0.100"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("100.10.0.0/16"),
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
} }
}, },
} }
@@ -198,12 +195,12 @@ func TestTracePacket(t *testing.T) {
m.forwarder.Store(&forwarder.Forwarder{}) m.forwarder.Store(&forwarder.Forwarder{})
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32) src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32) dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept) _, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err) require.NoError(t, err)
}, },
packetBuilder: func() *PacketBuilder { packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN) return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
@@ -222,12 +219,12 @@ func TestTracePacket(t *testing.T) {
m.nativeRouter.Store(false) m.nativeRouter.Store(false)
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32) src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32) dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop) _, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
require.NoError(t, err) require.NoError(t, err)
}, },
packetBuilder: func() *PacketBuilder { packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN) return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
@@ -245,7 +242,7 @@ func TestTracePacket(t *testing.T) {
m.nativeRouter.Store(true) m.nativeRouter.Store(true)
}, },
packetBuilder: func() *PacketBuilder { packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN) return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
@@ -263,7 +260,7 @@ func TestTracePacket(t *testing.T) {
m.routingEnabled.Store(false) m.routingEnabled.Store(false)
}, },
packetBuilder: func() *PacketBuilder { packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN) return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
}, },
expectedStages: []PacketStage{ expectedStages: []PacketStage{
StageReceived, StageReceived,
@@ -425,8 +422,8 @@ func TestTracePacket(t *testing.T) {
require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")), require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")),
"100.10.0.100 should be recognized as a local IP") "100.10.0.100 should be recognized as a local IP")
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("172.17.0.2")), require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("192.168.17.2")),
"172.17.0.2 should not be recognized as a local IP") "192.168.17.2 should not be recognized as a local IP")
pb := tc.packetBuilder() pb := tc.packetBuilder()

View File

@@ -39,8 +39,12 @@ const (
// EnvForceUserspaceRouter forces userspace routing even if native routing is available. // EnvForceUserspaceRouter forces userspace routing even if native routing is available.
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER" EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
// EnvEnableNetstackLocalForwarding enables forwarding of local traffic to the native stack when running netstack // EnvEnableLocalForwarding enables forwarding of local traffic to the native stack for internal (non-NetBird) interfaces.
// Leaving this on by default introduces a security risk as sockets on listening on localhost only will be accessible // Default off as it might be security risk because sockets listening on localhost only will become accessible.
EnvEnableLocalForwarding = "NB_ENABLE_LOCAL_FORWARDING"
// EnvEnableNetstackLocalForwarding is an alias for EnvEnableLocalForwarding.
// In netstack mode, it enables forwarding of local traffic to the native stack for all interfaces.
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING" EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
) )
@@ -49,10 +53,10 @@ var errNatNotSupported = errors.New("nat not supported with userspace firewall")
// RuleSet is a set of rules grouped by a string key // RuleSet is a set of rules grouped by a string key
type RuleSet map[string]PeerRule type RuleSet map[string]PeerRule
type RouteRules []RouteRule type RouteRules []*RouteRule
func (r RouteRules) Sort() { func (r RouteRules) Sort() {
slices.SortStableFunc(r, func(a, b RouteRule) int { slices.SortStableFunc(r, func(a, b *RouteRule) int {
// Deny rules come first // Deny rules come first
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop { if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
return -1 return -1
@@ -71,7 +75,6 @@ type Manager struct {
// incomingRules is used for filtering and hooks // incomingRules is used for filtering and hooks
incomingRules map[netip.Addr]RuleSet incomingRules map[netip.Addr]RuleSet
routeRules RouteRules routeRules RouteRules
wgNetwork *net.IPNet
decoders sync.Pool decoders sync.Pool
wgIface common.IFaceMapper wgIface common.IFaceMapper
nativeFirewall firewall.Manager nativeFirewall firewall.Manager
@@ -99,6 +102,8 @@ type Manager struct {
forwarder atomic.Pointer[forwarder.Forwarder] forwarder atomic.Pointer[forwarder.Forwarder]
logger *nblog.Logger logger *nblog.Logger
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
blockRule firewall.Rule
} }
// decoder for packages // decoder for packages
@@ -146,6 +151,11 @@ func parseCreateEnv() (bool, bool) {
if err != nil { if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err) log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
} }
} else if val := os.Getenv(EnvEnableLocalForwarding); val != "" {
enableLocalForwarding, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err)
}
} }
return disableConntrack, enableLocalForwarding return disableConntrack, enableLocalForwarding
@@ -201,41 +211,35 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
} }
} }
if err := m.blockInvalidRouted(iface); err != nil {
log.Errorf("failed to block invalid routed traffic: %v", err)
}
if err := iface.SetFilter(m); err != nil { if err := iface.SetFilter(m); err != nil {
return nil, fmt.Errorf("set filter: %w", err) return nil, fmt.Errorf("set filter: %w", err)
} }
return m, nil return m, nil
} }
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error { func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) {
if m.forwarder.Load() == nil {
return nil
}
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String()) wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
if err != nil { if err != nil {
return fmt.Errorf("parse wireguard network: %w", err) return nil, fmt.Errorf("parse wireguard network: %w", err)
} }
log.Debugf("blocking invalid routed traffic for %s", wgPrefix) log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
if _, err := m.AddRouteFiltering( rule, err := m.addRouteFiltering(
nil, nil,
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}, []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
wgPrefix, firewall.Network{Prefix: wgPrefix},
firewall.ProtocolALL, firewall.ProtocolALL,
nil, nil,
nil, nil,
firewall.ActionDrop, firewall.ActionDrop,
); err != nil { )
return fmt.Errorf("block wg nte : %w", err) if err != nil {
return nil, fmt.Errorf("block wg nte : %w", err)
} }
// TODO: Block networks that we're a client of // TODO: Block networks that we're a client of
return nil return rule, nil
} }
func (m *Manager) determineRouting() error { func (m *Manager) determineRouting() error {
@@ -273,7 +277,7 @@ func (m *Manager) determineRouting() error {
log.Info("userspace routing is forced") log.Info("userspace routing is forced")
case !m.netstack && m.nativeFirewall != nil && m.nativeFirewall.IsServerRouteSupported(): case !m.netstack && m.nativeFirewall != nil:
// if the OS supports routing natively, then we don't need to filter/route ourselves // if the OS supports routing natively, then we don't need to filter/route ourselves
// netstack mode won't support native routing as there is no interface // netstack mode won't support native routing as there is no interface
@@ -330,6 +334,10 @@ func (m *Manager) IsServerRouteSupported() bool {
return true return true
} }
func (m *Manager) IsStateful() bool {
return m.stateful
}
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.AddNatRule(pair) return m.nativeFirewall.AddNatRule(pair)
@@ -413,10 +421,23 @@ func (m *Manager) AddPeerFiltering(
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
id []byte, id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination firewall.Network,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort, dPort *firewall.Port,
dPort *firewall.Port, action firewall.Action,
) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.addRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
}
func (m *Manager) addRouteFiltering(
id []byte,
sources []netip.Prefix,
destination firewall.Network,
proto firewall.Protocol,
sPort, dPort *firewall.Port,
action firewall.Action, action firewall.Action,
) (firewall.Rule, error) { ) (firewall.Rule, error) {
if m.nativeRouter.Load() && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
@@ -429,31 +450,36 @@ func (m *Manager) AddRouteFiltering(
id: ruleID, id: ruleID,
mgmtId: id, mgmtId: id,
sources: sources, sources: sources,
destination: destination, dstSet: destination.Set,
proto: proto, proto: proto,
srcPort: sPort, srcPort: sPort,
dstPort: dPort, dstPort: dPort,
action: action, action: action,
} }
if destination.IsPrefix() {
rule.destinations = []netip.Prefix{destination.Prefix}
}
m.mutex.Lock() m.routeRules = append(m.routeRules, &rule)
m.routeRules = append(m.routeRules, rule)
m.routeRules.Sort() m.routeRules.Sort()
m.mutex.Unlock()
return &rule, nil return &rule, nil
} }
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.deleteRouteRule(rule)
}
func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.DeleteRouteRule(rule) return m.nativeFirewall.DeleteRouteRule(rule)
} }
m.mutex.Lock()
defer m.mutex.Unlock()
ruleID := rule.ID() ruleID := rule.ID()
idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool { idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
return r.id == ruleID return r.id == ruleID
}) })
if idx < 0 { if idx < 0 {
@@ -509,6 +535,52 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
return m.nativeFirewall.DeleteDNATRule(rule) return m.nativeFirewall.DeleteDNATRule(rule)
} }
// UpdateSet updates the rule destinations associated with the given set
// by merging the existing prefixes with the new ones, then deduplicating.
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.UpdateSet(set, prefixes)
}
m.mutex.Lock()
defer m.mutex.Unlock()
var matches []*RouteRule
for _, rule := range m.routeRules {
if rule.dstSet == set {
matches = append(matches, rule)
}
}
if len(matches) == 0 {
return fmt.Errorf("no route rule found for set: %s", set)
}
destinations := matches[0].destinations
for _, prefix := range prefixes {
if prefix.Addr().Is4() {
destinations = append(destinations, prefix)
}
}
slices.SortFunc(destinations, func(a, b netip.Prefix) int {
cmp := a.Addr().Compare(b.Addr())
if cmp != 0 {
return cmp
}
return a.Bits() - b.Bits()
})
destinations = slices.Compact(destinations)
for _, rule := range matches {
rule.destinations = destinations
}
log.Debugf("updated set %s to prefixes %v", set.HashedName(), destinations)
return nil
}
// DropOutgoing filter outgoing packets // DropOutgoing filter outgoing packets
func (m *Manager) DropOutgoing(packetData []byte, size int) bool { func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
return m.processOutgoingHooks(packetData, size) return m.processOutgoingHooks(packetData, size)
@@ -546,9 +618,8 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
return true return true
} }
if m.stateful { // for netflow we keep track even if the firewall is stateless
m.trackOutbound(d, srcIP, dstIP, size) m.trackOutbound(d, srcIP, dstIP, size)
}
return false return false
} }
@@ -658,7 +729,8 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
d := m.decoders.Get().(*decoder) d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d) defer m.decoders.Put(d)
if !m.isValidPacket(d, packetData) { valid, fragment := m.isValidPacket(d, packetData)
if !valid {
return true return true
} }
@@ -668,6 +740,13 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
return true return true
} }
// TODO: pass fragments of routed packets to forwarder
if fragment {
m.logger.Trace("packet is a fragment: src=%v dst=%v id=%v flags=%v",
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
return false
}
// For all inbound traffic, first check if it matches a tracked connection. // For all inbound traffic, first check if it matches a tracked connection.
// This must happen before any other filtering because the packets are statefully tracked. // This must happen before any other filtering because the packets are statefully tracked.
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) { if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
@@ -678,7 +757,7 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
return m.handleLocalTraffic(d, srcIP, dstIP, packetData, size) return m.handleLocalTraffic(d, srcIP, dstIP, packetData, size)
} }
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData) return m.handleRoutedTraffic(d, srcIP, dstIP, packetData, size)
} }
// handleLocalTraffic handles local traffic. // handleLocalTraffic handles local traffic.
@@ -709,9 +788,10 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
return true return true
} }
// if running in netstack mode we need to pass this to the forwarder // If requested we pass local traffic to internal interfaces to the forwarder.
if m.netstack && m.localForwarding { // netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder.
return m.handleNetstackLocalTraffic(packetData) if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) {
return m.handleForwardedLocalTraffic(packetData)
} }
// track inbound packets to get the correct direction and session id for flows // track inbound packets to get the correct direction and session id for flows
@@ -721,8 +801,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
return false return false
} }
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool { func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
fwd := m.forwarder.Load() fwd := m.forwarder.Load()
if fwd == nil { if fwd == nil {
m.logger.Trace("Dropping local packet (forwarder not initialized)") m.logger.Trace("Dropping local packet (forwarder not initialized)")
@@ -739,7 +818,7 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
// handleRoutedTraffic handles routed traffic. // handleRoutedTraffic handles routed traffic.
// If it returns true, the packet should be dropped. // If it returns true, the packet should be dropped.
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte) bool { func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
// Drop if routing is disabled // Drop if routing is disabled
if !m.routingEnabled.Load() { if !m.routingEnabled.Load() {
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s", m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
@@ -749,13 +828,15 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
// Pass to native stack if native router is enabled or forced // Pass to native stack if native router is enabled or forced
if m.nativeRouter.Load() { if m.nativeRouter.Load() {
m.trackInbound(d, srcIP, dstIP, nil, size)
return false return false
} }
proto, pnum := getProtocolFromPacket(d) proto, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
if ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass { ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
if !pass {
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort) ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
@@ -770,6 +851,8 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
SourcePort: srcPort, SourcePort: srcPort,
DestPort: dstPort, DestPort: dstPort,
// TODO: icmp type/code // TODO: icmp type/code
RxPackets: 1,
RxBytes: uint64(size),
}) })
return true return true
} }
@@ -779,8 +862,11 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
if fwd == nil { if fwd == nil {
m.logger.Trace("failed to forward routed packet (forwarder not initialized)") m.logger.Trace("failed to forward routed packet (forwarder not initialized)")
} else { } else {
fwd.RegisterRuleID(srcIP, dstIP, srcPort, dstPort, ruleID)
if err := fwd.InjectIncomingPacket(packetData); err != nil { if err := fwd.InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject routed packet: %v", err) m.logger.Error("Failed to inject routed packet: %v", err)
fwd.DeleteRuleID(srcIP, dstIP, srcPort, dstPort)
} }
} }
@@ -812,17 +898,32 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
} }
} }
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool { // isValidPacket checks if the packet is valid.
// It returns true, false if the packet is valid and not a fragment.
// It returns true, true if the packet is a fragment and valid.
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
m.logger.Trace("couldn't decode packet, err: %s", err) m.logger.Trace("couldn't decode packet, err: %s", err)
return false return false, false
} }
if len(d.decoded) < 2 { l := len(d.decoded)
m.logger.Trace("packet doesn't have network and transport layers")
return false // L3 and L4 are mandatory
if l >= 2 {
return true, false
} }
return true
// Fragments are also valid
if l == 1 && d.decoded[0] == layers.LayerTypeIPv4 {
ip4 := d.ip4
if ip4.Flags&layers.IPv4MoreFragments != 0 || ip4.FragOffset != 0 {
return true, true
}
}
m.logger.Trace("packet doesn't have network and transport layers")
return false, false
} }
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr, size int) bool { func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr, size int) bool {
@@ -962,8 +1063,15 @@ func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol
return nil, false return nil, false
} }
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool { func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
if !rule.destination.Contains(dstAddr) { destMatched := false
for _, dst := range rule.destinations {
if dst.Contains(dstAddr) {
destMatched = true
break
}
}
if !destMatched {
return false return false
} }
@@ -991,11 +1099,6 @@ func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto
return true return true
} }
// SetNetwork of the wireguard interface to which filtering applied
func (m *Manager) SetNetwork(network *net.IPNet) {
m.wgNetwork = network
}
// AddUDPPacketHook calls hook when UDP packet from given direction matched // AddUDPPacketHook calls hook when UDP packet from given direction matched
// //
// Hook function returns flag which indicates should be the matched package dropped or not // Hook function returns flag which indicates should be the matched package dropped or not
@@ -1065,7 +1168,22 @@ func (m *Manager) EnableRouting() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.determineRouting() if err := m.determineRouting(); err != nil {
return fmt.Errorf("determine routing: %w", err)
}
if m.forwarder.Load() == nil {
return nil
}
rule, err := m.blockInvalidRouted(m.wgIface)
if err != nil {
return fmt.Errorf("block invalid routed: %w", err)
}
m.blockRule = rule
return nil
} }
func (m *Manager) DisableRouting() error { func (m *Manager) DisableRouting() error {
@@ -1090,5 +1208,12 @@ func (m *Manager) DisableRouting() error {
log.Debug("forwarder stopped") log.Debug("forwarder stopped")
if m.blockRule != nil {
if err := m.deleteRouteRule(m.blockRule); err != nil {
return fmt.Errorf("delete block rule: %w", err)
}
m.blockRule = nil
}
return nil return nil
} }

View File

@@ -174,11 +174,6 @@ func BenchmarkCoreFiltering(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
// Apply scenario-specific setup // Apply scenario-specific setup
sc.setupFunc(manager) sc.setupFunc(manager)
@@ -219,11 +214,6 @@ func BenchmarkStateScaling(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
// Pre-populate connection table // Pre-populate connection table
srcIPs := generateRandomIPs(count) srcIPs := generateRandomIPs(count)
dstIPs := generateRandomIPs(count) dstIPs := generateRandomIPs(count)
@@ -267,11 +257,6 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
srcIP := generateRandomIPs(1)[0] srcIP := generateRandomIPs(1)[0]
dstIP := generateRandomIPs(1)[0] dstIP := generateRandomIPs(1)[0]
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP) outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
@@ -304,10 +289,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -321,10 +302,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -339,10 +316,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -356,10 +329,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
b.Setenv("NB_DISABLE_CONNTRACK", "1") b.Setenv("NB_DISABLE_CONNTRACK", "1")
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -373,10 +342,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -390,10 +355,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -408,10 +369,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolTCP, proto: layers.IPProtocolTCP,
state: "post_handshake", state: "post_handshake",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -426,10 +383,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "new", state: "new",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -443,10 +396,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
proto: layers.IPProtocolUDP, proto: layers.IPProtocolUDP,
state: "established", state: "established",
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("0.0.0.0"),
Mask: net.CIDRMask(0, 32),
}
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
}, },
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
@@ -593,11 +542,6 @@ func BenchmarkLongLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
// Single rule to allow all return traffic from port 80 // Single rule to allow all return traffic from port 80
@@ -681,11 +625,6 @@ func BenchmarkShortLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
// Single rule to allow all return traffic from port 80 // Single rule to allow all return traffic from port 80
@@ -797,11 +736,6 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "") _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
@@ -882,11 +816,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
})
if sc.rules { if sc.rules {
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "") _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
require.NoError(b, err) require.NoError(b, err)
@@ -1032,7 +961,8 @@ func BenchmarkRouteACLs(b *testing.B) {
} }
for _, r := range rules { for _, r := range rules {
_, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept) dst := fw.Network{Prefix: r.dest}
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }

View File

@@ -15,15 +15,12 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/management/domain"
) )
func TestPeerACLFiltering(t *testing.T) { func TestPeerACLFiltering(t *testing.T) {
localIP := net.ParseIP("100.10.0.100") localIP := netip.MustParseAddr("100.10.0.100")
wgNet := &net.IPNet{ wgNet := netip.MustParsePrefix("100.10.0.0/16")
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
@@ -42,8 +39,6 @@ func TestPeerACLFiltering(t *testing.T) {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Close(nil))
}) })
manager.wgNetwork = wgNet
err = manager.UpdateLocalIPs() err = manager.UpdateLocalIPs()
require.NoError(t, err) require.NoError(t, err)
@@ -188,6 +183,281 @@ func TestPeerACLFiltering(t *testing.T) {
ruleAction: fw.ActionAccept, ruleAction: fw.ActionAccept,
shouldBeBlocked: true, shouldBeBlocked: true,
}, },
{
name: "Allow TCP traffic without port specification",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "Allow UDP traffic without port specification",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 53,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolUDP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "TCP packet doesn't match UDP filter with same port",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolUDP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "UDP packet doesn't match TCP filter with same port",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "ICMP packet doesn't match TCP filter",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolICMP,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "ICMP packet doesn't match UDP filter",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolICMP,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolUDP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "Allow TCP traffic within port range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 8080,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "Block TCP traffic outside port range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 7999,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "Edge Case - Port at Range Boundary",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 8100,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "UDP Port Range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 5060,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolUDP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{5060, 5070}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "Allow multiple destination ports",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 8080,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "Allow multiple source ports",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 80,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleSrcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
// New drop test cases
{
name: "Drop TCP traffic from WG peer",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop UDP traffic from WG peer",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 53,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolUDP,
ruleDstPort: &fw.Port{Values: []uint16{53}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop ICMP traffic from WG peer",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolICMP,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolICMP,
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop all traffic from WG peer",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolALL,
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop traffic from multiple source ports",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 80,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleSrcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop multiple destination ports",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 8080,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop TCP traffic within port range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 8080,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Accept TCP traffic outside drop port range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 7999,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: false,
},
{
name: "Drop TCP traffic with source port range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 32100,
dstPort: 80,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleSrcPort: &fw.Port{IsRange: true, Values: []uint16{32000, 33000}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Mixed rule - drop specific port but allow other ports",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
} }
t.Run("Implicit DROP (no rules)", func(t *testing.T) { t.Run("Implicit DROP (no rules)", func(t *testing.T) {
@@ -198,6 +468,28 @@ func TestPeerACLFiltering(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
if tc.ruleAction == fw.ActionDrop {
// add general accept rule to test drop rule
// TODO: this only works because 0.0.0.0 is tested last, we need to implement order
rules, err := manager.AddPeerFiltering(
nil,
net.ParseIP("0.0.0.0"),
fw.ProtocolALL,
nil,
nil,
fw.ActionAccept,
"",
)
require.NoError(t, err)
require.NotEmpty(t, rules)
t.Cleanup(func() {
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
}
})
}
rules, err := manager.AddPeerFiltering( rules, err := manager.AddPeerFiltering(
nil, nil,
net.ParseIP(tc.ruleIP), net.ParseIP(tc.ruleIP),
@@ -283,14 +575,13 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
dev := mocks.NewMockDevice(ctrl) dev := mocks.NewMockDevice(ctrl)
dev.EXPECT().MTU().Return(1500, nil).AnyTimes() dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
localIP, wgNet, err := net.ParseCIDR(network) wgNet := netip.MustParsePrefix(network)
require.NoError(tb, err)
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: localIP, IP: wgNet.Addr(),
Network: wgNet, Network: wgNet,
} }
}, },
@@ -303,8 +594,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
} }
manager, err := Create(ifaceMock, false, flowLogger) manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(tb, manager.EnableRouting())
require.NoError(tb, err) require.NoError(tb, err)
require.NoError(tb, manager.EnableRouting())
require.NotNil(tb, manager) require.NotNil(tb, manager)
require.True(tb, manager.routingEnabled.Load()) require.True(tb, manager.routingEnabled.Load())
require.False(tb, manager.nativeRouter.Load()) require.False(tb, manager.nativeRouter.Load())
@@ -321,7 +612,7 @@ func TestRouteACLFiltering(t *testing.T) {
type rule struct { type rule struct {
sources []netip.Prefix sources []netip.Prefix
dest netip.Prefix dest fw.Network
proto fw.Protocol proto fw.Protocol
srcPort *fw.Port srcPort *fw.Port
dstPort *fw.Port dstPort *fw.Port
@@ -347,7 +638,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443, dstPort: 443,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}}, dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -363,7 +654,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443, dstPort: 443,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}}, dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -379,7 +670,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443, dstPort: 443,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
dest: netip.MustParsePrefix("0.0.0.0/0"), dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}}, dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -395,7 +686,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 53, dstPort: 53,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolUDP, proto: fw.ProtocolUDP,
dstPort: &fw.Port{Values: []uint16{53}}, dstPort: &fw.Port{Values: []uint16{53}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -409,7 +700,7 @@ func TestRouteACLFiltering(t *testing.T) {
proto: fw.ProtocolICMP, proto: fw.ProtocolICMP,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("0.0.0.0/0"), dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
proto: fw.ProtocolICMP, proto: fw.ProtocolICMP,
action: fw.ActionAccept, action: fw.ActionAccept,
}, },
@@ -424,7 +715,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -440,7 +731,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8080, dstPort: 8080,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -456,7 +747,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -472,7 +763,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -488,7 +779,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
srcPort: &fw.Port{Values: []uint16{12345}}, srcPort: &fw.Port{Values: []uint16{12345}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -507,7 +798,7 @@ func TestRouteACLFiltering(t *testing.T) {
netip.MustParsePrefix("100.10.0.0/16"), netip.MustParsePrefix("100.10.0.0/16"),
netip.MustParsePrefix("172.16.0.0/16"), netip.MustParsePrefix("172.16.0.0/16"),
}, },
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -521,7 +812,7 @@ func TestRouteACLFiltering(t *testing.T) {
proto: fw.ProtocolICMP, proto: fw.ProtocolICMP,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
action: fw.ActionAccept, action: fw.ActionAccept,
}, },
@@ -536,33 +827,13 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept, action: fw.ActionAccept,
}, },
shouldPass: true, shouldPass: true,
}, },
{
name: "Multiple source networks with mismatched protocol",
srcIP: "172.16.0.1",
dstIP: "192.168.1.100",
// Should not match TCP rule
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 80,
rule: rule{
sources: []netip.Prefix{
netip.MustParsePrefix("100.10.0.0/16"),
netip.MustParsePrefix("172.16.0.0/16"),
},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
},
shouldPass: false,
},
{ {
name: "Allow multiple destination ports", name: "Allow multiple destination ports",
srcIP: "100.10.0.1", srcIP: "100.10.0.1",
@@ -572,7 +843,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8080, dstPort: 8080,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80, 8080, 443}}, dstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -588,7 +859,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
srcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}}, srcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -604,7 +875,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
srcPort: &fw.Port{Values: []uint16{12345}}, srcPort: &fw.Port{Values: []uint16{12345}},
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
@@ -621,7 +892,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8080, dstPort: 8080,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{ dstPort: &fw.Port{
IsRange: true, IsRange: true,
@@ -640,7 +911,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 7999, dstPort: 7999,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{ dstPort: &fw.Port{
IsRange: true, IsRange: true,
@@ -659,7 +930,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
srcPort: &fw.Port{ srcPort: &fw.Port{
IsRange: true, IsRange: true,
@@ -678,7 +949,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443, dstPort: 443,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
srcPort: &fw.Port{ srcPort: &fw.Port{
IsRange: true, IsRange: true,
@@ -700,7 +971,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8100, dstPort: 8100,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{ dstPort: &fw.Port{
IsRange: true, IsRange: true,
@@ -719,7 +990,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 5060, dstPort: 5060,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolUDP, proto: fw.ProtocolUDP,
dstPort: &fw.Port{ dstPort: &fw.Port{
IsRange: true, IsRange: true,
@@ -738,7 +1009,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8080, dstPort: 8080,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
dstPort: &fw.Port{ dstPort: &fw.Port{
IsRange: true, IsRange: true,
@@ -757,7 +1028,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443, dstPort: 443,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}}, dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionDrop, action: fw.ActionDrop,
@@ -773,7 +1044,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
action: fw.ActionDrop, action: fw.ActionDrop,
}, },
@@ -791,17 +1062,158 @@ func TestRouteACLFiltering(t *testing.T) {
netip.MustParsePrefix("100.10.0.0/16"), netip.MustParsePrefix("100.10.0.0/16"),
netip.MustParsePrefix("172.16.0.0/16"), netip.MustParsePrefix("172.16.0.0/16"),
}, },
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionDrop, action: fw.ActionDrop,
}, },
shouldPass: false, shouldPass: false,
}, },
{
name: "Drop empty destination set",
srcIP: "172.16.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 80,
rule: rule{
sources: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
},
dest: fw.Network{Set: fw.Set{}},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
},
shouldPass: false,
},
{
name: "Accept TCP traffic outside drop port range",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 7999,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
action: fw.ActionDrop,
},
shouldPass: true,
},
{
name: "Allow TCP traffic without port specification",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
action: fw.ActionAccept,
},
shouldPass: true,
},
{
name: "Allow UDP traffic without port specification",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 53,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolUDP,
action: fw.ActionAccept,
},
shouldPass: true,
},
{
name: "TCP packet doesn't match UDP filter with same port",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolUDP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
},
shouldPass: false,
},
{
name: "UDP packet doesn't match TCP filter with same port",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
},
shouldPass: false,
},
{
name: "ICMP packet doesn't match TCP filter",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolICMP,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
action: fw.ActionAccept,
},
shouldPass: false,
},
{
name: "ICMP packet doesn't match UDP filter",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolICMP,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolUDP,
action: fw.ActionAccept,
},
shouldPass: false,
},
} }
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
if tc.rule.action == fw.ActionDrop {
// add general accept rule to test drop rule
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
fw.ProtocolALL,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
t.Cleanup(func() {
require.NoError(t, manager.DeleteRouteRule(rule))
})
}
rule, err := manager.AddRouteFiltering( rule, err := manager.AddRouteFiltering(
nil, nil,
tc.rule.sources, tc.rule.sources,
@@ -836,7 +1248,7 @@ func TestRouteACLOrder(t *testing.T) {
name string name string
rules []struct { rules []struct {
sources []netip.Prefix sources []netip.Prefix
dest netip.Prefix dest fw.Network
proto fw.Protocol proto fw.Protocol
srcPort *fw.Port srcPort *fw.Port
dstPort *fw.Port dstPort *fw.Port
@@ -857,7 +1269,7 @@ func TestRouteACLOrder(t *testing.T) {
name: "Drop rules take precedence over accept", name: "Drop rules take precedence over accept",
rules: []struct { rules: []struct {
sources []netip.Prefix sources []netip.Prefix
dest netip.Prefix dest fw.Network
proto fw.Protocol proto fw.Protocol
srcPort *fw.Port srcPort *fw.Port
dstPort *fw.Port dstPort *fw.Port
@@ -866,7 +1278,7 @@ func TestRouteACLOrder(t *testing.T) {
{ {
// Accept rule added first // Accept rule added first
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80, 443}}, dstPort: &fw.Port{Values: []uint16{80, 443}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -874,7 +1286,7 @@ func TestRouteACLOrder(t *testing.T) {
{ {
// Drop rule added second but should be evaluated first // Drop rule added second but should be evaluated first
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}}, dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionDrop, action: fw.ActionDrop,
@@ -912,7 +1324,7 @@ func TestRouteACLOrder(t *testing.T) {
name: "Multiple drop rules take precedence", name: "Multiple drop rules take precedence",
rules: []struct { rules: []struct {
sources []netip.Prefix sources []netip.Prefix
dest netip.Prefix dest fw.Network
proto fw.Protocol proto fw.Protocol
srcPort *fw.Port srcPort *fw.Port
dstPort *fw.Port dstPort *fw.Port
@@ -921,14 +1333,14 @@ func TestRouteACLOrder(t *testing.T) {
{ {
// Accept all // Accept all
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
dest: netip.MustParsePrefix("0.0.0.0/0"), dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
action: fw.ActionAccept, action: fw.ActionAccept,
}, },
{ {
// Drop specific port // Drop specific port
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}}, dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionDrop, action: fw.ActionDrop,
@@ -936,7 +1348,7 @@ func TestRouteACLOrder(t *testing.T) {
{ {
// Drop different port // Drop different port
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionDrop, action: fw.ActionDrop,
@@ -1015,3 +1427,50 @@ func TestRouteACLOrder(t *testing.T) {
}) })
} }
} }
func TestRouteACLSet(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.10.0.100"),
Network: netip.MustParsePrefix("100.10.0.0/16"),
}
},
}
manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
set := fw.NewDomainSet(domain.List{"example.org"})
// Add rule that uses the set (initially empty)
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
srcIP := netip.MustParseAddr("100.10.0.1")
dstIP := netip.MustParseAddr("192.168.1.100")
// Check that traffic is dropped (empty set shouldn't match anything)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
require.False(t, isAllowed, "Empty set should not allow any traffic")
err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")})
require.NoError(t, err)
// Now the packet should be allowed
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
}

View File

@@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/netflow"
"github.com/netbirdio/netbird/management/domain"
) )
var logger = log.NewFromLogrus(logrus.StandardLogger()) var logger = log.NewFromLogrus(logrus.StandardLogger())
@@ -270,11 +271,8 @@ func TestNotMatchByIP(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address { AddressFunc: func() wgaddr.Address {
return wgaddr.Address{ return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"), IP: netip.MustParseAddr("100.10.0.100"),
Network: &net.IPNet{ Network: netip.MustParsePrefix("100.10.0.0/16"),
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
} }
}, },
} }
@@ -284,10 +282,6 @@ func TestNotMatchByIP(t *testing.T) {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
} }
m.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
ip := net.ParseIP("0.0.0.0") ip := net.ParseIP("0.0.0.0")
proto := fw.ProtocolUDP proto := fw.ProtocolUDP
@@ -395,10 +389,6 @@ func TestProcessOutgoingHooks(t *testing.T) {
}, false, flowLogger) }, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close() manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger) manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
defer func() { defer func() {
@@ -508,11 +498,6 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}, false, flowLogger) }, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
}
manager.udpTracker.Close() // Close the existing tracker manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger) manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
manager.decoders = sync.Pool{ manager.decoders = sync.Pool{
@@ -711,3 +696,203 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}) })
} }
} }
func TestUpdateSetMerge(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
set := fw.NewDomainSet(domain.List{"example.org"})
initialPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("192.168.1.0/24"),
}
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
// Update the set with initial prefixes
err = manager.UpdateSet(set, initialPrefixes)
require.NoError(t, err)
// Test initial prefixes work
srcIP := netip.MustParseAddr("100.10.0.1")
dstIP1 := netip.MustParseAddr("10.0.0.100")
dstIP2 := netip.MustParseAddr("192.168.1.100")
dstIP3 := netip.MustParseAddr("172.16.0.100")
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed")
require.False(t, isAllowed3, "Traffic to 172.16.0.100 should be denied")
newPrefixes := []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("10.1.0.0/24"),
}
err = manager.UpdateSet(set, newPrefixes)
require.NoError(t, err)
// Check that all original prefixes are still included
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update")
// Check that new prefixes are included
dstIP4 := netip.MustParseAddr("172.16.1.100")
dstIP5 := netip.MustParseAddr("10.1.0.50")
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80)
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed")
require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed")
// Verify the rule has all prefixes
manager.mutex.RLock()
foundRule := false
for _, r := range manager.routeRules {
if r.id == rule.ID() {
foundRule = true
require.Len(t, r.destinations, len(initialPrefixes)+len(newPrefixes),
"Rule should have all prefixes merged")
}
}
manager.mutex.RUnlock()
require.True(t, foundRule, "Rule should be found")
}
func TestUpdateSetDeduplication(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
set := fw.NewDomainSet(domain.List{"example.org"})
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
initialPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("10.0.0.0/24"), // Duplicate
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.1.0/24"), // Duplicate
}
err = manager.UpdateSet(set, initialPrefixes)
require.NoError(t, err)
// Check the internal state for deduplication
manager.mutex.RLock()
foundRule := false
for _, r := range manager.routeRules {
if r.id == rule.ID() {
foundRule = true
// Should have deduplicated to 2 prefixes
require.Len(t, r.destinations, 2, "Duplicate prefixes should be removed")
// Check the prefixes are correct
expectedPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("192.168.1.0/24"),
}
for i, prefix := range expectedPrefixes {
require.True(t, r.destinations[i] == prefix,
"Prefix should match expected value")
}
}
}
manager.mutex.RUnlock()
require.True(t, foundRule, "Rule should be found")
// Test with overlapping prefixes of different sizes
overlappingPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/16"), // More general
netip.MustParsePrefix("10.0.0.0/24"), // More specific (already exists)
netip.MustParsePrefix("192.168.0.0/16"), // More general
netip.MustParsePrefix("192.168.1.0/24"), // More specific (already exists)
}
err = manager.UpdateSet(set, overlappingPrefixes)
require.NoError(t, err)
// Check that all prefixes are included (no deduplication of overlapping prefixes)
manager.mutex.RLock()
for _, r := range manager.routeRules {
if r.id == rule.ID() {
// Should have all 4 prefixes (2 original + 2 new more general ones)
require.Len(t, r.destinations, 4,
"Overlapping prefixes should not be deduplicated")
// Verify they're sorted correctly (more specific prefixes should come first)
prefixes := make([]string, 0, len(r.destinations))
for _, p := range r.destinations {
prefixes = append(prefixes, p.String())
}
// Check sorted order
require.Equal(t, []string{
"10.0.0.0/16",
"10.0.0.0/24",
"192.168.0.0/16",
"192.168.1.0/24",
}, prefixes, "Prefixes should be sorted")
}
}
manager.mutex.RUnlock()
// Test functionality with all prefixes
testCases := []struct {
dstIP netip.Addr
expected bool
desc string
}{
{netip.MustParseAddr("10.0.0.100"), true, "IP in both /16 and /24"},
{netip.MustParseAddr("10.0.1.100"), true, "IP only in /16"},
{netip.MustParseAddr("192.168.1.100"), true, "IP in both /16 and /24"},
{netip.MustParseAddr("192.168.2.100"), true, "IP only in /16"},
{netip.MustParseAddr("172.16.0.100"), false, "IP not in any prefix"},
}
srcIP := netip.MustParseAddr("100.10.0.1")
for _, tc := range testCases {
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80)
require.Equal(t, tc.expected, isAllowed, tc.desc)
}
}

View File

@@ -150,7 +150,7 @@ func isZeros(ip net.IP) bool {
// NewUDPMuxDefault creates an implementation of UDPMux // NewUDPMuxDefault creates an implementation of UDPMux
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
if params.Logger == nil { if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") params.Logger = getLogger()
} }
mux := &UDPMuxDefault{ mux := &UDPMuxDefault{
@@ -455,3 +455,9 @@ func newBufferHolder(size int) *bufferHolder {
buf: make([]byte, size), buf: make([]byte, size),
} }
} }
func getLogger() logging.LeveledLogger {
fac := logging.NewDefaultLoggerFactory()
//fac.Writer = log.StandardLogger().Writer()
return fac.NewLogger("ice")
}

View File

@@ -49,7 +49,7 @@ type UniversalUDPMuxParams struct {
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux // NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault { func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault {
if params.Logger == nil { if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") params.Logger = getLogger()
} }
if params.XORMappedAddrCacheTTL == 0 { if params.XORMappedAddrCacheTTL == 0 {
params.XORMappedAddrCacheTTL = time.Second * 25 params.XORMappedAddrCacheTTL = time.Second * 25
@@ -164,7 +164,7 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
return nil return nil
} }
if u.address.Network.Contains(a.AsSlice()) { if u.address.Network.Contains(a) {
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address) log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address) return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
} }

View File

@@ -12,6 +12,8 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
var zeroKey wgtypes.Key
type KernelConfigurer struct { type KernelConfigurer struct {
deviceName string deviceName string
} }
@@ -201,14 +203,71 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
func (c *KernelConfigurer) Close() { func (c *KernelConfigurer) Close() {
} }
func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) { func (c *KernelConfigurer) FullStats() (*Stats, error) {
peer, err := c.getPeer(c.deviceName, peerKey) wg, err := wgctrl.New()
if err != nil { if err != nil {
return WGStats{}, fmt.Errorf("get wireguard stats: %w", err) return nil, fmt.Errorf("wgctl: %w", err)
} }
return WGStats{ defer func() {
err = wg.Close()
if err != nil {
log.Errorf("Got error while closing wgctl: %v", err)
}
}()
wgDevice, err := wg.Device(c.deviceName)
if err != nil {
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
}
fullStats := &Stats{
DeviceName: wgDevice.Name,
PublicKey: wgDevice.PublicKey.String(),
ListenPort: wgDevice.ListenPort,
FWMark: wgDevice.FirewallMark,
Peers: []Peer{},
}
for _, p := range wgDevice.Peers {
peer := Peer{
PublicKey: p.PublicKey.String(),
AllowedIPs: p.AllowedIPs,
TxBytes: p.TransmitBytes,
RxBytes: p.ReceiveBytes,
LastHandshake: p.LastHandshakeTime,
PresharedKey: p.PresharedKey != zeroKey,
}
if p.Endpoint != nil {
peer.Endpoint = *p.Endpoint
}
fullStats.Peers = append(fullStats.Peers, peer)
}
return fullStats, nil
}
func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
stats := make(map[string]WGStats)
wg, err := wgctrl.New()
if err != nil {
return nil, fmt.Errorf("wgctl: %w", err)
}
defer func() {
err = wg.Close()
if err != nil {
log.Errorf("Got error while closing wgctl: %v", err)
}
}()
wgDevice, err := wg.Device(c.deviceName)
if err != nil {
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
}
for _, peer := range wgDevice.Peers {
stats[peer.PublicKey.String()] = WGStats{
LastHandshake: peer.LastHandshakeTime, LastHandshake: peer.LastHandshakeTime,
TxBytes: peer.TransmitBytes, TxBytes: peer.TransmitBytes,
RxBytes: peer.ReceiveBytes, RxBytes: peer.ReceiveBytes,
}, nil }
}
return stats, nil
} }

View File

@@ -1,6 +1,7 @@
package configurer package configurer
import ( import (
"encoding/base64"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"net" "net"
@@ -17,6 +18,20 @@ import (
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
const (
privateKey = "private_key"
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
ipcKeyTxBytes = "tx_bytes"
ipcKeyRxBytes = "rx_bytes"
allowedIP = "allowed_ip"
endpoint = "endpoint"
fwmark = "fwmark"
listenPort = "listen_port"
publicKey = "public_key"
presharedKey = "preshared_key"
)
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found") var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
type WGUSPConfigurer struct { type WGUSPConfigurer struct {
@@ -178,6 +193,15 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
return c.device.IpcSet(toWgUserspaceString(config)) return c.device.IpcSet(toWgUserspaceString(config))
} }
func (c *WGUSPConfigurer) FullStats() (*Stats, error) {
ipcStr, err := c.device.IpcGet()
if err != nil {
return nil, fmt.Errorf("IpcGet failed: %w", err)
}
return parseStatus(c.deviceName, ipcStr)
}
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool // startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
func (t *WGUSPConfigurer) startUAPI() { func (t *WGUSPConfigurer) startUAPI() {
var err error var err error
@@ -217,91 +241,75 @@ func (t *WGUSPConfigurer) Close() {
} }
} }
func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) { func (t *WGUSPConfigurer) GetStats() (map[string]WGStats, error) {
ipc, err := t.device.IpcGet() ipc, err := t.device.IpcGet()
if err != nil { if err != nil {
return WGStats{}, fmt.Errorf("ipc get: %w", err) return nil, fmt.Errorf("ipc get: %w", err)
} }
stats, err := findPeerInfo(ipc, peerKey, []string{ return parseTransfers(ipc)
"last_handshake_time_sec",
"last_handshake_time_nsec",
"tx_bytes",
"rx_bytes",
})
if err != nil {
return WGStats{}, fmt.Errorf("find peer info: %w", err)
} }
sec, err := strconv.ParseInt(stats["last_handshake_time_sec"], 10, 64) func parseTransfers(ipc string) (map[string]WGStats, error) {
if err != nil { stats := make(map[string]WGStats)
return WGStats{}, fmt.Errorf("parse handshake sec: %w", err) var (
} currentKey string
nsec, err := strconv.ParseInt(stats["last_handshake_time_nsec"], 10, 64) currentStats WGStats
if err != nil { hasPeer bool
return WGStats{}, fmt.Errorf("parse handshake nsec: %w", err) )
} lines := strings.Split(ipc, "\n")
txBytes, err := strconv.ParseInt(stats["tx_bytes"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse tx_bytes: %w", err)
}
rxBytes, err := strconv.ParseInt(stats["rx_bytes"], 10, 64)
if err != nil {
return WGStats{}, fmt.Errorf("parse rx_bytes: %w", err)
}
return WGStats{
LastHandshake: time.Unix(sec, nsec),
TxBytes: txBytes,
RxBytes: rxBytes,
}, nil
}
func findPeerInfo(ipcInput string, peerKey string, searchConfigKeys []string) (map[string]string, error) {
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil {
return nil, fmt.Errorf("parse key: %w", err)
}
hexKey := hex.EncodeToString(peerKeyParsed[:])
lines := strings.Split(ipcInput, "\n")
configFound := map[string]string{}
foundPeer := false
for _, line := range lines { for _, line := range lines {
line = strings.TrimSpace(line) line = strings.TrimSpace(line)
// If we're within the details of the found peer and encounter another public key, // If we're within the details of the found peer and encounter another public key,
// this means we're starting another peer's details. So, stop. // this means we're starting another peer's details. So, stop.
if strings.HasPrefix(line, "public_key=") && foundPeer { if strings.HasPrefix(line, "public_key=") {
break peerID := strings.TrimPrefix(line, "public_key=")
h, err := hex.DecodeString(peerID)
if err != nil {
return nil, fmt.Errorf("decode peerID: %w", err)
}
currentKey = base64.StdEncoding.EncodeToString(h)
currentStats = WGStats{} // Reset stats for the new peer
hasPeer = true
stats[currentKey] = currentStats
continue
} }
// Identify the peer with the specific public key if !hasPeer {
if line == fmt.Sprintf("public_key=%s", hexKey) { continue
foundPeer = true
} }
for _, key := range searchConfigKeys { key := strings.SplitN(line, "=", 2)
if foundPeer && strings.HasPrefix(line, key+"=") { if len(key) != 2 {
v := strings.SplitN(line, "=", 2) continue
configFound[v[0]] = v[1]
} }
switch key[0] {
case ipcKeyLastHandshakeTimeSec:
hs, err := toLastHandshake(key[1])
if err != nil {
return nil, err
}
currentStats.LastHandshake = hs
stats[currentKey] = currentStats
case ipcKeyRxBytes:
rxBytes, err := toBytes(key[1])
if err != nil {
return nil, fmt.Errorf("parse rx_bytes: %w", err)
}
currentStats.RxBytes = rxBytes
stats[currentKey] = currentStats
case ipcKeyTxBytes:
TxBytes, err := toBytes(key[1])
if err != nil {
return nil, fmt.Errorf("parse tx_bytes: %w", err)
}
currentStats.TxBytes = TxBytes
stats[currentKey] = currentStats
} }
} }
// todo: use multierr return stats, nil
for _, key := range searchConfigKeys {
if _, ok := configFound[key]; !ok {
return configFound, fmt.Errorf("config key not found: %s", key)
}
}
if !foundPeer {
return nil, fmt.Errorf("%w: %s", ErrPeerNotFound, peerKey)
}
return configFound, nil
} }
func toWgUserspaceString(wgCfg wgtypes.Config) string { func toWgUserspaceString(wgCfg wgtypes.Config) string {
@@ -355,9 +363,154 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
return sb.String() return sb.String()
} }
func toLastHandshake(stringVar string) (time.Time, error) {
sec, err := strconv.ParseInt(stringVar, 10, 64)
if err != nil {
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
}
return time.Unix(sec, 0), nil
}
func toBytes(s string) (int64, error) {
return strconv.ParseInt(s, 10, 64)
}
func getFwmark() int { func getFwmark() int {
if nbnet.AdvancedRouting() { if nbnet.AdvancedRouting() {
return nbnet.NetbirdFwmark return nbnet.ControlPlaneMark
} }
return 0 return 0
} }
func hexToWireguardKey(hexKey string) (wgtypes.Key, error) {
// Decode hex string to bytes
keyBytes, err := hex.DecodeString(hexKey)
if err != nil {
return wgtypes.Key{}, fmt.Errorf("failed to decode hex key: %w", err)
}
// Check if we have the right number of bytes (WireGuard keys are 32 bytes)
if len(keyBytes) != 32 {
return wgtypes.Key{}, fmt.Errorf("invalid key length: expected 32 bytes, got %d", len(keyBytes))
}
// Convert to wgtypes.Key
var key wgtypes.Key
copy(key[:], keyBytes)
return key, nil
}
func parseStatus(deviceName, ipcStr string) (*Stats, error) {
stats := &Stats{DeviceName: deviceName}
var currentPeer *Peer
for _, line := range strings.Split(strings.TrimSpace(ipcStr), "\n") {
if line == "" {
continue
}
parts := strings.SplitN(line, "=", 2)
if len(parts) != 2 {
continue
}
key := parts[0]
val := parts[1]
switch key {
case privateKey:
key, err := hexToWireguardKey(val)
if err != nil {
log.Errorf("failed to parse private key: %v", err)
continue
}
stats.PublicKey = key.PublicKey().String()
case publicKey:
// Save previous peer
if currentPeer != nil {
stats.Peers = append(stats.Peers, *currentPeer)
}
key, err := hexToWireguardKey(val)
if err != nil {
log.Errorf("failed to parse public key: %v", err)
continue
}
currentPeer = &Peer{
PublicKey: key.String(),
}
case listenPort:
if port, err := strconv.Atoi(val); err == nil {
stats.ListenPort = port
}
case fwmark:
if fwmark, err := strconv.Atoi(val); err == nil {
stats.FWMark = fwmark
}
case endpoint:
if currentPeer == nil {
continue
}
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
if err != nil {
log.Errorf("failed to parse endpoint: %v", err)
continue
}
port, err := strconv.Atoi(portStr)
if err != nil {
log.Errorf("failed to parse endpoint port: %v", err)
continue
}
currentPeer.Endpoint = net.UDPAddr{
IP: net.ParseIP(host),
Port: port,
}
case allowedIP:
if currentPeer == nil {
continue
}
_, ipnet, err := net.ParseCIDR(val)
if err == nil {
currentPeer.AllowedIPs = append(currentPeer.AllowedIPs, *ipnet)
}
case ipcKeyTxBytes:
if currentPeer == nil {
continue
}
rxBytes, err := toBytes(val)
if err != nil {
continue
}
currentPeer.TxBytes = rxBytes
case ipcKeyRxBytes:
if currentPeer == nil {
continue
}
rxBytes, err := toBytes(val)
if err != nil {
continue
}
currentPeer.RxBytes = rxBytes
case ipcKeyLastHandshakeTimeSec:
if currentPeer == nil {
continue
}
ts, err := toLastHandshake(val)
if err != nil {
continue
}
currentPeer.LastHandshake = ts
case presharedKey:
if currentPeer == nil {
continue
}
if val != "" {
currentPeer.PresharedKey = true
}
}
}
if currentPeer != nil {
stats.Peers = append(stats.Peers, *currentPeer)
}
return stats, nil
}

View File

@@ -2,10 +2,8 @@ package configurer
import ( import (
"encoding/hex" "encoding/hex"
"fmt"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@@ -34,58 +32,35 @@ errno=0
` `
func Test_findPeerInfo(t *testing.T) { func Test_parseTransfers(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
peerKey string peerKey string
searchKeys []string want WGStats
want map[string]string
wantErr bool
}{ }{
{ {
name: "single", name: "single",
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376", peerKey: "b85996fecc9c7f1fc6d2572a76eda11d59bcd20be8e543b15ce4bd85a8e75a33",
searchKeys: []string{"tx_bytes"}, want: WGStats{
want: map[string]string{ TxBytes: 0,
"tx_bytes": "38333", RxBytes: 0,
}, },
wantErr: false,
}, },
{ {
name: "multiple", name: "multiple",
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376", peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
searchKeys: []string{"tx_bytes", "rx_bytes"}, want: WGStats{
want: map[string]string{ TxBytes: 38333,
"tx_bytes": "38333", RxBytes: 2224,
"rx_bytes": "2224",
}, },
wantErr: false,
}, },
{ {
name: "lastpeer", name: "lastpeer",
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58", peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
searchKeys: []string{"tx_bytes", "rx_bytes"}, want: WGStats{
want: map[string]string{ TxBytes: 1212111,
"tx_bytes": "1212111", RxBytes: 1929999999,
"rx_bytes": "1929999999",
}, },
wantErr: false,
},
{
name: "peer not found",
peerKey: "1111111111111111111111111111111111111111111111111111111111111111",
searchKeys: nil,
want: nil,
wantErr: true,
},
{
name: "key not found",
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
searchKeys: []string{"tx_bytes", "unknown_key"},
want: map[string]string{
"tx_bytes": "1212111",
},
wantErr: true,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
@@ -96,9 +71,19 @@ func Test_findPeerInfo(t *testing.T) {
key, err := wgtypes.NewKey(res) key, err := wgtypes.NewKey(res)
require.NoError(t, err) require.NoError(t, err)
got, err := findPeerInfo(ipcFixture, key.String(), tt.searchKeys) stats, err := parseTransfers(ipcFixture)
assert.Equalf(t, tt.wantErr, err != nil, fmt.Sprintf("findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys)) if err != nil {
assert.Equalf(t, tt.want, got, "findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys) require.NoError(t, err)
return
}
stat, ok := stats[key.String()]
if !ok {
require.True(t, ok)
return
}
require.Equal(t, tt.want, stat)
}) })
} }
} }

View File

@@ -0,0 +1,24 @@
package configurer
import (
"net"
"time"
)
type Peer struct {
PublicKey string
Endpoint net.UDPAddr
AllowedIPs []net.IPNet
TxBytes int64
RxBytes int64
LastHandshake time.Time
PresharedKey bool
}
type Stats struct {
DeviceName string
PublicKey string
ListenPort int
FWMark int
Peers []Peer
}

View File

@@ -14,6 +14,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/management/domain"
) )
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform // WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
@@ -43,11 +44,11 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind
} }
} }
func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) { func (t *WGTunDevice) Create(routes []string, dns string, searchDomains domain.List) (WGConfigurer, error) {
log.Info("create tun interface") log.Info("create tun interface")
routesString := routesToString(routes) routesString := routesToString(routes)
searchDomainsToString := searchDomainsToString(searchDomains) searchDomainsToString := searchDomainsToString(searchDomains.ToPunycodeList())
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString) fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
if err != nil { if err != nil {

View File

@@ -1,7 +1,6 @@
package device package device
import ( import (
"net"
"net/netip" "net/netip"
"sync" "sync"
@@ -24,9 +23,6 @@ type PacketFilter interface {
// RemovePacketHook removes hook by ID // RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error RemovePacketHook(hookID string) error
// SetNetwork of the wireguard interface to which filtering applied
SetNetwork(*net.IPNet)
} }
// FilteredDevice to override Read or Write of packets // FilteredDevice to override Read or Write of packets

View File

@@ -51,7 +51,11 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
log.Info("create nbnetstack tun interface") log.Info("create nbnetstack tun interface")
// TODO: get from service listener runtime IP // TODO: get from service listener runtime IP
dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1) dnsAddr, err := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
if err != nil {
return nil, fmt.Errorf("last ip: %w", err)
}
log.Debugf("netstack using address: %s", t.address.IP) log.Debugf("netstack using address: %s", t.address.IP)
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu) t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
log.Debugf("netstack using dns address: %s", dnsAddr) log.Debugf("netstack using dns address: %s", dnsAddr)

View File

@@ -16,5 +16,6 @@ type WGConfigurer interface {
AddAllowedIP(peerKey string, allowedIP string) error AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error RemoveAllowedIP(peerKey string, allowedIP string) error
Close() Close()
GetStats(peerKey string) (configurer.WGStats, error) GetStats() (map[string]configurer.WGStats, error)
FullStats() (*configurer.Stats, error)
} }

View File

@@ -64,7 +64,15 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error {
} }
ip := address.IP.String() ip := address.IP.String()
mask := "0x" + address.Network.Mask.String()
// Convert prefix length to hex netmask
prefixLen := address.Network.Bits()
if !address.IP.Is4() {
return fmt.Errorf("IPv6 not supported for interface assignment")
}
maskBits := uint32(0xffffffff) << (32 - prefixLen)
mask := fmt.Sprintf("0x%08x", maskBits)
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name) log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)

View File

@@ -8,10 +8,11 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/management/domain"
) )
type WGTunDevice interface { type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error) Create(routes []string, dns string, searchDomains domain.List) (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address wgaddr.Address) error UpdateAddr(address wgaddr.Address) error
WgAddress() wgaddr.Address WgAddress() wgaddr.Address

View File

@@ -185,7 +185,6 @@ func (w *WGIface) SetFilter(filter device.PacketFilter) error {
} }
w.filter = filter w.filter = filter
w.filter.SetNetwork(w.tun.WgAddress().Network)
w.tun.FilteredDevice().SetFilter(filter) w.tun.FilteredDevice().SetFilter(filter)
return nil return nil
@@ -212,9 +211,13 @@ func (w *WGIface) GetWGDevice() *wgdevice.Device {
return w.tun.Device() return w.tun.Device()
} }
// GetStats returns the last handshake time, rx and tx bytes for the given peer // GetStats returns the last handshake time, rx and tx bytes
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) { func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) {
return w.configurer.GetStats(peerKey) return w.configurer.GetStats()
}
func (w *WGIface) FullStats() (*configurer.Stats, error) {
return w.configurer.FullStats()
} }
func (w *WGIface) waitUntilRemoved() error { func (w *WGIface) waitUntilRemoved() error {

View File

@@ -2,7 +2,11 @@
package iface package iface
import "fmt" import (
"fmt"
"github.com/netbirdio/netbird/management/domain"
)
// Create creates a new Wireguard interface, sets a given IP and brings it up. // Create creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one. // Will reuse an existing one.
@@ -21,6 +25,6 @@ func (w *WGIface) Create() error {
} }
// CreateOnAndroid this function make sense on mobile only // CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error { func (w *WGIface) CreateOnAndroid([]string, string, domain.List) error {
return fmt.Errorf("this function has not implemented on non mobile") return fmt.Errorf("this function has not implemented on non mobile")
} }

View File

@@ -2,11 +2,13 @@ package iface
import ( import (
"fmt" "fmt"
"github.com/netbirdio/netbird/management/domain"
) )
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up. // CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one. // Will reuse an existing one.
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error { func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains domain.List) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()

View File

@@ -7,6 +7,8 @@ import (
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
"github.com/netbirdio/netbird/management/domain"
) )
// Create creates a new Wireguard interface, sets a given IP and brings it up. // Create creates a new Wireguard interface, sets a given IP and brings it up.
@@ -36,6 +38,6 @@ func (w *WGIface) Create() error {
} }
// CreateOnAndroid this function make sense on mobile only // CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error { func (w *WGIface) CreateOnAndroid([]string, string, domain.List) error {
return fmt.Errorf("this function has not implemented on this platform") return fmt.Errorf("this function has not implemented on this platform")
} }

View File

@@ -5,7 +5,6 @@
package mocks package mocks
import ( import (
net "net"
"net/netip" "net/netip"
reflect "reflect" reflect "reflect"
@@ -90,15 +89,3 @@ func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomo
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
} }
// SetNetwork mocks base method.
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetNetwork", arg0)
}
// SetNetwork indicates an expected call of SetNetwork.
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
}

View File

@@ -1,8 +1,6 @@
package netstack package netstack
import ( import (
"fmt"
"net"
"net/netip" "net/netip"
"os" "os"
"strconv" "strconv"
@@ -15,8 +13,8 @@ import (
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY" const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
type NetStackTun struct { //nolint:revive type NetStackTun struct { //nolint:revive
address net.IP address netip.Addr
dnsAddress net.IP dnsAddress netip.Addr
mtu int mtu int
listenAddress string listenAddress string
@@ -24,7 +22,7 @@ type NetStackTun struct { //nolint:revive
tundev tun.Device tundev tun.Device
} }
func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu int) *NetStackTun { func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun {
return &NetStackTun{ return &NetStackTun{
address: address, address: address,
dnsAddress: dnsAddress, dnsAddress: dnsAddress,
@@ -34,19 +32,9 @@ func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu
} }
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) { func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
addr, ok := netip.AddrFromSlice(t.address)
if !ok {
return nil, nil, fmt.Errorf("convert address to netip.Addr: %v", t.address)
}
dnsAddr, ok := netip.AddrFromSlice(t.dnsAddress)
if !ok {
return nil, nil, fmt.Errorf("convert dns address to netip.Addr: %v", t.dnsAddress)
}
nsTunDev, tunNet, err := netstack.CreateNetTUN( nsTunDev, tunNet, err := netstack.CreateNetTUN(
[]netip.Addr{addr.Unmap()}, []netip.Addr{t.address},
[]netip.Addr{dnsAddr.Unmap()}, []netip.Addr{t.dnsAddress},
t.mtu) t.mtu)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err

View File

@@ -2,28 +2,27 @@ package wgaddr
import ( import (
"fmt" "fmt"
"net" "net/netip"
) )
// Address WireGuard parsed address // Address WireGuard parsed address
type Address struct { type Address struct {
IP net.IP IP netip.Addr
Network *net.IPNet Network netip.Prefix
} }
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address // ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
func ParseWGAddress(address string) (Address, error) { func ParseWGAddress(address string) (Address, error) {
ip, network, err := net.ParseCIDR(address) prefix, err := netip.ParsePrefix(address)
if err != nil { if err != nil {
return Address{}, err return Address{}, err
} }
return Address{ return Address{
IP: ip, IP: prefix.Addr().Unmap(),
Network: network, Network: prefix.Masked(),
}, nil }, nil
} }
func (addr Address) String() string { func (addr Address) String() string {
maskSize, _ := addr.Network.Mask.Size() return fmt.Sprintf("%s/%d", addr.IP.String(), addr.Network.Bits())
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
} }

View File

@@ -24,6 +24,8 @@
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run" !define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
!define NETBIRD_DATA_DIR "$COMMONPROGRAMDATA\Netbird"
Unicode True Unicode True
###################################################################### ######################################################################
@@ -49,6 +51,10 @@ ShowInstDetails Show
###################################################################### ######################################################################
!include "MUI2.nsh"
!include LogicLib.nsh
!include "nsDialogs.nsh"
!define MUI_ICON "${ICON}" !define MUI_ICON "${ICON}"
!define MUI_UNICON "${ICON}" !define MUI_UNICON "${ICON}"
!define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}" !define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}"
@@ -58,9 +64,6 @@ ShowInstDetails Show
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink" !define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink"
###################################################################### ######################################################################
!include "MUI2.nsh"
!include LogicLib.nsh
!define MUI_ABORTWARNING !define MUI_ABORTWARNING
!define MUI_UNABORTWARNING !define MUI_UNABORTWARNING
@@ -70,13 +73,16 @@ ShowInstDetails Show
!insertmacro MUI_PAGE_DIRECTORY !insertmacro MUI_PAGE_DIRECTORY
; Custom page for autostart checkbox
Page custom AutostartPage AutostartPageLeave Page custom AutostartPage AutostartPageLeave
!insertmacro MUI_PAGE_INSTFILES !insertmacro MUI_PAGE_INSTFILES
!insertmacro MUI_PAGE_FINISH !insertmacro MUI_PAGE_FINISH
!insertmacro MUI_UNPAGE_WELCOME
UninstPage custom un.DeleteDataPage un.DeleteDataPageLeave
!insertmacro MUI_UNPAGE_CONFIRM !insertmacro MUI_UNPAGE_CONFIRM
!insertmacro MUI_UNPAGE_INSTFILES !insertmacro MUI_UNPAGE_INSTFILES
@@ -89,6 +95,10 @@ Page custom AutostartPage AutostartPageLeave
Var AutostartCheckbox Var AutostartCheckbox
Var AutostartEnabled Var AutostartEnabled
; Variables for uninstall data deletion option
Var DeleteDataCheckbox
Var DeleteDataEnabled
###################################################################### ######################################################################
; Function to create the autostart options page ; Function to create the autostart options page
@@ -104,8 +114,8 @@ Function AutostartPage
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts" ${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
Pop $AutostartCheckbox Pop $AutostartCheckbox
${NSD_Check} $AutostartCheckbox ; Default to checked ${NSD_Check} $AutostartCheckbox
StrCpy $AutostartEnabled "1" ; Default to enabled StrCpy $AutostartEnabled "1"
nsDialogs::Show nsDialogs::Show
FunctionEnd FunctionEnd
@@ -115,6 +125,30 @@ Function AutostartPageLeave
${NSD_GetState} $AutostartCheckbox $AutostartEnabled ${NSD_GetState} $AutostartCheckbox $AutostartEnabled
FunctionEnd FunctionEnd
; Function to create the uninstall data deletion page
Function un.DeleteDataPage
!insertmacro MUI_HEADER_TEXT "Uninstall Options" "Choose whether to delete ${APP_NAME} data."
nsDialogs::Create 1018
Pop $0
${If} $0 == error
Abort
${EndIf}
${NSD_CreateCheckbox} 0 20u 100% 10u "Delete all ${APP_NAME} configuration and state data (${NETBIRD_DATA_DIR})"
Pop $DeleteDataCheckbox
${NSD_Uncheck} $DeleteDataCheckbox
StrCpy $DeleteDataEnabled "0"
nsDialogs::Show
FunctionEnd
; Function to handle leaving the data deletion page
Function un.DeleteDataPageLeave
${NSD_GetState} $DeleteDataCheckbox $DeleteDataEnabled
FunctionEnd
Function GetAppFromCommand Function GetAppFromCommand
Exch $1 Exch $1
Push $2 Push $2
@@ -225,31 +259,58 @@ SectionEnd
Section Uninstall Section Uninstall
${INSTALL_TYPE} ${INSTALL_TYPE}
DetailPrint "Stopping Netbird service..."
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop' ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
DetailPrint "Uninstalling Netbird service..."
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall' ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
# kill ui client DetailPrint "Terminating Netbird UI process..."
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f` ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart registry entry ; Remove autostart registry entry
DetailPrint "Removing autostart registry entry if exists..."
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Handle data deletion based on checkbox
DetailPrint "Checking if user requested data deletion..."
${If} $DeleteDataEnabled == "1"
DetailPrint "User opted to delete Netbird data. Removing ${NETBIRD_DATA_DIR}..."
ClearErrors
RMDir /r "${NETBIRD_DATA_DIR}"
IfErrors 0 +2 ; If no errors, jump over the message
DetailPrint "Error deleting Netbird data directory. It might be in use or already removed."
DetailPrint "Netbird data directory removal complete."
${Else}
DetailPrint "User did not opt to delete Netbird data."
${EndIf}
# wait the service uninstall take unblock the executable # wait the service uninstall take unblock the executable
DetailPrint "Waiting for service handle to be released..."
Sleep 3000 Sleep 3000
DetailPrint "Deleting application files..."
Delete "$INSTDIR\${UI_APP_EXE}" Delete "$INSTDIR\${UI_APP_EXE}"
Delete "$INSTDIR\${MAIN_APP_EXE}" Delete "$INSTDIR\${MAIN_APP_EXE}"
Delete "$INSTDIR\wintun.dll" Delete "$INSTDIR\wintun.dll"
Delete "$INSTDIR\opengl32.dll" Delete "$INSTDIR\opengl32.dll"
DetailPrint "Removing application directory..."
RmDir /r "$INSTDIR" RmDir /r "$INSTDIR"
DetailPrint "Removing shortcuts..."
SetShellVarContext all SetShellVarContext all
Delete "$DESKTOP\${APP_NAME}.lnk" Delete "$DESKTOP\${APP_NAME}.lnk"
Delete "$SMPROGRAMS\${APP_NAME}.lnk" Delete "$SMPROGRAMS\${APP_NAME}.lnk"
DetailPrint "Removing registry keys..."
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}" DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}" DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}"
DetailPrint "Removing application directory from PATH..."
EnVar::SetHKLM EnVar::SetHKLM
EnVar::DeleteValue "path" "$INSTDIR" EnVar::DeleteValue "path" "$INSTDIR"
DetailPrint "Uninstallation finished."
SectionEnd SectionEnd

View File

@@ -18,7 +18,7 @@ func (r RuleID) ID() string {
func GenerateRouteRuleKey( func GenerateRouteRuleKey(
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination manager.Network,
proto manager.Protocol, proto manager.Protocol,
sPort *manager.Port, sPort *manager.Port,
dPort *manager.Port, dPort *manager.Port,

View File

@@ -18,6 +18,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
) )
@@ -25,7 +26,7 @@ var ErrSourceRangesEmpty = errors.New("sources range is empty")
// Manager is a ACL rules manager // Manager is a ACL rules manager
type Manager interface { type Manager interface {
ApplyFiltering(networkMap *mgmProto.NetworkMap) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
} }
type protoMatch struct { type protoMatch struct {
@@ -53,10 +54,15 @@ func NewDefaultManager(fm firewall.Manager) *DefaultManager {
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy. // ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
// //
// If allowByDefault is true it appends allow ALL traffic rules to input and output chains. // If allowByDefault is true it appends allow ALL traffic rules to input and output chains.
func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) {
d.mutex.Lock() d.mutex.Lock()
defer d.mutex.Unlock() defer d.mutex.Unlock()
if d.firewall == nil {
log.Debug("firewall manager is not supported, skipping firewall rules")
return
}
start := time.Now() start := time.Now()
defer func() { defer func() {
total := 0 total := 0
@@ -68,21 +74,9 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
time.Since(start), total) time.Since(start), total)
}() }()
if d.firewall == nil {
log.Debug("firewall manager is not supported, skipping firewall rules")
return
}
d.applyPeerACLs(networkMap) d.applyPeerACLs(networkMap)
// If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag, if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
// then the mgmt server is older than the client, and we need to allow all traffic for routes
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
if err := d.firewall.SetLegacyManagement(isLegacy); err != nil {
log.Errorf("failed to set legacy management flag: %v", err)
}
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules); err != nil {
log.Errorf("Failed to apply route ACLs: %v", err) log.Errorf("Failed to apply route ACLs: %v", err)
} }
@@ -176,16 +170,16 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
d.peerRulesPairs = newRulePairs d.peerRulesPairs = newRulePairs
} }
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error { func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dynamicResolver bool) error {
newRouteRules := make(map[id.RuleID]struct{}, len(rules)) newRouteRules := make(map[id.RuleID]struct{}, len(rules))
var merr *multierror.Error var merr *multierror.Error
// Apply new rules - firewall manager will return existing rule ID if already present // Apply new rules - firewall manager will return existing rule ID if already present
for _, rule := range rules { for _, rule := range rules {
id, err := d.applyRouteACL(rule) id, err := d.applyRouteACL(rule, dynamicResolver)
if err != nil { if err != nil {
if errors.Is(err, ErrSourceRangesEmpty) { if errors.Is(err, ErrSourceRangesEmpty) {
log.Debugf("skipping empty rule with destination %s: %v", rule.Destination, err) log.Debugf("skipping empty sources rule with destination %s: %v", rule.Destination, err)
} else { } else {
merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err)) merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err))
} }
@@ -208,7 +202,7 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) err
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) { func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (id.RuleID, error) {
if len(rule.SourceRanges) == 0 { if len(rule.SourceRanges) == 0 {
return "", ErrSourceRangesEmpty return "", ErrSourceRangesEmpty
} }
@@ -222,15 +216,9 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul
sources = append(sources, source) sources = append(sources, source)
} }
var destination netip.Prefix destination, err := determineDestination(rule, dynamicResolver, sources)
if rule.IsDynamic {
destination = getDefault(sources[0])
} else {
var err error
destination, err = netip.ParsePrefix(rule.Destination)
if err != nil { if err != nil {
return "", fmt.Errorf("parse destination: %w", err) return "", fmt.Errorf("determine destination: %w", err)
}
} }
protocol, err := convertToFirewallProtocol(rule.Protocol) protocol, err := convertToFirewallProtocol(rule.Protocol)
@@ -296,8 +284,10 @@ func (d *DefaultManager) protoRuleToFirewallRule(
case mgmProto.RuleDirection_IN: case mgmProto.RuleDirection_IN:
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName) rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
case mgmProto.RuleDirection_OUT: case mgmProto.RuleDirection_OUT:
// TODO: Remove this soon. Outbound rules are obsolete. if d.firewall.IsStateful() {
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already return "", nil, nil
}
// return traffic for outbound connections if firewall is stateless
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName) rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
default: default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule") return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
@@ -580,6 +570,33 @@ func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port {
return nil return nil
} }
func determineDestination(rule *mgmProto.RouteFirewallRule, dynamicResolver bool, sources []netip.Prefix) (firewall.Network, error) {
var destination firewall.Network
if rule.IsDynamic {
if dynamicResolver {
if len(rule.Domains) > 0 {
destination.Set = firewall.NewDomainSet(domain.FromPunycodeList(rule.Domains))
} else {
// isDynamic is set but no domains = outdated management server
log.Warn("connected to an older version of management server (no domains in rules), using default destination")
destination.Prefix = getDefault(sources[0])
}
} else {
// client resolves DNS, we (router) don't know the destination
destination.Prefix = getDefault(sources[0])
}
return destination, nil
}
prefix, err := netip.ParsePrefix(rule.Destination)
if err != nil {
return destination, fmt.Errorf("parse destination: %w", err)
}
destination.Prefix = prefix
return destination, nil
}
func getDefault(prefix netip.Prefix) netip.Prefix { func getDefault(prefix netip.Prefix) netip.Prefix {
if prefix.Addr().Is6() { if prefix.Addr().Is6() {
return netip.PrefixFrom(netip.IPv6Unspecified(), 0) return netip.PrefixFrom(netip.IPv6Unspecified(), 0)

View File

@@ -1,13 +1,14 @@
package acl package acl
import ( import (
"net" "net/netip"
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks" "github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/netflow"
@@ -42,35 +43,31 @@ func TestDefaultManager(t *testing.T) {
ifaceMock := mocks.NewMockIFaceMapper(ctrl) ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes() ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any()) ifaceMock.EXPECT().SetFilter(gomock.Any())
ip, network, err := net.ParseCIDR("172.0.0.1/32") network := netip.MustParsePrefix("172.0.0.1/32")
if err != nil {
t.Fatalf("failed to parse IP address: %v", err)
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes() ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{ ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: ip, IP: network.Addr(),
Network: network, Network: network,
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
if err != nil { require.NoError(t, err)
t.Errorf("create firewall: %v", err) defer func() {
return err = fw.Close(nil)
} require.NoError(t, err)
defer func(fw manager.Manager) { }()
_ = fw.Close(nil)
}(fw)
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)
t.Run("apply firewall rules", func(t *testing.T) { t.Run("apply firewall rules", func(t *testing.T) {
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap, false)
if len(acl.peerRulesPairs) != 2 { if fw.IsStateful() {
t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs) assert.Equal(t, 0, len(acl.peerRulesPairs))
return } else {
assert.Equal(t, 2, len(acl.peerRulesPairs))
} }
}) })
@@ -92,14 +89,15 @@ func TestDefaultManager(t *testing.T) {
}, },
) )
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap, false)
// we should have one old and one new rule in the existed rules expectedRules := 2
if len(acl.peerRulesPairs) != 2 { if fw.IsStateful() {
t.Errorf("firewall rules not applied") expectedRules = 1 // only the inbound rule
return
} }
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
// check that old rule was removed // check that old rule was removed
previousCount := 0 previousCount := 0
for id := range acl.peerRulesPairs { for id := range acl.peerRulesPairs {
@@ -107,26 +105,86 @@ func TestDefaultManager(t *testing.T) {
previousCount++ previousCount++
} }
} }
if previousCount != 1 {
t.Errorf("old rule was not removed") expectedPreviousCount := 0
if !fw.IsStateful() {
expectedPreviousCount = 1
} }
assert.Equal(t, expectedPreviousCount, previousCount)
}) })
t.Run("handle default rules", func(t *testing.T) { t.Run("handle default rules", func(t *testing.T) {
networkMap.FirewallRules = networkMap.FirewallRules[:0] networkMap.FirewallRules = networkMap.FirewallRules[:0]
networkMap.FirewallRulesIsEmpty = true networkMap.FirewallRulesIsEmpty = true
if acl.ApplyFiltering(networkMap); len(acl.peerRulesPairs) != 0 { acl.ApplyFiltering(networkMap, false)
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs)) assert.Equal(t, 0, len(acl.peerRulesPairs))
return
}
networkMap.FirewallRulesIsEmpty = false networkMap.FirewallRulesIsEmpty = false
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap, false)
if len(acl.peerRulesPairs) != 1 {
t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs)) expectedRules := 1
return if fw.IsStateful() {
expectedRules = 1 // only inbound allow-all rule
} }
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
})
}
func TestDefaultManagerStateless(t *testing.T) {
// stateless currently only in userspace, so we have to disable kernel
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv("NB_DISABLE_CONNTRACK", "true")
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_OUT,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "80",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
Port: "53",
},
},
}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
require.NoError(t, err)
defer func() {
err = fw.Close(nil)
require.NoError(t, err)
}()
acl := NewDefaultManager(fw)
t.Run("stateless firewall creates outbound rules", func(t *testing.T) {
acl.ApplyFiltering(networkMap, false)
// In stateless mode, we should have both inbound and outbound rules
assert.False(t, fw.IsStateful())
assert.Equal(t, 2, len(acl.peerRulesPairs))
}) })
} }
@@ -192,42 +250,19 @@ func TestDefaultManagerSquashRules(t *testing.T) {
manager := &DefaultManager{} manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap) rules, _ := manager.squashAcceptRules(networkMap)
if len(rules) != 2 { assert.Equal(t, 2, len(rules))
t.Errorf("rules should contain 2, got: %v", rules)
return
}
r := rules[0] r := rules[0]
switch { assert.Equal(t, "0.0.0.0", r.PeerIP)
case r.PeerIP != "0.0.0.0": assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction)
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP) assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
return assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
case r.Direction != mgmProto.RuleDirection_IN:
t.Errorf("direction should be IN, got: %v", r.Direction)
return
case r.Protocol != mgmProto.RuleProtocol_ALL:
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return
case r.Action != mgmProto.RuleAction_ACCEPT:
t.Errorf("action should be ACCEPT, got: %v", r.Action)
return
}
r = rules[1] r = rules[1]
switch { assert.Equal(t, "0.0.0.0", r.PeerIP)
case r.PeerIP != "0.0.0.0": assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction)
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP) assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
return assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
case r.Direction != mgmProto.RuleDirection_OUT:
t.Errorf("direction should be OUT, got: %v", r.Direction)
return
case r.Protocol != mgmProto.RuleProtocol_ALL:
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
return
case r.Action != mgmProto.RuleAction_ACCEPT:
t.Errorf("action should be ACCEPT, got: %v", r.Action)
return
}
} }
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) { func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
@@ -291,9 +326,8 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
} }
manager := &DefaultManager{} manager := &DefaultManager{}
if rules, _ := manager.squashAcceptRules(networkMap); len(rules) != len(networkMap.FirewallRules) { rules, _ := manager.squashAcceptRules(networkMap)
t.Errorf("we should get the same amount of rules as output, got %v", len(rules)) assert.Equal(t, len(networkMap.FirewallRules), len(rules))
}
} }
func TestDefaultManagerEnableSSHRules(t *testing.T) { func TestDefaultManagerEnableSSHRules(t *testing.T) {
@@ -336,33 +370,29 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
ifaceMock := mocks.NewMockIFaceMapper(ctrl) ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes() ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any()) ifaceMock.EXPECT().SetFilter(gomock.Any())
ip, network, err := net.ParseCIDR("172.0.0.1/32") network := netip.MustParsePrefix("172.0.0.1/32")
if err != nil {
t.Fatalf("failed to parse IP address: %v", err)
}
ifaceMock.EXPECT().Name().Return("lo").AnyTimes() ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{ ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: ip, IP: network.Addr(),
Network: network, Network: network,
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
if err != nil { require.NoError(t, err)
t.Errorf("create firewall: %v", err) defer func() {
return err = fw.Close(nil)
} require.NoError(t, err)
defer func(fw manager.Manager) { }()
_ = fw.Close(nil)
}(fw)
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap, false)
if len(acl.peerRulesPairs) != 3 { expectedRules := 3
t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs)) if fw.IsStateful() {
return expectedRules = 3 // 2 inbound rules + SSH rule
} }
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
} }

View File

@@ -64,13 +64,8 @@ func (t TokenInfo) GetTokenToUse() string {
// and if that also fails, the authentication process is deemed unsuccessful // and if that also fails, the authentication process is deemed unsuccessful
// //
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow // On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopClient bool) (OAuthFlow, error) { func NewOAuthFlow(ctx context.Context, config *internal.Config, isUnixDesktopClient bool) (OAuthFlow, error) {
if runtime.GOOS == "linux" && !isLinuxDesktopClient { if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
return authenticateWithDeviceCodeFlow(ctx, config)
}
// On FreeBSD we currently do not support desktop environments and offer only Device Code Flow (#2384)
if runtime.GOOS == "freebsd" {
return authenticateWithDeviceCodeFlow(ctx, config) return authenticateWithDeviceCodeFlow(ctx, config)
} }

View File

@@ -94,12 +94,22 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
p.codeVerifier = codeVerifier p.codeVerifier = codeVerifier
codeChallenge := createCodeChallenge(codeVerifier) codeChallenge := createCodeChallenge(codeVerifier)
authURL := p.oAuthConfig.AuthCodeURL(
state, params := []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_challenge_method", "S256"), oauth2.SetAuthURLParam("code_challenge_method", "S256"),
oauth2.SetAuthURLParam("code_challenge", codeChallenge), oauth2.SetAuthURLParam("code_challenge", codeChallenge),
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience), oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
) }
if !p.providerConfig.DisablePromptLogin {
if p.providerConfig.LoginFlag.IsPromptLogin() {
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
}
if p.providerConfig.LoginFlag.IsMaxAge0Login() {
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
}
}
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
return AuthFlowInfo{ return AuthFlowInfo{
VerificationURIComplete: authURL, VerificationURIComplete: authURL,

View File

@@ -0,0 +1,71 @@
package auth
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
mgm "github.com/netbirdio/netbird/management/client/common"
)
func TestPromptLogin(t *testing.T) {
const (
promptLogin = "prompt=login"
maxAge0 = "max_age=0"
)
tt := []struct {
name string
loginFlag mgm.LoginFlag
disablePromptLogin bool
expect string
}{
{
name: "Prompt login",
loginFlag: mgm.LoginFlagPrompt,
expect: promptLogin,
},
{
name: "Max age 0 login",
loginFlag: mgm.LoginFlagMaxAge0,
expect: maxAge0,
},
{
name: "Disable prompt login",
loginFlag: mgm.LoginFlagPrompt,
disablePromptLogin: true,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
config := internal.PKCEAuthProviderConfig{
ClientID: "test-client-id",
Audience: "test-audience",
TokenEndpoint: "https://test-token-endpoint.com/token",
Scope: "openid email profile",
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
RedirectURLs: []string{"http://127.0.0.1:33992/"},
UseIDToken: true,
LoginFlag: tc.loginFlag,
}
pkce, err := NewPKCEAuthorizationFlow(config)
if err != nil {
t.Fatalf("Failed to create PKCEAuthorizationFlow: %v", err)
}
authInfo, err := pkce.RequestAuthInfo(context.Background())
if err != nil {
t.Fatalf("Failed to request auth info: %v", err)
}
if !tc.disablePromptLogin {
require.Contains(t, authInfo.VerificationURIComplete, tc.expect)
} else {
require.Contains(t, authInfo.VerificationURIComplete, promptLogin)
require.NotContains(t, authInfo.VerificationURIComplete, maxAge0)
}
})
}
}

View File

@@ -68,12 +68,14 @@ type ConfigInput struct {
DisableServerRoutes *bool DisableServerRoutes *bool
DisableDNS *bool DisableDNS *bool
DisableFirewall *bool DisableFirewall *bool
BlockLANAccess *bool BlockLANAccess *bool
BlockInbound *bool
DisableNotifications *bool DisableNotifications *bool
DNSLabels domain.List DNSLabels domain.List
LazyConnectionEnabled *bool
} }
// Config Configuration type // Config Configuration type
@@ -96,8 +98,8 @@ type Config struct {
DisableServerRoutes bool DisableServerRoutes bool
DisableDNS bool DisableDNS bool
DisableFirewall bool DisableFirewall bool
BlockLANAccess bool BlockLANAccess bool
BlockInbound bool
DisableNotifications *bool DisableNotifications *bool
@@ -138,6 +140,8 @@ type Config struct {
ClientCertKeyPath string ClientCertKeyPath string
ClientCertKeyPair *tls.Certificate `json:"-"` ClientCertKeyPair *tls.Certificate `json:"-"`
LazyConnectionEnabled bool
} }
// ReadConfig read config file and return with Config. If it is not exists create a new with default values // ReadConfig read config file and return with Config. If it is not exists create a new with default values
@@ -479,6 +483,16 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true updated = true
} }
if input.BlockInbound != nil && *input.BlockInbound != config.BlockInbound {
if *input.BlockInbound {
log.Infof("blocking inbound connections")
} else {
log.Infof("allowing inbound connections")
}
config.BlockInbound = *input.BlockInbound
updated = true
}
if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications { if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications {
if *input.DisableNotifications { if *input.DisableNotifications {
log.Infof("disabling notifications") log.Infof("disabling notifications")
@@ -524,6 +538,12 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true updated = true
} }
if input.LazyConnectionEnabled != nil && *input.LazyConnectionEnabled != config.LazyConnectionEnabled {
log.Infof("switching lazy connection to %t", *input.LazyConnectionEnabled)
config.LazyConnectionEnabled = *input.LazyConnectionEnabled
updated = true
}
return updated, nil return updated, nil
} }

303
client/internal/conn_mgr.go Normal file
View File

@@ -0,0 +1,303 @@
package internal
import (
"context"
"os"
"strconv"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/lazyconn/manager"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
"github.com/netbirdio/netbird/client/internal/peerstore"
)
// ConnMgr coordinates both lazy connections (established on-demand) and permanent peer connections.
//
// The connection manager is responsible for:
// - Managing lazy connections via the lazyConnManager
// - Maintaining a list of excluded peers that should always have permanent connections
// - Handling connection establishment based on peer signaling
//
// The implementation is not thread-safe; it is protected by engine.syncMsgMux.
type ConnMgr struct {
peerStore *peerstore.Store
statusRecorder *peer.Status
iface lazyconn.WGIface
dispatcher *dispatcher.ConnectionDispatcher
enabledLocally bool
lazyConnMgr *manager.Manager
wg sync.WaitGroup
ctx context.Context
ctxCancel context.CancelFunc
}
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface, dispatcher *dispatcher.ConnectionDispatcher) *ConnMgr {
e := &ConnMgr{
peerStore: peerStore,
statusRecorder: statusRecorder,
iface: iface,
dispatcher: dispatcher,
}
if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() {
e.enabledLocally = true
}
return e
}
// Start initializes the connection manager and starts the lazy connection manager if enabled by env var or cmd line option.
func (e *ConnMgr) Start(ctx context.Context) {
if e.lazyConnMgr != nil {
log.Errorf("lazy connection manager is already started")
return
}
if !e.enabledLocally {
log.Infof("lazy connection manager is disabled")
return
}
e.initLazyManager(ctx)
e.statusRecorder.UpdateLazyConnection(true)
}
// UpdatedRemoteFeatureFlag is called when the remote feature flag is updated.
// If enabled, it initializes the lazy connection manager and start it. Do not need to call Start() again.
// If disabled, then it closes the lazy connection manager and open the connections to all peers.
func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) error {
// do not disable lazy connection manager if it was enabled by env var
if e.enabledLocally {
return nil
}
if enabled {
// if the lazy connection manager is already started, do not start it again
if e.lazyConnMgr != nil {
return nil
}
log.Infof("lazy connection manager is enabled by management feature flag")
e.initLazyManager(ctx)
e.statusRecorder.UpdateLazyConnection(true)
return e.addPeersToLazyConnManager(ctx)
} else {
if e.lazyConnMgr == nil {
return nil
}
log.Infof("lazy connection manager is disabled by management feature flag")
e.closeManager(ctx)
e.statusRecorder.UpdateLazyConnection(false)
return nil
}
}
// SetExcludeList sets the list of peer IDs that should always have permanent connections.
func (e *ConnMgr) SetExcludeList(peerIDs map[string]bool) {
if e.lazyConnMgr == nil {
return
}
excludedPeers := make([]lazyconn.PeerConfig, 0, len(peerIDs))
for peerID := range peerIDs {
var peerConn *peer.Conn
var exists bool
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
log.Warnf("failed to find peer conn for peerID: %s", peerID)
continue
}
lazyPeerCfg := lazyconn.PeerConfig{
PublicKey: peerID,
AllowedIPs: peerConn.WgConfig().AllowedIps,
PeerConnID: peerConn.ConnID(),
Log: peerConn.Log,
}
excludedPeers = append(excludedPeers, lazyPeerCfg)
}
added := e.lazyConnMgr.ExcludePeer(e.ctx, excludedPeers)
for _, peerID := range added {
var peerConn *peer.Conn
var exists bool
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
// if the peer not exist in the store, it means that the engine will call the AddPeerConn in next step
continue
}
peerConn.Log.Infof("peer has been added to lazy connection exclude list, opening permanent connection")
if err := peerConn.Open(e.ctx); err != nil {
peerConn.Log.Errorf("failed to open connection: %v", err)
}
}
}
func (e *ConnMgr) AddPeerConn(ctx context.Context, peerKey string, conn *peer.Conn) (exists bool) {
if success := e.peerStore.AddPeerConn(peerKey, conn); !success {
return true
}
if !e.isStartedWithLazyMgr() {
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
return
}
if !lazyconn.IsSupported(conn.AgentVersionString()) {
conn.Log.Warnf("peer does not support lazy connection (%s), open permanent connection", conn.AgentVersionString())
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
return
}
lazyPeerCfg := lazyconn.PeerConfig{
PublicKey: peerKey,
AllowedIPs: conn.WgConfig().AllowedIps,
PeerConnID: conn.ConnID(),
Log: conn.Log,
}
excluded, err := e.lazyConnMgr.AddPeer(lazyPeerCfg)
if err != nil {
conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err)
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
return
}
if excluded {
conn.Log.Infof("peer is on lazy conn manager exclude list, opening connection")
if err := conn.Open(ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
return
}
conn.Log.Infof("peer added to lazy conn manager")
return
}
func (e *ConnMgr) RemovePeerConn(peerKey string) {
conn, ok := e.peerStore.Remove(peerKey)
if !ok {
return
}
defer conn.Close()
if !e.isStartedWithLazyMgr() {
return
}
e.lazyConnMgr.RemovePeer(peerKey)
conn.Log.Infof("removed peer from lazy conn manager")
}
func (e *ConnMgr) OnSignalMsg(ctx context.Context, peerKey string) (*peer.Conn, bool) {
conn, ok := e.peerStore.PeerConn(peerKey)
if !ok {
return nil, false
}
if !e.isStartedWithLazyMgr() {
return conn, true
}
if found := e.lazyConnMgr.ActivatePeer(ctx, peerKey); found {
conn.Log.Infof("activated peer from inactive state")
if err := conn.Open(e.ctx); err != nil {
conn.Log.Errorf("failed to open connection: %v", err)
}
}
return conn, true
}
func (e *ConnMgr) Close() {
if !e.isStartedWithLazyMgr() {
return
}
e.ctxCancel()
e.wg.Wait()
e.lazyConnMgr = nil
}
func (e *ConnMgr) initLazyManager(parentCtx context.Context) {
cfg := manager.Config{
InactivityThreshold: inactivityThresholdEnv(),
}
e.lazyConnMgr = manager.NewManager(cfg, e.peerStore, e.iface, e.dispatcher)
ctx, cancel := context.WithCancel(parentCtx)
e.ctx = ctx
e.ctxCancel = cancel
e.wg.Add(1)
go func() {
defer e.wg.Done()
e.lazyConnMgr.Start(ctx)
}()
}
func (e *ConnMgr) addPeersToLazyConnManager(ctx context.Context) error {
peers := e.peerStore.PeersPubKey()
lazyPeerCfgs := make([]lazyconn.PeerConfig, 0, len(peers))
for _, peerID := range peers {
var peerConn *peer.Conn
var exists bool
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
log.Warnf("failed to find peer conn for peerID: %s", peerID)
continue
}
lazyPeerCfg := lazyconn.PeerConfig{
PublicKey: peerID,
AllowedIPs: peerConn.WgConfig().AllowedIps,
PeerConnID: peerConn.ConnID(),
Log: peerConn.Log,
}
lazyPeerCfgs = append(lazyPeerCfgs, lazyPeerCfg)
}
return e.lazyConnMgr.AddActivePeers(ctx, lazyPeerCfgs)
}
func (e *ConnMgr) closeManager(ctx context.Context) {
if e.lazyConnMgr == nil {
return
}
e.ctxCancel()
e.wg.Wait()
e.lazyConnMgr = nil
for _, peerID := range e.peerStore.PeersPubKey() {
e.peerStore.PeerConnOpen(ctx, peerID)
}
}
func (e *ConnMgr) isStartedWithLazyMgr() bool {
return e.lazyConnMgr != nil && e.ctxCancel != nil
}
func inactivityThresholdEnv() *time.Duration {
envValue := os.Getenv(lazyconn.EnvInactivityThreshold)
if envValue == "" {
return nil
}
parsedMinutes, err := strconv.Atoi(envValue)
if err != nil || parsedMinutes <= 0 {
return nil
}
d := time.Duration(parsedMinutes) * time.Minute
return &d
}

View File

@@ -349,6 +349,25 @@ func (c *ConnectClient) Engine() *Engine {
return e return e
} }
// GetLatestNetworkMap returns the latest network map from the engine.
func (c *ConnectClient) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
engine := c.Engine()
if engine == nil {
return nil, errors.New("engine is not initialized")
}
networkMap, err := engine.GetLatestNetworkMap()
if err != nil {
return nil, fmt.Errorf("get latest network map: %w", err)
}
if networkMap == nil {
return nil, errors.New("network map is not available")
}
return networkMap, nil
}
// Status returns the current client status // Status returns the current client status
func (c *ConnectClient) Status() StatusType { func (c *ConnectClient) Status() StatusType {
if c == nil { if c == nil {
@@ -417,11 +436,13 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
DNSRouteInterval: config.DNSRouteInterval, DNSRouteInterval: config.DNSRouteInterval,
DisableClientRoutes: config.DisableClientRoutes, DisableClientRoutes: config.DisableClientRoutes,
DisableServerRoutes: config.DisableServerRoutes, DisableServerRoutes: config.DisableServerRoutes || config.BlockInbound,
DisableDNS: config.DisableDNS, DisableDNS: config.DisableDNS,
DisableFirewall: config.DisableFirewall, DisableFirewall: config.DisableFirewall,
BlockLANAccess: config.BlockLANAccess, BlockLANAccess: config.BlockLANAccess,
BlockInbound: config.BlockInbound,
LazyConnectionEnabled: config.LazyConnectionEnabled,
} }
if config.PreSharedKey != "" { if config.PreSharedKey != "" {
@@ -462,7 +483,7 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
return signalClient, nil return signalClient, nil
} }
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc) // loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) { func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) {
serverPublicKey, err := client.GetServerPublicKey() serverPublicKey, err := client.GetServerPublicKey()

File diff suppressed because it is too large Load Diff

View File

@@ -1,49 +1,130 @@
//go:build linux && !android //go:build linux && !android
package server package debug
import ( import (
"archive/zip"
"bytes" "bytes"
"context"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"os"
"os/exec" "os/exec"
"sort" "sort"
"strings" "strings"
"time"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/proto"
) )
const (
maxLogEntries = 100000
maxLogAge = 7 * 24 * time.Hour // Last 7 days
)
// trySystemdLogFallback attempts to get logs from systemd journal as fallback
func (g *BundleGenerator) trySystemdLogFallback() error {
log.Debug("Attempting to collect systemd journal logs")
serviceName := getServiceName()
journalLogs, err := getSystemdLogs(serviceName)
if err != nil {
return fmt.Errorf("get systemd logs for %s: %w", serviceName, err)
}
if strings.Contains(journalLogs, "No recent log entries found") {
log.Debug("No recent log entries found in systemd journal")
return nil
}
if g.anonymize {
journalLogs = g.anonymizer.AnonymizeString(journalLogs)
}
logReader := strings.NewReader(journalLogs)
fileName := fmt.Sprintf("systemd-%s.log", serviceName)
if err := g.addFileToZip(logReader, fileName); err != nil {
return fmt.Errorf("add systemd logs to bundle: %w", err)
}
log.Infof("Added systemd journal logs for %s to debug bundle", serviceName)
return nil
}
// getServiceName gets the service name from environment or defaults to netbird
func getServiceName() string {
if unitName := os.Getenv("SYSTEMD_UNIT"); unitName != "" {
log.Debugf("Detected SYSTEMD_UNIT environment variable: %s", unitName)
return unitName
}
return "netbird"
}
// getSystemdLogs retrieves logs from systemd journal for a specific service using journalctl
func getSystemdLogs(serviceName string) (string, error) {
args := []string{
"-u", fmt.Sprintf("%s.service", serviceName),
"--since", fmt.Sprintf("-%s", maxLogAge.String()),
"--lines", fmt.Sprintf("%d", maxLogEntries),
"--no-pager",
"--output", "short-iso",
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "journalctl", args...)
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return "", fmt.Errorf("journalctl command timed out after 30 seconds")
}
if strings.Contains(err.Error(), "executable file not found") {
return "", fmt.Errorf("journalctl command not found: %w", err)
}
return "", fmt.Errorf("execute journalctl: %w (stderr: %s)", err, stderr.String())
}
logs := stdout.String()
if strings.TrimSpace(logs) == "" {
return "No recent log entries found in systemd journal", nil
}
header := fmt.Sprintf("=== Systemd Journal Logs for %s.service (last %d entries, max %s) ===\n",
serviceName, maxLogEntries, maxLogAge.String())
return header + logs, nil
}
// addFirewallRules collects and adds firewall rules to the archive // addFirewallRules collects and adds firewall rules to the archive
func (s *Server) addFirewallRules(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { func (g *BundleGenerator) addFirewallRules() error {
log.Info("Collecting firewall rules") log.Info("Collecting firewall rules")
// Collect and add iptables rules
iptablesRules, err := collectIPTablesRules() iptablesRules, err := collectIPTablesRules()
if err != nil { if err != nil {
log.Warnf("Failed to collect iptables rules: %v", err) log.Warnf("Failed to collect iptables rules: %v", err)
} else { } else {
if req.GetAnonymize() { if g.anonymize {
iptablesRules = anonymizer.AnonymizeString(iptablesRules) iptablesRules = g.anonymizer.AnonymizeString(iptablesRules)
} }
if err := addFileToZip(archive, strings.NewReader(iptablesRules), "iptables.txt"); err != nil { if err := g.addFileToZip(strings.NewReader(iptablesRules), "iptables.txt"); err != nil {
log.Warnf("Failed to add iptables rules to bundle: %v", err) log.Warnf("Failed to add iptables rules to bundle: %v", err)
} }
} }
// Collect and add nftables rules
nftablesRules, err := collectNFTablesRules() nftablesRules, err := collectNFTablesRules()
if err != nil { if err != nil {
log.Warnf("Failed to collect nftables rules: %v", err) log.Warnf("Failed to collect nftables rules: %v", err)
} else { } else {
if req.GetAnonymize() { if g.anonymize {
nftablesRules = anonymizer.AnonymizeString(nftablesRules) nftablesRules = g.anonymizer.AnonymizeString(nftablesRules)
} }
if err := addFileToZip(archive, strings.NewReader(nftablesRules), "nftables.txt"); err != nil { if err := g.addFileToZip(strings.NewReader(nftablesRules), "nftables.txt"); err != nil {
log.Warnf("Failed to add nftables rules to bundle: %v", err) log.Warnf("Failed to add nftables rules to bundle: %v", err)
} }
} }
@@ -65,16 +146,23 @@ func collectIPTablesRules() (string, error) {
builder.WriteString("\n") builder.WriteString("\n")
} }
// Then get verbose statistics for each table // Collect ipset information
ipsetOutput, err := collectIPSets()
if err != nil {
log.Warnf("Failed to collect ipset information: %v", err)
} else {
builder.WriteString("=== ipset list output ===\n")
builder.WriteString(ipsetOutput)
builder.WriteString("\n")
}
builder.WriteString("=== iptables -v -n -L output ===\n") builder.WriteString("=== iptables -v -n -L output ===\n")
// Get list of tables
tables := []string{"filter", "nat", "mangle", "raw", "security"} tables := []string{"filter", "nat", "mangle", "raw", "security"}
for _, table := range tables { for _, table := range tables {
builder.WriteString(fmt.Sprintf("*%s\n", table)) builder.WriteString(fmt.Sprintf("*%s\n", table))
// Get verbose statistics for the entire table
stats, err := getTableStatistics(table) stats, err := getTableStatistics(table)
if err != nil { if err != nil {
log.Warnf("Failed to get statistics for table %s: %v", table, err) log.Warnf("Failed to get statistics for table %s: %v", table, err)
@@ -87,6 +175,28 @@ func collectIPTablesRules() (string, error) {
return builder.String(), nil return builder.String(), nil
} }
// collectIPSets collects information about ipsets
func collectIPSets() (string, error) {
cmd := exec.Command("ipset", "list")
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
if strings.Contains(err.Error(), "executable file not found") {
return "", fmt.Errorf("ipset command not found: %w", err)
}
return "", fmt.Errorf("execute ipset list: %w (stderr: %s)", err, stderr.String())
}
ipsets := stdout.String()
if strings.TrimSpace(ipsets) == "" {
return "No ipsets found", nil
}
return ipsets, nil
}
// collectIPTablesSave uses iptables-save to get rule definitions // collectIPTablesSave uses iptables-save to get rule definitions
func collectIPTablesSave() (string, error) { func collectIPTablesSave() (string, error) {
cmd := exec.Command("iptables-save") cmd := exec.Command("iptables-save")
@@ -182,12 +292,10 @@ func formatTables(conn *nftables.Conn, tables []*nftables.Table) string {
continue continue
} }
// Format chains
for _, chain := range chains { for _, chain := range chains {
formatChain(conn, table, chain, &builder) formatChain(conn, table, chain, &builder)
} }
// Format sets
if sets, err := conn.GetSets(table); err != nil { if sets, err := conn.GetSets(table); err != nil {
log.Warnf("Failed to get sets for table %s: %v", table.Name, err) log.Warnf("Failed to get sets for table %s: %v", table.Name, err)
} else if len(sets) > 0 { } else if len(sets) > 0 {
@@ -460,7 +568,7 @@ func formatExpr(exp expr.Any) string {
case *expr.Fib: case *expr.Fib:
return formatFib(e) return formatFib(e)
case *expr.Target: case *expr.Target:
return fmt.Sprintf("jump %s", e.Name) // Properly format jump targets return fmt.Sprintf("jump %s", e.Name)
case *expr.Immediate: case *expr.Immediate:
if e.Register == 1 { if e.Register == 1 {
return formatImmediateData(e.Data) return formatImmediateData(e.Data)

View File

@@ -0,0 +1,7 @@
//go:build ios || android
package debug
func (g *BundleGenerator) addRoutes() error {
return nil
}

View File

@@ -0,0 +1,14 @@
//go:build !linux || android
package debug
// collectFirewallRules returns nothing on non-linux systems
func (g *BundleGenerator) addFirewallRules() error {
return nil
}
func (g *BundleGenerator) trySystemdLogFallback() error {
// Systemd is only available on Linux
// TODO: Add BSD support
return nil
}

View File

@@ -0,0 +1,25 @@
//go:build !ios && !android
package debug
import (
"fmt"
"strings"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
func (g *BundleGenerator) addRoutes() error {
routes, err := systemops.GetRoutesFromTable()
if err != nil {
return fmt.Errorf("get routes: %w", err)
}
// TODO: get routes including nexthop
routesContent := formatRoutes(routes, g.anonymize, g.anonymizer)
routesReader := strings.NewReader(routesContent)
if err := g.addFileToZip(routesReader, "routes.txt"); err != nil {
return fmt.Errorf("add routes file to zip: %w", err)
}
return nil
}

View File

@@ -0,0 +1,543 @@
package debug
import (
"encoding/json"
"net"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/anonymize"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
func TestAnonymizeStateFile(t *testing.T) {
testState := map[string]json.RawMessage{
"null_state": json.RawMessage("null"),
"test_state": mustMarshal(map[string]any{
// Test simple fields
"public_ip": "203.0.113.1",
"private_ip": "192.168.1.1",
"protected_ip": "100.64.0.1",
"well_known_ip": "8.8.8.8",
"ipv6_addr": "2001:db8::1",
"private_ipv6": "fd00::1",
"domain": "test.example.com",
"uri": "stun:stun.example.com:3478",
"uri_with_ip": "turn:203.0.113.1:3478",
"netbird_domain": "device.netbird.cloud",
// Test CIDR ranges
"public_cidr": "203.0.113.0/24",
"private_cidr": "192.168.0.0/16",
"protected_cidr": "100.64.0.0/10",
"ipv6_cidr": "2001:db8::/32",
"private_ipv6_cidr": "fd00::/8",
// Test nested structures
"nested": map[string]any{
"ip": "203.0.113.2",
"domain": "nested.example.com",
"more_nest": map[string]any{
"ip": "203.0.113.3",
"domain": "deep.example.com",
},
},
// Test arrays
"string_array": []any{
"203.0.113.4",
"test1.example.com",
"test2.example.com",
},
"object_array": []any{
map[string]any{
"ip": "203.0.113.5",
"domain": "array1.example.com",
},
map[string]any{
"ip": "203.0.113.6",
"domain": "array2.example.com",
},
},
// Test multiple occurrences of same value
"duplicate_ip": "203.0.113.1", // Same as public_ip
"duplicate_domain": "test.example.com", // Same as domain
// Test URIs with various schemes
"stun_uri": "stun:stun.example.com:3478",
"turns_uri": "turns:turns.example.com:5349",
"http_uri": "http://web.example.com:80",
"https_uri": "https://secure.example.com:443",
// Test strings that might look like IPs but aren't
"not_ip": "300.300.300.300",
"partial_ip": "192.168",
"ip_like_string": "1234.5678",
// Test mixed content strings
"mixed_content": "Server at 203.0.113.1 (test.example.com) on port 80",
// Test empty and special values
"empty_string": "",
"null_value": nil,
"numeric_value": 42,
"boolean_value": true,
}),
"route_state": mustMarshal(map[string]any{
"routes": []any{
map[string]any{
"network": "203.0.113.0/24",
"gateway": "203.0.113.1",
"domains": []any{
"route1.example.com",
"route2.example.com",
},
},
map[string]any{
"network": "2001:db8::/32",
"gateway": "2001:db8::1",
"domains": []any{
"route3.example.com",
"route4.example.com",
},
},
},
// Test map with IP/CIDR keys
"refCountMap": map[string]any{
"203.0.113.1/32": map[string]any{
"Count": 1,
"Out": map[string]any{
"IP": "192.168.0.1",
"Intf": map[string]any{
"Name": "eth0",
"Index": 1,
},
},
},
"2001:db8::1/128": map[string]any{
"Count": 1,
"Out": map[string]any{
"IP": "fe80::1",
"Intf": map[string]any{
"Name": "eth0",
"Index": 1,
},
},
},
"10.0.0.1/32": map[string]any{ // private IP should remain unchanged
"Count": 1,
"Out": map[string]any{
"IP": "192.168.0.1",
},
},
},
}),
}
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
// Pre-seed the domains we need to verify in the test assertions
anonymizer.AnonymizeDomain("test.example.com")
anonymizer.AnonymizeDomain("nested.example.com")
anonymizer.AnonymizeDomain("deep.example.com")
anonymizer.AnonymizeDomain("array1.example.com")
err := anonymizeStateFile(&testState, anonymizer)
require.NoError(t, err)
// Helper function to unmarshal and get nested values
var state map[string]any
err = json.Unmarshal(testState["test_state"], &state)
require.NoError(t, err)
// Test null state remains unchanged
require.Equal(t, "null", string(testState["null_state"]))
// Basic assertions
assert.NotEqual(t, "203.0.113.1", state["public_ip"])
assert.Equal(t, "192.168.1.1", state["private_ip"]) // Private IP unchanged
assert.Equal(t, "100.64.0.1", state["protected_ip"]) // Protected IP unchanged
assert.Equal(t, "8.8.8.8", state["well_known_ip"]) // Well-known IP unchanged
assert.NotEqual(t, "2001:db8::1", state["ipv6_addr"])
assert.Equal(t, "fd00::1", state["private_ipv6"]) // Private IPv6 unchanged
assert.NotEqual(t, "test.example.com", state["domain"])
assert.True(t, strings.HasSuffix(state["domain"].(string), ".domain"))
assert.Equal(t, "device.netbird.cloud", state["netbird_domain"]) // Netbird domain unchanged
// CIDR ranges
assert.NotEqual(t, "203.0.113.0/24", state["public_cidr"])
assert.Contains(t, state["public_cidr"], "/24") // Prefix preserved
assert.Equal(t, "192.168.0.0/16", state["private_cidr"]) // Private CIDR unchanged
assert.Equal(t, "100.64.0.0/10", state["protected_cidr"]) // Protected CIDR unchanged
assert.NotEqual(t, "2001:db8::/32", state["ipv6_cidr"])
assert.Contains(t, state["ipv6_cidr"], "/32") // IPv6 prefix preserved
// Nested structures
nested := state["nested"].(map[string]any)
assert.NotEqual(t, "203.0.113.2", nested["ip"])
assert.NotEqual(t, "nested.example.com", nested["domain"])
moreNest := nested["more_nest"].(map[string]any)
assert.NotEqual(t, "203.0.113.3", moreNest["ip"])
assert.NotEqual(t, "deep.example.com", moreNest["domain"])
// Arrays
strArray := state["string_array"].([]any)
assert.NotEqual(t, "203.0.113.4", strArray[0])
assert.NotEqual(t, "test1.example.com", strArray[1])
assert.True(t, strings.HasSuffix(strArray[1].(string), ".domain"))
objArray := state["object_array"].([]any)
firstObj := objArray[0].(map[string]any)
assert.NotEqual(t, "203.0.113.5", firstObj["ip"])
assert.NotEqual(t, "array1.example.com", firstObj["domain"])
// Duplicate values should be anonymized consistently
assert.Equal(t, state["public_ip"], state["duplicate_ip"])
assert.Equal(t, state["domain"], state["duplicate_domain"])
// URIs
assert.NotContains(t, state["stun_uri"], "stun.example.com")
assert.NotContains(t, state["turns_uri"], "turns.example.com")
assert.NotContains(t, state["http_uri"], "web.example.com")
assert.NotContains(t, state["https_uri"], "secure.example.com")
// Non-IP strings should remain unchanged
assert.Equal(t, "300.300.300.300", state["not_ip"])
assert.Equal(t, "192.168", state["partial_ip"])
assert.Equal(t, "1234.5678", state["ip_like_string"])
// Mixed content should have IPs and domains replaced
mixedContent := state["mixed_content"].(string)
assert.NotContains(t, mixedContent, "203.0.113.1")
assert.NotContains(t, mixedContent, "test.example.com")
assert.Contains(t, mixedContent, "Server at ")
assert.Contains(t, mixedContent, " on port 80")
// Special values should remain unchanged
assert.Equal(t, "", state["empty_string"])
assert.Nil(t, state["null_value"])
assert.Equal(t, float64(42), state["numeric_value"])
assert.Equal(t, true, state["boolean_value"])
// Check route state
var routeState map[string]any
err = json.Unmarshal(testState["route_state"], &routeState)
require.NoError(t, err)
routes := routeState["routes"].([]any)
route1 := routes[0].(map[string]any)
assert.NotEqual(t, "203.0.113.0/24", route1["network"])
assert.Contains(t, route1["network"], "/24")
assert.NotEqual(t, "203.0.113.1", route1["gateway"])
domains := route1["domains"].([]any)
assert.True(t, strings.HasSuffix(domains[0].(string), ".domain"))
assert.True(t, strings.HasSuffix(domains[1].(string), ".domain"))
// Check map keys are anonymized
refCountMap := routeState["refCountMap"].(map[string]any)
hasPublicIPKey := false
hasIPv6Key := false
hasPrivateIPKey := false
for key := range refCountMap {
if strings.Contains(key, "203.0.113.1") {
hasPublicIPKey = true
}
if strings.Contains(key, "2001:db8::1") {
hasIPv6Key = true
}
if key == "10.0.0.1/32" {
hasPrivateIPKey = true
}
}
assert.False(t, hasPublicIPKey, "public IP in key should be anonymized")
assert.False(t, hasIPv6Key, "IPv6 in key should be anonymized")
assert.True(t, hasPrivateIPKey, "private IP in key should remain unchanged")
}
func mustMarshal(v any) json.RawMessage {
data, err := json.Marshal(v)
if err != nil {
panic(err)
}
return data
}
func TestAnonymizeNetworkMap(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
PeerConfig: &mgmProto.PeerConfig{
Address: "203.0.113.5",
Dns: "1.2.3.4",
Fqdn: "peer1.corp.example.com",
SshConfig: &mgmProto.SSHConfig{
SshPubKey: []byte("ssh-rsa AAAAB3NzaC1..."),
},
},
RemotePeers: []*mgmProto.RemotePeerConfig{
{
AllowedIps: []string{
"203.0.113.1/32",
"2001:db8:1234::1/128",
"192.168.1.1/32",
"100.64.0.1/32",
"10.0.0.1/32",
},
Fqdn: "peer2.corp.example.com",
SshConfig: &mgmProto.SSHConfig{
SshPubKey: []byte("ssh-rsa AAAAB3NzaC2..."),
},
},
},
Routes: []*mgmProto.Route{
{
Network: "197.51.100.0/24",
Domains: []string{"prod.example.com", "staging.example.com"},
NetID: "net-123abc",
},
},
DNSConfig: &mgmProto.DNSConfig{
NameServerGroups: []*mgmProto.NameServerGroup{
{
NameServers: []*mgmProto.NameServer{
{IP: "8.8.8.8"},
{IP: "1.1.1.1"},
{IP: "203.0.113.53"},
},
Domains: []string{"example.com", "internal.example.com"},
},
},
CustomZones: []*mgmProto.CustomZone{
{
Domain: "custom.example.com",
Records: []*mgmProto.SimpleRecord{
{
Name: "www.custom.example.com",
Type: 1,
RData: "203.0.113.10",
},
{
Name: "internal.custom.example.com",
Type: 1,
RData: "192.168.1.10",
},
},
},
},
},
}
// Create anonymizer with test addresses
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
// Anonymize the network map
err := anonymizeNetworkMap(networkMap, anonymizer)
require.NoError(t, err)
// Test PeerConfig anonymization
peerCfg := networkMap.PeerConfig
require.NotEqual(t, "203.0.113.5", peerCfg.Address)
// Verify DNS and FQDN are properly anonymized
require.NotEqual(t, "1.2.3.4", peerCfg.Dns)
require.NotEqual(t, "peer1.corp.example.com", peerCfg.Fqdn)
require.True(t, strings.HasSuffix(peerCfg.Fqdn, ".domain"))
// Verify SSH key is replaced
require.Equal(t, []byte("ssh-placeholder-key"), peerCfg.SshConfig.SshPubKey)
// Test RemotePeers anonymization
remotePeer := networkMap.RemotePeers[0]
// Verify FQDN is anonymized
require.NotEqual(t, "peer2.corp.example.com", remotePeer.Fqdn)
require.True(t, strings.HasSuffix(remotePeer.Fqdn, ".domain"))
// Check that public IPs are anonymized but private IPs are preserved
for _, allowedIP := range remotePeer.AllowedIps {
ip, _, err := net.ParseCIDR(allowedIP)
require.NoError(t, err)
if ip.IsPrivate() || isInCGNATRange(ip) {
require.Contains(t, []string{
"192.168.1.1/32",
"100.64.0.1/32",
"10.0.0.1/32",
}, allowedIP)
} else {
require.NotContains(t, []string{
"203.0.113.1/32",
"2001:db8:1234::1/128",
}, allowedIP)
}
}
// Test Routes anonymization
route := networkMap.Routes[0]
require.NotEqual(t, "197.51.100.0/24", route.Network)
for _, domain := range route.Domains {
require.True(t, strings.HasSuffix(domain, ".domain"))
require.NotContains(t, domain, "example.com")
}
// Test DNS config anonymization
dnsConfig := networkMap.DNSConfig
nameServerGroup := dnsConfig.NameServerGroups[0]
// Verify well-known DNS servers are preserved
require.Equal(t, "8.8.8.8", nameServerGroup.NameServers[0].IP)
require.Equal(t, "1.1.1.1", nameServerGroup.NameServers[1].IP)
// Verify public DNS server is anonymized
require.NotEqual(t, "203.0.113.53", nameServerGroup.NameServers[2].IP)
// Verify domains are anonymized
for _, domain := range nameServerGroup.Domains {
require.True(t, strings.HasSuffix(domain, ".domain"))
require.NotContains(t, domain, "example.com")
}
// Test CustomZones anonymization
customZone := dnsConfig.CustomZones[0]
require.True(t, strings.HasSuffix(customZone.Domain, ".domain"))
require.NotContains(t, customZone.Domain, "example.com")
// Verify records are properly anonymized
for _, record := range customZone.Records {
require.True(t, strings.HasSuffix(record.Name, ".domain"))
require.NotContains(t, record.Name, "example.com")
ip := net.ParseIP(record.RData)
if ip != nil {
if !ip.IsPrivate() {
require.NotEqual(t, "203.0.113.10", record.RData)
} else {
require.Equal(t, "192.168.1.10", record.RData)
}
}
}
}
// Helper function to check if IP is in CGNAT range
func isInCGNATRange(ip net.IP) bool {
cgnat := net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
return cgnat.Contains(ip)
}
func TestAnonymizeFirewallRules(t *testing.T) {
// TODO: Add ipv6
// Example iptables-save output
iptablesSave := `# Generated by iptables-save v1.8.7 on Thu Dec 19 10:00:00 2024
*filter
:INPUT ACCEPT [0:0]
:FORWARD ACCEPT [0:0]
:OUTPUT ACCEPT [0:0]
-A INPUT -s 192.168.1.0/24 -j ACCEPT
-A INPUT -s 44.192.140.1/32 -j DROP
-A FORWARD -s 10.0.0.0/8 -j DROP
-A FORWARD -s 44.192.140.0/24 -d 52.84.12.34/24 -j ACCEPT
COMMIT
*nat
:PREROUTING ACCEPT [0:0]
:INPUT ACCEPT [0:0]
:OUTPUT ACCEPT [0:0]
:POSTROUTING ACCEPT [0:0]
-A POSTROUTING -s 192.168.100.0/24 -j MASQUERADE
-A PREROUTING -d 44.192.140.10/32 -p tcp -m tcp --dport 80 -j DNAT --to-destination 192.168.1.10:80
COMMIT`
// Example iptables -v -n -L output
iptablesVerbose := `Chain INPUT (policy ACCEPT 0 packets, 0 bytes)
pkts bytes target prot opt in out source destination
0 0 ACCEPT all -- * * 192.168.1.0/24 0.0.0.0/0
100 1024 DROP all -- * * 44.192.140.1 0.0.0.0/0
Chain FORWARD (policy ACCEPT 0 packets, 0 bytes)
pkts bytes target prot opt in out source destination
0 0 DROP all -- * * 10.0.0.0/8 0.0.0.0/0
25 256 ACCEPT all -- * * 44.192.140.0/24 52.84.12.34/24
Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
pkts bytes target prot opt in out source destination`
// Example nftables output
nftablesRules := `table inet filter {
chain input {
type filter hook input priority filter; policy accept;
ip saddr 192.168.1.1 accept
ip saddr 44.192.140.1 drop
}
chain forward {
type filter hook forward priority filter; policy accept;
ip saddr 10.0.0.0/8 drop
ip saddr 44.192.140.0/24 ip daddr 52.84.12.34/24 accept
}
}`
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
// Test iptables-save anonymization
anonIptablesSave := anonymizer.AnonymizeString(iptablesSave)
// Private IP addresses should remain unchanged
assert.Contains(t, anonIptablesSave, "192.168.1.0/24")
assert.Contains(t, anonIptablesSave, "10.0.0.0/8")
assert.Contains(t, anonIptablesSave, "192.168.100.0/24")
assert.Contains(t, anonIptablesSave, "192.168.1.10")
// Public IP addresses should be anonymized to the default range
assert.NotContains(t, anonIptablesSave, "44.192.140.1")
assert.NotContains(t, anonIptablesSave, "44.192.140.0/24")
assert.NotContains(t, anonIptablesSave, "52.84.12.34")
assert.Contains(t, anonIptablesSave, "198.51.100.") // Default anonymous range
// Structure should be preserved
assert.Contains(t, anonIptablesSave, "*filter")
assert.Contains(t, anonIptablesSave, ":INPUT ACCEPT [0:0]")
assert.Contains(t, anonIptablesSave, "COMMIT")
assert.Contains(t, anonIptablesSave, "-j MASQUERADE")
assert.Contains(t, anonIptablesSave, "--dport 80")
// Test iptables verbose output anonymization
anonIptablesVerbose := anonymizer.AnonymizeString(iptablesVerbose)
// Private IP addresses should remain unchanged
assert.Contains(t, anonIptablesVerbose, "192.168.1.0/24")
assert.Contains(t, anonIptablesVerbose, "10.0.0.0/8")
// Public IP addresses should be anonymized to the default range
assert.NotContains(t, anonIptablesVerbose, "44.192.140.1")
assert.NotContains(t, anonIptablesVerbose, "44.192.140.0/24")
assert.NotContains(t, anonIptablesVerbose, "52.84.12.34")
assert.Contains(t, anonIptablesVerbose, "198.51.100.") // Default anonymous range
// Structure and counters should be preserved
assert.Contains(t, anonIptablesVerbose, "Chain INPUT (policy ACCEPT 0 packets, 0 bytes)")
assert.Contains(t, anonIptablesVerbose, "100 1024 DROP")
assert.Contains(t, anonIptablesVerbose, "pkts bytes target")
// Test nftables anonymization
anonNftables := anonymizer.AnonymizeString(nftablesRules)
// Private IP addresses should remain unchanged
assert.Contains(t, anonNftables, "192.168.1.1")
assert.Contains(t, anonNftables, "10.0.0.0/8")
// Public IP addresses should be anonymized to the default range
assert.NotContains(t, anonNftables, "44.192.140.1")
assert.NotContains(t, anonNftables, "44.192.140.0/24")
assert.NotContains(t, anonNftables, "52.84.12.34")
assert.Contains(t, anonNftables, "198.51.100.") // Default anonymous range
// Structure should be preserved
assert.Contains(t, anonNftables, "table inet filter {")
assert.Contains(t, anonNftables, "chain input {")
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
}

View File

@@ -0,0 +1,66 @@
package debug
import (
"bytes"
"fmt"
"strings"
"time"
"github.com/netbirdio/netbird/client/iface/configurer"
)
type WGIface interface {
FullStats() (*configurer.Stats, error)
}
func (g *BundleGenerator) addWgShow() error {
result, err := g.statusRecorder.PeersStatus()
if err != nil {
return err
}
output := g.toWGShowFormat(result)
reader := bytes.NewReader([]byte(output))
if err := g.addFileToZip(reader, "wgshow.txt"); err != nil {
return fmt.Errorf("add wg show to zip: %w", err)
}
return nil
}
func (g *BundleGenerator) toWGShowFormat(s *configurer.Stats) string {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("interface: %s\n", s.DeviceName))
sb.WriteString(fmt.Sprintf(" public key: %s\n", s.PublicKey))
sb.WriteString(fmt.Sprintf(" listen port: %d\n", s.ListenPort))
if s.FWMark != 0 {
sb.WriteString(fmt.Sprintf(" fwmark: %#x\n", s.FWMark))
}
for _, peer := range s.Peers {
sb.WriteString(fmt.Sprintf("\npeer: %s\n", peer.PublicKey))
if peer.Endpoint.IP != nil {
if g.anonymize {
anonEndpoint := g.anonymizer.AnonymizeUDPAddr(peer.Endpoint)
sb.WriteString(fmt.Sprintf(" endpoint: %s\n", anonEndpoint.String()))
} else {
sb.WriteString(fmt.Sprintf(" endpoint: %s\n", peer.Endpoint.String()))
}
}
if len(peer.AllowedIPs) > 0 {
var ipStrings []string
for _, ipnet := range peer.AllowedIPs {
ipStrings = append(ipStrings, ipnet.String())
}
sb.WriteString(fmt.Sprintf(" allowed ips: %s\n", strings.Join(ipStrings, ", ")))
}
sb.WriteString(fmt.Sprintf(" latest handshake: %s\n", peer.LastHandshake.Format(time.RFC1123)))
sb.WriteString(fmt.Sprintf(" transfer: %d B received, %d B sent\n", peer.RxBytes, peer.TxBytes))
if peer.PresharedKey {
sb.WriteString(" preshared key: (hidden)\n")
}
}
return sb.String()
}

View File

@@ -2,7 +2,7 @@ package internal
import ( import (
"fmt" "fmt"
"net" "net/netip"
"slices" "slices"
"strings" "strings"
@@ -12,13 +12,14 @@ import (
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
) )
func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) { func createPTRRecord(aRecord nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) {
ip := net.ParseIP(aRecord.RData) ip, err := netip.ParseAddr(aRecord.RData)
if ip == nil || ip.To4() == nil { if err != nil {
log.Warnf("failed to parse IP address %s: %v", aRecord.RData, err)
return nbdns.SimpleRecord{}, false return nbdns.SimpleRecord{}, false
} }
if !ipNet.Contains(ip) { if !prefix.Contains(ip) {
return nbdns.SimpleRecord{}, false return nbdns.SimpleRecord{}, false
} }
@@ -36,16 +37,19 @@ func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.Simple
} }
// generateReverseZoneName creates the reverse DNS zone name for a given network // generateReverseZoneName creates the reverse DNS zone name for a given network
func generateReverseZoneName(ipNet *net.IPNet) (string, error) { func generateReverseZoneName(network netip.Prefix) (string, error) {
networkIP := ipNet.IP.Mask(ipNet.Mask) networkIP := network.Masked().Addr()
maskOnes, _ := ipNet.Mask.Size()
if !networkIP.Is4() {
return "", fmt.Errorf("reverse DNS is only supported for IPv4 networks, got: %s", networkIP)
}
// round up to nearest byte // round up to nearest byte
octetsToUse := (maskOnes + 7) / 8 octetsToUse := (network.Bits() + 7) / 8
octets := strings.Split(networkIP.String(), ".") octets := strings.Split(networkIP.String(), ".")
if octetsToUse > len(octets) { if octetsToUse > len(octets) {
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", maskOnes) return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", network.Bits())
} }
reverseOctets := make([]string, octetsToUse) reverseOctets := make([]string, octetsToUse)
@@ -68,7 +72,7 @@ func zoneExists(config *nbdns.Config, zoneName string) bool {
} }
// collectPTRRecords gathers all PTR records for the given network from A records // collectPTRRecords gathers all PTR records for the given network from A records
func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRecord { func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.SimpleRecord {
var records []nbdns.SimpleRecord var records []nbdns.SimpleRecord
for _, zone := range config.CustomZones { for _, zone := range config.CustomZones {
@@ -77,7 +81,7 @@ func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRec
continue continue
} }
if ptrRecord, ok := createPTRRecord(record, ipNet); ok { if ptrRecord, ok := createPTRRecord(record, prefix); ok {
records = append(records, ptrRecord) records = append(records, ptrRecord)
} }
} }
@@ -87,8 +91,8 @@ func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRec
} }
// addReverseZone adds a reverse DNS zone to the configuration for the given network // addReverseZone adds a reverse DNS zone to the configuration for the given network
func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) { func addReverseZone(config *nbdns.Config, network netip.Prefix) {
zoneName, err := generateReverseZoneName(ipNet) zoneName, err := generateReverseZoneName(network)
if err != nil { if err != nil {
log.Warn(err) log.Warn(err)
return return
@@ -99,7 +103,7 @@ func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
return return
} }
records := collectPTRRecords(config, ipNet) records := collectPTRRecords(config, network)
reverseZone := nbdns.CustomZone{ reverseZone := nbdns.CustomZone{
Domain: zoneName, Domain: zoneName,

View File

@@ -239,7 +239,7 @@ func searchDomains(config HostDNSConfig) []string {
continue continue
} }
listOfDomains = append(listOfDomains, dConf.Domain) listOfDomains = append(listOfDomains, strings.TrimSuffix(dConf.Domain.PunycodeString(), "."))
} }
return listOfDomains return listOfDomains
} }

View File

@@ -1,12 +1,15 @@
package dns package dns
import ( import (
"fmt"
"slices" "slices"
"strings" "strings"
"sync" "sync"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/domain"
) )
const ( const (
@@ -23,8 +26,8 @@ type SubdomainMatcher interface {
type HandlerEntry struct { type HandlerEntry struct {
Handler dns.Handler Handler dns.Handler
Priority int Priority int
Pattern string Pattern domain.Domain
OrigPattern string OrigPattern domain.Domain
IsWildcard bool IsWildcard bool
MatchSubdomains bool MatchSubdomains bool
} }
@@ -38,7 +41,7 @@ type HandlerChain struct {
// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain // ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain
type ResponseWriterChain struct { type ResponseWriterChain struct {
dns.ResponseWriter dns.ResponseWriter
origPattern string origPattern domain.Domain
shouldContinue bool shouldContinue bool
} }
@@ -58,29 +61,24 @@ func NewHandlerChain() *HandlerChain {
} }
// GetOrigPattern returns the original pattern of the handler that wrote the response // GetOrigPattern returns the original pattern of the handler that wrote the response
func (w *ResponseWriterChain) GetOrigPattern() string { func (w *ResponseWriterChain) GetOrigPattern() domain.Domain {
return w.origPattern return w.origPattern
} }
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority // AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int) { func (c *HandlerChain) AddHandler(pattern domain.Domain, handler dns.Handler, priority int) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
pattern = strings.ToLower(dns.Fqdn(pattern)) pattern = domain.Domain(strings.ToLower(dns.Fqdn(pattern.PunycodeString())))
origPattern := pattern origPattern := pattern
isWildcard := strings.HasPrefix(pattern, "*.") isWildcard := strings.HasPrefix(pattern.PunycodeString(), "*.")
if isWildcard { if isWildcard {
pattern = pattern[2:] pattern = pattern[2:]
} }
// First remove any existing handler with same pattern (case-insensitive) and priority // First remove any existing handler with same pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- { c.removeEntry(origPattern, priority)
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
break
}
}
// Check if handler implements SubdomainMatcher interface // Check if handler implements SubdomainMatcher interface
matchSubdomains := false matchSubdomains := false
@@ -114,8 +112,8 @@ func (c *HandlerChain) findHandlerPosition(newEntry HandlerEntry) int {
// domain specificity next // domain specificity next
if h.Priority == newEntry.Priority { if h.Priority == newEntry.Priority {
newDots := strings.Count(newEntry.Pattern, ".") newDots := strings.Count(newEntry.Pattern.PunycodeString(), ".")
existingDots := strings.Count(h.Pattern, ".") existingDots := strings.Count(h.Pattern.PunycodeString(), ".")
if newDots > existingDots { if newDots > existingDots {
return i return i
} }
@@ -127,83 +125,53 @@ func (c *HandlerChain) findHandlerPosition(newEntry HandlerEntry) int {
} }
// RemoveHandler removes a handler for the given pattern and priority // RemoveHandler removes a handler for the given pattern and priority
func (c *HandlerChain) RemoveHandler(pattern string, priority int) { func (c *HandlerChain) RemoveHandler(pattern domain.Domain, priority int) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
pattern = dns.Fqdn(pattern) pattern = domain.Domain(dns.Fqdn(pattern.PunycodeString()))
c.removeEntry(pattern, priority)
}
func (c *HandlerChain) removeEntry(pattern domain.Domain, priority int) {
// Find and remove handlers matching both original pattern (case-insensitive) and priority // Find and remove handlers matching both original pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- { for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i] entry := c.handlers[i]
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority { if strings.EqualFold(entry.OrigPattern.PunycodeString(), pattern.PunycodeString()) && entry.Priority == priority {
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
return break
} }
} }
} }
// HasHandlers returns true if there are any handlers remaining for the given pattern
func (c *HandlerChain) HasHandlers(pattern string) bool {
c.mu.RLock()
defer c.mu.RUnlock()
pattern = strings.ToLower(dns.Fqdn(pattern))
for _, entry := range c.handlers {
if strings.EqualFold(entry.Pattern, pattern) {
return true
}
}
return false
}
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 { if len(r.Question) == 0 {
return return
} }
qname := strings.ToLower(r.Question[0].Name) qname := strings.ToLower(r.Question[0].Name)
log.Tracef("handling DNS request for domain=%s", qname)
c.mu.RLock() c.mu.RLock()
handlers := slices.Clone(c.handlers) handlers := slices.Clone(c.handlers)
c.mu.RUnlock() c.mu.RUnlock()
if log.IsLevelEnabled(log.TraceLevel) { if log.IsLevelEnabled(log.TraceLevel) {
log.Tracef("current handlers (%d):", len(handlers)) var b strings.Builder
b.WriteString(fmt.Sprintf("DNS request domain=%s, handlers (%d):\n", qname, len(handlers)))
for _, h := range handlers { for _, h := range handlers {
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d", b.WriteString(fmt.Sprintf(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d\n",
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority) h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority))
} }
log.Trace(strings.TrimSuffix(b.String(), "\n"))
} }
// Try handlers in priority order // Try handlers in priority order
for _, entry := range handlers { for _, entry := range handlers {
var matched bool matched := c.isHandlerMatch(qname, entry)
switch {
case entry.Pattern == ".":
matched = true
case entry.IsWildcard:
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
default:
// For non-wildcard patterns:
// If handler wants subdomain matching, allow suffix match
// Otherwise require exact match
if entry.MatchSubdomains {
matched = strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
} else {
matched = strings.EqualFold(qname, entry.Pattern)
}
}
if !matched { if matched {
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d matched=false", log.Tracef("handler matched: domain=%s -> pattern=%s wildcard=%v match_subdomain=%v priority=%d",
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard, entry.Priority)
continue
}
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d",
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority) qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
chainWriter := &ResponseWriterChain{ chainWriter := &ResponseWriterChain{
@@ -214,11 +182,12 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// If handler wants to continue, try next handler // If handler wants to continue, try next handler
if chainWriter.shouldContinue { if chainWriter.shouldContinue {
log.Tracef("handler requested continue to next handler") log.Tracef("handler requested continue to next handler for domain=%s", qname)
continue continue
} }
return return
} }
}
// No handler matched or all handlers passed // No handler matched or all handlers passed
log.Tracef("no handler found for domain=%s", qname) log.Tracef("no handler found for domain=%s", qname)
@@ -228,3 +197,22 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
log.Errorf("failed to write DNS response: %v", err) log.Errorf("failed to write DNS response: %v", err)
} }
} }
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
switch {
case entry.Pattern == ".":
return true
case entry.IsWildcard:
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern.PunycodeString()), ".")
return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern.PunycodeString())
default:
// For non-wildcard patterns:
// If handler wants subdomain matching, allow suffix match
// Otherwise require exact match
if entry.MatchSubdomains {
return strings.EqualFold(qname, entry.Pattern.PunycodeString()) || strings.HasSuffix(qname, "."+entry.Pattern.PunycodeString())
} else {
return strings.EqualFold(qname, entry.Pattern.PunycodeString())
}
}
}

View File

@@ -1,7 +1,6 @@
package dns_test package dns_test
import ( import (
"net"
"testing" "testing"
"github.com/miekg/dns" "github.com/miekg/dns"
@@ -9,6 +8,8 @@ import (
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
nbdns "github.com/netbirdio/netbird/client/internal/dns" nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/management/domain"
) )
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order // TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
@@ -30,7 +31,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
r.SetQuestion("example.com.", dns.TypeA) r.SetQuestion("example.com.", dns.TypeA)
// Create test writer // Create test writer
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Setup expectations - only highest priority handler should be called // Setup expectations - only highest priority handler should be called
dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once() dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once()
@@ -50,8 +51,8 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) { func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
handlerDomain string handlerDomain domain.Domain
queryDomain string queryDomain domain.Domain
isWildcard bool isWildcard bool
matchSubdomains bool matchSubdomains bool
shouldMatch bool shouldMatch bool
@@ -141,8 +142,8 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
chain.AddHandler(pattern, handler, nbdns.PriorityDefault) chain.AddHandler(pattern, handler, nbdns.PriorityDefault)
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA) r.SetQuestion(tt.queryDomain.PunycodeString(), dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
chain.ServeDNS(w, r) chain.ServeDNS(w, r)
@@ -160,17 +161,17 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
handlers []struct { handlers []struct {
pattern string pattern domain.Domain
priority int priority int
} }
queryDomain string queryDomain domain.Domain
expectedCalls int expectedCalls int
expectedHandler int // index of the handler that should be called expectedHandler int // index of the handler that should be called
}{ }{
{ {
name: "wildcard and exact same priority - exact should win", name: "wildcard and exact same priority - exact should win",
handlers: []struct { handlers: []struct {
pattern string pattern domain.Domain
priority int priority int
}{ }{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault}, {pattern: "*.example.com.", priority: nbdns.PriorityDefault},
@@ -183,7 +184,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
{ {
name: "higher priority wildcard over lower priority exact", name: "higher priority wildcard over lower priority exact",
handlers: []struct { handlers: []struct {
pattern string pattern domain.Domain
priority int priority int
}{ }{
{pattern: "example.com.", priority: nbdns.PriorityDefault}, {pattern: "example.com.", priority: nbdns.PriorityDefault},
@@ -196,7 +197,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
{ {
name: "multiple wildcards different priorities", name: "multiple wildcards different priorities",
handlers: []struct { handlers: []struct {
pattern string pattern domain.Domain
priority int priority int
}{ }{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault}, {pattern: "*.example.com.", priority: nbdns.PriorityDefault},
@@ -210,7 +211,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
{ {
name: "subdomain with mix of patterns", name: "subdomain with mix of patterns",
handlers: []struct { handlers: []struct {
pattern string pattern domain.Domain
priority int priority int
}{ }{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault}, {pattern: "*.example.com.", priority: nbdns.PriorityDefault},
@@ -224,7 +225,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
{ {
name: "root zone with specific domain", name: "root zone with specific domain",
handlers: []struct { handlers: []struct {
pattern string pattern domain.Domain
priority int priority int
}{ }{
{pattern: ".", priority: nbdns.PriorityDefault}, {pattern: ".", priority: nbdns.PriorityDefault},
@@ -258,8 +259,8 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
// Create and execute request // Create and execute request
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA) r.SetQuestion(tt.queryDomain.PunycodeString(), dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
chain.ServeDNS(w, r) chain.ServeDNS(w, r)
// Verify expectations // Verify expectations
@@ -316,7 +317,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
}).Once() }).Once()
// Execute // Execute
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
chain.ServeDNS(w, r) chain.ServeDNS(w, r)
// Verify all handlers were called in order // Verify all handlers were called in order
@@ -325,26 +326,12 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
handler3.AssertExpectations(t) handler3.AssertExpectations(t)
} }
// mockResponseWriter implements dns.ResponseWriter for testing
type mockResponseWriter struct {
mock.Mock
}
func (m *mockResponseWriter) LocalAddr() net.Addr { return nil }
func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil }
func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil }
func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
func (m *mockResponseWriter) Close() error { return nil }
func (m *mockResponseWriter) TsigStatus() error { return nil }
func (m *mockResponseWriter) TsigTimersOnly(bool) {}
func (m *mockResponseWriter) Hijack() {}
func TestHandlerChain_PriorityDeregistration(t *testing.T) { func TestHandlerChain_PriorityDeregistration(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
ops []struct { ops []struct {
action string // "add" or "remove" action string // "add" or "remove"
pattern string pattern domain.Domain
priority int priority int
} }
query string query string
@@ -354,7 +341,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
name: "remove high priority keeps lower priority handler", name: "remove high priority keeps lower priority handler",
ops: []struct { ops: []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
}{ }{
{"add", "example.com.", nbdns.PriorityDNSRoute}, {"add", "example.com.", nbdns.PriorityDNSRoute},
@@ -371,7 +358,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
name: "remove lower priority keeps high priority handler", name: "remove lower priority keeps high priority handler",
ops: []struct { ops: []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
}{ }{
{"add", "example.com.", nbdns.PriorityDNSRoute}, {"add", "example.com.", nbdns.PriorityDNSRoute},
@@ -388,7 +375,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
name: "remove all handlers in order", name: "remove all handlers in order",
ops: []struct { ops: []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
}{ }{
{"add", "example.com.", nbdns.PriorityDNSRoute}, {"add", "example.com.", nbdns.PriorityDNSRoute},
@@ -425,7 +412,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
// Create test request // Create test request
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA) r.SetQuestion(tt.query, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Setup expectations // Setup expectations
for priority, handler := range handlers { for priority, handler := range handlers {
@@ -443,14 +430,6 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
for _, handler := range handlers { for _, handler := range handlers {
handler.AssertExpectations(t) handler.AssertExpectations(t)
} }
// Verify handler exists check
for priority, shouldExist := range tt.expectedCalls {
if shouldExist {
assert.True(t, chain.HasHandlers(tt.ops[0].pattern),
"Handler chain should have handlers for pattern after removing priority %d", priority)
}
}
}) })
} }
} }
@@ -458,7 +437,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
func TestHandlerChain_MultiPriorityHandling(t *testing.T) { func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
chain := nbdns.NewHandlerChain() chain := nbdns.NewHandlerChain()
testDomain := "example.com." testDomain := domain.Domain("example.com.")
testQuery := "test.example.com." testQuery := "test.example.com."
// Create handlers with MatchSubdomains enabled // Create handlers with MatchSubdomains enabled
@@ -470,45 +449,69 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(testQuery, dns.TypeA) r.SetQuestion(testQuery, dns.TypeA)
// Keep track of mocks for the final assertion in Step 4
mocks := []*nbdns.MockSubdomainHandler{routeHandler, matchHandler, defaultHandler}
// Add handlers in mixed order // Add handlers in mixed order
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault) chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute) chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain) chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
// Test 1: Initial state with all three handlers // Test 1: Initial state
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Highest priority handler (routeHandler) should be called // Highest priority handler (routeHandler) should be called
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once() routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
matchHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
chain.ServeDNS(w, r) chain.ServeDNS(w1, r)
routeHandler.AssertExpectations(t) routeHandler.AssertExpectations(t)
routeHandler.ExpectedCalls = nil
routeHandler.Calls = nil
matchHandler.ExpectedCalls = nil
matchHandler.Calls = nil
defaultHandler.ExpectedCalls = nil
defaultHandler.Calls = nil
// Test 2: Remove highest priority handler // Test 2: Remove highest priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute) chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
assert.True(t, chain.HasHandlers(testDomain))
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w2 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Now middle priority handler (matchHandler) should be called // Now middle priority handler (matchHandler) should be called
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once() matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure default is not expected yet
chain.ServeDNS(w, r) chain.ServeDNS(w2, r)
matchHandler.AssertExpectations(t) matchHandler.AssertExpectations(t)
matchHandler.ExpectedCalls = nil
matchHandler.Calls = nil
defaultHandler.ExpectedCalls = nil
defaultHandler.Calls = nil
// Test 3: Remove middle priority handler // Test 3: Remove middle priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain) chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
assert.True(t, chain.HasHandlers(testDomain))
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Now lowest priority handler (defaultHandler) should be called // Now lowest priority handler (defaultHandler) should be called
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once() defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
chain.ServeDNS(w, r) chain.ServeDNS(w3, r)
defaultHandler.AssertExpectations(t) defaultHandler.AssertExpectations(t)
defaultHandler.ExpectedCalls = nil
defaultHandler.Calls = nil
// Test 4: Remove last handler // Test 4: Remove last handler
chain.RemoveHandler(testDomain, nbdns.PriorityDefault) chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
assert.False(t, chain.HasHandlers(testDomain)) w4 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
chain.ServeDNS(w4, r) // Call ServeDNS on the now empty chain for this domain
for _, m := range mocks {
m.AssertNumberOfCalls(t, "ServeDNS", 0)
}
} }
func TestHandlerChain_CaseSensitivity(t *testing.T) { func TestHandlerChain_CaseSensitivity(t *testing.T) {
@@ -516,7 +519,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name string name string
scenario string scenario string
addHandlers []struct { addHandlers []struct {
pattern string pattern domain.Domain
priority int priority int
subdomains bool subdomains bool
shouldMatch bool shouldMatch bool
@@ -528,7 +531,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name: "case insensitive exact match", name: "case insensitive exact match",
scenario: "handler registered lowercase, query uppercase", scenario: "handler registered lowercase, query uppercase",
addHandlers: []struct { addHandlers: []struct {
pattern string pattern domain.Domain
priority int priority int
subdomains bool subdomains bool
shouldMatch bool shouldMatch bool
@@ -542,7 +545,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name: "case insensitive wildcard match", name: "case insensitive wildcard match",
scenario: "handler registered mixed case wildcard, query different case", scenario: "handler registered mixed case wildcard, query different case",
addHandlers: []struct { addHandlers: []struct {
pattern string pattern domain.Domain
priority int priority int
subdomains bool subdomains bool
shouldMatch bool shouldMatch bool
@@ -556,7 +559,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name: "multiple handlers different case same domain", name: "multiple handlers different case same domain",
scenario: "second handler should replace first despite case difference", scenario: "second handler should replace first despite case difference",
addHandlers: []struct { addHandlers: []struct {
pattern string pattern domain.Domain
priority int priority int
subdomains bool subdomains bool
shouldMatch bool shouldMatch bool
@@ -571,7 +574,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name: "subdomain matching case insensitive", name: "subdomain matching case insensitive",
scenario: "handler with MatchSubdomains true should match regardless of case", scenario: "handler with MatchSubdomains true should match regardless of case",
addHandlers: []struct { addHandlers: []struct {
pattern string pattern domain.Domain
priority int priority int
subdomains bool subdomains bool
shouldMatch bool shouldMatch bool
@@ -585,7 +588,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name: "root zone case insensitive", name: "root zone case insensitive",
scenario: "root zone handler should match regardless of case", scenario: "root zone handler should match regardless of case",
addHandlers: []struct { addHandlers: []struct {
pattern string pattern domain.Domain
priority int priority int
subdomains bool subdomains bool
shouldMatch bool shouldMatch bool
@@ -599,7 +602,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
name: "multiple handlers different priority", name: "multiple handlers different priority",
scenario: "should call higher priority handler despite case differences", scenario: "should call higher priority handler despite case differences",
addHandlers: []struct { addHandlers: []struct {
pattern string pattern domain.Domain
priority int priority int
subdomains bool subdomains bool
shouldMatch bool shouldMatch bool
@@ -616,7 +619,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain() chain := nbdns.NewHandlerChain()
handlerCalls := make(map[string]bool) // track which patterns were called handlerCalls := make(map[domain.Domain]bool) // track which patterns were called
// Add handlers according to test case // Add handlers according to test case
for _, h := range tt.addHandlers { for _, h := range tt.addHandlers {
@@ -659,7 +662,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
// Execute request // Execute request
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA) r.SetQuestion(tt.query, dns.TypeA)
chain.ServeDNS(&mockResponseWriter{}, r) chain.ServeDNS(&test.MockResponseWriter{}, r)
// Verify each handler was called exactly as expected // Verify each handler was called exactly as expected
for _, h := range tt.addHandlers { for _, h := range tt.addHandlers {
@@ -684,19 +687,19 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
scenario string scenario string
ops []struct { ops []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
subdomain bool subdomain bool
} }
query string query domain.Domain
expectedMatch string expectedMatch domain.Domain
}{ }{
{ {
name: "more specific domain matches first", name: "more specific domain matches first",
scenario: "sub.example.com should match before example.com", scenario: "sub.example.com should match before example.com",
ops: []struct { ops: []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
subdomain bool subdomain bool
}{ }{
@@ -711,7 +714,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
scenario: "sub.example.com should match before example.com", scenario: "sub.example.com should match before example.com",
ops: []struct { ops: []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
subdomain bool subdomain bool
}{ }{
@@ -726,7 +729,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
scenario: "after removing most specific, should fall back to less specific", scenario: "after removing most specific, should fall back to less specific",
ops: []struct { ops: []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
subdomain bool subdomain bool
}{ }{
@@ -743,7 +746,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
scenario: "less specific domain with higher priority should match first", scenario: "less specific domain with higher priority should match first",
ops: []struct { ops: []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
subdomain bool subdomain bool
}{ }{
@@ -758,7 +761,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
scenario: "with equal priority, more specific domain should match", scenario: "with equal priority, more specific domain should match",
ops: []struct { ops: []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
subdomain bool subdomain bool
}{ }{
@@ -774,7 +777,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
scenario: "specific domain should match before wildcard at same priority", scenario: "specific domain should match before wildcard at same priority",
ops: []struct { ops: []struct {
action string action string
pattern string pattern domain.Domain
priority int priority int
subdomain bool subdomain bool
}{ }{
@@ -789,7 +792,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain() chain := nbdns.NewHandlerChain()
handlers := make(map[string]*nbdns.MockSubdomainHandler) handlers := make(map[domain.Domain]*nbdns.MockSubdomainHandler)
for _, op := range tt.ops { for _, op := range tt.ops {
if op.action == "add" { if op.action == "add" {
@@ -802,8 +805,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
} }
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA) r.SetQuestion(tt.query.PunycodeString(), dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Setup handler expectations // Setup handler expectations
for pattern, handler := range handlers { for pattern, handler := range handlers {
@@ -830,3 +833,165 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
}) })
} }
} }
func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
tests := []struct {
name string
addPattern domain.Domain
removePattern domain.Domain
queryPattern domain.Domain
shouldBeRemoved bool
description string
}{
{
name: "exact same pattern",
addPattern: "example.com.",
removePattern: "example.com.",
queryPattern: "example.com.",
shouldBeRemoved: true,
description: "Adding and removing with identical patterns",
},
{
name: "case difference",
addPattern: "Example.Com.",
removePattern: "EXAMPLE.COM.",
queryPattern: "example.com.",
shouldBeRemoved: true,
description: "Adding with mixed case, removing with uppercase",
},
{
name: "reversed case difference",
addPattern: "EXAMPLE.ORG.",
removePattern: "example.org.",
queryPattern: "example.org.",
shouldBeRemoved: true,
description: "Adding with uppercase, removing with lowercase",
},
{
name: "add wildcard, remove wildcard",
addPattern: "*.example.com.",
removePattern: "*.example.com.",
queryPattern: "sub.example.com.",
shouldBeRemoved: true,
description: "Adding and removing with identical wildcard patterns",
},
{
name: "add wildcard, remove transformed pattern",
addPattern: "*.example.net.",
removePattern: "example.net.",
queryPattern: "sub.example.net.",
shouldBeRemoved: false,
description: "Adding with wildcard, removing with non-wildcard pattern",
},
{
name: "add transformed pattern, remove wildcard",
addPattern: "example.io.",
removePattern: "*.example.io.",
queryPattern: "example.io.",
shouldBeRemoved: false,
description: "Adding with non-wildcard pattern, removing with wildcard pattern",
},
{
name: "trailing dot difference",
addPattern: "example.dev",
removePattern: "example.dev.",
queryPattern: "example.dev.",
shouldBeRemoved: true,
description: "Adding without trailing dot, removing with trailing dot",
},
{
name: "reversed trailing dot difference",
addPattern: "example.app.",
removePattern: "example.app",
queryPattern: "example.app.",
shouldBeRemoved: true,
description: "Adding with trailing dot, removing without trailing dot",
},
{
name: "mixed case and wildcard",
addPattern: "*.Example.Site.",
removePattern: "*.EXAMPLE.SITE.",
queryPattern: "sub.example.site.",
shouldBeRemoved: true,
description: "Adding mixed case wildcard, removing uppercase wildcard",
},
{
name: "root zone",
addPattern: ".",
removePattern: ".",
queryPattern: "random.domain.",
shouldBeRemoved: true,
description: "Adding and removing root zone",
},
{
name: "wrong domain",
addPattern: "example.com.",
removePattern: "different.com.",
queryPattern: "example.com.",
shouldBeRemoved: false,
description: "Adding one domain, trying to remove a different domain",
},
{
name: "subdomain mismatch",
addPattern: "sub.example.com.",
removePattern: "example.com.",
queryPattern: "sub.example.com.",
shouldBeRemoved: false,
description: "Adding subdomain, trying to remove parent domain",
},
{
name: "parent domain mismatch",
addPattern: "example.com.",
removePattern: "sub.example.com.",
queryPattern: "example.com.",
shouldBeRemoved: false,
description: "Adding parent domain, trying to remove subdomain",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
handler := &nbdns.MockHandler{}
r := new(dns.Msg)
r.SetQuestion(tt.queryPattern.PunycodeString(), dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// First verify no handler is called before adding any
chain.ServeDNS(w, r)
handler.AssertNotCalled(t, "ServeDNS")
// Add handler
chain.AddHandler(tt.addPattern, handler, nbdns.PriorityDefault)
// Verify handler is called after adding
handler.On("ServeDNS", mock.Anything, r).Once()
chain.ServeDNS(w, r)
handler.AssertExpectations(t)
// Reset mock for the next test
handler.ExpectedCalls = nil
// Remove handler
chain.RemoveHandler(tt.removePattern, nbdns.PriorityDefault)
// Set up expectations based on whether removal should succeed
if !tt.shouldBeRemoved {
handler.On("ServeDNS", mock.Anything, r).Once()
}
// Test if handler is still called after removal attempt
chain.ServeDNS(w, r)
if tt.shouldBeRemoved {
handler.AssertNotCalled(t, "ServeDNS",
"Handler should not be called after successful removal with pattern %q",
tt.removePattern)
} else {
handler.AssertExpectations(t)
handler.ExpectedCalls = nil
}
})
}
}

View File

@@ -5,15 +5,18 @@ import (
"net/netip" "net/netip"
"strings" "strings"
"github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
) )
var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured") var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
const ( const (
ipv4ReverseZone = ".in-addr.arpa" ipv4ReverseZone = ".in-addr.arpa."
ipv6ReverseZone = ".ip6.arpa" ipv6ReverseZone = ".ip6.arpa."
) )
type hostManager interface { type hostManager interface {
@@ -38,7 +41,7 @@ type HostDNSConfig struct {
type DomainConfig struct { type DomainConfig struct {
Disabled bool `json:"disabled"` Disabled bool `json:"disabled"`
Domain string `json:"domain"` Domain domain.Domain `json:"domain"`
MatchOnly bool `json:"matchOnly"` MatchOnly bool `json:"matchOnly"`
} }
@@ -101,18 +104,20 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
config.RouteAll = true config.RouteAll = true
} }
for _, domain := range nsConfig.Domains { for _, d := range nsConfig.Domains {
d := strings.ToLower(dns.Fqdn(d.PunycodeString()))
config.Domains = append(config.Domains, DomainConfig{ config.Domains = append(config.Domains, DomainConfig{
Domain: strings.TrimSuffix(domain, "."), Domain: domain.Domain(d),
MatchOnly: !nsConfig.SearchDomainsEnabled, MatchOnly: !nsConfig.SearchDomainsEnabled,
}) })
} }
} }
for _, customZone := range dnsConfig.CustomZones { for _, customZone := range dnsConfig.CustomZones {
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone) d := strings.ToLower(dns.Fqdn(customZone.Domain))
matchOnly := strings.HasSuffix(d, ipv4ReverseZone) || strings.HasSuffix(d, ipv6ReverseZone)
config.Domains = append(config.Domains, DomainConfig{ config.Domains = append(config.Domains, DomainConfig{
Domain: strings.TrimSuffix(customZone.Domain, "."), Domain: domain.Domain(d),
MatchOnly: matchOnly, MatchOnly: matchOnly,
}) })
} }

View File

@@ -79,10 +79,10 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
continue continue
} }
if dConf.MatchOnly { if dConf.MatchOnly {
matchDomains = append(matchDomains, dConf.Domain) matchDomains = append(matchDomains, strings.TrimSuffix(dConf.Domain.PunycodeString(), "."))
continue continue
} }
searchDomains = append(searchDomains, dConf.Domain) searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain.PunycodeString(), "."))
} }
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)

Some files were not shown because too many files have changed in this diff Show More