Compare commits

...

116 Commits

Author SHA1 Message Date
Viktor Liu
3efa7a282a Set log level debug 2024-11-27 13:46:37 +01:00
Viktor Liu
40551099b3 Add debug 2024-11-27 13:41:32 +01:00
Pascal Fischer
9db1932664 [management] Fix getSetupKey call (#2927) 2024-11-22 10:15:51 +01:00
Viktor Liu
1bbabf70b0 [client] Fix allow netbird rule verdict (#2925)
* Fix allow netbird rule verdict

* Fix chain name
2024-11-21 16:53:37 +01:00
Pascal Fischer
aa575d6f44 [management] Add activity events to group propagation flow (#2916) 2024-11-21 15:10:34 +01:00
Pascal Fischer
f66bbcc54c [management] Add metric for peer meta update (#2913) 2024-11-19 18:13:26 +01:00
Pascal Fischer
5dd6a08ea6 link peer meta update back to account object (#2911) 2024-11-19 17:25:49 +01:00
Krzysztof Nazarewski (kdn)
eb5d0569ae [client] Add NB_SKIP_SOCKET_MARK & fix crash instead of returing an error (#2899)
* dialer: fix crash instead of returning error

* add NB_SKIP_SOCKET_MARK
2024-11-19 14:14:58 +01:00
Pascal Fischer
52ea2e84e9 [management] Add transaction metrics and exclude getAccount time from peers update (#2904) 2024-11-19 00:04:50 +01:00
Maycon Santos
78fab877c0 [misc] Update signing pipeline version (#2900) 2024-11-18 15:31:53 +01:00
Maycon Santos
65a94f695f use google domain for tests (#2902) 2024-11-18 12:55:02 +01:00
Kursat Aktas
ec543f89fb Introducing NetBird Guru on Gurubase.io (#2778) 2024-11-16 15:45:31 +01:00
Viktor Liu
a7d5c52203 Fix error state race on mgmt connection error (#2892) 2024-11-15 22:59:49 +01:00
Viktor Liu
582bb58714 Move state updates outside the refcounter (#2897) 2024-11-15 22:55:33 +01:00
Viktor Liu
121dfda915 [client] Fix state manager race conditions (#2890) 2024-11-15 20:05:26 +01:00
İsmail
a1c5287b7c Fix the Inactivity Expiration problem. (#2865) 2024-11-15 18:21:27 +01:00
Bethuel Mmbaga
12f442439a [management] Refactor group to use store methods (#2867)
* Refactor setup key handling to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add lock to get account groups

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add check for regular user

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* get only required groups for auto-group validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add account lock and return auto groups map on validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor account peers update

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor groups to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor GetGroupByID and add NewGroupNotFoundError

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add AddPeer and RemovePeer methods to Group struct

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Preserve store engine in SqlStore transactions

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Run groups ops in transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix missing group removed from setup key activity

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix sonar

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Change setup key log level to debug for missing group

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Retrieve modified peers once for group events

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add account locking and merge group deletion methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-15 20:09:32 +03:00
Pascal Fischer
d9b691b8a5 [management] Limit the setup-key update operation (#2841) 2024-11-15 17:00:06 +01:00
Pascal Fischer
4aee3c9e33 [client/management] add peer lock to peer meta update and fix isEqual func (#2840) 2024-11-15 16:59:03 +01:00
Pascal Fischer
44e799c687 [management] Fix limited peer view groups (#2894) 2024-11-15 11:16:16 +01:00
Viktor Liu
be78efbd42 [client] Handle panic on nil wg interface (#2891) 2024-11-14 20:15:16 +01:00
Maycon Santos
6886691213 Update route calculation tests (#2884)
- Add two new test cases for p2p and relay routes with same latency
- Add extra statuses generation
2024-11-13 15:21:33 +01:00
Zoltan Papp
b48afd92fd [relay-server] Always close ws conn when work thread exit (#2879)
Close ws conn when work thread exit
2024-11-13 15:02:51 +01:00
Viktor Liu
39329e12a1 [client] Improve state write timeout and abort work early on timeout (#2882)
* Improve state write timeout and abort work early on timeout

* Don't block on initial persist state
2024-11-13 13:46:00 +01:00
Pascal Fischer
20a5afc359 [management] Add more logs to the peer update processes (#2881) 2024-11-12 14:19:22 +01:00
Bethuel Mmbaga
6cb697eed6 [management] Refactor setup key to use store methods (#2861)
* Refactor setup key handling to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add lock to get account groups

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add check for regular user

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* get only required groups for auto-group validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add account lock and return auto groups map on validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix missing group removed from setup key activity

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Remove context from DB queries

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Add user permission check and add setup events into events to store slice

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Retrieve all groups once during setup key auto-group validation

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix lint

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix sonar

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-11-11 19:46:10 +03:00
Viktor Liu
e0bed2b0fb [client] Fix race conditions (#2869)
* Fix concurrent map access in status

* Fix race when retrieving ctx state error

* Fix race when accessing service controller server instance
2024-11-11 14:55:10 +01:00
Zoltan Papp
30f025e7dd [client] fix/proxy close (#2873)
When the remote peer switches the Relay instance then must to close the proxy connection to the old instance.

It can cause issues when the remote peer switch connects to the Relay instance multiple times and then reconnects to an instance it had previously connected to.
2024-11-11 14:18:38 +01:00
Zoltan Papp
b4d7605147 [client] Remove loop after route calculation (#2856)
- ICE do not trigger disconnect callbacks if the stated did not change
- Fix route calculation callback loop
- Move route state updates into protected scope by mutex
- Do not calculate routes in case of peer.Open() and peer.Close()
2024-11-11 10:53:57 +01:00
Viktor Liu
08b6e9d647 [management] Fix api error message typo peers_group (#2862) 2024-11-08 23:28:02 +01:00
Pascal Fischer
67ce14eaea [management] Add peer lock to grpc server (#2859)
* add peer lock to grpc server

* remove sleep and put db update first

* don't export lock method
2024-11-08 18:47:22 +01:00
Pascal Fischer
669904cd06 [management] Remove context from database calls (#2863) 2024-11-08 15:49:00 +01:00
Zoltan Papp
4be826450b [client] Use offload in WireGuard bind receiver (#2815)
Improve the performance on Linux and Android in case of P2P connections
2024-11-07 17:28:38 +01:00
Maycon Santos
738387f2de Add benchmark tests to get account with claims (#2761)
* Add benchmark tests to get account with claims

* add users to account objects

* remove hardcoded env
2024-11-07 17:23:35 +01:00
Pascal Fischer
baf0678ceb [management] Fix potential panic on inactivity expiration log message (#2854) 2024-11-07 16:33:57 +01:00
Pascal Fischer
7fef8f6758 [management] Enforce max conn of 1 for sqlite setups (#2855) 2024-11-07 16:32:35 +01:00
Viktor Liu
6829a64a2d [client] Exclude split default route ip addresses from anonymization (#2853) 2024-11-07 16:29:32 +01:00
Zoltan Papp
cbf500024f [relay-server] Use X-Real-IP in case of reverse proxy (#2848)
* Use X-Real-IP in case of reverse proxy

* Use sprintf
2024-11-07 16:14:53 +01:00
Viktor Liu
509e184e10 [client] Use the prerouting chain to mark for masquerading to support older systems (#2808) 2024-11-07 12:37:04 +01:00
Pascal Fischer
3e88b7c56e [management] Fix network map update on peer validation (#2849) 2024-11-07 09:50:13 +01:00
Maycon Santos
b952d8693d Fix cached device flow oauth (#2833)
This change removes the cached device flow oauth info when a down command is called

Removing the need for the agent to be restarted
2024-11-05 14:51:17 +01:00
Maycon Santos
5b46cc8e9c Avoid failing all other matrix tests if one fails (#2839) 2024-11-05 13:28:42 +01:00
Pascal Fischer
a9d06b883f add all group to add peer affected peers network map check (#2830) 2024-11-01 22:09:08 +01:00
Viktor Liu
5f06b202c3 [client] Log windows panics (#2829) 2024-11-01 15:08:22 +01:00
Zoltan Papp
0eb99c266a Fix unused servers cleanup (#2826)
The cleanup loop did not manage those situations well when a connection failed or 
the connection success but the code did not add a peer connection to it yet.

- in the cleanup loop check if a connection failed to a server
- after adding a foreign server connection force to keep it a minimum 5 sec
2024-11-01 12:33:29 +01:00
Pascal Fischer
bac95ace18 [management] Add DB access duration to logs for context cancel (#2781) 2024-11-01 10:58:39 +01:00
Zoltan Papp
9812de853b Allocate new buffer for every package (#2823) 2024-11-01 00:33:25 +01:00
Zoltan Papp
ad4f0a6fdf [client] Nil check on ICE remote conn (#2806) 2024-10-31 23:18:35 +01:00
Pascal Fischer
4c758c6e52 [management] remove network map diff calculations (#2820) 2024-10-31 19:24:15 +01:00
Misha Bragin
ec5095ba6b Create FUNDING.yml (#2814) 2024-10-30 17:25:02 +01:00
Misha Bragin
49a54624f8 Create funding.json (#2813) 2024-10-30 17:18:27 +01:00
Pascal Fischer
729bcf2b01 [management] add metrics to network map diff (#2811) 2024-10-30 16:53:23 +01:00
Jing
a0cdb58303 [client] Fix the broken dependency gvisor.dev/gvisor (#2789)
The release was removed which is described at
https://github.com/google/gvisor/issues/11085#issuecomment-2438974962.
2024-10-29 20:17:40 +01:00
pascal-fischer
39c99781cb fix meta is equal slices (#2807) 2024-10-29 19:54:38 +01:00
Marco Garcês
01f24907c5 [client] Fix multiple peer name filtering in netbird status command (#2798) 2024-10-29 17:49:41 +01:00
pascal-fischer
10480eb52f [management] Setup key improvements (#2775) 2024-10-28 17:52:23 +01:00
pascal-fischer
1e44c5b574 [client] allow relay leader on iOS (#2795) 2024-10-28 16:55:00 +01:00
Viktor Liu
940f8b4547 [client] Remove legacy forwarding rules in userspace mode (#2782) 2024-10-28 12:29:29 +01:00
Viktor Liu
46e37fa04c [client] Ignore route rules with no sources instead of erroring out (#2786) 2024-10-28 12:28:44 +01:00
Stefano
b9f205b2ce [misc] Update Zitadel from v2.54.10 to v2.64.1 2024-10-28 10:08:17 +01:00
Viktor Liu
0fd874fa45 [client] Make native firewall init fail firewall creation (#2784) 2024-10-28 10:02:27 +01:00
Viktor Liu
8016710d24 [client] Cleanup firewall state on startup (#2768) 2024-10-24 14:46:24 +02:00
Zoltan Papp
4e918e55ba [client] Fix controller re-connection (#2758)
Rethink the peer reconnection implementation
2024-10-24 11:43:14 +02:00
Viktor Liu
869537c951 [client] Cleanup dns and route states on startup (#2757) 2024-10-24 10:53:46 +02:00
Zoltan Papp
44f2ce666e [relay-client] Log exposed address (#2771)
* Log exposed address
2024-10-23 18:32:27 +02:00
pascal-fischer
563dca705c [management] Fix session inactivity response (#2770) 2024-10-23 16:40:15 +02:00
Bethuel Mmbaga
7bda385e1b [management] Optimize network map updates (#2718)
* Skip peer update on unchanged network map (#2236)

* Enhance network updates by skipping unchanged messages

Optimizes the network update process
by skipping updates where no changes in the peer update message received.

* Add unit tests

* add locks

* Improve concurrency and update peer message handling

* Refactor account manager network update tests

* fix test

* Fix inverted network map update condition

* Add default group and policy to test data

* Run peer updates in a separate goroutine

* Refactor

* Refactor lock

* Fix peers update by including NetworkMap and posture Checks

* go mod tidy

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* [management] Skip account peers update if no changes affect peers (#2310)

* Remove incrementing network serial and updating peers after group deletion

* Update account peer if posture check is linked to policy

* Remove account peers update on saving setup key

* Refactor group link checking into re-usable functions

* Add HasPeers function to group

* Refactor group management

* Optimize group change effects on account peers

* Update account peers if ns group has peers

* Refactor group changes

* Optimize account peers update in DNS settings

* Optimize update of account peers on jwt groups sync

* Refactor peer account updates for efficiency

* Optimize peer update on user deletion and changes

* Remove condition check for network serial update

* Optimize account peers updates on route changes

* Remove UpdatePeerSSHKey method

* Remove unused isPolicyRuleGroupsEmpty

* Add tests for peer update behavior on posture check changes

* Add tests for peer update behavior on policy changes

* Add tests for peer update behavior on group changes

* Add tests for peer update behavior on dns settings changes

* Refactor

* Add tests for peer update behavior on name server changes

* Add tests for peer update behavior on user changes

* Add tests for peer update behavior on route changes

* fix tests

* Add tests for peer update behavior on setup key changes

* Add tests for peer update behavior on peers changes

* fix merge

* Fix tests

* go mod tidy

* Add NameServer and Route comparators

* Update network map diff logic with custom comparators

* Add tests

* Refactor duplicate diff handling logic

* fix linter

* fix tests

* Refactor policy group handling and update logic.

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Update route check by checking if group has peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor posture check policy linking logic

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Simplify peer update condition in DNS management

Refactor the condition for updating account peers to remove redundant checks

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add policy tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add posture checks tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix user and setup key tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix account and route tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix typo

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix nameserver tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix routes tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix group tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* upgrade diff package

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix nameserver tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* use generic differ for netip.Addr and netip.Prefix

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* go mod tidy

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add peer tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix management suite tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix postgres tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* enable diff nil structs comparison

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* skip the update only last sent the serial is larger

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor peer and user

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* skip spell check for groupD

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor group, ns group, policy and posture checks

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* skip spell check for GroupD

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* update account policy check before verifying policy status

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Update management/server/route_test.go

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>

* Update management/server/route_test.go

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>

* Update management/server/route_test.go

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>

* Update management/server/route_test.go

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>

* Update management/server/route_test.go

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>

* add tests missing tests for dns setting groups

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add tests for posture checks changes

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add ns group and policy tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add route and group tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* increase Linux test timeout to 10 minutes

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Run diff for client posture checks only

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add panic recovery and detailed logging in peer update comparison

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2024-10-23 13:05:02 +03:00
Zoltan Papp
30ebcf38c7 [client] Eliminate UDP proxy in user-space mode (#2712)
In the case of user space WireGuard mode, use in-memory proxy between the TURN/Relay connection and the WireGuard Bind. We keep the UDP proxy and eBPF proxy for kernel mode.

The key change is the new wgproxy/bind and the iface/bind/ice_bind changes. Everything else is just to fulfill the dependencies.
2024-10-22 20:53:14 +02:00
Bethuel Mmbaga
0106a95f7a lock account and use transaction (#2767)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-22 13:29:17 +03:00
Maycon Santos
9929b22afc Replace suite tests with regular go tests (#2762)
* Replace file suite tests with go tests

* Replace file suite tests with go tests
2024-10-21 14:39:28 +02:00
Maycon Santos
88e4fc2245 Release global lock on early error (#2760) 2024-10-19 18:32:17 +02:00
Maycon Santos
c8d8748dcf Update sign workflow version (#2756) 2024-10-18 17:28:58 +02:00
Maycon Santos
507a40bd7f Fix decompress zip path (#2755)
Since 0.30.2 the decompressed binary path from the signed package has changed

now it doesn't contain the arch suffix

this change handles that
2024-10-17 20:39:59 +02:00
Maycon Santos
ccd4ae6315 Fix domain information is up to date check (#2754) 2024-10-17 19:21:35 +02:00
Bethuel Mmbaga
96d2207684 Fix JSON function compatibility for SQLite and PostgreSQL (#2746)
resolves the issue with json_array_length compatibility between SQLite and PostgreSQL. It adjusts the query to conditionally cast types:

PostgreSQL: Casts to json with ::json.
SQLite: Uses the text representation directly.
2024-10-16 17:55:30 +02:00
Emre Oksum
f942491b91 Update Zitadel version on quickstart script (#2744)
Update Zitadel version at docker compose in quickstart script from 2.54.3 to 2.54.10 because 2.54.3 isn't stable and has a lot of bugs.
2024-10-16 17:51:21 +02:00
Viktor Liu
8c8900be57 [client] Exclude loopback from NAT (#2747) 2024-10-16 17:35:59 +02:00
Maycon Santos
cee95461d1 [client] Add universal bin build and update sign workflow version (#2738)
* Add universal binaries build for macOS

* update sign pipeline version

* handle info.plist in sign workflow
2024-10-15 15:03:17 +02:00
ctrl-zzz
49e65109d2 Add session expire functionality based on inactivity (#2326)
Implemented inactivity expiration by checking the status of a peer: after a configurable period of time following netbird down, the peer shows login required.
2024-10-13 14:52:43 +02:00
Zoltan Papp
d93dd4fc7f [relay-server] Move the handshake logic to separated struct (#2648)
* Move the handshake logic to separated struct

- The server will response to the client after it ready to process the peer
- Preload the response messages

* Fix deprecated lint issue

* Fix error handling

* [relay-server] Relay measure auth time (#2675)

Measure the Relay client's authentication time
2024-10-12 18:21:34 +02:00
Viktor Liu
3a88ac78ff [client] Add table filter rules using iptables (#2727)
This specifically concerns the established/related rule since this one is not compatible with iptables-nft even if it is generated the same way by iptables-translate.
2024-10-12 10:44:48 +02:00
Maycon Santos
da3a053e2b [management] Refactor getAccountIDWithAuthorizationClaims (#2715)
This change restructures the getAccountIDWithAuthorizationClaims method to improve readability, maintainability, and performance.

- have dedicated methods to handle possible cases
- introduced Store.UpdateAccountDomainAttributes and Store.GetAccountUsers methods
- Remove GetAccount and SaveAccount dependency
- added tests
2024-10-12 08:35:51 +02:00
Zoltan Papp
0e95f16cdd [relay,client] Relay/fix/wg roaming (#2691)
If a peer connection switches from Relayed to ICE P2P, the Relayed proxy still consumes the data the other peer sends. Because the proxy is operating, the WireGuard switches back to the Relayed proxy automatically, thanks to the roaming feature.

Extend the Proxy implementation with pause/resume functions. Before switching to the p2p connection, pause the WireGuard proxy operation to prevent unnecessary package sources.
Consider waiting some milliseconds after the pause to be sure the WireGuard engine already processed all UDP msg in from the pipe.
2024-10-11 16:24:30 +02:00
pascal-fischer
b2379175fe [signal] new signal dispatcher version (#2722) 2024-10-10 16:23:46 +02:00
Viktor Liu
09bdd271f1 [client] Improve route acl (#2705)
- Update nftables library to v0.2.0
- Mark traffic that was originally destined for local and applies the input rules in the forward chain if said traffic was redirected (e.g. by Docker)
- Add nft rules to internal map only if flush was successful
- Improve error message if handle is 0 (= not found or hasn't been refreshed)
- Add debug logging when route rules are added
- Replace nftables userdata (rule ID) with a rule hash
2024-10-10 15:54:34 +02:00
Misha Bragin
208a2b7169 Add billing user role (#2714) 2024-10-10 14:14:56 +02:00
pascal-fischer
8284ae959c [management] Move testdata to sql files (#2693) 2024-10-10 12:35:03 +02:00
Maycon Santos
6ce09bca16 Add support to envsub go management configurations (#2708)
This change allows users to reference environment variables using Go template format, like {{ .EnvName }}

Moved the previous file test code to file_suite_test.go.
2024-10-09 20:46:23 +02:00
pascal-fischer
b79c1d64cc [management] Make max open db conns configurable (#2713) 2024-10-09 20:17:25 +02:00
Misha Bragin
b1eda43f4b Add Link to the Lawrence Systems video (#2711) 2024-10-09 14:56:25 +02:00
pascal-fischer
d4ef84fe6e [management] Propagate error in store errors (#2709) 2024-10-09 14:33:58 +02:00
Viktor Liu
44e8107383 [client] Limit P2P attempts and restart on specific events (#2657) 2024-10-08 11:21:11 +02:00
Bethuel Mmbaga
2c1f5e46d5 [management] Validate peer ownership during login (#2704)
* check peer ownership in login

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* update error message

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-07 19:06:26 +03:00
pascal-fischer
dbec24b520 [management] Remove admin check on getAccountByID (#2699) 2024-10-06 17:01:13 +02:00
Carlos Hernandez
f603cd9202 [client] Check wginterface instead of engine ctx (#2676)
Moving code to ensure wgInterface is gone right after context is
cancelled/stop in the off chance that on next retry the backoff
operation is permanently cancelled and interface is abandoned without
destroying.
2024-10-04 19:15:16 +02:00
Bethuel Mmbaga
5897a48e29 fix wrong reference (#2695)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-04 18:55:25 +03:00
Bethuel Mmbaga
8bf729c7b4 [management] Add AccountExists to AccountManager (#2694)
* Add AccountExists method to account manager interface

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* remove unused code

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-04 18:09:40 +03:00
Bethuel Mmbaga
7f09b39769 [management] Refactor User JWT group sync (#2690)
* Refactor GetAccountIDByUserOrAccountID

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* sync user jwt group changes

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* propagate jwt group changes to peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix no jwt groups synced

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests and lint

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Move the account peer update outside the transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* move updateUserPeersInGroups to account manager

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* move event store outside of transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* get user with update lock

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Run jwt sync in transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-10-04 17:17:01 +03:00
pascal-fischer
158936fb15 [management] Remove file store (#2689) 2024-10-03 15:50:35 +02:00
Maycon Santos
8934453b30 Update management base docker image (#2687) 2024-10-02 19:29:51 +03:00
Zoltan Papp
fd67892cb4 [client] Refactor/iface pkg (#2646)
Refactor the flat code structure
2024-10-02 18:24:22 +02:00
pascal-fischer
7e5d3bdfe2 [signal] Move dummy signal message handling into dispatcher (#2686) 2024-10-02 15:33:38 +02:00
Maycon Santos
b7b0828133 [client] Adjust relay worker log level and message (#2683) 2024-10-02 15:14:09 +02:00
Bethuel Mmbaga
ff7863785f [management, client] Add access control support to network routes (#2100) 2024-10-02 13:41:00 +02:00
Maycon Santos
a3a479429e Use the pkgs to get the latest version (#2682)
* Use the pkgs to get the latest version

* disable fail fast
2024-10-02 11:48:42 +02:00
Maycon Santos
5932298ce0 Add log setting to Caddy container (#2684)
This avoids full disk on busy systems
2024-10-02 11:48:09 +02:00
Zoltan Papp
ee0ea86a0a [relay-client] Fix Relay disconnection handling (#2680)
* Fix Relay disconnection handling

If has an active P2P connection meanwhile the Relay connection broken with the server then we removed the WireGuard peer configuration.

* Change logs
2024-10-01 16:22:18 +02:00
Simen
24c0aaa745 Install sh alpine fixes (#2678)
* Made changes to the peer install script that makes it work on alpine linux without changes

* fix small oversight with doas fix

* use try catch approach when curling binaries
2024-10-01 13:32:58 +02:00
pascal-fischer
16179db599 [management] Propagate metrics (#2667) 2024-09-30 22:18:10 +02:00
Maycon Santos
e27f85b317 Update docker creds (#2677) 2024-09-30 20:07:21 +02:00
Gianluca Boiano
2fd60b2cb4 Specify goreleaser version and update to 2 (#2673) 2024-09-30 16:43:34 +02:00
Zoltan Papp
3dca6099d4 Fix ebpf close function (#2672) 2024-09-30 10:34:57 +02:00
pascal-fischer
cfbcf507fb propagate meter (#2668) 2024-09-29 20:23:34 +02:00
pascal-fischer
52ae693c9e [signal] add context to signal-dispatcher (#2662) 2024-09-29 00:22:47 +02:00
adasauce
58ff7ab797 [management] improve zitadel idp error response detail by decoding errors (#2634)
* [management] improve zitadel idp error response detail by decoding errors

* [management] extend readZitadelError to be used for requestJWTToken

more generically parse the error returned by zitadel.

* fix lint

---------

Co-authored-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-27 22:21:34 +03:00
Bethuel Mmbaga
acb73bd64a [management] Remove redundant get account calls in GetAccountFromToken (#2615)
* refactor access control middleware and user access by JWT groups

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor jwt groups extractor

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor handlers to get account when necessary

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor getAccountFromToken

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor getAccountWithAuthorizationClaims

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* revert handles change

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* remove GetUserByID from account manager

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor getAccountWithAuthorizationClaims to return account id

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor handlers to use GetAccountIDFromToken

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* remove locks

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* refactor

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add GetGroupByName from store

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add GetGroupByID from store and refactor

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor retrieval of policy and posture checks

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor user permissions and retrieves PAT

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor route, setupkey, nameserver and dns to get record(s) from store

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Refactor store

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix lint

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix add missing policy source posture checks

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add store lock

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix tests

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add get account

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2024-09-27 17:10:50 +03:00
350 changed files with 19175 additions and 9769 deletions

3
.github/FUNDING.yml vendored Normal file
View File

@@ -0,0 +1,3 @@
# These are supported funding model platforms
github: [netbirdio]

View File

@@ -42,4 +42,4 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./... run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...

View File

@@ -38,7 +38,7 @@ jobs:
time go test -timeout 1m -failfast ./dns/... time go test -timeout 1m -failfast ./dns/...
time go test -timeout 1m -failfast ./encryption/... time go test -timeout 1m -failfast ./encryption/...
time go test -timeout 1m -failfast ./formatter/... time go test -timeout 1m -failfast ./formatter/...
time go test -timeout 1m -failfast ./iface/... time go test -timeout 1m -failfast ./client/iface/...
time go test -timeout 1m -failfast ./route/... time go test -timeout 1m -failfast ./route/...
time go test -timeout 1m -failfast ./sharedsock/... time go test -timeout 1m -failfast ./sharedsock/...
time go test -timeout 1m -failfast ./signal/... time go test -timeout 1m -failfast ./signal/...

View File

@@ -13,10 +13,11 @@ concurrency:
jobs: jobs:
test: test:
strategy: strategy:
fail-fast: false
matrix: matrix:
arch: [ '386','amd64' ] arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres'] store: [ 'sqlite', 'postgres']
runs-on: ubuntu-latest runs-on: ubuntu-22.04
steps: steps:
- name: Install Go - name: Install Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
@@ -49,7 +50,7 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Test - name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./... run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
test_client_on_docker: test_client_on_docker:
runs-on: ubuntu-20.04 runs-on: ubuntu-20.04
@@ -79,9 +80,6 @@ jobs:
- name: check git status - name: check git status
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Generate Iface Test bin
run: CGO_ENABLED=0 go test -c -o iface-testing.bin ./iface/
- name: Generate Shared Sock Test bin - name: Generate Shared Sock Test bin
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
@@ -98,7 +96,7 @@ jobs:
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
- name: Generate Peer Test bin - name: Generate Peer Test bin
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/... run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/
- run: chmod +x *testing.bin - run: chmod +x *testing.bin
@@ -106,7 +104,7 @@ jobs:
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1 run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1
- name: Run Iface tests in docker - name: Run Iface tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/iface --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/iface-testing.bin -test.timeout 5m -test.parallel 1 run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/...
- name: Run RouteManager tests in docker - name: Run RouteManager tests in docker
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1 run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell - name: codespell
uses: codespell-project/actions-codespell@v2 uses: codespell-project/actions-codespell@v2
with: with:
ignore_words_list: erro,clienta,hastable, ignore_words_list: erro,clienta,hastable,iif,groupd
skip: go.mod,go.sum skip: go.mod,go.sum
only_warn: 1 only_warn: 1
golangci: golangci:

View File

@@ -13,6 +13,7 @@ concurrency:
jobs: jobs:
test-install-script: test-install-script:
strategy: strategy:
fail-fast: false
max-parallel: 2 max-parallel: 2
matrix: matrix:
os: [ubuntu-latest, macos-latest] os: [ubuntu-latest, macos-latest]

View File

@@ -3,15 +3,14 @@ name: Release
on: on:
push: push:
tags: tags:
- 'v*' - "v*"
branches: branches:
- main - main
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.0.14" SIGN_PIPE_VER: "v0.0.17"
GORELEASER_VER: "v1.14.1" GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)" COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
@@ -21,7 +20,7 @@ concurrency:
jobs: jobs:
release: release:
runs-on: ubuntu-latest runs-on: ubuntu-22.04
env: env:
flags: "" flags: ""
steps: steps:
@@ -34,19 +33,16 @@ jobs:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }} - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
- - name: Checkout
name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- - name: Set up Go
name: Set up Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version: "1.23" go-version: "1.23"
cache: false cache: false
- - name: Cache Go modules
name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v4
with: with:
path: | path: |
@@ -55,24 +51,19 @@ jobs:
key: ${{ runner.os }}-go-releaser-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-go-releaser-${{ hashFiles('**/go.sum') }}
restore-keys: | restore-keys: |
${{ runner.os }}-go-releaser- ${{ runner.os }}-go-releaser-
- - name: Install modules
name: Install modules
run: go mod tidy run: go mod tidy
- - name: check git status
name: check git status
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- - name: Set up QEMU
name: Set up QEMU
uses: docker/setup-qemu-action@v2 uses: docker/setup-qemu-action@v2
- - name: Set up Docker Buildx
name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2 uses: docker/setup-buildx-action@v2
- - name: Login to Docker hub
name: Login to Docker hub
if: github.event_name != 'pull_request' if: github.event_name != 'pull_request'
uses: docker/login-action@v1 uses: docker/login-action@v1
with: with:
username: netbirdio username: ${{ secrets.DOCKER_USER }}
password: ${{ secrets.DOCKER_TOKEN }} password: ${{ secrets.DOCKER_TOKEN }}
- name: Install OS build dependencies - name: Install OS build dependencies
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
@@ -85,35 +76,31 @@ jobs:
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
args: release --rm-dist ${{ env.flags }} args: release --clean ${{ env.flags }}
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }} UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
- - name: upload non tags for debug purposes
name: upload non tags for debug purposes
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: release name: release
path: dist/ path: dist/
retention-days: 3 retention-days: 3
- - name: upload linux packages
name: upload linux packages
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: linux-packages name: linux-packages
path: dist/netbird_linux** path: dist/netbird_linux**
retention-days: 3 retention-days: 3
- - name: upload windows packages
name: upload windows packages
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: windows-packages name: windows-packages
path: dist/netbird_windows** path: dist/netbird_windows**
retention-days: 3 retention-days: 3
- - name: upload macos packages
name: upload macos packages
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: macos-packages name: macos-packages
@@ -145,7 +132,7 @@ jobs:
- name: Cache Go modules - name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v4
with: with:
path: | path: |
~/go/pkg/mod ~/go/pkg/mod
~/.cache/go-build ~/.cache/go-build
key: ${{ runner.os }}-ui-go-releaser-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-ui-go-releaser-${{ hashFiles('**/go.sum') }}
@@ -169,7 +156,7 @@ jobs:
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui.yaml --rm-dist ${{ env.flags }} args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }} HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
@@ -187,19 +174,16 @@ jobs:
steps: steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }} - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
- - name: Checkout
name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
- - name: Set up Go
name: Set up Go
uses: actions/setup-go@v5 uses: actions/setup-go@v5
with: with:
go-version: "1.23" go-version: "1.23"
cache: false cache: false
- - name: Cache Go modules
name: Cache Go modules
uses: actions/cache@v4 uses: actions/cache@v4
with: with:
path: | path: |
@@ -208,23 +192,19 @@ jobs:
key: ${{ runner.os }}-ui-go-releaser-darwin-${{ hashFiles('**/go.sum') }} key: ${{ runner.os }}-ui-go-releaser-darwin-${{ hashFiles('**/go.sum') }}
restore-keys: | restore-keys: |
${{ runner.os }}-ui-go-releaser-darwin- ${{ runner.os }}-ui-go-releaser-darwin-
- - name: Install modules
name: Install modules
run: go mod tidy run: go mod tidy
- - name: check git status
name: check git status
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- - name: Run GoReleaser
name: Run GoReleaser
id: goreleaser id: goreleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui_darwin.yaml --rm-dist ${{ env.flags }} args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- - name: upload non tags for debug purposes
name: upload non tags for debug purposes
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: release-ui-darwin name: release-ui-darwin
@@ -233,7 +213,7 @@ jobs:
trigger_signer: trigger_signer:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: [release,release_ui,release_ui_darwin] needs: [release, release_ui, release_ui_darwin]
if: startsWith(github.ref, 'refs/tags/') if: startsWith(github.ref, 'refs/tags/')
steps: steps:
- name: Trigger binaries sign pipelines - name: Trigger binaries sign pipelines
@@ -243,4 +223,4 @@ jobs:
repo: netbirdio/sign-pipelines repo: netbirdio/sign-pipelines
ref: ${{ env.SIGN_PIPE_VER }} ref: ${{ env.SIGN_PIPE_VER }}
token: ${{ secrets.SIGN_GITHUB_TOKEN }} token: ${{ secrets.SIGN_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref }}" }' inputs: '{ "tag": "${{ github.ref }}", "skipRelease": false }'

View File

@@ -1,3 +1,5 @@
version: 2
project_name: netbird project_name: netbird
builds: builds:
- id: netbird - id: netbird
@@ -22,7 +24,7 @@ builds:
goarch: 386 goarch: 386
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: '{{ .CommitTimestamp }}' mod_timestamp: "{{ .CommitTimestamp }}"
tags: tags:
- load_wgnt_from_rsrc - load_wgnt_from_rsrc
@@ -42,19 +44,19 @@ builds:
- softfloat - softfloat
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: '{{ .CommitTimestamp }}' mod_timestamp: "{{ .CommitTimestamp }}"
tags: tags:
- load_wgnt_from_rsrc - load_wgnt_from_rsrc
- id: netbird-mgmt - id: netbird-mgmt
dir: management dir: management
env: env:
- CGO_ENABLED=1 - CGO_ENABLED=1
- >- - >-
{{- if eq .Runtime.Goos "linux" }} {{- if eq .Runtime.Goos "linux" }}
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }} {{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }} {{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
{{- end }} {{- end }}
binary: netbird-mgmt binary: netbird-mgmt
goos: goos:
- linux - linux
@@ -64,7 +66,7 @@ builds:
- arm - arm
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: '{{ .CommitTimestamp }}' mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-signal - id: netbird-signal
dir: signal dir: signal
@@ -78,7 +80,7 @@ builds:
- arm - arm
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: '{{ .CommitTimestamp }}' mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-relay - id: netbird-relay
dir: relay dir: relay
@@ -92,7 +94,10 @@ builds:
- arm - arm
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: '{{ .CommitTimestamp }}' mod_timestamp: "{{ .CommitTimestamp }}"
universal_binaries:
- id: netbird
archives: archives:
- builds: - builds:
@@ -100,7 +105,6 @@ archives:
- netbird-static - netbird-static
nfpms: nfpms:
- maintainer: Netbird <dev@netbird.io> - maintainer: Netbird <dev@netbird.io>
description: Netbird client. description: Netbird client.
homepage: https://netbird.io/ homepage: https://netbird.io/
@@ -416,10 +420,9 @@ docker_manifests:
- netbirdio/management:{{ .Version }}-debug-amd64 - netbirdio/management:{{ .Version }}-debug-amd64
brews: brews:
- - ids:
ids:
- default - default
tap: repository:
owner: netbirdio owner: netbirdio
name: homebrew-tap name: homebrew-tap
token: "{{ .Env.HOMEBREW_TAP_GITHUB_TOKEN }}" token: "{{ .Env.HOMEBREW_TAP_GITHUB_TOKEN }}"
@@ -436,7 +439,7 @@ brews:
uploads: uploads:
- name: debian - name: debian
ids: ids:
- netbird-deb - netbird-deb
mode: archive mode: archive
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package= target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
username: dev@wiretrustee.com username: dev@wiretrustee.com

View File

@@ -1,3 +1,5 @@
version: 2
project_name: netbird-ui project_name: netbird-ui
builds: builds:
- id: netbird-ui - id: netbird-ui
@@ -11,7 +13,7 @@ builds:
- amd64 - amd64
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: '{{ .CommitTimestamp }}' mod_timestamp: "{{ .CommitTimestamp }}"
- id: netbird-ui-windows - id: netbird-ui-windows
dir: client/ui dir: client/ui
@@ -26,7 +28,7 @@ builds:
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
- -H windowsgui - -H windowsgui
mod_timestamp: '{{ .CommitTimestamp }}' mod_timestamp: "{{ .CommitTimestamp }}"
archives: archives:
- id: linux-arch - id: linux-arch
@@ -39,7 +41,6 @@ archives:
- netbird-ui-windows - netbird-ui-windows
nfpms: nfpms:
- maintainer: Netbird <dev@netbird.io> - maintainer: Netbird <dev@netbird.io>
description: Netbird client UI. description: Netbird client UI.
homepage: https://netbird.io/ homepage: https://netbird.io/
@@ -77,7 +78,7 @@ nfpms:
uploads: uploads:
- name: debian - name: debian
ids: ids:
- netbird-ui-deb - netbird-ui-deb
mode: archive mode: archive
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package= target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
username: dev@wiretrustee.com username: dev@wiretrustee.com

View File

@@ -1,3 +1,5 @@
version: 2
project_name: netbird-ui project_name: netbird-ui
builds: builds:
- id: netbird-ui-darwin - id: netbird-ui-darwin
@@ -17,10 +19,13 @@ builds:
- softfloat - softfloat
ldflags: ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser - -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: '{{ .CommitTimestamp }}' mod_timestamp: "{{ .CommitTimestamp }}"
tags: tags:
- load_wgnt_from_rsrc - load_wgnt_from_rsrc
universal_binaries:
- id: netbird-ui-darwin
archives: archives:
- builds: - builds:
- netbird-ui-darwin - netbird-ui-darwin
@@ -28,4 +33,4 @@ archives:
checksum: checksum:
name_template: "{{ .ProjectName }}_darwin_checksums.txt" name_template: "{{ .ProjectName }}_darwin_checksums.txt"
changelog: changelog:
skip: true disable: true

View File

@@ -96,7 +96,7 @@ They can be executed from the repository root before every push or PR:
**Goreleaser** **Goreleaser**
```shell ```shell
goreleaser --snapshot --rm-dist goreleaser build --snapshot --clean
``` ```
**golangci-lint** **golangci-lint**
```shell ```shell

View File

@@ -19,6 +19,10 @@
<br> <br>
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ"> <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/> <img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
</a>
<br>
<a href="https://gurubase.io/g/netbird">
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
</a> </a>
</p> </p>
</div> </div>
@@ -49,6 +53,8 @@
![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab) ![netbird_2](https://github.com/netbirdio/netbird/assets/700848/46bc3b73-508d-4a0e-bb9a-f465d68646ab)
### NetBird on Lawrence Systems (Video)
[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw)
### Key features ### Key features
@@ -62,6 +68,7 @@
| | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> | | | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
| | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> | | | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> |
| | | | | <ul><li> - \[x] Docker </ul></li> | | | | | | <ul><li> - \[x] Docker </ul></li> |
### Quickstart with NetBird Cloud ### Quickstart with NetBird Cloud
- Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install) - Download and install NetBird at [https://app.netbird.io/install](https://app.netbird.io/install)

View File

@@ -8,6 +8,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
@@ -15,7 +16,6 @@ import (
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util/net" "github.com/netbirdio/netbird/util/net"
) )
@@ -26,7 +26,7 @@ type ConnectionListener interface {
// TunAdapter export internal TunAdapter for mobile // TunAdapter export internal TunAdapter for mobile
type TunAdapter interface { type TunAdapter interface {
iface.TunAdapter device.TunAdapter
} }
// IFaceDiscover export internal IFaceDiscover for mobile // IFaceDiscover export internal IFaceDiscover for mobile
@@ -51,7 +51,7 @@ func init() {
// Client struct manage the life circle of background service // Client struct manage the life circle of background service
type Client struct { type Client struct {
cfgFile string cfgFile string
tunAdapter iface.TunAdapter tunAdapter device.TunAdapter
iFaceDiscover IFaceDiscover iFaceDiscover IFaceDiscover
recorder *peer.Status recorder *peer.Status
ctxCancel context.CancelFunc ctxCancel context.CancelFunc

View File

@@ -201,6 +201,8 @@ func isWellKnown(addr netip.Addr) bool {
"2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6 "2606:4700:4700::1111", "2606:4700:4700::1001", // Cloudflare DNS IPv6
"9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4 "9.9.9.9", "149.112.112.112", // Quad9 DNS IPv4
"2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6 "2620:fe::fe", "2620:fe::9", // Quad9 DNS IPv6
"128.0.0.0", "8000::", // 2nd split subnet for default routes
} }
if slices.Contains(wellKnown, addr.String()) { if slices.Contains(wellKnown, addr.String()) {

View File

@@ -5,8 +5,8 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )

View File

@@ -7,7 +7,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/client/iface"
) )
func TestInitCommands(t *testing.T) { func TestInitCommands(t *testing.T) {

View File

@@ -2,6 +2,7 @@ package cmd
import ( import (
"context" "context"
"sync"
"github.com/kardianos/service" "github.com/kardianos/service"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -13,10 +14,11 @@ import (
) )
type program struct { type program struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
serv *grpc.Server serv *grpc.Server
serverInstance *server.Server serverInstance *server.Server
serverInstanceMu sync.Mutex
} }
func newProgram(ctx context.Context, cancel context.CancelFunc) *program { func newProgram(ctx context.Context, cancel context.CancelFunc) *program {

View File

@@ -61,7 +61,9 @@ func (p *program) Start(svc service.Service) error {
} }
proto.RegisterDaemonServiceServer(p.serv, serverInstance) proto.RegisterDaemonServiceServer(p.serv, serverInstance)
p.serverInstanceMu.Lock()
p.serverInstance = serverInstance p.serverInstance = serverInstance
p.serverInstanceMu.Unlock()
log.Printf("started daemon server: %v", split[1]) log.Printf("started daemon server: %v", split[1])
if err := p.serv.Serve(listen); err != nil { if err := p.serv.Serve(listen); err != nil {
@@ -72,6 +74,7 @@ func (p *program) Start(svc service.Service) error {
} }
func (p *program) Stop(srv service.Service) error { func (p *program) Stop(srv service.Service) error {
p.serverInstanceMu.Lock()
if p.serverInstance != nil { if p.serverInstance != nil {
in := new(proto.DownRequest) in := new(proto.DownRequest)
_, err := p.serverInstance.Down(p.ctx, in) _, err := p.serverInstance.Down(p.ctx, in)
@@ -79,6 +82,7 @@ func (p *program) Stop(srv service.Service) error {
log.Errorf("failed to stop daemon: %v", err) log.Errorf("failed to stop daemon: %v", err)
} }
} }
p.serverInstanceMu.Unlock()
p.cancel() p.cancel()

View File

@@ -680,7 +680,7 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool { func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
statusEval := false statusEval := false
ipEval := false ipEval := false
nameEval := false nameEval := true
if statusFilter != "" { if statusFilter != "" {
lowerStatusFilter := strings.ToLower(statusFilter) lowerStatusFilter := strings.ToLower(statusFilter)
@@ -700,11 +700,13 @@ func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
if len(prefixNamesFilter) > 0 { if len(prefixNamesFilter) > 0 {
for prefixNameFilter := range prefixNamesFilterMap { for prefixNameFilter := range prefixNamesFilterMap {
if !strings.HasPrefix(peerState.Fqdn, prefixNameFilter) { if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
nameEval = true nameEval = false
break break
} }
} }
} else {
nameEval = false
} }
return statusEval || ipEval || nameEval return statusEval || ipEval || nameEval

View File

@@ -3,7 +3,6 @@ package cmd
import ( import (
"context" "context"
"net" "net"
"path/filepath"
"testing" "testing"
"time" "time"
@@ -34,18 +33,12 @@ func startTestingServices(t *testing.T) string {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
testDir := t.TempDir()
config.Datadir = testDir
err = util.CopyFileContents("../testdata/store.json", filepath.Join(testDir, "store.json"))
if err != nil {
t.Fatal(err)
}
_, signalLis := startSignal(t) _, signalLis := startSignal(t)
signalAddr := signalLis.Addr().String() signalAddr := signalLis.Addr().String()
config.Signal.URI = signalAddr config.Signal.URI = signalAddr
_, mgmLis := startManagement(t, config) _, mgmLis := startManagement(t, config, "../testdata/store.sql")
mgmAddr := mgmLis.Addr().String() mgmAddr := mgmLis.Addr().String()
return mgmAddr return mgmAddr
} }
@@ -57,7 +50,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
t.Fatal(err) t.Fatal(err)
} }
s := grpc.NewServer() s := grpc.NewServer()
srv, err := sig.NewServer(otel.Meter("")) srv, err := sig.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err) require.NoError(t, err)
sigProto.RegisterSignalExchangeServer(s, srv) sigProto.RegisterSignalExchangeServer(s, srv)
@@ -70,7 +63,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
return s, lis return s, lis
} }
func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) { func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) {
t.Helper() t.Helper()
lis, err := net.Listen("tcp", ":0") lis, err := net.Listen("tcp", ":0")
@@ -78,7 +71,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
t.Fatal(err) t.Fatal(err)
} }
s := grpc.NewServer() s := grpc.NewServer()
store, cleanUp, err := mgmt.NewTestStoreFromJson(context.Background(), config.Datadir) store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -15,11 +15,11 @@ import (
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )

View File

@@ -3,7 +3,6 @@
package firewall package firewall
import ( import (
"context"
"fmt" "fmt"
"runtime" "runtime"
@@ -11,10 +10,11 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
// NewFirewall creates a firewall manager instance // NewFirewall creates a firewall manager instance
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) { func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) {
if !iface.IsUserspaceBind() { if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
} }

View File

@@ -3,7 +3,7 @@
package firewall package firewall
import ( import (
"context" "errors"
"fmt" "fmt"
"os" "os"
@@ -15,6 +15,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables" nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const ( const (
@@ -32,54 +33,65 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type // FWType is the type for the firewall type
type FWType int type FWType int
func NewFirewall(context context.Context, iface IFaceMapper) (firewall.Manager, error) { func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables // on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic // in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers // so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall // for the userspace packet filtering firewall
var fm firewall.Manager fm, err := createNativeFirewall(iface, stateManager)
var errFw error
if !iface.IsUserspaceBind() {
return fm, err
}
if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
}
return createUserspaceFirewall(iface, fm)
}
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
fm, err := createFW(iface)
if err != nil {
return nil, fmt.Errorf("create firewall: %s", err)
}
if err = fm.Init(stateManager); err != nil {
return nil, fmt.Errorf("init firewall: %s", err)
}
return fm, nil
}
func createFW(iface IFaceMapper) (firewall.Manager, error) {
switch check() { switch check() {
case IPTABLES: case IPTABLES:
log.Info("creating an iptables firewall manager") log.Info("creating an iptables firewall manager")
fm, errFw = nbiptables.Create(context, iface) return nbiptables.Create(iface)
if errFw != nil {
log.Errorf("failed to create iptables manager: %s", errFw)
}
case NFTABLES: case NFTABLES:
log.Info("creating an nftables firewall manager") log.Info("creating an nftables firewall manager")
fm, errFw = nbnftables.Create(context, iface) return nbnftables.Create(iface)
if errFw != nil {
log.Errorf("failed to create nftables manager: %s", errFw)
}
default: default:
errFw = fmt.Errorf("no firewall manager found")
log.Info("no firewall manager found, trying to use userspace packet filtering firewall") log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
return nil, errors.New("no firewall manager found")
}
}
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) {
var errUsp error
if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
} else {
fm, errUsp = uspfilter.Create(iface)
} }
if iface.IsUserspaceBind() { if errUsp != nil {
var errUsp error return nil, fmt.Errorf("create userspace firewall: %s", errUsp)
if errFw == nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
} else {
fm, errUsp = uspfilter.Create(iface)
}
if errUsp != nil {
log.Debugf("failed to create userspace filtering firewall: %s", errUsp)
return nil, errUsp
}
if err := fm.AllowNetbird(); err != nil {
log.Errorf("failed to allow netbird interface traffic: %v", err)
}
return fm, nil
} }
if errFw != nil { if err := fm.AllowNetbird(); err != nil {
return nil, errFw log.Errorf("failed to allow netbird interface traffic: %v", err)
} }
return fm, nil return fm, nil
} }

View File

@@ -1,11 +1,13 @@
package firewall package firewall
import "github.com/netbirdio/netbird/iface" import (
"github.com/netbirdio/netbird/client/iface/device"
)
// IFaceMapper defines subset methods of interface required for manager // IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface { type IFaceMapper interface {
Name() string Name() string
Address() iface.WGAddress Address() device.WGAddress
IsUserspaceBind() bool IsUserspaceBind() bool
SetFilter(iface.PacketFilter) error SetFilter(device.PacketFilter) error
} }

View File

@@ -11,6 +11,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net"
) )
const ( const (
@@ -19,49 +21,65 @@ const (
// rules chains contains the effective ACL rules // rules chains contains the effective ACL rules
chainNameInputRules = "NETBIRD-ACL-INPUT" chainNameInputRules = "NETBIRD-ACL-INPUT"
chainNameOutputRules = "NETBIRD-ACL-OUTPUT" chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
postRoutingMark = "0x000007e4"
) )
type aclManager struct { type aclEntries map[string][][]string
iptablesClient *iptables.IPTables
wgIface iFaceMapper
routeingFwChainName string
entries map[string][][]string type entry struct {
ipsetStore *ipsetStore spec []string
position int
} }
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routeingFwChainName string) (*aclManager, error) { type aclManager struct {
iptablesClient *iptables.IPTables
wgIface iFaceMapper
routingFwChainName string
entries aclEntries
optionalEntries map[string][]entry
ipsetStore *ipsetStore
stateManager *statemanager.Manager
}
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
m := &aclManager{ m := &aclManager{
iptablesClient: iptablesClient, iptablesClient: iptablesClient,
wgIface: wgIface, wgIface: wgIface,
routeingFwChainName: routeingFwChainName, routingFwChainName: routingFwChainName,
entries: make(map[string][][]string), entries: make(map[string][][]string),
ipsetStore: newIpsetStore(), optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(),
} }
err := ipset.Init() if err := ipset.Init(); err != nil {
if err != nil { return nil, fmt.Errorf("init ipset: %w", err)
return nil, fmt.Errorf("failed to init ipset: %w", err)
} }
m.seedInitialEntries()
err = m.cleanChains()
if err != nil {
return nil, err
}
err = m.createDefaultChains()
if err != nil {
return nil, err
}
return m, nil return m, nil
} }
func (m *aclManager) AddFiltering( func (m *aclManager) init(stateManager *statemanager.Manager) error {
m.stateManager = stateManager
m.seedInitialEntries()
m.seedInitialOptionalEntries()
if err := m.cleanChains(); err != nil {
return fmt.Errorf("clean chains: %w", err)
}
if err := m.createDefaultChains(); err != nil {
return fmt.Errorf("create default chains: %w", err)
}
m.updateState()
return nil
}
func (m *aclManager) AddPeerFiltering(
ip net.IP, ip net.IP,
protocol firewall.Protocol, protocol firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
@@ -127,7 +145,7 @@ func (m *aclManager) AddFiltering(
return nil, fmt.Errorf("rule already exists") return nil, fmt.Errorf("rule already exists")
} }
if err := m.iptablesClient.Insert("filter", chain, 1, specs...); err != nil { if err := m.iptablesClient.Append("filter", chain, specs...); err != nil {
return nil, err return nil, err
} }
@@ -139,28 +157,18 @@ func (m *aclManager) AddFiltering(
chain: chain, chain: chain,
} }
if !shouldAddToPrerouting(protocol, dPort, direction) { m.updateState()
return []firewall.Rule{rule}, nil
}
rulePrerouting, err := m.addPreroutingFilter(ipsetName, string(protocol), dPortVal, ip) return []firewall.Rule{rule}, nil
if err != nil {
return []firewall.Rule{rule}, err
}
return []firewall.Rule{rule, rulePrerouting}, nil
} }
// DeleteRule from the firewall by rule definition // DeletePeerRule from the firewall by rule definition
func (m *aclManager) DeleteRule(rule firewall.Rule) error { func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
r, ok := rule.(*Rule) r, ok := rule.(*Rule)
if !ok { if !ok {
return fmt.Errorf("invalid rule type") return fmt.Errorf("invalid rule type")
} }
if r.chain == "PREROUTING" {
goto DELETERULE
}
if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok { if ipsetList, ok := m.ipsetStore.ipset(r.ipsetName); ok {
// delete IP from ruleset IPs list and ipset // delete IP from ruleset IPs list and ipset
if _, ok := ipsetList.ips[r.ip]; ok { if _, ok := ipsetList.ips[r.ip]; ok {
@@ -185,60 +193,23 @@ func (m *aclManager) DeleteRule(rule firewall.Rule) error {
} }
} }
DELETERULE: if err := m.iptablesClient.Delete(tableName, r.chain, r.specs...); err != nil {
var table string return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
if r.chain == "PREROUTING" {
table = "mangle"
} else {
table = "filter"
} }
err := m.iptablesClient.Delete(table, r.chain, r.specs...)
if err != nil { m.updateState()
log.Debugf("failed to delete rule, %s, %v: %s", r.chain, r.specs, err)
} return nil
return err
} }
func (m *aclManager) Reset() error { func (m *aclManager) Reset() error {
return m.cleanChains() if err := m.cleanChains(); err != nil {
} return fmt.Errorf("clean chains: %w", err)
func (m *aclManager) addPreroutingFilter(ipsetName string, protocol string, port string, ip net.IP) (*Rule, error) {
var src []string
if ipsetName != "" {
src = []string{"-m", "set", "--set", ipsetName, "src"}
} else {
src = []string{"-s", ip.String()}
}
specs := []string{
"-d", m.wgIface.Address().IP.String(),
"-p", protocol,
"--dport", port,
"-j", "MARK", "--set-mark", postRoutingMark,
} }
specs = append(src, specs...) m.updateState()
ok, err := m.iptablesClient.Exists("mangle", "PREROUTING", specs...) return nil
if err != nil {
return nil, fmt.Errorf("failed to check rule: %w", err)
}
if ok {
return nil, fmt.Errorf("rule already exists")
}
if err := m.iptablesClient.Insert("mangle", "PREROUTING", 1, specs...); err != nil {
return nil, err
}
rule := &Rule{
ruleID: uuid.New().String(),
specs: specs,
ipsetName: ipsetName,
ip: ip.String(),
chain: "PREROUTING",
}
return rule, nil
} }
// todo write less destructive cleanup mechanism // todo write less destructive cleanup mechanism
@@ -293,8 +264,7 @@ func (m *aclManager) cleanChains() error {
ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING") ok, err = m.iptablesClient.ChainExists("mangle", "PREROUTING")
if err != nil { if err != nil {
log.Debugf("failed to list chains: %s", err) return fmt.Errorf("list chains: %w", err)
return err
} }
if ok { if ok {
for _, rule := range m.entries["PREROUTING"] { for _, rule := range m.entries["PREROUTING"] {
@@ -303,11 +273,6 @@ func (m *aclManager) cleanChains() error {
log.Errorf("failed to delete rule: %v, %s", rule, err) log.Errorf("failed to delete rule: %v, %s", rule, err)
} }
} }
err = m.iptablesClient.ClearChain("mangle", "PREROUTING")
if err != nil {
log.Debugf("failed to clear %s chain: %s", "PREROUTING", err)
return err
}
} }
for _, ipsetName := range m.ipsetStore.ipsetNames() { for _, ipsetName := range m.ipsetStore.ipsetNames() {
@@ -338,64 +303,98 @@ func (m *aclManager) createDefaultChains() error {
for chainName, rules := range m.entries { for chainName, rules := range m.entries {
for _, rule := range rules { for _, rule := range rules {
if chainName == "FORWARD" { if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
// position 2 because we add it after router's, jump rule log.Debugf("failed to create input chain jump rule: %s", err)
if err := m.iptablesClient.InsertUnique(tableName, "FORWARD", 2, rule...); err != nil { return err
log.Debugf("failed to create input chain jump rule: %s", err)
return err
}
} else {
if err := m.iptablesClient.AppendUnique(tableName, chainName, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err)
return err
}
} }
} }
} }
for chainName, entries := range m.optionalEntries {
for _, entry := range entries {
if err := m.iptablesClient.InsertUnique(tableName, chainName, entry.position, entry.spec...); err != nil {
log.Errorf("failed to insert optional entry %v: %v", entry.spec, err)
continue
}
m.entries[chainName] = append(m.entries[chainName], entry.spec)
}
}
clear(m.optionalEntries)
return nil return nil
} }
// seedInitialEntries adds default rules to the entries map, rules are inserted on pos 1, hence the order is reversed.
// We want to make sure our traffic is not dropped by existing rules.
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
func (m *aclManager) seedInitialEntries() { func (m *aclManager) seedInitialEntries() {
m.appendToEntries("INPUT",
[]string{"-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("INPUT", established := getConntrackEstablished()
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("INPUT",
[]string{"-i", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("OUTPUT", m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
[]string{"-o", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("OUTPUT",
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("OUTPUT",
[]string{"-o", m.wgIface.Name(), "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().String(), "-j", chainNameOutputRules})
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"}) m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", chainNameOutputRules})
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
m.appendToEntries("OUTPUT", append([]string{"-o", m.wgIface.Name()}, established...))
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules}) m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
m.appendToEntries("FORWARD", m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
[]string{"-o", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"}) }
m.appendToEntries("FORWARD",
[]string{"-i", m.wgIface.Name(), "-m", "mark", "--mark", postRoutingMark, "-j", "ACCEPT"})
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", m.routeingFwChainName})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routeingFwChainName})
m.appendToEntries("PREROUTING", func (m *aclManager) seedInitialOptionalEntries() {
[]string{"-t", "mangle", "-i", m.wgIface.Name(), "!", "-s", m.wgIface.Address().String(), "-d", m.wgIface.Address().IP.String(), "-m", "mark", "--mark", postRoutingMark}) m.optionalEntries["FORWARD"] = []entry{
{
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", chainNameInputRules},
position: 2,
},
}
m.optionalEntries["PREROUTING"] = []entry{
{
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected)},
position: 1,
},
}
} }
func (m *aclManager) appendToEntries(chainName string, spec []string) { func (m *aclManager) appendToEntries(chainName string, spec []string) {
m.entries[chainName] = append(m.entries[chainName], spec) m.entries[chainName] = append(m.entries[chainName], spec)
} }
func (m *aclManager) updateState() {
if m.stateManager == nil {
return
}
var currentState *ShutdownState
if existing := m.stateManager.GetState(currentState); existing != nil {
if existingState, ok := existing.(*ShutdownState); ok {
currentState = existingState
}
}
if currentState == nil {
currentState = &ShutdownState{}
}
currentState.Lock()
defer currentState.Unlock()
currentState.ACLEntries = m.entries
currentState.ACLIPsetStore = m.ipsetStore
if err := m.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err)
}
}
// filterRuleSpecs returns the specs of a filtering rule // filterRuleSpecs returns the specs of a filtering rule
func filterRuleSpecs( func filterRuleSpecs(
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string, ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
@@ -456,18 +455,3 @@ func transformIPsetName(ipsetName string, sPort, dPort string) string {
return ipsetName return ipsetName
} }
} }
func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
if proto == "all" {
return false
}
if direction != firewall.RuleDirectionIN {
return false
}
if dPort == nil {
return false
}
return true
}

View File

@@ -4,13 +4,17 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
// Manager of iptables firewall // Manager of iptables firewall
@@ -21,7 +25,7 @@ type Manager struct {
ipv4Client *iptables.IPTables ipv4Client *iptables.IPTables
aclMgr *aclManager aclMgr *aclManager
router *routerManager router *router
} }
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
@@ -32,10 +36,10 @@ type iFaceMapper interface {
} }
// Create iptables firewall manager // Create iptables firewall manager
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { func Create(wgIface iFaceMapper) (*Manager, error) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil { if err != nil {
return nil, fmt.Errorf("iptables is not installed in the system or not supported") return nil, fmt.Errorf("init iptables: %w", err)
} }
m := &Manager{ m := &Manager{
@@ -43,24 +47,55 @@ func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) {
ipv4Client: iptablesClient, ipv4Client: iptablesClient,
} }
m.router, err = newRouterManager(context, iptablesClient) m.router, err = newRouter(iptablesClient, wgIface)
if err != nil { if err != nil {
log.Debugf("failed to initialize route related chains: %s", err) return nil, fmt.Errorf("create router: %w", err)
return nil, err
} }
m.aclMgr, err = newAclManager(iptablesClient, wgIface, m.router.RouteingFwChainName())
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
if err != nil { if err != nil {
log.Debugf("failed to initialize ACL manager: %s", err) return nil, fmt.Errorf("create acl manager: %w", err)
return nil, err
} }
return m, nil return m, nil
} }
// AddFiltering rule to the firewall func (m *Manager) Init(stateManager *statemanager.Manager) error {
state := &ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
},
}
stateManager.RegisterState(state)
if err := stateManager.UpdateState(state); err != nil {
log.Errorf("failed to update state: %v", err)
}
if err := m.router.init(stateManager); err != nil {
return fmt.Errorf("router init: %w", err)
}
if err := m.aclMgr.init(stateManager); err != nil {
// TODO: cleanup router
return fmt.Errorf("acl manager init: %w", err)
}
// persist early to ensure cleanup of chains
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
}()
return nil
}
// AddPeerFiltering adds a rule to the firewall
// //
// Comment will be ignored because some system this feature is not supported // Comment will be ignored because some system this feature is not supported
func (m *Manager) AddFiltering( func (m *Manager) AddPeerFiltering(
ip net.IP, ip net.IP,
protocol firewall.Protocol, protocol firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
@@ -73,50 +108,86 @@ func (m *Manager) AddFiltering(
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.aclMgr.AddFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName) return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
} }
// DeleteRule from the firewall by rule definition func (m *Manager) AddRouteFiltering(
func (m *Manager) DeleteRule(rule firewall.Rule) error { sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.aclMgr.DeleteRule(rule) if !destination.Addr().Is4() {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
}
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclMgr.DeletePeerRule(rule)
}
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.DeleteRouteRule(rule)
} }
func (m *Manager) IsServerRouteSupported() bool { func (m *Manager) IsServerRouteSupported() bool {
return true return true
} }
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error { func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.InsertRoutingRules(pair) return m.router.AddNatRule(pair)
} }
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error { func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.RemoveRoutingRules(pair) return m.router.RemoveNatRule(pair)
}
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return firewall.SetLegacyManagement(m.router, isLegacy)
} }
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset() error { func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
errAcl := m.aclMgr.Reset() var merr *multierror.Error
if errAcl != nil {
log.Errorf("failed to clean up ACL rules from firewall: %s", errAcl) if err := m.aclMgr.Reset(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err))
} }
errMgr := m.router.Reset() if err := m.router.Reset(); err != nil {
if errMgr != nil { merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
log.Errorf("failed to clean up router rules from firewall: %s", errMgr)
return errMgr
} }
return errAcl
// attempt to delete state only if all other operations succeeded
if merr == nil {
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete state: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)
} }
// AllowNetbird allows netbird interface traffic // AllowNetbird allows netbird interface traffic
@@ -125,7 +196,7 @@ func (m *Manager) AllowNetbird() error {
return nil return nil
} }
_, err := m.AddFiltering( _, err := m.AddPeerFiltering(
net.ParseIP("0.0.0.0"), net.ParseIP("0.0.0.0"),
"all", "all",
nil, nil,
@@ -138,7 +209,7 @@ func (m *Manager) AllowNetbird() error {
if err != nil { if err != nil {
return fmt.Errorf("failed to allow netbird interface traffic: %w", err) return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
} }
_, err = m.AddFiltering( _, err = m.AddPeerFiltering(
net.ParseIP("0.0.0.0"), net.ParseIP("0.0.0.0"),
"all", "all",
nil, nil,
@@ -153,3 +224,7 @@ func (m *Manager) AllowNetbird() error {
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil } func (m *Manager) Flush() error { return nil }
func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
}

View File

@@ -1,7 +1,6 @@
package iptables package iptables
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"testing" "testing"
@@ -11,9 +10,24 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/client/iface"
) )
var ifaceMock = &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct { type iFaceMock struct {
NameFunc func() string NameFunc func() string
@@ -40,29 +54,15 @@ func TestIptablesManager(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err) require.NoError(t, err)
mock := &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
// just check on the local interface // just check on the local interface
manager, err := Create(context.Background(), mock) manager, err := Create(ifaceMock)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
err := manager.Reset() err := manager.Reset(nil)
require.NoError(t, err, "clear the manager state") require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -72,7 +72,7 @@ func TestIptablesManager(t *testing.T) {
t.Run("add first rule", func(t *testing.T) { t.Run("add first rule", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2") ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}} port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") rule1, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
for _, r := range rule1 { for _, r := range rule1 {
@@ -87,7 +87,7 @@ func TestIptablesManager(t *testing.T) {
port := &fw.Port{ port := &fw.Port{
Values: []int{8043: 8046}, Values: []int{8043: 8046},
} }
rule2, err = manager.AddFiltering( rule2, err = manager.AddPeerFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range") ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
@@ -99,7 +99,7 @@ func TestIptablesManager(t *testing.T) {
t.Run("delete first rule", func(t *testing.T) { t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 { for _, r := range rule1 {
err := manager.DeleteRule(r) err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule") require.NoError(t, err, "failed to delete rule")
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...) checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
@@ -108,7 +108,7 @@ func TestIptablesManager(t *testing.T) {
t.Run("delete second rule", func(t *testing.T) { t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 { for _, r := range rule2 {
err := manager.DeleteRule(r) err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule") require.NoError(t, err, "failed to delete rule")
} }
@@ -119,10 +119,10 @@ func TestIptablesManager(t *testing.T) {
// add second rule // add second rule
ip := net.ParseIP("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []int{5353}} port := &fw.Port{Values: []int{5353}}
_, err = manager.AddFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic") _, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Reset() err = manager.Reset(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules) ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
@@ -154,13 +154,14 @@ func TestIptablesManagerIPSet(t *testing.T) {
} }
// just check on the local interface // just check on the local interface
manager, err := Create(context.Background(), mock) manager, err := Create(mock)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
err := manager.Reset() err := manager.Reset(nil)
require.NoError(t, err, "clear the manager state") require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -170,7 +171,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
t.Run("add first rule with set", func(t *testing.T) { t.Run("add first rule with set", func(t *testing.T) {
ip := net.ParseIP("10.20.0.2") ip := net.ParseIP("10.20.0.2")
port := &fw.Port{Values: []int{8080}} port := &fw.Port{Values: []int{8080}}
rule1, err = manager.AddFiltering( rule1, err = manager.AddPeerFiltering(
ip, "tcp", nil, port, fw.RuleDirectionOUT, ip, "tcp", nil, port, fw.RuleDirectionOUT,
fw.ActionAccept, "default", "accept HTTP traffic", fw.ActionAccept, "default", "accept HTTP traffic",
) )
@@ -189,7 +190,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
port := &fw.Port{ port := &fw.Port{
Values: []int{443}, Values: []int{443},
} }
rule2, err = manager.AddFiltering( rule2, err = manager.AddPeerFiltering(
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
"default", "accept HTTPS traffic from ports range", "default", "accept HTTPS traffic from ports range",
) )
@@ -202,7 +203,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
t.Run("delete first rule", func(t *testing.T) { t.Run("delete first rule", func(t *testing.T) {
for _, r := range rule1 { for _, r := range rule1 {
err := manager.DeleteRule(r) err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule") require.NoError(t, err, "failed to delete rule")
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index") require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
@@ -211,7 +212,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
t.Run("delete second rule", func(t *testing.T) { t.Run("delete second rule", func(t *testing.T) {
for _, r := range rule2 { for _, r := range rule2 {
err := manager.DeleteRule(r) err := manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule") require.NoError(t, err, "failed to delete rule")
require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty") require.Empty(t, manager.aclMgr.ipsetStore.ipsets, "rulesets index after removed second rule must be empty")
@@ -219,7 +220,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
}) })
t.Run("reset check", func(t *testing.T) { t.Run("reset check", func(t *testing.T) {
err = manager.Reset() err = manager.Reset(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
}) })
} }
@@ -251,12 +252,13 @@ func TestIptablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(context.Background(), mock) manager, err := Create(mock)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
err := manager.Reset() err := manager.Reset(nil)
require.NoError(t, err, "clear the manager state") require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -269,9 +271,9 @@ func TestIptablesCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 { if i%2 == 0 {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else { } else {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
} }
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")

View File

@@ -3,370 +3,597 @@
package iptables package iptables
import ( import (
"context"
"fmt" "fmt"
"net/netip"
"strconv"
"strings" "strings"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/hashicorp/go-multierror"
"github.com/nadoo/ipset"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
) "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
const ( "github.com/netbirdio/netbird/client/internal/statemanager"
Ipv4Forwarding = "netbird-rt-forwarding" nbnet "github.com/netbirdio/netbird/util/net"
ipv4Nat = "netbird-rt-nat"
) )
// constants needed to manage and create iptable rules // constants needed to manage and create iptable rules
const ( const (
tableFilter = "filter" tableFilter = "filter"
tableNat = "nat" tableNat = "nat"
chainFORWARD = "FORWARD" tableMangle = "mangle"
chainPOSTROUTING = "POSTROUTING" chainPOSTROUTING = "POSTROUTING"
chainPREROUTING = "PREROUTING"
chainRTNAT = "NETBIRD-RT-NAT" chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWD = "NETBIRD-RT-FWD" chainRTFWD = "NETBIRD-RT-FWD"
chainRTPRE = "NETBIRD-RT-PRE"
routingFinalForwardJump = "ACCEPT" routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE" routingFinalNatJump = "MASQUERADE"
jumpPre = "jump-pre"
jumpNat = "jump-nat"
matchSet = "--match-set"
) )
type routerManager struct { type routeFilteringRuleParams struct {
ctx context.Context Sources []netip.Prefix
stop context.CancelFunc Destination netip.Prefix
iptablesClient *iptables.IPTables Proto firewall.Protocol
rules map[string][]string SPort *firewall.Port
DPort *firewall.Port
Direction firewall.RuleDirection
Action firewall.Action
SetName string
} }
func newRouterManager(parentCtx context.Context, iptablesClient *iptables.IPTables) (*routerManager, error) { type routeRules map[string][]string
ctx, cancel := context.WithCancel(parentCtx)
m := &routerManager{ type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
ctx: ctx,
stop: cancel, type router struct {
iptablesClient *iptables.IPTables
rules routeRules
ipsetCounter *ipsetCounter
wgIface iFaceMapper
legacyManagement bool
stateManager *statemanager.Manager
}
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
r := &router{
iptablesClient: iptablesClient, iptablesClient: iptablesClient,
rules: make(map[string][]string), rules: make(map[string][]string),
wgIface: wgIface,
} }
err := m.cleanUpDefaultForwardRules() r.ipsetCounter = refcounter.New(
if err != nil { func(name string, sources []netip.Prefix) (struct{}, error) {
log.Errorf("failed to cleanup routing rules: %s", err) return struct{}{}, r.createIpSet(name, sources)
return nil, err },
func(name string, _ struct{}) error {
return r.deleteIpSet(name)
},
)
if err := ipset.Init(); err != nil {
return nil, fmt.Errorf("init ipset: %w", err)
} }
err = m.createContainers()
if err != nil { return r, nil
log.Errorf("failed to create containers for route: %s", err)
}
return m, err
} }
// InsertRoutingRules inserts an iptables rule pair to the forwarding chain and if enabled, to the nat chain func (r *router) init(stateManager *statemanager.Manager) error {
func (i *routerManager) InsertRoutingRules(pair firewall.RouterPair) error { r.stateManager = stateManager
err := i.insertRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, pair)
if err != nil { if err := r.cleanUpDefaultForwardRules(); err != nil {
return err log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
} }
err = i.insertRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, routingFinalForwardJump, firewall.GetInPair(pair)) if err := r.createContainers(); err != nil {
if err != nil { return fmt.Errorf("create containers: %w", err)
return err }
r.updateState()
return nil
}
func (r *router) AddRouteFiltering(
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil
}
var setName string
if len(sources) > 1 {
setName = firewall.GenerateSetName(sources)
if _, err := r.ipsetCounter.Increment(setName, sources); err != nil {
return nil, fmt.Errorf("create or get ipset: %w", err)
}
}
params := routeFilteringRuleParams{
Sources: sources,
Destination: destination,
Proto: proto,
SPort: sPort,
DPort: dPort,
Action: action,
SetName: setName,
}
rule := genRouteFilteringRuleSpec(params)
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
return nil, fmt.Errorf("add route rule: %v", err)
}
r.rules[string(ruleKey)] = rule
r.updateState()
return ruleKey, nil
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
ruleKey := rule.GetRuleID()
if rule, exists := r.rules[ruleKey]; exists {
setName := r.findSetNameInRule(rule)
if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil {
return fmt.Errorf("delete route rule: %v", err)
}
delete(r.rules, ruleKey)
if setName != "" {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
return fmt.Errorf("failed to remove ipset: %w", err)
}
}
} else {
log.Debugf("route rule %s not found", ruleKey)
}
r.updateState()
return nil
}
func (r *router) findSetNameInRule(rule []string) string {
for i, arg := range rule {
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
return rule[i+3]
}
}
return ""
}
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
if err := ipset.Create(setName, ipset.OptTimeout(0)); err != nil {
return fmt.Errorf("create set %s: %w", setName, err)
}
for _, prefix := range sources {
if err := ipset.AddPrefix(setName, prefix); err != nil {
return fmt.Errorf("add element to set %s: %w", setName, err)
}
}
return nil
}
func (r *router) deleteIpSet(setName string) error {
if err := ipset.Destroy(setName); err != nil {
return fmt.Errorf("destroy set %s: %w", setName, err)
}
return nil
}
// AddNatRule inserts an iptables rule pair into the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error {
if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil {
return fmt.Errorf("add legacy routing rule: %w", err)
}
} }
if !pair.Masquerade { if !pair.Masquerade {
return nil return nil
} }
err = i.addNATRule(firewall.NatFormat, tableNat, chainRTNAT, routingFinalNatJump, pair) if err := r.addNatRule(pair); err != nil {
if err != nil { return fmt.Errorf("add nat rule: %w", err)
return err
} }
err = i.addNATRule(firewall.InNatFormat, tableNat, chainRTNAT, routingFinalNatJump, firewall.GetInPair(pair)) if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
if err != nil { return fmt.Errorf("add inverse nat rule: %w", err)
return err
} }
r.updateState()
return nil return nil
} }
// insertRoutingRule inserts an iptables rule // RemoveNatRule removes an iptables rule pair from forwarding and nat chains
func (i *routerManager) insertRoutingRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error { func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
var err error if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err)
}
ruleKey := firewall.GenKey(keyFormat, pair.ID) if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
rule := genRuleSpec(jump, pair.Source, pair.Destination) return fmt.Errorf("remove inverse nat rule: %w", err)
existingRule, found := i.rules[ruleKey] }
if found {
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...) if err := r.removeLegacyRouteRule(pair); err != nil {
if err != nil { return fmt.Errorf("remove legacy routing rule: %w", err)
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err) }
r.updateState()
return nil
}
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if err := r.removeLegacyRouteRule(pair); err != nil {
return err
}
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
r.rules[ruleKey] = rule
return nil
}
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
} }
delete(i.rules, ruleKey) delete(r.rules, ruleKey)
} } else {
log.Debugf("legacy forwarding rule %s not found", ruleKey)
err = i.iptablesClient.Insert(table, chain, 1, rule...)
if err != nil {
return fmt.Errorf("error while adding new %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err)
}
i.rules[ruleKey] = rule
return nil
}
// RemoveRoutingRules removes an iptables rule pair from forwarding and nat chains
func (i *routerManager) RemoveRoutingRules(pair firewall.RouterPair) error {
err := i.removeRoutingRule(firewall.ForwardingFormat, tableFilter, chainRTFWD, pair)
if err != nil {
return err
}
err = i.removeRoutingRule(firewall.InForwardingFormat, tableFilter, chainRTFWD, firewall.GetInPair(pair))
if err != nil {
return err
}
if !pair.Masquerade {
return nil
}
err = i.removeRoutingRule(firewall.NatFormat, tableNat, chainRTNAT, pair)
if err != nil {
return err
}
err = i.removeRoutingRule(firewall.InNatFormat, tableNat, chainRTNAT, firewall.GetInPair(pair))
if err != nil {
return err
} }
return nil return nil
} }
func (i *routerManager) removeRoutingRule(keyFormat, table, chain string, pair firewall.RouterPair) error { // GetLegacyManagement returns the current legacy management mode
var err error func (r *router) GetLegacyManagement() bool {
return r.legacyManagement
}
ruleKey := firewall.GenKey(keyFormat, pair.ID) // SetLegacyManagement sets the route manager to use legacy management mode
existingRule, found := i.rules[ruleKey] func (r *router) SetLegacyManagement(isLegacy bool) {
if found { r.legacyManagement = isLegacy
err = i.iptablesClient.DeleteIfExists(table, chain, existingRule...) }
if err != nil {
return fmt.Errorf("error while removing existing %s rule for %s: %v", getIptablesRuleType(table), pair.Destination, err) // RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
func (r *router) RemoveAllLegacyRouteRules() error {
var merr *multierror.Error
for k, rule := range r.rules {
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
continue
}
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
} else {
delete(r.rules, k)
} }
} }
delete(i.rules, ruleKey)
return nil r.updateState()
return nberrors.FormatErrorOrNil(merr)
} }
func (i *routerManager) RouteingFwChainName() string { func (r *router) Reset() error {
return chainRTFWD var merr *multierror.Error
} if err := r.cleanUpDefaultForwardRules(); err != nil {
merr = multierror.Append(merr, err)
func (i *routerManager) Reset() error {
err := i.cleanUpDefaultForwardRules()
if err != nil {
return err
} }
i.rules = make(map[string][]string) r.rules = make(map[string][]string)
return nil
if err := r.ipsetCounter.Flush(); err != nil {
merr = multierror.Append(merr, err)
}
r.updateState()
return nberrors.FormatErrorOrNil(merr)
} }
func (i *routerManager) cleanUpDefaultForwardRules() error { func (r *router) cleanUpDefaultForwardRules() error {
err := i.cleanJumpRules() if err := r.cleanJumpRules(); err != nil {
if err != nil { return fmt.Errorf("clean jump rules: %w", err)
return err
} }
log.Debug("flushing routing related tables") log.Debug("flushing routing related tables")
ok, err := i.iptablesClient.ChainExists(tableFilter, chainRTFWD) for _, chainInfo := range []struct {
if err != nil { chain string
log.Errorf("failed check chain %s,error: %v", chainRTFWD, err) table string
return err }{
} else if ok { {chainRTFWD, tableFilter},
err = i.iptablesClient.ClearAndDeleteChain(tableFilter, chainRTFWD) {chainRTNAT, tableNat},
{chainRTPRE, tableMangle},
} {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
if err != nil { if err != nil {
log.Errorf("failed cleaning chain %s,error: %v", chainRTFWD, err) return fmt.Errorf("check chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
return err } else if ok {
if err = r.iptablesClient.ClearAndDeleteChain(chainInfo.table, chainInfo.chain); err != nil {
return fmt.Errorf("clear and delete chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
}
} }
} }
ok, err = i.iptablesClient.ChainExists(tableNat, chainRTNAT) return nil
}
func (r *router) createContainers() error {
for _, chainInfo := range []struct {
chain string
table string
}{
{chainRTFWD, tableFilter},
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
} {
if err := r.createAndSetupChain(chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
}
}
if err := r.insertEstablishedRule(chainRTFWD); err != nil {
return fmt.Errorf("insert established rule: %w", err)
}
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add static nat rules: %w", err)
}
if err := r.addJumpRules(); err != nil {
return fmt.Errorf("add jump rules: %w", err)
}
return nil
}
func (r *router) addPostroutingRules() error {
// First rule for outbound masquerade
rule1 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
"!", "-o", "lo",
"-j", routingFinalNatJump,
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule1...); err != nil {
return fmt.Errorf("add outbound masquerade rule: %v", err)
}
r.rules["static-nat-outbound"] = rule1
// Second rule for return traffic masquerade
rule2 := []string{
"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
"-o", r.wgIface.Name(),
"-j", routingFinalNatJump,
}
if err := r.iptablesClient.Append(tableNat, chainRTNAT, rule2...); err != nil {
return fmt.Errorf("add return masquerade rule: %v", err)
}
r.rules["static-nat-return"] = rule2
return nil
}
func (r *router) createAndSetupChain(chain string) error {
table := r.getTableForChain(chain)
if err := r.iptablesClient.NewChain(table, chain); err != nil {
return fmt.Errorf("failed creating chain %s, error: %v", chain, err)
}
return nil
}
func (r *router) getTableForChain(chain string) string {
switch chain {
case chainRTNAT:
return tableNat
case chainRTPRE:
return tableMangle
default:
return tableFilter
}
}
func (r *router) insertEstablishedRule(chain string) error {
establishedRule := getConntrackEstablished()
err := r.iptablesClient.Insert(tableFilter, chain, 1, establishedRule...)
if err != nil { if err != nil {
log.Errorf("failed check chain %s,error: %v", chainRTNAT, err) return fmt.Errorf("failed to insert established rule: %v", err)
return err }
} else if ok {
err = i.iptablesClient.ClearAndDeleteChain(tableNat, chainRTNAT) ruleKey := "established-" + chain
if err != nil { r.rules[ruleKey] = establishedRule
log.Errorf("failed cleaning chain %s,error: %v", chainRTNAT, err)
return err return nil
}
func (r *router) addJumpRules() error {
// Jump to NAT chain
natRule := []string{"-j", chainRTNAT}
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
return fmt.Errorf("add nat jump rule: %v", err)
}
r.rules[jumpNat] = natRule
// Jump to prerouting chain
preRule := []string{"-j", chainRTPRE}
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
return fmt.Errorf("add prerouting jump rule: %v", err)
}
r.rules[jumpPre] = preRule
return nil
}
func (r *router) cleanJumpRules() error {
for _, ruleKey := range []string{jumpNat, jumpPre} {
if rule, exists := r.rules[ruleKey]; exists {
table := tableNat
chain := chainPOSTROUTING
if ruleKey == jumpPre {
table = tableMangle
chain = chainPREROUTING
}
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
return fmt.Errorf("delete rule from chain %s in table %s, err: %v", chain, table, err)
}
delete(r.rules, ruleKey)
} }
} }
return nil return nil
} }
func (i *routerManager) createContainers() error { func (r *router) addNatRule(pair firewall.RouterPair) error {
if i.rules[Ipv4Forwarding] != nil { ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while removing existing marking rule for %s: %v", pair.Destination, err)
}
delete(r.rules, ruleKey)
}
markValue := nbnet.PreroutingFwmarkMasquerade
if pair.Inverse {
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
}
rule := []string{"-i", r.wgIface.Name()}
if pair.Inverse {
rule = []string{"!", "-i", r.wgIface.Name()}
}
rule = append(rule,
"-m", "conntrack",
"--ctstate", "NEW",
"-s", pair.Source.String(),
"-d", pair.Destination.String(),
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
)
if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
}
r.rules[ruleKey] = rule
return nil
}
func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.NatFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainRTPRE, rule...); err != nil {
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
}
delete(r.rules, ruleKey)
} else {
log.Debugf("marking rule %s not found", ruleKey)
}
return nil
}
func (r *router) updateState() {
if r.stateManager == nil {
return
}
var currentState *ShutdownState
if existing := r.stateManager.GetState(currentState); existing != nil {
if existingState, ok := existing.(*ShutdownState); ok {
currentState = existingState
}
}
if currentState == nil {
currentState = &ShutdownState{}
}
currentState.Lock()
defer currentState.Unlock()
currentState.RouteRules = r.rules
currentState.RouteIPsetCounter = r.ipsetCounter
if err := r.stateManager.UpdateState(currentState); err != nil {
log.Errorf("failed to update state: %v", err)
}
}
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
var rule []string
if params.SetName != "" {
rule = append(rule, "-m", "set", matchSet, params.SetName, "src")
} else if len(params.Sources) > 0 {
source := params.Sources[0]
rule = append(rule, "-s", source.String())
}
rule = append(rule, "-d", params.Destination.String())
if params.Proto != firewall.ProtocolALL {
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
rule = append(rule, applyPort("--sport", params.SPort)...)
rule = append(rule, applyPort("--dport", params.DPort)...)
}
rule = append(rule, "-j", actionToStr(params.Action))
return rule
}
func applyPort(flag string, port *firewall.Port) []string {
if port == nil {
return nil return nil
} }
errMSGFormat := "failed creating chain %s,error: %v" if port.IsRange && len(port.Values) == 2 {
err := i.createChain(tableFilter, chainRTFWD) return []string{flag, fmt.Sprintf("%d:%d", port.Values[0], port.Values[1])}
if err != nil {
return fmt.Errorf(errMSGFormat, chainRTFWD, err)
} }
err = i.createChain(tableNat, chainRTNAT) if len(port.Values) > 1 {
if err != nil { portList := make([]string, len(port.Values))
return fmt.Errorf(errMSGFormat, chainRTNAT, err) for i, p := range port.Values {
portList[i] = strconv.Itoa(p)
}
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
} }
err = i.addJumpRules() return []string{flag, strconv.Itoa(port.Values[0])}
if err != nil {
return fmt.Errorf("error while creating jump rules: %v", err)
}
return nil
}
// addJumpRules create jump rules to send packets to NetBird chains
func (i *routerManager) addJumpRules() error {
rule := []string{"-j", chainRTFWD}
err := i.iptablesClient.Insert(tableFilter, chainFORWARD, 1, rule...)
if err != nil {
return err
}
i.rules[Ipv4Forwarding] = rule
rule = []string{"-j", chainRTNAT}
err = i.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, rule...)
if err != nil {
return err
}
i.rules[ipv4Nat] = rule
return nil
}
// cleanJumpRules cleans jump rules that was sending packets to NetBird chains
func (i *routerManager) cleanJumpRules() error {
var err error
errMSGFormat := "failed cleaning rule from chain %s,err: %v"
rule, found := i.rules[Ipv4Forwarding]
if found {
err = i.iptablesClient.DeleteIfExists(tableFilter, chainFORWARD, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, chainFORWARD, err)
}
}
rule, found = i.rules[ipv4Nat]
if found {
err = i.iptablesClient.DeleteIfExists(tableNat, chainPOSTROUTING, rule...)
if err != nil {
return fmt.Errorf(errMSGFormat, chainPOSTROUTING, err)
}
}
rules, err := i.iptablesClient.List("nat", "POSTROUTING")
if err != nil {
return fmt.Errorf("failed to list rules: %s", err)
}
for _, ruleString := range rules {
if !strings.Contains(ruleString, "NETBIRD") {
continue
}
rule := strings.Fields(ruleString)
err := i.iptablesClient.DeleteIfExists("nat", "POSTROUTING", rule[2:]...)
if err != nil {
return fmt.Errorf("failed to delete postrouting jump rule: %s", err)
}
}
rules, err = i.iptablesClient.List(tableFilter, "FORWARD")
if err != nil {
return fmt.Errorf("failed to list rules in FORWARD chain: %s", err)
}
for _, ruleString := range rules {
if !strings.Contains(ruleString, "NETBIRD") {
continue
}
rule := strings.Fields(ruleString)
err := i.iptablesClient.DeleteIfExists(tableFilter, "FORWARD", rule[2:]...)
if err != nil {
return fmt.Errorf("failed to delete FORWARD jump rule: %s", err)
}
}
return nil
}
func (i *routerManager) createChain(table, newChain string) error {
chains, err := i.iptablesClient.ListChains(table)
if err != nil {
return fmt.Errorf("couldn't get %s table chains, error: %v", table, err)
}
shouldCreateChain := true
for _, chain := range chains {
if chain == newChain {
shouldCreateChain = false
}
}
if shouldCreateChain {
err = i.iptablesClient.NewChain(table, newChain)
if err != nil {
return fmt.Errorf("couldn't create chain %s in %s table, error: %v", newChain, table, err)
}
// Add the loopback return rule to the NAT chain
loopbackRule := []string{"-o", "lo", "-j", "RETURN"}
err = i.iptablesClient.Insert(table, newChain, 1, loopbackRule...)
if err != nil {
return fmt.Errorf("failed to add loopback return rule to %s: %v", chainRTNAT, err)
}
err = i.iptablesClient.Append(table, newChain, "-j", "RETURN")
if err != nil {
return fmt.Errorf("couldn't create chain %s default rule, error: %v", newChain, err)
}
}
return nil
}
// addNATRule appends an iptables rule pair to the nat chain
func (i *routerManager) addNATRule(keyFormat, table, chain, jump string, pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(keyFormat, pair.ID)
rule := genRuleSpec(jump, pair.Source, pair.Destination)
existingRule, found := i.rules[ruleKey]
if found {
err := i.iptablesClient.DeleteIfExists(table, chain, existingRule...)
if err != nil {
return fmt.Errorf("error while removing existing NAT rule for %s: %v", pair.Destination, err)
}
delete(i.rules, ruleKey)
}
// inserting after loopback ignore rule
err := i.iptablesClient.Insert(table, chain, 2, rule...)
if err != nil {
return fmt.Errorf("error while appending new NAT rule for %s: %v", pair.Destination, err)
}
i.rules[ruleKey] = rule
return nil
}
// genRuleSpec generates rule specification
func genRuleSpec(jump, source, destination string) []string {
return []string{"-s", source, "-d", destination, "-j", jump}
}
func getIptablesRuleType(table string) string {
ruleType := "forwarding"
if table == tableNat {
ruleType = "nat"
}
return ruleType
} }

View File

@@ -3,16 +3,18 @@
package iptables package iptables
import ( import (
"context" "fmt"
"net/netip"
"os/exec" "os/exec"
"testing" "testing"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test" "github.com/netbirdio/netbird/client/firewall/test"
nbnet "github.com/netbirdio/netbird/util/net"
) )
func isIptablesSupported() bool { func isIptablesSupported() bool {
@@ -28,45 +30,45 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client") require.NoError(t, err, "failed to init iptables client")
manager, err := newRouterManager(context.TODO(), iptablesClient) manager, err := newRouter(iptablesClient, ifaceMock)
require.NoError(t, err, "should return a valid iptables manager") require.NoError(t, err, "should return a valid iptables manager")
require.NoError(t, manager.init(nil))
defer func() { defer func() {
_ = manager.Reset() assert.NoError(t, manager.Reset(), "shouldn't return error")
}() }()
require.Len(t, manager.rules, 2, "should have created rules map") // Now 5 rules:
// 1. established rule in forward chain
// 2. jump rule to NAT chain
// 3. jump rule to PRE chain
// 4. static outbound masquerade rule
// 5. static return masquerade rule
require.Len(t, manager.rules, 5, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableFilter, chainFORWARD, manager.rules[Ipv4Forwarding]...) exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainFORWARD)
require.True(t, exists, "forwarding rule should exist")
exists, err = manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, manager.rules[ipv4Nat]...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
require.True(t, exists, "postrouting rule should exist") require.True(t, exists, "postrouting jump rule should exist")
exists, err = manager.iptablesClient.Exists(tableMangle, chainPREROUTING, "-j", chainRTPRE)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainPREROUTING)
require.True(t, exists, "prerouting jump rule should exist")
pair := firewall.RouterPair{ pair := firewall.RouterPair{
ID: "abc", ID: "abc",
Source: "100.100.100.1/32", Source: netip.MustParsePrefix("100.100.100.1/32"),
Destination: "100.100.100.0/24", Destination: netip.MustParsePrefix("100.100.100.0/24"),
Masquerade: true, Masquerade: true,
} }
forward4Rule := genRuleSpec(routingFinalForwardJump, pair.Source, pair.Destination)
err = manager.iptablesClient.Insert(tableFilter, chainRTFWD, 1, forward4Rule...) err = manager.AddNatRule(pair)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "adding NAT rule should not return error")
nat4Rule := genRuleSpec(routingFinalNatJump, pair.Source, pair.Destination)
err = manager.iptablesClient.Insert(tableNat, chainRTNAT, 1, nat4Rule...)
require.NoError(t, err, "inserting rule should not return error")
err = manager.Reset() err = manager.Reset()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
} }
func TestIptablesManager_InsertRoutingRules(t *testing.T) { func TestIptablesManager_AddNatRule(t *testing.T) {
if !isIptablesSupported() { if !isIptablesSupported() {
t.SkipNow() t.SkipNow()
} }
@@ -76,78 +78,71 @@ func TestIptablesManager_InsertRoutingRules(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client") require.NoError(t, err, "failed to init iptables client")
manager, err := newRouterManager(context.TODO(), iptablesClient) manager, err := newRouter(iptablesClient, ifaceMock)
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
defer func() { defer func() {
err := manager.Reset() assert.NoError(t, manager.Reset(), "shouldn't return error")
if err != nil {
log.Errorf("failed to reset iptables manager: %s", err)
}
}() }()
err = manager.InsertRoutingRules(testCase.InputPair) err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "forwarding pair should be inserted") require.NoError(t, err, "marking rule should be inserted")
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID) natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination) markingRule := []string{
"-i", ifaceMock.Name(),
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...) "-m", "conntrack",
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD) "--ctstate", "NEW",
require.True(t, exists, "forwarding rule should exist") "-s", testCase.InputPair.Source.String(),
"-d", testCase.InputPair.Destination.String(),
foundRule, found := manager.rules[forwardRuleKey] "-j", "MARK", "--set-mark",
require.True(t, found, "forwarding rule should exist in the manager map") fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
require.Equal(t, forwardRule[:4], foundRule[:4], "stored forwarding rule should match")
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.True(t, exists, "income forwarding rule should exist")
foundRule, found = manager.rules[inForwardRuleKey]
require.True(t, found, "income forwarding rule should exist in the manager map")
require.Equal(t, inForwardRule[:4], foundRule[:4], "stored income forwarding rule should match")
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
if testCase.InputPair.Masquerade {
require.True(t, exists, "nat rule should be created")
foundNatRule, foundNat := manager.rules[natRuleKey]
require.True(t, foundNat, "nat rule should exist in the map")
require.Equal(t, natRule[:4], foundNatRule[:4], "stored nat rule should match")
} else {
require.False(t, exists, "nat rule should not be created")
_, foundNat := manager.rules[natRuleKey]
require.False(t, foundNat, "nat rule should not exist in the map")
} }
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID) exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
if testCase.InputPair.Masquerade { if testCase.InputPair.Masquerade {
require.True(t, exists, "income nat rule should be created") require.True(t, exists, "marking rule should be created")
foundNatRule, foundNat := manager.rules[inNatRuleKey] foundRule, found := manager.rules[natRuleKey]
require.True(t, foundNat, "income nat rule should exist in the map") require.True(t, found, "marking rule should exist in the map")
require.Equal(t, inNatRule[:4], foundNatRule[:4], "stored income nat rule should match") require.Equal(t, markingRule, foundRule, "stored marking rule should match")
} else { } else {
require.False(t, exists, "nat rule should not be created") require.False(t, exists, "marking rule should not be created")
_, foundNat := manager.rules[inNatRuleKey] _, found := manager.rules[natRuleKey]
require.False(t, foundNat, "income nat rule should not exist in the map") require.False(t, found, "marking rule should not exist in the map")
}
// Check inverse rule
inversePair := firewall.GetInversePair(testCase.InputPair)
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", inversePair.Source.String(),
"-d", inversePair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
}
exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
if testCase.InputPair.Masquerade {
require.True(t, exists, "inverse marking rule should be created")
foundRule, found := manager.rules[inverseRuleKey]
require.True(t, found, "inverse marking rule should exist in the map")
require.Equal(t, inverseMarkingRule, foundRule, "stored inverse marking rule should match")
} else {
require.False(t, exists, "inverse marking rule should not be created")
_, found := manager.rules[inverseRuleKey]
require.False(t, found, "inverse marking rule should not exist in the map")
} }
}) })
} }
} }
func TestIptablesManager_RemoveRoutingRules(t *testing.T) { func TestIptablesManager_RemoveNatRule(t *testing.T) {
if !isIptablesSupported() { if !isIptablesSupported() {
t.SkipNow() t.SkipNow()
} }
@@ -156,72 +151,226 @@ func TestIptablesManager_RemoveRoutingRules(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
manager, err := newRouterManager(context.TODO(), iptablesClient) manager, err := newRouter(iptablesClient, ifaceMock)
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
defer func() { defer func() {
_ = manager.Reset() assert.NoError(t, manager.Reset(), "shouldn't return error")
}() }()
err = manager.AddNatRule(testCase.InputPair)
require.NoError(t, err, "should add NAT rule without error")
err = manager.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID) natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair)
forwardRule := genRuleSpec(routingFinalForwardJump, testCase.InputPair.Source, testCase.InputPair.Destination) markingRule := []string{
"-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", testCase.InputPair.Source.String(),
"-d", testCase.InputPair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasquerade),
}
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, forwardRule...) exists, err := iptablesClient.Exists(tableMangle, chainRTPRE, markingRule...)
require.NoError(t, err, "inserting rule should not return error") require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
require.False(t, exists, "marking rule should not exist")
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID) _, found := manager.rules[natRuleKey]
inForwardRule := genRuleSpec(routingFinalForwardJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination) require.False(t, found, "marking rule should not exist in the manager map")
err = iptablesClient.Insert(tableFilter, chainRTFWD, 1, inForwardRule...) // Check inverse rule removal
require.NoError(t, err, "inserting rule should not return error") inversePair := firewall.GetInversePair(testCase.InputPair)
inverseRuleKey := firewall.GenKey(firewall.NatFormat, inversePair)
inverseMarkingRule := []string{
"!", "-i", ifaceMock.Name(),
"-m", "conntrack",
"--ctstate", "NEW",
"-s", inversePair.Source.String(),
"-d", inversePair.Destination.String(),
"-j", "MARK", "--set-mark",
fmt.Sprintf("%#x", nbnet.PreroutingFwmarkMasqueradeReturn),
}
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID) exists, err = iptablesClient.Exists(tableMangle, chainRTPRE, inverseMarkingRule...)
natRule := genRuleSpec(routingFinalNatJump, testCase.InputPair.Source, testCase.InputPair.Destination) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableMangle, chainRTPRE)
require.False(t, exists, "inverse marking rule should not exist")
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, natRule...)
require.NoError(t, err, "inserting rule should not return error")
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
inNatRule := genRuleSpec(routingFinalNatJump, firewall.GetInPair(testCase.InputPair).Source, firewall.GetInPair(testCase.InputPair).Destination)
err = iptablesClient.Insert(tableNat, chainRTNAT, 1, inNatRule...)
require.NoError(t, err, "inserting rule should not return error")
err = manager.Reset()
require.NoError(t, err, "shouldn't return error")
err = manager.RemoveRoutingRules(testCase.InputPair)
require.NoError(t, err, "shouldn't return error")
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, forwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.False(t, exists, "forwarding rule should not exist")
_, found := manager.rules[forwardRuleKey]
require.False(t, found, "forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableFilter, chainRTFWD, inForwardRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableFilter, chainRTFWD)
require.False(t, exists, "income forwarding rule should not exist")
_, found = manager.rules[inForwardRuleKey]
require.False(t, found, "income forwarding rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, natRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
require.False(t, exists, "nat rule should not exist")
_, found = manager.rules[natRuleKey]
require.False(t, found, "nat rule should exist in the manager map")
exists, err = iptablesClient.Exists(tableNat, chainRTNAT, inNatRule...)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainRTNAT)
require.False(t, exists, "income nat rule should not exist")
_, found = manager.rules[inNatRuleKey]
require.False(t, found, "income nat rule should exist in the manager map")
_, found = manager.rules[inverseRuleKey]
require.False(t, found, "inverse marking rule should not exist in the map")
})
}
}
func TestRouter_AddRouteFiltering(t *testing.T) {
if !isIptablesSupported() {
t.Skip("iptables not supported on this system")
}
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "Failed to create iptables client")
r, err := newRouter(iptablesClient, ifaceMock)
require.NoError(t, err, "Failed to create router manager")
require.NoError(t, r.init(nil))
defer func() {
err := r.Reset()
require.NoError(t, err, "Failed to reset router")
}()
tests := []struct {
name string
sources []netip.Prefix
destination netip.Prefix
proto firewall.Protocol
sPort *firewall.Port
dPort *firewall.Port
direction firewall.RuleDirection
action firewall.Action
expectSet bool
}{
{
name: "Basic TCP rule with single source",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolTCP,
sPort: nil,
dPort: &firewall.Port{Values: []int{80}},
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with multiple sources",
sources: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("192.168.0.0/16"),
},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolUDP,
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionDrop,
expectSet: true,
},
{
name: "All protocols rule",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
destination: netip.MustParsePrefix("0.0.0.0/0"),
proto: firewall.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "ICMP rule",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolICMP,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "TCP rule with multiple source ports",
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
destination: netip.MustParsePrefix("192.168.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with single IP and port range",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolUDP,
sPort: nil,
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
},
{
name: "TCP rule with source and destination ports",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
destination: netip.MustParsePrefix("172.16.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []int{22}},
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "Drop all incoming traffic",
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
destination: netip.MustParsePrefix("192.168.0.0/24"),
proto: firewall.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
// Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.GetRuleID()]
assert.True(t, ok, "Rule not found in internal map")
// Log the internal rule
t.Logf("Internal rule: %v", rule)
// Check if the rule exists in iptables
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, rule...)
assert.NoError(t, err, "Failed to check rule existence")
assert.True(t, exists, "Rule not found in iptables")
// Verify rule content
params := routeFilteringRuleParams{
Sources: tt.sources,
Destination: tt.destination,
Proto: tt.proto,
SPort: tt.sPort,
DPort: tt.dPort,
Action: tt.action,
SetName: "",
}
expectedRule := genRouteFilteringRuleSpec(params)
if tt.expectSet {
setName := firewall.GenerateSetName(tt.sources)
params.SetName = setName
expectedRule = genRouteFilteringRuleSpec(params)
// Check if the set was created
_, exists := r.ipsetCounter.Get(setName)
assert.True(t, exists, "IPSet not created")
}
assert.Equal(t, expectedRule, rule, "Rule content mismatch")
// Clean up
err = r.DeleteRouteRule(ruleKey)
require.NoError(t, err, "Failed to delete rule")
}) })
} }
} }

View File

@@ -1,14 +1,16 @@
package iptables package iptables
import "encoding/json"
type ipList struct { type ipList struct {
ips map[string]struct{} ips map[string]struct{}
} }
func newIpList(ip string) ipList { func newIpList(ip string) *ipList {
ips := make(map[string]struct{}) ips := make(map[string]struct{})
ips[ip] = struct{}{} ips[ip] = struct{}{}
return ipList{ return &ipList{
ips: ips, ips: ips,
} }
} }
@@ -17,27 +19,47 @@ func (s *ipList) addIP(ip string) {
s.ips[ip] = struct{}{} s.ips[ip] = struct{}{}
} }
// MarshalJSON implements json.Marshaler
func (s *ipList) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
IPs map[string]struct{} `json:"ips"`
}{
IPs: s.ips,
})
}
// UnmarshalJSON implements json.Unmarshaler
func (s *ipList) UnmarshalJSON(data []byte) error {
temp := struct {
IPs map[string]struct{} `json:"ips"`
}{}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
s.ips = temp.IPs
return nil
}
type ipsetStore struct { type ipsetStore struct {
ipsets map[string]ipList // ipsetName -> ruleset ipsets map[string]*ipList
} }
func newIpsetStore() *ipsetStore { func newIpsetStore() *ipsetStore {
return &ipsetStore{ return &ipsetStore{
ipsets: make(map[string]ipList), ipsets: make(map[string]*ipList),
} }
} }
func (s *ipsetStore) ipset(ipsetName string) (ipList, bool) { func (s *ipsetStore) ipset(ipsetName string) (*ipList, bool) {
r, ok := s.ipsets[ipsetName] r, ok := s.ipsets[ipsetName]
return r, ok return r, ok
} }
func (s *ipsetStore) addIpList(ipsetName string, list ipList) { func (s *ipsetStore) addIpList(ipsetName string, list *ipList) {
s.ipsets[ipsetName] = list s.ipsets[ipsetName] = list
} }
func (s *ipsetStore) deleteIpset(ipsetName string) { func (s *ipsetStore) deleteIpset(ipsetName string) {
s.ipsets[ipsetName] = ipList{}
delete(s.ipsets, ipsetName) delete(s.ipsets, ipsetName)
} }
@@ -48,3 +70,24 @@ func (s *ipsetStore) ipsetNames() []string {
} }
return names return names
} }
// MarshalJSON implements json.Marshaler
func (s *ipsetStore) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
IPSets map[string]*ipList `json:"ipsets"`
}{
IPSets: s.ipsets,
})
}
// UnmarshalJSON implements json.Unmarshaler
func (s *ipsetStore) UnmarshalJSON(data []byte) error {
temp := struct {
IPSets map[string]*ipList `json:"ipsets"`
}{}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}
s.ipsets = temp.IPSets
return nil
}

View File

@@ -0,0 +1,70 @@
package iptables
import (
"fmt"
"sync"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
}
func (i *InterfaceState) Name() string {
return i.NameStr
}
func (i *InterfaceState) Address() device.WGAddress {
return i.WGAddress
}
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct {
sync.Mutex
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
RouteRules routeRules `json:"route_rules,omitempty"`
RouteIPsetCounter *ipsetCounter `json:"route_ipset_counter,omitempty"`
ACLEntries aclEntries `json:"acl_entries,omitempty"`
ACLIPsetStore *ipsetStore `json:"acl_ipset_store,omitempty"`
}
func (s *ShutdownState) Name() string {
return "iptables_state"
}
func (s *ShutdownState) Cleanup() error {
ipt, err := Create(s.InterfaceState)
if err != nil {
return fmt.Errorf("create iptables manager: %w", err)
}
if s.RouteRules != nil {
ipt.router.rules = s.RouteRules
}
if s.RouteIPsetCounter != nil {
ipt.router.ipsetCounter.LoadData(s.RouteIPsetCounter)
}
if s.ACLEntries != nil {
ipt.aclMgr.entries = s.ACLEntries
}
if s.ACLIPsetStore != nil {
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
}
if err := ipt.Reset(nil); err != nil {
return fmt.Errorf("reset iptables manager: %w", err)
}
return nil
}

View File

@@ -1,15 +1,24 @@
package manager package manager
import ( import (
"crypto/sha256"
"encoding/hex"
"fmt" "fmt"
"net" "net"
"net/netip"
"sort"
"strings"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const ( const (
NatFormat = "netbird-nat-%s" ForwardingFormatPrefix = "netbird-fwd-"
ForwardingFormat = "netbird-fwd-%s" ForwardingFormat = "netbird-fwd-%s-%t"
InNatFormat = "netbird-nat-in-%s" PreroutingFormat = "netbird-prerouting-%s-%t"
InForwardingFormat = "netbird-fwd-in-%s" NatFormat = "netbird-nat-%s-%t"
) )
// Rule abstraction should be implemented by each firewall manager // Rule abstraction should be implemented by each firewall manager
@@ -46,14 +55,16 @@ const (
// It declares methods which handle actions required by the // It declares methods which handle actions required by the
// Netbird client for ACL and routing functionality // Netbird client for ACL and routing functionality
type Manager interface { type Manager interface {
Init(stateManager *statemanager.Manager) error
// AllowNetbird allows netbird interface traffic // AllowNetbird allows netbird interface traffic
AllowNetbird() error AllowNetbird() error
// AddFiltering rule to the firewall // AddPeerFiltering adds a rule to the firewall
// //
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
AddFiltering( AddPeerFiltering(
ip net.IP, ip net.IP,
proto Protocol, proto Protocol,
sPort *Port, sPort *Port,
@@ -64,25 +75,116 @@ type Manager interface {
comment string, comment string,
) ([]Rule, error) ) ([]Rule, error)
// DeleteRule from the firewall by rule definition // DeletePeerRule from the firewall by rule definition
DeleteRule(rule Rule) error DeletePeerRule(rule Rule) error
// IsServerRouteSupported returns true if the firewall supports server side routing operations // IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool IsServerRouteSupported() bool
// InsertRoutingRules inserts a routing firewall rule AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error)
InsertRoutingRules(pair RouterPair) error
// RemoveRoutingRules removes a routing firewall rule // DeleteRouteRule deletes a routing rule
RemoveRoutingRules(pair RouterPair) error DeleteRouteRule(rule Rule) error
// AddNatRule inserts a routing NAT rule
AddNatRule(pair RouterPair) error
// RemoveNatRule removes a routing NAT rule
RemoveNatRule(pair RouterPair) error
// SetLegacyManagement sets the legacy management mode
SetLegacyManagement(legacy bool) error
// Reset firewall to the default state // Reset firewall to the default state
Reset() error Reset(stateManager *statemanager.Manager) error
// Flush the changes to firewall controller // Flush the changes to firewall controller
Flush() error Flush() error
} }
func GenKey(format string, input string) string { func GenKey(format string, pair RouterPair) string {
return fmt.Sprintf(format, input) return fmt.Sprintf(format, pair.ID, pair.Inverse)
}
// LegacyManager defines the interface for legacy management operations
type LegacyManager interface {
RemoveAllLegacyRouteRules() error
GetLegacyManagement() bool
SetLegacyManagement(bool)
}
// SetLegacyManagement sets the route manager to use legacy management
func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
oldLegacy := router.GetLegacyManagement()
if oldLegacy != isLegacy {
router.SetLegacyManagement(isLegacy)
log.Debugf("Set legacy management to %v", isLegacy)
}
// client reconnected to a newer mgmt, we need to clean up the legacy rules
if !isLegacy && oldLegacy {
if err := router.RemoveAllLegacyRouteRules(); err != nil {
return fmt.Errorf("remove legacy routing rules: %v", err)
}
log.Debugf("Legacy routing rules removed")
}
return nil
}
// GenerateSetName generates a unique name for an ipset based on the given sources.
func GenerateSetName(sources []netip.Prefix) string {
// sort for consistent naming
SortPrefixes(sources)
var sourcesStr strings.Builder
for _, src := range sources {
sourcesStr.WriteString(src.String())
}
hash := sha256.Sum256([]byte(sourcesStr.String()))
shortHash := hex.EncodeToString(hash[:])[:8]
return fmt.Sprintf("nb-%s", shortHash)
}
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
if len(prefixes) == 0 {
return prefixes
}
merged := []netip.Prefix{prefixes[0]}
for _, prefix := range prefixes[1:] {
last := merged[len(merged)-1]
if last.Contains(prefix.Addr()) {
// If the current prefix is contained within the last merged prefix, skip it
continue
}
if prefix.Contains(last.Addr()) {
// If the current prefix contains the last merged prefix, replace it
merged[len(merged)-1] = prefix
} else {
// Otherwise, add the current prefix to the merged list
merged = append(merged, prefix)
}
}
return merged
}
// SortPrefixes sorts the given slice of netip.Prefix in place.
// It sorts first by IP address, then by prefix length (most specific to least specific).
func SortPrefixes(prefixes []netip.Prefix) {
sort.Slice(prefixes, func(i, j int) bool {
addrCmp := prefixes[i].Addr().Compare(prefixes[j].Addr())
if addrCmp != 0 {
return addrCmp < 0
}
// If IP addresses are the same, compare prefix lengths (longer prefixes first)
return prefixes[i].Bits() > prefixes[j].Bits()
})
} }

View File

@@ -0,0 +1,192 @@
package manager_test
import (
"net/netip"
"reflect"
"regexp"
"testing"
"github.com/netbirdio/netbird/client/firewall/manager"
)
func TestGenerateSetName(t *testing.T) {
t.Run("Different orders result in same hash", func(t *testing.T) {
prefixes1 := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
}
prefixes2 := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.1.0/24"),
}
result1 := manager.GenerateSetName(prefixes1)
result2 := manager.GenerateSetName(prefixes2)
if result1 != result2 {
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
}
})
t.Run("Result format is correct", func(t *testing.T) {
prefixes := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
}
result := manager.GenerateSetName(prefixes)
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result)
if err != nil {
t.Fatalf("Error matching regex: %v", err)
}
if !matched {
t.Errorf("Result format is incorrect: %s", result)
}
})
t.Run("Empty input produces consistent result", func(t *testing.T) {
result1 := manager.GenerateSetName([]netip.Prefix{})
result2 := manager.GenerateSetName([]netip.Prefix{})
if result1 != result2 {
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
}
})
t.Run("IPv4 and IPv6 mixing", func(t *testing.T) {
prefixes1 := []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("2001:db8::/32"),
}
prefixes2 := []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
netip.MustParsePrefix("192.168.1.0/24"),
}
result1 := manager.GenerateSetName(prefixes1)
result2 := manager.GenerateSetName(prefixes2)
if result1 != result2 {
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)
}
})
}
func TestMergeIPRanges(t *testing.T) {
tests := []struct {
name string
input []netip.Prefix
expected []netip.Prefix
}{
{
name: "Empty input",
input: []netip.Prefix{},
expected: []netip.Prefix{},
},
{
name: "Single range",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Two non-overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("10.0.0.0/8"),
},
},
{
name: "One range containing another",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "One range containing another (different order)",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.0.0/16"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "Overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.1.128/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Overlapping ranges (different order)",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.1.128/25"),
netip.MustParsePrefix("192.168.1.0/24"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"),
},
},
{
name: "Multiple overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.2.0/24"),
netip.MustParsePrefix("192.168.1.128/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"),
},
},
{
name: "Partially overlapping ranges",
input: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/23"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.2.0/25"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/23"),
netip.MustParsePrefix("192.168.2.0/25"),
},
},
{
name: "IPv6 ranges",
input: []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
netip.MustParsePrefix("2001:db8:1::/48"),
netip.MustParsePrefix("2001:db8:2::/48"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("2001:db8::/32"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := manager.MergeIPRanges(tt.input)
if !reflect.DeepEqual(result, tt.expected) {
t.Errorf("MergeIPRanges() = %v, want %v", result, tt.expected)
}
})
}
}

View File

@@ -1,18 +1,26 @@
package manager package manager
import (
"net/netip"
"github.com/netbirdio/netbird/route"
)
type RouterPair struct { type RouterPair struct {
ID string ID route.ID
Source string Source netip.Prefix
Destination string Destination netip.Prefix
Masquerade bool Masquerade bool
Inverse bool
} }
func GetInPair(pair RouterPair) RouterPair { func GetInversePair(pair RouterPair) RouterPair {
return RouterPair{ return RouterPair{
ID: pair.ID, ID: pair.ID,
// invert Source/Destination // invert Source/Destination
Source: pair.Destination, Source: pair.Destination,
Destination: pair.Source, Destination: pair.Source,
Masquerade: pair.Masquerade, Masquerade: pair.Masquerade,
Inverse: true,
} }
} }

View File

@@ -11,12 +11,13 @@ import (
"time" "time"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/iface" nbnet "github.com/netbirdio/netbird/util/net"
) )
const ( const (
@@ -29,72 +30,63 @@ const (
chainNameInputFilter = "netbird-acl-input-filter" chainNameInputFilter = "netbird-acl-input-filter"
chainNameOutputFilter = "netbird-acl-output-filter" chainNameOutputFilter = "netbird-acl-output-filter"
chainNameForwardFilter = "netbird-acl-forward-filter" chainNameForwardFilter = "netbird-acl-forward-filter"
chainNamePrerouting = "netbird-rt-prerouting"
allowNetbirdInputRuleID = "allow Netbird incoming traffic" allowNetbirdInputRuleID = "allow Netbird incoming traffic"
) )
const flushError = "flush: %w"
var ( var (
anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} anyIP = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
postroutingMark = []byte{0xe4, 0x7, 0x0, 0x00}
) )
type AclManager struct { type AclManager struct {
rConn *nftables.Conn rConn *nftables.Conn
sConn *nftables.Conn sConn *nftables.Conn
wgIface iFaceMapper wgIface iFaceMapper
routeingFwChainName string routingFwChainName string
workTable *nftables.Table workTable *nftables.Table
chainInputRules *nftables.Chain chainInputRules *nftables.Chain
chainOutputRules *nftables.Chain chainOutputRules *nftables.Chain
chainFwFilter *nftables.Chain
chainPrerouting *nftables.Chain
ipsetStore *ipsetStore ipsetStore *ipsetStore
rules map[string]*Rule rules map[string]*Rule
} }
// iFaceMapper defines subset methods of interface required for manager func newAclManager(table *nftables.Table, wgIface iFaceMapper, routingFwChainName string) (*AclManager, error) {
type iFaceMapper interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
}
func newAclManager(table *nftables.Table, wgIface iFaceMapper, routeingFwChainName string) (*AclManager, error) {
// sConn is used for creating sets and adding/removing elements from them // sConn is used for creating sets and adding/removing elements from them
// it's differ then rConn (which does create new conn for each flush operation) // it's differ then rConn (which does create new conn for each flush operation)
// and is permanent. Using same connection for booth type of operations // and is permanent. Using same connection for both type of operations
// overloads netlink with high amount of rules ( > 10000) // overloads netlink with high amount of rules ( > 10000)
sConn, err := nftables.New(nftables.AsLasting()) sConn, err := nftables.New(nftables.AsLasting())
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("create nf conn: %w", err)
} }
m := &AclManager{ return &AclManager{
rConn: &nftables.Conn{}, rConn: &nftables.Conn{},
sConn: sConn, sConn: sConn,
wgIface: wgIface, wgIface: wgIface,
workTable: table, workTable: table,
routeingFwChainName: routeingFwChainName, routingFwChainName: routingFwChainName,
ipsetStore: newIpsetStore(), ipsetStore: newIpsetStore(),
rules: make(map[string]*Rule), rules: make(map[string]*Rule),
} }, nil
err = m.createDefaultChains()
if err != nil {
return nil, err
}
return m, nil
} }
// AddFiltering rule to the firewall func (m *AclManager) init(workTable *nftables.Table) error {
m.workTable = workTable
return m.createDefaultChains()
}
// AddPeerFiltering rule to the firewall
// //
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
func (m *AclManager) AddFiltering( func (m *AclManager) AddPeerFiltering(
ip net.IP, ip net.IP,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
@@ -120,20 +112,11 @@ func (m *AclManager) AddFiltering(
} }
newRules = append(newRules, ioRule) newRules = append(newRules, ioRule)
if !shouldAddToPrerouting(proto, dPort, direction) {
return newRules, nil
}
preroutingRule, err := m.addPreroutingFiltering(ipset, proto, dPort, ip)
if err != nil {
return newRules, err
}
newRules = append(newRules, preroutingRule)
return newRules, nil return newRules, nil
} }
// DeleteRule from the firewall by rule definition // DeletePeerRule from the firewall by rule definition
func (m *AclManager) DeleteRule(rule firewall.Rule) error { func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
r, ok := rule.(*Rule) r, ok := rule.(*Rule)
if !ok { if !ok {
return fmt.Errorf("invalid rule type") return fmt.Errorf("invalid rule type")
@@ -199,8 +182,7 @@ func (m *AclManager) DeleteRule(rule firewall.Rule) error {
return nil return nil
} }
// createDefaultAllowRules In case if the USP firewall manager can use the native firewall manager we must to create allow rules for // createDefaultAllowRules creates default allow rules for the input and output chains
// input and output chains
func (m *AclManager) createDefaultAllowRules() error { func (m *AclManager) createDefaultAllowRules() error {
expIn := []expr.Any{ expIn := []expr.Any{
&expr.Payload{ &expr.Payload{
@@ -214,13 +196,13 @@ func (m *AclManager) createDefaultAllowRules() error {
SourceRegister: 1, SourceRegister: 1,
DestRegister: 1, DestRegister: 1,
Len: 4, Len: 4,
Mask: []byte{0x00, 0x00, 0x00, 0x00}, Mask: []byte{0, 0, 0, 0},
Xor: zeroXor, Xor: []byte{0, 0, 0, 0},
}, },
// net address // net address
&expr.Cmp{ &expr.Cmp{
Register: 1, Register: 1,
Data: []byte{0x00, 0x00, 0x00, 0x00}, Data: []byte{0, 0, 0, 0},
}, },
&expr.Verdict{ &expr.Verdict{
Kind: expr.VerdictAccept, Kind: expr.VerdictAccept,
@@ -246,13 +228,13 @@ func (m *AclManager) createDefaultAllowRules() error {
SourceRegister: 1, SourceRegister: 1,
DestRegister: 1, DestRegister: 1,
Len: 4, Len: 4,
Mask: []byte{0x00, 0x00, 0x00, 0x00}, Mask: []byte{0, 0, 0, 0},
Xor: zeroXor, Xor: []byte{0, 0, 0, 0},
}, },
// net address // net address
&expr.Cmp{ &expr.Cmp{
Register: 1, Register: 1,
Data: []byte{0x00, 0x00, 0x00, 0x00}, Data: []byte{0, 0, 0, 0},
}, },
&expr.Verdict{ &expr.Verdict{
Kind: expr.VerdictAccept, Kind: expr.VerdictAccept,
@@ -266,10 +248,8 @@ func (m *AclManager) createDefaultAllowRules() error {
Exprs: expOut, Exprs: expOut,
}) })
err := m.rConn.Flush() if err := m.rConn.Flush(); err != nil {
if err != nil { return fmt.Errorf(flushError, err)
log.Debugf("failed to create default allow rules: %s", err)
return err
} }
return nil return nil
} }
@@ -290,15 +270,11 @@ func (m *AclManager) Flush() error {
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err) log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
} }
if err := m.refreshRuleHandles(m.chainPrerouting); err != nil {
log.Errorf("failed to refresh rule handles IPv4 prerouting chain: %v", err)
}
return nil return nil
} }
func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) { func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) {
ruleId := generateRuleId(ip, sPort, dPort, direction, action, ipset) ruleId := generatePeerRuleId(ip, sPort, dPort, direction, action, ipset)
if r, ok := m.rules[ruleId]; ok { if r, ok := m.rules[ruleId]; ok {
return &Rule{ return &Rule{
r.nftRule, r.nftRule,
@@ -308,18 +284,7 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
}, nil }, nil
} }
ifaceKey := expr.MetaKeyIIFNAME var expressions []expr.Any
if direction == firewall.RuleDirectionOUT {
ifaceKey = expr.MetaKeyOIFNAME
}
expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
}
if proto != firewall.ProtocolALL { if proto != firewall.ProtocolALL {
expressions = append(expressions, &expr.Payload{ expressions = append(expressions, &expr.Payload{
@@ -329,21 +294,15 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
Len: uint32(1), Len: uint32(1),
}) })
var protoData []byte protoData, err := protoToInt(proto)
switch proto { if err != nil {
case firewall.ProtocolTCP: return nil, fmt.Errorf("convert protocol to number: %v", err)
protoData = []byte{unix.IPPROTO_TCP}
case firewall.ProtocolUDP:
protoData = []byte{unix.IPPROTO_UDP}
case firewall.ProtocolICMP:
protoData = []byte{unix.IPPROTO_ICMP}
default:
return nil, fmt.Errorf("unsupported protocol: %s", proto)
} }
expressions = append(expressions, &expr.Cmp{ expressions = append(expressions, &expr.Cmp{
Register: 1, Register: 1,
Op: expr.CmpOpEq, Op: expr.CmpOpEq,
Data: protoData, Data: []byte{protoData},
}) })
} }
@@ -432,10 +391,9 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
} else { } else {
chain = m.chainOutputRules chain = m.chainOutputRules
} }
nftRule := m.rConn.InsertRule(&nftables.Rule{ nftRule := m.rConn.AddRule(&nftables.Rule{
Table: m.workTable, Table: m.workTable,
Chain: chain, Chain: chain,
Position: 0,
Exprs: expressions, Exprs: expressions,
UserData: userData, UserData: userData,
}) })
@@ -453,139 +411,13 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
return rule, nil return rule, nil
} }
func (m *AclManager) addPreroutingFiltering(ipset *nftables.Set, proto firewall.Protocol, port *firewall.Port, ip net.IP) (*Rule, error) {
var protoData []byte
switch proto {
case firewall.ProtocolTCP:
protoData = []byte{unix.IPPROTO_TCP}
case firewall.ProtocolUDP:
protoData = []byte{unix.IPPROTO_UDP}
case firewall.ProtocolICMP:
protoData = []byte{unix.IPPROTO_ICMP}
default:
return nil, fmt.Errorf("unsupported protocol: %s", proto)
}
ruleId := generateRuleIdForMangle(ipset, ip, proto, port)
if r, ok := m.rules[ruleId]; ok {
return &Rule{
r.nftRule,
r.nftSet,
r.ruleID,
ip,
}, nil
}
var ipExpression expr.Any
// add individual IP for match if no ipset defined
rawIP := ip.To4()
if ipset == nil {
ipExpression = &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: rawIP,
}
} else {
ipExpression = &expr.Lookup{
SourceRegister: 1,
SetName: ipset.Name,
SetID: ipset.ID,
}
}
expressions := []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
ipExpression,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: m.wgIface.Address().IP.To4(),
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: uint32(9),
Len: uint32(1),
},
&expr.Cmp{
Register: 1,
Op: expr.CmpOpEq,
Data: protoData,
},
}
if port != nil {
expressions = append(expressions,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: encodePort(*port),
},
)
}
expressions = append(expressions,
&expr.Immediate{
Register: 1,
Data: postroutingMark,
},
&expr.Meta{
Key: expr.MetaKeyMARK,
SourceRegister: true,
Register: 1,
},
)
nftRule := m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainPrerouting,
Position: 0,
Exprs: expressions,
UserData: []byte(ruleId),
})
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf("flush insert rule: %v", err)
}
rule := &Rule{
nftRule: nftRule,
nftSet: ipset,
ruleID: ruleId,
ip: ip,
}
m.rules[ruleId] = rule
if ipset != nil {
m.ipsetStore.AddReferenceToIpset(ipset.Name)
}
return rule, nil
}
func (m *AclManager) createDefaultChains() (err error) { func (m *AclManager) createDefaultChains() (err error) {
// chainNameInputRules // chainNameInputRules
chain := m.createChain(chainNameInputRules) chain := m.createChain(chainNameInputRules)
err = m.rConn.Flush() err = m.rConn.Flush()
if err != nil { if err != nil {
log.Debugf("failed to create chain (%s): %s", chain.Name, err) log.Debugf("failed to create chain (%s): %s", chain.Name, err)
return err return fmt.Errorf(flushError, err)
} }
m.chainInputRules = chain m.chainInputRules = chain
@@ -601,9 +433,6 @@ func (m *AclManager) createDefaultChains() (err error) {
// netbird-acl-input-filter // netbird-acl-input-filter
// type filter hook input priority filter; policy accept; // type filter hook input priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput) chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
//netbird-acl-input-filter iifname "wt0" ip saddr 100.72.0.0/16 ip daddr != 100.72.0.0/16 accept
m.addRouteAllowRule(chain, expr.MetaKeyIIFNAME)
m.addFwdAllow(chain, expr.MetaKeyIIFNAME)
m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules m.addJumpRule(chain, m.chainInputRules.Name, expr.MetaKeyIIFNAME) // to netbird-acl-input-rules
m.addDropExpressions(chain, expr.MetaKeyIIFNAME) m.addDropExpressions(chain, expr.MetaKeyIIFNAME)
err = m.rConn.Flush() err = m.rConn.Flush()
@@ -615,7 +444,6 @@ func (m *AclManager) createDefaultChains() (err error) {
// netbird-acl-output-filter // netbird-acl-output-filter
// type filter hook output priority filter; policy accept; // type filter hook output priority filter; policy accept;
chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput) chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput)
m.addRouteAllowRule(chain, expr.MetaKeyOIFNAME)
m.addFwdAllow(chain, expr.MetaKeyOIFNAME) m.addFwdAllow(chain, expr.MetaKeyOIFNAME)
m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules
m.addDropExpressions(chain, expr.MetaKeyOIFNAME) m.addDropExpressions(chain, expr.MetaKeyOIFNAME)
@@ -626,29 +454,106 @@ func (m *AclManager) createDefaultChains() (err error) {
} }
// netbird-acl-forward-filter // netbird-acl-forward-filter
m.chainFwFilter = m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward) chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
m.addJumpRulesToRtForward() // to m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
m.addMarkAccept() m.addDropExpressions(chainFwFilter, expr.MetaKeyIIFNAME)
m.addJumpRuleToInputChain() // to netbird-acl-input-rules
m.addDropExpressions(m.chainFwFilter, expr.MetaKeyIIFNAME)
err = m.rConn.Flush() err = m.rConn.Flush()
if err != nil { if err != nil {
log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err) log.Debugf("failed to create chain (%s): %s", chainNameForwardFilter, err)
return err return fmt.Errorf(flushError, err)
} }
// netbird-acl-output-filter if err := m.allowRedirectedTraffic(chainFwFilter); err != nil {
// type filter hook output priority filter; policy accept; log.Errorf("failed to allow redirected traffic: %s", err)
m.chainPrerouting = m.createPreroutingMangle()
err = m.rConn.Flush()
if err != nil {
log.Debugf("failed to create chain (%s): %s", m.chainPrerouting.Name, err)
return err
} }
return nil return nil
} }
func (m *AclManager) addJumpRulesToRtForward() { // Makes redirected traffic originally destined for the host itself (now subject to the forward filter)
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
// netbird peer IP.
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
preroutingChain := m.rConn.AddChain(&nftables.Chain{
Name: chainNamePrerouting,
Table: m.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
})
m.addPreroutingRule(preroutingChain)
m.addFwmarkToForward(chainFwFilter)
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: preroutingChain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Fib{
Register: 1,
ResultADDRTYPE: true,
FlagDADDR: true,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
},
})
}
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
m.rConn.InsertRule(&nftables.Rule{
Table: m.workTable,
Chain: chainFwFilter,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: m.chainInputRules.Name,
},
},
})
}
func (m *AclManager) addJumpRulesToRtForward(chainFwFilter *nftables.Chain) {
expressions := []expr.Any{ expressions := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{ &expr.Cmp{
@@ -658,68 +563,15 @@ func (m *AclManager) addJumpRulesToRtForward() {
}, },
&expr.Verdict{ &expr.Verdict{
Kind: expr.VerdictJump, Kind: expr.VerdictJump,
Chain: m.routeingFwChainName, Chain: m.routingFwChainName,
}, },
} }
_ = m.rConn.AddRule(&nftables.Rule{ _ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable, Table: m.workTable,
Chain: m.chainFwFilter, Chain: chainFwFilter,
Exprs: expressions, Exprs: expressions,
}) })
expressions = []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: m.routeingFwChainName,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainFwFilter,
Exprs: expressions,
})
}
func (m *AclManager) addMarkAccept() {
// oifname "wt0" meta mark 0x000007e4 accept
// iifname "wt0" meta mark 0x000007e4 accept
ifaces := []expr.MetaKey{expr.MetaKeyIIFNAME, expr.MetaKeyOIFNAME}
for _, iface := range ifaces {
expressions := []expr.Any{
&expr.Meta{Key: iface, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: postroutingMark,
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainFwFilter,
Exprs: expressions,
})
}
} }
func (m *AclManager) createChain(name string) *nftables.Chain { func (m *AclManager) createChain(name string) *nftables.Chain {
@@ -729,10 +581,13 @@ func (m *AclManager) createChain(name string) *nftables.Chain {
} }
chain = m.rConn.AddChain(chain) chain = m.rConn.AddChain(chain)
insertReturnTrafficRule(m.rConn, m.workTable, chain)
return chain return chain
} }
func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.ChainHook) *nftables.Chain { func (m *AclManager) createFilterChainWithHook(name string, hookNum *nftables.ChainHook) *nftables.Chain {
polAccept := nftables.ChainPolicyAccept polAccept := nftables.ChainPolicyAccept
chain := &nftables.Chain{ chain := &nftables.Chain{
Name: name, Name: name,
@@ -746,74 +601,6 @@ func (m *AclManager) createFilterChainWithHook(name string, hookNum nftables.Cha
return m.rConn.AddChain(chain) return m.rConn.AddChain(chain)
} }
func (m *AclManager) createPreroutingMangle() *nftables.Chain {
polAccept := nftables.ChainPolicyAccept
chain := &nftables.Chain{
Name: "netbird-acl-prerouting-filter",
Table: m.workTable,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
Policy: &polAccept,
}
chain = m.rConn.AddChain(chain)
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
expressions := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: m.wgIface.Address().IP.To4(),
},
&expr.Immediate{
Register: 1,
Data: postroutingMark,
},
&expr.Meta{
Key: expr.MetaKeyMARK,
SourceRegister: true,
Register: 1,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: chain,
Exprs: expressions,
})
return chain
}
func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any { func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.MetaKey) []expr.Any {
expressions := []expr.Any{ expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1}, &expr.Meta{Key: ifaceKey, Register: 1},
@@ -832,101 +619,9 @@ func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.Met
return nil return nil
} }
func (m *AclManager) addJumpRuleToInputChain() {
expressions := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictJump,
Chain: m.chainInputRules.Name,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainFwFilter,
Exprs: expressions,
})
}
func (m *AclManager) addRouteAllowRule(chain *nftables.Chain, netIfName expr.MetaKey) {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
var srcOp, dstOp expr.CmpOp
if netIfName == expr.MetaKeyIIFNAME {
srcOp = expr.CmpOpEq
dstOp = expr.CmpOpNeq
} else {
srcOp = expr.CmpOpNeq
dstOp = expr.CmpOpEq
}
expressions := []expr.Any{
&expr.Meta{Key: netIfName, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: srcOp,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: dstOp,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
_ = m.rConn.AddRule(&nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: expressions,
})
}
func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) { func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
var srcOp, dstOp expr.CmpOp dstOp := expr.CmpOpNeq
if iifname == expr.MetaKeyIIFNAME {
srcOp = expr.CmpOpNeq
dstOp = expr.CmpOpEq
} else {
srcOp = expr.CmpOpEq
dstOp = expr.CmpOpNeq
}
expressions := []expr.Any{ expressions := []expr.Any{
&expr.Meta{Key: iifname, Register: 1}, &expr.Meta{Key: iifname, Register: 1},
&expr.Cmp{ &expr.Cmp{
@@ -934,24 +629,6 @@ func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
Register: 1, Register: 1,
Data: ifname(m.wgIface.Name()), Data: ifname(m.wgIface.Name()),
}, },
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: srcOp,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Payload{ &expr.Payload{
DestRegister: 2, DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader, Base: expr.PayloadBaseNetworkHeader,
@@ -982,7 +659,6 @@ func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
} }
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) { func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
expressions := []expr.Any{ expressions := []expr.Any{
&expr.Meta{Key: ifaceKey, Register: 1}, &expr.Meta{Key: ifaceKey, Register: 1},
&expr.Cmp{ &expr.Cmp{
@@ -990,47 +666,12 @@ func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr
Register: 1, Register: 1,
Data: ifname(m.wgIface.Name()), Data: ifname(m.wgIface.Name()),
}, },
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Payload{
DestRegister: 2,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Bitwise{
SourceRegister: 2,
DestRegister: 2,
Len: 4,
Xor: []byte{0x0, 0x0, 0x0, 0x0},
Mask: m.wgIface.Address().Network.Mask,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 2,
Data: ip.Unmap().AsSlice(),
},
&expr.Verdict{ &expr.Verdict{
Kind: expr.VerdictJump, Kind: expr.VerdictJump,
Chain: to, Chain: to,
}, },
} }
_ = m.rConn.AddRule(&nftables.Rule{ _ = m.rConn.AddRule(&nftables.Rule{
Table: chain.Table, Table: chain.Table,
Chain: chain, Chain: chain,
@@ -1132,7 +773,7 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
return nil return nil
} }
func generateRuleId( func generatePeerRuleId(
ip net.IP, ip net.IP,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
@@ -1155,33 +796,6 @@ func generateRuleId(
} }
return "set:" + ipset.Name + rulesetID return "set:" + ipset.Name + rulesetID
} }
func generateRuleIdForMangle(ipset *nftables.Set, ip net.IP, proto firewall.Protocol, port *firewall.Port) string {
// case of icmp port is empty
var p string
if port != nil {
p = port.String()
}
if ipset != nil {
return fmt.Sprintf("p:set:%s:%s:%v", ipset.Name, proto, p)
} else {
return fmt.Sprintf("p:ip:%s:%s:%v", ip.String(), proto, p)
}
}
func shouldAddToPrerouting(proto firewall.Protocol, dPort *firewall.Port, direction firewall.RuleDirection) bool {
if proto == "all" {
return false
}
if direction != firewall.RuleDirectionIN {
return false
}
if dPort == nil && proto != firewall.ProtocolICMP {
return false
}
return true
}
func encodePort(port firewall.Port) []byte { func encodePort(port firewall.Port) []byte {
bs := make([]byte, 2) bs := make([]byte, 2)
@@ -1191,6 +805,19 @@ func encodePort(port firewall.Port) []byte {
func ifname(n string) []byte { func ifname(n string) []byte {
b := make([]byte, 16) b := make([]byte, 16)
copy(b, []byte(n+"\x00")) copy(b, n+"\x00")
return b return b
} }
func protoToInt(protocol firewall.Protocol) (uint8, error) {
switch protocol {
case firewall.ProtocolTCP:
return unix.IPPROTO_TCP, nil
case firewall.ProtocolUDP:
return unix.IPPROTO_UDP, nil
case firewall.ProtocolICMP:
return unix.IPPROTO_ICMP, nil
}
return 0, fmt.Errorf("unsupported protocol: %s", protocol)
}

View File

@@ -5,20 +5,34 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const ( const (
// tableName is the name of the table that is used for filtering by the Netbird client // tableNameNetbird is the name of the table that is used for filtering by the Netbird client
tableName = "netbird" tableNameNetbird = "netbird"
tableNameFilter = "filter"
chainNameInput = "INPUT"
) )
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
Address() iface.WGAddress
IsUserspaceBind() bool
}
// Manager of iptables firewall // Manager of iptables firewall
type Manager struct { type Manager struct {
mutex sync.Mutex mutex sync.Mutex
@@ -30,35 +44,75 @@ type Manager struct {
} }
// Create nftables firewall manager // Create nftables firewall manager
func Create(context context.Context, wgIface iFaceMapper) (*Manager, error) { func Create(wgIface iFaceMapper) (*Manager, error) {
m := &Manager{ m := &Manager{
rConn: &nftables.Conn{}, rConn: &nftables.Conn{},
wgIface: wgIface, wgIface: wgIface,
} }
workTable, err := m.createWorkTable() workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}
var err error
m.router, err = newRouter(workTable, wgIface)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("create router: %w", err)
} }
m.router, err = newRouter(context, workTable) m.aclManager, err = newAclManager(workTable, wgIface, chainNameRoutingFw)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("create acl manager: %w", err)
}
m.aclManager, err = newAclManager(workTable, wgIface, m.router.RouteingFwChainName())
if err != nil {
return nil, err
} }
return m, nil return m, nil
} }
// AddFiltering rule to the firewall // Init nftables firewall manager
func (m *Manager) Init(stateManager *statemanager.Manager) error {
workTable, err := m.createWorkTable()
if err != nil {
return fmt.Errorf("create work table: %w", err)
}
if err := m.router.init(workTable); err != nil {
return fmt.Errorf("router init: %w", err)
}
if err := m.aclManager.init(workTable); err != nil {
// TODO: cleanup router
return fmt.Errorf("acl manager init: %w", err)
}
stateManager.RegisterState(&ShutdownState{})
// We only need to record minimal interface state for potential recreation.
// Unlike iptables, which requires tracking individual rules, nftables maintains
// a known state (our netbird table plus a few static rules). This allows for easy
// cleanup using Reset() without needing to store specific rules.
if err := stateManager.UpdateState(&ShutdownState{
InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
},
}); err != nil {
log.Errorf("failed to update state: %v", err)
}
// persist early
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
}()
return nil
}
// AddPeerFiltering rule to the firewall
// //
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
func (m *Manager) AddFiltering( func (m *Manager) AddPeerFiltering(
ip net.IP, ip net.IP,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
@@ -76,33 +130,52 @@ func (m *Manager) AddFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", ip.String()) return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
} }
return m.aclManager.AddFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment) return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
} }
// DeleteRule from the firewall by rule definition func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
func (m *Manager) DeleteRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.aclManager.DeleteRule(rule) if !destination.Addr().Is4() {
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
}
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.aclManager.DeletePeerRule(rule)
}
// DeleteRouteRule deletes a routing rule
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.DeleteRouteRule(rule)
} }
func (m *Manager) IsServerRouteSupported() bool { func (m *Manager) IsServerRouteSupported() bool {
return true return true
} }
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error { func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.AddRoutingRules(pair) return m.router.AddNatRule(pair)
} }
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error { func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.router.RemoveRoutingRules(pair) return m.router.RemoveNatRule(pair)
} }
// AllowNetbird allows netbird interface traffic // AllowNetbird allows netbird interface traffic
@@ -126,7 +199,7 @@ func (m *Manager) AllowNetbird() error {
var chain *nftables.Chain var chain *nftables.Chain
for _, c := range chains { for _, c := range chains {
if c.Table.Name == "filter" && c.Name == "INPUT" { if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
chain = c chain = c
break break
} }
@@ -157,47 +230,86 @@ func (m *Manager) AllowNetbird() error {
return nil return nil
} }
// SetLegacyManagement sets the route manager to use legacy management
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
return firewall.SetLegacyManagement(m.router, isLegacy)
}
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset() error { func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
chains, err := m.rConn.ListChains() if err := m.resetNetbirdInputRules(); err != nil {
if err != nil { return fmt.Errorf("reset netbird input rules: %v", err)
return fmt.Errorf("list of chains: %w", err)
} }
if err := m.router.Reset(); err != nil {
return fmt.Errorf("reset router: %v", err)
}
if err := m.cleanupNetbirdTables(); err != nil {
return fmt.Errorf("cleanup netbird tables: %v", err)
}
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
return fmt.Errorf("delete state: %v", err)
}
return nil
}
func (m *Manager) resetNetbirdInputRules() error {
chains, err := m.rConn.ListChains()
if err != nil {
return fmt.Errorf("list chains: %w", err)
}
m.deleteNetbirdInputRules(chains)
return nil
}
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
for _, c := range chains { for _, c := range chains {
// delete Netbird allow input traffic rule if it exists if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
if c.Table.Name == "filter" && c.Name == "INPUT" {
rules, err := m.rConn.GetRules(c.Table, c) rules, err := m.rConn.GetRules(c.Table, c)
if err != nil { if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err) log.Errorf("get rules for chain %q: %v", c.Name, err)
continue continue
} }
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) { m.deleteMatchingRules(rules)
if err := m.rConn.DelRule(r); err != nil { }
log.Errorf("delete rule: %v", err) }
} }
}
func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) {
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
if err := m.rConn.DelRule(r); err != nil {
log.Errorf("delete rule: %v", err)
} }
} }
} }
}
m.router.ResetForwardRules() func (m *Manager) cleanupNetbirdTables() error {
tables, err := m.rConn.ListTables() tables, err := m.rConn.ListTables()
if err != nil { if err != nil {
return fmt.Errorf("list of tables: %w", err) return fmt.Errorf("list tables: %w", err)
} }
for _, t := range tables { for _, t := range tables {
if t.Name == tableName { if t.Name == tableNameNetbird {
m.rConn.DelTable(t) m.rConn.DelTable(t)
} }
} }
return nil
return m.rConn.Flush()
} }
// Flush rule/chain/set operations from the buffer // Flush rule/chain/set operations from the buffer
@@ -218,12 +330,12 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
} }
for _, t := range tables { for _, t := range tables {
if t.Name == tableName { if t.Name == tableNameNetbird {
m.rConn.DelTable(t) m.rConn.DelTable(t)
} }
} }
table := m.rConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}) table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4})
err = m.rConn.Flush() err = m.rConn.Flush()
return table, err return table, err
} }
@@ -251,7 +363,7 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule { func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
ifName := ifname(m.wgIface.Name()) ifName := ifname(m.wgIface.Name())
for _, rule := range existedRules { for _, rule := range existedRules {
if rule.Table.Name == "filter" && rule.Chain.Name == "INPUT" { if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput {
if len(rule.Exprs) < 4 { if len(rule.Exprs) < 4 {
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME { if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
continue continue
@@ -265,3 +377,38 @@ func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftable
} }
return nil return nil
} }
func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) {
rule := &nftables.Rule{
Table: table,
Chain: chain,
Exprs: getEstablishedExprs(1),
}
conn.InsertRule(rule)
}
func getEstablishedExprs(register uint32) []expr.Any {
return []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: register,
},
&expr.Bitwise{
SourceRegister: register,
DestRegister: register,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: register,
Data: []byte{0, 0, 0, 0},
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
}

View File

@@ -1,7 +1,6 @@
package nftables package nftables
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@@ -9,14 +8,30 @@ import (
"time" "time"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/client/iface"
) )
var ifaceMock = &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct { type iFaceMock struct {
NameFunc func() string NameFunc func() string
@@ -40,28 +55,15 @@ func (i *iFaceMock) Address() iface.WGAddress {
func (i *iFaceMock) IsUserspaceBind() bool { return false } func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestNftablesManager(t *testing.T) { func TestNftablesManager(t *testing.T) {
mock := &iFaceMock{
NameFunc: func() string {
return "lo"
},
AddressFunc: func() iface.WGAddress {
return iface.WGAddress{
IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"),
Mask: net.IPv4Mask(255, 255, 255, 0),
},
}
},
}
// just check on the local interface // just check on the local interface
manager, err := Create(context.Background(), mock) manager, err := Create(ifaceMock)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
defer func() { defer func() {
err = manager.Reset() err = manager.Reset(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
@@ -70,7 +72,7 @@ func TestNftablesManager(t *testing.T) {
testClient := &nftables.Conn{} testClient := &nftables.Conn{}
rule, err := manager.AddFiltering( rule, err := manager.AddPeerFiltering(
ip, ip,
fw.ProtocolTCP, fw.ProtocolTCP,
nil, nil,
@@ -88,17 +90,35 @@ func TestNftablesManager(t *testing.T) {
rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules) rules, err := testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules") require.NoError(t, err, "failed to get rules")
require.Len(t, rules, 1, "expected 1 rules") require.Len(t, rules, 2, "expected 2 rules")
expectedExprs1 := []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions")
ipToAdd, _ := netip.AddrFromSlice(ip) ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap() add := ipToAdd.Unmap()
expectedExprs := []expr.Any{ expectedExprs2 := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname("lo"),
},
&expr.Payload{ &expr.Payload{
DestRegister: 1, DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader, Base: expr.PayloadBaseNetworkHeader,
@@ -134,10 +154,10 @@ func TestNftablesManager(t *testing.T) {
}, },
&expr.Verdict{Kind: expr.VerdictDrop}, &expr.Verdict{Kind: expr.VerdictDrop},
} }
require.ElementsMatch(t, rules[0].Exprs, expectedExprs, "expected the same expressions") require.ElementsMatch(t, rules[1].Exprs, expectedExprs2, "expected the same expressions")
for _, r := range rule { for _, r := range rule {
err = manager.DeleteRule(r) err = manager.DeletePeerRule(r)
require.NoError(t, err, "failed to delete rule") require.NoError(t, err, "failed to delete rule")
} }
@@ -146,9 +166,10 @@ func TestNftablesManager(t *testing.T) {
rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules) rules, err = testClient.GetRules(manager.aclManager.workTable, manager.aclManager.chainInputRules)
require.NoError(t, err, "failed to get rules") require.NoError(t, err, "failed to get rules")
require.Len(t, rules, 0, "expected 0 rules after deletion") // established rule remains
require.Len(t, rules, 1, "expected 1 rules after deletion")
err = manager.Reset() err = manager.Reset(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
} }
@@ -171,12 +192,13 @@ func TestNFtablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(context.Background(), mock) manager, err := Create(mock)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
defer func() { defer func() {
if err := manager.Reset(); err != nil { if err := manager.Reset(nil); err != nil {
t.Errorf("clear the manager state: %v", err) t.Errorf("clear the manager state: %v", err)
} }
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -187,9 +209,9 @@ func TestNFtablesCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 { if i%2 == 0 {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else { } else {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
} }
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")

View File

@@ -1,431 +0,0 @@
package nftables
import (
"bytes"
"context"
"errors"
"fmt"
"net"
"net/netip"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/manager"
)
const (
chainNameRouteingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-nat"
userDataAcceptForwardRuleSrc = "frwacceptsrc"
userDataAcceptForwardRuleDst = "frwacceptdst"
loopbackInterface = "lo\x00"
)
// some presets for building nftable rules
var (
zeroXor = binaryutil.NativeEndian.PutUint32(0)
exprCounterAccept = []expr.Any{
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
)
type router struct {
ctx context.Context
stop context.CancelFunc
conn *nftables.Conn
workTable *nftables.Table
filterTable *nftables.Table
chains map[string]*nftables.Chain
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
rules map[string]*nftables.Rule
isDefaultFwdRulesEnabled bool
}
func newRouter(parentCtx context.Context, workTable *nftables.Table) (*router, error) {
ctx, cancel := context.WithCancel(parentCtx)
r := &router{
ctx: ctx,
stop: cancel,
conn: &nftables.Conn{},
workTable: workTable,
chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
}
var err error
r.filterTable, err = r.loadFilterTable()
if err != nil {
if errors.Is(err, errFilterTableNotFound) {
log.Warnf("table 'filter' not found for forward rules")
} else {
return nil, err
}
}
err = r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
err = r.createContainers()
if err != nil {
log.Errorf("failed to create containers for route: %s", err)
}
return r, err
}
func (r *router) RouteingFwChainName() string {
return chainNameRouteingFw
}
// ResetForwardRules cleans existing nftables default forward rules from the system
func (r *router) ResetForwardRules() {
err := r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to reset forward rules: %s", err)
}
}
func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
}
for _, table := range tables {
if table.Name == "filter" {
return table, nil
}
}
return nil, errFilterTableNotFound
}
func (r *router) createContainers() error {
r.chains[chainNameRouteingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRouteingFw,
Table: r.workTable,
})
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityNATSource - 1,
Type: nftables.ChainTypeNAT,
})
// Add RETURN rule for loopback interface
loRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte(loopbackInterface),
},
&expr.Verdict{Kind: expr.VerdictReturn},
},
}
r.conn.InsertRule(loRule)
err := r.refreshRulesMap()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to initialize table: %v", err)
}
return nil
}
// AddRoutingRules appends a nftable rule pair to the forwarding chain and if enabled, to the nat chain
func (r *router) AddRoutingRules(pair manager.RouterPair) error {
err := r.refreshRulesMap()
if err != nil {
return err
}
err = r.addRoutingRule(manager.ForwardingFormat, chainNameRouteingFw, pair, false)
if err != nil {
return err
}
err = r.addRoutingRule(manager.InForwardingFormat, chainNameRouteingFw, manager.GetInPair(pair), false)
if err != nil {
return err
}
if pair.Masquerade {
err = r.addRoutingRule(manager.NatFormat, chainNameRoutingNat, pair, true)
if err != nil {
return err
}
err = r.addRoutingRule(manager.InNatFormat, chainNameRoutingNat, manager.GetInPair(pair), true)
if err != nil {
return err
}
}
if r.filterTable != nil && !r.isDefaultFwdRulesEnabled {
log.Debugf("add default accept forward rule")
r.acceptForwardRule(pair.Source)
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: unable to insert rules for %s: %v", pair.Destination, err)
}
return nil
}
// addRoutingRule inserts a nftable rule to the conn client flush queue
func (r *router) addRoutingRule(format, chainName string, pair manager.RouterPair, isNat bool) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
var expression []expr.Any
if isNat {
expression = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) // nolint:gocritic
} else {
expression = append(sourceExp, append(destExp, exprCounterAccept...)...) // nolint:gocritic
}
ruleKey := manager.GenKey(format, pair.ID)
_, exists := r.rules[ruleKey]
if exists {
err := r.removeRoutingRule(format, pair)
if err != nil {
return err
}
}
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainName],
Exprs: expression,
UserData: []byte(ruleKey),
})
return nil
}
func (r *router) acceptForwardRule(sourceNetwork string) {
src := generateCIDRMatcherExpressions(true, sourceNetwork)
dst := generateCIDRMatcherExpressions(false, "0.0.0.0/0")
var exprs []expr.Any
exprs = append(src, append(dst, &expr.Verdict{ // nolint:gocritic
Kind: expr.VerdictAccept,
})...)
rule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleSrc),
}
r.conn.AddRule(rule)
src = generateCIDRMatcherExpressions(true, "0.0.0.0/0")
dst = generateCIDRMatcherExpressions(false, sourceNetwork)
exprs = append(src, append(dst, &expr.Verdict{ //nolint:gocritic
Kind: expr.VerdictAccept,
})...)
rule = &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: exprs,
UserData: []byte(userDataAcceptForwardRuleDst),
}
r.conn.AddRule(rule)
r.isDefaultFwdRulesEnabled = true
}
// RemoveRoutingRules removes a nftable rule pair from forwarding and nat chains
func (r *router) RemoveRoutingRules(pair manager.RouterPair) error {
err := r.refreshRulesMap()
if err != nil {
return err
}
err = r.removeRoutingRule(manager.ForwardingFormat, pair)
if err != nil {
return err
}
err = r.removeRoutingRule(manager.InForwardingFormat, manager.GetInPair(pair))
if err != nil {
return err
}
err = r.removeRoutingRule(manager.NatFormat, pair)
if err != nil {
return err
}
err = r.removeRoutingRule(manager.InNatFormat, manager.GetInPair(pair))
if err != nil {
return err
}
if len(r.rules) == 0 {
err := r.cleanUpDefaultForwardRules()
if err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
}
err = r.conn.Flush()
if err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
}
log.Debugf("nftables: removed rules for %s", pair.Destination)
return nil
}
// removeRoutingRule add a nftable rule to the removal queue and delete from rules map
func (r *router) removeRoutingRule(format string, pair manager.RouterPair) error {
ruleKey := manager.GenKey(format, pair.ID)
rule, found := r.rules[ruleKey]
if found {
ruleType := "forwarding"
if rule.Chain.Type == nftables.ChainTypeNAT {
ruleType = "nat"
}
err := r.conn.DelRule(rule)
if err != nil {
return fmt.Errorf("nftables: unable to remove %s rule for %s: %v", ruleType, pair.Destination, err)
}
log.Debugf("nftables: removing %s rule for %s", ruleType, pair.Destination)
delete(r.rules, ruleKey)
}
return nil
}
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
// duplicates and to get missing attributes that we don't have when adding new rules
func (r *router) refreshRulesMap() error {
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("nftables: unable to list rules: %v", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
r.rules[string(rule.UserData)] = rule
}
}
}
return nil
}
func (r *router) cleanUpDefaultForwardRules() error {
if r.filterTable == nil {
r.isDefaultFwdRulesEnabled = false
return nil
}
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return err
}
var rules []*nftables.Rule
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name {
continue
}
if chain.Name != "FORWARD" {
continue
}
rules, err = r.conn.GetRules(r.filterTable, chain)
if err != nil {
return err
}
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleSrc)) || bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleDst)) {
err := r.conn.DelRule(rule)
if err != nil {
return err
}
}
}
r.isDefaultFwdRulesEnabled = false
return r.conn.Flush()
}
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
func generateCIDRMatcherExpressions(source bool, cidr string) []expr.Any {
ip, network, _ := net.ParseCIDR(cidr)
ipToAdd, _ := netip.AddrFromSlice(ip)
add := ipToAdd.Unmap()
var offSet uint32
if source {
offSet = 12 // src offset
} else {
offSet = 16 // dst offset
}
return []expr.Any{
// fetch src add
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offSet,
Len: 4,
},
// net mask
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 4,
Mask: network.Mask,
Xor: zeroXor,
},
// net address
&expr.Cmp{
Register: 1,
Data: add.AsSlice(),
},
}
}

