Compare commits

...

143 Commits

Author SHA1 Message Date
Pascal Fischer
39822808b3 add log message 2025-05-20 18:20:41 +02:00
Pascal Fischer
694a54d196 make interval configurable 2025-05-20 18:13:29 +02:00
Pascal Fischer
118ca450a6 go mod tidy 2025-05-20 17:38:13 +02:00
Pascal Fischer
a749e4fe73 add peer sync limiter 2025-05-20 17:32:22 +02:00
Viktor Liu
1d4cfb83e7 [client] Fix UI new version notifier (#3845) 2025-05-20 10:39:17 +02:00
Pascal Fischer
207fa059d2 [management] make locking strength clause optional (#3844) 2025-05-19 16:42:47 +02:00
Viktor Liu
cbcdad7814 [misc] Update issue template (#3842) 2025-05-19 15:41:24 +02:00
Pascal Fischer
701c13807a [management] add flag to disable auto-migration (#3840) 2025-05-19 13:36:24 +02:00
Viktor Liu
99f8dc7748 [client] Offer to remove netbird data in windows uninstall (#3766) 2025-05-16 17:39:30 +02:00
Pascal Fischer
f1de8e6eb0 [management] Make startup period configurable (#3767) 2025-05-16 13:16:51 +02:00
Viktor Liu
b2a10780af [client] Disable dnssec for systemd explicitly (#3831) 2025-05-16 09:43:13 +02:00
Pascal Fischer
43ae79d848 [management] extend rest client lib (#3830) 2025-05-15 18:20:29 +02:00
Pascal Fischer
e520b64c6d [signal] remove stream receive server side (#3820) 2025-05-14 19:28:51 +02:00
hakansa
92c91bbdd8 [client] Add FreeBSD desktop client support to OAuth flow (#3822)
[client] Add FreeBSD desktop client support to OAuth flow
2025-05-14 19:52:02 +03:00
Vlad
adf494e1ac [management] fix a bug with missed extra dns labels for a new peer (#3798) 2025-05-14 17:50:21 +02:00
Vlad
2158461121 [management,client] PKCE add flag parameter prompt=login or max_age (#3824) 2025-05-14 17:48:51 +02:00
Bethuel Mmbaga
0cd4b601c3 [management] Add connection type filter to Network Traffic API (#3815) 2025-05-14 11:15:50 +03:00
Zoltan Papp
ee1cec47b3 [client, android] Do not propagate empty routes (#3805)
If we get domain routes the Network prefix variable in route structure will be invalid (engine.go:1057). When we handower to Android the routes, we must to filter out the domain routes. If we do not do it the Android code will get "invalid prefix" string as a route.
2025-05-13 15:21:06 +02:00
Pascal Fischer
efb0edfc4c [signal] adjust signal log levels 2 (#3817) 2025-05-12 23:52:29 +02:00
Pascal Fischer
20f59ddecb [signal] adjust log levels (#3813) 2025-05-12 19:48:47 +02:00
hakansa
2f34e984b0 [client] Add TCP support to DNS forwarder service listener (#3790)
[client] Add TCP support to DNS forwarder service listener
2025-05-09 15:06:34 +03:00
Viktor Liu
d5b52e86b6 [client] Ignore irrelevant route changes to tracked network monitor routes (#3796) 2025-05-09 14:01:21 +02:00
Zoltan Papp
cad2fe1f39 Return with the correct copied length (#3804) 2025-05-09 13:56:27 +02:00
Pascal Fischer
fcd2c15a37 [management] policy delete cleans policy rules (#3788) 2025-05-07 07:25:25 +02:00
Bethuel Mmbaga
ebda0fc538 [management] Delete service users with account manager (#3793) 2025-05-06 17:31:03 +02:00
M. Essam
ac135ab11d [management/client/rest] fix panic when body is nil (#3714)
Fixes panic occurring when body is nil (this usually happens when connections is refused) due to lack of nil check by centralizing response.Body.Close() behavior.
2025-05-05 18:54:47 +02:00
Pascal Fischer
25faf9283d [management] removal of foreign key constraint enforcement on sqlite (#3786) 2025-05-05 18:21:48 +02:00
hakansa
59faaa99f6 [client] Improve NetBird installation script to handle daemon connection timeout (#3761)
[client] Improve NetBird installation script to handle daemon connection timeout
2025-05-05 17:05:01 +03:00
Viktor Liu
9762b39f29 [client] Fix stale local records (#3776) 2025-05-05 14:29:05 +02:00
Alin Trăistaru
ffdd115ded [client] set TLS ServerName for hostname-based QUIC connections (#3673)
* fix: set TLS ServerName for hostname-based QUIC connections

When connecting to a relay server by hostname, certificates are
validated against the IP address instead of the hostname.
This change sets ServerName in the TLS config when connecting
via hostname, ensuring proper certificate validation.

* use default port if port is missing in URL string
2025-05-05 12:20:54 +02:00
Pascal Fischer
055df9854c [management] add gorm tag for primary key for the networks objects (#3758) 2025-05-04 20:58:04 +02:00
Maycon Santos
12f883badf [management] Optimize load account (#3774) 2025-05-02 00:59:41 +02:00
Maycon Santos
2abb92b0d4 [management] Get account id with order (#3773)
updated log to display account id
2025-05-02 00:25:46 +02:00
Viktor Liu
01c3719c5d [client] Add debug for duration option to netbird ui (#3772) 2025-05-01 23:25:27 +02:00
Pedro Maia Costa
7b64953eed [management] user info with role permissions (#3728) 2025-05-01 11:24:55 +01:00
Viktor Liu
9bc7d788f0 [client] Add debug upload option to netbird ui (#3768) 2025-05-01 00:48:31 +02:00
Pedro Maia Costa
b5419ef11a [management] limit peers based on module read permission (#3757) 2025-04-30 15:53:18 +01:00
Zoltan Papp
d5081cef90 [client] Revert mgm client error handling (#3764) 2025-04-30 13:09:00 +02:00
Bethuel Mmbaga
488e619ec7 [management] Add network traffic events pagination (#3580)
* Add network traffic events pagination schema
2025-04-30 11:51:40 +03:00
hakansa
d2b42c8f68 [client] Add macOS .pkg installer support to installation script (#3755)
[client] Add macOS .pkg installer support to installation script
2025-04-29 13:43:42 +03:00
Maycon Santos
2f44fe2e23 [client] Feature/upload bundle (#3734)
Add an upload bundle option with the flag --upload-bundle; by default, the upload will use a NetBird address, which can be replaced using the flag --upload-bundle-url.

The upload server is available under the /upload-server path. The release change will push a docker image to netbirdio/upload image repository.

The server supports using s3 with pre-signed URL for direct upload and local file for storing bundles.
2025-04-29 00:43:50 +02:00
Bethuel Mmbaga
d8dc107bee [management] Skip IdP cache warm-up on Redis if data exists (#3733)
* Add Redis cache check to skip warm-up on startup if cache is already populated
* Refactor Redis test container setup for reusability
2025-04-28 15:10:40 +03:00
Viktor Liu
3fa915e271 [misc] Exclude client benchmarks from CI (#3752) 2025-04-28 13:40:36 +02:00
Pedro Maia Costa
47c3afe561 [management] add missing network admin mapping (#3751) 2025-04-28 11:05:27 +01:00
hakansa
84bfecdd37 [client] add byte counters & ruleID for routed traffic on userspace (#3653)
* [client] add byte counters for routed traffic on userspace 
* [client] add allowed ruleID for routed traffic on userspace
2025-04-28 10:10:41 +03:00
Viktor Liu
3cf87b6846 [client] Run container tests more generically (#3737) 2025-04-25 18:50:44 +02:00
Maycon Santos
4fe4c2054d [client] Move static check when running on foreground (#3742) 2025-04-25 18:25:48 +02:00
Pascal Fischer
38ada44a0e [management] allow impersonation via pats (#3739) 2025-04-25 16:40:54 +02:00
Pedro Maia Costa
dbf81a145e [management] network admin role (#3720) 2025-04-25 15:14:32 +01:00
Pedro Maia Costa
39483f8ca8 [management] Auditor role (#3721) 2025-04-25 15:04:25 +01:00
Carlos Hernandez
c0eaea938e [client] Fix macos privacy warning when checking static info (#3496)
avoid checking static info with a init call
2025-04-25 14:41:57 +02:00
Viktor Liu
ef8b8a2891 [client] Ensure dst-type local marks can overwrite nat marks (#3738) 2025-04-25 12:43:20 +02:00
Zoltan Papp
2817f62c13 [client] Fix error handling case of flow grpc error (#3727)
When a gRPC error occurs in the Flow package, it will be propagated to the upper layers and handled similarly to a Management gRPC error.

Always report a disconnected state in the event of any error
Hide the underlying gRPC errors
Force close the gRPC connection in the event of any error
2025-04-25 09:26:18 +02:00
Viktor Liu
4a9049566a [client] Set up firewall rules for dns routes dynamically based on dns response (#3702) 2025-04-24 17:37:28 +02:00
Viktor Liu
85f92f8321 [client] Add more userspace filter ACL test cases (#3730) 2025-04-24 12:57:46 +02:00
Viktor Liu
714beb6e3b [client] Fix exit node deselection (#3722) 2025-04-24 12:36:05 +02:00
Viktor Liu
400b9fca32 [management] Add firewall rule route ID and missing route domains (#3700) 2025-04-23 21:29:46 +02:00
hakansa
4013298e22 [client/ui] add connecting state to status handling (#3712) 2025-04-23 21:04:38 +02:00
Pascal Fischer
312bfd9bd7 [management] support custom domains per account (#3726) 2025-04-23 19:36:53 +02:00
Pascal Fischer
8db05838ca [misc] Change github runner for docker test (#3707) 2025-04-23 19:35:26 +02:00
Misha Bragin
c69df13515 [management] Add account meta (#3724) 2025-04-23 18:44:22 +02:00
Pascal Fischer
986eb8c1e0 [management] fix lastLogin on dashboard (#3725) 2025-04-23 15:54:49 +02:00
dependabot[bot]
197761ba4d Bump github.com/redis/go-redis/v9 from 9.7.1 to 9.7.3 (#3553)
Bumps [github.com/redis/go-redis/v9](https://github.com/redis/go-redis) from 9.7.1 to 9.7.3.
- [Release notes](https://github.com/redis/go-redis/releases)
- [Changelog](https://github.com/redis/go-redis/blob/master/CHANGELOG.md)
- [Commits](https://github.com/redis/go-redis/compare/v9.7.1...v9.7.3)

---
updated-dependencies:
- dependency-name: github.com/redis/go-redis/v9
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-04-23 10:21:36 +02:00
dependabot[bot]
f74ea64c7b Bump golang.org/x/net from 0.36.0 to 0.38.0 (#3695)
Bumps [golang.org/x/net](https://github.com/golang/net) from 0.36.0 to 0.38.0.
- [Commits](https://github.com/golang/net/compare/v0.36.0...v0.38.0)

---
updated-dependencies:
- dependency-name: golang.org/x/net
  dependency-version: 0.38.0
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-04-23 10:20:51 +02:00
Viktor Liu
3b7b9d25bc [client] Keep new routes selected unless all are deselected (#3692) 2025-04-23 01:07:04 +02:00
Pascal Fischer
1a6d6b3109 [management] fix github run id (#3705) 2025-04-18 11:21:54 +02:00
Pascal Fischer
f686615876 [management] benchmarks use ref_name instead (#3704) 2025-04-17 21:57:54 +02:00
Pascal Fischer
a4311f574d [management] push benchmark results to grafana (#3701) 2025-04-17 21:01:23 +02:00
Pierre Timmermans
0bb8eae903 [docs] fix: broken link in the README file (#3697)
improve README.md, broken link for activity logging
2025-04-17 14:48:10 +02:00
Pascal Fischer
e0b33d325d [management] permissions manager use crud operations (#3690) 2025-04-16 17:25:03 +02:00
Zoltan Papp
c38e07d89a [client] Fix Rosenpass permissive mode handling (#3689)
fixes the Rosenpass preshared key handling to enable successful WireGuard handshakes when one side is in permissive mode. Key changes include:

Updating field accesses from RosenpassPubKey/RosenpassAddr to RosenpassConfig.PubKey/RosenpassConfig.Addr.
Modifying the preshared key computation logic to account for permissive mode.
Revising peer configuration in the Engine to use the new RosenpassConfig struct.
2025-04-16 16:04:43 +02:00
Lamera
a37368fff4 [misc] update gpt file permissions in install.sh (#3663)
* Fix install.sh for some installations

Fix install.sh for some installations by explicitly setting the file permissions

* Add sudo
2025-04-16 14:23:25 +02:00
Viktor Liu
0c93bd3d06 [client] Keep selecting new networks after first deselection (#3671) 2025-04-16 13:55:26 +02:00
Viktor Liu
a675531b5c [client] Set up signal to generate debug bundles (#3683) 2025-04-16 11:06:22 +02:00
hakansa
7cb366bc7d [client] Remove logrus writer assignment in pion logging (#3684) 2025-04-15 18:15:52 +03:00
Viktor Liu
a354004564 [client] Add remaining debug profiles (#3681) 2025-04-15 13:06:28 +02:00
Pedro Maia Costa
75bdd47dfb [management] get current user endpoint (#3666) 2025-04-15 11:06:07 +01:00
Viktor Liu
b165f63327 [client] Add heap profile to debug bundle (#3679) 2025-04-15 11:36:41 +02:00
hakansa
51bb52cdf5 [client] Refactor DNSForwarder to improve handle wildcard domain resource id matching (#3651)
[client] Refactor DNSForwarder to improve handle wildcard domain resource id matching (#3651)
2025-04-15 10:54:17 +03:00
Pedro Maia Costa
4134b857b4 [management] add permissions manager to geolocation handler (#3665) 2025-04-14 17:57:58 +01:00
Vlad
7839d2c169 [management] Refactor/management/updchannel (#3645)
* refactoring updatechannel - use read mutex for send update
2025-04-11 18:22:59 +03:00
Pascal Fischer
b9f82e2f8a [management] Buffer updateAccountPeers calls (#3644) 2025-04-11 17:21:05 +02:00
Pedro Maia Costa
fd2a21c65d [management] remove unnecessary access control middleware (#3650) 2025-04-11 10:43:59 +01:00
Maycon Santos
82d982b0ab [management,client] Add support to configurable prompt login (#3660) 2025-04-11 11:34:55 +02:00
Maycon Santos
9e24fe7701 [docs] Fix a few typos on table (#3658) 2025-04-10 17:57:39 +02:00
Pedro Maia Costa
e470701b80 [ci] include stash in pr template (#3657) 2025-04-10 16:30:44 +01:00
Viktor Liu
e3ce026355 [client] Fix race dns cleanup race condition (#3652) 2025-04-10 13:21:14 +02:00
Pascal Fischer
5ea2806663 [management] use permission modules (#3622) 2025-04-10 11:06:52 +02:00
Viktor Liu
d6b0673580 [client] Support CNAME in local resolver (#3646) 2025-04-10 10:38:47 +02:00
Pedro Maia Costa
14913cfa7a add git town config (#3555) 2025-04-09 20:18:52 +01:00
Viktor Liu
03f600b576 [client] Fallback to TCP if a truncated UDP response is received from upstream DNS (#3632) 2025-04-08 13:41:13 +02:00
Viktor Liu
192c97aa63 [client] Support IP fragmentation in userspace (#3639) 2025-04-08 12:49:14 +02:00
Maycon Santos
4db78db49a [misc] Update FreeBSD workflow (#3638)
Update FreeBSD release to 14.2 and download Go package directly since port wasn't finding the package to install
2025-04-08 09:15:09 +02:00
Viktor Liu
87e600a4f3 [client] Automatically register match domains for DNS routes (#3614) 2025-04-07 15:18:45 +02:00
Viktor Liu
6162aeb82d [client] Mark netbird data plane traffic to identify interface traffic correctly (#3623) 2025-04-07 13:14:56 +02:00
hakansa
1ba1e092ce [client] Enhance DNS forwarder to track resolved IPs with resource IDs on routing peers (#3620)
[client] Enhance DNS forwarder to track resolved IPs with resource IDs on routing peers (#3620)
2025-04-07 15:16:12 +08:00
hakansa
86dbb4ee4f [client] Add no-browser flag to login and up commands for SSO login control (#3610)
* [client] Add no-browser flag to login and up commands for SSO login control (#3610)
2025-04-07 14:39:53 +08:00
hakansa
4af177215f [client] Fix Status Recorder Route Removal Logic to Handle Dynamic Routes Correctly 2025-04-06 09:57:28 +08:00
Viktor Liu
df9c1b9883 [client] Improve TCP conn tracking (#3572) 2025-04-05 11:42:15 +02:00
Viktor Liu
5752bb78f2 [client] Fix missing inbound flows in Linux userspace mode with native router (#3624)
* Fix missing inbound flows in Linux userspace mode with native router

* Fix route enable/disable order for userspace mode
2025-04-05 11:41:31 +02:00
Maycon Santos
fbd783ad58 [client] Use the netbird logger for ice and grpc (#3603)
updates the logging implementation to use the netbird logger for both ICE and gRPC components. The key changes include:

- Introducing a gRPC logger configuration in util/log.go that integrates with the netbird logging setup.
- Updating the log hook in formatter/hook/hook.go to ensure a default caller is used when not set.
- Refactoring ICE agent and UDP multiplexers to use a unified logger via the new getLogger() method.
2025-04-04 18:30:47 +02:00
Viktor Liu
80702b9323 [client] Fix dns forwarder handling of requested record types (#3615) 2025-04-03 13:58:36 +02:00
Viktor Liu
09243a0fe0 [management] Remove remaining backend linux router limitation (#3589) 2025-04-01 21:29:57 +02:00
Maycon Santos
3658215747 [client] Force new user login on PKCE auth in CLI (#3604)
With this change, browser session won't be considered for cli authentication and credentials will be requested
2025-04-01 10:29:29 +02:00
Viktor Liu
48ffec95dd Improve local ip lookup (#3551)
- lower memory footprint in most cases
- increase accuracy
2025-03-31 10:05:57 +02:00
Pedro Maia Costa
cbec7bda80 [management] permission manager validate account access (#3444) 2025-03-30 17:08:22 +02:00
Zoltan Papp
21464ac770 [client] Fix close WireGuard watcher (#3598)
This PR fixes issues with closing the WireGuard watcher by adjusting its asynchronous invocation and synchronization.

Update tests in wg_watcher_test.go to launch the watcher in a goroutine and add a delay for timing.
Modify wg_watcher.go to run the periodic handshake check synchronously by removing the waitGroup and goroutine.
Enhance conn.go to wait on the watcher wait group during connection close and add a note for potential further synchronization
2025-03-28 20:12:31 +01:00
Zoltan Papp
ed5647028a [client] Prevent calling the onDisconnected callback in incorrect state (#3582)
Prevent calling the onDisconnected callback if the ICE connection has never been established

If call onDisconnected without onConnected then overwrite the relayed status in the conn priority variable.
2025-03-28 18:08:26 +01:00
Viktor Liu
29a6e5be71 [client] Stop flow grpc receiver properly (#3596) 2025-03-28 16:08:31 +01:00
Viktor Liu
6124e3b937 [client] Disable systemd-resolved default route explicitly on match domains only (#3584) 2025-03-28 11:14:32 +01:00
Maycon Santos
50f5cc48cd [management] Fix extended config when nil (#3593)
* Fix extended config when nil

* update integrations
2025-03-27 23:07:10 +01:00
Viktor Liu
101cce27f2 [client] Ensure status recorder is always initialized (#3588)
* Ensure status recorder is always initialized

* Add test

* Add subscribe test
2025-03-27 22:48:11 +01:00
Maycon Santos
a4f04f5570 [management] fix extend call and move config to types (#3575)
This PR fixes configuration inconsistencies and updates the store engine type usage throughout the management code. Key changes include:
- Replacing outdated server.Config references with types.Config and updating related flag variables (e.g. types.MgmtConfigPath).
- Converting engine constants (SqliteStoreEngine, PostgresStoreEngine, MysqlStoreEngine) to use types.Engine for consistent type–safety.
- Adjusting various test and migration code paths to correctly reference the new configuration and engine types.
2025-03-27 13:04:50 +01:00
hakansa
fceb3ca392 [client] fix route handling for local peer state (#3586) 2025-03-27 19:31:04 +08:00
Bethuel Mmbaga
34d86c5ab8 [management] Sync account peers on network router group changes (#3573)
- Updates account peers when a group linked to a network router is modified
- Prevents group deletion if it's still being used by any network router
2025-03-27 12:19:22 +01:00
Maycon Santos
9cbcf7531f [management] Fix invalid port range sync (#3571)
We should not send port range when a port is set or when protocol is all or icmp
2025-03-24 00:56:51 +01:00
Maycon Santos
bd8f0c1ef3 [client] add profiling dumps to debug package (#3517)
enhances debugging capabilities by adding support for goroutine, mutex, and block profiling while updating state dump tracking and refining test and release settings.

- Adds pprof-based profiling for goroutine, mutex, and block profiles in the debug bundle.
- Updates state dump functionality by incorporating new status and key fields.
- Adjusts test validations and default flag/retention settings.
2025-03-23 13:46:09 +01:00
Renat Galiev
051a5a4adc [misc] chore: remove duplicate labels for services.relay in docker-compose.yml.tmpl.traefik (#3502)
Signed-off-by: Renat Galiev <renat@galiev.net>
2025-03-22 23:14:42 +01:00
Maycon Santos
8b4c0c58e4 [client] Add initiator field to ack (#3563)
added the new field and client handling
2025-03-22 22:22:34 +01:00
Viktor Liu
99b41543b8 [client] Fix flows for embedded listeners (#3564) 2025-03-22 18:51:48 +01:00
Viktor Liu
2bbe0f3f09 [client] Don't permanently fail on flow grpc shutdown (#3557) 2025-03-22 11:56:00 +01:00
Misha Bragin
9325fb7990 Remove UI client Admin Panel item (#3560) 2025-03-21 18:48:15 +01:00
Pascal Fischer
f081435a56 [management] add log when using redis cache (#3562) 2025-03-21 18:16:27 +01:00
Pascal Fischer
b62a1b56ce [docs] rename network traffic logging to traffic events (#3556) 2025-03-21 16:32:47 +01:00
Pascal Fischer
8d7c92c661 [management] add receive timestamp to traffic event (#3559) 2025-03-21 16:31:23 +01:00
Maycon Santos
d9d051cb1e Add initiator field and parse url (#3558)
- Add initiator field to flow proto
- Parse URL
- Update a few trace logs
2025-03-21 14:47:04 +01:00
Maycon Santos
cb318b7ef4 [client] Use UTC on event generation (#3554) 2025-03-21 11:14:51 +01:00
Pascal Fischer
8f0aa8352a [docs] add examples to events and tag to ingress port (#3552) 2025-03-20 18:26:08 +01:00
Maycon Santos
c02e236196 [client,management] add netflow support to client and update management (#3414)
adds NetFlow functionality to track and log network traffic information between peers, with features including:

- Flow logging for TCP, UDP, and ICMP traffic
- Integration with connection tracking system
- Resource ID tracking in NetFlow events
- DNS and exit node collection configuration
- Flow API and Redis cache in management
- Memory-based flow storage implementation
- Kernel conntrack counters and userspace counters
- TCP state machine improvements for more accurate tracking
- Migration from net.IP to netip.Addr in the userspace firewall
2025-03-20 17:05:48 +01:00
Dominik
f51e0b59bd [management] Posture checks handle suffixes like "-dev" in netbird version (#3511) 2025-03-20 16:28:39 +01:00
Misha Bragin
32ec42a667 Update CONTRIBUTOR_LICENSE_AGREEMENT.md (#3535) 2025-03-19 15:11:58 +01:00
Alexandre JARDON
9929daf6ce [client] Fix DNS Nrpt policies (#3459) 2025-03-18 22:57:41 +01:00
M. Essam
939419a0ea [management] Add Bearer token support (#3534) 2025-03-18 21:48:36 +01:00
Christian Alexander Sauer Mark
919fe94fd5 Fix always enabling of NetworkResource in createResource() (#3532) 2025-03-18 19:41:15 +01:00
dependabot[bot]
df71cb4690 [client,management] Bump golang.org/x/net from 0.33.0 to 0.36.0 (#3492)
Bumps [golang.org/x/net](https://github.com/golang/net) from 0.33.0 to 0.36.0.
- [Commits](https://github.com/golang/net/compare/v0.33.0...v0.36.0)

---
updated-dependencies:
- dependency-name: golang.org/x/net
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-03-16 17:55:22 +01:00
levindecaro
4508c61728 [client] Fix Advanced Setting unable to open on Windows 11 with Chinese Locale Setting. (#3483)
Fix #3345 and #2603
2025-03-16 17:51:42 +01:00
Viktor Liu
0ef476b014 [client] Fix state dump panic (#3519) 2025-03-16 15:13:04 +01:00
Zoltan Papp
6f82e96d6a [client] Set info logs (#3504)
collect and log connection stats per peer every 10 minutes
2025-03-14 22:34:41 +01:00
Viktor Liu
a2faae5d62 [client] Fix anonymized addresses documentation (#3505) 2025-03-14 11:38:16 +01:00
Zoltan Papp
4a3cbcd38a Nil check on route manager (#3486) 2025-03-13 00:04:00 +01:00
Misha Bragin
c2980bc8cf Update link to kubernetes operator (#3489) 2025-03-12 21:18:19 +01:00
Pascal Fischer
67ae871ce4 [management] return empty array instead of null on networks endpoints (#3480) 2025-03-11 00:20:54 +01:00
Maycon Santos
39ff5e833a [misc] Update slack invite link (#3479) 2025-03-11 00:12:11 +01:00
365 changed files with 23489 additions and 8263 deletions

27
.git-branches.toml Normal file
View File

@@ -0,0 +1,27 @@
# More info around this file at https://www.git-town.com/configuration-file
[branches]
main = "main"
perennials = []
perennial-regex = ""
[create]
new-branch-type = "feature"
push-new-branches = false
[hosting]
dev-remote = "origin"
# platform = ""
# origin-hostname = ""
[ship]
delete-tracking-branch = false
strategy = "squash-merge"
[sync]
feature-strategy = "merge"
perennial-strategy = "rebase"
prototype-strategy = "merge"
push-hook = true
tags = true
upstream = false

View File

@@ -37,17 +37,22 @@ If yes, which one?
**Debug output** **Debug output**
To help us resolve the problem, please attach the following debug output To help us resolve the problem, please attach the following anonymized status output
netbird status -dA netbird status -dA
As well as the file created by Create and upload a debug bundle, and share the returned file key:
netbird debug for 1m -AS -U
*Uploaded files are automatically deleted after 30 days.*
Alternatively, create the file only and attach it here manually:
netbird debug for 1m -AS netbird debug for 1m -AS
We advise reviewing the anonymized output for any remaining personal information.
**Screenshots** **Screenshots**
If applicable, add screenshots to help explain your problem. If applicable, add screenshots to help explain your problem.
@@ -57,8 +62,10 @@ If applicable, add screenshots to help explain your problem.
Add any other context about the problem here. Add any other context about the problem here.
**Have you tried these troubleshooting steps?** **Have you tried these troubleshooting steps?**
- [ ] Reviewed [client troubleshooting](https://docs.netbird.io/how-to/troubleshooting-client) (if applicable)
- [ ] Checked for newer NetBird versions - [ ] Checked for newer NetBird versions
- [ ] Searched for similar issues on GitHub (including closed ones) - [ ] Searched for similar issues on GitHub (including closed ones)
- [ ] Restarted the NetBird client - [ ] Restarted the NetBird client
- [ ] Disabled other VPN software - [ ] Disabled other VPN software
- [ ] Checked firewall settings - [ ] Checked firewall settings

View File

@@ -2,6 +2,10 @@
## Issue ticket number and link ## Issue ticket number and link
## Stack
<!-- branch-stack -->
### Checklist ### Checklist
- [ ] Is it a bug fix - [ ] Is it a bug fix
- [ ] Is a typo/documentation fix - [ ] Is a typo/documentation fix

21
.github/workflows/git-town.yml vendored Normal file
View File

@@ -0,0 +1,21 @@
name: Git Town
on:
pull_request:
branches:
- '**'
jobs:
git-town:
name: Display the branch stack
runs-on: ubuntu-latest
permissions:
contents: read
pull-requests: write
steps:
- uses: actions/checkout@v4
- uses: git-town/action@v1
with:
skip-single-stacks: true

View File

@@ -22,14 +22,20 @@ jobs:
with: with:
usesh: true usesh: true
copyback: false copyback: false
release: "14.1" release: "14.2"
prepare: | prepare: |
pkg install -y go pkgconf xorg pkg install -y curl pkgconf xorg
LATEST_VERSION=$(curl -s https://go.dev/VERSION?m=text|head -n 1)
GO_TARBALL="$LATEST_VERSION.freebsd-amd64.tar.gz"
GO_URL="https://go.dev/dl/$GO_TARBALL"
curl -vLO "$GO_URL"
tar -C /usr/local -vxzf "$GO_TARBALL"
# -x - to print all executed commands # -x - to print all executed commands
# -e - to faile on first error # -e - to faile on first error
run: | run: |
set -e -x set -e -x
export PATH=$PATH:/usr/local/go/bin:$HOME/go/bin
time go build -o netbird client/main.go time go build -o netbird client/main.go
# check all component except management, since we do not support management server on freebsd # check all component except management, since we do not support management server on freebsd
time go test -timeout 1m -failfast ./base62/... time go test -timeout 1m -failfast ./base62/...

View File

@@ -146,6 +146,65 @@ jobs:
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay) run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
test_client_on_docker:
name: "Client (Docker) / Unit"
needs: [build-cache]
runs-on: ubuntu-22.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
id: go-env
run: |
echo "cache_dir=$(go env GOCACHE)" >> $GITHUB_OUTPUT
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
- name: Cache Go modules
uses: actions/cache/restore@v4
id: cache-restore
with:
path: |
${{ steps.go-env.outputs.cache_dir }}
${{ steps.go-env.outputs.modcache_dir }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Run tests in container
env:
HOST_GOCACHE: ${{ steps.go-env.outputs.cache_dir }}
HOST_GOMODCACHE: ${{ steps.go-env.outputs.modcache_dir }}
run: |
CONTAINER_GOCACHE="/root/.cache/go-build"
CONTAINER_GOMODCACHE="/go/pkg/mod"
docker run --rm \
--cap-add=NET_ADMIN \
--privileged \
-v $PWD:/app \
-w /app \
-v "${HOST_GOCACHE}:${CONTAINER_GOCACHE}" \
-v "${HOST_GOMODCACHE}:${CONTAINER_GOMODCACHE}" \
-e CGO_ENABLED=1 \
-e CI=true \
-e DOCKER_CI=true \
-e GOARCH=${GOARCH_TARGET} \
-e GOCACHE=${CONTAINER_GOCACHE} \
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
golang:1.23-alpine \
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
'
test_relay: test_relay:
name: "Relay / Unit" name: "Relay / Unit"
needs: [build-cache] needs: [build-cache]
@@ -179,13 +238,6 @@ jobs:
restore-keys: | restore-keys: |
${{ runner.os }}-gotest-cache- ${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules - name: Install modules
run: go mod tidy run: go mod tidy
@@ -232,13 +284,6 @@ jobs:
restore-keys: | restore-keys: |
${{ runner.os }}-gotest-cache- ${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules - name: Install modules
run: go mod tidy run: go mod tidy
@@ -286,13 +331,6 @@ jobs:
restore-keys: | restore-keys: |
${{ runner.os }}-gotest-cache- ${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules - name: Install modules
run: go mod tidy run: go mod tidy
@@ -314,6 +352,7 @@ jobs:
run: | run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \ NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
go test -tags=devcert \ go test -tags=devcert \
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \ -exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
-timeout 20m ./management/... -timeout 20m ./management/...
@@ -353,13 +392,6 @@ jobs:
restore-keys: | restore-keys: |
${{ runner.os }}-gotest-cache- ${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules - name: Install modules
run: go mod tidy run: go mod tidy
@@ -380,10 +412,11 @@ jobs:
- name: Test - name: Test
run: | run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \ NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
go test -tags devcert -run=^$ -bench=. \ go test -tags devcert -run=^$ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./... -timeout 20m ./management/...
api_benchmark: api_benchmark:
name: "Management / Benchmark (API)" name: "Management / Benchmark (API)"
@@ -396,6 +429,33 @@ jobs:
store: [ 'sqlite', 'postgres' ] store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Create Docker network
run: docker network create promnet
- name: Start Prometheus Pushgateway
run: docker run -d --name pushgateway --network promnet -p 9091:9091 prom/pushgateway
- name: Start Prometheus (for Pushgateway forwarding)
run: |
echo '
global:
scrape_interval: 15s
scrape_configs:
- job_name: "pushgateway"
static_configs:
- targets: ["pushgateway:9091"]
remote_write:
- url: ${{ secrets.GRAFANA_URL }}
basic_auth:
username: ${{ secrets.GRAFANA_USER }}
password: ${{ secrets.GRAFANA_API_KEY }}
' > prometheus.yml
docker run -d --name prometheus --network promnet \
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
-p 9090:9090 \
prom/prometheus
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
@@ -420,13 +480,6 @@ jobs:
restore-keys: | restore-keys: |
${{ runner.os }}-gotest-cache- ${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules - name: Install modules
run: go mod tidy run: go mod tidy
@@ -447,11 +500,13 @@ jobs:
- name: Test - name: Test
run: | run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \ NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
GIT_BRANCH=${{ github.ref_name }} \
go test -tags=benchmark \ go test -tags=benchmark \
-run=^$ \ -run=^$ \
-bench=. \ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
-timeout 20m ./management/... -timeout 20m ./management/...
api_integration_test: api_integration_test:
@@ -489,13 +544,6 @@ jobs:
restore-keys: | restore-keys: |
${{ runner.os }}-gotest-cache- ${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install 32-bit libpcap
if: matrix.arch == '386'
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
- name: Install modules - name: Install modules
run: go mod tidy run: go mod tidy
@@ -505,89 +553,8 @@ jobs:
- name: Test - name: Test
run: | run: |
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \ CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \ NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
go test -tags=integration \ go test -tags=integration \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \ -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
-timeout 20m ./management/... -timeout 20m ./management/...
test_client_on_docker:
name: "Client (Docker) / Unit"
needs: [ build-cache ]
runs-on: ubuntu-20.04
steps:
- name: Install Go
uses: actions/setup-go@v5
with:
go-version: "1.23.x"
cache: false
- name: Checkout code
uses: actions/checkout@v4
- name: Get Go environment
run: |
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@v4
with:
path: |
${{ env.cache }}
${{ env.modcache }}
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ runner.os }}-gotest-cache-
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
- name: Install modules
run: go mod tidy
- name: check git status
run: git --no-pager diff --exit-code
- name: Generate Shared Sock Test bin
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
- name: Generate RouteManager Test bin
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager
- name: Generate SystemOps Test bin
run: CGO_ENABLED=1 go test -c -o systemops-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/systemops
- name: Generate nftables Manager Test bin
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
- name: Generate Engine Test bin
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
- name: Generate Peer Test bin
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/
- run: chmod +x *testing.bin
- name: Run Shared Sock tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Iface tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/...
- name: Run RouteManager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
- name: Run SystemOps tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/systemops-testing.bin -test.timeout 5m -test.parallel 1
- name: Run nftables Manager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Engine tests in docker with file store
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Engine tests in docker with sqlite store
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Peer tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell - name: codespell
uses: codespell-project/actions-codespell@v2 uses: codespell-project/actions-codespell@v2
with: with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
skip: go.mod,go.sum skip: go.mod,go.sum
only_warn: 1 only_warn: 1
golangci: golangci:

View File

@@ -87,25 +87,25 @@ jobs:
with: with:
name: release name: release
path: dist/ path: dist/
retention-days: 3 retention-days: 7
- name: upload linux packages - name: upload linux packages
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: linux-packages name: linux-packages
path: dist/netbird_linux** path: dist/netbird_linux**
retention-days: 3 retention-days: 7
- name: upload windows packages - name: upload windows packages
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: windows-packages name: windows-packages
path: dist/netbird_windows** path: dist/netbird_windows**
retention-days: 3 retention-days: 7
- name: upload macos packages - name: upload macos packages
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: macos-packages name: macos-packages
path: dist/netbird_darwin** path: dist/netbird_darwin**
retention-days: 3 retention-days: 7
release_ui: release_ui:
runs-on: ubuntu-latest runs-on: ubuntu-latest

View File

@@ -178,6 +178,8 @@ jobs:
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$' grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445" grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
grep -A 7 Relay management.json | egrep '"Secret": ".+"' grep -A 7 Relay management.json | egrep '"Secret": ".+"'
grep DisablePromptLogin management.json | grep 'true'
grep LoginFlag management.json | grep 0
- name: Install modules - name: Install modules
run: go mod tidy run: go mod tidy

View File

@@ -96,6 +96,20 @@ builds:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}" mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-upload
dir: upload-server
env: [CGO_ENABLED=0]
binary: netbird-upload
goos:
- linux
goarch:
- amd64
- arm64
- arm
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
universal_binaries: universal_binaries:
- id: netbird - id: netbird
@@ -409,6 +423,52 @@ dockers:
- "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}" - "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io" - "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-amd64
ids:
- netbird-upload
goarch: amd64
use: buildx
dockerfile: upload-server/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
ids:
- netbird-upload
goarch: arm64
use: buildx
dockerfile: upload-server/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-arm
ids:
- netbird-upload
goarch: arm
goarm: 6
use: buildx
dockerfile: upload-server/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=maintainer=dev@netbird.io"
docker_manifests: docker_manifests:
- name_template: netbirdio/netbird:{{ .Version }} - name_template: netbirdio/netbird:{{ .Version }}
image_templates: image_templates:
@@ -475,7 +535,17 @@ docker_manifests:
- netbirdio/management:{{ .Version }}-debug-arm64v8 - netbirdio/management:{{ .Version }}-debug-arm64v8
- netbirdio/management:{{ .Version }}-debug-arm - netbirdio/management:{{ .Version }}-debug-arm
- netbirdio/management:{{ .Version }}-debug-amd64 - netbirdio/management:{{ .Version }}-debug-amd64
- name_template: netbirdio/upload:{{ .Version }}
image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
- name_template: netbirdio/upload:latest
image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
brews: brews:
- ids: - ids:
- default - default

View File

@@ -1,148 +1,64 @@
# Contributor License Agreement ## Contributor License Agreement
We are incredibly thankful for the contributions we receive from the community. This Contributor License Agreement (referred to as the "Agreement") is entered into by the individual
We require our external contributors to sign a Contributor License Agreement ("CLA") in submitting this Agreement and NetBird GmbH, c/o Max-Beer-Straße 2-4 Münzstraße 12 10178 Berlin, Germany,
order to ensure that our projects remain licensed under Free and Open Source licenses such referred to as "NetBird" (collectively, the "Parties"). The Agreement outlines the terms and conditions
as BSD-3 while allowing NetBird to build a sustainable business. under which NetBird may utilize software contributions provided by the Contributor for inclusion in
its software development projects. By submitting this Agreement, the Contributor confirms their acceptance
NetBird is committed to having a true Open Source Software ("OSS") license for of the terms and conditions outlined below. The Contributor further represents that they are authorized to
our software. A CLA enables NetBird to safely commercialize our products complete this process as described herein.
while keeping a standard OSS license with all the rights that license grants to users: the
ability to use the project in their own projects or businesses, to republish modified
source, or to completely fork the project.
This page gives a human-friendly summary of our CLA, details on why we require a CLA, how
contributors can sign our CLA, and more. You may view the full legal CLA document (below).
# Human-friendly summary
This is a human-readable summary of (and not a substitute for) the full agreement (below).
This highlights only some of key terms of the CLA. It has no legal value and you should
carefully review all the terms of the actual CLA before agreeing.
<li>Grant of copyright license. You give NetBird permission to use your copyrighted work
in commercial products.
</li>
<li>Grant of patent license. If your contributed work uses a patent, you give NetBird a
license to use that patent including within commercial products. You also agree that you
have permission to grant this license.
</li>
<li>No Warranty or Support Obligations.
By making a contribution, you are not obligating yourself to provide support for the
contribution, and you are not taking on any warranty obligations or providing any
assurances about how it will perform.
</li>
The CLA does not change the terms of the standard open source license used by our software
such as BSD-3 or MIT.
You are still free to use our projects within your own projects or businesses, republish
modified source, and more.
Please reference the appropriate license for the project you're contributing to to learn
more.
# Why require a CLA?
Agreeing to a CLA explicitly states that you are entitled to provide a contribution, that you cannot withdraw permission
to use your contribution at a later date, and that NetBird has permission to use your contribution in our commercial
products.
This removes any ambiguities or uncertainties caused by not having a CLA and allows users and customers to confidently
adopt our projects. At the same time, the CLA ensures that all contributions to our open source projects are licensed
under the project's respective open source license, such as BSD-3.
Requiring a CLA is a common and well-accepted practice in open source. Major open source projects require CLAs such as
Apache Software Foundation projects, Facebook projects (such as React), Google projects (including Go), Python, Django,
and more. Each of these projects remains licensed under permissive OSS licenses such as MIT, Apache, BSD, and more.
# Signing the CLA
Open a pull request ("PR") to any of our open source projects to sign the CLA. A bot will comment on the PR asking you
to sign the CLA if you haven't already.
Follow the steps given by the bot to sign the CLA. This will require you to log in with GitHub (we only request public
information from your account) and to fill in a few additional details such as your name and email address. We will only
use this information for CLA tracking; none of your submitted information will be used for marketing purposes.
You only have to sign the CLA once. Once you've signed the CLA, future contributions to any NetBird project will not
require you to sign again.
# Legal Terms and Agreement
In order to clarify the intellectual property license granted with Contributions from any person or entity, NetBird
GmbH ("NetBird") must have a Contributor License Agreement ("CLA") on file that has been signed
by each Contributor, indicating agreement to the license terms below. This license does not change your rights to use
your own Contributions for any other purpose.
You accept and agree to the following terms and conditions for Your present and future Contributions submitted to
NetBird. Except for the license granted herein to NetBird and recipients of software distributed by NetBird,
You reserve all right, title, and interest in and to Your Contributions.
1. Definitions.
```
"You" (or "Your") shall mean the copyright owner or legal entity authorized by the copyright owner
that is making this Agreement with NetBird. For legal entities, the entity making a Contribution and all other
entities that control, are controlled by, or are under common control with that entity are considered
to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect,
to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty
percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
```
```
"Contribution" shall mean any original work of authorship, including any modifications or additions to
an existing work, that is or previously has been intentionally submitted by You to NetBird for inclusion in,
or documentation of, any of the products owned or managed by NetBird (the "Work").
For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication
sent to NetBird or its representatives, including but not limited to communication on electronic mailing lists,
source code control systems, and issue tracking systems that are managed by, or on behalf of,
NetBird for the purpose of discussing and improving the Work, but excluding communication that is conspicuously
marked or otherwise designated in writing by You as "Not a Contribution."
```
2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird
and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge,
royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly
perform, sublicense, and distribute Your Contributions and such derivative works.
3. Grant of Patent License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird and ## 1 Preamble
to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free, In order to clarify the IP Rights situation with regard to Contributions from any person or entity, NetBird
irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, must have a contributor license agreement on file to be signed by each Contributor, containing the license
and otherwise transfer the Work, where such license applies only to those patent claims licensable by You that are terms below. This license serves as protection for both the Contributor as well as NetBird and its software users;
necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Work to which it does not change Contributors rights to use his/her own Contributions for any other purpose.
such Contribution(s) was submitted. If any entity institutes patent litigation against You or any other entity (
including a cross-claim or counterclaim in a lawsuit) alleging that your Contribution, or the Work to which you have
contributed, constitutes direct or contributory patent infringement, then any patent licenses granted to that entity
under this Agreement for that Contribution or Work shall terminate as of the date such litigation is filed.
## 2 Definitions
2.1 “IP Rights” shall mean all industrial and intellectual property rights, whether registered or not registered, whether created by Contributor or acquired by Contributor from third parties, and similar rights, including (but not limited to) semiconductor property rights, design rights, copyrights (including in the form of database rights and rights to software), all neighbouring rights (Leistungsschutzrechte), trademarks, service marks, titles, internet domain names, trade names and other labelling rights, rights deriving from corresponding applications and registrations of such rights as well as any licenses (Nutzungsrechte) under and entitlements to any such intellectual and industrial property rights.
4. You represent that you are legally entitled to grant the above license. If your employer(s) has rights to 2.2 "Contribution" shall mean any original work of authorship, including any modifications or additions to an existing work, that is or previously has been intentionally Submitted by Contributor to NetBird for inclusion in, or documentation of any Work.
intellectual property that you create that includes your Contributions, you represent that you have received
permission to make Contributions on behalf of that employer, that you will have received permission from your current
and future employers for all future Contributions, that your applicable employer has waived such rights for all of
your current and future Contributions to NetBird, or that your employer has executed a separate Corporate CLA
with NetBird.
2.3 "Contributor" shall mean the copyright owner or legal entity authorized by the copyright owner that is concluding this Agreement with NetBird. For legal entities, the entity making a Contribution and all other entities that control, are controlled by, or are under common control with that entity are considered to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
5. You represent that each of Your Contributions is Your original creation (see section 7 for submissions on behalf of 2.4 "Submitted" shall mean any form of electronic, verbal, or written communication sent to NetBird or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, NetBird for the purpose of discussing and improving the Work, but excluding communication that is marked or otherwise designated in writing by Contributor as "Not a Contribution".
others). You represent that Your Contribution submissions include complete details of any third-party license or
other restriction (including, but not limited to, related patents and trademarks) of which you are personally aware
and which are associated with any part of Your Contributions.
2.5 "Work" means any of the products owned or managed by NetBird, in particular, but not exclusively, software.
6. You are not expected to provide support for Your Contributions, except to the extent You desire to provide support. ## 3 Licenses
You may provide support for free, for a fee, or not at all. Unless required by applicable law or agreed to in 3.1 Subject to the terms and conditions of this agreement, Contributor hereby grants to NetBird and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable license to reproduce by any means and in any form, in whole or in part, permanently or temporarily, the Contributions (including loading, displaying, executing, transmitting or storing works for the purpose of executing and processing data or transferring them to video, audio and other data carriers), including the right to distribute, display and present such Contributions and make them available to the public (e.g. via the internet) and to transmit and display such Contributions by any means. The license also includes the right to modify, translate, adapt, edit and otherwise alter the Contributions and to use these results in the same manner as the original Contributions and derivative works. Except for licenses in patents acc. to Sec. 3, such license refers to any IP Rights in the Contributions and derivative works. The Contributor acknowledges that NetBird is not required to credit them by name for their Contribution and agrees to waive any moral rights associated with their Contribution in relation to NetBird or its sublicensees.
writing, You provide Your Contributions on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
express or implied, including, without limitation, any warranties or conditions of TITLE, NON- INFRINGEMENT,
MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE.
3.2 Subject to the terms and conditions of this agreement, Contributor hereby grants to NetBird and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license in the Contributions to make, have made, use, sell, offer to sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by the Contributor which are necessarily infringed by Contributors Contribution(s) alone or by combination of Contributors Contribution(s) with the Work to which such Contribution(s) was Submitted.
7. Should You wish to submit work that is not Your original creation, You may submit it to NetBird separately from 3.3 NetBird hereby accepts such licenses.
any Contribution, identifying the complete details of its source and of any license or other restriction (including,
but not limited to, related patents, trademarks, and license agreements) of which you are personally aware, and
conspicuously marking the work as "Submitted on behalf of a third-party: [named here]".
## 4 Contributors Representations
4.1 Contributor represents that Contributor is legally entitled to grant the above license. If Contributors employer has IP Rights to Contributors Contributions, Contributor represent that he/she has received permission to make Contributions on behalf of such employer, that such employer has waived such IP Rights to the Contributions of Contributor to NetBird, or that such employer has executed a separate contributor license agreement with NetBird.
4.2 Contributor represents that any Contribution is his/her original creation.
4.3 Contributor represents to his/her best knowledge that any Contribution does not violate any third party IP Rights.
4.4 Contributor represents that any Contribution submission includes complete details of any third-party license or other restriction (including, but not limited to, related patents and trademarks) of which Contributor is personally aware and which are associated with any part of the Contribution.
4.5 The Contributor represents that their Contribution does not include any work distributed under a copyleft license.
## 5 Information obligation
Contributor agrees to notify NetBird of any facts or circumstances of which Contributor become aware that would make these representations inaccurate in any respect.
## 6 Submission of Third-Party works
Should Contributor wish to submit work that is not Contributors original creation, Contributor may submit it to NetBird separately from any Contribution, identifying the complete details of its source and of any license or other restriction (including, but not limited to, related patents, trademarks, and license agreements) of which Contributor are personally aware, and conspicuously marking the work as "Submitted on behalf of a third-party: [named here]".
## 7 No Consideration
Unless compensation is mandatory under statutory law, no compensation for any license under this agreement shall be payable.
## 8 Final Provisions
8.1 Laws. This Agreement is governed by the laws of the Federal Republic of Germany.
8.2 Venue. Place of jurisdiction shall, to the extent legally permissible, be Berlin, Germany.
8.3 Severability. If any provision in this agreement is unlawful, invalid or ineffective, it shall not affect the enforceability or effectiveness of the remainder of this agreement. The parties agree to replace any unlawful, invalid or ineffective provision with a provision that comes as close as possible to the commercial intent and purpose of the original provision. This section also applies accordingly to any gaps in the contract.
8.4 Variations. Any variations, amendments or supplements to this Agreement must be in writing. This also applies to any variation of this Section 8.4.
8. You agree to notify NetBird of any facts or circumstances of which you become aware that would make these
representations inaccurate in any respect.

View File

@@ -12,7 +12,7 @@
<img src="https://img.shields.io/badge/license-BSD--3-blue" /> <img src="https://img.shields.io/badge/license-BSD--3-blue" />
</a> </a>
<br> <br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg"> <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/> <img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a> </a>
<br> <br>
@@ -29,13 +29,13 @@
<br/> <br/>
See <a href="https://netbird.io/docs/">Documentation</a> See <a href="https://netbird.io/docs/">Documentation</a>
<br/> <br/>
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">Slack channel</a> Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">Slack channel</a>
<br/> <br/>
</strong> </strong>
<br> <br>
<a href="https://netbird.io/webinars/achieve-zero-trust-access-to-k8s?utm_source=github&utm_campaign=2502%20-%20webinar%20-%20How%20to%20Achieve%20Zero%20Trust%20Access%20to%20Kubernetes%20-%20Effortlessly&utm_medium=github"> <a href="https://github.com/netbirdio/kubernetes-operator">
Webinar: Securely Access Kubernetes without Port Forwarding and Jump Hosts New: NetBird Kubernetes Operator
</a> </a>
</p> </p>
@@ -58,15 +58,15 @@
### Key features ### Key features
| Connectivity | Management | Security | Automation| Platforms | | Connectivity | Management | Security | Automation| Platforms |
|------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------| |----|----|----|----|----|
| <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> | | <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> |
| <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> | | <ul><li>- \[x] Peer-to-peer connections</ul></li> | <ul><li>- \[x] Auto peer discovery and configuration</ui></li> | <ul><li>- \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)</ui></li> | <ul><li>- \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)</ui></li> | <ul><li>- \[x] Mac</ui></li> |
| <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> | | <ul><li>- \[x] Connection relay fallback</ui></li> | <ul><li>- \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)</ui></li> | <ul><li>- \[x] [Activity logging](https://docs.netbird.io/how-to/audit-events-logging)</ui></li> | <ul><li>- \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)</ui></li> | <ul><li>- \[x] Windows</ui></li> |
| <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | <ul><li> - \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> | | <ul><li>- \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)</ui></li> | <ul><li>- \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)</ui></li> | <ul><li>- \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)</ui></li> | <ul><li>- \[x] IdP groups sync with JWT</ui></li> | <ul><li>- \[x] Android</ui></li> |
| <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | <ul><li> - \[x] Peer-to-peer encryption </ul></li> | | <ul><li> - \[x] iOS </ul></li> | | <ul><li>- \[x] NAT traversal with BPF</ui></li> | <ul><li>- \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)</ui></li> | <ul><li>- \[x] Peer-to-peer encryption</ui></li> || <ul><li>- \[x] iOS</ui></li> |
| | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> | ||| <ul><li>- \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)</ui></li> || <ul><li>- \[x] OpenWRT</ui></li> |
| | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> | ||| <ul><li>- \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ui></li> || <ul><li>- \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)</ui></li> |
| | | | | <ul><li> - \[x] Docker </ul></li> | ||||| <ul><li>- \[x] Docker</ui></li> |
### Quickstart with NetBird Cloud ### Quickstart with NetBird Cloud

View File

@@ -1,5 +1,6 @@
FROM alpine:3.21.3 FROM alpine:3.21.3
RUN apk add --no-cache ca-certificates iptables ip6tables # iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables
ENV NB_FOREGROUND_MODE=true ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"] ENTRYPOINT [ "/usr/local/bin/netbird","up"]
COPY netbird /usr/local/bin/netbird COPY netbird /usr/local/bin/netbird

View File

@@ -26,7 +26,7 @@ type Anonymizer struct {
} }
func DefaultAddresses() (netip.Addr, netip.Addr) { func DefaultAddresses() (netip.Addr, netip.Addr) {
// 192.51.100.0, 100:: // 198.51.100.0, 100::
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01}) return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
} }

View File

@@ -11,9 +11,12 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server" "github.com/netbirdio/netbird/client/server"
nbstatus "github.com/netbirdio/netbird/client/status" nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/management/proto"
) )
const errCloseConnection = "Failed to close connection: %v" const errCloseConnection = "Failed to close connection: %v"
@@ -84,16 +87,27 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
}() }()
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{ request := &proto.DebugBundleRequest{
Anonymize: anonymizeFlag, Anonymize: anonymizeFlag,
Status: getStatusOutput(cmd, anonymizeFlag), Status: getStatusOutput(cmd, anonymizeFlag),
SystemInfo: debugSystemInfoFlag, SystemInfo: debugSystemInfoFlag,
}) }
if debugUploadBundle {
request.UploadURL = debugUploadBundleURL
}
resp, err := client.DebugBundle(cmd.Context(), request)
if err != nil { if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
} }
cmd.Printf("Local file:\n%s\n", resp.GetPath())
cmd.Println(resp.GetPath()) if resp.GetUploadFailureReason() != "" {
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
}
if debugUploadBundle {
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
}
return nil return nil
} }
@@ -208,23 +222,19 @@ func runForDuration(cmd *cobra.Command, args []string) error {
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration) headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag)) statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
request := &proto.DebugBundleRequest{
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
Anonymize: anonymizeFlag, Anonymize: anonymizeFlag,
Status: statusOutput, Status: statusOutput,
SystemInfo: debugSystemInfoFlag, SystemInfo: debugSystemInfoFlag,
}) }
if debugUploadBundle {
request.UploadURL = debugUploadBundleURL
}
resp, err := client.DebugBundle(cmd.Context(), request)
if err != nil { if err != nil {
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
} }
// Disable network map persistence after creating the debug bundle
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
Enabled: false,
}); err != nil {
return fmt.Errorf("failed to disable network map persistence: %v", status.Convert(err).Message())
}
if stateWasDown { if stateWasDown {
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
return fmt.Errorf("failed to down: %v", status.Convert(err).Message()) return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
@@ -239,7 +249,15 @@ func runForDuration(cmd *cobra.Command, args []string) error {
cmd.Println("Log level restored to", initialLogLevel.GetLevel()) cmd.Println("Log level restored to", initialLogLevel.GetLevel())
} }
cmd.Println(resp.GetPath()) cmd.Printf("Local file:\n%s\n", resp.GetPath())
if resp.GetUploadFailureReason() != "" {
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
}
if debugUploadBundle {
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
}
return nil return nil
} }
@@ -326,3 +344,34 @@ func formatDuration(d time.Duration) string {
s := d / time.Second s := d / time.Second
return fmt.Sprintf("%02d:%02d:%02d", h, m, s) return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
} }
func generateDebugBundle(config *internal.Config, recorder *peer.Status, connectClient *internal.ConnectClient, logFilePath string) {
var networkMap *mgmProto.NetworkMap
var err error
if connectClient != nil {
networkMap, err = connectClient.GetLatestNetworkMap()
if err != nil {
log.Warnf("Failed to get latest network map: %v", err)
}
}
bundleGenerator := debug.NewBundleGenerator(
debug.GeneratorDependencies{
InternalConfig: config,
StatusRecorder: recorder,
NetworkMap: networkMap,
LogFile: logFilePath,
},
debug.BundleConfig{
IncludeSystemInfo: true,
},
)
path, err := bundleGenerator.Generate()
if err != nil {
log.Errorf("Failed to generate debug bundle: %v", err)
return
}
log.Infof("Generated debug bundle from SIGUSR1 at: %s", path)
}

39
client/cmd/debug_unix.go Normal file
View File

@@ -0,0 +1,39 @@
//go:build unix
package cmd
import (
"context"
"os"
"os/signal"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
)
func SetupDebugHandler(
ctx context.Context,
config *internal.Config,
recorder *peer.Status,
connectClient *internal.ConnectClient,
logFilePath string,
) {
usr1Ch := make(chan os.Signal, 1)
signal.Notify(usr1Ch, syscall.SIGUSR1)
go func() {
for {
select {
case <-ctx.Done():
return
case <-usr1Ch:
log.Info("Received SIGUSR1. Triggering debug bundle generation.")
go generateDebugBundle(config, recorder, connectClient, logFilePath)
}
}
}()
}

126
client/cmd/debug_windows.go Normal file
View File

@@ -0,0 +1,126 @@
package cmd
import (
"context"
"errors"
"os"
"strconv"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
)
const (
envListenEvent = "NB_LISTEN_DEBUG_EVENT"
debugTriggerEventName = `Global\NetbirdDebugTriggerEvent`
waitTimeout = 5 * time.Second
)
// SetupDebugHandler sets up a Windows event to listen for a signal to generate a debug bundle.
// Example usage with PowerShell:
// $evt = [System.Threading.EventWaitHandle]::OpenExisting("Global\NetbirdDebugTriggerEvent")
// $evt.Set()
// $evt.Close()
func SetupDebugHandler(
ctx context.Context,
config *internal.Config,
recorder *peer.Status,
connectClient *internal.ConnectClient,
logFilePath string,
) {
env := os.Getenv(envListenEvent)
if env == "" {
return
}
listenEvent, err := strconv.ParseBool(env)
if err != nil {
log.Errorf("Failed to parse %s: %v", envListenEvent, err)
return
}
if !listenEvent {
return
}
eventNamePtr, err := windows.UTF16PtrFromString(debugTriggerEventName)
if err != nil {
log.Errorf("Failed to convert event name '%s' to UTF16: %v", debugTriggerEventName, err)
return
}
// TODO: restrict access by ACL
eventHandle, err := windows.CreateEvent(nil, 1, 0, eventNamePtr)
if err != nil {
if errors.Is(err, windows.ERROR_ALREADY_EXISTS) {
log.Warnf("Debug trigger event '%s' already exists. Attempting to open.", debugTriggerEventName)
// SYNCHRONIZE is needed for WaitForSingleObject, EVENT_MODIFY_STATE for ResetEvent.
eventHandle, err = windows.OpenEvent(windows.SYNCHRONIZE|windows.EVENT_MODIFY_STATE, false, eventNamePtr)
if err != nil {
log.Errorf("Failed to open existing debug trigger event '%s': %v", debugTriggerEventName, err)
return
}
log.Infof("Successfully opened existing debug trigger event '%s'.", debugTriggerEventName)
} else {
log.Errorf("Failed to create debug trigger event '%s': %v", debugTriggerEventName, err)
return
}
}
if eventHandle == windows.InvalidHandle {
log.Errorf("Obtained an invalid handle for debug trigger event '%s'", debugTriggerEventName)
return
}
log.Infof("Debug handler waiting for signal on event: %s", debugTriggerEventName)
go waitForEvent(ctx, config, recorder, connectClient, logFilePath, eventHandle)
}
func waitForEvent(
ctx context.Context,
config *internal.Config,
recorder *peer.Status,
connectClient *internal.ConnectClient,
logFilePath string,
eventHandle windows.Handle,
) {
defer func() {
if err := windows.CloseHandle(eventHandle); err != nil {
log.Errorf("Failed to close debug event handle '%s': %v", debugTriggerEventName, err)
}
}()
for {
if ctx.Err() != nil {
return
}
status, err := windows.WaitForSingleObject(eventHandle, uint32(waitTimeout.Milliseconds()))
switch status {
case windows.WAIT_OBJECT_0:
log.Info("Received signal on debug event. Triggering debug bundle generation.")
// reset the event so it can be triggered again later (manual reset == 1)
if err := windows.ResetEvent(eventHandle); err != nil {
log.Errorf("Failed to reset debug event '%s': %v", debugTriggerEventName, err)
}
go generateDebugBundle(config, recorder, connectClient, logFilePath)
case uint32(windows.WAIT_TIMEOUT):
default:
log.Errorf("Unexpected status %d from WaitForSingleObject for debug event '%s': %v", status, debugTriggerEventName, err)
select {
case <-time.After(5 * time.Second):
case <-ctx.Done():
return
}
}
}
}

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"runtime"
"strings" "strings"
"time" "time"
@@ -19,6 +20,10 @@ import (
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
func init() {
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
}
var loginCmd = &cobra.Command{ var loginCmd = &cobra.Command{
Use: "login", Use: "login",
Short: "login to the Netbird Management Service (first run)", Short: "login to the Netbird Management Service (first run)",
@@ -51,6 +56,9 @@ var loginCmd = &cobra.Command{
return err return err
} }
// update host's static platform and system information
system.UpdateStaticInfo()
ic := internal.ConfigInput{ ic := internal.ConfigInput{
ManagementURL: managementURL, ManagementURL: managementURL,
AdminURL: adminURL, AdminURL: adminURL,
@@ -93,7 +101,7 @@ var loginCmd = &cobra.Command{
loginRequest := proto.LoginRequest{ loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey, SetupKey: providedSetupKey,
ManagementUrl: managementURL, ManagementUrl: managementURL,
IsLinuxDesktopClient: isLinuxRunningDesktop(), IsUnixDesktopClient: isUnixRunningDesktop(),
Hostname: hostName, Hostname: hostName,
DnsLabels: dnsLabelsReq, DnsLabels: dnsLabelsReq,
} }
@@ -127,7 +135,7 @@ var loginCmd = &cobra.Command{
} }
if loginResp.NeedsSSOLogin { if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode) openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName}) _, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil { if err != nil {
@@ -188,7 +196,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
} }
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) { func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isLinuxRunningDesktop()) oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -198,7 +206,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err) return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
} }
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode) openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout) waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
@@ -212,23 +220,34 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
return &tokenInfo, nil return &tokenInfo, nil
} }
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) { func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser bool) {
var codeMsg string var codeMsg string
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) { if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode) codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
} }
if noBrowser {
cmd.Println("Use this URL to log in:\n\n" + verificationURIComplete + " " + codeMsg)
} else {
cmd.Println("Please do the SSO login in your browser. \n" + cmd.Println("Please do the SSO login in your browser. \n" +
"If your browser didn't open automatically, use this URL to log in:\n\n" + "If your browser didn't open automatically, use this URL to log in:\n\n" +
verificationURIComplete + " " + codeMsg) verificationURIComplete + " " + codeMsg)
}
cmd.Println("") cmd.Println("")
if !noBrowser {
if err := open.Run(verificationURIComplete); err != nil { if err := open.Run(verificationURIComplete); err != nil {
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" + cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys") "https://docs.netbird.io/how-to/register-machines-using-setup-keys")
} }
} }
}
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment // isUnixRunningDesktop checks if a Linux OS is running desktop environment
func isLinuxRunningDesktop() bool { func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
return false
}
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != "" return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
} }

View File

@@ -22,6 +22,7 @@ import (
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/upload-server/types"
) )
const ( const (
@@ -39,6 +40,8 @@ const (
dnsRouteIntervalFlag = "dns-router-interval" dnsRouteIntervalFlag = "dns-router-interval"
systemInfoFlag = "system-info" systemInfoFlag = "system-info"
blockLANAccessFlag = "block-lan-access" blockLANAccessFlag = "block-lan-access"
uploadBundle = "upload-bundle"
uploadBundleURL = "upload-bundle-url"
) )
var ( var (
@@ -75,6 +78,8 @@ var (
debugSystemInfoFlag bool debugSystemInfoFlag bool
dnsRouteInterval time.Duration dnsRouteInterval time.Duration
blockLANAccess bool blockLANAccess bool
debugUploadBundle bool
debugUploadBundleURL string
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "netbird", Use: "netbird",
@@ -180,7 +185,9 @@ func init() {
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted") upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.") upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", false, "Adds system information to the debug bundle") debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))
debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
} }
// SetupCloseHandler handles SIGTERM signal and exits with success // SetupCloseHandler handles SIGTERM signal and exits with success

View File

@@ -16,12 +16,17 @@ import (
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/server" "github.com/netbirdio/netbird/client/server"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
func (p *program) Start(svc service.Service) error { func (p *program) Start(svc service.Service) error {
// Start should not block. Do the actual work async. // Start should not block. Do the actual work async.
log.Info("starting Netbird service") //nolint log.Info("starting Netbird service") //nolint
// Collect static system and platform information
system.UpdateStaticInfo()
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API. // in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
p.serv = grpc.NewServer() p.serv = grpc.NewServer()
@@ -115,6 +120,7 @@ var runCmd = &cobra.Command{
ctx, cancel := context.WithCancel(cmd.Context()) ctx, cancel := context.WithCancel(cmd.Context())
SetupCloseHandler(ctx, cancel) SetupCloseHandler(ctx, cancel)
SetupDebugHandler(ctx, nil, nil, nil, logFile)
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig()) s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
if err != nil { if err != nil {

View File

@@ -6,14 +6,17 @@ import (
"testing" "testing"
"time" "time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
@@ -31,7 +34,7 @@ import (
func startTestingServices(t *testing.T) string { func startTestingServices(t *testing.T) string {
t.Helper() t.Helper()
config := &mgmt.Config{} config := &types.Config{}
_, err := util.ReadJson("../testdata/management.json", config) _, err := util.ReadJson("../testdata/management.json", config)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -66,7 +69,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
return s, lis return s, lis
} }
func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) { func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc.Server, net.Listener) {
t.Helper() t.Helper()
lis, err := net.Listen("tcp", ":0") lis, err := net.Listen("tcp", ":0")
@@ -89,14 +92,24 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err) require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock()) settingsMockManager := settings.NewMockManager(ctrl)
permissionsManagerMock := permissions.NewMockManager(ctrl)
settingsMockManager.EXPECT().
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&types.Settings{}, nil).
AnyTimes()
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil) mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -32,12 +32,16 @@ const (
const ( const (
dnsLabelsFlag = "extra-dns-labels" dnsLabelsFlag = "extra-dns-labels"
noBrowserFlag = "no-browser"
noBrowserDesc = "do not open the browser for SSO login"
) )
var ( var (
foregroundMode bool foregroundMode bool
dnsLabels []string dnsLabels []string
dnsLabelsValidated domain.List dnsLabelsValidated domain.List
noBrowser bool
upCmd = &cobra.Command{ upCmd = &cobra.Command{
Use: "up", Use: "up",
@@ -65,6 +69,9 @@ func init() {
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+ `E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
`or --extra-dns-labels ""`, `or --extra-dns-labels ""`,
) )
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
} }
func upFunc(cmd *cobra.Command, args []string) error { func upFunc(cmd *cobra.Command, args []string) error {
@@ -212,6 +219,8 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
r.GetFullStatus() r.GetFullStatus()
connectClient := internal.NewConnectClient(ctx, config, r) connectClient := internal.NewConnectClient(ctx, config, r)
SetupDebugHandler(ctx, config, r, connectClient, "")
return connectClient.Run(nil) return connectClient.Run(nil)
} }
@@ -259,7 +268,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
NatExternalIPs: natExternalIPs, NatExternalIPs: natExternalIPs,
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0, CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
CustomDNSAddress: customDNSAddressConverted, CustomDNSAddress: customDNSAddressConverted,
IsLinuxDesktopClient: isLinuxRunningDesktop(), IsUnixDesktopClient: isUnixRunningDesktop(),
Hostname: hostName, Hostname: hostName,
ExtraIFaceBlacklist: extraIFaceBlackList, ExtraIFaceBlacklist: extraIFaceBlackList,
DnsLabels: dnsLabels, DnsLabels: dnsLabels,
@@ -349,7 +358,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
if loginResp.NeedsSSOLogin { if loginResp.NeedsSSOLogin {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode) openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName}) _, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil { if err != nil {

View File

@@ -10,17 +10,18 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
// NewFirewall creates a firewall manager instance // NewFirewall creates a firewall manager instance
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) { func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
if !iface.IsUserspaceBind() { if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
} }
// use userspace packet filtering firewall // use userspace packet filtering firewall
fm, err := uspfilter.Create(iface, disableServerRoutes) fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -15,6 +15,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables" nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -33,7 +34,7 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type // FWType is the type for the firewall type
type FWType int type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) { func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables // on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic // in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers // so we use AllowNetbird traffic from these firewall managers
@@ -47,7 +48,7 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableS
if err != nil { if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err) log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
} }
return createUserspaceFirewall(iface, fm, disableServerRoutes) return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger)
} }
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) { func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
@@ -77,12 +78,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
} }
} }
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) { func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) {
var errUsp error var errUsp error
if fm != nil { if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes) fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger)
} else { } else {
fm, errUsp = uspfilter.Create(iface, disableServerRoutes) fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger)
} }
if errUsp != nil { if errUsp != nil {

View File

@@ -75,6 +75,7 @@ func (m *aclManager) init(stateManager *statemanager.Manager) error {
} }
func (m *aclManager) AddPeerFiltering( func (m *aclManager) AddPeerFiltering(
id []byte,
ip net.IP, ip net.IP,
protocol firewall.Protocol, protocol firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,

View File

@@ -96,36 +96,36 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
// //
// Comment will be ignored because some system this feature is not supported // Comment will be ignored because some system this feature is not supported
func (m *Manager) AddPeerFiltering( func (m *Manager) AddPeerFiltering(
id []byte,
ip net.IP, ip net.IP,
protocol firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
_ string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, action, ipsetName) return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
} }
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination firewall.Network,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort, dPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action, action firewall.Action,
) (firewall.Rule, error) { ) (firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if !destination.Addr().Is4() { if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String()) return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
} }
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action) return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
} }
// DeletePeerRule from the firewall by rule definition // DeletePeerRule from the firewall by rule definition
@@ -196,13 +196,13 @@ func (m *Manager) AllowNetbird() error {
} }
_, err := m.AddPeerFiltering( _, err := m.AddPeerFiltering(
nil,
net.IP{0, 0, 0, 0}, net.IP{0, 0, 0, 0},
"all", "all",
nil, nil,
nil, nil,
firewall.ActionAccept, firewall.ActionAccept,
"", "",
"",
) )
if err != nil { if err != nil {
return fmt.Errorf("allow netbird interface traffic: %w", err) return fmt.Errorf("allow netbird interface traffic: %w", err)
@@ -242,6 +242,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
return m.router.DeleteDNATRule(rule) return m.router.DeleteDNATRule(rule)
} }
// UpdateSet updates the set with the given prefixes
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.UpdateSet(set, prefixes)
}
func getConntrackEstablished() []string { func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
} }

View File

@@ -75,7 +75,7 @@ func TestIptablesManager(t *testing.T) {
IsRange: true, IsRange: true,
Values: []uint16{8043, 8046}, Values: []uint16{8043, 8046},
} }
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range") rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
for _, r := range rule2 { for _, r := range rule2 {
@@ -97,7 +97,7 @@ func TestIptablesManager(t *testing.T) {
// add second rule // add second rule
ip := net.ParseIP("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []uint16{5353}} port := &fw.Port{Values: []uint16{5353}}
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic") _, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Close(nil) err = manager.Close(nil)
@@ -148,7 +148,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
port := &fw.Port{ port := &fw.Port{
Values: []uint16{443}, Values: []uint16{443},
} }
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range") rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
for _, r := range rule2 { for _, r := range rule2 {
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set") require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
@@ -216,7 +216,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}} port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
} }

View File

@@ -15,7 +15,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id" nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
@@ -41,6 +41,8 @@ const (
jumpManglePre = "jump-mangle-pre" jumpManglePre = "jump-mangle-pre"
jumpNatPre = "jump-nat-pre" jumpNatPre = "jump-nat-pre"
jumpNatPost = "jump-nat-post" jumpNatPost = "jump-nat-post"
markManglePre = "mark-mangle-pre"
markManglePost = "mark-mangle-post"
matchSet = "--match-set" matchSet = "--match-set"
dnatSuffix = "_dnat" dnatSuffix = "_dnat"
@@ -55,18 +57,18 @@ type ruleInfo struct {
} }
type routeFilteringRuleParams struct { type routeFilteringRuleParams struct {
Sources []netip.Prefix Source firewall.Network
Destination netip.Prefix Destination firewall.Network
Proto firewall.Protocol Proto firewall.Protocol
SPort *firewall.Port SPort *firewall.Port
DPort *firewall.Port DPort *firewall.Port
Direction firewall.RuleDirection Direction firewall.RuleDirection
Action firewall.Action Action firewall.Action
SetName string
} }
type routeRules map[string][]string type routeRules map[string][]string
// the ipset library currently does not support comments, so we use the name only (string)
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}] type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
type router struct { type router struct {
@@ -115,45 +117,51 @@ func (r *router) init(stateManager *statemanager.Manager) error {
return fmt.Errorf("create containers: %w", err) return fmt.Errorf("create containers: %w", err)
} }
if err := r.setupDataPlaneMark(); err != nil {
log.Errorf("failed to set up data plane mark: %v", err)
}
r.updateState() r.updateState()
return nil return nil
} }
func (r *router) AddRouteFiltering( func (r *router) AddRouteFiltering(
id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination firewall.Network,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
action firewall.Action, action firewall.Action,
) (firewall.Rule, error) { ) (firewall.Rule, error) {
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action) ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok { if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil return ruleKey, nil
} }
var setName string var source firewall.Network
if len(sources) > 1 { if len(sources) > 1 {
setName = firewall.GenerateSetName(sources) source.Set = firewall.NewPrefixSet(sources)
if _, err := r.ipsetCounter.Increment(setName, sources); err != nil { } else if len(sources) > 0 {
return nil, fmt.Errorf("create or get ipset: %w", err) source.Prefix = sources[0]
}
} }
params := routeFilteringRuleParams{ params := routeFilteringRuleParams{
Sources: sources, Source: source,
Destination: destination, Destination: destination,
Proto: proto, Proto: proto,
SPort: sPort, SPort: sPort,
DPort: dPort, DPort: dPort,
Action: action, Action: action,
SetName: setName,
} }
rule := genRouteFilteringRuleSpec(params) rule, err := r.genRouteRuleSpec(params, sources)
if err != nil {
return nil, fmt.Errorf("generate route rule spec: %w", err)
}
// Insert DROP rules at the beginning, append ACCEPT rules at the end // Insert DROP rules at the beginning, append ACCEPT rules at the end
var err error
if action == firewall.ActionDrop { if action == firewall.ActionDrop {
// after the established rule // after the established rule
err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...) err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...)
@@ -176,17 +184,13 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
ruleKey := rule.ID() ruleKey := rule.ID()
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
setName := r.findSetNameInRule(rule)
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil { if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil {
return fmt.Errorf("delete route rule: %v", err) return fmt.Errorf("delete route rule: %v", err)
} }
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
if setName != "" { if err := r.decrementSetCounter(rule); err != nil {
if _, err := r.ipsetCounter.Decrement(setName); err != nil { return fmt.Errorf("decrement ipset counter: %w", err)
return fmt.Errorf("failed to remove ipset: %w", err)
}
} }
} else { } else {
log.Debugf("route rule %s not found", ruleKey) log.Debugf("route rule %s not found", ruleKey)
@@ -197,13 +201,26 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
return nil return nil
} }
func (r *router) findSetNameInRule(rule []string) string { func (r *router) decrementSetCounter(rule []string) error {
sets := r.findSets(rule)
var merr *multierror.Error
for _, setName := range sets {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
merr = multierror.Append(merr, fmt.Errorf("decrement counter: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) findSets(rule []string) []string {
var sets []string
for i, arg := range rule { for i, arg := range rule {
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet { if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
return rule[i+3] sets = append(sets, rule[i+3])
} }
} }
return "" return sets
} }
func (r *router) createIpSet(setName string, sources []netip.Prefix) error { func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
@@ -224,6 +241,8 @@ func (r *router) deleteIpSet(setName string) error {
if err := ipset.Destroy(setName); err != nil { if err := ipset.Destroy(setName); err != nil {
return fmt.Errorf("destroy set %s: %w", setName, err) return fmt.Errorf("destroy set %s: %w", setName, err)
} }
log.Debugf("Deleted unused ipset %s", setName)
return nil return nil
} }
@@ -263,6 +282,7 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
log.Errorf("%v", err) log.Errorf("%v", err)
} }
if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil { if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err) return fmt.Errorf("remove nat rule: %w", err)
} }
@@ -270,6 +290,7 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse nat rule: %w", err) return fmt.Errorf("remove inverse nat rule: %w", err)
} }
}
if err := r.removeLegacyRouteRule(pair); err != nil { if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err) return fmt.Errorf("remove legacy routing rule: %w", err)
@@ -306,8 +327,10 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
} }
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
} else {
log.Debugf("legacy forwarding rule %s not found", ruleKey) if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement ipset counter: %w", err)
}
} }
return nil return nil
@@ -347,12 +370,16 @@ func (r *router) Reset() error {
if err := r.cleanUpDefaultForwardRules(); err != nil { if err := r.cleanUpDefaultForwardRules(); err != nil {
merr = multierror.Append(merr, err) merr = multierror.Append(merr, err)
} }
r.rules = make(map[string][]string)
if err := r.ipsetCounter.Flush(); err != nil { if err := r.ipsetCounter.Flush(); err != nil {
merr = multierror.Append(merr, err) merr = multierror.Append(merr, err)
} }
if err := r.cleanupDataPlaneMark(); err != nil {
merr = multierror.Append(merr, err)
}
r.rules = make(map[string][]string)
r.updateState() r.updateState()
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
@@ -422,6 +449,57 @@ func (r *router) createContainers() error {
return nil return nil
} }
// setupDataPlaneMark configures the fwmark for the data plane
func (r *router) setupDataPlaneMark() error {
var merr *multierror.Error
preRule := []string{
"-i", r.wgIface.Name(),
"-m", "conntrack", "--ctstate", "NEW",
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkIn),
}
if err := r.iptablesClient.AppendUnique(tableMangle, chainPREROUTING, preRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add mangle prerouting rule: %w", err))
} else {
r.rules[markManglePre] = preRule
}
postRule := []string{
"-o", r.wgIface.Name(),
"-m", "conntrack", "--ctstate", "NEW",
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkOut),
}
if err := r.iptablesClient.AppendUnique(tableMangle, chainPOSTROUTING, postRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add mangle postrouting rule: %w", err))
} else {
r.rules[markManglePost] = postRule
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) cleanupDataPlaneMark() error {
var merr *multierror.Error
if preRule, exists := r.rules[markManglePre]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPREROUTING, preRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove mangle prerouting rule: %w", err))
} else {
delete(r.rules, markManglePre)
}
}
if postRule, exists := r.rules[markManglePost]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPOSTROUTING, postRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove mangle postrouting rule: %w", err))
} else {
delete(r.rules, markManglePost)
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) addPostroutingRules() error { func (r *router) addPostroutingRules() error {
// First rule for outbound masquerade // First rule for outbound masquerade
rule1 := []string{ rule1 := []string{
@@ -463,7 +541,7 @@ func (r *router) insertEstablishedRule(chain string) error {
} }
func (r *router) addJumpRules() error { func (r *router) addJumpRules() error {
// Jump to NAT chain // Jump to nat chain
natRule := []string{"-j", chainRTNAT} natRule := []string{"-j", chainRTNAT}
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil { if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
return fmt.Errorf("add nat postrouting jump rule: %v", err) return fmt.Errorf("add nat postrouting jump rule: %v", err)
@@ -537,12 +615,26 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
rule = append(rule, rule = append(rule,
"-m", "conntrack", "-m", "conntrack",
"--ctstate", "NEW", "--ctstate", "NEW",
"-s", pair.Source.String(), )
"-d", pair.Destination.String(), sourceExp, err := r.applyNetwork("-s", pair.Source, nil)
if err != nil {
return fmt.Errorf("apply network -s: %w", err)
}
destExp, err := r.applyNetwork("-d", pair.Destination, nil)
if err != nil {
return fmt.Errorf("apply network -d: %w", err)
}
rule = append(rule, sourceExp...)
rule = append(rule, destExp...)
rule = append(rule,
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue), "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
) )
if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil { // Ensure nat rules come first, so the mark can be overwritten.
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
if err := r.iptablesClient.Insert(tableMangle, chainRTPRE, 1, rule...); err != nil {
// TODO: rollback ipset counter
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err) return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
} }
@@ -560,6 +652,10 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err) return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
} }
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement ipset counter: %w", err)
}
} else { } else {
log.Debugf("marking rule %s not found", ruleKey) log.Debugf("marking rule %s not found", ruleKey)
} }
@@ -725,17 +821,21 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []netip.Prefix) ([]string, error) {
var rule []string var rule []string
if params.SetName != "" { sourceExp, err := r.applyNetwork("-s", params.Source, sources)
rule = append(rule, "-m", "set", matchSet, params.SetName, "src") if err != nil {
} else if len(params.Sources) > 0 { return nil, fmt.Errorf("apply network -s: %w", err)
source := params.Sources[0]
rule = append(rule, "-s", source.String()) }
destExp, err := r.applyNetwork("-d", params.Destination, nil)
if err != nil {
return nil, fmt.Errorf("apply network -d: %w", err)
} }
rule = append(rule, "-d", params.Destination.String()) rule = append(rule, sourceExp...)
rule = append(rule, destExp...)
if params.Proto != firewall.ProtocolALL { if params.Proto != firewall.ProtocolALL {
rule = append(rule, "-p", strings.ToLower(string(params.Proto))) rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
@@ -745,7 +845,47 @@ func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
rule = append(rule, "-j", actionToStr(params.Action)) rule = append(rule, "-j", actionToStr(params.Action))
return rule return rule, nil
}
func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
direction := "src"
if flag == "-d" {
direction = "dst"
}
if network.IsSet() {
if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err)
}
return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil
}
if network.IsPrefix() {
return []string{flag, network.Prefix.String()}, nil
}
// nolint:nilnil
return nil, nil
}
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
var merr *multierror.Error
for _, prefix := range prefixes {
// TODO: Implement IPv6 support
if prefix.Addr().Is6() {
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue
}
if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil {
merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err))
}
}
if merr == nil {
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
}
return nberrors.FormatErrorOrNil(merr)
} }
func applyPort(flag string, port *firewall.Port) []string { func applyPort(flag string, port *firewall.Port) []string {

View File

@@ -46,7 +46,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
// 5. jump rule to PRE nat chain // 5. jump rule to PRE nat chain
// 6. static outbound masquerade rule // 6. static outbound masquerade rule
// 7. static return masquerade rule // 7. static return masquerade rule
require.Len(t, manager.rules, 7, "should have created rules map") // 8. mangle prerouting mark rule
// 9. mangle postrouting mark rule
require.Len(t, manager.rules, 9, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT) exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
@@ -58,8 +60,8 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
pair := firewall.RouterPair{ pair := firewall.RouterPair{
ID: "abc", ID: "abc",
Source: netip.MustParsePrefix("100.100.100.1/32"), Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: netip.MustParsePrefix("100.100.100.0/24"), Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.0/24")},
Masquerade: true, Masquerade: true,
} }
@@ -330,7 +332,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed") require.NoError(t, err, "AddRouteFiltering failed")
// Check if the rule is in the internal map // Check if the rule is in the internal map
@@ -345,23 +347,29 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
assert.NoError(t, err, "Failed to check rule existence") assert.NoError(t, err, "Failed to check rule existence")
assert.True(t, exists, "Rule not found in iptables") assert.True(t, exists, "Rule not found in iptables")
var source firewall.Network
if len(tt.sources) > 1 {
source.Set = firewall.NewPrefixSet(tt.sources)
} else if len(tt.sources) > 0 {
source.Prefix = tt.sources[0]
}
// Verify rule content // Verify rule content
params := routeFilteringRuleParams{ params := routeFilteringRuleParams{
Sources: tt.sources, Source: source,
Destination: tt.destination, Destination: firewall.Network{Prefix: tt.destination},
Proto: tt.proto, Proto: tt.proto,
SPort: tt.sPort, SPort: tt.sPort,
DPort: tt.dPort, DPort: tt.dPort,
Action: tt.action, Action: tt.action,
SetName: "",
} }
expectedRule := genRouteFilteringRuleSpec(params) expectedRule, err := r.genRouteRuleSpec(params, nil)
require.NoError(t, err, "Failed to generate expected rule spec")
if tt.expectSet { if tt.expectSet {
setName := firewall.GenerateSetName(tt.sources) setName := firewall.NewPrefixSet(tt.sources).HashedName()
params.SetName = setName expectedRule, err = r.genRouteRuleSpec(params, nil)
expectedRule = genRouteFilteringRuleSpec(params) require.NoError(t, err, "Failed to generate expected rule spec with set")
// Check if the set was created // Check if the set was created
_, exists := r.ipsetCounter.Get(setName) _, exists := r.ipsetCounter.Get(setName)
@@ -376,3 +384,62 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
}) })
} }
} }
func TestFindSetNameInRule(t *testing.T) {
r := &router{}
testCases := []struct {
name string
rule []string
expected []string
}{
{
name: "Basic rule with two sets",
rule: []string{
"-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-m", "set", "--match-set", "nb-2e5a2a05", "src",
"-m", "set", "--match-set", "nb-349ae051", "dst", "-m", "tcp", "--dport", "8080", "-j", "ACCEPT",
},
expected: []string{"nb-2e5a2a05", "nb-349ae051"},
},
{
name: "No sets",
rule: []string{"-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-j", "ACCEPT"},
expected: []string{},
},
{
name: "Multiple sets with different positions",
rule: []string{
"-m", "set", "--match-set", "set1", "src", "-p", "tcp",
"-m", "set", "--match-set", "set-abc123", "dst", "-j", "ACCEPT",
},
expected: []string{"set1", "set-abc123"},
},
{
name: "Boundary case - sequence appears at end",
rule: []string{"-p", "tcp", "-m", "set", "--match-set", "final-set"},
expected: []string{"final-set"},
},
{
name: "Incomplete pattern - missing set name",
rule: []string{"-p", "tcp", "-m", "set", "--match-set"},
expected: []string{},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := r.findSets(tc.rule)
if len(result) != len(tc.expected) {
t.Errorf("Expected %d sets, got %d. Sets found: %v", len(tc.expected), len(result), result)
return
}
for i, set := range result {
if set != tc.expected[i] {
t.Errorf("Expected set %q at position %d, got %q", tc.expected[i], i, set)
}
}
})
}
}

View File

@@ -1,13 +1,10 @@
package manager package manager
import ( import (
"crypto/sha256"
"encoding/hex"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"sort" "sort"
"strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -43,6 +40,18 @@ const (
// Action is the action to be taken on a rule // Action is the action to be taken on a rule
type Action int type Action int
// String returns the string representation of the action
func (a Action) String() string {
switch a {
case ActionAccept:
return "accept"
case ActionDrop:
return "drop"
default:
return "unknown"
}
}
const ( const (
// ActionAccept is the action to accept a packet // ActionAccept is the action to accept a packet
ActionAccept Action = iota ActionAccept Action = iota
@@ -50,6 +59,33 @@ const (
ActionDrop ActionDrop
) )
// Network is a rule destination, either a set or a prefix
type Network struct {
Set Set
Prefix netip.Prefix
}
// String returns the string representation of the destination
func (d Network) String() string {
if d.Prefix.IsValid() {
return d.Prefix.String()
}
if d.IsSet() {
return d.Set.HashedName()
}
return "<invalid network>"
}
// IsSet returns true if the destination is a set
func (d Network) IsSet() bool {
return d.Set != Set{}
}
// IsPrefix returns true if the destination is a valid prefix
func (d Network) IsPrefix() bool {
return d.Prefix.IsValid()
}
// Manager is the high level abstraction of a firewall manager // Manager is the high level abstraction of a firewall manager
// //
// It declares methods which handle actions required by the // It declares methods which handle actions required by the
@@ -65,13 +101,13 @@ type Manager interface {
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
AddPeerFiltering( AddPeerFiltering(
id []byte,
ip net.IP, ip net.IP,
proto Protocol, proto Protocol,
sPort *Port, sPort *Port,
dPort *Port, dPort *Port,
action Action, action Action,
ipsetName string, ipsetName string,
comment string,
) ([]Rule, error) ) ([]Rule, error)
// DeletePeerRule from the firewall by rule definition // DeletePeerRule from the firewall by rule definition
@@ -80,7 +116,14 @@ type Manager interface {
// IsServerRouteSupported returns true if the firewall supports server side routing operations // IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool IsServerRouteSupported() bool
AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination Network,
proto Protocol,
sPort, dPort *Port,
action Action,
) (Rule, error)
// DeleteRouteRule deletes a routing rule // DeleteRouteRule deletes a routing rule
DeleteRouteRule(rule Rule) error DeleteRouteRule(rule Rule) error
@@ -111,6 +154,9 @@ type Manager interface {
// DeleteDNATRule deletes a DNAT rule // DeleteDNATRule deletes a DNAT rule
DeleteDNATRule(Rule) error DeleteDNATRule(Rule) error
// UpdateSet updates the set with the given prefixes
UpdateSet(hash Set, prefixes []netip.Prefix) error
} }
func GenKey(format string, pair RouterPair) string { func GenKey(format string, pair RouterPair) string {
@@ -145,22 +191,6 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
return nil return nil
} }
// GenerateSetName generates a unique name for an ipset based on the given sources.
func GenerateSetName(sources []netip.Prefix) string {
// sort for consistent naming
SortPrefixes(sources)
var sourcesStr strings.Builder
for _, src := range sources {
sourcesStr.WriteString(src.String())
}
hash := sha256.Sum256([]byte(sourcesStr.String()))
shortHash := hex.EncodeToString(hash[:])[:8]
return fmt.Sprintf("nb-%s", shortHash)
}
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix // MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix { func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
if len(prefixes) == 0 { if len(prefixes) == 0 {

View File

@@ -20,8 +20,8 @@ func TestGenerateSetName(t *testing.T) {
netip.MustParsePrefix("192.168.1.0/24"), netip.MustParsePrefix("192.168.1.0/24"),
} }
result1 := manager.GenerateSetName(prefixes1) result1 := manager.NewPrefixSet(prefixes1)
result2 := manager.GenerateSetName(prefixes2) result2 := manager.NewPrefixSet(prefixes2)
if result1 != result2 { if result1 != result2 {
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2) t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
@@ -34,9 +34,9 @@ func TestGenerateSetName(t *testing.T) {
netip.MustParsePrefix("10.0.0.0/8"), netip.MustParsePrefix("10.0.0.0/8"),
} }
result := manager.GenerateSetName(prefixes) result := manager.NewPrefixSet(prefixes)
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result) matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result.HashedName())
if err != nil { if err != nil {
t.Fatalf("Error matching regex: %v", err) t.Fatalf("Error matching regex: %v", err)
} }
@@ -46,8 +46,8 @@ func TestGenerateSetName(t *testing.T) {
}) })
t.Run("Empty input produces consistent result", func(t *testing.T) { t.Run("Empty input produces consistent result", func(t *testing.T) {
result1 := manager.GenerateSetName([]netip.Prefix{}) result1 := manager.NewPrefixSet([]netip.Prefix{})
result2 := manager.GenerateSetName([]netip.Prefix{}) result2 := manager.NewPrefixSet([]netip.Prefix{})
if result1 != result2 { if result1 != result2 {
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2) t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
@@ -64,8 +64,8 @@ func TestGenerateSetName(t *testing.T) {
netip.MustParsePrefix("192.168.1.0/24"), netip.MustParsePrefix("192.168.1.0/24"),
} }
result1 := manager.GenerateSetName(prefixes1) result1 := manager.NewPrefixSet(prefixes1)
result2 := manager.GenerateSetName(prefixes2) result2 := manager.NewPrefixSet(prefixes2)
if result1 != result2 { if result1 != result2 {
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2) t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)

View File

@@ -1,15 +1,13 @@
package manager package manager
import ( import (
"net/netip"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
type RouterPair struct { type RouterPair struct {
ID route.ID ID route.ID
Source netip.Prefix Source Network
Destination netip.Prefix Destination Network
Masquerade bool Masquerade bool
Inverse bool Inverse bool
} }

View File

@@ -0,0 +1,74 @@
package manager
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"net/netip"
"slices"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/domain"
)
type Set struct {
hash [4]byte
comment string
}
// String returns the string representation of the set: hashed name and comment
func (h Set) String() string {
if h.comment == "" {
return h.HashedName()
}
return h.HashedName() + ": " + h.comment
}
// HashedName returns the string representation of the hash
func (h Set) HashedName() string {
return fmt.Sprintf(
"nb-%s",
hex.EncodeToString(h.hash[:]),
)
}
// Comment returns the comment of the set
func (h Set) Comment() string {
return h.comment
}
// NewPrefixSet generates a unique name for an ipset based on the given prefixes.
func NewPrefixSet(prefixes []netip.Prefix) Set {
// sort for consistent naming
SortPrefixes(prefixes)
hash := sha256.New()
for _, src := range prefixes {
bytes, err := src.MarshalBinary()
if err != nil {
log.Warnf("failed to marshal prefix %s: %v", src, err)
}
hash.Write(bytes)
}
var set Set
copy(set.hash[:], hash.Sum(nil)[:4])
return set
}
// NewDomainSet generates a unique name for an ipset based on the given domains.
func NewDomainSet(domains domain.List) Set {
slices.Sort(domains)
hash := sha256.New()
for _, d := range domains {
hash.Write([]byte(d.PunycodeString()))
}
set := Set{
comment: domains.SafeString(),
}
copy(set.hash[:], hash.Sum(nil)[:4])
return set
}

View File

@@ -27,7 +27,8 @@ const (
// filter chains contains the rules that jump to the rules chains // filter chains contains the rules that jump to the rules chains
chainNameInputFilter = "netbird-acl-input-filter" chainNameInputFilter = "netbird-acl-input-filter"
chainNameForwardFilter = "netbird-acl-forward-filter" chainNameForwardFilter = "netbird-acl-forward-filter"
chainNamePrerouting = "netbird-rt-prerouting" chainNameManglePrerouting = "netbird-mangle-prerouting"
chainNameManglePostrouting = "netbird-mangle-postrouting"
allowNetbirdInputRuleID = "allow Netbird incoming traffic" allowNetbirdInputRuleID = "allow Netbird incoming traffic"
) )
@@ -84,13 +85,13 @@ func (m *AclManager) init(workTable *nftables.Table) error {
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
func (m *AclManager) AddPeerFiltering( func (m *AclManager) AddPeerFiltering(
id []byte,
ip net.IP, ip net.IP,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
var ipset *nftables.Set var ipset *nftables.Set
if ipsetName != "" { if ipsetName != "" {
@@ -102,7 +103,7 @@ func (m *AclManager) AddPeerFiltering(
} }
newRules := make([]firewall.Rule, 0, 2) newRules := make([]firewall.Rule, 0, 2)
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset, comment) ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -256,7 +257,6 @@ func (m *AclManager) addIOFiltering(
dPort *firewall.Port, dPort *firewall.Port,
action firewall.Action, action firewall.Action,
ipset *nftables.Set, ipset *nftables.Set,
comment string,
) (*Rule, error) { ) (*Rule, error) {
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset) ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
if r, ok := m.rules[ruleId]; ok { if r, ok := m.rules[ruleId]; ok {
@@ -338,7 +338,7 @@ func (m *AclManager) addIOFiltering(
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop}) mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
} }
userData := []byte(strings.Join([]string{ruleId, comment}, " ")) userData := []byte(ruleId)
chain := m.chainInputRules chain := m.chainInputRules
nftRule := m.rConn.AddRule(&nftables.Rule{ nftRule := m.rConn.AddRule(&nftables.Rule{
@@ -463,13 +463,15 @@ func (m *AclManager) createDefaultChains() (err error) {
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the // go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
// netbird peer IP. // netbird peer IP.
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error { func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
m.chainPrerouting = m.rConn.AddChain(&nftables.Chain{ // Chain is created by route manager
Name: chainNamePrerouting, // TODO: move creation to a common place
m.chainPrerouting = &nftables.Chain{
Name: chainNameManglePrerouting,
Table: m.workTable, Table: m.workTable,
Type: nftables.ChainTypeFilter, Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting, Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle, Priority: nftables.ChainPriorityMangle,
}) }
m.addFwmarkToForward(chainFwFilter) m.addFwmarkToForward(chainFwFilter)

View File

@@ -113,13 +113,13 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
func (m *Manager) AddPeerFiltering( func (m *Manager) AddPeerFiltering(
id []byte,
ip net.IP, ip net.IP,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -129,25 +129,25 @@ func (m *Manager) AddPeerFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", ip.String()) return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
} }
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, action, ipsetName, comment) return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
} }
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination firewall.Network,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort, dPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action, action firewall.Action,
) (firewall.Rule, error) { ) (firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if !destination.Addr().Is4() { if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String()) return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
} }
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action) return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
} }
// DeletePeerRule from the firewall by rule definition // DeletePeerRule from the firewall by rule definition
@@ -241,7 +241,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return firewall.SetLegacyManagement(m.router, isLegacy) return firewall.SetLegacyManagement(m.router, isLegacy)
} }
// Reset firewall to the default state // Close closes the firewall manager
func (m *Manager) Close(stateManager *statemanager.Manager) error { func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -358,6 +358,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
return m.router.DeleteDNATRule(rule) return m.router.DeleteDNATRule(rule)
} }
// UpdateSet updates the set with the given prefixes
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.UpdateSet(set, prefixes)
}
func (m *Manager) createWorkTable() (*nftables.Table, error) { func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil { if err != nil {

View File

@@ -74,7 +74,7 @@ func TestNftablesManager(t *testing.T) {
testClient := &nftables.Conn{} testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "", "") rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Flush() err = manager.Flush()
@@ -201,7 +201,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}} port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
if i%100 == 0 { if i%100 == 0 {
@@ -283,12 +283,13 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
}) })
ip := net.ParseIP("100.96.0.1") ip := net.ParseIP("100.96.0.1")
_, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "", "test rule") _, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err, "failed to add peer filtering rule") require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering( _, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")}, []netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
netip.MustParsePrefix("10.1.0.0/24"), fw.Network{Prefix: netip.MustParsePrefix("10.1.0.0/24")},
fw.ProtocolTCP, fw.ProtocolTCP,
nil, nil,
&fw.Port{Values: []uint16{443}}, &fw.Port{Values: []uint16{443}},
@@ -297,8 +298,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
require.NoError(t, err, "failed to add route filtering rule") require.NoError(t, err, "failed to add route filtering rule")
pair := fw.RouterPair{ pair := fw.RouterPair{
Source: netip.MustParsePrefix("192.168.1.0/24"), Source: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
Destination: netip.MustParsePrefix("10.0.0.0/24"), Destination: fw.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
Masquerade: true, Masquerade: true,
} }
err = manager.AddNatRule(pair) err = manager.AddNatRule(pair)

View File

@@ -10,7 +10,6 @@ import (
"strings" "strings"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/davecgh/go-spew/spew"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil" "github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
@@ -20,7 +19,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id" nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
@@ -44,9 +43,14 @@ const (
const refreshRulesMapError = "refresh rules map: %w" const refreshRulesMapError = "refresh rules map: %w"
var ( var (
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found") errFilterTableNotFound = fmt.Errorf("'filter' table not found")
) )
type setInput struct {
set firewall.Set
prefixes []netip.Prefix
}
type router struct { type router struct {
conn *nftables.Conn conn *nftables.Conn
workTable *nftables.Table workTable *nftables.Table
@@ -54,7 +58,7 @@ type router struct {
chains map[string]*nftables.Chain chains map[string]*nftables.Chain
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules // rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
rules map[string]*nftables.Rule rules map[string]*nftables.Rule
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set] ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
wgIface iFaceMapper wgIface iFaceMapper
ipFwdState *ipfwdstate.IPForwardingState ipFwdState *ipfwdstate.IPForwardingState
@@ -100,6 +104,10 @@ func (r *router) init(workTable *nftables.Table) error {
return fmt.Errorf("create containers: %w", err) return fmt.Errorf("create containers: %w", err)
} }
if err := r.setupDataPlaneMark(); err != nil {
log.Errorf("failed to set up data plane mark: %v", err)
}
return nil return nil
} }
@@ -159,7 +167,7 @@ func (r *router) removeNatPreroutingRules() error {
func (r *router) loadFilterTable() (*nftables.Table, error) { func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4) tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil { if err != nil {
return nil, fmt.Errorf("nftables: unable to list tables: %v", err) return nil, fmt.Errorf("unable to list tables: %v", err)
} }
for _, table := range tables { for _, table := range tables {
@@ -196,15 +204,21 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeNAT, Type: nftables.ChainTypeNAT,
}) })
// Chain is created by acl manager r.chains[chainNameManglePostrouting] = r.conn.AddChain(&nftables.Chain{
// TODO: move creation to a common place Name: chainNameManglePostrouting,
r.chains[chainNamePrerouting] = &nftables.Chain{ Table: r.workTable,
Name: chainNamePrerouting, Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
})
r.chains[chainNameManglePrerouting] = r.conn.AddChain(&nftables.Chain{
Name: chainNameManglePrerouting,
Table: r.workTable, Table: r.workTable,
Hooknum: nftables.ChainHookPrerouting, Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle, Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter, Type: nftables.ChainTypeFilter,
} })
// Add the single NAT rule that matches on mark // Add the single NAT rule that matches on mark
if err := r.addPostroutingRules(); err != nil { if err := r.addPostroutingRules(); err != nil {
@@ -220,7 +234,83 @@ func (r *router) createContainers() error {
} }
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: unable to initialize table: %v", err) return fmt.Errorf("initialize tables: %v", err)
}
return nil
}
// setupDataPlaneMark configures the fwmark for the data plane
func (r *router) setupDataPlaneMark() error {
if r.chains[chainNameManglePrerouting] == nil || r.chains[chainNameManglePostrouting] == nil {
return errors.New("no mangle chains found")
}
ctNew := getCtNewExprs()
preExprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}
preExprs = append(preExprs, ctNew...)
preExprs = append(preExprs,
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkIn),
},
&expr.Ct{
Key: expr.CtKeyMARK,
Register: 1,
SourceRegister: true,
},
)
preNftRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameManglePrerouting],
Exprs: preExprs,
}
r.conn.AddRule(preNftRule)
postExprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}
postExprs = append(postExprs, ctNew...)
postExprs = append(postExprs,
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkOut),
},
&expr.Ct{
Key: expr.CtKeyMARK,
Register: 1,
SourceRegister: true,
},
)
postNftRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameManglePostrouting],
Exprs: postExprs,
}
r.conn.AddRule(postNftRule)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush: %w", err)
} }
return nil return nil
@@ -228,15 +318,16 @@ func (r *router) createContainers() error {
// AddRouteFiltering appends a nftables rule to the routing chain // AddRouteFiltering appends a nftables rule to the routing chain
func (r *router) AddRouteFiltering( func (r *router) AddRouteFiltering(
id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination firewall.Network,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
action firewall.Action, action firewall.Action,
) (firewall.Rule, error) { ) (firewall.Rule, error) {
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action) ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok { if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil return ruleKey, nil
} }
@@ -244,23 +335,29 @@ func (r *router) AddRouteFiltering(
chain := r.chains[chainNameRoutingFw] chain := r.chains[chainNameRoutingFw]
var exprs []expr.Any var exprs []expr.Any
var source firewall.Network
switch { switch {
case len(sources) == 1 && sources[0].Bits() == 0: case len(sources) == 1 && sources[0].Bits() == 0:
// If it's 0.0.0.0/0, we don't need to add any source matching // If it's 0.0.0.0/0, we don't need to add any source matching
case len(sources) == 1: case len(sources) == 1:
// If there's only one source, we can use it directly // If there's only one source, we can use it directly
exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...) source.Prefix = sources[0]
default: default:
// If there are multiple sources, create or get an ipset // If there are multiple sources, use a set
var err error source.Set = firewall.NewPrefixSet(sources)
exprs, err = r.getIpSetExprs(sources, exprs)
if err != nil {
return nil, fmt.Errorf("get ipset expressions: %w", err)
}
} }
// Handle destination sourceExp, err := r.applyNetwork(source, sources, true)
exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...) if err != nil {
return nil, fmt.Errorf("apply source: %w", err)
}
exprs = append(exprs, sourceExp...)
destExp, err := r.applyNetwork(destination, nil, false)
if err != nil {
return nil, fmt.Errorf("apply destination: %w", err)
}
exprs = append(exprs, destExp...)
// Handle protocol // Handle protocol
if proto != firewall.ProtocolALL { if proto != firewall.ProtocolALL {
@@ -304,39 +401,27 @@ func (r *router) AddRouteFiltering(
rule = r.conn.AddRule(rule) rule = r.conn.AddRule(rule)
} }
log.Tracef("Adding route rule %s", spew.Sdump(rule))
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf(flushError, err) return nil, fmt.Errorf(flushError, err)
} }
r.rules[string(ruleKey)] = rule r.rules[string(ruleKey)] = rule
log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action) log.Debugf("added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
return ruleKey, nil return ruleKey, nil
} }
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) { func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bool) ([]expr.Any, error) {
setName := firewall.GenerateSetName(sources) ref, err := r.ipsetCounter.Increment(set.HashedName(), setInput{
ref, err := r.ipsetCounter.Increment(setName, sources) set: set,
prefixes: prefixes,
})
if err != nil { if err != nil {
return nil, fmt.Errorf("create or get ipset for sources: %w", err) return nil, fmt.Errorf("create or get ipset: %w", err)
} }
exprs = append(exprs, return getIpSetExprs(ref, isSource)
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: ref.Out.Name,
SetID: ref.Out.ID,
},
)
return exprs, nil
} }
func (r *router) DeleteRouteRule(rule firewall.Rule) error { func (r *router) DeleteRouteRule(rule firewall.Rule) error {
@@ -355,42 +440,54 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
return fmt.Errorf("route rule %s has no handle", ruleKey) return fmt.Errorf("route rule %s has no handle", ruleKey)
} }
setName := r.findSetNameInRule(nftRule)
if err := r.deleteNftRule(nftRule, ruleKey); err != nil { if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
return fmt.Errorf("delete: %w", err) return fmt.Errorf("delete: %w", err)
} }
if setName != "" {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
return fmt.Errorf("decrement ipset reference: %w", err)
}
}
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err) return fmt.Errorf(flushError, err)
} }
if err := r.decrementSetCounter(nftRule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
return nil return nil
} }
func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) { func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, error) {
// overlapping prefixes will result in an error, so we need to merge them // overlapping prefixes will result in an error, so we need to merge them
sources = firewall.MergeIPRanges(sources) prefixes := firewall.MergeIPRanges(input.prefixes)
set := &nftables.Set{ nfset := &nftables.Set{
Name: setName, Name: setName,
Comment: input.set.Comment(),
Table: r.workTable, Table: r.workTable,
// required for prefixes // required for prefixes
Interval: true, Interval: true,
KeyType: nftables.TypeIPAddr, KeyType: nftables.TypeIPAddr,
} }
elements := convertPrefixesToSet(prefixes)
if err := r.conn.AddSet(nfset, elements); err != nil {
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
return nfset, nil
}
func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
var elements []nftables.SetElement var elements []nftables.SetElement
for _, prefix := range sources { for _, prefix := range prefixes {
// TODO: Implement IPv6 support // TODO: Implement IPv6 support
if prefix.Addr().Is6() { if prefix.Addr().Is6() {
log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix) log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue continue
} }
@@ -406,18 +503,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true}, nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
) )
} }
return elements
if err := r.conn.AddSet(set, elements); err != nil {
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
return set, nil
} }
// calculateLastIP determines the last IP in a given prefix. // calculateLastIP determines the last IP in a given prefix.
@@ -441,8 +527,8 @@ func uint32ToBytes(ip uint32) [4]byte {
return b return b
} }
func (r *router) deleteIpSet(setName string, set *nftables.Set) error { func (r *router) deleteIpSet(setName string, nfset *nftables.Set) error {
r.conn.DelSet(set) r.conn.DelSet(nfset)
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err) return fmt.Errorf(flushError, err)
} }
@@ -451,13 +537,27 @@ func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
return nil return nil
} }
func (r *router) findSetNameInRule(rule *nftables.Rule) string { func (r *router) decrementSetCounter(rule *nftables.Rule) error {
sets := r.findSets(rule)
var merr *multierror.Error
for _, setName := range sets {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
merr = multierror.Append(merr, fmt.Errorf("decrement set counter: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) findSets(rule *nftables.Rule) []string {
var sets []string
for _, e := range rule.Exprs { for _, e := range rule.Exprs {
if lookup, ok := e.(*expr.Lookup); ok { if lookup, ok := e.(*expr.Lookup); ok {
return lookup.SetName sets = append(sets, lookup.SetName)
} }
} }
return "" return sets
} }
func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error { func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
@@ -499,7 +599,8 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
} }
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err) // TODO: rollback ipset counter
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
} }
return nil return nil
@@ -507,8 +608,15 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
// addNatRule inserts a nftables rule to the conn client flush queue // addNatRule inserts a nftables rule to the conn client flush queue
func (r *router) addNatRule(pair firewall.RouterPair) error { func (r *router) addNatRule(pair firewall.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source) sourceExp, err := r.applyNetwork(pair.Source, nil, true)
destExp := generateCIDRMatcherExpressions(false, pair.Destination) if err != nil {
return fmt.Errorf("apply source: %w", err)
}
destExp, err := r.applyNetwork(pair.Destination, nil, false)
if err != nil {
return fmt.Errorf("apply destination: %w", err)
}
op := expr.CmpOpEq op := expr.CmpOpEq
if pair.Inverse { if pair.Inverse {
@@ -516,26 +624,6 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
} }
exprs := []expr.Any{ exprs := []expr.Any{
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
// interface matching
&expr.Meta{ &expr.Meta{
Key: expr.MetaKeyIIFNAME, Key: expr.MetaKeyIIFNAME,
Register: 1, Register: 1,
@@ -546,6 +634,9 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
Data: ifname(r.wgIface.Name()), Data: ifname(r.wgIface.Name()),
}, },
} }
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
exprs = append(exprs, getCtNewExprs()...)
exprs = append(exprs, sourceExp...) exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...) exprs = append(exprs, destExp...)
@@ -575,9 +666,11 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
} }
} }
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ // Ensure nat rules come first, so the mark can be overwritten.
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
Table: r.workTable, Table: r.workTable,
Chain: r.chains[chainNamePrerouting], Chain: r.chains[chainNameManglePrerouting],
Exprs: exprs, Exprs: exprs,
UserData: []byte(ruleKey), UserData: []byte(ruleKey),
}) })
@@ -658,8 +751,15 @@ func (r *router) addPostroutingRules() error {
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls // addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source) sourceExp, err := r.applyNetwork(pair.Source, nil, true)
destExp := generateCIDRMatcherExpressions(false, pair.Destination) if err != nil {
return fmt.Errorf("apply source: %w", err)
}
destExp, err := r.applyNetwork(pair.Destination, nil, false)
if err != nil {
return fmt.Errorf("apply destination: %w", err)
}
exprs := []expr.Any{ exprs := []expr.Any{
&expr.Counter{}, &expr.Counter{},
@@ -668,7 +768,8 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
}, },
} }
expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...)
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
@@ -681,7 +782,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{ r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable, Table: r.workTable,
Chain: r.chains[chainNameRoutingFw], Chain: r.chains[chainNameRoutingFw],
Exprs: expression, Exprs: exprs,
UserData: []byte(ruleKey), UserData: []byte(ruleKey),
}) })
return nil return nil
@@ -696,11 +797,13 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
} }
log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination) log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
} else {
log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey) if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
} }
return nil return nil
@@ -911,6 +1014,7 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
if pair.Masquerade {
if err := r.removeNatRule(pair); err != nil { if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove prerouting rule: %w", err) return fmt.Errorf("remove prerouting rule: %w", err)
} }
@@ -918,16 +1022,17 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil { if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse prerouting rule: %w", err) return fmt.Errorf("remove inverse prerouting rule: %w", err)
} }
}
if err := r.removeLegacyRouteRule(pair); err != nil { if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err) return fmt.Errorf("remove legacy routing rule: %w", err)
} }
if err := r.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err) // TODO: rollback set counter
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
} }
log.Debugf("nftables: removed nat rules for %s", pair.Destination)
return nil return nil
} }
@@ -935,16 +1040,19 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair) ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
err := r.conn.DelRule(rule) if err := r.conn.DelRule(rule); err != nil {
if err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err) return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
} }
log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination) log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
if err := r.decrementSetCounter(rule); err != nil {
return fmt.Errorf("decrement set counter: %w", err)
}
} else { } else {
log.Debugf("nftables: prerouting rule %s not found", ruleKey) log.Debugf("prerouting rule %s not found", ruleKey)
} }
return nil return nil
@@ -956,7 +1064,7 @@ func (r *router) refreshRulesMap() error {
for _, chain := range r.chains { for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain) rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil { if err != nil {
return fmt.Errorf("nftables: unable to list rules: %v", err) return fmt.Errorf(" unable to list rules: %v", err)
} }
for _, rule := range rules { for _, rule := range rules {
if len(rule.UserData) > 0 { if len(rule.UserData) > 0 {
@@ -1230,13 +1338,54 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any { nfset, err := r.conn.GetSetByName(r.workTable, set.HashedName())
var offset uint32 if err != nil {
if source { return fmt.Errorf("get set %s: %w", set.HashedName(), err)
offset = 12 // src offset }
} else {
offset = 16 // dst offset elements := convertPrefixesToSet(prefixes)
if err := r.conn.SetAddElements(nfset, elements); err != nil {
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
return nil
}
// applyNetwork generates nftables expressions for networks (CIDR) or sets
func (r *router) applyNetwork(
network firewall.Network,
setPrefixes []netip.Prefix,
isSource bool,
) ([]expr.Any, error) {
if network.IsSet() {
exprs, err := r.getIpSet(network.Set, setPrefixes, isSource)
if err != nil {
return nil, fmt.Errorf("source: %w", err)
}
return exprs, nil
}
if network.IsPrefix() {
return applyPrefix(network.Prefix, isSource), nil
}
return nil, nil
}
// applyPrefix generates nftables expressions for a CIDR prefix
func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
// dst offset
offset := uint32(16)
if isSource {
// src offset
offset = 12
} }
ones := prefix.Bits() ones := prefix.Bits()
@@ -1323,3 +1472,48 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
return exprs return exprs
} }
func getCtNewExprs() []expr.Any {
return []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
}
}
func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
// dst offset
offset := uint32(16)
if isSource {
// src offset
offset = 12
}
return []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offset,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: ref.Out.Name,
SetID: ref.Out.ID,
},
}, nil
}

View File

@@ -88,8 +88,8 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
} }
// Build CIDR matching expressions // Build CIDR matching expressions
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false)
// Combine all expressions in the correct order // Combine all expressions in the correct order
// nolint:gocritic // nolint:gocritic
@@ -100,7 +100,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair) natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
found := 0 found := 0
for _, chain := range rtr.chains { for _, chain := range rtr.chains {
if chain.Name == chainNamePrerouting { if chain.Name == chainNameManglePrerouting {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain) rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules { for _, rule := range rules {
@@ -141,7 +141,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
// Verify the rule was added // Verify the rule was added
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair) natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
found := false found := false
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting]) rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
require.NoError(t, err, "should list rules") require.NoError(t, err, "should list rules")
for _, rule := range rules { for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey { if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
@@ -157,7 +157,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
// Verify the rule was removed // Verify the rule was removed
found = false found = false
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting]) rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
require.NoError(t, err, "should list rules after removal") require.NoError(t, err, "should list rules after removal")
for _, rule := range rules { for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey { if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
@@ -311,7 +311,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed") require.NoError(t, err, "AddRouteFiltering failed")
t.Cleanup(func() { t.Cleanup(func() {
@@ -441,8 +441,8 @@ func TestNftablesCreateIpSet(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
setName := firewall.GenerateSetName(tt.sources) setName := firewall.NewPrefixSet(tt.sources).HashedName()
set, err := r.createIpSet(setName, tt.sources) set, err := r.createIpSet(setName, setInput{prefixes: tt.sources})
if err != nil { if err != nil {
t.Logf("Failed to create IP set: %v", err) t.Logf("Failed to create IP set: %v", err)
printNftSets() printNftSets()

View File

@@ -15,8 +15,8 @@ var (
Name: "Insert Forwarding IPV4 Rule", Name: "Insert Forwarding IPV4 Rule",
InputPair: firewall.RouterPair{ InputPair: firewall.RouterPair{
ID: "zxa", ID: "zxa",
Source: netip.MustParsePrefix("100.100.100.1/32"), Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: netip.MustParsePrefix("100.100.200.0/24"), Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
Masquerade: false, Masquerade: false,
}, },
}, },
@@ -24,8 +24,8 @@ var (
Name: "Insert Forwarding And Nat IPV4 Rules", Name: "Insert Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{ InputPair: firewall.RouterPair{
ID: "zxa", ID: "zxa",
Source: netip.MustParsePrefix("100.100.100.1/32"), Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: netip.MustParsePrefix("100.100.200.0/24"), Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
Masquerade: true, Masquerade: true,
}, },
}, },
@@ -40,8 +40,8 @@ var (
Name: "Remove Forwarding And Nat IPV4 Rules", Name: "Remove Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{ InputPair: firewall.RouterPair{
ID: "zxa", ID: "zxa",
Source: netip.MustParsePrefix("100.100.100.1/32"), Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
Destination: netip.MustParsePrefix("100.100.200.0/24"), Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
Masquerade: true, Masquerade: true,
}, },
}, },

View File

@@ -4,6 +4,7 @@ package uspfilter
import ( import (
"context" "context"
"net/netip"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -11,13 +12,13 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
// Reset firewall to the default state // Close cleans up the firewall manager by removing all rules and closing trackers
func (m *Manager) Close(stateManager *statemanager.Manager) error { func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet) m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[string]RuleSet) m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil { if m.udpTracker != nil {
m.udpTracker.Close() m.udpTracker.Close()
@@ -31,8 +32,8 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.tcpTracker.Close() m.tcpTracker.Close()
} }
if m.forwarder != nil { if fwder := m.forwarder.Load(); fwder != nil {
m.forwarder.Stop() fwder.Stop()
} }
if m.logger != nil { if m.logger != nil {

View File

@@ -3,6 +3,7 @@ package uspfilter
import ( import (
"context" "context"
"fmt" "fmt"
"net/netip"
"os/exec" "os/exec"
"syscall" "syscall"
"time" "time"
@@ -20,13 +21,13 @@ const (
firewallRuleName = "Netbird" firewallRuleName = "Netbird"
) )
// Close closes the firewall manager // Close cleans up the firewall manager by removing all rules and closing trackers
func (m *Manager) Close(*statemanager.Manager) error { func (m *Manager) Close(*statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet) m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[string]RuleSet) m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil { if m.udpTracker != nil {
m.udpTracker.Close() m.udpTracker.Close()
@@ -40,8 +41,8 @@ func (m *Manager) Close(*statemanager.Manager) error {
m.tcpTracker.Close() m.tcpTracker.Close()
} }
if m.forwarder != nil { if fwder := m.forwarder.Load(); fwder != nil {
m.forwarder.Stop() fwder.Stop()
} }
if m.logger != nil { if m.logger != nil {

View File

@@ -1,20 +1,27 @@
// common.go
package conntrack package conntrack
import ( import (
"net" "fmt"
"sync" "net/netip"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/google/uuid"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
// BaseConnTrack provides common fields and locking for all connection types // BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct { type BaseConnTrack struct {
SourceIP net.IP FlowId uuid.UUID
DestIP net.IP Direction nftypes.Direction
SourcePort uint16 SourceIP netip.Addr
DestPort uint16 DestIP netip.Addr
lastSeen atomic.Int64 // Unix nano for atomic access lastSeen atomic.Int64
PacketsTx atomic.Uint64
PacketsRx atomic.Uint64
BytesTx atomic.Uint64
BytesRx atomic.Uint64
} }
// these small methods will be inlined by the compiler // these small methods will be inlined by the compiler
@@ -24,6 +31,17 @@ func (b *BaseConnTrack) UpdateLastSeen() {
b.lastSeen.Store(time.Now().UnixNano()) b.lastSeen.Store(time.Now().UnixNano())
} }
// UpdateCounters safely updates the packet and byte counters
func (b *BaseConnTrack) UpdateCounters(direction nftypes.Direction, bytes int) {
if direction == nftypes.Egress {
b.PacketsTx.Add(1)
b.BytesTx.Add(uint64(bytes))
} else {
b.PacketsRx.Add(1)
b.BytesRx.Add(uint64(bytes))
}
}
// GetLastSeen safely gets the last seen timestamp // GetLastSeen safely gets the last seen timestamp
func (b *BaseConnTrack) GetLastSeen() time.Time { func (b *BaseConnTrack) GetLastSeen() time.Time {
return time.Unix(0, b.lastSeen.Load()) return time.Unix(0, b.lastSeen.Load())
@@ -35,92 +53,14 @@ func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
return time.Since(lastSeen) > timeout return time.Since(lastSeen) > timeout
} }
// IPAddr is a fixed-size IP address to avoid allocations
type IPAddr [16]byte
// MakeIPAddr creates an IPAddr from net.IP
func MakeIPAddr(ip net.IP) (addr IPAddr) {
// Optimization: check for v4 first as it's more common
if ip4 := ip.To4(); ip4 != nil {
copy(addr[12:], ip4)
} else {
copy(addr[:], ip.To16())
}
return addr
}
// ConnKey uniquely identifies a connection // ConnKey uniquely identifies a connection
type ConnKey struct { type ConnKey struct {
SrcIP IPAddr SrcIP netip.Addr
DstIP IPAddr DstIP netip.Addr
SrcPort uint16 SrcPort uint16
DstPort uint16 DstPort uint16
} }
// makeConnKey creates a connection key func (c ConnKey) String() string {
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey { return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
return ConnKey{
SrcIP: MakeIPAddr(srcIP),
DstIP: MakeIPAddr(dstIP),
SrcPort: srcPort,
DstPort: dstPort,
}
}
// ValidateIPs checks if IPs match without allocation
func ValidateIPs(connIP IPAddr, pktIP net.IP) bool {
if ip4 := pktIP.To4(); ip4 != nil {
// Compare IPv4 addresses (last 4 bytes)
for i := 0; i < 4; i++ {
if connIP[12+i] != ip4[i] {
return false
}
}
return true
}
// Compare full IPv6 addresses
ip6 := pktIP.To16()
for i := 0; i < 16; i++ {
if connIP[i] != ip6[i] {
return false
}
}
return true
}
// PreallocatedIPs is a pool of IP byte slices to reduce allocations
type PreallocatedIPs struct {
sync.Pool
}
// NewPreallocatedIPs creates a new IP pool
func NewPreallocatedIPs() *PreallocatedIPs {
return &PreallocatedIPs{
Pool: sync.Pool{
New: func() interface{} {
ip := make(net.IP, 16)
return &ip
},
},
}
}
// Get retrieves an IP from the pool
func (p *PreallocatedIPs) Get() net.IP {
return *p.Pool.Get().(*net.IP)
}
// Put returns an IP to the pool
func (p *PreallocatedIPs) Put(ip net.IP) {
p.Pool.Put(&ip)
}
// copyIP copies an IP address efficiently
func copyIP(dst, src net.IP) {
if len(src) == 16 {
copy(dst, src)
} else {
// Handle IPv4
copy(dst[12:], src.To4())
}
} }

View File

@@ -1,94 +1,66 @@
package conntrack package conntrack
import ( import (
"net" "net/netip"
"testing" "testing"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log" "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/internal/netflow"
) )
var logger = log.NewFromLogrus(logrus.StandardLogger()) var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
func BenchmarkIPOperations(b *testing.B) {
b.Run("MakeIPAddr", func(b *testing.B) {
ip := net.ParseIP("192.168.1.1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = MakeIPAddr(ip)
}
})
b.Run("ValidateIPs", func(b *testing.B) {
ip1 := net.ParseIP("192.168.1.1")
ip2 := net.ParseIP("192.168.1.1")
addr := MakeIPAddr(ip1)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = ValidateIPs(addr, ip2)
}
})
b.Run("IPPool", func(b *testing.B) {
pool := NewPreallocatedIPs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
ip := pool.Get()
pool.Put(ip)
}
})
}
// Memory pressure tests // Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) { func BenchmarkMemoryPressure(b *testing.B) {
b.Run("TCPHighLoad", func(b *testing.B) { b.Run("TCPHighLoad", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
// Generate different IPs // Generate different IPs
srcIPs := make([]net.IP, 100) srcIPs := make([]netip.Addr, 100)
dstIPs := make([]net.IP, 100) dstIPs := make([]netip.Addr, 100)
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256)) srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256)) dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs) srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs) dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn) tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn, 0)
// Simulate some valid inbound packets // Simulate some valid inbound packets
if i%3 == 0 { if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck) tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck, 0)
} }
} }
}) })
b.Run("UDPHighLoad", func(b *testing.B) { b.Run("UDPHighLoad", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger) tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
// Generate different IPs // Generate different IPs
srcIPs := make([]net.IP, 100) srcIPs := make([]netip.Addr, 100)
dstIPs := make([]net.IP, 100) dstIPs := make([]netip.Addr, 100)
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256)) srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256)) dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs) srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs) dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80) tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, 0)
// Simulate some valid inbound packets // Simulate some valid inbound packets
if i%3 == 0 { if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535)) tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), 0)
} }
} }
}) })

View File

@@ -2,13 +2,16 @@ package conntrack
import ( import (
"context" "context"
"net" "fmt"
"net/netip"
"sync" "sync"
"time" "time"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
const ( const (
@@ -20,18 +23,20 @@ const (
// ICMPConnKey uniquely identifies an ICMP connection // ICMPConnKey uniquely identifies an ICMP connection
type ICMPConnKey struct { type ICMPConnKey struct {
// Supports both IPv4 and IPv6 SrcIP netip.Addr
SrcIP [16]byte DstIP netip.Addr
DstIP [16]byte ID uint16
Sequence uint16 // ICMP sequence number }
ID uint16 // ICMP identifier
func (i ICMPConnKey) String() string {
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
} }
// ICMPConnTrack represents an ICMP connection state // ICMPConnTrack represents an ICMP connection state
type ICMPConnTrack struct { type ICMPConnTrack struct {
BaseConnTrack BaseConnTrack
Sequence uint16 ICMPType uint8
ID uint16 ICMPCode uint8
} }
// ICMPTracker manages ICMP connection states // ICMPTracker manages ICMP connection states
@@ -42,11 +47,11 @@ type ICMPTracker struct {
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
tickerCancel context.CancelFunc tickerCancel context.CancelFunc
mutex sync.RWMutex mutex sync.RWMutex
ipPool *PreallocatedIPs flowLogger nftypes.FlowLogger
} }
// NewICMPTracker creates a new ICMP connection tracker // NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker { func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
if timeout == 0 { if timeout == 0 {
timeout = DefaultICMPTimeout timeout = DefaultICMPTimeout
} }
@@ -59,67 +64,108 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval), cleanupTicker: time.NewTicker(ICMPCleanupInterval),
tickerCancel: cancel, tickerCancel: cancel,
ipPool: NewPreallocatedIPs(), flowLogger: flowLogger,
} }
go tracker.cleanupRoutine(ctx) go tracker.cleanupRoutine(ctx)
return tracker return tracker
} }
// TrackOutbound records an outbound ICMP Echo Request func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, direction nftypes.Direction, size int) (ICMPConnKey, bool) {
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) { key := ICMPConnKey{
key := makeICMPKey(srcIP, dstIP, id, seq) SrcIP: srcIP,
DstIP: dstIP,
t.mutex.Lock()
conn, exists := t.connections[key]
if !exists {
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &ICMPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
},
ID: id, ID: id,
Sequence: seq,
} }
conn.UpdateLastSeen()
t.connections[key] = conn
t.logger.Trace("New ICMP connection %v", key)
}
t.mutex.Unlock()
conn.UpdateLastSeen()
}
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
return false
}
key := makeICMPKey(dstIP, srcIP, id, seq)
t.mutex.RLock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
t.mutex.RUnlock() t.mutex.RUnlock()
if !exists { if exists {
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, true
}
return key, false
}
// TrackOutbound records an outbound ICMP connection
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
}
}
// TrackInbound records an inbound ICMP Echo Request
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
}
// track is the common implementation for tracking both inbound and outbound ICMP connections
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
if exists {
return
}
typ, code := typecode.Type(), typecode.Code()
// non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
return
}
conn := &ICMPConnTrack{
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
},
ICMPType: typ,
ICMPCode: code,
}
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
t.mutex.Lock()
t.connections[key] = conn
t.mutex.Unlock()
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendEvent(nftypes.TypeStart, conn, ruleId)
}
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
return false return false
} }
if conn.timeoutExceeded(t.timeout) { key := ICMPConnKey{
SrcIP: dstIP,
DstIP: srcIP,
ID: id,
}
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists || conn.timeoutExceeded(t.timeout) {
return false return false
} }
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && conn.UpdateLastSeen()
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) && conn.UpdateCounters(nftypes.Ingress, size)
conn.ID == id &&
conn.Sequence == seq return true
} }
func (t *ICMPTracker) cleanupRoutine(ctx context.Context) { func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
@@ -134,17 +180,18 @@ func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
} }
} }
} }
func (t *ICMPTracker) cleanup() { func (t *ICMPTracker) cleanup() {
t.mutex.Lock() t.mutex.Lock()
defer t.mutex.Unlock() defer t.mutex.Unlock()
for key, conn := range t.connections { for key, conn := range t.connections {
if conn.timeoutExceeded(t.timeout) { if conn.timeoutExceeded(t.timeout) {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key) delete(t.connections, key)
t.logger.Debug("Removed ICMP connection %v (timeout)", key) t.logger.Trace("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
} }
} }
@@ -154,20 +201,46 @@ func (t *ICMPTracker) Close() {
t.tickerCancel() t.tickerCancel()
t.mutex.Lock() t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil t.connections = nil
t.mutex.Unlock() t.mutex.Unlock()
} }
// makeICMPKey creates an ICMP connection key func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []byte) {
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey { t.flowLogger.StoreEvent(nftypes.EventFields{
return ICMPConnKey{ FlowID: conn.FlowId,
SrcIP: MakeIPAddr(srcIP), Type: typ,
DstIP: MakeIPAddr(dstIP), RuleID: ruleID,
ID: id, Direction: conn.Direction,
Sequence: seq, Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
ICMPType: conn.ICMPType,
ICMPCode: conn.ICMPCode,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
})
} }
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8, ruleID []byte, size int) {
fields := nftypes.EventFields{
FlowID: uuid.New(),
Type: nftypes.TypeStart,
RuleID: ruleID,
Direction: direction,
Protocol: nftypes.ICMP,
SourceIP: srcIP,
DestIP: dstIP,
ICMPType: typ,
ICMPCode: code,
}
if direction == nftypes.Ingress {
fields.RxPackets = 1
fields.RxBytes = uint64(size)
} else {
fields.TxPackets = 1
fields.TxBytes = uint64(size)
}
t.flowLogger.StoreEvent(fields)
} }

View File

@@ -1,39 +1,39 @@
package conntrack package conntrack
import ( import (
"net" "net/netip"
"testing" "testing"
) )
func BenchmarkICMPTracker(b *testing.B) { func BenchmarkICMPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger) tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535)) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
} }
}) })
b.Run("IsValidInbound", func(b *testing.B) { b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger) tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections // Pre-populate some connections
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i)) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0) tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), 0, 0)
} }
}) })
} }

View File

@@ -4,12 +4,15 @@ package conntrack
import ( import (
"context" "context"
"net" "net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
const ( const (
@@ -20,11 +23,11 @@ const (
) )
const ( const (
TCPSyn uint8 = 0x02
TCPAck uint8 = 0x10
TCPFin uint8 = 0x01 TCPFin uint8 = 0x01
TCPSyn uint8 = 0x02
TCPRst uint8 = 0x04 TCPRst uint8 = 0x04
TCPPush uint8 = 0x08 TCPPush uint8 = 0x08
TCPAck uint8 = 0x10
TCPUrg uint8 = 0x20 TCPUrg uint8 = 0x20
) )
@@ -38,7 +41,36 @@ const (
) )
// TCPState represents the state of a TCP connection // TCPState represents the state of a TCP connection
type TCPState int type TCPState int32
func (s TCPState) String() string {
switch s {
case TCPStateNew:
return "New"
case TCPStateSynSent:
return "SYN Sent"
case TCPStateSynReceived:
return "SYN Received"
case TCPStateEstablished:
return "Established"
case TCPStateFinWait1:
return "FIN Wait 1"
case TCPStateFinWait2:
return "FIN Wait 2"
case TCPStateClosing:
return "Closing"
case TCPStateTimeWait:
return "Time Wait"
case TCPStateCloseWait:
return "Close Wait"
case TCPStateLastAck:
return "Last ACK"
case TCPStateClosed:
return "Closed"
default:
return "Unknown"
}
}
const ( const (
TCPStateNew TCPState = iota TCPStateNew TCPState = iota
@@ -54,30 +86,38 @@ const (
TCPStateClosed TCPStateClosed
) )
// TCPConnKey uniquely identifies a TCP connection
type TCPConnKey struct {
SrcIP [16]byte
DstIP [16]byte
SrcPort uint16
DstPort uint16
}
// TCPConnTrack represents a TCP connection state // TCPConnTrack represents a TCP connection state
type TCPConnTrack struct { type TCPConnTrack struct {
BaseConnTrack BaseConnTrack
State TCPState SourcePort uint16
established atomic.Bool DestPort uint16
sync.RWMutex state atomic.Int32
tombstone atomic.Bool
} }
// IsEstablished safely checks if connection is established // GetState safely retrieves the current state
func (t *TCPConnTrack) IsEstablished() bool { func (t *TCPConnTrack) GetState() TCPState {
return t.established.Load() return TCPState(t.state.Load())
} }
// SetEstablished safely sets the established state // SetState safely updates the current state
func (t *TCPConnTrack) SetEstablished(state bool) { func (t *TCPConnTrack) SetState(state TCPState) {
t.established.Store(state) t.state.Store(int32(state))
}
// CompareAndSwapState atomically changes the state from old to new if current == old
func (t *TCPConnTrack) CompareAndSwapState(old, newState TCPState) bool {
return t.state.CompareAndSwap(int32(old), int32(newState))
}
// IsTombstone safely checks if the connection is marked for deletion
func (t *TCPConnTrack) IsTombstone() bool {
return t.tombstone.Load()
}
// SetTombstone safely marks the connection for deletion
func (t *TCPConnTrack) SetTombstone() {
t.tombstone.Store(true)
} }
// TCPTracker manages TCP connection states // TCPTracker manages TCP connection states
@@ -88,11 +128,18 @@ type TCPTracker struct {
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
tickerCancel context.CancelFunc tickerCancel context.CancelFunc
timeout time.Duration timeout time.Duration
ipPool *PreallocatedIPs waitTimeout time.Duration
flowLogger nftypes.FlowLogger
} }
// NewTCPTracker creates a new TCP connection tracker // NewTCPTracker creates a new TCP connection tracker
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker { func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker {
waitTimeout := TimeWaitTimeout
if timeout == 0 {
timeout = DefaultTCPTimeout
} else {
waitTimeout = timeout / 45
}
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@@ -102,179 +149,211 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
cleanupTicker: time.NewTicker(TCPCleanupInterval), cleanupTicker: time.NewTicker(TCPCleanupInterval),
tickerCancel: cancel, tickerCancel: cancel,
timeout: timeout, timeout: timeout,
ipPool: NewPreallocatedIPs(), waitTimeout: waitTimeout,
flowLogger: flowLogger,
} }
go tracker.cleanupRoutine(ctx) go tracker.cleanupRoutine(ctx)
return tracker return tracker
} }
// TrackOutbound processes an outbound TCP packet and updates connection state func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) { key := ConnKey{
// Create key before lock SrcIP: srcIP,
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) DstIP: dstIP,
SrcPort: srcPort,
t.mutex.Lock() DstPort: dstPort,
conn, exists := t.connections[key]
if !exists {
// Use preallocated IPs
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
SourcePort: srcPort,
DestPort: dstPort,
},
State: TCPStateNew,
} }
conn.UpdateLastSeen()
conn.established.Store(false)
t.connections[key] = conn
t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
}
t.mutex.Unlock()
// Lock individual connection for state update
conn.Lock()
t.updateState(conn, flags, true)
conn.Unlock()
conn.UpdateLastSeen()
}
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
if !isValidFlagCombination(flags) {
return false
}
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.RLock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
t.mutex.RUnlock() t.mutex.RUnlock()
if !exists { if exists {
return false t.updateState(key, conn, flags, direction, size)
return key, true
} }
// Handle RST packets return key, false
if flags&TCPRst != 0 {
conn.Lock()
if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
conn.State = TCPStateClosed
conn.SetEstablished(false)
conn.Unlock()
return true
}
conn.Unlock()
return false
} }
conn.Lock() // TrackOutbound records an outbound TCP connection
t.updateState(conn, flags, false) func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) {
conn.UpdateLastSeen() if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists {
isEstablished := conn.IsEstablished() // if (inverted direction) conn is not tracked, track this direction
isValidState := t.isValidStateForFlags(conn.State, flags) t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
conn.Unlock() }
return isEstablished || isValidState
} }
// updateState updates the TCP connection state based on flags // TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) { func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) {
// Handle RST flag specially - it always causes transition to closed t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
if flags&TCPRst != 0 { }
conn.State = TCPStateClosed
conn.SetEstablished(false)
t.logger.Trace("TCP connection reset: %s:%d -> %s:%d", // track is the common implementation for tracking both inbound and outbound connections
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
if exists || flags&TCPSyn == 0 {
return return
} }
switch conn.State { conn := &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
},
SourcePort: srcPort,
DestPort: dstPort,
}
conn.tombstone.Store(false)
conn.state.Store(int32(TCPStateNew))
t.logger.Trace("New %s TCP connection: %s", direction, key)
t.updateState(key, conn, flags, direction, size)
t.mutex.Lock()
t.connections[key] = conn
t.mutex.Unlock()
t.sendEvent(nftypes.TypeStart, conn, ruleID)
}
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool {
key := ConnKey{
SrcIP: dstIP,
DstIP: srcIP,
SrcPort: dstPort,
DstPort: srcPort,
}
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists || conn.IsTombstone() {
return false
}
currentState := conn.GetState()
if !t.isValidStateForFlags(currentState, flags) {
t.logger.Warn("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
// allow all flags for established for now
if currentState == TCPStateEstablished {
return true
}
return false
}
t.updateState(key, conn, flags, nftypes.Ingress, size)
return true
}
// updateState updates the TCP connection state based on flags
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) {
conn.UpdateLastSeen()
conn.UpdateCounters(packetDir, size)
currentState := conn.GetState()
if flags&TCPRst != 0 {
if conn.CompareAndSwapState(currentState, TCPStateClosed) {
conn.SetTombstone()
t.logger.Trace("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
return
}
var newState TCPState
switch currentState {
case TCPStateNew: case TCPStateNew:
if flags&TCPSyn != 0 && flags&TCPAck == 0 { if flags&TCPSyn != 0 && flags&TCPAck == 0 {
conn.State = TCPStateSynSent if conn.Direction == nftypes.Egress {
newState = TCPStateSynSent
} else {
newState = TCPStateSynReceived
}
} }
case TCPStateSynSent: case TCPStateSynSent:
if flags&TCPSyn != 0 && flags&TCPAck != 0 { if flags&TCPSyn != 0 && flags&TCPAck != 0 {
if isOutbound { if packetDir != conn.Direction {
conn.State = TCPStateSynReceived newState = TCPStateEstablished
} else { } else {
// Simultaneous open // Simultaneous open
conn.State = TCPStateEstablished newState = TCPStateSynReceived
conn.SetEstablished(true)
} }
} }
case TCPStateSynReceived: case TCPStateSynReceived:
if flags&TCPAck != 0 && flags&TCPSyn == 0 { if flags&TCPAck != 0 && flags&TCPSyn == 0 {
conn.State = TCPStateEstablished if packetDir == conn.Direction {
conn.SetEstablished(true) newState = TCPStateEstablished
}
} }
case TCPStateEstablished: case TCPStateEstablished:
if flags&TCPFin != 0 { if flags&TCPFin != 0 {
if isOutbound { if packetDir == conn.Direction {
conn.State = TCPStateFinWait1 newState = TCPStateFinWait1
} else { } else {
conn.State = TCPStateCloseWait newState = TCPStateCloseWait
} }
conn.SetEstablished(false)
} }
case TCPStateFinWait1: case TCPStateFinWait1:
if packetDir != conn.Direction {
switch { switch {
case flags&TCPFin != 0 && flags&TCPAck != 0: case flags&TCPFin != 0 && flags&TCPAck != 0:
// Simultaneous close - both sides sent FIN newState = TCPStateClosing
conn.State = TCPStateClosing
case flags&TCPFin != 0: case flags&TCPFin != 0:
conn.State = TCPStateFinWait2 newState = TCPStateClosing
case flags&TCPAck != 0: case flags&TCPAck != 0:
conn.State = TCPStateFinWait2 newState = TCPStateFinWait2
}
} }
case TCPStateFinWait2: case TCPStateFinWait2:
if flags&TCPFin != 0 { if flags&TCPFin != 0 {
conn.State = TCPStateTimeWait newState = TCPStateTimeWait
} }
case TCPStateClosing: case TCPStateClosing:
if flags&TCPAck != 0 { if flags&TCPAck != 0 {
conn.State = TCPStateTimeWait newState = TCPStateTimeWait
// Keep established = false from previous state
t.logger.Trace("TCP connection closed (simultaneous) - %s:%d -> %s:%d",
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
} }
case TCPStateCloseWait: case TCPStateCloseWait:
if flags&TCPFin != 0 { if flags&TCPFin != 0 {
conn.State = TCPStateLastAck newState = TCPStateLastAck
} }
case TCPStateLastAck: case TCPStateLastAck:
if flags&TCPAck != 0 { if flags&TCPAck != 0 {
conn.State = TCPStateClosed newState = TCPStateClosed
}
t.logger.Trace("TCP connection gracefully closed: %s:%d -> %s:%d",
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
} }
case TCPStateTimeWait: if newState != 0 && conn.CompareAndSwapState(currentState, newState) {
// Stay in TIME-WAIT for 2MSL before transitioning to closed t.logger.Trace("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir)
// This is handled by the cleanup routine
t.logger.Trace("TCP connection completed - %s:%d -> %s:%d", switch newState {
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) case TCPStateTimeWait:
t.logger.Trace("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
case TCPStateClosed:
conn.SetTombstone()
t.logger.Trace("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
} }
} }
@@ -283,18 +362,22 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
if !isValidFlagCombination(flags) { if !isValidFlagCombination(flags) {
return false return false
} }
if flags&TCPRst != 0 {
if state == TCPStateSynSent {
return flags&TCPAck != 0
}
return true
}
switch state { switch state {
case TCPStateNew: case TCPStateNew:
return flags&TCPSyn != 0 && flags&TCPAck == 0 return flags&TCPSyn != 0 && flags&TCPAck == 0
case TCPStateSynSent: case TCPStateSynSent:
// TODO: support simultaneous open
return flags&TCPSyn != 0 && flags&TCPAck != 0 return flags&TCPSyn != 0 && flags&TCPAck != 0
case TCPStateSynReceived: case TCPStateSynReceived:
return flags&TCPAck != 0 return flags&TCPAck != 0
case TCPStateEstablished: case TCPStateEstablished:
if flags&TCPRst != 0 {
return true
}
return flags&TCPAck != 0 return flags&TCPAck != 0
case TCPStateFinWait1: case TCPStateFinWait1:
return flags&TCPFin != 0 || flags&TCPAck != 0 return flags&TCPFin != 0 || flags&TCPAck != 0
@@ -311,9 +394,7 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
case TCPStateLastAck: case TCPStateLastAck:
return flags&TCPAck != 0 return flags&TCPAck != 0
case TCPStateClosed: case TCPStateClosed:
// Accept retransmitted ACKs in closed state // Accept retransmitted ACKs in closed state, the final ACK might be lost and the peer will retransmit their FIN-ACK
// This is important because the final ACK might be lost
// and the peer will retransmit their FIN-ACK
return flags&TCPAck != 0 return flags&TCPAck != 0
} }
return false return false
@@ -337,24 +418,33 @@ func (t *TCPTracker) cleanup() {
defer t.mutex.Unlock() defer t.mutex.Unlock()
for key, conn := range t.connections { for key, conn := range t.connections {
if conn.IsTombstone() {
// Clean up tombstoned connections without sending an event
delete(t.connections, key)
continue
}
var timeout time.Duration var timeout time.Duration
switch { currentState := conn.GetState()
case conn.State == TCPStateTimeWait: switch currentState {
timeout = TimeWaitTimeout case TCPStateTimeWait:
case conn.IsEstablished(): timeout = t.waitTimeout
case TCPStateEstablished:
timeout = t.timeout timeout = t.timeout
default: default:
timeout = TCPHandshakeTimeout timeout = TCPHandshakeTimeout
} }
lastSeen := conn.GetLastSeen() if conn.timeoutExceeded(timeout) {
if time.Since(lastSeen) > timeout {
// Return IPs to pool
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace("Cleaned up TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) t.logger.Trace("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
// event already handled by state change
if currentState != TCPStateTimeWait {
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
} }
} }
} }
@@ -365,10 +455,6 @@ func (t *TCPTracker) Close() {
// Clean up all remaining IPs // Clean up all remaining IPs
t.mutex.Lock() t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil t.connections = nil
t.mutex.Unlock() t.mutex.Unlock()
} }
@@ -386,3 +472,21 @@ func isValidFlagCombination(flags uint8) bool {
return true return true
} }
func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack, ruleID []byte) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
RuleID: ruleID,
Direction: conn.Direction,
Protocol: nftypes.TCP,
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
SourcePort: conn.SourcePort,
DestPort: conn.DestPort,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
})
}

View File

@@ -0,0 +1,83 @@
package conntrack
import (
"net/netip"
"testing"
"time"
)
func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck|TCPSyn, 0)
}
})
b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
if i%2 == 0 {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
} else {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck|TCPSyn, 0)
}
i++
}
})
})
}
// Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger)
defer tracker.Close()
// Pre-populate with expired connections
srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := netip.MustParseAddr("192.168.1.2")
for i := 0; i < 10000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
}
// Wait for connections to expire
time.Sleep(200 * time.Millisecond)
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.cleanup()
}
})
}

View File

@@ -1,19 +1,20 @@
package conntrack package conntrack
import ( import (
"net" "net/netip"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestTCPStateMachine(t *testing.T) { func TestTCPStateMachine(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1") srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := net.ParseIP("100.64.0.2") dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345) srcPort := uint16(12345)
dstPort := uint16(80) dstPort := uint16(80)
@@ -58,7 +59,7 @@ func TestTCPStateMachine(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags) isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags, 0)
require.Equal(t, !tt.wantDrop, isValid, tt.desc) require.Equal(t, !tt.wantDrop, isValid, tt.desc)
}) })
} }
@@ -76,17 +77,17 @@ func TestTCPStateMachine(t *testing.T) {
t.Helper() t.Helper()
// Send initial SYN // Send initial SYN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
// Receive SYN-ACK // Receive SYN-ACK
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
require.True(t, valid, "SYN-ACK should be allowed") require.True(t, valid, "SYN-ACK should be allowed")
// Send ACK // Send ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// Test data transfer // Test data transfer
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck) valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 0)
require.True(t, valid, "Data should be allowed after handshake") require.True(t, valid, "Data should be allowed after handshake")
}, },
}, },
@@ -99,18 +100,18 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Send FIN // Send FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
// Receive ACK for FIN // Receive ACK for FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid, "ACK for FIN should be allowed") require.True(t, valid, "ACK for FIN should be allowed")
// Receive FIN from other side // Receive FIN from other side
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck) valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid, "FIN should be allowed") require.True(t, valid, "FIN should be allowed")
// Send final ACK // Send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
}, },
}, },
{ {
@@ -122,11 +123,8 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Receive RST // Receive RST
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
require.True(t, valid, "RST should be allowed for established connection") require.True(t, valid, "RST should be allowed for established connection")
// Connection is logically dead but we don't enforce blocking subsequent packets
// The connection will be cleaned up by timeout
}, },
}, },
{ {
@@ -138,13 +136,13 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Both sides send FIN+ACK // Both sides send FIN+ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid, "Simultaneous FIN should be allowed") require.True(t, valid, "Simultaneous FIN should be allowed")
// Both sides send final ACK // Both sides send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck) valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid, "Final ACKs should be allowed") require.True(t, valid, "Final ACKs should be allowed")
}, },
}, },
@@ -154,7 +152,7 @@ func TestTCPStateMachine(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Helper() t.Helper()
tracker = NewTCPTracker(DefaultTCPTimeout, logger) tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
tt.test(t) tt.test(t)
}) })
} }
@@ -162,11 +160,11 @@ func TestTCPStateMachine(t *testing.T) {
} }
func TestRSTHandling(t *testing.T) { func TestRSTHandling(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1") srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := net.ParseIP("100.64.0.2") dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345) srcPort := uint16(12345)
dstPort := uint16(80) dstPort := uint16(80)
@@ -181,12 +179,12 @@ func TestRSTHandling(t *testing.T) {
name: "RST in established", name: "RST in established",
setupState: func() { setupState: func() {
// Establish connection first // Establish connection first
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
}, },
sendRST: func() { sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
}, },
wantValid: true, wantValid: true,
desc: "Should accept RST for established connection", desc: "Should accept RST for established connection",
@@ -195,7 +193,7 @@ func TestRSTHandling(t *testing.T) {
name: "RST without connection", name: "RST without connection",
setupState: func() {}, setupState: func() {},
sendRST: func() { sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
}, },
wantValid: false, wantValid: false,
desc: "Should reject RST without connection", desc: "Should reject RST without connection",
@@ -208,101 +206,455 @@ func TestRSTHandling(t *testing.T) {
tt.sendRST() tt.sendRST()
// Verify connection state is as expected // Verify connection state is as expected
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key] conn := tracker.connections[key]
if tt.wantValid { if tt.wantValid {
require.NotNil(t, conn) require.NotNil(t, conn)
require.Equal(t, TCPStateClosed, conn.State) require.Equal(t, TCPStateClosed, conn.GetState())
require.False(t, conn.IsEstablished())
} }
}) })
} }
} }
func TestTCPRetransmissions(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
// Test SYN retransmission
t.Run("SYN Retransmission", func(t *testing.T) {
// Initial SYN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
// Retransmit SYN (should not affect the state machine)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
// Verify we're still in SYN-SENT state
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
require.Equal(t, TCPStateSynSent, conn.GetState())
// Complete the handshake
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
require.True(t, valid)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// Verify we're in ESTABLISHED state
require.Equal(t, TCPStateEstablished, conn.GetState())
})
// Test ACK retransmission in established state
t.Run("ACK Retransmission", func(t *testing.T) {
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
// Establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Get connection object
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
require.Equal(t, TCPStateEstablished, conn.GetState())
// Retransmit ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// State should remain ESTABLISHED
require.Equal(t, TCPStateEstablished, conn.GetState())
})
// Test FIN retransmission
t.Run("FIN Retransmission", func(t *testing.T) {
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
// Establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Get connection object
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
// Send FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Retransmit FIN (should not change state)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Receive ACK for FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateFinWait2, conn.GetState())
})
}
func TestTCPDataTransfer(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Data Transfer", func(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Get connection object
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
// Send data
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000)
// Receive ACK for data
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 100)
require.True(t, valid)
// Receive data
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1500)
require.True(t, valid)
// Send ACK for received data
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
// State should remain ESTABLISHED
require.Equal(t, TCPStateEstablished, conn.GetState())
assert.Equal(t, uint64(1300), conn.BytesTx.Load())
assert.Equal(t, uint64(1700), conn.BytesRx.Load())
assert.Equal(t, uint64(4), conn.PacketsTx.Load())
assert.Equal(t, uint64(3), conn.PacketsRx.Load())
})
}
func TestTCPHalfClosedConnections(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
// Test half-closed connection: local end closes, remote end continues sending data
t.Run("Local Close, Remote Data", func(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
// Send FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Receive ACK for FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateFinWait2, conn.GetState())
// Remote end can still send data
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1000)
require.True(t, valid)
// We can still ACK their data
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// Receive FIN from remote end
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateTimeWait, conn.GetState())
// Send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// State should remain TIME-WAIT (waiting for possible retransmissions)
require.Equal(t, TCPStateTimeWait, conn.GetState())
})
// Test half-closed connection: remote end closes, local end continues sending data
t.Run("Remote Close, Local Data", func(t *testing.T) {
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
// Establish connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Get connection object
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
// Receive FIN from remote
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateCloseWait, conn.GetState())
// We can still send data
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000)
// Remote can still ACK our data
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
// Send our FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateLastAck, conn.GetState())
// Receive final ACK
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateClosed, conn.GetState())
})
}
func TestTCPAbnormalSequences(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
// Test handling of unsolicited RST in various states
t.Run("Unsolicited RST in SYN-SENT", func(t *testing.T) {
// Send SYN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
// Receive unsolicited RST (without proper ACK)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
require.False(t, valid, "RST without proper ACK in SYN-SENT should be rejected")
// Receive RST with proper ACK
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
require.True(t, valid, "RST with proper ACK in SYN-SENT should be accepted")
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.Equal(t, TCPStateClosed, conn.GetState())
require.True(t, conn.IsTombstone())
})
}
func TestTCPTimeoutHandling(t *testing.T) {
// Create tracker with a very short timeout for testing
shortTimeout := 100 * time.Millisecond
tracker := NewTCPTracker(shortTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
t.Run("Connection Timeout", func(t *testing.T) {
// Establish a connection
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Get connection object
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
require.Equal(t, TCPStateEstablished, conn.GetState())
// Wait for the connection to timeout
time.Sleep(2 * shortTimeout)
// Force cleanup
tracker.cleanup()
// Connection should be removed
_, exists := tracker.connections[key]
require.False(t, exists, "Connection should be removed after timeout")
})
t.Run("TIME_WAIT Timeout", func(t *testing.T) {
tracker = NewTCPTracker(shortTimeout, logger, flowLogger)
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key]
require.NotNil(t, conn)
// Complete the connection close to enter TIME_WAIT
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.Equal(t, TCPStateTimeWait, conn.GetState())
// TIME_WAIT should have its own timeout value (usually 2*MSL)
// For the test, we're using a short timeout
time.Sleep(2 * shortTimeout)
tracker.cleanup()
// Connection should be removed
_, exists := tracker.connections[key]
require.False(t, exists, "Connection should be removed after TIME_WAIT timeout")
})
}
func TestSynFlood(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
basePort := uint16(10000)
dstPort := uint16(80)
// Create a large number of SYN packets to simulate a SYN flood
for i := uint16(0); i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, basePort+i, dstPort, TCPSyn, 0)
}
// Check that we're tracking all connections
require.Equal(t, 1000, len(tracker.connections))
// Now simulate SYN timeout
var oldConns int
tracker.mutex.Lock()
for _, conn := range tracker.connections {
if conn.GetState() == TCPStateSynSent {
// Make the connection appear old
conn.lastSeen.Store(time.Now().Add(-TCPHandshakeTimeout - time.Second).UnixNano())
oldConns++
}
}
tracker.mutex.Unlock()
require.Equal(t, 1000, oldConns)
// Run cleanup
tracker.cleanup()
// Check that stale connections were cleaned up
require.Equal(t, 0, len(tracker.connections))
}
func TestTCPInboundInitiatedConnection(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
clientIP := netip.MustParseAddr("100.64.0.1")
serverIP := netip.MustParseAddr("100.64.0.2")
clientPort := uint16(12345)
serverPort := uint16(80)
// 1. Client sends SYN (we receive it as inbound)
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100)
key := ConnKey{
SrcIP: clientIP,
DstIP: serverIP,
SrcPort: clientPort,
DstPort: serverPort,
}
tracker.mutex.RLock()
conn := tracker.connections[key]
tracker.mutex.RUnlock()
require.NotNil(t, conn)
require.Equal(t, TCPStateSynReceived, conn.GetState(), "Connection should be in SYN-RECEIVED state after inbound SYN")
// 2. Server sends SYN-ACK response
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
// 3. Client sends ACK to complete handshake
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion")
// 4. Test data transfer
// Client sends data
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000)
// Server sends ACK for data
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
// Server sends data
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500)
// Client sends ACK for data
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
// Verify state and counters
require.Equal(t, TCPStateEstablished, conn.GetState())
assert.Equal(t, uint64(1300), conn.BytesRx.Load()) // 3 packets * 100 + 1000 data
assert.Equal(t, uint64(1700), conn.BytesTx.Load()) // 2 packets * 100 + 1500 data
assert.Equal(t, uint64(4), conn.PacketsRx.Load()) // SYN, ACK, Data
assert.Equal(t, uint64(3), conn.PacketsTx.Load()) // SYN-ACK, Data
}
// Helper to establish a TCP connection // Helper to establish a TCP connection
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) { func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
t.Helper() t.Helper()
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
require.True(t, valid, "SYN-ACK should be allowed") require.True(t, valid, "SYN-ACK should be allowed")
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
}
func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
}
})
b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
// Pre-populate some connections
for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck)
}
})
b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
if i%2 == 0 {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
} else {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck)
}
i++
}
})
})
}
// Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100*time.Millisecond, logger) // Short timeout for testing
defer tracker.Close()
// Pre-populate with expired connections
srcIP := net.ParseIP("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2")
for i := 0; i < 10000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
}
// Wait for connections to expire
time.Sleep(200 * time.Millisecond)
b.ResetTimer()
for i := 0; i < b.N; i++ {
tracker.cleanup()
}
})
} }

View File

@@ -2,11 +2,14 @@ package conntrack
import ( import (
"context" "context"
"net" "net/netip"
"sync" "sync"
"time" "time"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
const ( const (
@@ -19,6 +22,8 @@ const (
// UDPConnTrack represents a UDP connection state // UDPConnTrack represents a UDP connection state
type UDPConnTrack struct { type UDPConnTrack struct {
BaseConnTrack BaseConnTrack
SourcePort uint16
DestPort uint16
} }
// UDPTracker manages UDP connection states // UDPTracker manages UDP connection states
@@ -29,11 +34,11 @@ type UDPTracker struct {
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
tickerCancel context.CancelFunc tickerCancel context.CancelFunc
mutex sync.RWMutex mutex sync.RWMutex
ipPool *PreallocatedIPs flowLogger nftypes.FlowLogger
} }
// NewUDPTracker creates a new UDP connection tracker // NewUDPTracker creates a new UDP connection tracker
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker { func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *UDPTracker {
if timeout == 0 { if timeout == 0 {
timeout = DefaultUDPTimeout timeout = DefaultUDPTimeout
} }
@@ -46,7 +51,7 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval), cleanupTicker: time.NewTicker(UDPCleanupInterval),
tickerCancel: cancel, tickerCancel: cancel,
ipPool: NewPreallocatedIPs(), flowLogger: flowLogger,
} }
go tracker.cleanupRoutine(ctx) go tracker.cleanupRoutine(ctx)
@@ -54,55 +59,88 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
} }
// TrackOutbound records an outbound UDP connection // TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) { func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.mutex.Lock() t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
conn, exists := t.connections[key]
if !exists {
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &UDPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
SourcePort: srcPort,
DestPort: dstPort,
},
} }
conn.UpdateLastSeen()
t.connections[key] = conn
t.logger.Trace("New UDP connection: %v", conn)
}
t.mutex.Unlock()
conn.UpdateLastSeen()
} }
// IsValidInbound checks if an inbound packet matches a tracked connection // TrackInbound records an inbound UDP connection
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool { func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
key := makeConnKey(dstIP, srcIP, dstPort, srcPort) t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
}
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
t.mutex.RLock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
t.mutex.RUnlock() t.mutex.RUnlock()
if !exists { if exists {
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, true
}
return key, false
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
if exists {
return
}
conn := &UDPConnTrack{
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
},
SourcePort: srcPort,
DestPort: dstPort,
}
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
t.mutex.Lock()
t.connections[key] = conn
t.mutex.Unlock()
t.logger.Trace("New %s UDP connection: %s", direction, key)
t.sendEvent(nftypes.TypeStart, conn, ruleID)
}
// IsValidInbound checks if an inbound packet matches a tracked connection
func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) bool {
key := ConnKey{
SrcIP: dstIP,
DstIP: srcIP,
SrcPort: dstPort,
DstPort: srcPort,
}
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists || conn.timeoutExceeded(t.timeout) {
return false return false
} }
if conn.timeoutExceeded(t.timeout) { conn.UpdateLastSeen()
return false conn.UpdateCounters(nftypes.Ingress, size)
}
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && return true
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
conn.DestPort == srcPort &&
conn.SourcePort == dstPort
} }
// cleanupRoutine periodically removes stale connections // cleanupRoutine periodically removes stale connections
@@ -125,11 +163,11 @@ func (t *UDPTracker) cleanup() {
for key, conn := range t.connections { for key, conn := range t.connections {
if conn.timeoutExceeded(t.timeout) { if conn.timeoutExceeded(t.timeout) {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace("Removed UDP connection %v (timeout)", conn) t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
} }
} }
@@ -139,29 +177,44 @@ func (t *UDPTracker) Close() {
t.tickerCancel() t.tickerCancel()
t.mutex.Lock() t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil t.connections = nil
t.mutex.Unlock() t.mutex.Unlock()
} }
// GetConnection safely retrieves a connection state // GetConnection safely retrieves a connection state
func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) { func (t *UDPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*UDPConnTrack, bool) {
t.mutex.RLock() t.mutex.RLock()
defer t.mutex.RUnlock() defer t.mutex.RUnlock()
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) key := ConnKey{
conn, exists := t.connections[key] SrcIP: srcIP,
if !exists { DstIP: dstIP,
return nil, false SrcPort: srcPort,
DstPort: dstPort,
} }
conn, exists := t.connections[key]
return conn, true return conn, exists
} }
// Timeout returns the configured timeout duration for the tracker // Timeout returns the configured timeout duration for the tracker
func (t *UDPTracker) Timeout() time.Duration { func (t *UDPTracker) Timeout() time.Duration {
return t.timeout return t.timeout
} }
func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack, ruleID []byte) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
RuleID: ruleID,
Direction: conn.Direction,
Protocol: nftypes.UDP,
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
SourcePort: conn.SourcePort,
DestPort: conn.DestPort,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
})
}

View File

@@ -2,7 +2,7 @@ package conntrack
import ( import (
"context" "context"
"net" "net/netip"
"testing" "testing"
"time" "time"
@@ -30,7 +30,7 @@ func TestNewUDPTracker(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tracker := NewUDPTracker(tt.timeout, logger) tracker := NewUDPTracker(tt.timeout, logger, flowLogger)
assert.NotNil(t, tracker) assert.NotNil(t, tracker)
assert.Equal(t, tt.wantTimeout, tracker.timeout) assert.Equal(t, tt.wantTimeout, tracker.timeout)
assert.NotNil(t, tracker.connections) assert.NotNil(t, tracker.connections)
@@ -41,43 +41,48 @@ func TestNewUDPTracker(t *testing.T) {
} }
func TestUDPTracker_TrackOutbound(t *testing.T) { func TestUDPTracker_TrackOutbound(t *testing.T) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger) tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2") srcIP := netip.MustParseAddr("192.168.1.2")
dstIP := net.ParseIP("192.168.1.3") dstIP := netip.MustParseAddr("192.168.1.3")
srcPort := uint16(12345) srcPort := uint16(12345)
dstPort := uint16(53) dstPort := uint16(53)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
// Verify connection was tracked // Verify connection was tracked
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn, exists := tracker.connections[key] conn, exists := tracker.connections[key]
require.True(t, exists) require.True(t, exists)
assert.True(t, conn.SourceIP.Equal(srcIP)) assert.True(t, conn.SourceIP.Compare(srcIP) == 0)
assert.True(t, conn.DestIP.Equal(dstIP)) assert.True(t, conn.DestIP.Compare(dstIP) == 0)
assert.Equal(t, srcPort, conn.SourcePort) assert.Equal(t, srcPort, conn.SourcePort)
assert.Equal(t, dstPort, conn.DestPort) assert.Equal(t, dstPort, conn.DestPort)
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second) assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
} }
func TestUDPTracker_IsValidInbound(t *testing.T) { func TestUDPTracker_IsValidInbound(t *testing.T) {
tracker := NewUDPTracker(1*time.Second, logger) tracker := NewUDPTracker(1*time.Second, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2") srcIP := netip.MustParseAddr("192.168.1.2")
dstIP := net.ParseIP("192.168.1.3") dstIP := netip.MustParseAddr("192.168.1.3")
srcPort := uint16(12345) srcPort := uint16(12345)
dstPort := uint16(53) dstPort := uint16(53)
// Track outbound connection // Track outbound connection
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
tests := []struct { tests := []struct {
name string name string
srcIP net.IP srcIP netip.Addr
dstIP net.IP dstIP netip.Addr
srcPort uint16 srcPort uint16
dstPort uint16 dstPort uint16
sleep time.Duration sleep time.Duration
@@ -94,7 +99,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
}, },
{ {
name: "invalid source IP", name: "invalid source IP",
srcIP: net.ParseIP("192.168.1.4"), srcIP: netip.MustParseAddr("192.168.1.4"),
dstIP: srcIP, dstIP: srcIP,
srcPort: dstPort, srcPort: dstPort,
dstPort: srcPort, dstPort: srcPort,
@@ -104,7 +109,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
{ {
name: "invalid destination IP", name: "invalid destination IP",
srcIP: dstIP, srcIP: dstIP,
dstIP: net.ParseIP("192.168.1.4"), dstIP: netip.MustParseAddr("192.168.1.4"),
srcPort: dstPort, srcPort: dstPort,
dstPort: srcPort, dstPort: srcPort,
sleep: 0, sleep: 0,
@@ -144,7 +149,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
if tt.sleep > 0 { if tt.sleep > 0 {
time.Sleep(tt.sleep) time.Sleep(tt.sleep)
} }
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort) got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort, 0)
assert.Equal(t, tt.want, got) assert.Equal(t, tt.want, got)
}) })
} }
@@ -164,8 +169,8 @@ func TestUDPTracker_Cleanup(t *testing.T) {
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(cleanupInterval), cleanupTicker: time.NewTicker(cleanupInterval),
tickerCancel: tickerCancel, tickerCancel: tickerCancel,
ipPool: NewPreallocatedIPs(),
logger: logger, logger: logger,
flowLogger: flowLogger,
} }
// Start cleanup routine // Start cleanup routine
@@ -173,27 +178,27 @@ func TestUDPTracker_Cleanup(t *testing.T) {
// Add some connections // Add some connections
connections := []struct { connections := []struct {
srcIP net.IP srcIP netip.Addr
dstIP net.IP dstIP netip.Addr
srcPort uint16 srcPort uint16
dstPort uint16 dstPort uint16
}{ }{
{ {
srcIP: net.ParseIP("192.168.1.2"), srcIP: netip.MustParseAddr("192.168.1.2"),
dstIP: net.ParseIP("192.168.1.3"), dstIP: netip.MustParseAddr("192.168.1.3"),
srcPort: 12345, srcPort: 12345,
dstPort: 53, dstPort: 53,
}, },
{ {
srcIP: net.ParseIP("192.168.1.4"), srcIP: netip.MustParseAddr("192.168.1.4"),
dstIP: net.ParseIP("192.168.1.5"), dstIP: netip.MustParseAddr("192.168.1.5"),
srcPort: 12346, srcPort: 12346,
dstPort: 53, dstPort: 53,
}, },
} }
for _, conn := range connections { for _, conn := range connections {
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort) tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort, 0)
} }
// Verify initial connections // Verify initial connections
@@ -215,33 +220,33 @@ func TestUDPTracker_Cleanup(t *testing.T) {
func BenchmarkUDPTracker(b *testing.B) { func BenchmarkUDPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger) tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, 0)
} }
}) })
b.Run("IsValidInbound", func(b *testing.B) { b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger) tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections // Pre-populate some connections
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, 0)
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000)) tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), 0)
} }
}) })
} }

View File

@@ -1,6 +1,8 @@
package forwarder package forwarder
import ( import (
"fmt"
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/header"
@@ -79,3 +81,10 @@ func (e *endpoint) AddHeader(*stack.PacketBuffer) {
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool { func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
return true return true
} }
type epID stack.TransportEndpointID
func (i epID) String() string {
// src and remote is swapped
return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
}

View File

@@ -4,7 +4,9 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"runtime" "runtime"
"sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/buffer"
@@ -17,7 +19,9 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common" "github.com/netbirdio/netbird/client/firewall/uspfilter/common"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
const ( const (
@@ -29,6 +33,9 @@ const (
type Forwarder struct { type Forwarder struct {
logger *nblog.Logger logger *nblog.Logger
flowLogger nftypes.FlowLogger
// ruleIdMap is used to store the rule ID for a given connection
ruleIdMap sync.Map
stack *stack.Stack stack *stack.Stack
endpoint *endpoint endpoint *endpoint
udpForwarder *udpForwarder udpForwarder *udpForwarder
@@ -38,7 +45,7 @@ type Forwarder struct {
netstack bool netstack bool
} }
func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) { func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) {
s := stack.New(stack.Options{ s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{ TransportProtocols: []stack.TransportProtocolFactory{
@@ -102,9 +109,10 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwar
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
f := &Forwarder{ f := &Forwarder{
logger: logger, logger: logger,
flowLogger: flowLogger,
stack: s, stack: s,
endpoint: endpoint, endpoint: endpoint,
udpForwarder: newUDPForwarder(mtu, logger), udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
netstack: netstack, netstack: netstack,
@@ -164,3 +172,35 @@ func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
} }
return addr.AsSlice() return addr.AsSlice()
} }
func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) {
key := buildKey(srcIP, dstIP, srcPort, dstPort)
f.ruleIdMap.LoadOrStore(key, ruleID)
}
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
return value.([]byte), true
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {
return value.([]byte), true
}
return nil, false
}
func (f *Forwarder) DeleteRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
if _, ok := f.ruleIdMap.LoadAndDelete(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
return
}
f.ruleIdMap.LoadAndDelete(buildKey(dstIP, srcIP, dstPort, srcPort))
}
func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKey {
return conntrack.ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
}

View File

@@ -3,14 +3,30 @@ package forwarder
import ( import (
"context" "context"
"net" "net"
"net/netip"
"time" "time"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
// handleICMP handles ICMP packets from the network stack // handleICMP handles ICMP packets from the network stack
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool { func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
icmpType := uint8(icmpHdr.Type())
icmpCode := uint8(icmpHdr.Code())
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
// dont process our own replies
return true
}
flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode, 0, 0)
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second) ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel() defer cancel()
@@ -18,70 +34,55 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
// TODO: support non-root // TODO: support non-root
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil { if err != nil {
f.logger.Error("Failed to create ICMP socket for %v: %v", id, err) f.logger.Error("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err)
// This will make netstack reply on behalf of the original destination, that's ok for now // This will make netstack reply on behalf of the original destination, that's ok for now
return false return false
} }
defer func() { defer func() {
if err := conn.Close(); err != nil { if err := conn.Close(); err != nil {
f.logger.Debug("Failed to close ICMP socket: %v", err) f.logger.Debug("forwarder: Failed to close ICMP socket: %v", err)
} }
}() }()
dstIP := f.determineDialAddr(id.LocalAddress) dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP} dst := &net.IPAddr{IP: dstIP}
// Get the complete ICMP message (header + data)
fullPacket := stack.PayloadSince(pkt.TransportHeader()) fullPacket := stack.PayloadSince(pkt.TransportHeader())
payload := fullPacket.AsSlice() payload := fullPacket.AsSlice()
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice()) if _, err = conn.WriteTo(payload, dst); err != nil {
f.logger.Error("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err)
return true
}
f.logger.Trace("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
// For Echo Requests, send and handle response // For Echo Requests, send and handle response
switch icmpHdr.Type() { if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
case header.ICMPv4Echo: rxBytes := pkt.Size()
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id) txBytes := f.handleEchoResponse(icmpHdr, conn, id)
case header.ICMPv4EchoReply: f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
// dont process our own replies
return true
default:
} }
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) // For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
_, err = conn.WriteTo(payload, dst)
if err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
return true return true
} }
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v", func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int {
id, icmpHdr.Type(), icmpHdr.Code())
return true
}
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID) bool {
if _, err := conn.WriteTo(payload, dst); err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
return true
}
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
id, icmpHdr.Type(), icmpHdr.Code())
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error("Failed to set read deadline for ICMP response: %v", err) f.logger.Error("forwarder: Failed to set read deadline for ICMP response: %v", err)
return true return 0
} }
response := make([]byte, f.endpoint.mtu) response := make([]byte, f.endpoint.mtu)
n, _, err := conn.ReadFrom(response) n, _, err := conn.ReadFrom(response)
if err != nil { if err != nil {
if !isTimeout(err) { if !isTimeout(err) {
f.logger.Error("Failed to read ICMP response: %v", err) f.logger.Error("forwarder: Failed to read ICMP response: %v", err)
} }
return true return 0
} }
ipHdr := make([]byte, header.IPv4MinimumSize) ipHdr := make([]byte, header.IPv4MinimumSize)
@@ -100,10 +101,54 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds
fullPacket = append(fullPacket, response[:n]...) fullPacket = append(fullPacket, response[:n]...)
if err := f.InjectIncomingPacket(fullPacket); err != nil { if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error("Failed to inject ICMP response: %v", err) f.logger.Error("forwarder: Failed to inject ICMP response: %v", err)
return true
return 0
} }
f.logger.Trace("Forwarded ICMP echo reply for %v", id) f.logger.Trace("forwarder: Forwarded ICMP echo reply for %v type %v code %v",
return true epID(id), icmpHdr.Type(), icmpHdr.Code())
return len(fullPacket)
}
// sendICMPEvent stores flow events for ICMP packets
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, rxBytes, txBytes uint64) {
var rxPackets, txPackets uint64
if rxBytes > 0 {
rxPackets = 1
}
if txBytes > 0 {
txPackets = 1
}
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
fields := nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.ICMP,
// TODO: handle ipv6
SourceIP: srcIp,
DestIP: dstIp,
ICMPType: icmpType,
ICMPCode: icmpCode,
RxBytes: rxBytes,
TxBytes: txBytes,
RxPackets: rxPackets,
TxPackets: txPackets,
}
if typ == nftypes.TypeStart {
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
fields.RuleID = ruleId
}
} else {
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
}
f.flowLogger.StoreEvent(fields)
} }

View File

@@ -5,24 +5,40 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"net/netip"
"sync"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter" "gvisor.dev/gvisor/pkg/waiter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
// handleTCP is called by the TCP forwarder for new connections. // handleTCP is called by the TCP forwarder for new connections.
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
id := r.ID() id := r.ID()
flowID := uuid.New()
f.sendTCPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
var success bool
defer func() {
if !success {
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
}
}()
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if err != nil { if err != nil {
r.Complete(true) r.Complete(true)
f.logger.Trace("forwarder: dial error for %v: %v", id, err) f.logger.Trace("forwarder: dial error for %v: %v", epID(id), err)
return return
} }
@@ -44,47 +60,105 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
inConn := gonet.NewTCPConn(&wq, ep) inConn := gonet.NewTCPConn(&wq, ep)
f.logger.Trace("forwarder: established TCP connection %v", id) success = true
f.logger.Trace("forwarder: established TCP connection %v", epID(id))
go f.proxyTCP(id, inConn, outConn, ep) go f.proxyTCP(id, inConn, outConn, ep, flowID)
} }
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint) { func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
defer func() {
if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: inConn close error: %v", err)
}
if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: outConn close error: %v", err)
}
ep.Close()
}()
// Create context for managing the proxy goroutines
ctx, cancel := context.WithCancel(f.ctx) ctx, cancel := context.WithCancel(f.ctx)
defer cancel() defer cancel()
errChan := make(chan error, 2) go func() {
<-ctx.Done()
// Close connections and endpoint.
if err := inConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug("forwarder: inConn close error: %v", err)
}
if err := outConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug("forwarder: outConn close error: %v", err)
}
ep.Close()
}()
var wg sync.WaitGroup
wg.Add(2)
var (
bytesFromInToOut int64 // bytes from client to server (tx for client)
bytesFromOutToIn int64 // bytes from server to client (rx for client)
errInToOut error
errOutToIn error
)
go func() { go func() {
_, err := io.Copy(outConn, inConn) bytesFromInToOut, errInToOut = io.Copy(outConn, inConn)
errChan <- err cancel()
wg.Done()
}() }()
go func() { go func() {
_, err := io.Copy(inConn, outConn)
errChan <- err bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn)
cancel()
wg.Done()
}() }()
select { wg.Wait()
case <-ctx.Done():
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", id) if errInToOut != nil {
return if !isClosedError(errInToOut) {
case err := <-errChan: f.logger.Error("proxyTCP: copy error (in -> out): %v", errInToOut)
if err != nil && !isClosedError(err) {
f.logger.Error("proxyTCP: copy error: %v", err)
}
f.logger.Trace("forwarder: tearing down TCP connection %v", id)
return
} }
} }
if errOutToIn != nil {
if !isClosedError(errOutToIn) {
f.logger.Error("proxyTCP: copy error (out -> in): %v", errOutToIn)
}
}
var rxPackets, txPackets uint64
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
// fields are flipped since this is the in conn
rxPackets = tcpStats.SegmentsSent.Value()
txPackets = tcpStats.SegmentsReceived.Value()
}
f.logger.Trace("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
}
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
fields := nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.TCP,
// TODO: handle ipv6
SourceIP: srcIp,
DestIP: dstIp,
SourcePort: id.RemotePort,
DestPort: id.LocalPort,
RxBytes: rxBytes,
TxBytes: txBytes,
RxPackets: rxPackets,
TxPackets: txPackets,
}
if typ == nftypes.TypeStart {
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
fields.RuleID = ruleId
}
} else {
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
}
f.flowLogger.StoreEvent(fields)
}

View File

@@ -5,10 +5,12 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -16,6 +18,7 @@ import (
"gvisor.dev/gvisor/pkg/waiter" "gvisor.dev/gvisor/pkg/waiter"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
const ( const (
@@ -28,11 +31,13 @@ type udpPacketConn struct {
lastSeen atomic.Int64 lastSeen atomic.Int64
cancel context.CancelFunc cancel context.CancelFunc
ep tcpip.Endpoint ep tcpip.Endpoint
flowID uuid.UUID
} }
type udpForwarder struct { type udpForwarder struct {
sync.RWMutex sync.RWMutex
logger *nblog.Logger logger *nblog.Logger
flowLogger nftypes.FlowLogger
conns map[stack.TransportEndpointID]*udpPacketConn conns map[stack.TransportEndpointID]*udpPacketConn
bufPool sync.Pool bufPool sync.Pool
ctx context.Context ctx context.Context
@@ -44,10 +49,11 @@ type idleConn struct {
conn *udpPacketConn conn *udpPacketConn
} }
func newUDPForwarder(mtu int, logger *nblog.Logger) *udpForwarder { func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
f := &udpForwarder{ f := &udpForwarder{
logger: logger, logger: logger,
flowLogger: flowLogger,
conns: make(map[stack.TransportEndpointID]*udpPacketConn), conns: make(map[stack.TransportEndpointID]*udpPacketConn),
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
@@ -72,10 +78,10 @@ func (f *udpForwarder) Stop() {
for id, conn := range f.conns { for id, conn := range f.conns {
conn.cancel() conn.cancel()
if err := conn.conn.Close(); err != nil { if err := conn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP conn close error for %v: %v", id, err) f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(id), err)
} }
if err := conn.outConn.Close(); err != nil { if err := conn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err) f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
} }
conn.ep.Close() conn.ep.Close()
@@ -106,10 +112,10 @@ func (f *udpForwarder) cleanup() {
for _, idle := range idleConns { for _, idle := range idleConns {
idle.conn.cancel() idle.conn.cancel()
if err := idle.conn.conn.Close(); err != nil { if err := idle.conn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP conn close error for %v: %v", idle.id, err) f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(idle.id), err)
} }
if err := idle.conn.outConn.Close(); err != nil { if err := idle.conn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", idle.id, err) f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(idle.id), err)
} }
idle.conn.ep.Close() idle.conn.ep.Close()
@@ -118,7 +124,7 @@ func (f *udpForwarder) cleanup() {
delete(f.conns, idle.id) delete(f.conns, idle.id)
f.Unlock() f.Unlock()
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id) f.logger.Trace("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
} }
} }
} }
@@ -137,14 +143,24 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
_, exists := f.udpForwarder.conns[id] _, exists := f.udpForwarder.conns[id]
f.udpForwarder.RUnlock() f.udpForwarder.RUnlock()
if exists { if exists {
f.logger.Trace("forwarder: existing UDP connection for %v", id) f.logger.Trace("forwarder: existing UDP connection for %v", epID(id))
return return
} }
flowID := uuid.New()
f.sendUDPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
var success bool
defer func() {
if !success {
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
}
}()
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil { if err != nil {
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err) f.logger.Debug("forwarder: UDP dial error for %v: %v", epID(id), err)
// TODO: Send ICMP error message // TODO: Send ICMP error message
return return
} }
@@ -155,7 +171,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if epErr != nil { if epErr != nil {
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr) f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
if err := outConn.Close(); err != nil { if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err) f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
} }
return return
} }
@@ -168,6 +184,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
outConn: outConn, outConn: outConn,
cancel: connCancel, cancel: connCancel,
ep: ep, ep: ep,
flowID: flowID,
} }
pConn.updateLastSeen() pConn.updateLastSeen()
@@ -177,58 +194,114 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
f.udpForwarder.Unlock() f.udpForwarder.Unlock()
pConn.cancel() pConn.cancel()
if err := inConn.Close(); err != nil { if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err) f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
} }
if err := outConn.Close(); err != nil { if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err) f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
} }
return return
} }
f.udpForwarder.conns[id] = pConn f.udpForwarder.conns[id] = pConn
f.udpForwarder.Unlock() f.udpForwarder.Unlock()
f.logger.Trace("forwarder: established UDP connection to %v", id) success = true
f.logger.Trace("forwarder: established UDP connection %v", epID(id))
go f.proxyUDP(connCtx, pConn, id, ep) go f.proxyUDP(connCtx, pConn, id, ep)
} }
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) { func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
defer func() {
ctx, cancel := context.WithCancel(f.ctx)
defer cancel()
go func() {
<-ctx.Done()
pConn.cancel() pConn.cancel()
if err := pConn.conn.Close(); err != nil { if err := pConn.conn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err) f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
} }
if err := pConn.outConn.Close(); err != nil { if err := pConn.outConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err) f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
} }
ep.Close() ep.Close()
}()
var wg sync.WaitGroup
wg.Add(2)
var txBytes, rxBytes int64
var outboundErr, inboundErr error
// outbound->inbound: copy from pConn.conn to pConn.outConn
go func() {
defer wg.Done()
txBytes, outboundErr = pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
}()
// inbound->outbound: copy from pConn.outConn to pConn.conn
go func() {
defer wg.Done()
rxBytes, inboundErr = pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
}()
wg.Wait()
if outboundErr != nil && !isClosedError(outboundErr) {
f.logger.Error("proxyUDP: copy error (outbound->inbound): %v", outboundErr)
}
if inboundErr != nil && !isClosedError(inboundErr) {
f.logger.Error("proxyUDP: copy error (inbound->outbound): %v", inboundErr)
}
var rxPackets, txPackets uint64
if udpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
// fields are flipped since this is the in conn
rxPackets = udpStats.PacketsSent.Value()
txPackets = udpStats.PacketsReceived.Value()
}
f.logger.Trace("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
f.udpForwarder.Lock() f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id) delete(f.udpForwarder.conns, id)
f.udpForwarder.Unlock() f.udpForwarder.Unlock()
}()
errChan := make(chan error, 2) f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, uint64(rxBytes), uint64(txBytes), rxPackets, txPackets)
go func() {
errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
}()
go func() {
errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
}()
select {
case <-ctx.Done():
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", id)
return
case err := <-errChan:
if err != nil && !isClosedError(err) {
f.logger.Error("proxyUDP: copy error: %v", err)
} }
f.logger.Trace("forwarder: tearing down UDP connection %v", id)
return // sendUDPEvent stores flow events for UDP connections
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
fields := nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.UDP,
// TODO: handle ipv6
SourceIP: srcIp,
DestIP: dstIp,
SourcePort: id.RemotePort,
DestPort: id.LocalPort,
RxBytes: rxBytes,
TxBytes: txBytes,
RxPackets: rxPackets,
TxPackets: txPackets,
} }
if typ == nftypes.TypeStart {
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
fields.RuleID = ruleId
}
} else {
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
}
f.flowLogger.StoreEvent(fields)
} }
func (c *udpPacketConn) updateLastSeen() { func (c *udpPacketConn) updateLastSeen() {
@@ -240,18 +313,20 @@ func (c *udpPacketConn) getIdleDuration() time.Duration {
return time.Since(lastSeen) return time.Since(lastSeen)
} }
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error { // copy reads from src and writes to dst.
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) (int64, error) {
bufp := bufPool.Get().(*[]byte) bufp := bufPool.Get().(*[]byte)
defer bufPool.Put(bufp) defer bufPool.Put(bufp)
buffer := *bufp buffer := *bufp
var totalBytes int64 = 0
for { for {
if ctx.Err() != nil { if ctx.Err() != nil {
return ctx.Err() return totalBytes, ctx.Err()
} }
if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil { if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil {
return fmt.Errorf("set read deadline: %w", err) return totalBytes, fmt.Errorf("set read deadline: %w", err)
} }
n, err := src.Read(buffer) n, err := src.Read(buffer)
@@ -259,14 +334,15 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu
if isTimeout(err) { if isTimeout(err) {
continue continue
} }
return fmt.Errorf("read from %s: %w", direction, err) return totalBytes, fmt.Errorf("read from %s: %w", direction, err)
} }
_, err = dst.Write(buffer[:n]) nWritten, err := dst.Write(buffer[:n])
if err != nil { if err != nil {
return fmt.Errorf("write to %s: %w", direction, err) return totalBytes, fmt.Errorf("write to %s: %w", direction, err)
} }
totalBytes += int64(nWritten)
c.updateLastSeen() c.updateLastSeen()
} }
} }

View File

@@ -3,6 +3,7 @@ package uspfilter
import ( import (
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -13,8 +14,13 @@ import (
type localIPManager struct { type localIPManager struct {
mu sync.RWMutex mu sync.RWMutex
// Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory) // fixed-size high array for upper byte of a IPv4 address
ipv4Bitmap [1 << 16]uint32 ipv4Bitmap [256]*ipv4LowBitmap
}
// ipv4LowBitmap is a map for the low 16 bits of a IPv4 address
type ipv4LowBitmap struct {
bitmap [8192]uint32
} }
func newLocalIPManager() *localIPManager { func newLocalIPManager() *localIPManager {
@@ -26,39 +32,59 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
if ipv4 == nil { if ipv4 == nil {
return return
} }
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1]) high := uint16(ipv4[0])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3]) low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
m.ipv4Bitmap[high] |= 1 << (low % 32)
index := low / 32
bit := low % 32
if m.ipv4Bitmap[high] == nil {
m.ipv4Bitmap[high] = &ipv4LowBitmap{}
} }
func (m *localIPManager) checkBitmapBit(ip net.IP) bool { m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
ipv4 := ip.To4()
if ipv4 == nil {
return false
}
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
} }
func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error { func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
if ipv4 := ip.To4(); ipv4 != nil { if ipv4 := ip.To4(); ipv4 != nil {
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1]) high := uint16(ipv4[0])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3]) low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
if int(high) >= len(*newIPv4Bitmap) {
return fmt.Errorf("invalid IPv4 address: %s", ip) if bitmap[high] == nil {
bitmap[high] = &ipv4LowBitmap{}
} }
ipStr := ip.String()
index := low / 32
bit := low % 32
bitmap[high].bitmap[index] |= 1 << bit
ipStr := ipv4.String()
if _, exists := ipv4Set[ipStr]; !exists { if _, exists := ipv4Set[ipStr]; !exists {
ipv4Set[ipStr] = struct{}{} ipv4Set[ipStr] = struct{}{}
*ipv4Addresses = append(*ipv4Addresses, ipStr) *ipv4Addresses = append(*ipv4Addresses, ipStr)
newIPv4Bitmap[high] |= 1 << (low % 32)
} }
} }
}
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
high := uint16(ip[0])
low := (uint16(ip[1]) << 8) | (uint16(ip[2]) << 4) | uint16(ip[3])
if m.ipv4Bitmap[high] == nil {
return false
}
index := low / 32
bit := low % 32
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
}
func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
return nil return nil
} }
func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) { func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
addrs, err := iface.Addrs() addrs, err := iface.Addrs()
if err != nil { if err != nil {
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err) log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
@@ -76,7 +102,7 @@ func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1
continue continue
} }
if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil { if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil {
log.Debugf("process IP failed: %v", err) log.Debugf("process IP failed: %v", err)
} }
} }
@@ -89,14 +115,14 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
} }
}() }()
var newIPv4Bitmap [1 << 16]uint32 var newIPv4Bitmap [256]*ipv4LowBitmap
ipv4Set := make(map[string]struct{}) ipv4Set := make(map[string]struct{})
var ipv4Addresses []string var ipv4Addresses []string
// 127.0.0.0/8 // 127.0.0.0/8
high := uint16(127) << 8 newIPv4Bitmap[127] = &ipv4LowBitmap{}
for i := uint16(0); i < 256; i++ { for i := 0; i < 8192; i++ {
newIPv4Bitmap[high|i] = 0xffffffff newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
} }
if iface != nil { if iface != nil {
@@ -122,13 +148,13 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
return nil return nil
} }
func (m *localIPManager) IsLocalIP(ip net.IP) bool { func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
if !ip.Is4() {
return false
}
m.mu.RLock() m.mu.RLock()
defer m.mu.RUnlock() defer m.mu.RUnlock()
if ipv4 := ip.To4(); ipv4 != nil { return m.checkBitmapBit(ip.AsSlice())
return m.checkBitmapBit(ipv4)
}
return false
} }

View File

@@ -2,6 +2,7 @@ package uspfilter
import ( import (
"net" "net"
"net/netip"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -13,7 +14,7 @@ func TestLocalIPManager(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
setupAddr wgaddr.Address setupAddr wgaddr.Address
testIP net.IP testIP netip.Addr
expected bool expected bool
}{ }{
{ {
@@ -25,7 +26,7 @@ func TestLocalIPManager(t *testing.T) {
Mask: net.CIDRMask(24, 32), Mask: net.CIDRMask(24, 32),
}, },
}, },
testIP: net.ParseIP("127.0.0.2"), testIP: netip.MustParseAddr("127.0.0.2"),
expected: true, expected: true,
}, },
{ {
@@ -37,7 +38,7 @@ func TestLocalIPManager(t *testing.T) {
Mask: net.CIDRMask(24, 32), Mask: net.CIDRMask(24, 32),
}, },
}, },
testIP: net.ParseIP("127.0.0.1"), testIP: netip.MustParseAddr("127.0.0.1"),
expected: true, expected: true,
}, },
{ {
@@ -49,7 +50,7 @@ func TestLocalIPManager(t *testing.T) {
Mask: net.CIDRMask(24, 32), Mask: net.CIDRMask(24, 32),
}, },
}, },
testIP: net.ParseIP("127.255.255.255"), testIP: netip.MustParseAddr("127.255.255.255"),
expected: true, expected: true,
}, },
{ {
@@ -61,7 +62,7 @@ func TestLocalIPManager(t *testing.T) {
Mask: net.CIDRMask(24, 32), Mask: net.CIDRMask(24, 32),
}, },
}, },
testIP: net.ParseIP("192.168.1.1"), testIP: netip.MustParseAddr("192.168.1.1"),
expected: true, expected: true,
}, },
{ {
@@ -73,7 +74,19 @@ func TestLocalIPManager(t *testing.T) {
Mask: net.CIDRMask(24, 32), Mask: net.CIDRMask(24, 32),
}, },
}, },
testIP: net.ParseIP("192.168.1.2"), testIP: netip.MustParseAddr("192.168.1.2"),
expected: false,
},
{
name: "Local IP doesn't match - addresses 32 apart",
setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32),
},
},
testIP: netip.MustParseAddr("192.168.1.33"),
expected: false, expected: false,
}, },
{ {
@@ -85,7 +98,7 @@ func TestLocalIPManager(t *testing.T) {
Mask: net.CIDRMask(64, 128), Mask: net.CIDRMask(64, 128),
}, },
}, },
testIP: net.ParseIP("fe80::1"), testIP: netip.MustParseAddr("fe80::1"),
expected: false, expected: false,
}, },
} }
@@ -174,7 +187,7 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
t.Logf("Testing %d IPs", len(tests)) t.Logf("Testing %d IPs", len(tests))
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.ip, func(t *testing.T) { t.Run(tt.ip, func(t *testing.T) {
result := manager.IsLocalIP(net.ParseIP(tt.ip)) result := manager.IsLocalIP(netip.MustParseAddr(tt.ip))
require.Equal(t, tt.expected, result, "IP: %s", tt.ip) require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
}) })
} }
@@ -191,10 +204,8 @@ func BenchmarkIPChecks(b *testing.B) {
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i)) interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
} }
// Setup bitmap version // Setup bitmap
bitmapManager := &localIPManager{ bitmapManager := newLocalIPManager()
ipv4Bitmap: [1 << 16]uint32{},
}
for _, ip := range interfaces[:8] { // Add half of IPs for _, ip := range interfaces[:8] { // Add half of IPs
bitmapManager.setBitmapBit(ip) bitmapManager.setBitmapBit(ip)
} }
@@ -247,7 +258,7 @@ func BenchmarkWGPosition(b *testing.B) {
// Create two managers - one checks WG IP first, other checks it last // Create two managers - one checks WG IP first, other checks it last
b.Run("WG_First", func(b *testing.B) { b.Run("WG_First", func(b *testing.B) {
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}} bm := newLocalIPManager()
bm.setBitmapBit(wgIP) bm.setBitmapBit(wgIP)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@@ -256,7 +267,7 @@ func BenchmarkWGPosition(b *testing.B) {
}) })
b.Run("WG_Last", func(b *testing.B) { b.Run("WG_Last", func(b *testing.B) {
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}} bm := newLocalIPManager()
// Fill with other IPs first // Fill with other IPs first
for i := 0; i < 15; i++ { for i := 0; i < 15; i++ {
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i))) bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))

View File

@@ -1,4 +1,4 @@
// Package logger provides a high-performance, non-blocking logger for userspace networking // Package log provides a high-performance, non-blocking logger for userspace networking
package log package log
import ( import (
@@ -13,13 +13,12 @@ import (
) )
const ( const (
maxBatchSize = 1024 * 16 // 16KB max batch size maxBatchSize = 1024 * 16
maxMessageSize = 1024 * 2 // 2KB per message maxMessageSize = 1024 * 2
bufferSize = 1024 * 256 // 256KB ring buffer
defaultFlushInterval = 2 * time.Second defaultFlushInterval = 2 * time.Second
logChannelSize = 1000
) )
// Level represents log severity
type Level uint32 type Level uint32
const ( const (
@@ -42,32 +41,37 @@ var levelStrings = map[Level]string{
LevelTrace: "TRAC", LevelTrace: "TRAC",
} }
type logMessage struct {
level Level
format string
args []any
}
// Logger is a high-performance, non-blocking logger // Logger is a high-performance, non-blocking logger
type Logger struct { type Logger struct {
output io.Writer output io.Writer
level atomic.Uint32 level atomic.Uint32
buffer *ringBuffer msgChannel chan logMessage
shutdown chan struct{} shutdown chan struct{}
closeOnce sync.Once closeOnce sync.Once
wg sync.WaitGroup wg sync.WaitGroup
// Reusable buffer pool for formatting messages
bufPool sync.Pool bufPool sync.Pool
} }
// NewFromLogrus creates a new Logger that writes to the same output as the given logrus logger
func NewFromLogrus(logrusLogger *log.Logger) *Logger { func NewFromLogrus(logrusLogger *log.Logger) *Logger {
l := &Logger{ l := &Logger{
output: logrusLogger.Out, output: logrusLogger.Out,
buffer: newRingBuffer(bufferSize), msgChannel: make(chan logMessage, logChannelSize),
shutdown: make(chan struct{}), shutdown: make(chan struct{}),
bufPool: sync.Pool{ bufPool: sync.Pool{
New: func() interface{} { New: func() any {
// Pre-allocate buffer for message formatting
b := make([]byte, 0, maxMessageSize) b := make([]byte, 0, maxMessageSize)
return &b return &b
}, },
}, },
} }
logrusLevel := logrusLogger.GetLevel() logrusLevel := logrusLogger.GetLevel()
l.level.Store(uint32(logrusLevel)) l.level.Store(uint32(logrusLevel))
level := levelStrings[Level(logrusLevel)] level := levelStrings[Level(logrusLevel)]
@@ -79,97 +83,149 @@ func NewFromLogrus(logrusLogger *log.Logger) *Logger {
return l return l
} }
// SetLevel sets the logging level
func (l *Logger) SetLevel(level Level) { func (l *Logger) SetLevel(level Level) {
l.level.Store(uint32(level)) l.level.Store(uint32(level))
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level]) log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
} }
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) { func (l *Logger) log(level Level, format string, args ...any) {
*buf = (*buf)[:0] select {
case l.msgChannel <- logMessage{level: level, format: format, args: args}:
// Timestamp default:
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00") }
*buf = append(*buf, ' ')
// Level
*buf = append(*buf, levelStrings[level]...)
*buf = append(*buf, ' ')
// Message
if len(args) > 0 {
*buf = append(*buf, fmt.Sprintf(format, args...)...)
} else {
*buf = append(*buf, format...)
} }
*buf = append(*buf, '\n') // Error logs a message at error level
} func (l *Logger) Error(format string, args ...any) {
func (l *Logger) log(level Level, format string, args ...interface{}) {
bufp := l.bufPool.Get().(*[]byte)
l.formatMessage(bufp, level, format, args...)
if len(*bufp) > maxMessageSize {
*bufp = (*bufp)[:maxMessageSize]
}
_, _ = l.buffer.Write(*bufp)
l.bufPool.Put(bufp)
}
func (l *Logger) Error(format string, args ...interface{}) {
if l.level.Load() >= uint32(LevelError) { if l.level.Load() >= uint32(LevelError) {
l.log(LevelError, format, args...) l.log(LevelError, format, args...)
} }
} }
func (l *Logger) Warn(format string, args ...interface{}) { // Warn logs a message at warning level
func (l *Logger) Warn(format string, args ...any) {
if l.level.Load() >= uint32(LevelWarn) { if l.level.Load() >= uint32(LevelWarn) {
l.log(LevelWarn, format, args...) l.log(LevelWarn, format, args...)
} }
} }
func (l *Logger) Info(format string, args ...interface{}) { // Info logs a message at info level
func (l *Logger) Info(format string, args ...any) {
if l.level.Load() >= uint32(LevelInfo) { if l.level.Load() >= uint32(LevelInfo) {
l.log(LevelInfo, format, args...) l.log(LevelInfo, format, args...)
} }
} }
func (l *Logger) Debug(format string, args ...interface{}) { // Debug logs a message at debug level
func (l *Logger) Debug(format string, args ...any) {
if l.level.Load() >= uint32(LevelDebug) { if l.level.Load() >= uint32(LevelDebug) {
l.log(LevelDebug, format, args...) l.log(LevelDebug, format, args...)
} }
} }
func (l *Logger) Trace(format string, args ...interface{}) { // Trace logs a message at trace level
func (l *Logger) Trace(format string, args ...any) {
if l.level.Load() >= uint32(LevelTrace) { if l.level.Load() >= uint32(LevelTrace) {
l.log(LevelTrace, format, args...) l.log(LevelTrace, format, args...)
} }
} }
// worker periodically flushes the buffer func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...any) {
*buf = (*buf)[:0]
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
*buf = append(*buf, ' ')
*buf = append(*buf, levelStrings[level]...)
*buf = append(*buf, ' ')
var msg string
if len(args) > 0 {
msg = fmt.Sprintf(format, args...)
} else {
msg = format
}
*buf = append(*buf, msg...)
*buf = append(*buf, '\n')
if len(*buf) > maxMessageSize {
*buf = (*buf)[:maxMessageSize]
}
}
// processMessage handles a single log message and adds it to the buffer
func (l *Logger) processMessage(msg logMessage, buffer *[]byte) {
bufp := l.bufPool.Get().(*[]byte)
defer l.bufPool.Put(bufp)
l.formatMessage(bufp, msg.level, msg.format, msg.args...)
if len(*buffer)+len(*bufp) > maxBatchSize {
_, _ = l.output.Write(*buffer)
*buffer = (*buffer)[:0]
}
*buffer = append(*buffer, *bufp...)
}
// flushBuffer writes the accumulated buffer to output
func (l *Logger) flushBuffer(buffer *[]byte) {
if len(*buffer) > 0 {
_, _ = l.output.Write(*buffer)
*buffer = (*buffer)[:0]
}
}
// processBatch processes as many messages as possible without blocking
func (l *Logger) processBatch(buffer *[]byte) {
for len(*buffer) < maxBatchSize {
select {
case msg := <-l.msgChannel:
l.processMessage(msg, buffer)
default:
return
}
}
}
// handleShutdown manages the graceful shutdown sequence with timeout
func (l *Logger) handleShutdown(buffer *[]byte) {
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
for {
select {
case msg := <-l.msgChannel:
l.processMessage(msg, buffer)
case <-ctx.Done():
l.flushBuffer(buffer)
return
}
if len(l.msgChannel) == 0 {
l.flushBuffer(buffer)
return
}
}
}
// worker is the main goroutine that processes log messages
func (l *Logger) worker() { func (l *Logger) worker() {
defer l.wg.Done() defer l.wg.Done()
ticker := time.NewTicker(defaultFlushInterval) ticker := time.NewTicker(defaultFlushInterval)
defer ticker.Stop() defer ticker.Stop()
buf := make([]byte, 0, maxBatchSize) buffer := make([]byte, 0, maxBatchSize)
for { for {
select { select {
case <-l.shutdown: case <-l.shutdown:
l.handleShutdown(&buffer)
return return
case <-ticker.C: case <-ticker.C:
// Read accumulated messages l.flushBuffer(&buffer)
n, _ := l.buffer.Read(buf[:cap(buf)]) case msg := <-l.msgChannel:
if n == 0 { l.processMessage(msg, &buffer)
continue l.processBatch(&buffer)
}
// Write batch
_, _ = l.output.Write(buf[:n])
} }
} }
} }

View File

@@ -0,0 +1,121 @@
package log_test
import (
"context"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
type discard struct{}
func (d *discard) Write(p []byte) (n int, err error) {
return len(p), nil
}
func BenchmarkLogger(b *testing.B) {
simpleMessage := "Connection established"
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4 // TCPStateEstablished
complexMessage := "Packet inspection result: protocol=%s, direction=%s, flags=0x%x, sequence=%d, acknowledged=%d, payload_size=%d, fragmented=%v, connection_id=%s"
protocol := "TCP"
direction := "outbound"
flags := uint16(0x18) // ACK + PSH
sequence := uint32(123456789)
acknowledged := uint32(987654321)
payloadSize := 1460
fragmented := false
connID := "f7a12b3e-c456-7890-d123-456789abcdef"
b.Run("SimpleMessage", func(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(simpleMessage)
}
})
b.Run("ConntrackMessage", func(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
}
})
b.Run("ComplexMessage", func(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(complexMessage, protocol, direction, flags, sequence, acknowledged, payloadSize, fragmented, connID)
}
})
}
// BenchmarkLoggerParallel tests the logger under concurrent load
func BenchmarkLoggerParallel(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
}
})
}
// BenchmarkLoggerBurst tests how the logger handles bursts of messages
func BenchmarkLoggerBurst(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j := 0; j < 100; j++ {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
}
}
}
func createTestLogger() *log.Logger {
logrusLogger := logrus.New()
logrusLogger.SetOutput(&discard{})
logrusLogger.SetLevel(logrus.TraceLevel)
return log.NewFromLogrus(logrusLogger)
}
func cleanupLogger(logger *log.Logger) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_ = logger.Stop(ctx)
}

View File

@@ -1,85 +0,0 @@
package log
import "sync"
// ringBuffer is a simple ring buffer implementation
type ringBuffer struct {
buf []byte
size int
r, w int64 // Read and write positions
mu sync.Mutex
}
func newRingBuffer(size int) *ringBuffer {
return &ringBuffer{
buf: make([]byte, size),
size: size,
}
}
func (r *ringBuffer) Write(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
r.mu.Lock()
defer r.mu.Unlock()
if len(p) > r.size {
p = p[:r.size]
}
n = len(p)
// Write data, handling wrap-around
pos := int(r.w % int64(r.size))
writeLen := min(len(p), r.size-pos)
copy(r.buf[pos:], p[:writeLen])
// If we have more data and need to wrap around
if writeLen < len(p) {
copy(r.buf, p[writeLen:])
}
// Update write position
r.w += int64(n)
return n, nil
}
func (r *ringBuffer) Read(p []byte) (n int, err error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.w == r.r {
return 0, nil
}
// Calculate available data accounting for wraparound
available := int(r.w - r.r)
if available < 0 {
available += r.size
}
available = min(available, r.size)
// Limit read to buffer size
toRead := min(available, len(p))
if toRead == 0 {
return 0, nil
}
// Read data, handling wrap-around
pos := int(r.r % int64(r.size))
readLen := min(toRead, r.size-pos)
n = copy(p, r.buf[pos:pos+readLen])
// If we need more data and need to wrap around
if readLen < toRead {
n += copy(p[readLen:toRead], r.buf[:toRead-readLen])
}
// Update read position
r.r += int64(n)
return n, nil
}

View File

@@ -1,7 +1,6 @@
package uspfilter package uspfilter
import ( import (
"net"
"net/netip" "net/netip"
"github.com/google/gopacket" "github.com/google/gopacket"
@@ -12,14 +11,14 @@ import (
// PeerRule to handle management of rules // PeerRule to handle management of rules
type PeerRule struct { type PeerRule struct {
id string id string
ip net.IP mgmtId []byte
ip netip.Addr
ipLayer gopacket.LayerType ipLayer gopacket.LayerType
matchByIP bool matchByIP bool
protoLayer gopacket.LayerType protoLayer gopacket.LayerType
sPort *firewall.Port sPort *firewall.Port
dPort *firewall.Port dPort *firewall.Port
drop bool drop bool
comment string
udpHook func([]byte) bool udpHook func([]byte) bool
} }
@@ -31,8 +30,10 @@ func (r *PeerRule) ID() string {
type RouteRule struct { type RouteRule struct {
id string id string
mgmtId []byte
sources []netip.Prefix sources []netip.Prefix
destination netip.Prefix dstSet firewall.Set
destinations []netip.Prefix
proto firewall.Protocol proto firewall.Protocol
srcPort *firewall.Port srcPort *firewall.Port
dstPort *firewall.Port dstPort *firewall.Port

View File

@@ -2,7 +2,7 @@ package uspfilter
import ( import (
"fmt" "fmt"
"net" "net/netip"
"time" "time"
"github.com/google/gopacket" "github.com/google/gopacket"
@@ -53,8 +53,8 @@ type TraceResult struct {
} }
type PacketTrace struct { type PacketTrace struct {
SourceIP net.IP SourceIP netip.Addr
DestinationIP net.IP DestinationIP netip.Addr
Protocol string Protocol string
SourcePort uint16 SourcePort uint16
DestinationPort uint16 DestinationPort uint16
@@ -72,8 +72,8 @@ type TCPState struct {
} }
type PacketBuilder struct { type PacketBuilder struct {
SrcIP net.IP SrcIP netip.Addr
DstIP net.IP DstIP netip.Addr
Protocol fw.Protocol Protocol fw.Protocol
SrcPort uint16 SrcPort uint16
DstPort uint16 DstPort uint16
@@ -126,8 +126,8 @@ func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
Version: 4, Version: 4,
TTL: 64, TTL: 64,
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)), Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
SrcIP: p.SrcIP, SrcIP: p.SrcIP.AsSlice(),
DstIP: p.DstIP, DstIP: p.DstIP.AsSlice(),
} }
} }
@@ -260,28 +260,30 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
return m.traceInbound(packetData, trace, d, srcIP, dstIP) return m.traceInbound(packetData, trace, d, srcIP, dstIP)
} }
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace { func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) { if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
return trace return trace
} }
if m.localipmanager.IsLocalIP(dstIP) {
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) { if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
return trace return trace
} }
}
if !m.handleRouting(trace) { if !m.handleRouting(trace) {
return trace return trace
} }
if m.nativeRouter { if m.nativeRouter.Load() {
return m.handleNativeRouter(trace) return m.handleNativeRouter(trace)
} }
return m.handleRouteACLs(trace, d, srcIP, dstIP) return m.handleRouteACLs(trace, d, srcIP, dstIP)
} }
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool { func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) bool {
allowed := m.isValidTrackedConnection(d, srcIP, dstIP) allowed := m.isValidTrackedConnection(d, srcIP, dstIP, 0)
msg := "No existing connection found" msg := "No existing connection found"
if allowed { if allowed {
msg = m.buildConntrackStateMessage(d) msg = m.buildConntrackStateMessage(d)
@@ -309,32 +311,46 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
return msg return msg
} }
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool { func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
if !m.localForwarding { trace.AddResult(StageRouting, "Packet destined for local delivery", true)
trace.AddResult(StageRouting, "Local forwarding disabled", false)
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false) ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
strRuleId := "<no id>"
if ruleId != nil {
strRuleId = string(ruleId)
}
msg := fmt.Sprintf("Allowed by peer ACL rules (%s)", strRuleId)
if blocked {
msg = fmt.Sprintf("Blocked by peer ACL rules (%s)", strRuleId)
trace.AddResult(StagePeerACL, msg, false)
trace.AddResult(StageCompleted, "Packet dropped - ACL denied", false)
return true return true
} }
trace.AddResult(StageRouting, "Packet destined for local delivery", true) trace.AddResult(StagePeerACL, msg, true)
blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
msg := "Allowed by peer ACL rules"
if blocked {
msg = "Blocked by peer ACL rules"
}
trace.AddResult(StagePeerACL, msg, !blocked)
// Handle netstack mode
if m.netstack { if m.netstack {
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", !blocked) switch {
case !m.localForwarding:
trace.AddResult(StageCompleted, "Packet sent to virtual stack", true)
case m.forwarder.Load() != nil:
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", true)
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
default:
trace.AddResult(StageCompleted, "Packet dropped - forwarder not initialized", false)
}
return true
} }
trace.AddResult(StageCompleted, msgProcessingCompleted, !blocked) // In normal mode, packets are allowed through for local delivery
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
return true return true
} }
func (m *Manager) handleRouting(trace *PacketTrace) bool { func (m *Manager) handleRouting(trace *PacketTrace) bool {
if !m.routingEnabled { if !m.routingEnabled.Load() {
trace.AddResult(StageRouting, "Routing disabled", false) trace.AddResult(StageRouting, "Routing disabled", false)
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false) trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
return false return false
@@ -350,18 +366,23 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
return trace return trace
} }
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace { func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
proto := getProtocolFromPacket(d) proto, _ := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
msg := "Allowed by route ACLs" strId := string(id)
if id == nil {
strId = "<no id>"
}
msg := fmt.Sprintf("Allowed by route ACLs (%s)", strId)
if !allowed { if !allowed {
msg = "Blocked by route ACLs" msg = fmt.Sprintf("Blocked by route ACLs (%s)", strId)
} }
trace.AddResult(StageRouteACL, msg, allowed) trace.AddResult(StageRouteACL, msg, allowed)
if allowed && m.forwarder != nil { if allowed && m.forwarder.Load() != nil {
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true) m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
} }
@@ -380,7 +401,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace { func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
// will create or update the connection state // will create or update the connection state
dropped := m.processOutgoingHooks(packetData) dropped := m.processOutgoingHooks(packetData, 0)
if dropped { if dropped {
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false) trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
} else { } else {

View File

@@ -0,0 +1,440 @@
package uspfilter
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func verifyTraceStages(t *testing.T, trace *PacketTrace, expectedStages []PacketStage) {
t.Logf("Trace results: %v", trace.Results)
actualStages := make([]PacketStage, 0, len(trace.Results))
for _, result := range trace.Results {
actualStages = append(actualStages, result.Stage)
t.Logf("Stage: %s, Message: %s, Allowed: %v", result.Stage, result.Message, result.Allowed)
}
require.ElementsMatch(t, expectedStages, actualStages, "Trace stages don't match expected stages")
}
func verifyFinalDisposition(t *testing.T, trace *PacketTrace, expectedAllowed bool) {
require.NotEmpty(t, trace.Results, "Trace should have results")
lastResult := trace.Results[len(trace.Results)-1]
require.Equal(t, StageCompleted, lastResult.Stage, "Last stage should be 'Completed'")
require.Equal(t, expectedAllowed, lastResult.Allowed, "Final disposition incorrect")
}
func TestTracePacket(t *testing.T) {
setupTracerTest := func(statefulMode bool) *Manager {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
}
},
}
m, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
if !statefulMode {
m.stateful = false
}
return m
}
createPacketBuilder := func(srcIP, dstIP string, protocol fw.Protocol, srcPort, dstPort uint16, direction fw.RuleDirection) *PacketBuilder {
builder := &PacketBuilder{
SrcIP: netip.MustParseAddr(srcIP),
DstIP: netip.MustParseAddr(dstIP),
Protocol: protocol,
SrcPort: srcPort,
DstPort: dstPort,
Direction: direction,
}
if protocol == "tcp" {
builder.TCPState = &TCPState{SYN: true}
}
return builder
}
createICMPPacketBuilder := func(srcIP, dstIP string, icmpType, icmpCode uint8, direction fw.RuleDirection) *PacketBuilder {
return &PacketBuilder{
SrcIP: netip.MustParseAddr(srcIP),
DstIP: netip.MustParseAddr(dstIP),
Protocol: "icmp",
ICMPType: icmpType,
ICMPCode: icmpCode,
Direction: direction,
}
}
testCases := []struct {
name string
setup func(*Manager)
packetBuilder func() *PacketBuilder
expectedStages []PacketStage
expectedAllow bool
}{
{
name: "LocalTraffic_ACLAllowed",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "LocalTraffic_ACLDenied",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,
},
{
name: "LocalTraffic_WithForwarder",
setup: func(m *Manager) {
m.netstack = true
m.localForwarding = true
m.forwarder.Store(&forwarder.Forwarder{})
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageForwarding,
StageCompleted,
},
expectedAllow: true,
},
{
name: "LocalTraffic_WithoutForwarder",
setup: func(m *Manager) {
m.netstack = true
m.localForwarding = false
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "RoutedTraffic_ACLAllowed",
setup: func(m *Manager) {
m.routingEnabled.Store(true)
m.nativeRouter.Store(false)
m.forwarder.Store(&forwarder.Forwarder{})
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageRouteACL,
StageForwarding,
StageCompleted,
},
expectedAllow: true,
},
{
name: "RoutedTraffic_ACLDenied",
setup: func(m *Manager) {
m.routingEnabled.Store(true)
m.nativeRouter.Store(false)
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageRouteACL,
StageCompleted,
},
expectedAllow: false,
},
{
name: "RoutedTraffic_NativeRouter",
setup: func(m *Manager) {
m.routingEnabled.Store(true)
m.nativeRouter.Store(true)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageRouteACL,
StageForwarding,
StageCompleted,
},
expectedAllow: true,
},
{
name: "RoutedTraffic_RoutingDisabled",
setup: func(m *Manager) {
m.routingEnabled.Store(false)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageCompleted,
},
expectedAllow: false,
},
{
name: "ConnectionTracking_Hit",
setup: func(m *Manager) {
srcIP := netip.MustParseAddr("100.10.0.100")
dstIP := netip.MustParseAddr("1.1.1.1")
srcPort := uint16(12345)
dstPort := uint16(80)
m.tcpTracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, conntrack.TCPSyn, 0)
},
packetBuilder: func() *PacketBuilder {
pb := createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 80, 12345, fw.RuleDirectionIN)
pb.TCPState = &TCPState{SYN: true, ACK: true}
return pb
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageCompleted,
},
expectedAllow: true,
},
{
name: "OutboundTraffic",
setup: func(m *Manager) {
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("100.10.0.100", "1.1.1.1", "tcp", 12345, 80, fw.RuleDirectionOUT)
},
expectedStages: []PacketStage{
StageReceived,
StageCompleted,
},
expectedAllow: true,
},
{
name: "ICMPEchoRequest",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolICMP
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 8, 0, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "ICMPDestinationUnreachable",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolICMP
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 3, 0, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "UDPTraffic_WithoutHook",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolUDP
port := &fw.Port{Values: []uint16{53}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "UDPTraffic_WithHook",
setup: func(m *Manager) {
hookFunc := func([]byte) bool {
return true
}
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,
},
{
name: "StatefulDisabled_NoTracking",
setup: func(m *Manager) {
m.stateful = false
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
m := setupTracerTest(true)
tc.setup(m)
require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")),
"100.10.0.100 should be recognized as a local IP")
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("192.168.17.2")),
"192.168.17.2 should not be recognized as a local IP")
pb := tc.packetBuilder()
trace, err := m.TracePacketFromBuilder(pb)
require.NoError(t, err)
verifyTraceStages(t, trace, tc.expectedStages)
verifyFinalDisposition(t, trace, tc.expectedAllow)
})
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -93,8 +93,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
stateful: false, stateful: false,
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
// Single rule allowing all traffic // Single rule allowing all traffic
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, _, err := m.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
fw.ActionAccept, "", "allow all")
require.NoError(b, err) require.NoError(b, err)
}, },
desc: "Baseline: Single 'allow all' rule without connection tracking", desc: "Baseline: Single 'allow all' rule without connection tracking",
@@ -114,10 +113,15 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Add explicit rules matching return traffic pattern // Add explicit rules matching return traffic pattern
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
ip := generateRandomIPs(1)[0] ip := generateRandomIPs(1)[0]
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP, _, err := m.AddPeerFiltering(
nil,
ip,
fw.ProtocolTCP,
&fw.Port{Values: []uint16{uint16(1024 + i)}}, &fw.Port{Values: []uint16{uint16(1024 + i)}},
&fw.Port{Values: []uint16{80}}, &fw.Port{Values: []uint16{80}},
fw.ActionAccept, "", "explicit return") fw.ActionAccept,
"",
)
require.NoError(b, err) require.NoError(b, err)
} }
}, },
@@ -128,8 +132,15 @@ func BenchmarkCoreFiltering(b *testing.B) {
stateful: true, stateful: true,
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
// Add some basic rules but rely on state for established connections // Add some basic rules but rely on state for established connections
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil, _, err := m.AddPeerFiltering(
fw.ActionDrop, "", "default drop") nil,
net.ParseIP("0.0.0.0"),
fw.ProtocolTCP,
nil,
nil,
fw.ActionDrop,
"",
)
require.NoError(b, err) require.NoError(b, err)
}, },
desc: "Connection tracking with established connections", desc: "Connection tracking with established connections",
@@ -158,7 +169,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Create manager and basic setup // Create manager and basic setup
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -182,13 +193,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
// For stateful scenarios, establish the connection // For stateful scenarios, establish the connection
if sc.stateful { if sc.stateful {
manager.processOutgoingHooks(outbound) manager.processOutgoingHooks(outbound, 0)
} }
// Measure inbound packet processing // Measure inbound packet processing
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(inbound) manager.dropFilter(inbound, 0)
} }
}) })
} }
@@ -203,7 +214,7 @@ func BenchmarkStateScaling(b *testing.B) {
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) { b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -219,7 +230,7 @@ func BenchmarkStateScaling(b *testing.B) {
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
outbound := generatePacket(b, srcIPs[i], dstIPs[i], outbound := generatePacket(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, layers.IPProtocolTCP) uint16(1024+i), 80, layers.IPProtocolTCP)
manager.processOutgoingHooks(outbound) manager.processOutgoingHooks(outbound, 0)
} }
// Test packet // Test packet
@@ -227,11 +238,11 @@ func BenchmarkStateScaling(b *testing.B) {
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP) testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
// First establish our test connection // First establish our test connection
manager.processOutgoingHooks(testOut) manager.processOutgoingHooks(testOut, 0)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(testIn) manager.dropFilter(testIn, 0)
} }
}) })
} }
@@ -251,7 +262,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
b.Run(sc.name, func(b *testing.B) { b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -267,12 +278,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
if sc.established { if sc.established {
manager.processOutgoingHooks(outbound) manager.processOutgoingHooks(outbound, 0)
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(inbound) manager.dropFilter(inbound, 0)
} }
}) })
} }
@@ -450,7 +461,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
b.Run(sc.name, func(b *testing.B) { b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -466,25 +477,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
// For stateful cases and established connections // For stateful cases and established connections
if !strings.Contains(sc.name, "allow_non_wg") || if !strings.Contains(sc.name, "allow_non_wg") ||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") { (strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
manager.processOutgoingHooks(outbound) manager.processOutgoingHooks(outbound, 0)
// For TCP post-handshake, simulate full handshake // For TCP post-handshake, simulate full handshake
if sc.state == "post_handshake" { if sc.state == "post_handshake" {
// SYN // SYN
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn)) syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn) manager.processOutgoingHooks(syn, 0)
// SYN-ACK // SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck)) synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack) manager.dropFilter(synack, 0)
// ACK // ACK
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)) ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack) manager.processOutgoingHooks(ack, 0)
} }
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(inbound) manager.dropFilter(inbound, 0)
} }
}) })
} }
@@ -577,7 +588,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -590,10 +601,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
// Single rule to allow all return traffic from port 80 // Single rule to allow all return traffic from port 80
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
&fw.Port{Values: []uint16{80}},
nil,
fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }
@@ -616,17 +624,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Initial SYN // Initial SYN
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn)) uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn) manager.processOutgoingHooks(syn, 0)
// SYN-ACK // SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack) manager.dropFilter(synack, 0)
// ACK // ACK
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck)) uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack) manager.processOutgoingHooks(ack, 0)
} }
// Prepare test packets simulating bidirectional traffic // Prepare test packets simulating bidirectional traffic
@@ -647,9 +655,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Simulate bidirectional traffic // Simulate bidirectional traffic
// First outbound data // First outbound data
manager.processOutgoingHooks(outPackets[connIdx]) manager.processOutgoingHooks(outPackets[connIdx], 0)
// Then inbound response - this is what we're actually measuring // Then inbound response - this is what we're actually measuring
manager.dropFilter(inPackets[connIdx]) manager.dropFilter(inPackets[connIdx], 0)
} }
}) })
} }
@@ -668,7 +676,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -681,10 +689,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
// Single rule to allow all return traffic from port 80 // Single rule to allow all return traffic from port 80
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
&fw.Port{Values: []uint16{80}},
nil,
fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }
@@ -756,19 +761,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
p := patterns[connIdx] p := patterns[connIdx]
// Connection establishment // Connection establishment
manager.processOutgoingHooks(p.syn) manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck) manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack) manager.processOutgoingHooks(p.ack, 0)
// Data transfer // Data transfer
manager.processOutgoingHooks(p.request) manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response) manager.dropFilter(p.response, 0)
// Connection teardown // Connection teardown
manager.processOutgoingHooks(p.finClient) manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer) manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer) manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient) manager.processOutgoingHooks(p.ackClient, 0)
} }
}) })
} }
@@ -787,7 +792,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -799,10 +804,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
&fw.Port{Values: []uint16{80}},
nil,
fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }
@@ -824,15 +826,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
for i := 0; i < sc.connCount; i++ { for i := 0; i < sc.connCount; i++ {
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn)) uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn) manager.processOutgoingHooks(syn, 0)
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack) manager.dropFilter(synack, 0)
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck)) uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack) manager.processOutgoingHooks(ack, 0)
} }
// Pre-generate test packets // Pre-generate test packets
@@ -854,8 +856,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
counter++ counter++
// Simulate bidirectional traffic // Simulate bidirectional traffic
manager.processOutgoingHooks(outPackets[connIdx]) manager.processOutgoingHooks(outPackets[connIdx], 0)
manager.dropFilter(inPackets[connIdx]) manager.dropFilter(inPackets[connIdx], 0)
} }
}) })
}) })
@@ -875,7 +877,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -886,10 +888,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
}) })
if sc.rules { if sc.rules {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
&fw.Port{Values: []uint16{80}},
nil,
fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }
@@ -951,17 +950,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
p := patterns[connIdx] p := patterns[connIdx]
// Full connection lifecycle // Full connection lifecycle
manager.processOutgoingHooks(p.syn) manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck) manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack) manager.processOutgoingHooks(p.ack, 0)
manager.processOutgoingHooks(p.request) manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response) manager.dropFilter(p.response, 0)
manager.processOutgoingHooks(p.finClient) manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer) manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer) manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient) manager.processOutgoingHooks(p.ackClient, 0)
} }
}) })
}) })
@@ -1033,14 +1032,7 @@ func BenchmarkRouteACLs(b *testing.B) {
} }
for _, r := range rules { for _, r := range rules {
_, err := manager.AddRouteFiltering( _, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept)
r.sources,
r.dest,
r.proto,
nil,
r.port,
fw.ActionAccept,
)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
@@ -1062,8 +1054,8 @@ func BenchmarkRouteACLs(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
for _, tc := range cases { for _, tc := range cases {
srcIP := net.ParseIP(tc.srcIP) srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := net.ParseIP(tc.dstIP) dstIP := netip.MustParseAddr(tc.dstIP)
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort) manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
} }
} }

View File

@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/management/domain"
) )
func TestPeerACLFiltering(t *testing.T) { func TestPeerACLFiltering(t *testing.T) {
@@ -34,7 +35,7 @@ func TestPeerACLFiltering(t *testing.T) {
}, },
} }
manager, err := Create(ifaceMock, false) manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, manager) require.NotNil(t, manager)
@@ -188,24 +189,321 @@ func TestPeerACLFiltering(t *testing.T) {
ruleAction: fw.ActionAccept, ruleAction: fw.ActionAccept,
shouldBeBlocked: true, shouldBeBlocked: true,
}, },
{
name: "Allow TCP traffic without port specification",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "Allow UDP traffic without port specification",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 53,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolUDP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "TCP packet doesn't match UDP filter with same port",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolUDP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "UDP packet doesn't match TCP filter with same port",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "ICMP packet doesn't match TCP filter",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolICMP,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "ICMP packet doesn't match UDP filter",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolICMP,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolUDP,
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "Allow TCP traffic within port range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 8080,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "Block TCP traffic outside port range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 7999,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: true,
},
{
name: "Edge Case - Port at Range Boundary",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 8100,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "UDP Port Range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 5060,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolUDP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{5060, 5070}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "Allow multiple destination ports",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 8080,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
{
name: "Allow multiple source ports",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 80,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleSrcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
ruleAction: fw.ActionAccept,
shouldBeBlocked: false,
},
// New drop test cases
{
name: "Drop TCP traffic from WG peer",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop UDP traffic from WG peer",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 53,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolUDP,
ruleDstPort: &fw.Port{Values: []uint16{53}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop ICMP traffic from WG peer",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolICMP,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolICMP,
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop all traffic from WG peer",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolALL,
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop traffic from multiple source ports",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 80,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleSrcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop multiple destination ports",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 8080,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Drop TCP traffic within port range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 8080,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Accept TCP traffic outside drop port range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 7999,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: false,
},
{
name: "Drop TCP traffic with source port range",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 32100,
dstPort: 80,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleSrcPort: &fw.Port{IsRange: true, Values: []uint16{32000, 33000}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
{
name: "Mixed rule - drop specific port but allow other ports",
srcIP: "100.10.0.1",
dstIP: "100.10.0.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
ruleIP: "100.10.0.1",
ruleProto: fw.ProtocolTCP,
ruleDstPort: &fw.Port{Values: []uint16{443}},
ruleAction: fw.ActionDrop,
shouldBeBlocked: true,
},
} }
t.Run("Implicit DROP (no rules)", func(t *testing.T) { t.Run("Implicit DROP (no rules)", func(t *testing.T) {
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443) packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
isDropped := manager.DropIncoming(packet) isDropped := manager.DropIncoming(packet, 0)
require.True(t, isDropped, "Packet should be dropped when no rules exist") require.True(t, isDropped, "Packet should be dropped when no rules exist")
}) })
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
if tc.ruleAction == fw.ActionDrop {
// add general accept rule to test drop rule
// TODO: this only works because 0.0.0.0 is tested last, we need to implement order
rules, err := manager.AddPeerFiltering( rules, err := manager.AddPeerFiltering(
nil,
net.ParseIP("0.0.0.0"),
fw.ProtocolALL,
nil,
nil,
fw.ActionAccept,
"",
)
require.NoError(t, err)
require.NotEmpty(t, rules)
t.Cleanup(func() {
for _, rule := range rules {
require.NoError(t, manager.DeletePeerRule(rule))
}
})
}
rules, err := manager.AddPeerFiltering(
nil,
net.ParseIP(tc.ruleIP), net.ParseIP(tc.ruleIP),
tc.ruleProto, tc.ruleProto,
tc.ruleSrcPort, tc.ruleSrcPort,
tc.ruleDstPort, tc.ruleDstPort,
tc.ruleAction, tc.ruleAction,
"", "",
tc.name,
) )
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, rules) require.NotEmpty(t, rules)
@@ -217,7 +515,7 @@ func TestPeerACLFiltering(t *testing.T) {
}) })
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort) packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
isDropped := manager.DropIncoming(packet) isDropped := manager.DropIncoming(packet, 0)
require.Equal(t, tc.shouldBeBlocked, isDropped) require.Equal(t, tc.shouldBeBlocked, isDropped)
}) })
} }
@@ -302,12 +600,12 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
}, },
} }
manager, err := Create(ifaceMock, false) manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(tb, manager.EnableRouting())
require.NoError(tb, err) require.NoError(tb, err)
require.NoError(tb, manager.EnableRouting())
require.NotNil(tb, manager) require.NotNil(tb, manager)
require.True(tb, manager.routingEnabled) require.True(tb, manager.routingEnabled.Load())
require.False(tb, manager.nativeRouter) require.False(tb, manager.nativeRouter.Load())
tb.Cleanup(func() { tb.Cleanup(func() {
require.NoError(tb, manager.Close(nil)) require.NoError(tb, manager.Close(nil))
@@ -321,7 +619,7 @@ func TestRouteACLFiltering(t *testing.T) {
type rule struct { type rule struct {
sources []netip.Prefix sources []netip.Prefix
dest netip.Prefix dest fw.Network
proto fw.Protocol proto fw.Protocol
srcPort *fw.Port srcPort *fw.Port
dstPort *fw.Port dstPort *fw.Port
@@ -347,7 +645,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443, dstPort: 443,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}}, dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -363,7 +661,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443, dstPort: 443,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}}, dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -379,7 +677,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443, dstPort: 443,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
dest: netip.MustParsePrefix("0.0.0.0/0"), dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}}, dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -395,7 +693,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 53, dstPort: 53,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolUDP, proto: fw.ProtocolUDP,
dstPort: &fw.Port{Values: []uint16{53}}, dstPort: &fw.Port{Values: []uint16{53}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -409,7 +707,7 @@ func TestRouteACLFiltering(t *testing.T) {
proto: fw.ProtocolICMP, proto: fw.ProtocolICMP,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("0.0.0.0/0"), dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
proto: fw.ProtocolICMP, proto: fw.ProtocolICMP,
action: fw.ActionAccept, action: fw.ActionAccept,
}, },
@@ -424,7 +722,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -440,7 +738,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8080, dstPort: 8080,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -456,7 +754,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -472,7 +770,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -488,7 +786,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
srcPort: &fw.Port{Values: []uint16{12345}}, srcPort: &fw.Port{Values: []uint16{12345}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -507,7 +805,7 @@ func TestRouteACLFiltering(t *testing.T) {
netip.MustParsePrefix("100.10.0.0/16"), netip.MustParsePrefix("100.10.0.0/16"),
netip.MustParsePrefix("172.16.0.0/16"), netip.MustParsePrefix("172.16.0.0/16"),
}, },
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -521,7 +819,7 @@ func TestRouteACLFiltering(t *testing.T) {
proto: fw.ProtocolICMP, proto: fw.ProtocolICMP,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
action: fw.ActionAccept, action: fw.ActionAccept,
}, },
@@ -536,33 +834,13 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept, action: fw.ActionAccept,
}, },
shouldPass: true, shouldPass: true,
}, },
{
name: "Multiple source networks with mismatched protocol",
srcIP: "172.16.0.1",
dstIP: "192.168.1.100",
// Should not match TCP rule
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 80,
rule: rule{
sources: []netip.Prefix{
netip.MustParsePrefix("100.10.0.0/16"),
netip.MustParsePrefix("172.16.0.0/16"),
},
dest: netip.MustParsePrefix("192.168.1.0/24"),
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
},
shouldPass: false,
},
{ {
name: "Allow multiple destination ports", name: "Allow multiple destination ports",
srcIP: "100.10.0.1", srcIP: "100.10.0.1",
@@ -572,7 +850,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8080, dstPort: 8080,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80, 8080, 443}}, dstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -588,7 +866,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
srcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}}, srcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -604,7 +882,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
srcPort: &fw.Port{Values: []uint16{12345}}, srcPort: &fw.Port{Values: []uint16{12345}},
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
@@ -621,7 +899,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8080, dstPort: 8080,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{ dstPort: &fw.Port{
IsRange: true, IsRange: true,
@@ -640,7 +918,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 7999, dstPort: 7999,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{ dstPort: &fw.Port{
IsRange: true, IsRange: true,
@@ -659,7 +937,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
srcPort: &fw.Port{ srcPort: &fw.Port{
IsRange: true, IsRange: true,
@@ -678,7 +956,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443, dstPort: 443,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
srcPort: &fw.Port{ srcPort: &fw.Port{
IsRange: true, IsRange: true,
@@ -700,7 +978,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8100, dstPort: 8100,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{ dstPort: &fw.Port{
IsRange: true, IsRange: true,
@@ -719,7 +997,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 5060, dstPort: 5060,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolUDP, proto: fw.ProtocolUDP,
dstPort: &fw.Port{ dstPort: &fw.Port{
IsRange: true, IsRange: true,
@@ -738,7 +1016,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 8080, dstPort: 8080,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
dstPort: &fw.Port{ dstPort: &fw.Port{
IsRange: true, IsRange: true,
@@ -757,7 +1035,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 443, dstPort: 443,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}}, dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionDrop, action: fw.ActionDrop,
@@ -773,7 +1051,7 @@ func TestRouteACLFiltering(t *testing.T) {
dstPort: 80, dstPort: 80,
rule: rule{ rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
action: fw.ActionDrop, action: fw.ActionDrop,
}, },
@@ -791,18 +1069,160 @@ func TestRouteACLFiltering(t *testing.T) {
netip.MustParsePrefix("100.10.0.0/16"), netip.MustParsePrefix("100.10.0.0/16"),
netip.MustParsePrefix("172.16.0.0/16"), netip.MustParsePrefix("172.16.0.0/16"),
}, },
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionDrop, action: fw.ActionDrop,
}, },
shouldPass: false, shouldPass: false,
}, },
{
name: "Drop empty destination set",
srcIP: "172.16.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 80,
rule: rule{
sources: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
},
dest: fw.Network{Set: fw.Set{}},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
},
shouldPass: false,
},
{
name: "Accept TCP traffic outside drop port range",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 7999,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
action: fw.ActionDrop,
},
shouldPass: true,
},
{
name: "Allow TCP traffic without port specification",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 443,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
action: fw.ActionAccept,
},
shouldPass: true,
},
{
name: "Allow UDP traffic without port specification",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 53,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolUDP,
action: fw.ActionAccept,
},
shouldPass: true,
},
{
name: "TCP packet doesn't match UDP filter with same port",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolTCP,
srcPort: 12345,
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolUDP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
},
shouldPass: false,
},
{
name: "UDP packet doesn't match TCP filter with same port",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolUDP,
srcPort: 12345,
dstPort: 80,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionAccept,
},
shouldPass: false,
},
{
name: "ICMP packet doesn't match TCP filter",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolICMP,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP,
action: fw.ActionAccept,
},
shouldPass: false,
},
{
name: "ICMP packet doesn't match UDP filter",
srcIP: "100.10.0.1",
dstIP: "192.168.1.100",
proto: fw.ProtocolICMP,
rule: rule{
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolUDP,
action: fw.ActionAccept,
},
shouldPass: false,
},
} }
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
if tc.rule.action == fw.ActionDrop {
// add general accept rule to test drop rule
rule, err := manager.AddRouteFiltering( rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
fw.ProtocolALL,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
t.Cleanup(func() {
require.NoError(t, manager.DeleteRouteRule(rule))
})
}
rule, err := manager.AddRouteFiltering(
nil,
tc.rule.sources, tc.rule.sources,
tc.rule.dest, tc.rule.dest,
tc.rule.proto, tc.rule.proto,
@@ -817,12 +1237,12 @@ func TestRouteACLFiltering(t *testing.T) {
require.NoError(t, manager.DeleteRouteRule(rule)) require.NoError(t, manager.DeleteRouteRule(rule))
}) })
srcIP := net.ParseIP(tc.srcIP) srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := net.ParseIP(tc.dstIP) dstIP := netip.MustParseAddr(tc.dstIP)
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed // testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
// to the forwarder // to the forwarder
isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort) _, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
require.Equal(t, tc.shouldPass, isAllowed) require.Equal(t, tc.shouldPass, isAllowed)
}) })
} }
@@ -835,7 +1255,7 @@ func TestRouteACLOrder(t *testing.T) {
name string name string
rules []struct { rules []struct {
sources []netip.Prefix sources []netip.Prefix
dest netip.Prefix dest fw.Network
proto fw.Protocol proto fw.Protocol
srcPort *fw.Port srcPort *fw.Port
dstPort *fw.Port dstPort *fw.Port
@@ -856,7 +1276,7 @@ func TestRouteACLOrder(t *testing.T) {
name: "Drop rules take precedence over accept", name: "Drop rules take precedence over accept",
rules: []struct { rules: []struct {
sources []netip.Prefix sources []netip.Prefix
dest netip.Prefix dest fw.Network
proto fw.Protocol proto fw.Protocol
srcPort *fw.Port srcPort *fw.Port
dstPort *fw.Port dstPort *fw.Port
@@ -865,7 +1285,7 @@ func TestRouteACLOrder(t *testing.T) {
{ {
// Accept rule added first // Accept rule added first
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80, 443}}, dstPort: &fw.Port{Values: []uint16{80, 443}},
action: fw.ActionAccept, action: fw.ActionAccept,
@@ -873,7 +1293,7 @@ func TestRouteACLOrder(t *testing.T) {
{ {
// Drop rule added second but should be evaluated first // Drop rule added second but should be evaluated first
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}}, dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionDrop, action: fw.ActionDrop,
@@ -911,7 +1331,7 @@ func TestRouteACLOrder(t *testing.T) {
name: "Multiple drop rules take precedence", name: "Multiple drop rules take precedence",
rules: []struct { rules: []struct {
sources []netip.Prefix sources []netip.Prefix
dest netip.Prefix dest fw.Network
proto fw.Protocol proto fw.Protocol
srcPort *fw.Port srcPort *fw.Port
dstPort *fw.Port dstPort *fw.Port
@@ -920,14 +1340,14 @@ func TestRouteACLOrder(t *testing.T) {
{ {
// Accept all // Accept all
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
dest: netip.MustParsePrefix("0.0.0.0/0"), dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
proto: fw.ProtocolALL, proto: fw.ProtocolALL,
action: fw.ActionAccept, action: fw.ActionAccept,
}, },
{ {
// Drop specific port // Drop specific port
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{443}}, dstPort: &fw.Port{Values: []uint16{443}},
action: fw.ActionDrop, action: fw.ActionDrop,
@@ -935,7 +1355,7 @@ func TestRouteACLOrder(t *testing.T) {
{ {
// Drop different port // Drop different port
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
dest: netip.MustParsePrefix("192.168.1.0/24"), dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
proto: fw.ProtocolTCP, proto: fw.ProtocolTCP,
dstPort: &fw.Port{Values: []uint16{80}}, dstPort: &fw.Port{Values: []uint16{80}},
action: fw.ActionDrop, action: fw.ActionDrop,
@@ -985,6 +1405,7 @@ func TestRouteACLOrder(t *testing.T) {
var rules []fw.Rule var rules []fw.Rule
for _, r := range tc.rules { for _, r := range tc.rules {
rule, err := manager.AddRouteFiltering( rule, err := manager.AddRouteFiltering(
nil,
r.sources, r.sources,
r.dest, r.dest,
r.proto, r.proto,
@@ -1004,12 +1425,62 @@ func TestRouteACLOrder(t *testing.T) {
}) })
for i, p := range tc.packets { for i, p := range tc.packets {
srcIP := net.ParseIP(p.srcIP) srcIP := netip.MustParseAddr(p.srcIP)
dstIP := net.ParseIP(p.dstIP) dstIP := netip.MustParseAddr(p.dstIP)
isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort) _, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i) require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
} }
}) })
} }
} }
func TestRouteACLSet(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
}
},
}
manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
set := fw.NewDomainSet(domain.List{"example.org"})
// Add rule that uses the set (initially empty)
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
srcIP := netip.MustParseAddr("100.10.0.1")
dstIP := netip.MustParseAddr("192.168.1.100")
// Check that traffic is dropped (empty set shouldn't match anything)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
require.False(t, isAllowed, "Empty set should not allow any traffic")
err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")})
require.NoError(t, err)
// Now the packet should be allowed
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
}

View File

@@ -3,6 +3,7 @@ package uspfilter
import ( import (
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
"testing" "testing"
"time" "time"
@@ -18,9 +19,12 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/log" "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow"
"github.com/netbirdio/netbird/management/domain"
) )
var logger = log.NewFromLogrus(logrus.StandardLogger()) var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
type IFaceMock struct { type IFaceMock struct {
SetFilterFunc func(device.PacketFilter) error SetFilterFunc func(device.PacketFilter) error
@@ -62,7 +66,7 @@ func TestManagerCreate(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock, false) m, err := Create(ifaceMock, false, flowLogger)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -82,7 +86,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
}, },
} }
m, err := Create(ifaceMock, false) m, err := Create(ifaceMock, false, flowLogger)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -92,9 +96,8 @@ func TestManagerAddPeerFiltering(t *testing.T) {
proto := fw.ProtocolTCP proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}} port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule"
rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment) rule, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -116,26 +119,25 @@ func TestManagerDeleteRule(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock, false) m, err := Create(ifaceMock, false, flowLogger)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
} }
ip := net.ParseIP("192.168.1.1") ip := netip.MustParseAddr("192.168.1.1")
proto := fw.ProtocolTCP proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}} port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule 2"
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment) rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
} }
for _, r := range rule2 { for _, r := range rule2 {
if _, ok := m.incomingRules[ip.String()][r.ID()]; !ok { if _, ok := m.incomingRules[ip][r.ID()]; !ok {
t.Errorf("rule2 is not in the incomingRules") t.Errorf("rule2 is not in the incomingRules")
} }
} }
@@ -149,7 +151,7 @@ func TestManagerDeleteRule(t *testing.T) {
} }
for _, r := range rule2 { for _, r := range rule2 {
if _, ok := m.incomingRules[ip.String()][r.ID()]; ok { if _, ok := m.incomingRules[ip][r.ID()]; ok {
t.Errorf("rule2 is not in the incomingRules") t.Errorf("rule2 is not in the incomingRules")
} }
} }
@@ -160,7 +162,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name string name string
in bool in bool
expDir fw.RuleDirection expDir fw.RuleDirection
ip net.IP ip netip.Addr
dPort uint16 dPort uint16
hook func([]byte) bool hook func([]byte) bool
expectedID string expectedID string
@@ -169,7 +171,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name: "Test Outgoing UDP Packet Hook", name: "Test Outgoing UDP Packet Hook",
in: false, in: false,
expDir: fw.RuleDirectionOUT, expDir: fw.RuleDirectionOUT,
ip: net.IPv4(10, 168, 0, 1), ip: netip.MustParseAddr("10.168.0.1"),
dPort: 8000, dPort: 8000,
hook: func([]byte) bool { return true }, hook: func([]byte) bool { return true },
}, },
@@ -177,7 +179,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name: "Test Incoming UDP Packet Hook", name: "Test Incoming UDP Packet Hook",
in: true, in: true,
expDir: fw.RuleDirectionIN, expDir: fw.RuleDirectionIN,
ip: net.IPv6loopback, ip: netip.MustParseAddr("::1"),
dPort: 9000, dPort: 9000,
hook: func([]byte) bool { return false }, hook: func([]byte) bool { return false },
}, },
@@ -187,18 +189,18 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
var addedRule PeerRule var addedRule PeerRule
if tt.in { if tt.in {
if len(manager.incomingRules[tt.ip.String()]) != 1 { if len(manager.incomingRules[tt.ip]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules)) t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
return return
} }
for _, rule := range manager.incomingRules[tt.ip.String()] { for _, rule := range manager.incomingRules[tt.ip] {
addedRule = rule addedRule = rule
} }
} else { } else {
@@ -206,12 +208,12 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules)) t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
return return
} }
for _, rule := range manager.outgoingRules[tt.ip.String()] { for _, rule := range manager.outgoingRules[tt.ip] {
addedRule = rule addedRule = rule
} }
} }
if !tt.ip.Equal(addedRule.ip) { if tt.ip.Compare(addedRule.ip) != 0 {
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip) t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
return return
} }
@@ -236,7 +238,7 @@ func TestManagerReset(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock, false) m, err := Create(ifaceMock, false, flowLogger)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -246,9 +248,8 @@ func TestManagerReset(t *testing.T) {
proto := fw.ProtocolTCP proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}} port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment) _, err = m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -279,7 +280,7 @@ func TestNotMatchByIP(t *testing.T) {
}, },
} }
m, err := Create(ifaceMock, false) m, err := Create(ifaceMock, false, flowLogger)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -292,9 +293,8 @@ func TestNotMatchByIP(t *testing.T) {
ip := net.ParseIP("0.0.0.0") ip := net.ParseIP("0.0.0.0")
proto := fw.ProtocolUDP proto := fw.ProtocolUDP
action := fw.ActionAccept action := fw.ActionAccept
comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment) _, err = m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -328,7 +328,7 @@ func TestNotMatchByIP(t *testing.T) {
return return
} }
if m.dropFilter(buf.Bytes()) { if m.dropFilter(buf.Bytes(), 0) {
t.Errorf("expected packet to be accepted") t.Errorf("expected packet to be accepted")
return return
} }
@@ -347,7 +347,7 @@ func TestRemovePacketHook(t *testing.T) {
} }
// creating manager instance // creating manager instance
manager, err := Create(iface, false) manager, err := Create(iface, false, flowLogger)
if err != nil { if err != nil {
t.Fatalf("Failed to create Manager: %s", err) t.Fatalf("Failed to create Manager: %s", err)
} }
@@ -357,7 +357,7 @@ func TestRemovePacketHook(t *testing.T) {
// Add a UDP packet hook // Add a UDP packet hook
hookFunc := func(data []byte) bool { return true } hookFunc := func(data []byte) bool { return true }
hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc) hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
// Assert the hook is added by finding it in the manager's outgoing rules // Assert the hook is added by finding it in the manager's outgoing rules
found := false found := false
@@ -393,7 +393,7 @@ func TestRemovePacketHook(t *testing.T) {
func TestProcessOutgoingHooks(t *testing.T) { func TestProcessOutgoingHooks(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{ manager.wgNetwork = &net.IPNet{
@@ -401,7 +401,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
Mask: net.CIDRMask(16, 32), Mask: net.CIDRMask(16, 32),
} }
manager.udpTracker.Close() manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger) manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
defer func() { defer func() {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Close(nil))
}() }()
@@ -423,7 +423,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
hookCalled := false hookCalled := false
hookID := manager.AddUDPPacketHook( hookID := manager.AddUDPPacketHook(
false, false,
net.ParseIP("100.10.0.100"), netip.MustParseAddr("100.10.0.100"),
53, 53,
func([]byte) bool { func([]byte) bool {
hookCalled = true hookCalled = true
@@ -458,7 +458,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Test hook gets called // Test hook gets called
result := manager.processOutgoingHooks(buf.Bytes()) result := manager.processOutgoingHooks(buf.Bytes(), 0)
require.True(t, result) require.True(t, result)
require.True(t, hookCalled) require.True(t, hookCalled)
@@ -468,7 +468,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
err = gopacket.SerializeLayers(buf, opts, ipv4) err = gopacket.SerializeLayers(buf, opts, ipv4)
require.NoError(t, err) require.NoError(t, err)
result = manager.processOutgoingHooks(buf.Bytes()) result = manager.processOutgoingHooks(buf.Bytes(), 0)
require.False(t, result) require.False(t, result)
} }
@@ -479,7 +479,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
manager, err := Create(ifaceMock, false) manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -494,7 +494,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}} port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
} }
@@ -506,7 +506,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
func TestStatefulFirewall_UDPTracking(t *testing.T) { func TestStatefulFirewall_UDPTracking(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{ manager.wgNetwork = &net.IPNet{
@@ -515,7 +515,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
} }
manager.udpTracker.Close() // Close the existing tracker manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger) manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
manager.decoders = sync.Pool{ manager.decoders = sync.Pool{
New: func() any { New: func() any {
d := &decoder{ d := &decoder{
@@ -534,8 +534,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}() }()
// Set up packet parameters // Set up packet parameters
srcIP := net.ParseIP("100.10.0.1") srcIP := netip.MustParseAddr("100.10.0.1")
dstIP := net.ParseIP("100.10.0.100") dstIP := netip.MustParseAddr("100.10.0.100")
srcPort := uint16(51334) srcPort := uint16(51334)
dstPort := uint16(53) dstPort := uint16(53)
@@ -543,8 +543,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
outboundIPv4 := &layers.IPv4{ outboundIPv4 := &layers.IPv4{
TTL: 64, TTL: 64,
Version: 4, Version: 4,
SrcIP: srcIP, SrcIP: srcIP.AsSlice(),
DstIP: dstIP, DstIP: dstIP.AsSlice(),
Protocol: layers.IPProtocolUDP, Protocol: layers.IPProtocolUDP,
} }
outboundUDP := &layers.UDP{ outboundUDP := &layers.UDP{
@@ -569,15 +569,15 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Process outbound packet and verify connection tracking // Process outbound packet and verify connection tracking
drop := manager.DropOutgoing(outboundBuf.Bytes()) drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
require.False(t, drop, "Initial outbound packet should not be dropped") require.False(t, drop, "Initial outbound packet should not be dropped")
// Verify connection was tracked // Verify connection was tracked
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort) conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
require.True(t, exists, "Connection should be tracked after outbound packet") require.True(t, exists, "Connection should be tracked after outbound packet")
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match") require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match")
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match") require.True(t, dstIP.Compare(conn.DestIP) == 0, "Destination IP should match")
require.Equal(t, srcPort, conn.SourcePort, "Source port should match") require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
require.Equal(t, dstPort, conn.DestPort, "Destination port should match") require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
@@ -585,8 +585,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
inboundIPv4 := &layers.IPv4{ inboundIPv4 := &layers.IPv4{
TTL: 64, TTL: 64,
Version: 4, Version: 4,
SrcIP: dstIP, // Original destination is now source SrcIP: dstIP.AsSlice(), // Original destination is now source
DstIP: srcIP, // Original source is now destination DstIP: srcIP.AsSlice(), // Original source is now destination
Protocol: layers.IPProtocolUDP, Protocol: layers.IPProtocolUDP,
} }
inboundUDP := &layers.UDP{ inboundUDP := &layers.UDP{
@@ -636,7 +636,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
for _, cp := range checkPoints { for _, cp := range checkPoints {
time.Sleep(cp.sleep) time.Sleep(cp.sleep)
drop = manager.dropFilter(inboundBuf.Bytes()) drop = manager.dropFilter(inboundBuf.Bytes(), 0)
require.Equal(t, cp.shouldAllow, !drop, cp.description) require.Equal(t, cp.shouldAllow, !drop, cp.description)
// If the connection should still be valid, verify it exists // If the connection should still be valid, verify it exists
@@ -685,7 +685,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
} }
// Create a new outbound connection for invalid tests // Create a new outbound connection for invalid tests
drop = manager.processOutgoingHooks(outboundBuf.Bytes()) drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
require.False(t, drop, "Second outbound packet should not be dropped") require.False(t, drop, "Second outbound packet should not be dropped")
for _, tc := range invalidCases { for _, tc := range invalidCases {
@@ -707,8 +707,208 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Verify the invalid packet is dropped // Verify the invalid packet is dropped
drop = manager.dropFilter(testBuf.Bytes()) drop = manager.dropFilter(testBuf.Bytes(), 0)
require.True(t, drop, tc.description) require.True(t, drop, tc.description)
}) })
} }
} }
func TestUpdateSetMerge(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
set := fw.NewDomainSet(domain.List{"example.org"})
initialPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("192.168.1.0/24"),
}
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
// Update the set with initial prefixes
err = manager.UpdateSet(set, initialPrefixes)
require.NoError(t, err)
// Test initial prefixes work
srcIP := netip.MustParseAddr("100.10.0.1")
dstIP1 := netip.MustParseAddr("10.0.0.100")
dstIP2 := netip.MustParseAddr("192.168.1.100")
dstIP3 := netip.MustParseAddr("172.16.0.100")
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed")
require.False(t, isAllowed3, "Traffic to 172.16.0.100 should be denied")
newPrefixes := []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("10.1.0.0/24"),
}
err = manager.UpdateSet(set, newPrefixes)
require.NoError(t, err)
// Check that all original prefixes are still included
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update")
// Check that new prefixes are included
dstIP4 := netip.MustParseAddr("172.16.1.100")
dstIP5 := netip.MustParseAddr("10.1.0.50")
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80)
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80)
require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed")
require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed")
// Verify the rule has all prefixes
manager.mutex.RLock()
foundRule := false
for _, r := range manager.routeRules {
if r.id == rule.ID() {
foundRule = true
require.Len(t, r.destinations, len(initialPrefixes)+len(newPrefixes),
"Rule should have all prefixes merged")
}
}
manager.mutex.RUnlock()
require.True(t, foundRule, "Rule should be found")
}
func TestUpdateSetDeduplication(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
set := fw.NewDomainSet(domain.List{"example.org"})
rule, err := manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
fw.Network{Set: set},
fw.ProtocolTCP,
nil,
nil,
fw.ActionAccept,
)
require.NoError(t, err)
require.NotNil(t, rule)
initialPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("10.0.0.0/24"), // Duplicate
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.1.0/24"), // Duplicate
}
err = manager.UpdateSet(set, initialPrefixes)
require.NoError(t, err)
// Check the internal state for deduplication
manager.mutex.RLock()
foundRule := false
for _, r := range manager.routeRules {
if r.id == rule.ID() {
foundRule = true
// Should have deduplicated to 2 prefixes
require.Len(t, r.destinations, 2, "Duplicate prefixes should be removed")
// Check the prefixes are correct
expectedPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("192.168.1.0/24"),
}
for i, prefix := range expectedPrefixes {
require.True(t, r.destinations[i] == prefix,
"Prefix should match expected value")
}
}
}
manager.mutex.RUnlock()
require.True(t, foundRule, "Rule should be found")
// Test with overlapping prefixes of different sizes
overlappingPrefixes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/16"), // More general
netip.MustParsePrefix("10.0.0.0/24"), // More specific (already exists)
netip.MustParsePrefix("192.168.0.0/16"), // More general
netip.MustParsePrefix("192.168.1.0/24"), // More specific (already exists)
}
err = manager.UpdateSet(set, overlappingPrefixes)
require.NoError(t, err)
// Check that all prefixes are included (no deduplication of overlapping prefixes)
manager.mutex.RLock()
for _, r := range manager.routeRules {
if r.id == rule.ID() {
// Should have all 4 prefixes (2 original + 2 new more general ones)
require.Len(t, r.destinations, 4,
"Overlapping prefixes should not be deduplicated")
// Verify they're sorted correctly (more specific prefixes should come first)
prefixes := make([]string, 0, len(r.destinations))
for _, p := range r.destinations {
prefixes = append(prefixes, p.String())
}
// Check sorted order
require.Equal(t, []string{
"10.0.0.0/16",
"10.0.0.0/24",
"192.168.0.0/16",
"192.168.1.0/24",
}, prefixes, "Prefixes should be sorted")
}
}
manager.mutex.RUnlock()
// Test functionality with all prefixes
testCases := []struct {
dstIP netip.Addr
expected bool
desc string
}{
{netip.MustParseAddr("10.0.0.100"), true, "IP in both /16 and /24"},
{netip.MustParseAddr("10.0.1.100"), true, "IP only in /16"},
{netip.MustParseAddr("192.168.1.100"), true, "IP in both /16 and /24"},
{netip.MustParseAddr("192.168.2.100"), true, "IP only in /16"},
{netip.MustParseAddr("172.16.0.100"), false, "IP not in any prefix"},
}
srcIP := netip.MustParseAddr("100.10.0.1")
for _, tc := range testCases {
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80)
require.Equal(t, tc.expected, isAllowed, tc.desc)
}
}

View File

@@ -150,7 +150,7 @@ func isZeros(ip net.IP) bool {
// NewUDPMuxDefault creates an implementation of UDPMux // NewUDPMuxDefault creates an implementation of UDPMux
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
if params.Logger == nil { if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") params.Logger = getLogger()
} }
mux := &UDPMuxDefault{ mux := &UDPMuxDefault{
@@ -455,3 +455,9 @@ func newBufferHolder(size int) *bufferHolder {
buf: make([]byte, size), buf: make([]byte, size),
} }
} }
func getLogger() logging.LeveledLogger {
fac := logging.NewDefaultLoggerFactory()
//fac.Writer = log.StandardLogger().Writer()
return fac.NewLogger("ice")
}

View File

@@ -49,7 +49,7 @@ type UniversalUDPMuxParams struct {
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux // NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault { func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault {
if params.Logger == nil { if params.Logger == nil {
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice") params.Logger = getLogger()
} }
if params.XORMappedAddrCacheTTL == 0 { if params.XORMappedAddrCacheTTL == 0 {
params.XORMappedAddrCacheTTL = time.Second * 25 params.XORMappedAddrCacheTTL = time.Second * 25

View File

@@ -357,7 +357,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
func getFwmark() int { func getFwmark() int {
if nbnet.AdvancedRouting() { if nbnet.AdvancedRouting() {
return nbnet.NetbirdFwmark return nbnet.ControlPlaneMark
} }
return 0 return 0
} }

View File

@@ -2,6 +2,7 @@ package device
import ( import (
"net" "net"
"net/netip"
"sync" "sync"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
@@ -10,16 +11,16 @@ import (
// PacketFilter interface for firewall abilities // PacketFilter interface for firewall abilities
type PacketFilter interface { type PacketFilter interface {
// DropOutgoing filter outgoing packets from host to external destinations // DropOutgoing filter outgoing packets from host to external destinations
DropOutgoing(packetData []byte) bool DropOutgoing(packetData []byte, size int) bool
// DropIncoming filter incoming packets from external sources to host // DropIncoming filter incoming packets from external sources to host
DropIncoming(packetData []byte) bool DropIncoming(packetData []byte, size int) bool
// AddUDPPacketHook calls hook when UDP packet from given direction matched // AddUDPPacketHook calls hook when UDP packet from given direction matched
// //
// Hook function returns flag which indicates should be the matched package dropped or not. // Hook function returns flag which indicates should be the matched package dropped or not.
// Hook function receives raw network packet data as argument. // Hook function receives raw network packet data as argument.
AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
// RemovePacketHook removes hook by ID // RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error RemovePacketHook(hookID string) error
@@ -57,7 +58,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
} }
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
if filter.DropOutgoing(bufs[i][offset : offset+sizes[i]]) { if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
bufs = append(bufs[:i], bufs[i+1:]...) bufs = append(bufs[:i], bufs[i+1:]...)
sizes = append(sizes[:i], sizes[i+1:]...) sizes = append(sizes[:i], sizes[i+1:]...)
n-- n--
@@ -81,7 +82,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
filteredBufs := make([][]byte, 0, len(bufs)) filteredBufs := make([][]byte, 0, len(bufs))
dropped := 0 dropped := 0
for _, buf := range bufs { for _, buf := range bufs {
if !filter.DropIncoming(buf[offset:]) { if !filter.DropIncoming(buf[offset:], len(buf)) {
filteredBufs = append(filteredBufs, buf) filteredBufs = append(filteredBufs, buf)
dropped++ dropped++
} }

View File

@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
tun.EXPECT().Write(mockBufs, 0).Return(0, nil) tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
filter := mocks.NewMockPacketFilter(ctrl) filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropIncoming(gomock.Any()).Return(true) filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun) wrapped := newDeviceFilter(tun)
wrapped.filter = filter wrapped.filter = filter
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
return 1, nil return 1, nil
}) })
filter := mocks.NewMockPacketFilter(ctrl) filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropOutgoing(gomock.Any()).Return(true) filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun) wrapped := newDeviceFilter(tun)
wrapped.filter = filter wrapped.filter = filter

View File

@@ -6,6 +6,7 @@ package mocks
import ( import (
net "net" net "net"
"net/netip"
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
@@ -35,7 +36,7 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
} }
// AddUDPPacketHook mocks base method. // AddUDPPacketHook mocks base method.
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func([]byte) bool) string { func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3) ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(string) ret0, _ := ret[0].(string)
@@ -49,31 +50,31 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
} }
// DropIncoming mocks base method. // DropIncoming mocks base method.
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool { func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropIncoming", arg0) ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1)
ret0, _ := ret[0].(bool) ret0, _ := ret[0].(bool)
return ret0 return ret0
} }
// DropIncoming indicates an expected call of DropIncoming. // DropIncoming indicates an expected call of DropIncoming.
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call { func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1)
} }
// DropOutgoing mocks base method. // DropOutgoing mocks base method.
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool { func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropOutgoing", arg0) ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1)
ret0, _ := ret[0].(bool) ret0, _ := ret[0].(bool)
return ret0 return ret0
} }
// DropOutgoing indicates an expected call of DropOutgoing. // DropOutgoing indicates an expected call of DropOutgoing.
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call { func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1)
} }
// RemovePacketHook mocks base method. // RemovePacketHook mocks base method.

View File

@@ -24,6 +24,8 @@
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run" !define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
!define NETBIRD_DATA_DIR "$COMMONPROGRAMDATA\Netbird"
Unicode True Unicode True
###################################################################### ######################################################################
@@ -49,6 +51,10 @@ ShowInstDetails Show
###################################################################### ######################################################################
!include "MUI2.nsh"
!include LogicLib.nsh
!include "nsDialogs.nsh"
!define MUI_ICON "${ICON}" !define MUI_ICON "${ICON}"
!define MUI_UNICON "${ICON}" !define MUI_UNICON "${ICON}"
!define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}" !define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}"
@@ -58,9 +64,6 @@ ShowInstDetails Show
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink" !define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink"
###################################################################### ######################################################################
!include "MUI2.nsh"
!include LogicLib.nsh
!define MUI_ABORTWARNING !define MUI_ABORTWARNING
!define MUI_UNABORTWARNING !define MUI_UNABORTWARNING
@@ -70,13 +73,16 @@ ShowInstDetails Show
!insertmacro MUI_PAGE_DIRECTORY !insertmacro MUI_PAGE_DIRECTORY
; Custom page for autostart checkbox
Page custom AutostartPage AutostartPageLeave Page custom AutostartPage AutostartPageLeave
!insertmacro MUI_PAGE_INSTFILES !insertmacro MUI_PAGE_INSTFILES
!insertmacro MUI_PAGE_FINISH !insertmacro MUI_PAGE_FINISH
!insertmacro MUI_UNPAGE_WELCOME
UninstPage custom un.DeleteDataPage un.DeleteDataPageLeave
!insertmacro MUI_UNPAGE_CONFIRM !insertmacro MUI_UNPAGE_CONFIRM
!insertmacro MUI_UNPAGE_INSTFILES !insertmacro MUI_UNPAGE_INSTFILES
@@ -89,6 +95,10 @@ Page custom AutostartPage AutostartPageLeave
Var AutostartCheckbox Var AutostartCheckbox
Var AutostartEnabled Var AutostartEnabled
; Variables for uninstall data deletion option
Var DeleteDataCheckbox
Var DeleteDataEnabled
###################################################################### ######################################################################
; Function to create the autostart options page ; Function to create the autostart options page
@@ -104,8 +114,8 @@ Function AutostartPage
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts" ${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
Pop $AutostartCheckbox Pop $AutostartCheckbox
${NSD_Check} $AutostartCheckbox ; Default to checked ${NSD_Check} $AutostartCheckbox
StrCpy $AutostartEnabled "1" ; Default to enabled StrCpy $AutostartEnabled "1"
nsDialogs::Show nsDialogs::Show
FunctionEnd FunctionEnd
@@ -115,6 +125,30 @@ Function AutostartPageLeave
${NSD_GetState} $AutostartCheckbox $AutostartEnabled ${NSD_GetState} $AutostartCheckbox $AutostartEnabled
FunctionEnd FunctionEnd
; Function to create the uninstall data deletion page
Function un.DeleteDataPage
!insertmacro MUI_HEADER_TEXT "Uninstall Options" "Choose whether to delete ${APP_NAME} data."
nsDialogs::Create 1018
Pop $0
${If} $0 == error
Abort
${EndIf}
${NSD_CreateCheckbox} 0 20u 100% 10u "Delete all ${APP_NAME} configuration and state data (${NETBIRD_DATA_DIR})"
Pop $DeleteDataCheckbox
${NSD_Uncheck} $DeleteDataCheckbox
StrCpy $DeleteDataEnabled "0"
nsDialogs::Show
FunctionEnd
; Function to handle leaving the data deletion page
Function un.DeleteDataPageLeave
${NSD_GetState} $DeleteDataCheckbox $DeleteDataEnabled
FunctionEnd
Function GetAppFromCommand Function GetAppFromCommand
Exch $1 Exch $1
Push $2 Push $2
@@ -225,31 +259,58 @@ SectionEnd
Section Uninstall Section Uninstall
${INSTALL_TYPE} ${INSTALL_TYPE}
DetailPrint "Stopping Netbird service..."
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop' ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
DetailPrint "Uninstalling Netbird service..."
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall' ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
# kill ui client DetailPrint "Terminating Netbird UI process..."
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f` ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart registry entry ; Remove autostart registry entry
DetailPrint "Removing autostart registry entry if exists..."
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Handle data deletion based on checkbox
DetailPrint "Checking if user requested data deletion..."
${If} $DeleteDataEnabled == "1"
DetailPrint "User opted to delete Netbird data. Removing ${NETBIRD_DATA_DIR}..."
ClearErrors
RMDir /r "${NETBIRD_DATA_DIR}"
IfErrors 0 +2 ; If no errors, jump over the message
DetailPrint "Error deleting Netbird data directory. It might be in use or already removed."
DetailPrint "Netbird data directory removal complete."
${Else}
DetailPrint "User did not opt to delete Netbird data."
${EndIf}
# wait the service uninstall take unblock the executable # wait the service uninstall take unblock the executable
DetailPrint "Waiting for service handle to be released..."
Sleep 3000 Sleep 3000
DetailPrint "Deleting application files..."
Delete "$INSTDIR\${UI_APP_EXE}" Delete "$INSTDIR\${UI_APP_EXE}"
Delete "$INSTDIR\${MAIN_APP_EXE}" Delete "$INSTDIR\${MAIN_APP_EXE}"
Delete "$INSTDIR\wintun.dll" Delete "$INSTDIR\wintun.dll"
Delete "$INSTDIR\opengl32.dll" Delete "$INSTDIR\opengl32.dll"
DetailPrint "Removing application directory..."
RmDir /r "$INSTDIR" RmDir /r "$INSTDIR"
DetailPrint "Removing shortcuts..."
SetShellVarContext all SetShellVarContext all
Delete "$DESKTOP\${APP_NAME}.lnk" Delete "$DESKTOP\${APP_NAME}.lnk"
Delete "$SMPROGRAMS\${APP_NAME}.lnk" Delete "$SMPROGRAMS\${APP_NAME}.lnk"
DetailPrint "Removing registry keys..."
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}" DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}" DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}"
DetailPrint "Removing application directory from PATH..."
EnVar::SetHKLM EnVar::SetHKLM
EnVar::DeleteValue "path" "$INSTDIR" EnVar::DeleteValue "path" "$INSTDIR"
DetailPrint "Uninstallation finished."
SectionEnd SectionEnd

View File

@@ -18,7 +18,7 @@ func (r RuleID) ID() string {
func GenerateRouteRuleKey( func GenerateRouteRuleKey(
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination manager.Network,
proto manager.Protocol, proto manager.Protocol,
sPort *manager.Port, sPort *manager.Port,
dPort *manager.Port, dPort *manager.Port,

View File

@@ -18,6 +18,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/management/domain"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
) )
@@ -25,7 +26,12 @@ var ErrSourceRangesEmpty = errors.New("sources range is empty")
// Manager is a ACL rules manager // Manager is a ACL rules manager
type Manager interface { type Manager interface {
ApplyFiltering(networkMap *mgmProto.NetworkMap) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
}
type protoMatch struct {
ips map[string]int
policyID []byte
} }
// DefaultManager uses firewall manager to handle // DefaultManager uses firewall manager to handle
@@ -48,7 +54,7 @@ func NewDefaultManager(fm firewall.Manager) *DefaultManager {
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy. // ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
// //
// If allowByDefault is true it appends allow ALL traffic rules to input and output chains. // If allowByDefault is true it appends allow ALL traffic rules to input and output chains.
func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) { func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) {
d.mutex.Lock() d.mutex.Lock()
defer d.mutex.Unlock() defer d.mutex.Unlock()
@@ -77,7 +83,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
log.Errorf("failed to set legacy management flag: %v", err) log.Errorf("failed to set legacy management flag: %v", err)
} }
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules); err != nil { if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
log.Errorf("Failed to apply route ACLs: %v", err) log.Errorf("Failed to apply route ACLs: %v", err)
} }
@@ -171,16 +177,16 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
d.peerRulesPairs = newRulePairs d.peerRulesPairs = newRulePairs
} }
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error { func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dynamicResolver bool) error {
newRouteRules := make(map[id.RuleID]struct{}, len(rules)) newRouteRules := make(map[id.RuleID]struct{}, len(rules))
var merr *multierror.Error var merr *multierror.Error
// Apply new rules - firewall manager will return existing rule ID if already present // Apply new rules - firewall manager will return existing rule ID if already present
for _, rule := range rules { for _, rule := range rules {
id, err := d.applyRouteACL(rule) id, err := d.applyRouteACL(rule, dynamicResolver)
if err != nil { if err != nil {
if errors.Is(err, ErrSourceRangesEmpty) { if errors.Is(err, ErrSourceRangesEmpty) {
log.Debugf("skipping empty rule with destination %s: %v", rule.Destination, err) log.Debugf("skipping empty sources rule with destination %s: %v", rule.Destination, err)
} else { } else {
merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err)) merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err))
} }
@@ -203,7 +209,7 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) err
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)
} }
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) { func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (id.RuleID, error) {
if len(rule.SourceRanges) == 0 { if len(rule.SourceRanges) == 0 {
return "", ErrSourceRangesEmpty return "", ErrSourceRangesEmpty
} }
@@ -217,15 +223,9 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul
sources = append(sources, source) sources = append(sources, source)
} }
var destination netip.Prefix destination, err := determineDestination(rule, dynamicResolver, sources)
if rule.IsDynamic {
destination = getDefault(sources[0])
} else {
var err error
destination, err = netip.ParsePrefix(rule.Destination)
if err != nil { if err != nil {
return "", fmt.Errorf("parse destination: %w", err) return "", fmt.Errorf("determine destination: %w", err)
}
} }
protocol, err := convertToFirewallProtocol(rule.Protocol) protocol, err := convertToFirewallProtocol(rule.Protocol)
@@ -240,7 +240,7 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul
dPorts := convertPortInfo(rule.PortInfo) dPorts := convertPortInfo(rule.PortInfo)
addedRule, err := d.firewall.AddRouteFiltering(sources, destination, protocol, nil, dPorts, action) addedRule, err := d.firewall.AddRouteFiltering(rule.PolicyID, sources, destination, protocol, nil, dPorts, action)
if err != nil { if err != nil {
return "", fmt.Errorf("add route rule: %w", err) return "", fmt.Errorf("add route rule: %w", err)
} }
@@ -281,7 +281,7 @@ func (d *DefaultManager) protoRuleToFirewallRule(
} }
} }
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action, "") ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action)
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok { if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
return ruleID, rulesPair, nil return ruleID, rulesPair, nil
} }
@@ -289,11 +289,11 @@ func (d *DefaultManager) protoRuleToFirewallRule(
var rules []firewall.Rule var rules []firewall.Rule
switch r.Direction { switch r.Direction {
case mgmProto.RuleDirection_IN: case mgmProto.RuleDirection_IN:
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "") rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
case mgmProto.RuleDirection_OUT: case mgmProto.RuleDirection_OUT:
// TODO: Remove this soon. Outbound rules are obsolete. // TODO: Remove this soon. Outbound rules are obsolete.
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already // We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "") rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
default: default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule") return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
} }
@@ -322,14 +322,14 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
} }
func (d *DefaultManager) addInRules( func (d *DefaultManager) addInRules(
id []byte,
ip net.IP, ip net.IP,
protocol firewall.Protocol, protocol firewall.Protocol,
port *firewall.Port, port *firewall.Port,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
rule, err := d.firewall.AddPeerFiltering(ip, protocol, nil, port, action, ipsetName, comment) rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, nil, port, action, ipsetName)
if err != nil { if err != nil {
return nil, fmt.Errorf("add firewall rule: %w", err) return nil, fmt.Errorf("add firewall rule: %w", err)
} }
@@ -338,18 +338,18 @@ func (d *DefaultManager) addInRules(
} }
func (d *DefaultManager) addOutRules( func (d *DefaultManager) addOutRules(
id []byte,
ip net.IP, ip net.IP,
protocol firewall.Protocol, protocol firewall.Protocol,
port *firewall.Port, port *firewall.Port,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
if shouldSkipInvertedRule(protocol, port) { if shouldSkipInvertedRule(protocol, port) {
return nil, nil return nil, nil
} }
rule, err := d.firewall.AddPeerFiltering(ip, protocol, port, nil, action, ipsetName, comment) rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, port, nil, action, ipsetName)
if err != nil { if err != nil {
return nil, fmt.Errorf("add firewall rule: %w", err) return nil, fmt.Errorf("add firewall rule: %w", err)
} }
@@ -364,9 +364,8 @@ func (d *DefaultManager) getPeerRuleID(
direction int, direction int,
port *firewall.Port, port *firewall.Port,
action firewall.Action, action firewall.Action,
comment string,
) id.RuleID { ) id.RuleID {
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action))
if port != nil { if port != nil {
idStr += port.String() idStr += port.String()
} }
@@ -389,10 +388,8 @@ func (d *DefaultManager) squashAcceptRules(
} }
} }
type protoMatch map[mgmProto.RuleProtocol]map[string]int in := map[mgmProto.RuleProtocol]*protoMatch{}
out := map[mgmProto.RuleProtocol]*protoMatch{}
in := protoMatch{}
out := protoMatch{}
// trace which type of protocols was squashed // trace which type of protocols was squashed
squashedRules := []*mgmProto.FirewallRule{} squashedRules := []*mgmProto.FirewallRule{}
@@ -405,14 +402,18 @@ func (d *DefaultManager) squashAcceptRules(
// 2. Any of rule contains Port. // 2. Any of rule contains Port.
// //
// We zeroed this to notify squash function that this protocol can't be squashed. // We zeroed this to notify squash function that this protocol can't be squashed.
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols protoMatch) { addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
drop := r.Action == mgmProto.RuleAction_DROP || r.Port != "" drop := r.Action == mgmProto.RuleAction_DROP || r.Port != ""
if drop { if drop {
protocols[r.Protocol] = map[string]int{} protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
return return
} }
if _, ok := protocols[r.Protocol]; !ok { if _, ok := protocols[r.Protocol]; !ok {
protocols[r.Protocol] = map[string]int{} protocols[r.Protocol] = &protoMatch{
ips: map[string]int{},
// store the first encountered PolicyID for this protocol
policyID: r.PolicyID,
}
} }
// special case, when we receive this all network IP address // special case, when we receive this all network IP address
@@ -424,7 +425,7 @@ func (d *DefaultManager) squashAcceptRules(
return return
} }
ipset := protocols[r.Protocol] ipset := protocols[r.Protocol].ips
if _, ok := ipset[r.PeerIP]; ok { if _, ok := ipset[r.PeerIP]; ok {
return return
@@ -450,9 +451,10 @@ func (d *DefaultManager) squashAcceptRules(
mgmProto.RuleProtocol_UDP, mgmProto.RuleProtocol_UDP,
} }
squash := func(matches protoMatch, direction mgmProto.RuleDirection) { squash := func(matches map[mgmProto.RuleProtocol]*protoMatch, direction mgmProto.RuleDirection) {
for _, protocol := range protocolOrders { for _, protocol := range protocolOrders {
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 { match, ok := matches[protocol]
if !ok || len(match.ips) != totalIPs || len(match.ips) < 2 {
// don't squash if : // don't squash if :
// 1. Rules not cover all peers in the network // 1. Rules not cover all peers in the network
// 2. Rules cover only one peer in the network. // 2. Rules cover only one peer in the network.
@@ -465,6 +467,7 @@ func (d *DefaultManager) squashAcceptRules(
Direction: direction, Direction: direction,
Action: mgmProto.RuleAction_ACCEPT, Action: mgmProto.RuleAction_ACCEPT,
Protocol: protocol, Protocol: protocol,
PolicyID: match.policyID,
}) })
squashedProtocols[protocol] = struct{}{} squashedProtocols[protocol] = struct{}{}
@@ -493,9 +496,9 @@ func (d *DefaultManager) squashAcceptRules(
// if we also have other not squashed rules. // if we also have other not squashed rules.
for i, r := range networkMap.FirewallRules { for i, r := range networkMap.FirewallRules {
if _, ok := squashedProtocols[r.Protocol]; ok { if _, ok := squashedProtocols[r.Protocol]; ok {
if m, ok := in[r.Protocol]; ok && m[r.PeerIP] == i { if m, ok := in[r.Protocol]; ok && m.ips[r.PeerIP] == i {
continue continue
} else if m, ok := out[r.Protocol]; ok && m[r.PeerIP] == i { } else if m, ok := out[r.Protocol]; ok && m.ips[r.PeerIP] == i {
continue continue
} }
} }
@@ -572,6 +575,33 @@ func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port {
return nil return nil
} }
func determineDestination(rule *mgmProto.RouteFirewallRule, dynamicResolver bool, sources []netip.Prefix) (firewall.Network, error) {
var destination firewall.Network
if rule.IsDynamic {
if dynamicResolver {
if len(rule.Domains) > 0 {
destination.Set = firewall.NewDomainSet(domain.FromPunycodeList(rule.Domains))
} else {
// isDynamic is set but no domains = outdated management server
log.Warn("connected to an older version of management server (no domains in rules), using default destination")
destination.Prefix = getDefault(sources[0])
}
} else {
// client resolves DNS, we (router) don't know the destination
destination.Prefix = getDefault(sources[0])
}
return destination, nil
}
prefix, err := netip.ParsePrefix(rule.Destination)
if err != nil {
return destination, fmt.Errorf("parse destination: %w", err)
}
destination.Prefix = prefix
return destination, nil
}
func getDefault(prefix netip.Prefix) netip.Prefix { func getDefault(prefix netip.Prefix) netip.Prefix {
if prefix.Addr().Is6() { if prefix.Addr().Is6() {
return netip.PrefixFrom(netip.IPv6Unspecified(), 0) return netip.PrefixFrom(netip.IPv6Unspecified(), 0)

View File

@@ -10,9 +10,12 @@ import (
"github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks" "github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/client/internal/netflow"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
) )
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
func TestDefaultManager(t *testing.T) { func TestDefaultManager(t *testing.T) {
networkMap := &mgmProto.NetworkMap{ networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{ FirewallRules: []*mgmProto.FirewallRule{
@@ -52,7 +55,7 @@ func TestDefaultManager(t *testing.T) {
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it // we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil, false) fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
if err != nil { if err != nil {
t.Errorf("create firewall: %v", err) t.Errorf("create firewall: %v", err)
return return
@@ -63,7 +66,7 @@ func TestDefaultManager(t *testing.T) {
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)
t.Run("apply firewall rules", func(t *testing.T) { t.Run("apply firewall rules", func(t *testing.T) {
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap, false)
if len(acl.peerRulesPairs) != 2 { if len(acl.peerRulesPairs) != 2 {
t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs) t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs)
@@ -89,7 +92,7 @@ func TestDefaultManager(t *testing.T) {
}, },
) )
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap, false)
// we should have one old and one new rule in the existed rules // we should have one old and one new rule in the existed rules
if len(acl.peerRulesPairs) != 2 { if len(acl.peerRulesPairs) != 2 {
@@ -113,13 +116,13 @@ func TestDefaultManager(t *testing.T) {
networkMap.FirewallRules = networkMap.FirewallRules[:0] networkMap.FirewallRules = networkMap.FirewallRules[:0]
networkMap.FirewallRulesIsEmpty = true networkMap.FirewallRulesIsEmpty = true
if acl.ApplyFiltering(networkMap); len(acl.peerRulesPairs) != 0 { if acl.ApplyFiltering(networkMap, false); len(acl.peerRulesPairs) != 0 {
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs)) t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs))
return return
} }
networkMap.FirewallRulesIsEmpty = false networkMap.FirewallRulesIsEmpty = false
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap, false)
if len(acl.peerRulesPairs) != 1 { if len(acl.peerRulesPairs) != 1 {
t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs)) t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
return return
@@ -346,7 +349,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it // we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil, false) fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
if err != nil { if err != nil {
t.Errorf("create firewall: %v", err) t.Errorf("create firewall: %v", err)
return return
@@ -356,7 +359,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
}(fw) }(fw)
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)
acl.ApplyFiltering(networkMap) acl.ApplyFiltering(networkMap, false)
if len(acl.peerRulesPairs) != 3 { if len(acl.peerRulesPairs) != 3 {
t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs)) t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))

View File

@@ -64,13 +64,8 @@ func (t TokenInfo) GetTokenToUse() string {
// and if that also fails, the authentication process is deemed unsuccessful // and if that also fails, the authentication process is deemed unsuccessful
// //
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow // On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopClient bool) (OAuthFlow, error) { func NewOAuthFlow(ctx context.Context, config *internal.Config, isUnixDesktopClient bool) (OAuthFlow, error) {
if runtime.GOOS == "linux" && !isLinuxDesktopClient { if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
return authenticateWithDeviceCodeFlow(ctx, config)
}
// On FreeBSD we currently do not support desktop environments and offer only Device Code Flow (#2384)
if runtime.GOOS == "freebsd" {
return authenticateWithDeviceCodeFlow(ctx, config) return authenticateWithDeviceCodeFlow(ctx, config)
} }

View File

@@ -94,12 +94,22 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
p.codeVerifier = codeVerifier p.codeVerifier = codeVerifier
codeChallenge := createCodeChallenge(codeVerifier) codeChallenge := createCodeChallenge(codeVerifier)
authURL := p.oAuthConfig.AuthCodeURL(
state, params := []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_challenge_method", "S256"), oauth2.SetAuthURLParam("code_challenge_method", "S256"),
oauth2.SetAuthURLParam("code_challenge", codeChallenge), oauth2.SetAuthURLParam("code_challenge", codeChallenge),
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience), oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
) }
if !p.providerConfig.DisablePromptLogin {
if p.providerConfig.LoginFlag.IsPromptLogin() {
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
}
if p.providerConfig.LoginFlag.IsMaxAge0Login() {
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
}
}
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
return AuthFlowInfo{ return AuthFlowInfo{
VerificationURIComplete: authURL, VerificationURIComplete: authURL,

View File

@@ -0,0 +1,71 @@
package auth
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal"
mgm "github.com/netbirdio/netbird/management/client/common"
)
func TestPromptLogin(t *testing.T) {
const (
promptLogin = "prompt=login"
maxAge0 = "max_age=0"
)
tt := []struct {
name string
loginFlag mgm.LoginFlag
disablePromptLogin bool
expect string
}{
{
name: "Prompt login",
loginFlag: mgm.LoginFlagPrompt,
expect: promptLogin,
},
{
name: "Max age 0 login",
loginFlag: mgm.LoginFlagMaxAge0,
expect: maxAge0,
},
{
name: "Disable prompt login",
loginFlag: mgm.LoginFlagPrompt,
disablePromptLogin: true,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
config := internal.PKCEAuthProviderConfig{
ClientID: "test-client-id",
Audience: "test-audience",
TokenEndpoint: "https://test-token-endpoint.com/token",
Scope: "openid email profile",
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
RedirectURLs: []string{"http://127.0.0.1:33992/"},
UseIDToken: true,
LoginFlag: tc.loginFlag,
}
pkce, err := NewPKCEAuthorizationFlow(config)
if err != nil {
t.Fatalf("Failed to create PKCEAuthorizationFlow: %v", err)
}
authInfo, err := pkce.RequestAuthInfo(context.Background())
if err != nil {
t.Fatalf("Failed to request auth info: %v", err)
}
if !tc.disablePromptLogin {
require.Contains(t, authInfo.VerificationURIComplete, tc.expect)
} else {
require.Contains(t, authInfo.VerificationURIComplete, promptLogin)
require.NotContains(t, authInfo.VerificationURIComplete, maxAge0)
}
})
}
}

View File

@@ -349,6 +349,25 @@ func (c *ConnectClient) Engine() *Engine {
return e return e
} }
// GetLatestNetworkMap returns the latest network map from the engine.
func (c *ConnectClient) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
engine := c.Engine()
if engine == nil {
return nil, errors.New("engine is not initialized")
}
networkMap, err := engine.GetLatestNetworkMap()
if err != nil {
return nil, fmt.Errorf("get latest network map: %w", err)
}
if networkMap == nil {
return nil, errors.New("network map is not available")
}
return networkMap, nil
}
// Status returns the current client status // Status returns the current client status
func (c *ConnectClient) Status() StatusType { func (c *ConnectClient) Status() StatusType {
if c == nil { if c == nil {

File diff suppressed because it is too large Load Diff

View File

@@ -1,9 +1,8 @@
//go:build linux && !android //go:build linux && !android
package server package debug
import ( import (
"archive/zip"
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
@@ -14,36 +13,31 @@ import (
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/proto"
) )
// addFirewallRules collects and adds firewall rules to the archive // addFirewallRules collects and adds firewall rules to the archive
func (s *Server) addFirewallRules(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { func (g *BundleGenerator) addFirewallRules() error {
log.Info("Collecting firewall rules") log.Info("Collecting firewall rules")
// Collect and add iptables rules
iptablesRules, err := collectIPTablesRules() iptablesRules, err := collectIPTablesRules()
if err != nil { if err != nil {
log.Warnf("Failed to collect iptables rules: %v", err) log.Warnf("Failed to collect iptables rules: %v", err)
} else { } else {
if req.GetAnonymize() { if g.anonymize {
iptablesRules = anonymizer.AnonymizeString(iptablesRules) iptablesRules = g.anonymizer.AnonymizeString(iptablesRules)
} }
if err := addFileToZip(archive, strings.NewReader(iptablesRules), "iptables.txt"); err != nil { if err := g.addFileToZip(strings.NewReader(iptablesRules), "iptables.txt"); err != nil {
log.Warnf("Failed to add iptables rules to bundle: %v", err) log.Warnf("Failed to add iptables rules to bundle: %v", err)
} }
} }
// Collect and add nftables rules
nftablesRules, err := collectNFTablesRules() nftablesRules, err := collectNFTablesRules()
if err != nil { if err != nil {
log.Warnf("Failed to collect nftables rules: %v", err) log.Warnf("Failed to collect nftables rules: %v", err)
} else { } else {
if req.GetAnonymize() { if g.anonymize {
nftablesRules = anonymizer.AnonymizeString(nftablesRules) nftablesRules = g.anonymizer.AnonymizeString(nftablesRules)
} }
if err := addFileToZip(archive, strings.NewReader(nftablesRules), "nftables.txt"); err != nil { if err := g.addFileToZip(strings.NewReader(nftablesRules), "nftables.txt"); err != nil {
log.Warnf("Failed to add nftables rules to bundle: %v", err) log.Warnf("Failed to add nftables rules to bundle: %v", err)
} }
} }
@@ -65,16 +59,23 @@ func collectIPTablesRules() (string, error) {
builder.WriteString("\n") builder.WriteString("\n")
} }
// Then get verbose statistics for each table // Collect ipset information
ipsetOutput, err := collectIPSets()
if err != nil {
log.Warnf("Failed to collect ipset information: %v", err)
} else {
builder.WriteString("=== ipset list output ===\n")
builder.WriteString(ipsetOutput)
builder.WriteString("\n")
}
builder.WriteString("=== iptables -v -n -L output ===\n") builder.WriteString("=== iptables -v -n -L output ===\n")
// Get list of tables
tables := []string{"filter", "nat", "mangle", "raw", "security"} tables := []string{"filter", "nat", "mangle", "raw", "security"}
for _, table := range tables { for _, table := range tables {
builder.WriteString(fmt.Sprintf("*%s\n", table)) builder.WriteString(fmt.Sprintf("*%s\n", table))
// Get verbose statistics for the entire table
stats, err := getTableStatistics(table) stats, err := getTableStatistics(table)
if err != nil { if err != nil {
log.Warnf("Failed to get statistics for table %s: %v", table, err) log.Warnf("Failed to get statistics for table %s: %v", table, err)
@@ -87,6 +88,28 @@ func collectIPTablesRules() (string, error) {
return builder.String(), nil return builder.String(), nil
} }
// collectIPSets collects information about ipsets
func collectIPSets() (string, error) {
cmd := exec.Command("ipset", "list")
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
if strings.Contains(err.Error(), "executable file not found") {
return "", fmt.Errorf("ipset command not found: %w", err)
}
return "", fmt.Errorf("execute ipset list: %w (stderr: %s)", err, stderr.String())
}
ipsets := stdout.String()
if strings.TrimSpace(ipsets) == "" {
return "No ipsets found", nil
}
return ipsets, nil
}
// collectIPTablesSave uses iptables-save to get rule definitions // collectIPTablesSave uses iptables-save to get rule definitions
func collectIPTablesSave() (string, error) { func collectIPTablesSave() (string, error) {
cmd := exec.Command("iptables-save") cmd := exec.Command("iptables-save")
@@ -182,12 +205,10 @@ func formatTables(conn *nftables.Conn, tables []*nftables.Table) string {
continue continue
} }
// Format chains
for _, chain := range chains { for _, chain := range chains {
formatChain(conn, table, chain, &builder) formatChain(conn, table, chain, &builder)
} }
// Format sets
if sets, err := conn.GetSets(table); err != nil { if sets, err := conn.GetSets(table); err != nil {
log.Warnf("Failed to get sets for table %s: %v", table.Name, err) log.Warnf("Failed to get sets for table %s: %v", table.Name, err)
} else if len(sets) > 0 { } else if len(sets) > 0 {

View File

@@ -0,0 +1,7 @@
//go:build ios || android
package debug
func (g *BundleGenerator) addRoutes() error {
return nil
}

View File

@@ -0,0 +1,8 @@
//go:build !linux || android
package debug
// collectFirewallRules returns nothing on non-linux systems
func (g *BundleGenerator) addFirewallRules() error {
return nil
}

View File

@@ -0,0 +1,25 @@
//go:build !ios && !android
package debug
import (
"fmt"
"strings"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
func (g *BundleGenerator) addRoutes() error {
routes, err := systemops.GetRoutesFromTable()
if err != nil {
return fmt.Errorf("get routes: %w", err)
}
// TODO: get routes including nexthop
routesContent := formatRoutes(routes, g.anonymize, g.anonymizer)
routesReader := strings.NewReader(routesContent)
if err := g.addFileToZip(routesReader, "routes.txt"); err != nil {
return fmt.Errorf("add routes file to zip: %w", err)
}
return nil
}

View File

@@ -0,0 +1,543 @@
package debug
import (
"encoding/json"
"net"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/anonymize"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
func TestAnonymizeStateFile(t *testing.T) {
testState := map[string]json.RawMessage{
"null_state": json.RawMessage("null"),
"test_state": mustMarshal(map[string]any{
// Test simple fields
"public_ip": "203.0.113.1",
"private_ip": "192.168.1.1",
"protected_ip": "100.64.0.1",
"well_known_ip": "8.8.8.8",
"ipv6_addr": "2001:db8::1",
"private_ipv6": "fd00::1",
"domain": "test.example.com",
"uri": "stun:stun.example.com:3478",
"uri_with_ip": "turn:203.0.113.1:3478",
"netbird_domain": "device.netbird.cloud",
// Test CIDR ranges
"public_cidr": "203.0.113.0/24",
"private_cidr": "192.168.0.0/16",
"protected_cidr": "100.64.0.0/10",
"ipv6_cidr": "2001:db8::/32",
"private_ipv6_cidr": "fd00::/8",
// Test nested structures
"nested": map[string]any{
"ip": "203.0.113.2",
"domain": "nested.example.com",
"more_nest": map[string]any{
"ip": "203.0.113.3",
"domain": "deep.example.com",
},
},
// Test arrays
"string_array": []any{
"203.0.113.4",
"test1.example.com",
"test2.example.com",
},
"object_array": []any{
map[string]any{
"ip": "203.0.113.5",
"domain": "array1.example.com",
},
map[string]any{
"ip": "203.0.113.6",
"domain": "array2.example.com",
},
},
// Test multiple occurrences of same value
"duplicate_ip": "203.0.113.1", // Same as public_ip
"duplicate_domain": "test.example.com", // Same as domain
// Test URIs with various schemes
"stun_uri": "stun:stun.example.com:3478",
"turns_uri": "turns:turns.example.com:5349",
"http_uri": "http://web.example.com:80",
"https_uri": "https://secure.example.com:443",
// Test strings that might look like IPs but aren't
"not_ip": "300.300.300.300",
"partial_ip": "192.168",
"ip_like_string": "1234.5678",
// Test mixed content strings
"mixed_content": "Server at 203.0.113.1 (test.example.com) on port 80",
// Test empty and special values
"empty_string": "",
"null_value": nil,
"numeric_value": 42,
"boolean_value": true,
}),
"route_state": mustMarshal(map[string]any{
"routes": []any{
map[string]any{
"network": "203.0.113.0/24",
"gateway": "203.0.113.1",
"domains": []any{
"route1.example.com",
"route2.example.com",
},
},
map[string]any{
"network": "2001:db8::/32",
"gateway": "2001:db8::1",
"domains": []any{
"route3.example.com",
"route4.example.com",
},
},
},
// Test map with IP/CIDR keys
"refCountMap": map[string]any{
"203.0.113.1/32": map[string]any{
"Count": 1,
"Out": map[string]any{
"IP": "192.168.0.1",
"Intf": map[string]any{
"Name": "eth0",
"Index": 1,
},
},
},
"2001:db8::1/128": map[string]any{
"Count": 1,
"Out": map[string]any{
"IP": "fe80::1",
"Intf": map[string]any{
"Name": "eth0",
"Index": 1,
},
},
},
"10.0.0.1/32": map[string]any{ // private IP should remain unchanged
"Count": 1,
"Out": map[string]any{
"IP": "192.168.0.1",
},
},
},
}),
}
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
// Pre-seed the domains we need to verify in the test assertions
anonymizer.AnonymizeDomain("test.example.com")
anonymizer.AnonymizeDomain("nested.example.com")
anonymizer.AnonymizeDomain("deep.example.com")
anonymizer.AnonymizeDomain("array1.example.com")
err := anonymizeStateFile(&testState, anonymizer)
require.NoError(t, err)
// Helper function to unmarshal and get nested values
var state map[string]any
err = json.Unmarshal(testState["test_state"], &state)
require.NoError(t, err)
// Test null state remains unchanged
require.Equal(t, "null", string(testState["null_state"]))
// Basic assertions
assert.NotEqual(t, "203.0.113.1", state["public_ip"])
assert.Equal(t, "192.168.1.1", state["private_ip"]) // Private IP unchanged
assert.Equal(t, "100.64.0.1", state["protected_ip"]) // Protected IP unchanged
assert.Equal(t, "8.8.8.8", state["well_known_ip"]) // Well-known IP unchanged
assert.NotEqual(t, "2001:db8::1", state["ipv6_addr"])
assert.Equal(t, "fd00::1", state["private_ipv6"]) // Private IPv6 unchanged
assert.NotEqual(t, "test.example.com", state["domain"])
assert.True(t, strings.HasSuffix(state["domain"].(string), ".domain"))
assert.Equal(t, "device.netbird.cloud", state["netbird_domain"]) // Netbird domain unchanged
// CIDR ranges
assert.NotEqual(t, "203.0.113.0/24", state["public_cidr"])
assert.Contains(t, state["public_cidr"], "/24") // Prefix preserved
assert.Equal(t, "192.168.0.0/16", state["private_cidr"]) // Private CIDR unchanged
assert.Equal(t, "100.64.0.0/10", state["protected_cidr"]) // Protected CIDR unchanged
assert.NotEqual(t, "2001:db8::/32", state["ipv6_cidr"])
assert.Contains(t, state["ipv6_cidr"], "/32") // IPv6 prefix preserved
// Nested structures
nested := state["nested"].(map[string]any)
assert.NotEqual(t, "203.0.113.2", nested["ip"])
assert.NotEqual(t, "nested.example.com", nested["domain"])
moreNest := nested["more_nest"].(map[string]any)
assert.NotEqual(t, "203.0.113.3", moreNest["ip"])
assert.NotEqual(t, "deep.example.com", moreNest["domain"])
// Arrays
strArray := state["string_array"].([]any)
assert.NotEqual(t, "203.0.113.4", strArray[0])
assert.NotEqual(t, "test1.example.com", strArray[1])
assert.True(t, strings.HasSuffix(strArray[1].(string), ".domain"))
objArray := state["object_array"].([]any)
firstObj := objArray[0].(map[string]any)
assert.NotEqual(t, "203.0.113.5", firstObj["ip"])
assert.NotEqual(t, "array1.example.com", firstObj["domain"])
// Duplicate values should be anonymized consistently
assert.Equal(t, state["public_ip"], state["duplicate_ip"])
assert.Equal(t, state["domain"], state["duplicate_domain"])
// URIs
assert.NotContains(t, state["stun_uri"], "stun.example.com")
assert.NotContains(t, state["turns_uri"], "turns.example.com")
assert.NotContains(t, state["http_uri"], "web.example.com")
assert.NotContains(t, state["https_uri"], "secure.example.com")
// Non-IP strings should remain unchanged
assert.Equal(t, "300.300.300.300", state["not_ip"])
assert.Equal(t, "192.168", state["partial_ip"])
assert.Equal(t, "1234.5678", state["ip_like_string"])
// Mixed content should have IPs and domains replaced
mixedContent := state["mixed_content"].(string)
assert.NotContains(t, mixedContent, "203.0.113.1")
assert.NotContains(t, mixedContent, "test.example.com")
assert.Contains(t, mixedContent, "Server at ")
assert.Contains(t, mixedContent, " on port 80")
// Special values should remain unchanged
assert.Equal(t, "", state["empty_string"])
assert.Nil(t, state["null_value"])
assert.Equal(t, float64(42), state["numeric_value"])
assert.Equal(t, true, state["boolean_value"])
// Check route state
var routeState map[string]any
err = json.Unmarshal(testState["route_state"], &routeState)
require.NoError(t, err)
routes := routeState["routes"].([]any)
route1 := routes[0].(map[string]any)
assert.NotEqual(t, "203.0.113.0/24", route1["network"])
assert.Contains(t, route1["network"], "/24")
assert.NotEqual(t, "203.0.113.1", route1["gateway"])
domains := route1["domains"].([]any)
assert.True(t, strings.HasSuffix(domains[0].(string), ".domain"))
assert.True(t, strings.HasSuffix(domains[1].(string), ".domain"))
// Check map keys are anonymized
refCountMap := routeState["refCountMap"].(map[string]any)
hasPublicIPKey := false
hasIPv6Key := false
hasPrivateIPKey := false
for key := range refCountMap {
if strings.Contains(key, "203.0.113.1") {
hasPublicIPKey = true
}
if strings.Contains(key, "2001:db8::1") {
hasIPv6Key = true
}
if key == "10.0.0.1/32" {
hasPrivateIPKey = true
}
}
assert.False(t, hasPublicIPKey, "public IP in key should be anonymized")
assert.False(t, hasIPv6Key, "IPv6 in key should be anonymized")
assert.True(t, hasPrivateIPKey, "private IP in key should remain unchanged")
}
func mustMarshal(v any) json.RawMessage {
data, err := json.Marshal(v)
if err != nil {
panic(err)
}
return data
}
func TestAnonymizeNetworkMap(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
PeerConfig: &mgmProto.PeerConfig{
Address: "203.0.113.5",
Dns: "1.2.3.4",
Fqdn: "peer1.corp.example.com",
SshConfig: &mgmProto.SSHConfig{
SshPubKey: []byte("ssh-rsa AAAAB3NzaC1..."),
},
},
RemotePeers: []*mgmProto.RemotePeerConfig{
{
AllowedIps: []string{
"203.0.113.1/32",
"2001:db8:1234::1/128",
"192.168.1.1/32",
"100.64.0.1/32",
"10.0.0.1/32",
},
Fqdn: "peer2.corp.example.com",
SshConfig: &mgmProto.SSHConfig{
SshPubKey: []byte("ssh-rsa AAAAB3NzaC2..."),
},
},
},
Routes: []*mgmProto.Route{
{
Network: "197.51.100.0/24",
Domains: []string{"prod.example.com", "staging.example.com"},
NetID: "net-123abc",
},
},
DNSConfig: &mgmProto.DNSConfig{
NameServerGroups: []*mgmProto.NameServerGroup{
{
NameServers: []*mgmProto.NameServer{
{IP: "8.8.8.8"},
{IP: "1.1.1.1"},
{IP: "203.0.113.53"},
},
Domains: []string{"example.com", "internal.example.com"},
},
},
CustomZones: []*mgmProto.CustomZone{
{
Domain: "custom.example.com",
Records: []*mgmProto.SimpleRecord{
{
Name: "www.custom.example.com",
Type: 1,
RData: "203.0.113.10",
},
{
Name: "internal.custom.example.com",
Type: 1,
RData: "192.168.1.10",
},
},
},
},
},
}
// Create anonymizer with test addresses
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
// Anonymize the network map
err := anonymizeNetworkMap(networkMap, anonymizer)
require.NoError(t, err)
// Test PeerConfig anonymization
peerCfg := networkMap.PeerConfig
require.NotEqual(t, "203.0.113.5", peerCfg.Address)
// Verify DNS and FQDN are properly anonymized
require.NotEqual(t, "1.2.3.4", peerCfg.Dns)
require.NotEqual(t, "peer1.corp.example.com", peerCfg.Fqdn)
require.True(t, strings.HasSuffix(peerCfg.Fqdn, ".domain"))
// Verify SSH key is replaced
require.Equal(t, []byte("ssh-placeholder-key"), peerCfg.SshConfig.SshPubKey)
// Test RemotePeers anonymization
remotePeer := networkMap.RemotePeers[0]
// Verify FQDN is anonymized
require.NotEqual(t, "peer2.corp.example.com", remotePeer.Fqdn)
require.True(t, strings.HasSuffix(remotePeer.Fqdn, ".domain"))
// Check that public IPs are anonymized but private IPs are preserved
for _, allowedIP := range remotePeer.AllowedIps {
ip, _, err := net.ParseCIDR(allowedIP)
require.NoError(t, err)
if ip.IsPrivate() || isInCGNATRange(ip) {
require.Contains(t, []string{
"192.168.1.1/32",
"100.64.0.1/32",
"10.0.0.1/32",
}, allowedIP)
} else {
require.NotContains(t, []string{
"203.0.113.1/32",
"2001:db8:1234::1/128",
}, allowedIP)
}
}
// Test Routes anonymization
route := networkMap.Routes[0]
require.NotEqual(t, "197.51.100.0/24", route.Network)
for _, domain := range route.Domains {
require.True(t, strings.HasSuffix(domain, ".domain"))
require.NotContains(t, domain, "example.com")
}
// Test DNS config anonymization
dnsConfig := networkMap.DNSConfig
nameServerGroup := dnsConfig.NameServerGroups[0]
// Verify well-known DNS servers are preserved
require.Equal(t, "8.8.8.8", nameServerGroup.NameServers[0].IP)
require.Equal(t, "1.1.1.1", nameServerGroup.NameServers[1].IP)
// Verify public DNS server is anonymized
require.NotEqual(t, "203.0.113.53", nameServerGroup.NameServers[2].IP)
// Verify domains are anonymized
for _, domain := range nameServerGroup.Domains {
require.True(t, strings.HasSuffix(domain, ".domain"))
require.NotContains(t, domain, "example.com")
}
// Test CustomZones anonymization
customZone := dnsConfig.CustomZones[0]
require.True(t, strings.HasSuffix(customZone.Domain, ".domain"))
require.NotContains(t, customZone.Domain, "example.com")
// Verify records are properly anonymized
for _, record := range customZone.Records {
require.True(t, strings.HasSuffix(record.Name, ".domain"))
require.NotContains(t, record.Name, "example.com")
ip := net.ParseIP(record.RData)
if ip != nil {
if !ip.IsPrivate() {
require.NotEqual(t, "203.0.113.10", record.RData)
} else {
require.Equal(t, "192.168.1.10", record.RData)
}
}
}
}
// Helper function to check if IP is in CGNAT range
func isInCGNATRange(ip net.IP) bool {
cgnat := net.IPNet{
IP: net.ParseIP("100.64.0.0"),
Mask: net.CIDRMask(10, 32),
}
return cgnat.Contains(ip)
}
func TestAnonymizeFirewallRules(t *testing.T) {
// TODO: Add ipv6
// Example iptables-save output
iptablesSave := `# Generated by iptables-save v1.8.7 on Thu Dec 19 10:00:00 2024
*filter
:INPUT ACCEPT [0:0]
:FORWARD ACCEPT [0:0]
:OUTPUT ACCEPT [0:0]
-A INPUT -s 192.168.1.0/24 -j ACCEPT
-A INPUT -s 44.192.140.1/32 -j DROP
-A FORWARD -s 10.0.0.0/8 -j DROP
-A FORWARD -s 44.192.140.0/24 -d 52.84.12.34/24 -j ACCEPT
COMMIT
*nat
:PREROUTING ACCEPT [0:0]
:INPUT ACCEPT [0:0]
:OUTPUT ACCEPT [0:0]
:POSTROUTING ACCEPT [0:0]
-A POSTROUTING -s 192.168.100.0/24 -j MASQUERADE
-A PREROUTING -d 44.192.140.10/32 -p tcp -m tcp --dport 80 -j DNAT --to-destination 192.168.1.10:80
COMMIT`
// Example iptables -v -n -L output
iptablesVerbose := `Chain INPUT (policy ACCEPT 0 packets, 0 bytes)
pkts bytes target prot opt in out source destination
0 0 ACCEPT all -- * * 192.168.1.0/24 0.0.0.0/0
100 1024 DROP all -- * * 44.192.140.1 0.0.0.0/0
Chain FORWARD (policy ACCEPT 0 packets, 0 bytes)
pkts bytes target prot opt in out source destination
0 0 DROP all -- * * 10.0.0.0/8 0.0.0.0/0
25 256 ACCEPT all -- * * 44.192.140.0/24 52.84.12.34/24
Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
pkts bytes target prot opt in out source destination`
// Example nftables output
nftablesRules := `table inet filter {
chain input {
type filter hook input priority filter; policy accept;
ip saddr 192.168.1.1 accept
ip saddr 44.192.140.1 drop
}
chain forward {
type filter hook forward priority filter; policy accept;
ip saddr 10.0.0.0/8 drop
ip saddr 44.192.140.0/24 ip daddr 52.84.12.34/24 accept
}
}`
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
// Test iptables-save anonymization
anonIptablesSave := anonymizer.AnonymizeString(iptablesSave)
// Private IP addresses should remain unchanged
assert.Contains(t, anonIptablesSave, "192.168.1.0/24")
assert.Contains(t, anonIptablesSave, "10.0.0.0/8")
assert.Contains(t, anonIptablesSave, "192.168.100.0/24")
assert.Contains(t, anonIptablesSave, "192.168.1.10")
// Public IP addresses should be anonymized to the default range
assert.NotContains(t, anonIptablesSave, "44.192.140.1")
assert.NotContains(t, anonIptablesSave, "44.192.140.0/24")
assert.NotContains(t, anonIptablesSave, "52.84.12.34")
assert.Contains(t, anonIptablesSave, "198.51.100.") // Default anonymous range
// Structure should be preserved
assert.Contains(t, anonIptablesSave, "*filter")
assert.Contains(t, anonIptablesSave, ":INPUT ACCEPT [0:0]")
assert.Contains(t, anonIptablesSave, "COMMIT")
assert.Contains(t, anonIptablesSave, "-j MASQUERADE")
assert.Contains(t, anonIptablesSave, "--dport 80")
// Test iptables verbose output anonymization
anonIptablesVerbose := anonymizer.AnonymizeString(iptablesVerbose)
// Private IP addresses should remain unchanged
assert.Contains(t, anonIptablesVerbose, "192.168.1.0/24")
assert.Contains(t, anonIptablesVerbose, "10.0.0.0/8")
// Public IP addresses should be anonymized to the default range
assert.NotContains(t, anonIptablesVerbose, "44.192.140.1")
assert.NotContains(t, anonIptablesVerbose, "44.192.140.0/24")
assert.NotContains(t, anonIptablesVerbose, "52.84.12.34")
assert.Contains(t, anonIptablesVerbose, "198.51.100.") // Default anonymous range
// Structure and counters should be preserved
assert.Contains(t, anonIptablesVerbose, "Chain INPUT (policy ACCEPT 0 packets, 0 bytes)")
assert.Contains(t, anonIptablesVerbose, "100 1024 DROP")
assert.Contains(t, anonIptablesVerbose, "pkts bytes target")
// Test nftables anonymization
anonNftables := anonymizer.AnonymizeString(nftablesRules)
// Private IP addresses should remain unchanged
assert.Contains(t, anonNftables, "192.168.1.1")
assert.Contains(t, anonNftables, "10.0.0.0/8")
// Public IP addresses should be anonymized to the default range
assert.NotContains(t, anonNftables, "44.192.140.1")
assert.NotContains(t, anonNftables, "44.192.140.0/24")
assert.NotContains(t, anonNftables, "52.84.12.34")
assert.Contains(t, anonNftables, "198.51.100.") // Default anonymous range
// Structure should be preserved
assert.Contains(t, anonNftables, "table inet filter {")
assert.Contains(t, anonNftables, "chain input {")
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
}

View File

@@ -239,7 +239,7 @@ func searchDomains(config HostDNSConfig) []string {
continue continue
} }
listOfDomains = append(listOfDomains, dConf.Domain) listOfDomains = append(listOfDomains, strings.TrimSuffix(dConf.Domain, "."))
} }
return listOfDomains return listOfDomains
} }

View File

@@ -75,12 +75,7 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
} }
// First remove any existing handler with same pattern (case-insensitive) and priority // First remove any existing handler with same pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- { c.removeEntry(origPattern, priority)
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
break
}
}
// Check if handler implements SubdomainMatcher interface // Check if handler implements SubdomainMatcher interface
matchSubdomains := false matchSubdomains := false
@@ -133,30 +128,20 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
pattern = dns.Fqdn(pattern) pattern = dns.Fqdn(pattern)
c.removeEntry(pattern, priority)
}
func (c *HandlerChain) removeEntry(pattern string, priority int) {
// Find and remove handlers matching both original pattern (case-insensitive) and priority // Find and remove handlers matching both original pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- { for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i] entry := c.handlers[i]
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority { if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
return break
} }
} }
} }
// HasHandlers returns true if there are any handlers remaining for the given pattern
func (c *HandlerChain) HasHandlers(pattern string) bool {
c.mu.RLock()
defer c.mu.RUnlock()
pattern = strings.ToLower(dns.Fqdn(pattern))
for _, entry := range c.handlers {
if strings.EqualFold(entry.Pattern, pattern) {
return true
}
}
return false
}
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 { if len(r.Question) == 0 {
return return

View File

@@ -1,7 +1,6 @@
package dns_test package dns_test
import ( import (
"net"
"testing" "testing"
"github.com/miekg/dns" "github.com/miekg/dns"
@@ -9,6 +8,7 @@ import (
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
nbdns "github.com/netbirdio/netbird/client/internal/dns" nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/test"
) )
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order // TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
@@ -30,7 +30,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
r.SetQuestion("example.com.", dns.TypeA) r.SetQuestion("example.com.", dns.TypeA)
// Create test writer // Create test writer
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Setup expectations - only highest priority handler should be called // Setup expectations - only highest priority handler should be called
dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once() dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once()
@@ -142,7 +142,7 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA) r.SetQuestion(tt.queryDomain, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
chain.ServeDNS(w, r) chain.ServeDNS(w, r)
@@ -259,7 +259,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
// Create and execute request // Create and execute request
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.queryDomain, dns.TypeA) r.SetQuestion(tt.queryDomain, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
chain.ServeDNS(w, r) chain.ServeDNS(w, r)
// Verify expectations // Verify expectations
@@ -316,7 +316,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
}).Once() }).Once()
// Execute // Execute
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
chain.ServeDNS(w, r) chain.ServeDNS(w, r)
// Verify all handlers were called in order // Verify all handlers were called in order
@@ -325,20 +325,6 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
handler3.AssertExpectations(t) handler3.AssertExpectations(t)
} }
// mockResponseWriter implements dns.ResponseWriter for testing
type mockResponseWriter struct {
mock.Mock
}
func (m *mockResponseWriter) LocalAddr() net.Addr { return nil }
func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil }
func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil }
func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
func (m *mockResponseWriter) Close() error { return nil }
func (m *mockResponseWriter) TsigStatus() error { return nil }
func (m *mockResponseWriter) TsigTimersOnly(bool) {}
func (m *mockResponseWriter) Hijack() {}
func TestHandlerChain_PriorityDeregistration(t *testing.T) { func TestHandlerChain_PriorityDeregistration(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -425,7 +411,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
// Create test request // Create test request
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA) r.SetQuestion(tt.query, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Setup expectations // Setup expectations
for priority, handler := range handlers { for priority, handler := range handlers {
@@ -443,14 +429,6 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
for _, handler := range handlers { for _, handler := range handlers {
handler.AssertExpectations(t) handler.AssertExpectations(t)
} }
// Verify handler exists check
for priority, shouldExist := range tt.expectedCalls {
if shouldExist {
assert.True(t, chain.HasHandlers(tt.ops[0].pattern),
"Handler chain should have handlers for pattern after removing priority %d", priority)
}
}
}) })
} }
} }
@@ -470,45 +448,69 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(testQuery, dns.TypeA) r.SetQuestion(testQuery, dns.TypeA)
// Keep track of mocks for the final assertion in Step 4
mocks := []*nbdns.MockSubdomainHandler{routeHandler, matchHandler, defaultHandler}
// Add handlers in mixed order // Add handlers in mixed order
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault) chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute) chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain) chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
// Test 1: Initial state with all three handlers // Test 1: Initial state
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Highest priority handler (routeHandler) should be called // Highest priority handler (routeHandler) should be called
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once() routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
matchHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
chain.ServeDNS(w, r) chain.ServeDNS(w1, r)
routeHandler.AssertExpectations(t) routeHandler.AssertExpectations(t)
routeHandler.ExpectedCalls = nil
routeHandler.Calls = nil
matchHandler.ExpectedCalls = nil
matchHandler.Calls = nil
defaultHandler.ExpectedCalls = nil
defaultHandler.Calls = nil
// Test 2: Remove highest priority handler // Test 2: Remove highest priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute) chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
assert.True(t, chain.HasHandlers(testDomain))
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w2 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Now middle priority handler (matchHandler) should be called // Now middle priority handler (matchHandler) should be called
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once() matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure default is not expected yet
chain.ServeDNS(w, r) chain.ServeDNS(w2, r)
matchHandler.AssertExpectations(t) matchHandler.AssertExpectations(t)
matchHandler.ExpectedCalls = nil
matchHandler.Calls = nil
defaultHandler.ExpectedCalls = nil
defaultHandler.Calls = nil
// Test 3: Remove middle priority handler // Test 3: Remove middle priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain) chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
assert.True(t, chain.HasHandlers(testDomain))
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Now lowest priority handler (defaultHandler) should be called // Now lowest priority handler (defaultHandler) should be called
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once() defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
chain.ServeDNS(w, r) chain.ServeDNS(w3, r)
defaultHandler.AssertExpectations(t) defaultHandler.AssertExpectations(t)
defaultHandler.ExpectedCalls = nil
defaultHandler.Calls = nil
// Test 4: Remove last handler // Test 4: Remove last handler
chain.RemoveHandler(testDomain, nbdns.PriorityDefault) chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
assert.False(t, chain.HasHandlers(testDomain)) w4 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
chain.ServeDNS(w4, r) // Call ServeDNS on the now empty chain for this domain
for _, m := range mocks {
m.AssertNumberOfCalls(t, "ServeDNS", 0)
}
} }
func TestHandlerChain_CaseSensitivity(t *testing.T) { func TestHandlerChain_CaseSensitivity(t *testing.T) {
@@ -659,7 +661,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
// Execute request // Execute request
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA) r.SetQuestion(tt.query, dns.TypeA)
chain.ServeDNS(&mockResponseWriter{}, r) chain.ServeDNS(&test.MockResponseWriter{}, r)
// Verify each handler was called exactly as expected // Verify each handler was called exactly as expected
for _, h := range tt.addHandlers { for _, h := range tt.addHandlers {
@@ -803,7 +805,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
r := new(dns.Msg) r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA) r.SetQuestion(tt.query, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Setup handler expectations // Setup handler expectations
for pattern, handler := range handlers { for pattern, handler := range handlers {
@@ -830,3 +832,165 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
}) })
} }
} }
func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
tests := []struct {
name string
addPattern string
removePattern string
queryPattern string
shouldBeRemoved bool
description string
}{
{
name: "exact same pattern",
addPattern: "example.com.",
removePattern: "example.com.",
queryPattern: "example.com.",
shouldBeRemoved: true,
description: "Adding and removing with identical patterns",
},
{
name: "case difference",
addPattern: "Example.Com.",
removePattern: "EXAMPLE.COM.",
queryPattern: "example.com.",
shouldBeRemoved: true,
description: "Adding with mixed case, removing with uppercase",
},
{
name: "reversed case difference",
addPattern: "EXAMPLE.ORG.",
removePattern: "example.org.",
queryPattern: "example.org.",
shouldBeRemoved: true,
description: "Adding with uppercase, removing with lowercase",
},
{
name: "add wildcard, remove wildcard",
addPattern: "*.example.com.",
removePattern: "*.example.com.",
queryPattern: "sub.example.com.",
shouldBeRemoved: true,
description: "Adding and removing with identical wildcard patterns",
},
{
name: "add wildcard, remove transformed pattern",
addPattern: "*.example.net.",
removePattern: "example.net.",
queryPattern: "sub.example.net.",
shouldBeRemoved: false,
description: "Adding with wildcard, removing with non-wildcard pattern",
},
{
name: "add transformed pattern, remove wildcard",
addPattern: "example.io.",
removePattern: "*.example.io.",
queryPattern: "example.io.",
shouldBeRemoved: false,
description: "Adding with non-wildcard pattern, removing with wildcard pattern",
},
{
name: "trailing dot difference",
addPattern: "example.dev",
removePattern: "example.dev.",
queryPattern: "example.dev.",
shouldBeRemoved: true,
description: "Adding without trailing dot, removing with trailing dot",
},
{
name: "reversed trailing dot difference",
addPattern: "example.app.",
removePattern: "example.app",
queryPattern: "example.app.",
shouldBeRemoved: true,
description: "Adding with trailing dot, removing without trailing dot",
},
{
name: "mixed case and wildcard",
addPattern: "*.Example.Site.",
removePattern: "*.EXAMPLE.SITE.",
queryPattern: "sub.example.site.",
shouldBeRemoved: true,
description: "Adding mixed case wildcard, removing uppercase wildcard",
},
{
name: "root zone",
addPattern: ".",
removePattern: ".",
queryPattern: "random.domain.",
shouldBeRemoved: true,
description: "Adding and removing root zone",
},
{
name: "wrong domain",
addPattern: "example.com.",
removePattern: "different.com.",
queryPattern: "example.com.",
shouldBeRemoved: false,
description: "Adding one domain, trying to remove a different domain",
},
{
name: "subdomain mismatch",
addPattern: "sub.example.com.",
removePattern: "example.com.",
queryPattern: "sub.example.com.",
shouldBeRemoved: false,
description: "Adding subdomain, trying to remove parent domain",
},
{
name: "parent domain mismatch",
addPattern: "example.com.",
removePattern: "sub.example.com.",
queryPattern: "example.com.",
shouldBeRemoved: false,
description: "Adding parent domain, trying to remove subdomain",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
handler := &nbdns.MockHandler{}
r := new(dns.Msg)
r.SetQuestion(tt.queryPattern, dns.TypeA)
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// First verify no handler is called before adding any
chain.ServeDNS(w, r)
handler.AssertNotCalled(t, "ServeDNS")
// Add handler
chain.AddHandler(tt.addPattern, handler, nbdns.PriorityDefault)
// Verify handler is called after adding
handler.On("ServeDNS", mock.Anything, r).Once()
chain.ServeDNS(w, r)
handler.AssertExpectations(t)
// Reset mock for the next test
handler.ExpectedCalls = nil
// Remove handler
chain.RemoveHandler(tt.removePattern, nbdns.PriorityDefault)
// Set up expectations based on whether removal should succeed
if !tt.shouldBeRemoved {
handler.On("ServeDNS", mock.Anything, r).Once()
}
// Test if handler is still called after removal attempt
chain.ServeDNS(w, r)
if tt.shouldBeRemoved {
handler.AssertNotCalled(t, "ServeDNS",
"Handler should not be called after successful removal with pattern %q",
tt.removePattern)
} else {
handler.AssertExpectations(t)
handler.ExpectedCalls = nil
}
})
}
}

View File

@@ -5,6 +5,8 @@ import (
"net/netip" "net/netip"
"strings" "strings"
"github.com/miekg/dns"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
) )
@@ -12,8 +14,8 @@ import (
var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured") var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
const ( const (
ipv4ReverseZone = ".in-addr.arpa" ipv4ReverseZone = ".in-addr.arpa."
ipv6ReverseZone = ".ip6.arpa" ipv6ReverseZone = ".ip6.arpa."
) )
type hostManager interface { type hostManager interface {
@@ -103,7 +105,7 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
for _, domain := range nsConfig.Domains { for _, domain := range nsConfig.Domains {
config.Domains = append(config.Domains, DomainConfig{ config.Domains = append(config.Domains, DomainConfig{
Domain: strings.TrimSuffix(domain, "."), Domain: strings.ToLower(dns.Fqdn(domain)),
MatchOnly: !nsConfig.SearchDomainsEnabled, MatchOnly: !nsConfig.SearchDomainsEnabled,
}) })
} }
@@ -112,7 +114,7 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
for _, customZone := range dnsConfig.CustomZones { for _, customZone := range dnsConfig.CustomZones {
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone) matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
config.Domains = append(config.Domains, DomainConfig{ config.Domains = append(config.Domains, DomainConfig{
Domain: strings.TrimSuffix(customZone.Domain, "."), Domain: strings.ToLower(dns.Fqdn(customZone.Domain)),
MatchOnly: matchOnly, MatchOnly: matchOnly,
}) })
} }

View File

@@ -79,10 +79,10 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
continue continue
} }
if dConf.MatchOnly { if dConf.MatchOnly {
matchDomains = append(matchDomains, dConf.Domain) matchDomains = append(matchDomains, strings.TrimSuffix(dConf.Domain, "."))
continue continue
} }
searchDomains = append(searchDomains, dConf.Domain) searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, "."))
} }
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)

View File

@@ -17,15 +17,18 @@ import (
var ( var (
userenv = syscall.NewLazyDLL("userenv.dll") userenv = syscall.NewLazyDLL("userenv.dll")
dnsapi = syscall.NewLazyDLL("dnsapi.dll")
// https://learn.microsoft.com/en-us/windows/win32/api/userenv/nf-userenv-refreshpolicyex // https://learn.microsoft.com/en-us/windows/win32/api/userenv/nf-userenv-refreshpolicyex
refreshPolicyExFn = userenv.NewProc("RefreshPolicyEx") refreshPolicyExFn = userenv.NewProc("RefreshPolicyEx")
dnsFlushResolverCacheFn = dnsapi.NewProc("DnsFlushResolverCache")
) )
const ( const (
dnsPolicyConfigMatchPath = `SYSTEM\CurrentControlSet\Services\Dnscache\Parameters\DnsPolicyConfig\NetBird-Match` dnsPolicyConfigMatchPath = `SYSTEM\CurrentControlSet\Services\Dnscache\Parameters\DnsPolicyConfig\NetBird-Match`
gpoDnsPolicyRoot = `SOFTWARE\Policies\Microsoft\Windows NT\DNSClient` gpoDnsPolicyRoot = `SOFTWARE\Policies\Microsoft\Windows NT\DNSClient\DnsPolicyConfig`
gpoDnsPolicyConfigMatchPath = gpoDnsPolicyRoot + `\DnsPolicyConfig\NetBird-Match` gpoDnsPolicyConfigMatchPath = gpoDnsPolicyRoot + `\NetBird-Match`
dnsPolicyConfigVersionKey = "Version" dnsPolicyConfigVersionKey = "Version"
dnsPolicyConfigVersionValue = 2 dnsPolicyConfigVersionValue = 2
@@ -97,9 +100,9 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
continue continue
} }
if !dConf.MatchOnly { if !dConf.MatchOnly {
searchDomains = append(searchDomains, dConf.Domain) searchDomains = append(searchDomains, strings.TrimSuffix(dConf.Domain, "."))
} }
matchDomains = append(matchDomains, "."+dConf.Domain) matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, "."))
} }
if len(matchDomains) != 0 { if len(matchDomains) != 0 {
@@ -116,6 +119,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
return fmt.Errorf("update search domains: %w", err) return fmt.Errorf("update search domains: %w", err)
} }
if err := r.flushDNSCache(); err != nil {
log.Errorf("failed to flush DNS cache: %v", err)
}
return nil return nil
} }
@@ -136,10 +143,6 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) er
return fmt.Errorf("configure GPO DNS policy: %w", err) return fmt.Errorf("configure GPO DNS policy: %w", err)
} }
if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, domains, ip); err != nil {
return fmt.Errorf("configure local DNS policy: %w", err)
}
if err := refreshGroupPolicy(); err != nil { if err := refreshGroupPolicy(); err != nil {
log.Warnf("failed to refresh group policy: %v", err) log.Warnf("failed to refresh group policy: %v", err)
} }
@@ -188,6 +191,26 @@ func (r *registryConfigurator) string() string {
return "registry" return "registry"
} }
func (r *registryConfigurator) flushDNSCache() error {
// dnsFlushResolverCacheFn.Call() may panic if the func is not found
defer func() {
if rec := recover(); rec != nil {
log.Errorf("Recovered from panic in flushDNSCache: %v", rec)
}
}()
ret, _, err := dnsFlushResolverCacheFn.Call()
if ret == 0 {
if err != nil && !errors.Is(err, syscall.Errno(0)) {
return fmt.Errorf("DnsFlushResolverCache failed: %w", err)
}
return fmt.Errorf("DnsFlushResolverCache failed")
}
log.Info("flushed DNS cache")
return nil
}
func (r *registryConfigurator) updateSearchDomains(domains []string) error { func (r *registryConfigurator) updateSearchDomains(domains []string) error {
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ",")); err != nil { if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ",")); err != nil {
return fmt.Errorf("update search domains: %w", err) return fmt.Errorf("update search domains: %w", err)
@@ -240,6 +263,10 @@ func (r *registryConfigurator) restoreHostDNS() error {
return fmt.Errorf("remove interface registry key: %w", err) return fmt.Errorf("remove interface registry key: %w", err)
} }
if err := r.flushDNSCache(); err != nil {
log.Errorf("failed to flush DNS cache: %v", err)
}
return nil return nil
} }

View File

@@ -1,124 +0,0 @@
package dns
import (
"fmt"
"strings"
"sync"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
)
type registrationMap map[string]struct{}
type localResolver struct {
registeredMap registrationMap
records sync.Map // key: string (domain_class_type), value: []dns.RR
}
func (d *localResolver) MatchSubdomains() bool {
return true
}
func (d *localResolver) stop() {
}
// String returns a string representation of the local resolver
func (d *localResolver) String() string {
return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap))
}
// ID returns the unique handler ID
func (d *localResolver) id() handlerID {
return "local-resolver"
}
// ServeDNS handles a DNS request
func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) > 0 {
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
}
replyMessage := &dns.Msg{}
replyMessage.SetReply(r)
replyMessage.RecursionAvailable = true
// lookup all records matching the question
records := d.lookupRecords(r)
if len(records) > 0 {
replyMessage.Rcode = dns.RcodeSuccess
replyMessage.Answer = append(replyMessage.Answer, records...)
} else {
replyMessage.Rcode = dns.RcodeNameError
}
err := w.WriteMsg(replyMessage)
if err != nil {
log.Debugf("got an error while writing the local resolver response, error: %v", err)
}
}
// lookupRecords fetches *all* DNS records matching the first question in r.
func (d *localResolver) lookupRecords(r *dns.Msg) []dns.RR {
if len(r.Question) == 0 {
return nil
}
question := r.Question[0]
question.Name = strings.ToLower(question.Name)
key := buildRecordKey(question.Name, question.Qclass, question.Qtype)
value, found := d.records.Load(key)
if !found {
return nil
}
records, ok := value.([]dns.RR)
if !ok {
log.Errorf("failed to cast records to []dns.RR, records: %v", value)
return nil
}
// if there's more than one record, rotate them (round-robin)
if len(records) > 1 {
first := records[0]
records = append(records[1:], first)
d.records.Store(key, records)
}
return records
}
// registerRecord stores a new record by appending it to any existing list
func (d *localResolver) registerRecord(record nbdns.SimpleRecord) (string, error) {
rr, err := dns.NewRR(record.String())
if err != nil {
return "", fmt.Errorf("register record: %w", err)
}
rr.Header().Rdlength = record.Len()
header := rr.Header()
key := buildRecordKey(header.Name, header.Class, header.Rrtype)
// load any existing slice of records, then append
existing, _ := d.records.LoadOrStore(key, []dns.RR{})
records := existing.([]dns.RR)
records = append(records, rr)
// store updated slice
d.records.Store(key, records)
return key, nil
}
// deleteRecord removes *all* records under the recordKey.
func (d *localResolver) deleteRecord(recordKey string) {
d.records.Delete(dns.Fqdn(recordKey))
}
// buildRecordKey consistently generates a key: name_class_type
func buildRecordKey(name string, class, qType uint16) string {
return fmt.Sprintf("%s_%d_%d", dns.Fqdn(name), class, qType)
}
func (d *localResolver) probeAvailability() {}

View File

@@ -0,0 +1,149 @@
package local
import (
"fmt"
"slices"
"strings"
"sync"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/dns/types"
nbdns "github.com/netbirdio/netbird/dns"
)
type Resolver struct {
mu sync.RWMutex
records map[dns.Question][]dns.RR
}
func NewResolver() *Resolver {
return &Resolver{
records: make(map[dns.Question][]dns.RR),
}
}
func (d *Resolver) MatchSubdomains() bool {
return true
}
// String returns a string representation of the local resolver
func (d *Resolver) String() string {
return fmt.Sprintf("local resolver [%d records]", len(d.records))
}
func (d *Resolver) Stop() {}
// ID returns the unique handler ID
func (d *Resolver) ID() types.HandlerID {
return "local-resolver"
}
func (d *Resolver) ProbeAvailability() {}
// ServeDNS handles a DNS request
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
log.Debugf("received local resolver request with no question")
return
}
question := r.Question[0]
question.Name = strings.ToLower(dns.Fqdn(question.Name))
log.Tracef("received local question: domain=%s type=%v class=%v", r.Question[0].Name, question.Qtype, question.Qclass)
replyMessage := &dns.Msg{}
replyMessage.SetReply(r)
replyMessage.RecursionAvailable = true
// lookup all records matching the question
records := d.lookupRecords(question)
if len(records) > 0 {
replyMessage.Rcode = dns.RcodeSuccess
replyMessage.Answer = append(replyMessage.Answer, records...)
} else {
// TODO: return success if we have a different record type for the same name, relevant for search domains
replyMessage.Rcode = dns.RcodeNameError
}
if err := w.WriteMsg(replyMessage); err != nil {
log.Warnf("failed to write the local resolver response: %v", err)
}
}
// lookupRecords fetches *all* DNS records matching the first question in r.
func (d *Resolver) lookupRecords(question dns.Question) []dns.RR {
d.mu.RLock()
records, found := d.records[question]
if !found {
d.mu.RUnlock()
// alternatively check if we have a cname
if question.Qtype != dns.TypeCNAME {
question.Qtype = dns.TypeCNAME
return d.lookupRecords(question)
}
return nil
}
recordsCopy := slices.Clone(records)
d.mu.RUnlock()
// if there's more than one record, rotate them (round-robin)
if len(recordsCopy) > 1 {
d.mu.Lock()
records = d.records[question]
if len(records) > 1 {
first := records[0]
records = append(records[1:], first)
d.records[question] = records
}
d.mu.Unlock()
}
return recordsCopy
}
func (d *Resolver) Update(update []nbdns.SimpleRecord) {
d.mu.Lock()
defer d.mu.Unlock()
maps.Clear(d.records)
for _, rec := range update {
if err := d.registerRecord(rec); err != nil {
log.Warnf("failed to register the record (%s): %v", rec, err)
continue
}
}
}
// RegisterRecord stores a new record by appending it to any existing list
func (d *Resolver) RegisterRecord(record nbdns.SimpleRecord) error {
d.mu.Lock()
defer d.mu.Unlock()
return d.registerRecord(record)
}
// registerRecord performs the registration with the lock already held
func (d *Resolver) registerRecord(record nbdns.SimpleRecord) error {
rr, err := dns.NewRR(record.String())
if err != nil {
return fmt.Errorf("register record: %w", err)
}
rr.Header().Rdlength = record.Len()
header := rr.Header()
q := dns.Question{
Name: strings.ToLower(dns.Fqdn(header.Name)),
Qtype: header.Rrtype,
Qclass: header.Class,
}
d.records[q] = append(d.records[q], rr)
return nil
}

View File

@@ -0,0 +1,472 @@
package local
import (
"strings"
"testing"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/dns/test"
nbdns "github.com/netbirdio/netbird/dns"
)
func TestLocalResolver_ServeDNS(t *testing.T) {
recordA := nbdns.SimpleRecord{
Name: "peera.netbird.cloud.",
Type: 1,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "1.2.3.4",
}
recordCNAME := nbdns.SimpleRecord{
Name: "peerb.netbird.cloud.",
Type: 5,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "www.netbird.io",
}
testCases := []struct {
name string
inputRecord nbdns.SimpleRecord
inputMSG *dns.Msg
responseShouldBeNil bool
}{
{
name: "Should Resolve A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA),
},
{
name: "Should Resolve CNAME Record",
inputRecord: recordCNAME,
inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME),
},
{
name: "Should Not Write When Not Found A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
responseShouldBeNil: true,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
resolver := NewResolver()
_ = resolver.RegisterRecord(testCase.inputRecord)
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
resolver.ServeDNS(responseWriter, testCase.inputMSG)
if responseMSG == nil || len(responseMSG.Answer) == 0 {
if testCase.responseShouldBeNil {
return
}
t.Fatalf("should write a response message")
}
answerString := responseMSG.Answer[0].String()
if !strings.Contains(answerString, testCase.inputRecord.Name) {
t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString)
}
if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) {
t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString)
}
if !strings.Contains(answerString, testCase.inputRecord.RData) {
t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString)
}
})
}
}
// TestLocalResolver_Update_StaleRecord verifies that updating
// a record correctly replaces the old one, preventing stale entries.
func TestLocalResolver_Update_StaleRecord(t *testing.T) {
recordName := "host.example.com."
recordType := dns.TypeA
recordClass := dns.ClassINET
record1 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "1.1.1.1",
}
record2 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "2.2.2.2",
}
recordKey := dns.Question{Name: recordName, Qtype: uint16(recordClass), Qclass: recordType}
resolver := NewResolver()
update1 := []nbdns.SimpleRecord{record1}
update2 := []nbdns.SimpleRecord{record2}
// Apply first update
resolver.Update(update1)
// Verify first update
resolver.mu.RLock()
rrSlice1, found1 := resolver.records[recordKey]
resolver.mu.RUnlock()
require.True(t, found1, "Record key %s not found after first update", recordKey)
require.Len(t, rrSlice1, 1, "Should have exactly 1 record after first update")
assert.Contains(t, rrSlice1[0].String(), record1.RData, "Record after first update should be %s", record1.RData)
// Apply second update
resolver.Update(update2)
// Verify second update
resolver.mu.RLock()
rrSlice2, found2 := resolver.records[recordKey]
resolver.mu.RUnlock()
require.True(t, found2, "Record key %s not found after second update", recordKey)
require.Len(t, rrSlice2, 1, "Should have exactly 1 record after update overwriting the key")
assert.Contains(t, rrSlice2[0].String(), record2.RData, "The single record should be the updated one (%s)", record2.RData)
assert.NotContains(t, rrSlice2[0].String(), record1.RData, "The stale record (%s) should not be present", record1.RData)
}
// TestLocalResolver_MultipleRecords_SameQuestion verifies that multiple records
// with the same question are stored properly
func TestLocalResolver_MultipleRecords_SameQuestion(t *testing.T) {
resolver := NewResolver()
recordName := "multi.example.com."
recordType := dns.TypeA
// Create two records with the same name and type but different IPs
record1 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1",
}
record2 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2",
}
update := []nbdns.SimpleRecord{record1, record2}
// Apply update with both records
resolver.Update(update)
// Create question that matches both records
question := dns.Question{
Name: recordName,
Qtype: recordType,
Qclass: dns.ClassINET,
}
// Verify both records are stored
resolver.mu.RLock()
records, found := resolver.records[question]
resolver.mu.RUnlock()
require.True(t, found, "Records for question %v not found", question)
require.Len(t, records, 2, "Should have exactly 2 records for the same question")
// Verify both record data values are present
recordStrings := []string{records[0].String(), records[1].String()}
assert.Contains(t, recordStrings[0]+recordStrings[1], record1.RData, "First record data should be present")
assert.Contains(t, recordStrings[0]+recordStrings[1], record2.RData, "Second record data should be present")
}
// TestLocalResolver_RecordRotation verifies that records are rotated in a round-robin fashion
func TestLocalResolver_RecordRotation(t *testing.T) {
resolver := NewResolver()
recordName := "rotation.example.com."
recordType := dns.TypeA
// Create three records with the same name and type but different IPs
record1 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.1",
}
record2 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.2",
}
record3 := nbdns.SimpleRecord{
Name: recordName, Type: int(recordType), Class: nbdns.DefaultClass, TTL: 300, RData: "192.168.1.3",
}
update := []nbdns.SimpleRecord{record1, record2, record3}
// Apply update with all three records
resolver.Update(update)
msg := new(dns.Msg).SetQuestion(recordName, recordType)
// First lookup - should return the records in original order
var responses [3]*dns.Msg
// Perform three lookups to verify rotation
for i := 0; i < 3; i++ {
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responses[i] = m
return nil
},
}
resolver.ServeDNS(responseWriter, msg)
}
// Verify all three responses contain answers
for i, resp := range responses {
require.NotNil(t, resp, "Response %d should not be nil", i)
require.Len(t, resp.Answer, 3, "Response %d should have 3 answers", i)
}
// Verify the first record in each response is different due to rotation
firstRecordIPs := []string{
responses[0].Answer[0].String(),
responses[1].Answer[0].String(),
responses[2].Answer[0].String(),
}
// Each record should be different (rotated)
assert.NotEqual(t, firstRecordIPs[0], firstRecordIPs[1], "First lookup should differ from second lookup due to rotation")
assert.NotEqual(t, firstRecordIPs[1], firstRecordIPs[2], "Second lookup should differ from third lookup due to rotation")
assert.NotEqual(t, firstRecordIPs[0], firstRecordIPs[2], "First lookup should differ from third lookup due to rotation")
// After three rotations, we should have cycled through all records
assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record1.RData)
assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record2.RData)
assert.Contains(t, firstRecordIPs[0]+firstRecordIPs[1]+firstRecordIPs[2], record3.RData)
}
// TestLocalResolver_CaseInsensitiveMatching verifies that DNS record lookups are case-insensitive
func TestLocalResolver_CaseInsensitiveMatching(t *testing.T) {
resolver := NewResolver()
// Create record with lowercase name
lowerCaseRecord := nbdns.SimpleRecord{
Name: "lower.example.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "10.10.10.10",
}
// Create record with mixed case name
mixedCaseRecord := nbdns.SimpleRecord{
Name: "MiXeD.ExAmPlE.CoM.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "20.20.20.20",
}
// Update resolver with the records
resolver.Update([]nbdns.SimpleRecord{lowerCaseRecord, mixedCaseRecord})
testCases := []struct {
name string
queryName string
expectedRData string
shouldResolve bool
}{
{
name: "Query lowercase with lowercase record",
queryName: "lower.example.com.",
expectedRData: "10.10.10.10",
shouldResolve: true,
},
{
name: "Query uppercase with lowercase record",
queryName: "LOWER.EXAMPLE.COM.",
expectedRData: "10.10.10.10",
shouldResolve: true,
},
{
name: "Query mixed case with lowercase record",
queryName: "LoWeR.eXaMpLe.CoM.",
expectedRData: "10.10.10.10",
shouldResolve: true,
},
{
name: "Query lowercase with mixed case record",
queryName: "mixed.example.com.",
expectedRData: "20.20.20.20",
shouldResolve: true,
},
{
name: "Query uppercase with mixed case record",
queryName: "MIXED.EXAMPLE.COM.",
expectedRData: "20.20.20.20",
shouldResolve: true,
},
{
name: "Query with different casing pattern",
queryName: "mIxEd.ExaMpLe.cOm.",
expectedRData: "20.20.20.20",
shouldResolve: true,
},
{
name: "Query non-existent domain",
queryName: "nonexistent.example.com.",
shouldResolve: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var responseMSG *dns.Msg
// Create DNS query with the test case name
msg := new(dns.Msg).SetQuestion(tc.queryName, dns.TypeA)
// Create mock response writer to capture the response
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
// Perform DNS query
resolver.ServeDNS(responseWriter, msg)
// Check if we expect a successful resolution
if !tc.shouldResolve {
if responseMSG == nil || len(responseMSG.Answer) == 0 {
// Expected no answer, test passes
return
}
t.Fatalf("Expected no resolution for %s, but got answer: %v", tc.queryName, responseMSG.Answer)
}
// Verify we got a response
require.NotNil(t, responseMSG, "Should have received a response message")
require.Greater(t, len(responseMSG.Answer), 0, "Response should contain at least one answer")
// Verify the response contains the expected data
answerString := responseMSG.Answer[0].String()
assert.Contains(t, answerString, tc.expectedRData,
"Answer should contain the expected IP address %s, got: %s",
tc.expectedRData, answerString)
})
}
}
// TestLocalResolver_CNAMEFallback verifies that the resolver correctly falls back
// to checking for CNAME records when the requested record type isn't found
func TestLocalResolver_CNAMEFallback(t *testing.T) {
resolver := NewResolver()
// Create a CNAME record (but no A record for this name)
cnameRecord := nbdns.SimpleRecord{
Name: "alias.example.com.",
Type: int(dns.TypeCNAME),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "target.example.com.",
}
// Create an A record for the CNAME target
targetRecord := nbdns.SimpleRecord{
Name: "target.example.com.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.100.100",
}
// Update resolver with both records
resolver.Update([]nbdns.SimpleRecord{cnameRecord, targetRecord})
testCases := []struct {
name string
queryName string
queryType uint16
expectedType string
expectedRData string
shouldResolve bool
}{
{
name: "Directly query CNAME record",
queryName: "alias.example.com.",
queryType: dns.TypeCNAME,
expectedType: "CNAME",
expectedRData: "target.example.com.",
shouldResolve: true,
},
{
name: "Query A record but get CNAME fallback",
queryName: "alias.example.com.",
queryType: dns.TypeA,
expectedType: "CNAME",
expectedRData: "target.example.com.",
shouldResolve: true,
},
{
name: "Query AAAA record but get CNAME fallback",
queryName: "alias.example.com.",
queryType: dns.TypeAAAA,
expectedType: "CNAME",
expectedRData: "target.example.com.",
shouldResolve: true,
},
{
name: "Query direct A record",
queryName: "target.example.com.",
queryType: dns.TypeA,
expectedType: "A",
expectedRData: "192.168.100.100",
shouldResolve: true,
},
{
name: "Query non-existent name",
queryName: "nonexistent.example.com.",
queryType: dns.TypeA,
shouldResolve: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var responseMSG *dns.Msg
// Create DNS query with the test case parameters
msg := new(dns.Msg).SetQuestion(tc.queryName, tc.queryType)
// Create mock response writer to capture the response
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
// Perform DNS query
resolver.ServeDNS(responseWriter, msg)
// Check if we expect a successful resolution
if !tc.shouldResolve {
if responseMSG == nil || len(responseMSG.Answer) == 0 || responseMSG.Rcode != dns.RcodeSuccess {
// Expected no resolution, test passes
return
}
t.Fatalf("Expected no resolution for %s, but got answer: %v", tc.queryName, responseMSG.Answer)
}
// Verify we got a successful response
require.NotNil(t, responseMSG, "Should have received a response message")
require.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "Response should have success status code")
require.Greater(t, len(responseMSG.Answer), 0, "Response should contain at least one answer")
// Verify the response contains the expected data
answerString := responseMSG.Answer[0].String()
assert.Contains(t, answerString, tc.expectedType,
"Answer should be of type %s, got: %s", tc.expectedType, answerString)
assert.Contains(t, answerString, tc.expectedRData,
"Answer should contain the expected data %s, got: %s", tc.expectedRData, answerString)
})
}
}

View File

@@ -1,88 +0,0 @@
package dns
import (
"strings"
"testing"
"github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns"
)
func TestLocalResolver_ServeDNS(t *testing.T) {
recordA := nbdns.SimpleRecord{
Name: "peera.netbird.cloud.",
Type: 1,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "1.2.3.4",
}
recordCNAME := nbdns.SimpleRecord{
Name: "peerb.netbird.cloud.",
Type: 5,
Class: nbdns.DefaultClass,
TTL: 300,
RData: "www.netbird.io",
}
testCases := []struct {
name string
inputRecord nbdns.SimpleRecord
inputMSG *dns.Msg
responseShouldBeNil bool
}{
{
name: "Should Resolve A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion(recordA.Name, dns.TypeA),
},
{
name: "Should Resolve CNAME Record",
inputRecord: recordCNAME,
inputMSG: new(dns.Msg).SetQuestion(recordCNAME.Name, dns.TypeCNAME),
},
{
name: "Should Not Write When Not Found A Record",
inputRecord: recordA,
inputMSG: new(dns.Msg).SetQuestion("not.found.com", dns.TypeA),
responseShouldBeNil: true,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
resolver := &localResolver{
registeredMap: make(registrationMap),
}
_, _ = resolver.registerRecord(testCase.inputRecord)
var responseMSG *dns.Msg
responseWriter := &mockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
resolver.ServeDNS(responseWriter, testCase.inputMSG)
if responseMSG == nil || len(responseMSG.Answer) == 0 {
if testCase.responseShouldBeNil {
return
}
t.Fatalf("should write a response message")
}
answerString := responseMSG.Answer[0].String()
if !strings.Contains(answerString, testCase.inputRecord.Name) {
t.Fatalf("answer doesn't contain the same domain name: \nWant: %s\nGot:%s", testCase.name, answerString)
}
if !strings.Contains(answerString, dns.Type(testCase.inputRecord.Type).String()) {
t.Fatalf("answer doesn't contain the correct type: \nWant: %s\nGot:%s", dns.Type(testCase.inputRecord.Type).String(), answerString)
}
if !strings.Contains(answerString, testCase.inputRecord.RData) {
t.Fatalf("answer doesn't contain the same address: \nWant: %s\nGot:%s", testCase.inputRecord.RData, answerString)
}
})
}
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
) )
// MockServer is the mock instance of a dns server // MockServer is the mock instance of a dns server
@@ -13,17 +14,17 @@ type MockServer struct {
InitializeFunc func() error InitializeFunc func() error
StopFunc func() StopFunc func()
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
RegisterHandlerFunc func([]string, dns.Handler, int) RegisterHandlerFunc func(domain.List, dns.Handler, int)
DeregisterHandlerFunc func([]string, int) DeregisterHandlerFunc func(domain.List, int)
} }
func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler, priority int) { func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
if m.RegisterHandlerFunc != nil { if m.RegisterHandlerFunc != nil {
m.RegisterHandlerFunc(domains, handler, priority) m.RegisterHandlerFunc(domains, handler, priority)
} }
} }
func (m *MockServer) DeregisterHandler(domains []string, priority int) { func (m *MockServer) DeregisterHandler(domains domain.List, priority int) {
if m.DeregisterHandlerFunc != nil { if m.DeregisterHandlerFunc != nil {
m.DeregisterHandlerFunc(domains, priority) m.DeregisterHandlerFunc(domains, priority)
} }

View File

@@ -1,26 +0,0 @@
package dns
import (
"net"
"github.com/miekg/dns"
)
type mockResponseWriter struct {
WriteMsgFunc func(m *dns.Msg) error
}
func (rw *mockResponseWriter) WriteMsg(m *dns.Msg) error {
if rw.WriteMsgFunc != nil {
return rw.WriteMsgFunc(m)
}
return nil
}
func (rw *mockResponseWriter) LocalAddr() net.Addr { return nil }
func (rw *mockResponseWriter) RemoteAddr() net.Addr { return nil }
func (rw *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
func (rw *mockResponseWriter) Close() error { return nil }
func (rw *mockResponseWriter) TsigStatus() error { return nil }
func (rw *mockResponseWriter) TsigTimersOnly(bool) {}
func (rw *mockResponseWriter) Hijack() {}

View File

@@ -13,7 +13,6 @@ import (
"github.com/godbus/dbus/v5" "github.com/godbus/dbus/v5"
"github.com/hashicorp/go-version" "github.com/hashicorp/go-version"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
@@ -126,10 +125,10 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st
continue continue
} }
if dConf.MatchOnly { if dConf.MatchOnly {
matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.Domain)) matchDomains = append(matchDomains, "~."+dConf.Domain)
continue continue
} }
searchDomains = append(searchDomains, dns.Fqdn(dConf.Domain)) searchDomains = append(searchDomains, dConf.Domain)
} }
newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic

Some files were not shown because too many files have changed in this diff Show More