View File

@@ -0,0 +1,989 @@
package nftables
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"net"
"net/netip"
"strings"
"github.com/coreos/go-iptables/iptables"
"github.com/davecgh/go-spew/spew"
"github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/util/net"
)
const (
chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameForward = "FORWARD"
userDataAcceptForwardRuleIif = "frwacceptiif"
userDataAcceptForwardRuleOif = "frwacceptoif"
)
const refreshRulesMapError = "refresh rules map: %w"
var (
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
)
type router struct {
conn *nftables.Conn
workTable *nftables.Table
filterTable *nftables.Table
chains map[string]*nftables.Chain
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
rules map[string]*nftables.Rule
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
wgIface iFaceMapper
legacyManagement bool
}
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
r := &router{
conn: &nftables.Conn{},
workTable: workTable,
chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule),
wgIface: wgIface,
}
r.ipsetCounter = refcounter.New(
r.createIpSet,
r.deleteIpSet,
)
var err error
r.filterTable, err = r.loadFilterTable()
if err != nil {
if errors.Is(err, errFilterTableNotFound) {
log.Warnf("table 'filter' not found for forward rules")
} else {
return nil, fmt.Errorf("load filter table: %w", err)
}
}
return r, nil
}
func (r *router) init(workTable *nftables.Table) error {
r.workTable = workTable
if err := r.removeAcceptForwardRules(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
if err := r.createContainers(); err != nil {
return fmt.Errorf("create containers: %w", err)
}
return nil
}
// Reset cleans existing nftables default forward rules from the system
func (r *router) Reset() error {
// clear without deleting the ipsets, the nf table will be deleted by the caller
r.ipsetCounter.Clear()
return r.removeAcceptForwardRules()
}
func (r *router) loadFilterTable() (*nftables.Table, error) {
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
}
for _, table := range tables {
if table.Name == "filter" {
return table, nil
}
}
return nil, errFilterTableNotFound
}
func (r *router) createContainers() error {
r.chains[chainNameRoutingFw] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingFw,
Table: r.workTable,
})
insertReturnTrafficRule(r.conn, r.workTable, r.chains[chainNameRoutingFw])
prio := *nftables.ChainPriorityNATSource - 1
r.chains[chainNameRoutingNat] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingNat,
Table: r.workTable,
Hooknum: nftables.ChainHookPostrouting,
Priority: &prio,
Type: nftables.ChainTypeNAT,
})
// Chain is created by acl manager
// TODO: move creation to a common place
r.chains[chainNamePrerouting] = &nftables.Chain{
Name: chainNamePrerouting,
Table: r.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle,
}
// Add the single NAT rule that matches on mark
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add single nat rule: %v", err)
}
if err := r.acceptForwardRules(); err != nil {
log.Errorf("failed to add accept rules for the forward chain: %s", err)
}
if err := r.refreshRulesMap(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: unable to initialize table: %v", err)
}
return nil
}
// AddRouteFiltering appends a nftables rule to the routing chain
func (r *router) AddRouteFiltering(
sources []netip.Prefix,
destination netip.Prefix,
proto firewall.Protocol,
sPort *firewall.Port,
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil
}
chain := r.chains[chainNameRoutingFw]
var exprs []expr.Any
switch {
case len(sources) == 1 && sources[0].Bits() == 0:
// If it's 0.0.0.0/0, we don't need to add any source matching
case len(sources) == 1:
// If there's only one source, we can use it directly
exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...)
default:
// If there are multiple sources, create or get an ipset
var err error
exprs, err = r.getIpSetExprs(sources, exprs)
if err != nil {
return nil, fmt.Errorf("get ipset expressions: %w", err)
}
}
// Handle destination
exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...)
// Handle protocol
if proto != firewall.ProtocolALL {
protoNum, err := protoToInt(proto)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err)
}
exprs = append(exprs, &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1})
exprs = append(exprs, &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
})
exprs = append(exprs, applyPort(sPort, true)...)
exprs = append(exprs, applyPort(dPort, false)...)
}
exprs = append(exprs, &expr.Counter{})
var verdict expr.VerdictKind
if action == firewall.ActionAccept {
verdict = expr.VerdictAccept
} else {
verdict = expr.VerdictDrop
}
exprs = append(exprs, &expr.Verdict{Kind: verdict})
rule := &nftables.Rule{
Table: r.workTable,
Chain: chain,
Exprs: exprs,
UserData: []byte(ruleKey),
}
rule = r.conn.AddRule(rule)
log.Tracef("Adding route rule %s", spew.Sdump(rule))
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf(flushError, err)
}
r.rules[string(ruleKey)] = rule
log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
return ruleKey, nil
}
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
setName := firewall.GenerateSetName(sources)
ref, err := r.ipsetCounter.Increment(setName, sources)
if err != nil {
return nil, fmt.Errorf("create or get ipset for sources: %w", err)
}
exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
&expr.Lookup{
SourceRegister: 1,
SetName: ref.Out.Name,
SetID: ref.Out.ID,
},
)
return exprs, nil
}
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
ruleKey := rule.GetRuleID()
nftRule, exists := r.rules[ruleKey]
if !exists {
log.Debugf("route rule %s not found", ruleKey)
return nil
}
if nftRule.Handle == 0 {
return fmt.Errorf("route rule %s has no handle", ruleKey)
}
setName := r.findSetNameInRule(nftRule)
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
return fmt.Errorf("delete: %w", err)
}
if setName != "" {
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
return fmt.Errorf("decrement ipset reference: %w", err)
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) {
// overlapping prefixes will result in an error, so we need to merge them
sources = firewall.MergeIPRanges(sources)
set := &nftables.Set{
Name: setName,
Table: r.workTable,
// required for prefixes
Interval: true,
KeyType: nftables.TypeIPAddr,
}
var elements []nftables.SetElement
for _, prefix := range sources {
// TODO: Implement IPv6 support
if prefix.Addr().Is6() {
log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
continue
}
// nftables needs half-open intervals [firstIP, lastIP) for prefixes
// e.g. 10.0.0.0/24 becomes [10.0.0.0, 10.0.1.0), 10.1.1.1/32 becomes [10.1.1.1, 10.1.1.2) etc
firstIP := prefix.Addr()
lastIP := calculateLastIP(prefix).Next()
elements = append(elements,
// the nft tool also adds a line like this, see https://github.com/google/nftables/issues/247
// nftables.SetElement{Key: []byte{0, 0, 0, 0}, IntervalEnd: true},
nftables.SetElement{Key: firstIP.AsSlice()},
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
)
}
if err := r.conn.AddSet(set, elements); err != nil {
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
}
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush error: %w", err)
}
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
return set, nil
}
// calculateLastIP determines the last IP in a given prefix.
func calculateLastIP(prefix netip.Prefix) netip.Addr {
hostMask := ^uint32(0) >> prefix.Masked().Bits()
lastIP := uint32FromNetipAddr(prefix.Addr()) | hostMask
return netip.AddrFrom4(uint32ToBytes(lastIP))
}
// Utility function to convert netip.Addr to uint32.
func uint32FromNetipAddr(addr netip.Addr) uint32 {
b := addr.As4()
return binary.BigEndian.Uint32(b[:])
}
// Utility function to convert uint32 to a netip-compatible byte slice.
func uint32ToBytes(ip uint32) [4]byte {
var b [4]byte
binary.BigEndian.PutUint32(b[:], ip)
return b
}
func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
r.conn.DelSet(set)
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
log.Debugf("Deleted unused ipset %s", setName)
return nil
}
func (r *router) findSetNameInRule(rule *nftables.Rule) string {
for _, e := range rule.Exprs {
if lookup, ok := e.(*expr.Lookup); ok {
return lookup.SetName
}
}
return ""
}
func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule %s: %w", ruleKey, err)
}
delete(r.rules, ruleKey)
log.Debugf("removed route rule %s", ruleKey)
return nil
}
// AddNatRule appends a nftables rule pair to the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil {
return fmt.Errorf("add legacy routing rule: %w", err)
}
}
if pair.Masquerade {
if err := r.addNatRule(pair); err != nil {
return fmt.Errorf("add nat rule: %w", err)
}
if err := r.addNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("add inverse nat rule: %w", err)
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err)
}
return nil
}
// addNatRule inserts a nftables rule to the conn client flush queue
func (r *router) addNatRule(pair firewall.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
op := expr.CmpOpEq
if pair.Inverse {
op = expr.CmpOpNeq
}
exprs := []expr.Any{
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
// interface matching
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: op,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
}
exprs = append(exprs, sourceExp...)
exprs = append(exprs, destExp...)
var markValue uint32 = nbnet.PreroutingFwmarkMasquerade
if pair.Inverse {
markValue = nbnet.PreroutingFwmarkMasqueradeReturn
}
exprs = append(exprs,
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(markValue),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
SourceRegister: true,
Register: 1,
},
)
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if _, exists := r.rules[ruleKey]; exists {
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove prerouting rule: %w", err)
}
}
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNamePrerouting],
Exprs: exprs,
UserData: []byte(ruleKey),
})
return nil
}
// addPostroutingRules adds the masquerade rules
func (r *router) addPostroutingRules() error {
// First masquerade rule for traffic coming in from WireGuard interface
exprs := []expr.Any{
// Match on the first fwmark
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasquerade),
},
// We need to exclude the loopback interface as this changes the ebpf proxy port
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname("lo"),
},
&expr.Counter{},
&expr.Masq{},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs,
})
// Second masquerade rule for traffic going out through WireGuard interface
exprs2 := []expr.Any{
// Match on the second fwmark
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkMasqueradeReturn),
},
// Match WireGuard interface
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Counter{},
&expr.Masq{},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: exprs2,
})
return nil
}
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
exprs := []expr.Any{
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
}
expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if _, exists := r.rules[ruleKey]; exists {
if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err)
}
}
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingFw],
Exprs: expression,
UserData: []byte(ruleKey),
})
return nil
}
// removeLegacyRouteRule removes a legacy routing rule for mgmt servers pre route acls
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
} else {
log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey)
}
return nil
}
// GetLegacyManagement returns the route manager's legacy management mode
func (r *router) GetLegacyManagement() bool {
return r.legacyManagement
}
// SetLegacyManagement sets the route manager to use legacy management mode
func (r *router) SetLegacyManagement(isLegacy bool) {
r.legacyManagement = isLegacy
}
// RemoveAllLegacyRouteRules removes all legacy routing rules for mgmt servers pre route acls
func (r *router) RemoveAllLegacyRouteRules() error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
for k, rule := range r.rules {
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
continue
}
if err := r.conn.DelRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
} else {
delete(r.rules, k)
}
}
return nberrors.FormatErrorOrNil(merr)
}
// acceptForwardRules adds iif/oif rules in the filter table/forward chain to make sure
// that our traffic is not dropped by existing rules there.
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
func (r *router) acceptForwardRules() error {
if r.filterTable == nil {
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
return nil
}
fw := "iptables"
defer func() {
log.Debugf("Used %s to add accept forward rules", fw)
}()
// Try iptables first and fallback to nftables if iptables is not available
ipt, err := iptables.New()
if err != nil {
// filter table exists but iptables is not
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
fw = "nftables"
return r.acceptForwardRulesNftables()
}
return r.acceptForwardRulesIptables(ipt)
}
func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err))
} else {
log.Debugf("added iptables rule: %v", rule)
}
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) getAcceptForwardRules() [][]string {
intf := r.wgIface.Name()
return [][]string{
{"-i", intf, "-j", "ACCEPT"},
{"-o", intf, "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"},
}
}
func (r *router) acceptForwardRulesNftables() error {
intf := ifname(r.wgIface.Name())
// Rule for incoming interface (iif) with counter
iifRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: chainNameForward,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
},
UserData: []byte(userDataAcceptForwardRuleIif),
}
r.conn.InsertRule(iifRule)
oifExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
}
// Rule for outgoing interface (oif) with counter
oifRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityFilter,
},
Exprs: append(oifExprs, getEstablishedExprs(2)...),
UserData: []byte(userDataAcceptForwardRuleOif),
}
r.conn.InsertRule(oifRule)
return nil
}
func (r *router) removeAcceptForwardRules() error {
if r.filterTable == nil {
return nil
}
// Try iptables first and fallback to nftables if iptables is not available
ipt, err := iptables.New()
if err != nil {
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
return r.removeAcceptForwardRulesNftables()
}
return r.removeAcceptForwardRulesIptables(ipt)
}
func (r *router) removeAcceptForwardRulesNftables() error {
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list chains: %v", err)
}
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
continue
}
rules, err := r.conn.GetRules(r.filterTable, chain)
if err != nil {
return fmt.Errorf("get rules: %v", err)
}
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule: %v", err)
}
}
}
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf(flushError, err)
}
return nil
}
func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err))
}
}
return nberrors.FormatErrorOrNil(merr)
}
// RemoveNatRule removes the prerouting mark rule
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove prerouting rule: %w", err)
}
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
return fmt.Errorf("remove inverse prerouting rule: %w", err)
}
if err := r.removeLegacyRouteRule(pair); err != nil {
return fmt.Errorf("remove legacy routing rule: %w", err)
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
}
log.Debugf("nftables: removed nat rules for %s", pair.Destination)
return nil
}
func (r *router) removeNatRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists {
err := r.conn.DelRule(rule)
if err != nil {
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
}
log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination)
delete(r.rules, ruleKey)
} else {
log.Debugf("nftables: prerouting rule %s not found", ruleKey)
}
return nil
}
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
// duplicates and to get missing attributes that we don't have when adding new rules
func (r *router) refreshRulesMap() error {
for _, chain := range r.chains {
rules, err := r.conn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("nftables: unable to list rules: %v", err)
}
for _, rule := range rules {
if len(rule.UserData) > 0 {
r.rules[string(rule.UserData)] = rule
}
}
}
return nil
}
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
var offset uint32
if source {
offset = 12 // src offset
} else {
offset = 16 // dst offset
}
ones := prefix.Bits()
// 0.0.0.0/0 doesn't need extra expressions
if ones == 0 {
return nil
}
mask := net.CIDRMask(ones, 32)
return []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: offset,
Len: 4,
},
// netmask
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 4,
Mask: mask,
Xor: []byte{0, 0, 0, 0},
},
// net address
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: prefix.Masked().Addr().AsSlice(),
},
}
}
func applyPort(port *firewall.Port, isSource bool) []expr.Any {
if port == nil {
return nil
}
var exprs []expr.Any
offset := uint32(2) // Default offset for destination port
if isSource {
offset = 0 // Offset for source port
}
exprs = append(exprs, &expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: offset,
Len: 2,
})
if port.IsRange && len(port.Values) == 2 {
// Handle port range
exprs = append(exprs,
&expr.Cmp{
Op: expr.CmpOpGte,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[0])),
},
&expr.Cmp{
Op: expr.CmpOpLte,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[1])),
},
)
} else {
// Handle single port or multiple ports
for i, p := range port.Values {
if i > 0 {
// Add a bitwise OR operation between port checks
exprs = append(exprs, &expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: []byte{0x00, 0x00, 0xff, 0xff},
Xor: []byte{0x00, 0x00, 0x00, 0x00},
})
}
exprs = append(exprs, &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(p)),
})
}
}
return exprs
}

View File

@@ -3,12 +3,16 @@
package nftables package nftables
import ( import (
"context" "encoding/binary"
"net/netip"
"os/exec"
"testing" "testing"
"github.com/coreos/go-iptables/iptables" "github.com/coreos/go-iptables/iptables"
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -24,195 +28,627 @@ const (
NFTABLES NFTABLES
) )
func TestNftablesManager_InsertRoutingRules(t *testing.T) { func TestNftablesManager_AddNatRule(t *testing.T) {
if check() != NFTABLES { if check() != NFTABLES {
t.Skip("nftables not supported on this OS") t.Skip("nftables not supported on this OS")
} }
table, err := createWorkTable()
if err != nil {
t.Fatal(err)
}
defer deleteWorkTable()
for _, testCase := range test.InsertRuleTestCases { for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
manager, err := newRouter(context.TODO(), table) // need fw manager to init both acl mgr and router for all chains to be present
require.NoError(t, err, "failed to create router") manager, err := Create(ifaceMock)
t.Cleanup(func() {
require.NoError(t, manager.Reset(nil))
})
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
nftablesTestingClient := &nftables.Conn{} nftablesTestingClient := &nftables.Conn{}
defer manager.ResetForwardRules() rtr := manager.router
err = rtr.AddNatRule(testCase.InputPair)
require.NoError(t, err, "pair should be inserted")
require.NoError(t, err, "shouldn't return error") t.Cleanup(func() {
require.NoError(t, rtr.RemoveNatRule(testCase.InputPair), "failed to remove rule")
err = manager.AddRoutingRules(testCase.InputPair) })
defer func() {
_ = manager.RemoveRoutingRules(testCase.InputPair)
}()
require.NoError(t, err, "forwarding pair should be inserted")
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
testingExpression := append(sourceExp, destExp...) //nolint:gocritic
fwdRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID)
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == fwdRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "forwarding rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
if testCase.InputPair.Masquerade { if testCase.InputPair.Masquerade {
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID) // Build expected expressions for connection tracking
conntrackExprs := []expr.Any{
&expr.Ct{
Key: expr.CtKeySTATE,
Register: 1,
},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0, 0, 0, 0},
},
}
// Build interface matching expression
ifaceExprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(ifaceMock.Name()),
},
}
// Build CIDR matching expressions
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
// Combine all expressions in the correct order
// nolint:gocritic
testingExpression := append(conntrackExprs, ifaceExprs...)
testingExpression = append(testingExpression, sourceExp...)
testingExpression = append(testingExpression, destExp...)
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
found := 0 found := 0
for _, chain := range manager.chains { for _, chain := range rtr.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain) if chain.Name == chainNamePrerouting {
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
for _, rule := range rules { require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey { for _, rule := range rules {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "nat rule elements should match") if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found = 1 // Compare expressions up to the mark setting expressions
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "prerouting nat rule elements should match")
found = 1
}
} }
} }
} }
require.Equal(t, 1, found, "should find at least 1 rule to test") require.Equal(t, 1, found, "should find at least 1 rule in prerouting chain")
}
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source)
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination)
testingExpression = append(sourceExp, destExp...) //nolint:gocritic
inFwdRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID)
found = 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inFwdRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income forwarding rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
if testCase.InputPair.Masquerade {
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID)
found := 0
for _, chain := range manager.chains {
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == inNatRuleKey {
require.ElementsMatchf(t, rule.Exprs[:len(testingExpression)], testingExpression, "income nat rule elements should match")
found = 1
}
}
}
require.Equal(t, 1, found, "should find at least 1 rule to test")
} }
}) })
} }
} }
func TestNftablesManager_RemoveRoutingRules(t *testing.T) { func TestNftablesManager_RemoveNatRule(t *testing.T) {
if check() != NFTABLES { if check() != NFTABLES {
t.Skip("nftables not supported on this OS") t.Skip("nftables not supported on this OS")
} }
table, err := createWorkTable() for _, testCase := range test.RemoveRuleTestCases {
if err != nil { t.Run(testCase.Name, func(t *testing.T) {
t.Fatal(err) manager, err := Create(ifaceMock)
t.Cleanup(func() {
require.NoError(t, manager.Reset(nil))
})
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
rtr := manager.router
// First add the NAT rule using the router's method
err = rtr.AddNatRule(testCase.InputPair)
require.NoError(t, err, "should add NAT rule")
// Verify the rule was added
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
found := false
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
require.NoError(t, err, "should list rules")
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found = true
break
}
}
require.True(t, found, "NAT rule should exist before removal")
// Now remove the rule
err = rtr.RemoveNatRule(testCase.InputPair)
require.NoError(t, err, "shouldn't return error when removing rule")
// Verify the rule was removed
found = false
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
require.NoError(t, err, "should list rules after removal")
for _, rule := range rules {
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
found = true
break
}
}
require.False(t, found, "NAT rule should not exist after removal")
// Verify the static postrouting rules still exist
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameRoutingNat])
require.NoError(t, err, "should list postrouting rules")
foundCounter := false
for _, rule := range rules {
for _, e := range rule.Exprs {
if _, ok := e.(*expr.Counter); ok {
foundCounter = true
break
}
}
if foundCounter {
break
}
}
require.True(t, foundCounter, "static postrouting rule should remain")
})
} }
}
func TestRouter_AddRouteFiltering(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
workTable, err := createWorkTable()
require.NoError(t, err, "Failed to create work table")
defer deleteWorkTable() defer deleteWorkTable()
for _, testCase := range test.RemoveRuleTestCases { r, err := newRouter(workTable, ifaceMock)
t.Run(testCase.Name, func(t *testing.T) { require.NoError(t, err, "Failed to create router")
manager, err := newRouter(context.TODO(), table) require.NoError(t, r.init(workTable))
require.NoError(t, err, "failed to create router")
nftablesTestingClient := &nftables.Conn{} defer func(r *router) {
require.NoError(t, r.Reset(), "Failed to reset rules")
}(r)
defer manager.ResetForwardRules() tests := []struct {
name string
sources []netip.Prefix
destination netip.Prefix
proto firewall.Protocol
sPort *firewall.Port
dPort *firewall.Port
direction firewall.RuleDirection
action firewall.Action
expectSet bool
}{
{
name: "Basic TCP rule with single source",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolTCP,
sPort: nil,
dPort: &firewall.Port{Values: []int{80}},
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with multiple sources",
sources: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("192.168.0.0/16"),
},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolUDP,
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionDrop,
expectSet: true,
},
{
name: "All protocols rule",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
destination: netip.MustParsePrefix("0.0.0.0/0"),
proto: firewall.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "ICMP rule",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/16")},
destination: netip.MustParsePrefix("10.0.0.0/8"),
proto: firewall.ProtocolICMP,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "TCP rule with multiple source ports",
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
destination: netip.MustParsePrefix("192.168.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
dPort: nil,
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "UDP rule with single IP and port range",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
destination: netip.MustParsePrefix("10.0.0.0/24"),
proto: firewall.ProtocolUDP,
sPort: nil,
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
},
{
name: "TCP rule with source and destination ports",
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
destination: netip.MustParsePrefix("172.16.0.0/16"),
proto: firewall.ProtocolTCP,
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
dPort: &firewall.Port{Values: []int{22}},
direction: firewall.RuleDirectionOUT,
action: firewall.ActionAccept,
expectSet: false,
},
{
name: "Drop all incoming traffic",
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
destination: netip.MustParsePrefix("192.168.0.0/24"),
proto: firewall.ProtocolALL,
sPort: nil,
dPort: nil,
direction: firewall.RuleDirectionIN,
action: firewall.ActionDrop,
expectSet: false,
},
}
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source) for _, tt := range tests {
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination) t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed")
forwardExp := append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic t.Cleanup(func() {
forwardRuleKey := firewall.GenKey(firewall.ForwardingFormat, testCase.InputPair.ID) require.NoError(t, r.DeleteRouteRule(ruleKey), "Failed to delete rule")
insertedForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{
Table: manager.workTable,
Chain: manager.chains[chainNameRouteingFw],
Exprs: forwardExp,
UserData: []byte(forwardRuleKey),
}) })
natExp := append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic // Check if the rule is in the internal map
natRuleKey := firewall.GenKey(firewall.NatFormat, testCase.InputPair.ID) rule, ok := r.rules[ruleKey.GetRuleID()]
assert.True(t, ok, "Rule not found in internal map")
insertedNat := nftablesTestingClient.InsertRule(&nftables.Rule{ t.Log("Internal rule expressions:")
Table: manager.workTable, for i, expr := range rule.Exprs {
Chain: manager.chains[chainNameRoutingNat], t.Logf(" [%d] %T: %+v", i, expr, expr)
Exprs: natExp, }
UserData: []byte(natRuleKey),
})
sourceExp = generateCIDRMatcherExpressions(true, firewall.GetInPair(testCase.InputPair).Source) // Verify internal rule content
destExp = generateCIDRMatcherExpressions(false, firewall.GetInPair(testCase.InputPair).Destination) verifyRule(t, rule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
forwardExp = append(sourceExp, append(destExp, exprCounterAccept...)...) //nolint:gocritic // Check if the rule exists in nftables and verify its content
inForwardRuleKey := firewall.GenKey(firewall.InForwardingFormat, testCase.InputPair.ID) rules, err := r.conn.GetRules(r.workTable, r.chains[chainNameRoutingFw])
insertedInForwarding := nftablesTestingClient.InsertRule(&nftables.Rule{ require.NoError(t, err, "Failed to get rules from nftables")
Table: manager.workTable,
Chain: manager.chains[chainNameRouteingFw],
Exprs: forwardExp,
UserData: []byte(inForwardRuleKey),
})
natExp = append(sourceExp, append(destExp, &expr.Counter{}, &expr.Masq{})...) //nolint:gocritic var nftRule *nftables.Rule
inNatRuleKey := firewall.GenKey(firewall.InNatFormat, testCase.InputPair.ID) for _, rule := range rules {
if string(rule.UserData) == ruleKey.GetRuleID() {
nftRule = rule
break
}
}
insertedInNat := nftablesTestingClient.InsertRule(&nftables.Rule{ require.NotNil(t, nftRule, "Rule not found in nftables")
Table: manager.workTable, t.Log("Actual nftables rule expressions:")
Chain: manager.chains[chainNameRoutingNat], for i, expr := range nftRule.Exprs {
Exprs: natExp, t.Logf(" [%d] %T: %+v", i, expr, expr)
UserData: []byte(inNatRuleKey), }
})
err = nftablesTestingClient.Flush() // Verify actual nftables rule content
require.NoError(t, err, "shouldn't return error") verifyRule(t, nftRule, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.direction, tt.action, tt.expectSet)
})
}
}
manager.ResetForwardRules() func TestNftablesCreateIpSet(t *testing.T) {
if check() != NFTABLES {
t.Skip("nftables not supported on this system")
}
err = manager.RemoveRoutingRules(testCase.InputPair) workTable, err := createWorkTable()
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "Failed to create work table")
for _, chain := range manager.chains { defer deleteWorkTable()
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name) r, err := newRouter(workTable, ifaceMock)
for _, rule := range rules { require.NoError(t, err, "Failed to create router")
if len(rule.UserData) > 0 { require.NoError(t, r.init(workTable))
require.NotEqual(t, insertedForwarding.UserData, rule.UserData, "forwarding rule should not exist")
require.NotEqual(t, insertedNat.UserData, rule.UserData, "nat rule should not exist") defer func() {
require.NotEqual(t, insertedInForwarding.UserData, rule.UserData, "income forwarding rule should not exist") require.NoError(t, r.Reset(), "Failed to reset router")
require.NotEqual(t, insertedInNat.UserData, rule.UserData, "income nat rule should not exist") }()
tests := []struct {
name string
sources []netip.Prefix
expected []netip.Prefix
}{
{
name: "Single IP",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.1/32")},
},
{
name: "Multiple IPs",
sources: []netip.Prefix{
netip.MustParsePrefix("192.168.1.1/32"),
netip.MustParsePrefix("10.0.0.1/32"),
netip.MustParsePrefix("172.16.0.1/32"),
},
},
{
name: "Single Subnet",
sources: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")},
},
{
name: "Multiple Subnets with Various Prefix Lengths",
sources: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("203.0.113.0/26"),
},
},
{
name: "Mix of Single IPs and Subnets in Different Positions",
sources: []netip.Prefix{
netip.MustParsePrefix("192.168.1.1/32"),
netip.MustParsePrefix("10.0.0.0/16"),
netip.MustParsePrefix("172.16.0.1/32"),
netip.MustParsePrefix("203.0.113.0/24"),
},
},
{
name: "Overlapping IPs/Subnets",
sources: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("10.0.0.0/16"),
netip.MustParsePrefix("10.0.0.1/32"),
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("192.168.1.1/32"),
},
expected: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.0.0/16"),
},
},
}
// Add this helper function inside TestNftablesCreateIpSet
printNftSets := func() {
cmd := exec.Command("nft", "list", "sets")
output, err := cmd.CombinedOutput()
if err != nil {
t.Logf("Failed to run 'nft list sets': %v", err)
} else {
t.Logf("Current nft sets:\n%s", output)
}
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
setName := firewall.GenerateSetName(tt.sources)
set, err := r.createIpSet(setName, tt.sources)
if err != nil {
t.Logf("Failed to create IP set: %v", err)
printNftSets()
require.NoError(t, err, "Failed to create IP set")
}
require.NotNil(t, set, "Created set is nil")
// Verify set properties
assert.Equal(t, setName, set.Name, "Set name mismatch")
assert.Equal(t, r.workTable, set.Table, "Set table mismatch")
assert.True(t, set.Interval, "Set interval property should be true")
assert.Equal(t, nftables.TypeIPAddr, set.KeyType, "Set key type mismatch")
// Fetch the created set from nftables
fetchedSet, err := r.conn.GetSetByName(r.workTable, setName)
require.NoError(t, err, "Failed to fetch created set")
require.NotNil(t, fetchedSet, "Fetched set is nil")
// Verify set elements
elements, err := r.conn.GetSetElements(fetchedSet)
require.NoError(t, err, "Failed to get set elements")
// Count the number of unique prefixes (excluding interval end markers)
uniquePrefixes := make(map[string]bool)
for _, elem := range elements {
if !elem.IntervalEnd {
ip := netip.AddrFrom4(*(*[4]byte)(elem.Key))
uniquePrefixes[ip.String()] = true
}
}
// Check against expected merged prefixes
expectedCount := len(tt.expected)
if expectedCount == 0 {
expectedCount = len(tt.sources)
}
assert.Equal(t, expectedCount, len(uniquePrefixes), "Number of unique prefixes in set doesn't match expected")
// Verify each expected prefix is in the set
for _, expected := range tt.expected {
found := false
for _, elem := range elements {
if !elem.IntervalEnd {
ip := netip.AddrFrom4(*(*[4]byte)(elem.Key))
if expected.Contains(ip) {
found = true
break
}
}
}
assert.True(t, found, "Expected prefix %s not found in set", expected)
}
r.conn.DelSet(set)
if err := r.conn.Flush(); err != nil {
t.Logf("Failed to delete set: %v", err)
printNftSets()
}
require.NoError(t, err, "Failed to delete set")
})
}
}
func verifyRule(t *testing.T, rule *nftables.Rule, sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, expectSet bool) {
t.Helper()
assert.NotNil(t, rule, "Rule should not be nil")
// Verify sources and destination
if expectSet {
assert.True(t, containsSetLookup(rule.Exprs), "Rule should contain set lookup for multiple sources")
} else if len(sources) == 1 && sources[0].Bits() != 0 {
if direction == firewall.RuleDirectionIN {
assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], true), "Rule should contain source CIDR matcher for %s", sources[0])
} else {
assert.True(t, containsCIDRMatcher(rule.Exprs, sources[0], false), "Rule should contain destination CIDR matcher for %s", sources[0])
}
}
if direction == firewall.RuleDirectionIN {
assert.True(t, containsCIDRMatcher(rule.Exprs, destination, false), "Rule should contain destination CIDR matcher for %s", destination)
} else {
assert.True(t, containsCIDRMatcher(rule.Exprs, destination, true), "Rule should contain source CIDR matcher for %s", destination)
}
// Verify protocol
if proto != firewall.ProtocolALL {
assert.True(t, containsProtocol(rule.Exprs, proto), "Rule should contain protocol matcher for %s", proto)
}
// Verify ports
if sPort != nil {
assert.True(t, containsPort(rule.Exprs, sPort, true), "Rule should contain source port matcher for %v", sPort)
}
if dPort != nil {
assert.True(t, containsPort(rule.Exprs, dPort, false), "Rule should contain destination port matcher for %v", dPort)
}
// Verify action
assert.True(t, containsAction(rule.Exprs, action), "Rule should contain correct action: %s", action)
}
func containsSetLookup(exprs []expr.Any) bool {
for _, e := range exprs {
if _, ok := e.(*expr.Lookup); ok {
return true
}
}
return false
}
func containsCIDRMatcher(exprs []expr.Any, prefix netip.Prefix, isSource bool) bool {
var offset uint32
if isSource {
offset = 12 // src offset
} else {
offset = 16 // dst offset
}
var payloadFound, bitwiseFound, cmpFound bool
for _, e := range exprs {
switch ex := e.(type) {
case *expr.Payload:
if ex.Base == expr.PayloadBaseNetworkHeader && ex.Offset == offset && ex.Len == 4 {
payloadFound = true
}
case *expr.Bitwise:
if ex.Len == 4 && len(ex.Mask) == 4 && len(ex.Xor) == 4 {
bitwiseFound = true
}
case *expr.Cmp:
if ex.Op == expr.CmpOpEq && len(ex.Data) == 4 {
cmpFound = true
}
}
}
return (payloadFound && bitwiseFound && cmpFound) || prefix.Bits() == 0
}
func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
var offset uint32 = 2 // Default offset for destination port
if isSource {
offset = 0 // Offset for source port
}
var payloadFound, portMatchFound bool
for _, e := range exprs {
switch ex := e.(type) {
case *expr.Payload:
if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 {
payloadFound = true
}
case *expr.Cmp:
if port.IsRange {
if ex.Op == expr.CmpOpGte || ex.Op == expr.CmpOpLte {
portMatchFound = true
}
} else {
if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 {
portValue := binary.BigEndian.Uint16(ex.Data)
for _, p := range port.Values {
if uint16(p) == portValue {
portMatchFound = true
break
}
} }
} }
} }
}) }
if payloadFound && portMatchFound {
return true
}
} }
return false
}
func containsProtocol(exprs []expr.Any, proto firewall.Protocol) bool {
var metaFound, cmpFound bool
expectedProto, _ := protoToInt(proto)
for _, e := range exprs {
switch ex := e.(type) {
case *expr.Meta:
if ex.Key == expr.MetaKeyL4PROTO {
metaFound = true
}
case *expr.Cmp:
if ex.Op == expr.CmpOpEq && len(ex.Data) == 1 && ex.Data[0] == expectedProto {
cmpFound = true
}
}
}
return metaFound && cmpFound
}
func containsAction(exprs []expr.Any, action firewall.Action) bool {
for _, e := range exprs {
if verdict, ok := e.(*expr.Verdict); ok {
switch action {
case firewall.ActionAccept:
return verdict.Kind == expr.VerdictAccept
case firewall.ActionDrop:
return verdict.Kind == expr.VerdictDrop
}
}
}
return false
} }
// check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found. // check returns the firewall type based on common lib checks. It returns UNKNOWN if no firewall is found.
@@ -250,12 +686,12 @@ func createWorkTable() (*nftables.Table, error) {
} }
for _, t := range tables { for _, t := range tables {
if t.Name == tableName { if t.Name == tableNameNetbird {
sConn.DelTable(t) sConn.DelTable(t)
} }
} }
table := sConn.AddTable(&nftables.Table{Name: tableName, Family: nftables.TableFamilyIPv4}) table := sConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4})
err = sConn.Flush() err = sConn.Flush()
return table, err return table, err
@@ -273,7 +709,7 @@ func deleteWorkTable() {
} }
for _, t := range tables { for _, t := range tables {
if t.Name == tableName { if t.Name == tableNameNetbird {
sConn.DelTable(t) sConn.DelTable(t)
} }
} }

View File

@@ -0,0 +1 @@
package nftables

View File

@@ -0,0 +1,47 @@
package nftables
import (
"fmt"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
type InterfaceState struct {
NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
}
func (i *InterfaceState) Name() string {
return i.NameStr
}
func (i *InterfaceState) Address() device.WGAddress {
return i.WGAddress
}
func (i *InterfaceState) IsUserspaceBind() bool {
return i.UserspaceBind
}
type ShutdownState struct {
InterfaceState *InterfaceState `json:"interface_state,omitempty"`
}
func (s *ShutdownState) Name() string {
return "nftables_state"
}
func (s *ShutdownState) Cleanup() error {
nft, err := Create(s.InterfaceState)
if err != nil {
return fmt.Errorf("create nftables manager: %w", err)
}
if err := nft.Reset(nil); err != nil {
return fmt.Errorf("reset nftables manager: %w", err)
}
return nil
}

View File

@@ -1,8 +1,10 @@
//go:build !android
package test package test
import firewall "github.com/netbirdio/netbird/client/firewall/manager" import (
"net/netip"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
var ( var (
InsertRuleTestCases = []struct { InsertRuleTestCases = []struct {
@@ -13,8 +15,8 @@ var (
Name: "Insert Forwarding IPV4 Rule", Name: "Insert Forwarding IPV4 Rule",
InputPair: firewall.RouterPair{ InputPair: firewall.RouterPair{
ID: "zxa", ID: "zxa",
Source: "100.100.100.1/32", Source: netip.MustParsePrefix("100.100.100.1/32"),
Destination: "100.100.200.0/24", Destination: netip.MustParsePrefix("100.100.200.0/24"),
Masquerade: false, Masquerade: false,
}, },
}, },
@@ -22,8 +24,8 @@ var (
Name: "Insert Forwarding And Nat IPV4 Rules", Name: "Insert Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{ InputPair: firewall.RouterPair{
ID: "zxa", ID: "zxa",
Source: "100.100.100.1/32", Source: netip.MustParsePrefix("100.100.100.1/32"),
Destination: "100.100.200.0/24", Destination: netip.MustParsePrefix("100.100.200.0/24"),
Masquerade: true, Masquerade: true,
}, },
}, },
@@ -38,8 +40,8 @@ var (
Name: "Remove Forwarding And Nat IPV4 Rules", Name: "Remove Forwarding And Nat IPV4 Rules",
InputPair: firewall.RouterPair{ InputPair: firewall.RouterPair{
ID: "zxa", ID: "zxa",
Source: "100.100.100.1/32", Source: netip.MustParsePrefix("100.100.100.1/32"),
Destination: "100.100.200.0/24", Destination: netip.MustParsePrefix("100.100.200.0/24"),
Masquerade: true, Masquerade: true,
}, },
}, },

View File

@@ -2,8 +2,10 @@
package uspfilter package uspfilter
import "github.com/netbirdio/netbird/client/internal/statemanager"
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset() error { func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -11,7 +13,7 @@ func (m *Manager) Reset() error {
m.incomingRules = make(map[string]RuleSet) m.incomingRules = make(map[string]RuleSet)
if m.nativeFirewall != nil { if m.nativeFirewall != nil {
return m.nativeFirewall.Reset() return m.nativeFirewall.Reset(stateManager)
} }
return nil return nil
} }

View File

@@ -6,6 +6,8 @@ import (
"syscall" "syscall"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
type action string type action string
@@ -17,7 +19,7 @@ const (
) )
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset() error { func (m *Manager) Reset(*statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()

View File

@@ -3,6 +3,7 @@ package uspfilter
import ( import (
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
"github.com/google/gopacket" "github.com/google/gopacket"
@@ -11,7 +12,9 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal/statemanager"
) )
const layerTypeAll = 0 const layerTypeAll = 0
@@ -22,7 +25,7 @@ var (
// IFaceMapper defines subset methods of interface required for manager // IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface { type IFaceMapper interface {
SetFilter(iface.PacketFilter) error SetFilter(device.PacketFilter) error
Address() iface.WGAddress Address() iface.WGAddress
} }
@@ -95,6 +98,10 @@ func create(iface IFaceMapper) (*Manager, error) {
return m, nil return m, nil
} }
func (m *Manager) Init(*statemanager.Manager) error {
return nil
}
func (m *Manager) IsServerRouteSupported() bool { func (m *Manager) IsServerRouteSupported() bool {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
return false return false
@@ -103,26 +110,26 @@ func (m *Manager) IsServerRouteSupported() bool {
} }
} }
func (m *Manager) InsertRoutingRules(pair firewall.RouterPair) error { func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
return errRouteNotSupported return errRouteNotSupported
} }
return m.nativeFirewall.InsertRoutingRules(pair) return m.nativeFirewall.AddNatRule(pair)
} }
// RemoveRoutingRules removes a routing firewall rule // RemoveNatRule removes a routing firewall rule
func (m *Manager) RemoveRoutingRules(pair firewall.RouterPair) error { func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
if m.nativeFirewall == nil { if m.nativeFirewall == nil {
return errRouteNotSupported return errRouteNotSupported
} }
return m.nativeFirewall.RemoveRoutingRules(pair) return m.nativeFirewall.RemoveNatRule(pair)
} }
// AddFiltering rule to the firewall // AddPeerFiltering rule to the firewall
// //
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
func (m *Manager) AddFiltering( func (m *Manager) AddPeerFiltering(
ip net.IP, ip net.IP,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
@@ -188,8 +195,22 @@ func (m *Manager) AddFiltering(
return []firewall.Rule{&r}, nil return []firewall.Rule{&r}, nil
} }
// DeleteRule from the firewall by rule definition func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
func (m *Manager) DeleteRule(rule firewall.Rule) error { if m.nativeFirewall == nil {
return nil, errRouteNotSupported
}
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
}
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
if m.nativeFirewall == nil {
return errRouteNotSupported
}
return m.nativeFirewall.DeleteRouteRule(rule)
}
// DeletePeerRule from the firewall by rule definition
func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -215,6 +236,14 @@ func (m *Manager) DeleteRule(rule firewall.Rule) error {
return nil return nil
} }
// SetLegacyManagement doesn't need to be implemented for this manager
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
if m.nativeFirewall == nil {
return errRouteNotSupported
}
return m.nativeFirewall.SetLegacyManagement(isLegacy)
}
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil } func (m *Manager) Flush() error { return nil }
@@ -395,7 +424,7 @@ func (m *Manager) RemovePacketHook(hookID string) error {
for _, r := range arr { for _, r := range arr {
if r.id == hookID { if r.id == hookID {
rule := r rule := r
return m.DeleteRule(&rule) return m.DeletePeerRule(&rule)
} }
} }
} }
@@ -403,7 +432,7 @@ func (m *Manager) RemovePacketHook(hookID string) error {
for _, r := range arr { for _, r := range arr {
if r.id == hookID { if r.id == hookID {
rule := r rule := r
return m.DeleteRule(&rule) return m.DeletePeerRule(&rule)
} }
} }
} }

View File

@@ -11,15 +11,16 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
) )
type IFaceMock struct { type IFaceMock struct {
SetFilterFunc func(iface.PacketFilter) error SetFilterFunc func(device.PacketFilter) error
AddressFunc func() iface.WGAddress AddressFunc func() iface.WGAddress
} }
func (i *IFaceMock) SetFilter(iface iface.PacketFilter) error { func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
if i.SetFilterFunc == nil { if i.SetFilterFunc == nil {
return fmt.Errorf("not implemented") return fmt.Errorf("not implemented")
} }
@@ -35,7 +36,7 @@ func (i *IFaceMock) Address() iface.WGAddress {
func TestManagerCreate(t *testing.T) { func TestManagerCreate(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(iface.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock) m, err := Create(ifaceMock)
@@ -49,10 +50,10 @@ func TestManagerCreate(t *testing.T) {
} }
} }
func TestManagerAddFiltering(t *testing.T) { func TestManagerAddPeerFiltering(t *testing.T) {
isSetFilterCalled := false isSetFilterCalled := false
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(iface.PacketFilter) error { SetFilterFunc: func(device.PacketFilter) error {
isSetFilterCalled = true isSetFilterCalled = true
return nil return nil
}, },
@@ -71,7 +72,7 @@ func TestManagerAddFiltering(t *testing.T) {
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule"
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment) rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -90,7 +91,7 @@ func TestManagerAddFiltering(t *testing.T) {
func TestManagerDeleteRule(t *testing.T) { func TestManagerDeleteRule(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(iface.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock) m, err := Create(ifaceMock)
@@ -106,7 +107,7 @@ func TestManagerDeleteRule(t *testing.T) {
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule"
rule, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment) rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -119,14 +120,14 @@ func TestManagerDeleteRule(t *testing.T) {
action = fw.ActionDrop action = fw.ActionDrop
comment = "Test rule 2" comment = "Test rule 2"
rule2, err := m.AddFiltering(ip, proto, nil, port, direction, action, "", comment) rule2, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
} }
for _, r := range rule { for _, r := range rule {
err = m.DeleteRule(r) err = m.DeletePeerRule(r)
if err != nil { if err != nil {
t.Errorf("failed to delete rule: %v", err) t.Errorf("failed to delete rule: %v", err)
return return
@@ -140,7 +141,7 @@ func TestManagerDeleteRule(t *testing.T) {
} }
for _, r := range rule2 { for _, r := range rule2 {
err = m.DeleteRule(r) err = m.DeletePeerRule(r)
if err != nil { if err != nil {
t.Errorf("failed to delete rule: %v", err) t.Errorf("failed to delete rule: %v", err)
return return
@@ -236,7 +237,7 @@ func TestAddUDPPacketHook(t *testing.T) {
func TestManagerReset(t *testing.T) { func TestManagerReset(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(iface.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock) m, err := Create(ifaceMock)
@@ -252,13 +253,13 @@ func TestManagerReset(t *testing.T) {
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule" comment := "Test rule"
_, err = m.AddFiltering(ip, proto, nil, port, direction, action, "", comment) _, err = m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
} }
err = m.Reset() err = m.Reset(nil)
if err != nil { if err != nil {
t.Errorf("failed to reset Manager: %v", err) t.Errorf("failed to reset Manager: %v", err)
return return
@@ -271,7 +272,7 @@ func TestManagerReset(t *testing.T) {
func TestNotMatchByIP(t *testing.T) { func TestNotMatchByIP(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(iface.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock) m, err := Create(ifaceMock)
@@ -290,7 +291,7 @@ func TestNotMatchByIP(t *testing.T) {
action := fw.ActionAccept action := fw.ActionAccept
comment := "Test rule" comment := "Test rule"
_, err = m.AddFiltering(ip, proto, nil, nil, direction, action, "", comment) _, err = m.AddPeerFiltering(ip, proto, nil, nil, direction, action, "", comment)
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -329,7 +330,7 @@ func TestNotMatchByIP(t *testing.T) {
return return
} }
if err = m.Reset(); err != nil { if err = m.Reset(nil); err != nil {
t.Errorf("failed to reset Manager: %v", err) t.Errorf("failed to reset Manager: %v", err)
return return
} }
@@ -339,7 +340,7 @@ func TestNotMatchByIP(t *testing.T) {
func TestRemovePacketHook(t *testing.T) { func TestRemovePacketHook(t *testing.T) {
// creating mock iface // creating mock iface
iface := &IFaceMock{ iface := &IFaceMock{
SetFilterFunc: func(iface.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
// creating manager instance // creating manager instance
@@ -388,14 +389,14 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface // just check on the local interface
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(iface.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
if err := manager.Reset(); err != nil { if err := manager.Reset(nil); err != nil {
t.Errorf("clear the manager state: %v", err) t.Errorf("clear the manager state: %v", err)
} }
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -406,9 +407,9 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []int{1000 + i}} port := &fw.Port{Values: []int{1000 + i}}
if i%2 == 0 { if i%2 == 0 {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
} else { } else {
_, err = manager.AddFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
} }
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")

View File

@@ -0,0 +1,5 @@
package bind
import wgConn "golang.zx2c4.com/wireguard/conn"
type Endpoint = wgConn.StdNetEndpoint

View File

@@ -0,0 +1,303 @@
package bind
import (
"fmt"
"net"
"net/netip"
"runtime"
"strings"
"sync"
"github.com/pion/stun/v2"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
)
type RecvMessage struct {
Endpoint *Endpoint
Buffer []byte
}
type receiverCreator struct {
iceBind *ICEBind
}
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
}
// ICEBind is a bind implementation with two main features:
// 1. filter out STUN messages and handle them
// 2. forward the received packets to the WireGuard interface from the relayed connection
//
// ICEBind.endpoints var is a map that stores the connection for each relayed peer. Fake address is just an IP address
// without port, in the format of 127.1.x.x where x.x is the last two octets of the peer address. We try to avoid to
// use the port because in the Send function the wgConn.Endpoint the port info is not exported.
type ICEBind struct {
*wgConn.StdNetBind
RecvChan chan RecvMessage
transportNet transport.Net
filterFn FilterFn
endpoints map[netip.Addr]net.Conn
endpointsMu sync.Mutex
// every time when Close() is called (i.e. BindUpdate()) we need to close exit from the receiveRelayed and create a
// new closed channel. With the closedChanMu we can safely close the channel and create a new one
closedChan chan struct{}
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
closed bool
muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault
}
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{
StdNetBind: b,
RecvChan: make(chan RecvMessage, 1),
transportNet: transportNet,
filterFn: filterFn,
endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}),
closed: true,
}
rc := receiverCreator{
ib,
}
ib.StdNetBind = wgConn.NewStdNetBindWithReceiverCreator(rc)
return ib
}
func (s *ICEBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
s.closed = false
s.closedChanMu.Lock()
s.closedChan = make(chan struct{})
s.closedChanMu.Unlock()
fns, port, err := s.StdNetBind.Open(uport)
if err != nil {
return nil, 0, err
}
fns = append(fns, s.receiveRelayed)
return fns, port, nil
}
func (s *ICEBind) Close() error {
if s.closed {
return nil
}
s.closed = true
close(s.closedChan)
return s.StdNetBind.Close()
}
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
if s.udpMux == nil {
return nil, fmt.Errorf("ICEBind has not been initialized yet")
}
return s.udpMux, nil
}
func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) {
fakeUDPAddr, err := fakeAddress(peerAddress)
if err != nil {
return nil, err
}
// force IPv4
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
return nil, fmt.Errorf("failed to convert IP to netip.Addr")
}
b.endpointsMu.Lock()
b.endpoints[fakeAddr] = conn
b.endpointsMu.Unlock()
return fakeUDPAddr, nil
}
func (b *ICEBind) RemoveEndpoint(fakeUDPAddr *net.UDPAddr) {
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
log.Warnf("failed to convert IP to netip.Addr")
return
}
b.endpointsMu.Lock()
defer b.endpointsMu.Unlock()
delete(b.endpoints, fakeAddr)
}
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
b.endpointsMu.Lock()
conn, ok := b.endpoints[ep.DstIP()]
b.endpointsMu.Unlock()
if !ok {
return b.StdNetBind.Send(bufs, ep)
}
for _, buf := range bufs {
if _, err := conn.Write(buf); err != nil {
return err
}
}
return nil
}
func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc {
s.muUDPMux.Lock()
defer s.muUDPMux.Unlock()
s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{
UDPConn: conn,
Net: s.transportNet,
FilterFn: s.filterFn,
},
)
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
msgs := getMessages(msgsPool)
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
}
defer putMessages(msgs, msgsPool)
var numMsgs int
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
if rxOffload {
readAt := len(*msgs) - (wgConn.IdealBatchSize / wgConn.UdpSegmentMaxDatagrams)
//nolint
numMsgs, err = pc.ReadBatch((*msgs)[readAt:], 0)
if err != nil {
return 0, err
}
numMsgs, err = wgConn.SplitCoalescedMessages(*msgs, readAt, wgConn.GetGSOSize)
if err != nil {
return 0, err
}
} else {
numMsgs, err = pc.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
}
} else {
msg := &(*msgs)[0]
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
if err != nil {
return 0, err
}
numMsgs = 1
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
// todo: handle err
ok, _ := s.filterOutStunMessages(msg.Buffers, msg.N, msg.Addr)
if ok {
continue
}
sizes[i] = msg.N
if sizes[i] == 0 {
continue
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
eps[i] = ep
}
return numMsgs, nil
}
}
func (s *ICEBind) filterOutStunMessages(buffers [][]byte, n int, addr net.Addr) (bool, error) {
for i := range buffers {
if !stun.IsMessage(buffers[i]) {
continue
}
msg, err := s.parseSTUNMessage(buffers[i][:n])
if err != nil {
buffers[i] = []byte{}
return true, err
}
muxErr := s.udpMux.HandleSTUNMessage(msg, addr)
if muxErr != nil {
log.Warnf("failed to handle STUN packet")
}
buffers[i] = []byte{}
return true, nil
}
return false, nil
}
func (s *ICEBind) parseSTUNMessage(raw []byte) (*stun.Message, error) {
msg := &stun.Message{
Raw: raw,
}
if err := msg.Decode(); err != nil {
return nil, err
}
return msg, nil
}
// receiveRelayed is a receive function that is used to receive packets from the relayed connection and forward to the
// WireGuard. Critical part is do not block if the Closed() has been called.
func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
c.closedChanMu.RLock()
defer c.closedChanMu.RUnlock()
select {
case <-c.closedChan:
return 0, net.ErrClosed
case msg, ok := <-c.RecvChan:
if !ok {
return 0, net.ErrClosed
}
copy(buffs[0], msg.Buffer)
sizes[0] = len(msg.Buffer)
eps[0] = wgConn.Endpoint(msg.Endpoint)
return 1, nil
}
}
// fakeAddress returns a fake address that is used to as an identifier for the peer.
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
octets := strings.Split(peerAddress.IP.String(), ".")
if len(octets) != 4 {
return nil, fmt.Errorf("invalid IP format")
}
newAddr := &net.UDPAddr{
IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])),
Port: peerAddress.Port,
}
return newAddr, nil
}
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
return msgsPool.Get().(*[]ipv6.Message)
}
func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
for i := range *msgs {
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
}
msgsPool.Put(msgs)
}

View File

@@ -0,0 +1,5 @@
package configurer
import "errors"
var ErrPeerNotFound = errors.New("peer not found")

View File

@@ -1,6 +1,6 @@
//go:build (linux && !android) || freebsd //go:build (linux && !android) || freebsd
package iface package configurer
import ( import (
"fmt" "fmt"
@@ -12,18 +12,17 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
type wgKernelConfigurer struct { type KernelConfigurer struct {
deviceName string deviceName string
} }
func newWGConfigurer(deviceName string) wgConfigurer { func NewKernelConfigurer(deviceName string) *KernelConfigurer {
wgc := &wgKernelConfigurer{ return &KernelConfigurer{
deviceName: deviceName, deviceName: deviceName,
} }
return wgc
} }
func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) error { func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error {
log.Debugf("adding Wireguard private key") log.Debugf("adding Wireguard private key")
key, err := wgtypes.ParseKey(privateKey) key, err := wgtypes.ParseKey(privateKey)
if err != nil { if err != nil {
@@ -44,7 +43,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err
return nil return nil
} }
func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
// parse allowed ips // parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps) _, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil { if err != nil {
@@ -75,7 +74,7 @@ func (c *wgKernelConfigurer) updatePeer(peerKey string, allowedIps string, keepA
return nil return nil
} }
func (c *wgKernelConfigurer) removePeer(peerKey string) error { func (c *KernelConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {
return err return err
@@ -96,7 +95,7 @@ func (c *wgKernelConfigurer) removePeer(peerKey string) error {
return nil return nil
} }
func (c *wgKernelConfigurer) addAllowedIP(peerKey string, allowedIP string) error { func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP) _, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil { if err != nil {
return err return err
@@ -123,7 +122,7 @@ func (c *wgKernelConfigurer) addAllowedIP(peerKey string, allowedIP string) erro
return nil return nil
} }
func (c *wgKernelConfigurer) removeAllowedIP(peerKey string, allowedIP string) error { func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP) _, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil { if err != nil {
return fmt.Errorf("parse allowed IP: %w", err) return fmt.Errorf("parse allowed IP: %w", err)
@@ -165,7 +164,7 @@ func (c *wgKernelConfigurer) removeAllowedIP(peerKey string, allowedIP string) e
return nil return nil
} }
func (c *wgKernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) { func (c *KernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer, error) {
wg, err := wgctrl.New() wg, err := wgctrl.New()
if err != nil { if err != nil {
return wgtypes.Peer{}, fmt.Errorf("wgctl: %w", err) return wgtypes.Peer{}, fmt.Errorf("wgctl: %w", err)
@@ -189,7 +188,7 @@ func (c *wgKernelConfigurer) getPeer(ifaceName, peerPubKey string) (wgtypes.Peer
return wgtypes.Peer{}, ErrPeerNotFound return wgtypes.Peer{}, ErrPeerNotFound
} }
func (c *wgKernelConfigurer) configure(config wgtypes.Config) error { func (c *KernelConfigurer) configure(config wgtypes.Config) error {
wg, err := wgctrl.New() wg, err := wgctrl.New()
if err != nil { if err != nil {
return err return err
@@ -205,10 +204,10 @@ func (c *wgKernelConfigurer) configure(config wgtypes.Config) error {
return wg.ConfigureDevice(c.deviceName, config) return wg.ConfigureDevice(c.deviceName, config)
} }
func (c *wgKernelConfigurer) close() { func (c *KernelConfigurer) Close() {
} }
func (c *wgKernelConfigurer) getStats(peerKey string) (WGStats, error) { func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) {
peer, err := c.getPeer(c.deviceName, peerKey) peer, err := c.getPeer(c.deviceName, peerKey)
if err != nil { if err != nil {
return WGStats{}, fmt.Errorf("get wireguard stats: %w", err) return WGStats{}, fmt.Errorf("get wireguard stats: %w", err)

View File

@@ -1,6 +1,6 @@
//go:build linux || windows || freebsd //go:build linux || windows || freebsd
package iface package configurer
// WgInterfaceDefault is a default interface name of Wiretrustee // WgInterfaceDefault is a default interface name of Wiretrustee
const WgInterfaceDefault = "wt0" const WgInterfaceDefault = "wt0"

View File

@@ -1,6 +1,6 @@
//go:build darwin //go:build darwin
package iface package configurer
// WgInterfaceDefault is a default interface name of Wiretrustee // WgInterfaceDefault is a default interface name of Wiretrustee
const WgInterfaceDefault = "utun100" const WgInterfaceDefault = "utun100"

View File

@@ -1,6 +1,6 @@
//go:build !windows //go:build !windows
package iface package configurer
import ( import (
"net" "net"

View File

@@ -1,4 +1,4 @@
package iface package configurer
import ( import (
"net" "net"

View File

@@ -1,4 +1,4 @@
package iface package configurer
import ( import (
"encoding/hex" "encoding/hex"
@@ -19,15 +19,15 @@ import (
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found") var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
type wgUSPConfigurer struct { type WGUSPConfigurer struct {
device *device.Device device *device.Device
deviceName string deviceName string
uapiListener net.Listener uapiListener net.Listener
} }
func newWGUSPConfigurer(device *device.Device, deviceName string) wgConfigurer { func NewUSPConfigurer(device *device.Device, deviceName string) *WGUSPConfigurer {
wgCfg := &wgUSPConfigurer{ wgCfg := &WGUSPConfigurer{
device: device, device: device,
deviceName: deviceName, deviceName: deviceName,
} }
@@ -35,7 +35,7 @@ func newWGUSPConfigurer(device *device.Device, deviceName string) wgConfigurer {
return wgCfg return wgCfg
} }
func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error { func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
log.Debugf("adding Wireguard private key") log.Debugf("adding Wireguard private key")
key, err := wgtypes.ParseKey(privateKey) key, err := wgtypes.ParseKey(privateKey)
if err != nil { if err != nil {
@@ -52,7 +52,7 @@ func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error
return c.device.IpcSet(toWgUserspaceString(config)) return c.device.IpcSet(toWgUserspaceString(config))
} }
func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
// parse allowed ips // parse allowed ips
_, ipNet, err := net.ParseCIDR(allowedIps) _, ipNet, err := net.ParseCIDR(allowedIps)
if err != nil { if err != nil {
@@ -80,7 +80,7 @@ func (c *wgUSPConfigurer) updatePeer(peerKey string, allowedIps string, keepAliv
return c.device.IpcSet(toWgUserspaceString(config)) return c.device.IpcSet(toWgUserspaceString(config))
} }
func (c *wgUSPConfigurer) removePeer(peerKey string) error { func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
peerKeyParsed, err := wgtypes.ParseKey(peerKey) peerKeyParsed, err := wgtypes.ParseKey(peerKey)
if err != nil { if err != nil {
return err return err
@@ -97,7 +97,7 @@ func (c *wgUSPConfigurer) removePeer(peerKey string) error {
return c.device.IpcSet(toWgUserspaceString(config)) return c.device.IpcSet(toWgUserspaceString(config))
} }
func (c *wgUSPConfigurer) addAllowedIP(peerKey string, allowedIP string) error { func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
_, ipNet, err := net.ParseCIDR(allowedIP) _, ipNet, err := net.ParseCIDR(allowedIP)
if err != nil { if err != nil {
return err return err
@@ -121,7 +121,7 @@ func (c *wgUSPConfigurer) addAllowedIP(peerKey string, allowedIP string) error {
return c.device.IpcSet(toWgUserspaceString(config)) return c.device.IpcSet(toWgUserspaceString(config))
} }
func (c *wgUSPConfigurer) removeAllowedIP(peerKey string, ip string) error { func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
ipc, err := c.device.IpcGet() ipc, err := c.device.IpcGet()
if err != nil { if err != nil {
return err return err
@@ -185,7 +185,7 @@ func (c *wgUSPConfigurer) removeAllowedIP(peerKey string, ip string) error {
} }
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool // startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
func (t *wgUSPConfigurer) startUAPI() { func (t *WGUSPConfigurer) startUAPI() {
var err error var err error
t.uapiListener, err = openUAPI(t.deviceName) t.uapiListener, err = openUAPI(t.deviceName)
if err != nil { if err != nil {
@@ -207,7 +207,7 @@ func (t *wgUSPConfigurer) startUAPI() {
}(t.uapiListener) }(t.uapiListener)
} }
func (t *wgUSPConfigurer) close() { func (t *WGUSPConfigurer) Close() {
if t.uapiListener != nil { if t.uapiListener != nil {
err := t.uapiListener.Close() err := t.uapiListener.Close()
if err != nil { if err != nil {
@@ -223,7 +223,7 @@ func (t *wgUSPConfigurer) close() {
} }
} }
func (t *wgUSPConfigurer) getStats(peerKey string) (WGStats, error) { func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) {
ipc, err := t.device.IpcGet() ipc, err := t.device.IpcGet()
if err != nil { if err != nil {
return WGStats{}, fmt.Errorf("ipc get: %w", err) return WGStats{}, fmt.Errorf("ipc get: %w", err)

View File

@@ -1,4 +1,4 @@
package iface package configurer
import ( import (
"encoding/hex" "encoding/hex"

View File

@@ -0,0 +1,9 @@
package configurer
import "time"
type WGStats struct {
LastHandshake time.Time
TxBytes int64
RxBytes int64
}

18
client/iface/device.go Normal file
View File

@@ -0,0 +1,18 @@
//go:build !android
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
)
type WGTunDevice interface {
Create() (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address WGAddress) error
WgAddress() WGAddress
DeviceName() string
Close() error
FilteredDevice() *device.FilteredDevice
}

View File

@@ -1,4 +1,4 @@
package iface package device
// TunAdapter is an interface for create tun device from external service // TunAdapter is an interface for create tun device from external service
type TunAdapter interface { type TunAdapter interface {

View File

@@ -1,18 +1,18 @@
package iface package device
import ( import (
"fmt" "fmt"
"net" "net"
) )
// WGAddress Wireguard parsed address // WGAddress WireGuard parsed address
type WGAddress struct { type WGAddress struct {
IP net.IP IP net.IP
Network *net.IPNet Network *net.IPNet
} }
// parseWGAddress parse a string ("1.2.3.4/24") address to WG Address // ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
func parseWGAddress(address string) (WGAddress, error) { func ParseWGAddress(address string) (WGAddress, error) {
ip, network, err := net.ParseCIDR(address) ip, network, err := net.ParseCIDR(address)
if err != nil { if err != nil {
return WGAddress{}, err return WGAddress{}, err

View File

@@ -1,4 +1,4 @@
package iface package device
type MobileIFaceArguments struct { type MobileIFaceArguments struct {
TunAdapter TunAdapter // only for Android TunAdapter TunAdapter // only for Android

View File

@@ -1,22 +1,21 @@
//go:build android //go:build android
// +build android
package iface package device
import ( import (
"strings" "strings"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
"github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
) )
// ignore the wgTunDevice interface on Android because the creation of the tun device is different on this platform // WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
type wgTunDevice struct { type WGTunDevice struct {
address WGAddress address WGAddress
port int port int
key string key string
@@ -24,25 +23,25 @@ type wgTunDevice struct {
iceBind *bind.ICEBind iceBind *bind.ICEBind
tunAdapter TunAdapter tunAdapter TunAdapter
name string name string
device *device.Device device *device.Device
wrapper *DeviceWrapper filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault udpMux *bind.UniversalUDPMuxDefault
configurer wgConfigurer configurer WGConfigurer
} }
func newTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter, filterFn bind.FilterFn) wgTunDevice { func NewTunDevice(address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
return wgTunDevice{ return &WGTunDevice{
address: address, address: address,
port: port, port: port,
key: key, key: key,
mtu: mtu, mtu: mtu,
iceBind: bind.NewICEBind(transportNet, filterFn), iceBind: iceBind,
tunAdapter: tunAdapter, tunAdapter: tunAdapter,
} }
} }
func (t *wgTunDevice) Create(routes []string, dns string, searchDomains []string) (wgConfigurer, error) { func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) {
log.Info("create tun interface") log.Info("create tun interface")
routesString := routesToString(routes) routesString := routesToString(routes)
@@ -61,24 +60,24 @@ func (t *wgTunDevice) Create(routes []string, dns string, searchDomains []string
return nil, err return nil, err
} }
t.name = name t.name = name
t.wrapper = newDeviceWrapper(tunDevice) t.filteredDevice = newDeviceFilter(tunDevice)
log.Debugf("attaching to interface %v", name) log.Debugf("attaching to interface %v", name)
t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] ")) t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
// without this property mobile devices can discover remote endpoints if the configured one was wrong. // without this property mobile devices can discover remote endpoints if the configured one was wrong.
// this helps with support for the older NetBird clients that had a hardcoded direct mode // this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics() // t.device.DisableSomeRoamingForBrokenMobileSemantics()
t.configurer = newWGUSPConfigurer(t.device, t.name) t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.configureInterface(t.key, t.port) err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil { if err != nil {
t.device.Close() t.device.Close()
t.configurer.close() t.configurer.Close()
return nil, err return nil, err
} }
return t.configurer, nil return t.configurer, nil
} }
func (t *wgTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up() err := t.device.Up()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -93,14 +92,14 @@ func (t *wgTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *wgTunDevice) UpdateAddr(addr WGAddress) error { func (t *WGTunDevice) UpdateAddr(addr WGAddress) error {
// todo implement // todo implement
return nil return nil
} }
func (t *wgTunDevice) Close() error { func (t *WGTunDevice) Close() error {
if t.configurer != nil { if t.configurer != nil {
t.configurer.close() t.configurer.Close()
} }
if t.device != nil { if t.device != nil {
@@ -115,20 +114,20 @@ func (t *wgTunDevice) Close() error {
return nil return nil
} }
func (t *wgTunDevice) Device() *device.Device { func (t *WGTunDevice) Device() *device.Device {
return t.device return t.device
} }
func (t *wgTunDevice) DeviceName() string { func (t *WGTunDevice) DeviceName() string {
return t.name return t.name
} }
func (t *wgTunDevice) WgAddress() WGAddress { func (t *WGTunDevice) WgAddress() WGAddress {
return t.address return t.address
} }
func (t *wgTunDevice) Wrapper() *DeviceWrapper { func (t *WGTunDevice) FilteredDevice() *FilteredDevice {
return t.wrapper return t.filteredDevice
} }
func routesToString(routes []string) string { func routesToString(routes []string) string {

View File

@@ -1,20 +1,20 @@
//go:build !ios //go:build !ios
package iface package device
import ( import (
"fmt" "fmt"
"os/exec" "os/exec"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
"github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
) )
type tunDevice struct { type TunDevice struct {
name string name string
address WGAddress address WGAddress
port int port int
@@ -22,33 +22,33 @@ type tunDevice struct {
mtu int mtu int
iceBind *bind.ICEBind iceBind *bind.ICEBind
device *device.Device device *device.Device
wrapper *DeviceWrapper filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault udpMux *bind.UniversalUDPMuxDefault
configurer wgConfigurer configurer WGConfigurer
} }
func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice { func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &tunDevice{ return &TunDevice{
name: name, name: name,
address: address, address: address,
port: port, port: port,
key: key, key: key,
mtu: mtu, mtu: mtu,
iceBind: bind.NewICEBind(transportNet, filterFn), iceBind: iceBind,
} }
} }
func (t *tunDevice) Create() (wgConfigurer, error) { func (t *TunDevice) Create() (WGConfigurer, error) {
tunDevice, err := tun.CreateTUN(t.name, t.mtu) tunDevice, err := tun.CreateTUN(t.name, t.mtu)
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating tun device: %s", err) return nil, fmt.Errorf("error creating tun device: %s", err)
} }
t.wrapper = newDeviceWrapper(tunDevice) t.filteredDevice = newDeviceFilter(tunDevice)
// We need to create a wireguard-go device and listen to configuration requests // We need to create a wireguard-go device and listen to configuration requests
t.device = device.NewDevice( t.device = device.NewDevice(
t.wrapper, t.filteredDevice,
t.iceBind, t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "), device.NewLogger(wgLogLevel(), "[netbird] "),
) )
@@ -59,17 +59,17 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
return nil, fmt.Errorf("error assigning ip: %s", err) return nil, fmt.Errorf("error assigning ip: %s", err)
} }
t.configurer = newWGUSPConfigurer(t.device, t.name) t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.configureInterface(t.key, t.port) err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil { if err != nil {
t.device.Close() t.device.Close()
t.configurer.close() t.configurer.Close()
return nil, fmt.Errorf("error configuring interface: %s", err) return nil, fmt.Errorf("error configuring interface: %s", err)
} }
return t.configurer, nil return t.configurer, nil
} }
func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up() err := t.device.Up()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -84,14 +84,14 @@ func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *tunDevice) UpdateAddr(address WGAddress) error { func (t *TunDevice) UpdateAddr(address WGAddress) error {
t.address = address t.address = address
return t.assignAddr() return t.assignAddr()
} }
func (t *tunDevice) Close() error { func (t *TunDevice) Close() error {
if t.configurer != nil { if t.configurer != nil {
t.configurer.close() t.configurer.Close()
} }
if t.device != nil { if t.device != nil {
@@ -105,20 +105,20 @@ func (t *tunDevice) Close() error {
return nil return nil
} }
func (t *tunDevice) WgAddress() WGAddress { func (t *TunDevice) WgAddress() WGAddress {
return t.address return t.address
} }
func (t *tunDevice) DeviceName() string { func (t *TunDevice) DeviceName() string {
return t.name return t.name
} }
func (t *tunDevice) Wrapper() *DeviceWrapper { func (t *TunDevice) FilteredDevice() *FilteredDevice {
return t.wrapper return t.filteredDevice
} }
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided // assignAddr Adds IP address to the tunnel interface and network route based on the range provided
func (t *tunDevice) assignAddr() error { func (t *TunDevice) assignAddr() error {
cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String()) cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String())
if out, err := cmd.CombinedOutput(); err != nil { if out, err := cmd.CombinedOutput(); err != nil {
log.Errorf("adding address command '%v' failed with output: %s", cmd.String(), out) log.Errorf("adding address command '%v' failed with output: %s", cmd.String(), out)

View File

@@ -1,4 +1,4 @@
package iface package device
import ( import (
"net" "net"
@@ -28,22 +28,23 @@ type PacketFilter interface {
SetNetwork(*net.IPNet) SetNetwork(*net.IPNet)
} }
// DeviceWrapper to override Read or Write of packets // FilteredDevice to override Read or Write of packets
type DeviceWrapper struct { type FilteredDevice struct {
tun.Device tun.Device
filter PacketFilter filter PacketFilter
mutex sync.RWMutex mutex sync.RWMutex
} }
// newDeviceWrapper constructor function // newDeviceFilter constructor function
func newDeviceWrapper(device tun.Device) *DeviceWrapper { func newDeviceFilter(device tun.Device) *FilteredDevice {
return &DeviceWrapper{ return &FilteredDevice{
Device: device, Device: device,
} }
} }
// Read wraps read method with filtering feature // Read wraps read method with filtering feature
func (d *DeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
if n, err = d.Device.Read(bufs, sizes, offset); err != nil { if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
return 0, err return 0, err
} }
@@ -68,7 +69,7 @@ func (d *DeviceWrapper) Read(bufs [][]byte, sizes []int, offset int) (n int, err
} }
// Write wraps write method with filtering feature // Write wraps write method with filtering feature
func (d *DeviceWrapper) Write(bufs [][]byte, offset int) (int, error) { func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
d.mutex.RLock() d.mutex.RLock()
filter := d.filter filter := d.filter
d.mutex.RUnlock() d.mutex.RUnlock()
@@ -92,7 +93,7 @@ func (d *DeviceWrapper) Write(bufs [][]byte, offset int) (int, error) {
} }
// SetFilter sets packet filter to device // SetFilter sets packet filter to device
func (d *DeviceWrapper) SetFilter(filter PacketFilter) { func (d *FilteredDevice) SetFilter(filter PacketFilter) {
d.mutex.Lock() d.mutex.Lock()
d.filter = filter d.filter = filter
d.mutex.Unlock() d.mutex.Unlock()

View File

@@ -1,4 +1,4 @@
package iface package device
import ( import (
"net" "net"
@@ -7,7 +7,8 @@ import (
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
mocks "github.com/netbirdio/netbird/iface/mocks"
mocks "github.com/netbirdio/netbird/client/iface/mocks"
) )
func TestDeviceWrapperRead(t *testing.T) { func TestDeviceWrapperRead(t *testing.T) {
@@ -51,7 +52,7 @@ func TestDeviceWrapperRead(t *testing.T) {
return 1, nil return 1, nil
}) })
wrapped := newDeviceWrapper(tun) wrapped := newDeviceFilter(tun)
bufs := [][]byte{{}} bufs := [][]byte{{}}
sizes := []int{0} sizes := []int{0}
@@ -99,7 +100,7 @@ func TestDeviceWrapperRead(t *testing.T) {
tun := mocks.NewMockDevice(ctrl) tun := mocks.NewMockDevice(ctrl)
tun.EXPECT().Write(mockBufs, 0).Return(1, nil) tun.EXPECT().Write(mockBufs, 0).Return(1, nil)
wrapped := newDeviceWrapper(tun) wrapped := newDeviceFilter(tun)
bufs := [][]byte{buffer.Bytes()} bufs := [][]byte{buffer.Bytes()}
@@ -147,7 +148,7 @@ func TestDeviceWrapperRead(t *testing.T) {
filter := mocks.NewMockPacketFilter(ctrl) filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropIncoming(gomock.Any()).Return(true) filter.EXPECT().DropIncoming(gomock.Any()).Return(true)
wrapped := newDeviceWrapper(tun) wrapped := newDeviceFilter(tun)
wrapped.filter = filter wrapped.filter = filter
bufs := [][]byte{buffer.Bytes()} bufs := [][]byte{buffer.Bytes()}
@@ -202,7 +203,7 @@ func TestDeviceWrapperRead(t *testing.T) {
filter := mocks.NewMockPacketFilter(ctrl) filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropOutgoing(gomock.Any()).Return(true) filter.EXPECT().DropOutgoing(gomock.Any()).Return(true)
wrapped := newDeviceWrapper(tun) wrapped := newDeviceFilter(tun)
wrapped.filter = filter wrapped.filter = filter
bufs := [][]byte{{}} bufs := [][]byte{{}}

View File

@@ -1,21 +1,21 @@
//go:build ios //go:build ios
// +build ios // +build ios
package iface package device
import ( import (
"os" "os"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
"github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
) )
type tunDevice struct { type TunDevice struct {
name string name string
address WGAddress address WGAddress
port int port int
@@ -23,24 +23,24 @@ type tunDevice struct {
iceBind *bind.ICEBind iceBind *bind.ICEBind
tunFd int tunFd int
device *device.Device device *device.Device
wrapper *DeviceWrapper filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault udpMux *bind.UniversalUDPMuxDefault
configurer wgConfigurer configurer WGConfigurer
} }
func newTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int, filterFn bind.FilterFn) *tunDevice { func NewTunDevice(name string, address WGAddress, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
return &tunDevice{ return &TunDevice{
name: name, name: name,
address: address, address: address,
port: port, port: port,
key: key, key: key,
iceBind: bind.NewICEBind(transportNet, filterFn), iceBind: iceBind,
tunFd: tunFd, tunFd: tunFd,
} }
} }
func (t *tunDevice) Create() (wgConfigurer, error) { func (t *TunDevice) Create() (WGConfigurer, error) {
log.Infof("create tun interface") log.Infof("create tun interface")
dupTunFd, err := unix.Dup(t.tunFd) dupTunFd, err := unix.Dup(t.tunFd)
@@ -62,24 +62,24 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
return nil, err return nil, err
} }
t.wrapper = newDeviceWrapper(tunDevice) t.filteredDevice = newDeviceFilter(tunDevice)
log.Debug("Attaching to interface") log.Debug("Attaching to interface")
t.device = device.NewDevice(t.wrapper, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] ")) t.device = device.NewDevice(t.filteredDevice, t.iceBind, device.NewLogger(wgLogLevel(), "[wiretrustee] "))
// without this property mobile devices can discover remote endpoints if the configured one was wrong. // without this property mobile devices can discover remote endpoints if the configured one was wrong.
// this helps with support for the older NetBird clients that had a hardcoded direct mode // this helps with support for the older NetBird clients that had a hardcoded direct mode
// t.device.DisableSomeRoamingForBrokenMobileSemantics() // t.device.DisableSomeRoamingForBrokenMobileSemantics()
t.configurer = newWGUSPConfigurer(t.device, t.name) t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.configureInterface(t.key, t.port) err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil { if err != nil {
t.device.Close() t.device.Close()
t.configurer.close() t.configurer.Close()
return nil, err return nil, err
} }
return t.configurer, nil return t.configurer, nil
} }
func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up() err := t.device.Up()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -94,17 +94,17 @@ func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *tunDevice) Device() *device.Device { func (t *TunDevice) Device() *device.Device {
return t.device return t.device
} }
func (t *tunDevice) DeviceName() string { func (t *TunDevice) DeviceName() string {
return t.name return t.name
} }
func (t *tunDevice) Close() error { func (t *TunDevice) Close() error {
if t.configurer != nil { if t.configurer != nil {
t.configurer.close() t.configurer.Close()
} }
if t.device != nil { if t.device != nil {
@@ -119,15 +119,15 @@ func (t *tunDevice) Close() error {
return nil return nil
} }
func (t *tunDevice) WgAddress() WGAddress { func (t *TunDevice) WgAddress() WGAddress {
return t.address return t.address
} }
func (t *tunDevice) UpdateAddr(addr WGAddress) error { func (t *TunDevice) UpdateAddr(addr WGAddress) error {
// todo implement // todo implement
return nil return nil
} }
func (t *tunDevice) Wrapper() *DeviceWrapper { func (t *TunDevice) FilteredDevice() *FilteredDevice {
return t.wrapper return t.filteredDevice
} }

View File

@@ -1,6 +1,6 @@
//go:build (linux && !android) || freebsd //go:build (linux && !android) || freebsd
package iface package device
import ( import (
"context" "context"
@@ -10,11 +10,12 @@ import (
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/sharedsock" "github.com/netbirdio/netbird/sharedsock"
) )
type tunKernelDevice struct { type TunKernelDevice struct {
name string name string
address WGAddress address WGAddress
wgPort int wgPort int
@@ -31,11 +32,11 @@ type tunKernelDevice struct {
filterFn bind.FilterFn filterFn bind.FilterFn
} }
func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) wgTunDevice { func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
checkUser() checkUser()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &tunKernelDevice{ return &TunKernelDevice{
ctx: ctx, ctx: ctx,
ctxCancel: cancel, ctxCancel: cancel,
name: name, name: name,
@@ -47,7 +48,7 @@ func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu in
} }
} }
func (t *tunKernelDevice) Create() (wgConfigurer, error) { func (t *TunKernelDevice) Create() (WGConfigurer, error) {
link := newWGLink(t.name) link := newWGLink(t.name)
if err := link.recreate(); err != nil { if err := link.recreate(); err != nil {
@@ -67,16 +68,16 @@ func (t *tunKernelDevice) Create() (wgConfigurer, error) {
return nil, fmt.Errorf("set mtu: %w", err) return nil, fmt.Errorf("set mtu: %w", err)
} }
configurer := newWGConfigurer(t.name) configurer := configurer.NewKernelConfigurer(t.name)
if err := configurer.configureInterface(t.key, t.wgPort); err != nil { if err := configurer.ConfigureInterface(t.key, t.wgPort); err != nil {
return nil, fmt.Errorf("error configuring interface: %s", err) return nil, fmt.Errorf("error configuring interface: %s", err)
} }
return configurer, nil return configurer, nil
} }
func (t *tunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.udpMux != nil { if t.udpMux != nil {
return t.udpMux, nil return t.udpMux, nil
} }
@@ -111,12 +112,12 @@ func (t *tunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return t.udpMux, nil return t.udpMux, nil
} }
func (t *tunKernelDevice) UpdateAddr(address WGAddress) error { func (t *TunKernelDevice) UpdateAddr(address WGAddress) error {
t.address = address t.address = address
return t.assignAddr() return t.assignAddr()
} }
func (t *tunKernelDevice) Close() error { func (t *TunKernelDevice) Close() error {
if t.link == nil { if t.link == nil {
return nil return nil
} }
@@ -144,19 +145,19 @@ func (t *tunKernelDevice) Close() error {
return closErr return closErr
} }
func (t *tunKernelDevice) WgAddress() WGAddress { func (t *TunKernelDevice) WgAddress() WGAddress {
return t.address return t.address
} }
func (t *tunKernelDevice) DeviceName() string { func (t *TunKernelDevice) DeviceName() string {
return t.name return t.name
} }
func (t *tunKernelDevice) Wrapper() *DeviceWrapper { func (t *TunKernelDevice) FilteredDevice() *FilteredDevice {
return nil return nil
} }
// assignAddr Adds IP address to the tunnel interface // assignAddr Adds IP address to the tunnel interface
func (t *tunKernelDevice) assignAddr() error { func (t *TunKernelDevice) assignAddr() error {
return t.link.assignAddr(t.address) return t.link.assignAddr(t.address)
} }

View File

@@ -1,20 +1,20 @@
//go:build !android //go:build !android
// +build !android // +build !android
package iface package device
import ( import (
"fmt" "fmt"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/iface/netstack" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/netstack"
) )
type tunNetstackDevice struct { type TunNetstackDevice struct {
name string name string
address WGAddress address WGAddress
port int port int
@@ -23,42 +23,42 @@ type tunNetstackDevice struct {
listenAddress string listenAddress string
iceBind *bind.ICEBind iceBind *bind.ICEBind
device *device.Device device *device.Device
wrapper *DeviceWrapper filteredDevice *FilteredDevice
nsTun *netstack.NetStackTun nsTun *netstack.NetStackTun
udpMux *bind.UniversalUDPMuxDefault udpMux *bind.UniversalUDPMuxDefault
configurer wgConfigurer configurer WGConfigurer
} }
func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) wgTunDevice { func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
return &tunNetstackDevice{ return &TunNetstackDevice{
name: name, name: name,
address: address, address: address,
port: wgPort, port: wgPort,
key: key, key: key,
mtu: mtu, mtu: mtu,
listenAddress: listenAddress, listenAddress: listenAddress,
iceBind: bind.NewICEBind(transportNet, filterFn), iceBind: iceBind,
} }
} }
func (t *tunNetstackDevice) Create() (wgConfigurer, error) { func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
log.Info("create netstack tun interface") log.Info("create netstack tun interface")
t.nsTun = netstack.NewNetStackTun(t.listenAddress, t.address.IP.String(), t.mtu) t.nsTun = netstack.NewNetStackTun(t.listenAddress, t.address.IP.String(), t.mtu)
tunIface, err := t.nsTun.Create() tunIface, err := t.nsTun.Create()
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating tun device: %s", err) return nil, fmt.Errorf("error creating tun device: %s", err)
} }
t.wrapper = newDeviceWrapper(tunIface) t.filteredDevice = newDeviceFilter(tunIface)
t.device = device.NewDevice( t.device = device.NewDevice(
t.wrapper, t.filteredDevice,
t.iceBind, t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "), device.NewLogger(wgLogLevel(), "[netbird] "),
) )
t.configurer = newWGUSPConfigurer(t.device, t.name) t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.configureInterface(t.key, t.port) err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil { if err != nil {
_ = tunIface.Close() _ = tunIface.Close()
return nil, fmt.Errorf("error configuring interface: %s", err) return nil, fmt.Errorf("error configuring interface: %s", err)
@@ -68,7 +68,7 @@ func (t *tunNetstackDevice) Create() (wgConfigurer, error) {
return t.configurer, nil return t.configurer, nil
} }
func (t *tunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.device == nil { if t.device == nil {
return nil, fmt.Errorf("device is not ready yet") return nil, fmt.Errorf("device is not ready yet")
} }
@@ -87,13 +87,13 @@ func (t *tunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *tunNetstackDevice) UpdateAddr(WGAddress) error { func (t *TunNetstackDevice) UpdateAddr(WGAddress) error {
return nil return nil
} }
func (t *tunNetstackDevice) Close() error { func (t *TunNetstackDevice) Close() error {
if t.configurer != nil { if t.configurer != nil {
t.configurer.close() t.configurer.Close()
} }
if t.device != nil { if t.device != nil {
@@ -106,14 +106,14 @@ func (t *tunNetstackDevice) Close() error {
return nil return nil
} }
func (t *tunNetstackDevice) WgAddress() WGAddress { func (t *TunNetstackDevice) WgAddress() WGAddress {
return t.address return t.address
} }
func (t *tunNetstackDevice) DeviceName() string { func (t *TunNetstackDevice) DeviceName() string {
return t.name return t.name
} }
func (t *tunNetstackDevice) Wrapper() *DeviceWrapper { func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice {
return t.wrapper return t.filteredDevice
} }

View File

@@ -1,21 +1,21 @@
//go:build (linux && !android) || freebsd //go:build (linux && !android) || freebsd
package iface package device
import ( import (
"fmt" "fmt"
"os" "os"
"runtime" "runtime"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
"github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
) )
type tunUSPDevice struct { type USPDevice struct {
name string name string
address WGAddress address WGAddress
port int port int
@@ -23,39 +23,39 @@ type tunUSPDevice struct {
mtu int mtu int
iceBind *bind.ICEBind iceBind *bind.ICEBind
device *device.Device device *device.Device
wrapper *DeviceWrapper filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault udpMux *bind.UniversalUDPMuxDefault
configurer wgConfigurer configurer WGConfigurer
} }
func newTunUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice { func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
log.Infof("using userspace bind mode") log.Infof("using userspace bind mode")
checkUser() checkUser()
return &tunUSPDevice{ return &USPDevice{
name: name, name: name,
address: address, address: address,
port: port, port: port,
key: key, key: key,
mtu: mtu, mtu: mtu,
iceBind: bind.NewICEBind(transportNet, filterFn), iceBind: iceBind,
} }
} }
func (t *tunUSPDevice) Create() (wgConfigurer, error) { func (t *USPDevice) Create() (WGConfigurer, error) {
log.Info("create tun interface") log.Info("create tun interface")
tunIface, err := tun.CreateTUN(t.name, t.mtu) tunIface, err := tun.CreateTUN(t.name, t.mtu)
if err != nil { if err != nil {
log.Debugf("failed to create tun interface (%s, %d): %s", t.name, t.mtu, err) log.Debugf("failed to create tun interface (%s, %d): %s", t.name, t.mtu, err)
return nil, fmt.Errorf("error creating tun device: %s", err) return nil, fmt.Errorf("error creating tun device: %s", err)
} }
t.wrapper = newDeviceWrapper(tunIface) t.filteredDevice = newDeviceFilter(tunIface)
// We need to create a wireguard-go device and listen to configuration requests // We need to create a wireguard-go device and listen to configuration requests
t.device = device.NewDevice( t.device = device.NewDevice(
t.wrapper, t.filteredDevice,
t.iceBind, t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "), device.NewLogger(wgLogLevel(), "[netbird] "),
) )
@@ -66,17 +66,17 @@ func (t *tunUSPDevice) Create() (wgConfigurer, error) {
return nil, fmt.Errorf("error assigning ip: %s", err) return nil, fmt.Errorf("error assigning ip: %s", err)
} }
t.configurer = newWGUSPConfigurer(t.device, t.name) t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.configureInterface(t.key, t.port) err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil { if err != nil {
t.device.Close() t.device.Close()
t.configurer.close() t.configurer.Close()
return nil, fmt.Errorf("error configuring interface: %s", err) return nil, fmt.Errorf("error configuring interface: %s", err)
} }
return t.configurer, nil return t.configurer, nil
} }
func (t *tunUSPDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
if t.device == nil { if t.device == nil {
return nil, fmt.Errorf("device is not ready yet") return nil, fmt.Errorf("device is not ready yet")
} }
@@ -96,14 +96,14 @@ func (t *tunUSPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *tunUSPDevice) UpdateAddr(address WGAddress) error { func (t *USPDevice) UpdateAddr(address WGAddress) error {
t.address = address t.address = address
return t.assignAddr() return t.assignAddr()
} }
func (t *tunUSPDevice) Close() error { func (t *USPDevice) Close() error {
if t.configurer != nil { if t.configurer != nil {
t.configurer.close() t.configurer.Close()
} }
if t.device != nil { if t.device != nil {
@@ -116,20 +116,20 @@ func (t *tunUSPDevice) Close() error {
return nil return nil
} }
func (t *tunUSPDevice) WgAddress() WGAddress { func (t *USPDevice) WgAddress() WGAddress {
return t.address return t.address
} }
func (t *tunUSPDevice) DeviceName() string { func (t *USPDevice) DeviceName() string {
return t.name return t.name
} }
func (t *tunUSPDevice) Wrapper() *DeviceWrapper { func (t *USPDevice) FilteredDevice() *FilteredDevice {
return t.wrapper return t.filteredDevice
} }
// assignAddr Adds IP address to the tunnel interface // assignAddr Adds IP address to the tunnel interface
func (t *tunUSPDevice) assignAddr() error { func (t *USPDevice) assignAddr() error {
link := newWGLink(t.name) link := newWGLink(t.name)
return link.assignAddr(t.address) return link.assignAddr(t.address)

View File

@@ -1,22 +1,22 @@
package iface package device
import ( import (
"fmt" "fmt"
"net/netip" "net/netip"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
"github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
) )
const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}" const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}"
type tunDevice struct { type TunDevice struct {
name string name string
address WGAddress address WGAddress
port int port int
@@ -26,19 +26,19 @@ type tunDevice struct {
device *device.Device device *device.Device
nativeTunDevice *tun.NativeTun nativeTunDevice *tun.NativeTun
wrapper *DeviceWrapper filteredDevice *FilteredDevice
udpMux *bind.UniversalUDPMuxDefault udpMux *bind.UniversalUDPMuxDefault
configurer wgConfigurer configurer WGConfigurer
} }
func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice { func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &tunDevice{ return &TunDevice{
name: name, name: name,
address: address, address: address,
port: port, port: port,
key: key, key: key,
mtu: mtu, mtu: mtu,
iceBind: bind.NewICEBind(transportNet, filterFn), iceBind: iceBind,
} }
} }
@@ -50,7 +50,7 @@ func getGUID() (windows.GUID, error) {
return windows.GUIDFromString(guidString) return windows.GUIDFromString(guidString)
} }
func (t *tunDevice) Create() (wgConfigurer, error) { func (t *TunDevice) Create() (WGConfigurer, error) {
guid, err := getGUID() guid, err := getGUID()
if err != nil { if err != nil {
log.Errorf("failed to get GUID: %s", err) log.Errorf("failed to get GUID: %s", err)
@@ -62,11 +62,11 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
return nil, fmt.Errorf("error creating tun device: %s", err) return nil, fmt.Errorf("error creating tun device: %s", err)
} }
t.nativeTunDevice = tunDevice.(*tun.NativeTun) t.nativeTunDevice = tunDevice.(*tun.NativeTun)
t.wrapper = newDeviceWrapper(tunDevice) t.filteredDevice = newDeviceFilter(tunDevice)
// We need to create a wireguard-go device and listen to configuration requests // We need to create a wireguard-go device and listen to configuration requests
t.device = device.NewDevice( t.device = device.NewDevice(
t.wrapper, t.filteredDevice,
t.iceBind, t.iceBind,
device.NewLogger(wgLogLevel(), "[netbird] "), device.NewLogger(wgLogLevel(), "[netbird] "),
) )
@@ -92,17 +92,17 @@ func (t *tunDevice) Create() (wgConfigurer, error) {
return nil, fmt.Errorf("error assigning ip: %s", err) return nil, fmt.Errorf("error assigning ip: %s", err)
} }
t.configurer = newWGUSPConfigurer(t.device, t.name) t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
err = t.configurer.configureInterface(t.key, t.port) err = t.configurer.ConfigureInterface(t.key, t.port)
if err != nil { if err != nil {
t.device.Close() t.device.Close()
t.configurer.close() t.configurer.Close()
return nil, fmt.Errorf("error configuring interface: %s", err) return nil, fmt.Errorf("error configuring interface: %s", err)
} }
return t.configurer, nil return t.configurer, nil
} }
func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) { func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
err := t.device.Up() err := t.device.Up()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -117,14 +117,14 @@ func (t *tunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *tunDevice) UpdateAddr(address WGAddress) error { func (t *TunDevice) UpdateAddr(address WGAddress) error {
t.address = address t.address = address
return t.assignAddr() return t.assignAddr()
} }
func (t *tunDevice) Close() error { func (t *TunDevice) Close() error {
if t.configurer != nil { if t.configurer != nil {
t.configurer.close() t.configurer.Close()
} }
if t.device != nil { if t.device != nil {
@@ -138,19 +138,19 @@ func (t *tunDevice) Close() error {
} }
return nil return nil
} }
func (t *tunDevice) WgAddress() WGAddress { func (t *TunDevice) WgAddress() WGAddress {
return t.address return t.address
} }
func (t *tunDevice) DeviceName() string { func (t *TunDevice) DeviceName() string {
return t.name return t.name
} }
func (t *tunDevice) Wrapper() *DeviceWrapper { func (t *TunDevice) FilteredDevice() *FilteredDevice {
return t.wrapper return t.filteredDevice
} }
func (t *tunDevice) getInterfaceGUIDString() (string, error) { func (t *TunDevice) GetInterfaceGUIDString() (string, error) {
if t.nativeTunDevice == nil { if t.nativeTunDevice == nil {
return "", fmt.Errorf("interface has not been initialized yet") return "", fmt.Errorf("interface has not been initialized yet")
} }
@@ -164,7 +164,7 @@ func (t *tunDevice) getInterfaceGUIDString() (string, error) {
} }
// assignAddr Adds IP address to the tunnel interface and network route based on the range provided // assignAddr Adds IP address to the tunnel interface and network route based on the range provided
func (t *tunDevice) assignAddr() error { func (t *TunDevice) assignAddr() error {
luid := winipcfg.LUID(t.nativeTunDevice.LUID()) luid := winipcfg.LUID(t.nativeTunDevice.LUID())
log.Debugf("adding address %s to interface: %s", t.address.IP, t.name) log.Debugf("adding address %s to interface: %s", t.address.IP, t.name)
return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())}) return luid.SetIPAddresses([]netip.Prefix{netip.MustParsePrefix(t.address.String())})

View File

@@ -0,0 +1,20 @@
package device
import (
"net"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/configurer"
)
type WGConfigurer interface {
ConfigureInterface(privateKey string, port int) error
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(peerKey string) error
AddAllowedIP(peerKey string, allowedIP string) error
RemoveAllowedIP(peerKey string, allowedIP string) error
Close()
GetStats(peerKey string) (configurer.WGStats, error)
}

View File

@@ -1,6 +1,6 @@
//go:build (!linux && !freebsd) || android //go:build (!linux && !freebsd) || android
package iface package device
// WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only) // WireGuardModuleIsLoaded check if we can load WireGuard mod (linux only)
func WireGuardModuleIsLoaded() bool { func WireGuardModuleIsLoaded() bool {

View File

@@ -1,4 +1,4 @@
package iface package device
// WireGuardModuleIsLoaded check if kernel support wireguard // WireGuardModuleIsLoaded check if kernel support wireguard
func WireGuardModuleIsLoaded() bool { func WireGuardModuleIsLoaded() bool {
@@ -10,8 +10,8 @@ func WireGuardModuleIsLoaded() bool {
return false return false
} }
// tunModuleIsLoaded check if tun module exist, if is not attempt to load it // ModuleTunIsLoaded check if tun module exist, if is not attempt to load it
func tunModuleIsLoaded() bool { func ModuleTunIsLoaded() bool {
// Assume tun supported by freebsd kernel by default // Assume tun supported by freebsd kernel by default
// TODO: implement check for module loaded in kernel or build-it // TODO: implement check for module loaded in kernel or build-it
return true return true

View File

@@ -1,7 +1,7 @@
//go:build linux && !android //go:build linux && !android
// Package iface provides wireguard network interface creation and management // Package iface provides wireguard network interface creation and management
package iface package device
import ( import (
"bufio" "bufio"
@@ -66,8 +66,8 @@ func getModuleRoot() string {
return filepath.Join(moduleLibDir, string(uname.Release[:i])) return filepath.Join(moduleLibDir, string(uname.Release[:i]))
} }
// tunModuleIsLoaded check if tun module exist, if is not attempt to load it // ModuleTunIsLoaded check if tun module exist, if is not attempt to load it
func tunModuleIsLoaded() bool { func ModuleTunIsLoaded() bool {
_, err := os.Stat("/dev/net/tun") _, err := os.Stat("/dev/net/tun")
if err == nil { if err == nil {
return true return true

View File

@@ -1,4 +1,6 @@
package iface //go:build linux && !android
package device
import ( import (
"bufio" "bufio"
@@ -132,7 +134,7 @@ func resetGlobals() {
} }
func createFiles(t *testing.T) (string, []module) { func createFiles(t *testing.T) (string, []module) {
t.Helper() t.Helper()
writeFile := func(path, text string) { writeFile := func(path, text string) {
if err := os.WriteFile(path, []byte(text), 0644); err != nil { if err := os.WriteFile(path, []byte(text), 0644); err != nil {
t.Fatal(err) t.Fatal(err)
@@ -168,7 +170,7 @@ func createFiles(t *testing.T) (string, []module) {
} }
func getRandomLoadedModule(t *testing.T) (string, error) { func getRandomLoadedModule(t *testing.T) (string, error) {
t.Helper() t.Helper()
f, err := os.Open("/proc/modules") f, err := os.Open("/proc/modules")
if err != nil { if err != nil {
return "", err return "", err

View File

@@ -1,10 +1,11 @@
package iface package device
import ( import (
"fmt" "fmt"
"github.com/netbirdio/netbird/iface/freebsd"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/freebsd"
) )
type wgLink struct { type wgLink struct {

View File

@@ -1,6 +1,6 @@
//go:build linux && !android //go:build linux && !android
package iface package device
import ( import (
"fmt" "fmt"

View File

@@ -1,4 +1,4 @@
package iface package device
import ( import (
"os" "os"

View File

@@ -0,0 +1,4 @@
package device
// CustomWindowsGUIDString is a custom GUID string for the interface
var CustomWindowsGUIDString string

View File

@@ -0,0 +1,16 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
)
type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address WGAddress) error
WgAddress() WGAddress
DeviceName() string
Close() error
FilteredDevice() *device.FilteredDevice
}

View File

@@ -6,31 +6,55 @@ import (
"sync" "sync"
"time" "time"
"github.com/hashicorp/go-multierror"
"github.com/pion/transport/v3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
) )
const ( const (
DefaultMTU = 1280 DefaultMTU = 1280
DefaultWgPort = 51820 DefaultWgPort = 51820
WgInterfaceDefault = configurer.WgInterfaceDefault
) )
// WGIface represents a interface instance type WGAddress = device.WGAddress
type wgProxyFactory interface {
GetProxy() wgproxy.Proxy
Free() error
}
type WGIFaceOpts struct {
IFaceName string
Address string
WGPort int
WGPrivKey string
MTU int
MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net
FilterFn bind.FilterFn
}
// WGIface represents an interface instance
type WGIface struct { type WGIface struct {
tun wgTunDevice tun WGTunDevice
userspaceBind bool userspaceBind bool
mu sync.Mutex mu sync.Mutex
configurer wgConfigurer configurer device.WGConfigurer
filter PacketFilter filter device.PacketFilter
wgProxyFactory wgProxyFactory
} }
type WGStats struct { func (w *WGIface) GetProxy() wgproxy.Proxy {
LastHandshake time.Time return w.wgProxyFactory.GetProxy()
TxBytes int64
RxBytes int64
} }
// IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind // IsUserspaceBind indicates whether this interfaces is userspace with bind.ICEBind
@@ -44,7 +68,7 @@ func (w *WGIface) Name() string {
} }
// Address returns the interface address // Address returns the interface address
func (w *WGIface) Address() WGAddress { func (w *WGIface) Address() device.WGAddress {
return w.tun.WgAddress() return w.tun.WgAddress()
} }
@@ -75,7 +99,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
addr, err := parseWGAddress(newAddr) addr, err := device.ParseWGAddress(newAddr)
if err != nil { if err != nil {
return err return err
} }
@@ -90,7 +114,7 @@ func (w *WGIface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.D
defer w.mu.Unlock() defer w.mu.Unlock()
log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint) log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint)
return w.configurer.updatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
} }
// RemovePeer removes a Wireguard Peer from the interface iface // RemovePeer removes a Wireguard Peer from the interface iface
@@ -99,7 +123,7 @@ func (w *WGIface) RemovePeer(peerKey string) error {
defer w.mu.Unlock() defer w.mu.Unlock()
log.Debugf("Removing peer %s from interface %s ", peerKey, w.tun.DeviceName()) log.Debugf("Removing peer %s from interface %s ", peerKey, w.tun.DeviceName())
return w.configurer.removePeer(peerKey) return w.configurer.RemovePeer(peerKey)
} }
// AddAllowedIP adds a prefix to the allowed IPs list of peer // AddAllowedIP adds a prefix to the allowed IPs list of peer
@@ -108,7 +132,7 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error {
defer w.mu.Unlock() defer w.mu.Unlock()
log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) log.Debugf("Adding allowed IP to interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
return w.configurer.addAllowedIP(peerKey, allowedIP) return w.configurer.AddAllowedIP(peerKey, allowedIP)
} }
// RemoveAllowedIP removes a prefix from the allowed IPs list of peer // RemoveAllowedIP removes a prefix from the allowed IPs list of peer
@@ -117,7 +141,7 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
defer w.mu.Unlock() defer w.mu.Unlock()
log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP) log.Debugf("Removing allowed IP from interface %s and peer %s: allowed IP %s ", w.tun.DeviceName(), peerKey, allowedIP)
return w.configurer.removeAllowedIP(peerKey, allowedIP) return w.configurer.RemoveAllowedIP(peerKey, allowedIP)
} }
// Close closes the tunnel interface // Close closes the tunnel interface
@@ -125,42 +149,46 @@ func (w *WGIface) Close() error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
err := w.tun.Close() var result *multierror.Error
if err != nil {
return fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err) if err := w.wgProxyFactory.Free(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
} }
err = w.waitUntilRemoved() if err := w.tun.Close(); err != nil {
if err != nil { result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
}
if err := w.waitUntilRemoved(); err != nil {
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err) log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
err = w.Destroy() if err := w.Destroy(); err != nil {
if err != nil { result = multierror.Append(result, fmt.Errorf("failed to remove WireGuard interface %s: %w", w.Name(), err))
return fmt.Errorf("failed to remove WireGuard interface %s: %w", w.Name(), err) return errors.FormatErrorOrNil(result)
} }
log.Infof("interface %s successfully removed", w.Name()) log.Infof("interface %s successfully removed", w.Name())
} }
return nil return errors.FormatErrorOrNil(result)
} }
// SetFilter sets packet filters for the userspace implementation // SetFilter sets packet filters for the userspace implementation
func (w *WGIface) SetFilter(filter PacketFilter) error { func (w *WGIface) SetFilter(filter device.PacketFilter) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
if w.tun.Wrapper() == nil { if w.tun.FilteredDevice() == nil {
return fmt.Errorf("userspace packet filtering not handled on this device") return fmt.Errorf("userspace packet filtering not handled on this device")
} }
w.filter = filter w.filter = filter
w.filter.SetNetwork(w.tun.WgAddress().Network) w.filter.SetNetwork(w.tun.WgAddress().Network)
w.tun.Wrapper().SetFilter(filter) w.tun.FilteredDevice().SetFilter(filter)
return nil return nil
} }
// GetFilter returns packet filter used by interface if it uses userspace device implementation // GetFilter returns packet filter used by interface if it uses userspace device implementation
func (w *WGIface) GetFilter() PacketFilter { func (w *WGIface) GetFilter() device.PacketFilter {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
@@ -168,16 +196,16 @@ func (w *WGIface) GetFilter() PacketFilter {
} }
// GetDevice to interact with raw device (with filtering) // GetDevice to interact with raw device (with filtering)
func (w *WGIface) GetDevice() *DeviceWrapper { func (w *WGIface) GetDevice() *device.FilteredDevice {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
return w.tun.Wrapper() return w.tun.FilteredDevice()
} }
// GetStats returns the last handshake time, rx and tx bytes for the given peer // GetStats returns the last handshake time, rx and tx bytes for the given peer
func (w *WGIface) GetStats(peerKey string) (WGStats, error) { func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
return w.configurer.getStats(peerKey) return w.configurer.GetStats(peerKey)
} }
func (w *WGIface) waitUntilRemoved() error { func (w *WGIface) waitUntilRemoved() error {

View File

@@ -2,6 +2,8 @@
package iface package iface
import "fmt"
// Create creates a new Wireguard interface, sets a given IP and brings it up. // Create creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one. // Will reuse an existing one.
// this function is different on Android // this function is different on Android
@@ -17,3 +19,8 @@ func (w *WGIface) Create() error {
w.configurer = cfgr w.configurer = cfgr
return nil return nil
} }
// CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on non mobile")
}

View File

@@ -0,0 +1,24 @@
package iface
import (
"fmt"
)
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one.
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
w.mu.Lock()
defer w.mu.Unlock()
cfgr, err := w.tun.Create(routes, dns, searchDomains)
if err != nil {
return err
}
w.configurer = cfgr
return nil
}
// Create this function make sense on mobile only
func (w *WGIface) Create() error {
return fmt.Errorf("this function has not implemented on this platform")
}

View File

@@ -7,38 +7,8 @@ import (
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
"github.com/pion/transport/v3"
"github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/iface/netstack"
) )
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := parseWGAddress(address)
if err != nil {
return nil, err
}
wgIFace := &WGIface{
userspaceBind: true,
}
if netstack.IsEnabled() {
wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
return wgIFace, nil
}
wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
return wgIFace, nil
}
// CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on this platform")
}
// Create creates a new Wireguard interface, sets a given IP and brings it up. // Create creates a new Wireguard interface, sets a given IP and brings it up.
// Will reuse an existing one. // Will reuse an existing one.
// this function is different on Android // this function is different on Android
@@ -64,3 +34,8 @@ func (w *WGIface) Create() error {
return backoff.Retry(operation, backOff) return backoff.Retry(operation, backOff)
} }
// CreateOnAndroid this function make sense on mobile only
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
return fmt.Errorf("this function has not implemented on this platform")
}

View File

@@ -0,0 +1,10 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/device"
)
// GetInterfaceGUIDString returns an interface GUID. This is useful on Windows only
func (w *WGIface) GetInterfaceGUIDString() (string, error) {
return w.tun.(*device.TunDevice).GetInterfaceGUIDString()
}

View File

@@ -6,7 +6,10 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
) )
type MockWGIface struct { type MockWGIface struct {
@@ -14,7 +17,7 @@ type MockWGIface struct {
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
IsUserspaceBindFunc func() bool IsUserspaceBindFunc func() bool
NameFunc func() string NameFunc func() string
AddressFunc func() WGAddress AddressFunc func() device.WGAddress
ToInterfaceFunc func() *net.Interface ToInterfaceFunc func() *net.Interface
UpFunc func() (*bind.UniversalUDPMuxDefault, error) UpFunc func() (*bind.UniversalUDPMuxDefault, error)
UpdateAddrFunc func(newAddr string) error UpdateAddrFunc func(newAddr string) error
@@ -23,11 +26,12 @@ type MockWGIface struct {
AddAllowedIPFunc func(peerKey string, allowedIP string) error AddAllowedIPFunc func(peerKey string, allowedIP string) error
RemoveAllowedIPFunc func(peerKey string, allowedIP string) error RemoveAllowedIPFunc func(peerKey string, allowedIP string) error
CloseFunc func() error CloseFunc func() error
SetFilterFunc func(filter PacketFilter) error SetFilterFunc func(filter device.PacketFilter) error
GetFilterFunc func() PacketFilter GetFilterFunc func() device.PacketFilter
GetDeviceFunc func() *DeviceWrapper GetDeviceFunc func() *device.FilteredDevice
GetStatsFunc func(peerKey string) (WGStats, error) GetStatsFunc func(peerKey string) (configurer.WGStats, error)
GetInterfaceGUIDStringFunc func() (string, error) GetInterfaceGUIDStringFunc func() (string, error)
GetProxyFunc func() wgproxy.Proxy
} }
func (m *MockWGIface) GetInterfaceGUIDString() (string, error) { func (m *MockWGIface) GetInterfaceGUIDString() (string, error) {
@@ -50,7 +54,7 @@ func (m *MockWGIface) Name() string {
return m.NameFunc() return m.NameFunc()
} }
func (m *MockWGIface) Address() WGAddress { func (m *MockWGIface) Address() device.WGAddress {
return m.AddressFunc() return m.AddressFunc()
} }
@@ -86,18 +90,23 @@ func (m *MockWGIface) Close() error {
return m.CloseFunc() return m.CloseFunc()
} }
func (m *MockWGIface) SetFilter(filter PacketFilter) error { func (m *MockWGIface) SetFilter(filter device.PacketFilter) error {
return m.SetFilterFunc(filter) return m.SetFilterFunc(filter)
} }
func (m *MockWGIface) GetFilter() PacketFilter { func (m *MockWGIface) GetFilter() device.PacketFilter {
return m.GetFilterFunc() return m.GetFilterFunc()
} }
func (m *MockWGIface) GetDevice() *DeviceWrapper { func (m *MockWGIface) GetDevice() *device.FilteredDevice {
return m.GetDeviceFunc() return m.GetDeviceFunc()
} }
func (m *MockWGIface) GetStats(peerKey string) (WGStats, error) { func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) {
return m.GetStatsFunc(peerKey) return m.GetStatsFunc(peerKey)
} }
func (m *MockWGIface) GetProxy() wgproxy.Proxy {
//TODO implement me
panic("implement me")
}

View File

@@ -0,0 +1,24 @@
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
wgIFace := &WGIface{
userspaceBind: true,
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil
}

View File

@@ -0,0 +1,34 @@
//go:build !ios
package iface
import (
"github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address)
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
var tun WGTunDevice
if netstack.IsEnabled() {
tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
} else {
tun = device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
}
wgIFace := &WGIface{
userspaceBind: true,
tun: tun,
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
}
return wgIFace, nil
}

Some files were not shown because too many files have changed in this diff Show More