Compare commits

...

80 Commits

Author SHA1 Message Date
Pascal Fischer
4d2c774378 refactor networm map generation 2025-03-13 14:29:59 +01:00
Pascal Fischer
ab2e3fec72 expose resource type consts 2025-03-12 13:49:24 +01:00
Hakan Sariman
47f88f7057 Refactor routeIDLookup methods to use Addr() for resolved IP operations 2025-03-11 19:43:58 +08:00
Hakan Sariman
ee33a6ed7c Refactor RemoveLocalPeerStateRoute to eliminate resourceId parameter 2025-03-11 13:19:30 +08:00
Hakan Sariman
da662cfd08 Add source and destination resource IDs to FlowFields 2025-03-11 13:12:54 +08:00
Hakan Sariman
ed2ee1ee9d Merge branch 'feature/flow' into feat/flow-resid 2025-03-11 13:08:11 +08:00
Viktor Liu
76d73548d6 Fix more conflicts 2025-03-10 18:46:01 +01:00
Viktor Liu
11828a064a Fix conflict 2025-03-10 18:35:32 +01:00
Viktor Liu
0c2a3dd937 Merge branch 'main' into feature/flow 2025-03-10 18:30:45 +01:00
Zoltan Papp
cd9eff5331 Increase the timeout to 50 sec (#3481) 2025-03-10 18:23:47 +01:00
Viktor Liu
47dcf8d68c Fix forwarder IP source/destination (#3463) 2025-03-10 14:55:07 +01:00
Viktor Liu
80ceb80197 [client] Ignore candidates that are part of the the wireguard subnet (#3472) 2025-03-10 13:59:21 +01:00
Bethuel Mmbaga
cc8f6bcaf3 [management] Fix tests circular dependency (#3460)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-03-10 15:54:36 +03:00
Zoltan Papp
636a0e2475 [client] Fix engine restart (#3435)
- Refactor the network monitoring to handle one event and it after return
- In the engine restart cancel the upper layer context and the responsibility of the engine stop will be the upper layer
- Before triggering a restart, the engine checks whether the state is already down. This helps avoid unnecessary delayed network restart events.
2025-03-10 13:32:12 +01:00
Viktor Liu
e66e329bf6 [client] Add option to autostart netbird ui in the Windows installer (#3469) 2025-03-10 13:19:17 +01:00
Zoltan Papp
aaa23beeec [client] Prevent to block channel writing (#3474)
The "runningChan" provides feedback to the UI or any client about whether the service is up and running. If the client exits earlier than when the service successfully starts, then this channel causes a block.

- Added timeout for reading the channel to ensure we don't cause blocks for too long for the caller
- Modified channel writing operations to be non-blocking
2025-03-10 13:17:09 +01:00
Zoltan Papp
6bef474e9e [client] Prevent panic in case of double close call (#3475)
Prevent panic in case of double close call
2025-03-10 13:16:28 +01:00
Maycon Santos
81040ff80a [docs] Update typo (#3477) 2025-03-10 11:52:36 +01:00
Viktor Liu
c73481aee4 [client] Enable windows stderr logs by default (#3476) 2025-03-10 11:30:49 +01:00
Hakan Sariman
92286b2541 Implement routeIDLookup for managing local and remote route IDs 2025-03-10 15:58:45 +08:00
Maycon Santos
d8bcf745b0 update integrations 2025-03-09 19:32:38 +01:00
Maycon Santos
8430139d80 fix missing method 2025-03-09 19:03:57 +01:00
Maycon Santos
a2962b4ce0 sync go.sum 2025-03-09 18:50:20 +01:00
Maycon Santos
16fffdb75b sync changes from #3426 2025-03-09 18:48:48 +01:00
Maycon Santos
036cecbf46 update integrations and go mod 2025-03-09 18:47:05 +01:00
Maycon Santos
3482852bb6 sync proto and sum 2025-03-09 18:02:33 +01:00
Maycon Santos
fd62665b1f Merge branch 'main' into feature/flow
# Conflicts:
#	client/cmd/testutil_test.go
#	client/firewall/iptables/router_linux.go
#	client/firewall/nftables/router_linux.go
#	client/firewall/uspfilter/allow_netbird.go
#	client/firewall/uspfilter/allow_netbird_windows.go
#	client/firewall/uspfilter/uspfilter_test.go
#	client/internal/engine.go
#	client/internal/engine_test.go
#	client/server/server_test.go
#	go.mod
#	go.sum
#	management/client/client_test.go
#	management/cmd/management.go
#	management/proto/management.pb.go
#	management/proto/management.proto
#	management/server/account.go
#	management/server/account_test.go
#	management/server/dns_test.go
#	management/server/http/handler.go
#	management/server/http/testing/testing_tools/tools.go
#	management/server/integrations/port_forwarding/controller.go
#	management/server/management_proto_test.go
#	management/server/management_test.go
#	management/server/nameserver_test.go
#	management/server/peer.go
#	management/server/peer_test.go
#	management/server/route_test.go
2025-03-09 17:42:16 +01:00
Viktor Liu
fc1da94520 [client, management] Add port forwarding (#3275)
Add initial support to ingress ports on the client code.

- new types where added
- new protocol messages and controller
2025-03-09 16:06:43 +01:00
Hakan Sariman
1ffe48f0d4 Add nil check in CheckRoutes to prevent potential panic 2025-03-08 12:54:33 +03:00
Hakan Sariman
a3b8a21385 Refactor CheckRoutes to return resource IDs for matching source and destination addresses 2025-03-08 12:26:53 +03:00
Hakan Sariman
86492b88c4 Refactor route handling to simplify route information and improve state management 2025-03-08 12:25:35 +03:00
Hakan Sariman
d08a629f9e Merge branch 'feature/flow' into feat/flow-resid 2025-03-08 12:18:02 +03:00
Viktor Liu
36da464413 Fix tracer test 2025-03-07 17:19:10 +01:00
Hakan Sariman
268e3404d3 Merge branch 'feature/flow' into feat/flow-resid 2025-03-07 18:52:11 +03:00
Hakan Sariman
54d0591833 Refactor route handling to use RouteWithResourceId for improved state management 2025-03-07 18:43:49 +03:00
Muzammil
ae6b61301c Muz/netbird dashboards (#3458)
* added all 3 dashboards

* update readme
2025-03-07 16:13:11 +01:00
Viktor Liu
86370a0e7b Use bytes for flows event id (#3439) 2025-03-07 16:12:47 +01:00
Philippe Vaucher
a444e551b3 [misc] Traefik config improvements (#3346)
* Remove deprecated docker-compose version

* Prettify docker-compose files

* Backports missing logging entries

* Fix signal port

* Add missing relay configuration

* Serve management over 33073 to avoid confusion
2025-03-07 16:10:11 +01:00
Zoltan Papp
53b9a2002f Print out the goroutine id (#3433)
The TXT logger prints out the actual go routine ID

This feature depends on 'loggoroutine' build tag

```go build -tags loggoroutine```
2025-03-07 14:06:47 +01:00
Viktor Liu
cb16d0f45f Align packet tracer behavior with actual code paths (#3424) 2025-03-07 14:03:45 +01:00
Viktor Liu
e8d8bd8f18 Add peer traffic rule IDs to allowed connections in flows (#3442) 2025-03-07 13:56:26 +01:00
Viktor Liu
8b07f21c28 Don't track intercepted packets (#3448) 2025-03-07 13:56:16 +01:00
Viktor Liu
54be772ffd Handle flow updates (#3455) 2025-03-07 13:56:00 +01:00
Zoltan Papp
4b76d93cec [client] Fix TURN-Relay switch (#3456)
- When a peer is connected with TURN and a Relay connection is established, do not force switching to Relay. Keep using TURN until disconnection.

-In the proxy preparation phase, the Bind Proxy does not set the remote conn as a fake address for Bind. When running the Work() function, the proper proxy instance updates the conn inside the Bind.
2025-03-07 12:00:25 +01:00
Viktor Liu
3c3a454e61 Fix merge regression 2025-03-06 16:54:15 +01:00
Viktor Liu
5ff77b3595 Add flow userspace counters (#3438) 2025-03-06 16:52:56 +01:00
Viktor Liu
b180edbe5c Track icmp with id only (#3447) 2025-03-06 14:51:23 +01:00
Hakan Sariman
de3b5c78d7 Fix nil pointer dereference in CheckRoutes method 2025-03-06 14:10:31 +03:00
Hakan Sariman
0b42f40cf6 Refactor route management to include resource IDs in state handling 2025-03-06 13:51:46 +03:00
Viktor Liu
062d1ec76f [misc] Update bug-issue-report.md template (#3449) 2025-03-06 01:10:37 +01:00
Viktor Liu
0a042ac36d Fix merge conflict 2025-03-05 19:11:20 +01:00
Viktor Liu
c111675dd8 [client] Handle large DNS packets in dns route resolution (#3441) 2025-03-05 18:57:17 +01:00
Hakan Sariman
e7f921d787 [client] add resource id fields to netflow events 2025-03-05 20:35:52 +03:00
Viktor Liu
e9f11fb11b Replace net.IP with netip.Addr (#3425) 2025-03-05 18:28:05 +01:00
hakansa
419ed275fa Handle TCP RST flag to transition connection state to closed (#3432) 2025-03-05 18:25:42 +01:00
Viktor Liu
2d4fcaf186 Fix proto numbering (#3436) 2025-03-04 16:57:25 +01:00
Viktor Liu
acf172b52c Add kernel conntrack counters (#3434) 2025-03-04 16:46:03 +01:00
Viktor Liu
8c81a823fa Add flow ACL IDs (#3421) 2025-03-04 16:43:07 +01:00
Maycon Santos
619c549547 sync port forwarding 2025-03-04 16:29:59 +01:00
hakansa
60ffe0dc87 [client] UI Refactor Icon Paths (#3420)
[client] UI Refactor Icon Paths (#3420)
2025-03-04 18:29:29 +03:00
Maycon Santos
9a713a0987 Merge branch 'feature/port-forwarding' into feature/flow
# Conflicts:
#	go.mod
#	go.sum
2025-03-04 16:28:57 +01:00
Pascal Fischer
c4945cd565 add cleanup scheduler + metrics 2025-03-04 16:21:52 +01:00
Viktor Liu
1e10c17ecb Fix tcp state (#3431) 2025-03-04 11:19:54 +01:00
Viktor Liu
bcc5824980 [client] Close userspace firewall properly (#3426) 2025-03-04 11:19:42 +01:00
robertgro
af5796de1c [client] Add Netbird GitHub link to the client ui about sub menu (#3372) 2025-03-03 17:32:50 +01:00
Philippe Vaucher
9d604b7e66 [client Fix env var typo (#3415) 2025-03-03 17:22:51 +01:00
Viktor Liu
96d5190436 Add icmp type and code to forwarder flow event (#3413) 2025-02-28 21:04:07 +01:00
Viktor Liu
d19c26df06 Fix log direction (#3412) 2025-02-28 21:03:40 +01:00
Viktor Liu
36e36414d9 Fix forwarder log displaying (#3411) 2025-02-28 20:53:01 +01:00
bcmmbaga
7e69589e05 Update management-integrations
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-28 19:49:56 +00:00
bcmmbaga
aa613ab79a Update golang.org/x/crypto/ssh
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-28 19:27:46 +00:00
Viktor Liu
6ead0ff95e Fix log format 2025-02-28 20:24:23 +01:00
Viktor Liu
0db65a8984 Add routed packet drop flow (#3410) 2025-02-28 20:04:59 +01:00
Pascal Fischer
c138807e95 remove log message 2025-02-28 19:54:50 +01:00
Viktor Liu
637c0c8949 Add icmp type and code (#3409) 2025-02-28 19:16:42 +01:00
Viktor Liu
c72e13d8e6 Add conntrack flows (#3406) 2025-02-28 19:16:29 +01:00
Maycon Santos
f6d7bccfa0 Add flow client with sender/receiver (#3405)
add an initial version of receiver client and flow manager receiver and sender
2025-02-28 17:16:18 +00:00
bcmmbaga
e3ed01cafb go mod tidy
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-28 17:10:44 +00:00
Viktor Liu
fa748a7ec2 Add userspace flow implementation (#3393) 2025-02-28 11:08:35 +01:00
Bethuel Mmbaga
82c12cc8ae [management] Handle transaction error on peer deletion (#3387)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-25 19:57:04 +00:00
232 changed files with 11546 additions and 2610 deletions

View File

@@ -31,14 +31,22 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
`netbird version` `netbird version`
**NetBird status -dA output:** **Is any other VPN software installed?**
If applicable, add the `netbird status -dA' command output. If yes, which one?
**Do you face any (non-mobile) client issues?** **Debug output**
Please provide the file created by `netbird debug for 1m -AS`. To help us resolve the problem, please attach the following debug output
We advise reviewing the anonymized files for any remaining PII.
netbird status -dA
As well as the file created by
netbird debug for 1m -AS
We advise reviewing the anonymized output for any remaining personal information.
**Screenshots** **Screenshots**
@@ -47,3 +55,10 @@ If applicable, add screenshots to help explain your problem.
**Additional context** **Additional context**
Add any other context about the problem here. Add any other context about the problem here.
**Have you tried these troubleshooting steps?**
- [ ] Checked for newer NetBird versions
- [ ] Searched for similar issues on GitHub (including closed ones)
- [ ] Restarted the NetBird client
- [ ] Disabled other VPN software
- [ ] Checked firewall settings

View File

@@ -258,7 +258,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
arch: [ '386','amd64' ] arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ] store: [ 'sqlite', 'postgres', 'mysql' ]
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
@@ -325,8 +325,8 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
arch: [ '386','amd64' ] arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ] store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Install Go - name: Install Go
@@ -392,7 +392,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
arch: [ '386','amd64' ] arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ] store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
@@ -461,7 +461,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
arch: [ '386','amd64' ] arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres'] store: [ 'sqlite', 'postgres']
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:

View File

@@ -71,7 +71,7 @@ jobs:
- name: Install goversioninfo - name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64 - name: Generate windows syso amd64
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso run: goversioninfo -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
- name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
with: with:
@@ -150,7 +150,7 @@ jobs:
- name: Install goversioninfo - name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64 - name: Generate windows syso amd64
run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso run: goversioninfo -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
- name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4

View File

@@ -53,9 +53,9 @@ nfpms:
scripts: scripts:
postinstall: "release_files/ui-post-install.sh" postinstall: "release_files/ui-post-install.sh"
contents: contents:
- src: client/ui/netbird.desktop - src: client/ui/build/netbird.desktop
dst: /usr/share/applications/netbird.desktop dst: /usr/share/applications/netbird.desktop
- src: client/ui/netbird.png - src: client/ui/assets/netbird.png
dst: /usr/share/pixmaps/netbird.png dst: /usr/share/pixmaps/netbird.png
dependencies: dependencies:
- netbird - netbird
@@ -72,9 +72,9 @@ nfpms:
scripts: scripts:
postinstall: "release_files/ui-post-install.sh" postinstall: "release_files/ui-post-install.sh"
contents: contents:
- src: client/ui/netbird.desktop - src: client/ui/build/netbird.desktop
dst: /usr/share/applications/netbird.desktop dst: /usr/share/applications/netbird.desktop
- src: client/ui/netbird.png - src: client/ui/assets/netbird.png
dst: /usr/share/pixmaps/netbird.png dst: /usr/share/pixmaps/netbird.png
dependencies: dependencies:
- netbird - netbird

View File

@@ -134,10 +134,11 @@ func (c *Client) Start(startCtx context.Context) error {
// either startup error (permanent backoff err) or nil err (successful engine up) // either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available // TODO: make after-startup backoff err available
run := make(chan error, 1) run := make(chan struct{}, 1)
clientErr := make(chan error, 1)
go func() { go func() {
if err := client.Run(run); err != nil { if err := client.Run(run); err != nil {
run <- err clientErr <- err
} }
}() }()
@@ -147,13 +148,9 @@ func (c *Client) Start(startCtx context.Context) error {
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err()) return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
} }
return startCtx.Err() return startCtx.Err()
case err := <-run: case err := <-clientErr:
if err != nil { return fmt.Errorf("startup: %w", err)
if stopErr := client.Stop(); stopErr != nil { case <-run:
return fmt.Errorf("stop error after failed to startup. Stop error: %w. Start error: %w", stopErr, err)
}
return fmt.Errorf("startup: %w", err)
}
} }
c.connect = client c.connect = client

View File

@@ -10,17 +10,18 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
// NewFirewall creates a firewall manager instance // NewFirewall creates a firewall manager instance
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) { func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
if !iface.IsUserspaceBind() { if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
} }
// use userspace packet filtering firewall // use userspace packet filtering firewall
fm, err := uspfilter.Create(iface, disableServerRoutes) fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -15,6 +15,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables" nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -33,7 +34,7 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type // FWType is the type for the firewall type
type FWType int type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) { func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables // on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic // in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers // so we use AllowNetbird traffic from these firewall managers
@@ -47,7 +48,7 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableS
if err != nil { if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err) log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
} }
return createUserspaceFirewall(iface, fm, disableServerRoutes) return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger)
} }
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) { func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
@@ -77,12 +78,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
} }
} }
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) { func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) {
var errUsp error var errUsp error
if fm != nil { if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes) fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger)
} else { } else {
fm, errUsp = uspfilter.Create(iface, disableServerRoutes) fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger)
} }
if errUsp != nil { if errUsp != nil {

View File

@@ -4,12 +4,13 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// IFaceMapper defines subset methods of interface required for manager // IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface { type IFaceMapper interface {
Name() string Name() string
Address() device.WGAddress Address() wgaddr.Address
IsUserspaceBind() bool IsUserspaceBind() bool
SetFilter(device.PacketFilter) error SetFilter(device.PacketFilter) error
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice

View File

@@ -75,6 +75,7 @@ func (m *aclManager) init(stateManager *statemanager.Manager) error {
} }
func (m *aclManager) AddPeerFiltering( func (m *aclManager) AddPeerFiltering(
id []byte,
ip net.IP, ip net.IP,
protocol firewall.Protocol, protocol firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,

View File

@@ -13,7 +13,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -31,7 +31,7 @@ type Manager struct {
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface { type iFaceMapper interface {
Name() string Name() string
Address() iface.WGAddress Address() wgaddr.Address
IsUserspaceBind() bool IsUserspaceBind() bool
} }
@@ -96,21 +96,22 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
// //
// Comment will be ignored because some system this feature is not supported // Comment will be ignored because some system this feature is not supported
func (m *Manager) AddPeerFiltering( func (m *Manager) AddPeerFiltering(
id []byte,
ip net.IP, ip net.IP,
protocol firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
_ string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, action, ipsetName) return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
} }
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination netip.Prefix,
proto firewall.Protocol, proto firewall.Protocol,
@@ -125,7 +126,7 @@ func (m *Manager) AddRouteFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String()) return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
} }
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action) return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
} }
// DeletePeerRule from the firewall by rule definition // DeletePeerRule from the firewall by rule definition
@@ -166,7 +167,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
} }
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset(stateManager *statemanager.Manager) error { func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -196,13 +197,13 @@ func (m *Manager) AllowNetbird() error {
} }
_, err := m.AddPeerFiltering( _, err := m.AddPeerFiltering(
nil,
net.IP{0, 0, 0, 0}, net.IP{0, 0, 0, 0},
"all", "all",
nil, nil,
nil, nil,
firewall.ActionAccept, firewall.ActionAccept,
"", "",
"",
) )
if err != nil { if err != nil {
return fmt.Errorf("allow netbird interface traffic: %w", err) return fmt.Errorf("allow netbird interface traffic: %w", err)

View File

@@ -10,15 +10,15 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
var ifaceMock = &iFaceMock{ var ifaceMock = &iFaceMock{
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
}, },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"), IP: net.ParseIP("10.20.0.0"),
@@ -31,7 +31,7 @@ var ifaceMock = &iFaceMock{
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct { type iFaceMock struct {
NameFunc func() string NameFunc func() string
AddressFunc func() iface.WGAddress AddressFunc func() wgaddr.Address
} }
func (i *iFaceMock) Name() string { func (i *iFaceMock) Name() string {
@@ -41,7 +41,7 @@ func (i *iFaceMock) Name() string {
panic("NameFunc is not set") panic("NameFunc is not set")
} }
func (i *iFaceMock) Address() iface.WGAddress { func (i *iFaceMock) Address() wgaddr.Address {
if i.AddressFunc != nil { if i.AddressFunc != nil {
return i.AddressFunc() return i.AddressFunc()
} }
@@ -62,7 +62,7 @@ func TestIptablesManager(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
err := manager.Reset(nil) err := manager.Close(nil)
require.NoError(t, err, "clear the manager state") require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -75,7 +75,7 @@ func TestIptablesManager(t *testing.T) {
IsRange: true, IsRange: true,
Values: []uint16{8043, 8046}, Values: []uint16{8043, 8046},
} }
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range") rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
for _, r := range rule2 { for _, r := range rule2 {
@@ -97,17 +97,17 @@ func TestIptablesManager(t *testing.T) {
// add second rule // add second rule
ip := net.ParseIP("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []uint16{5353}} port := &fw.Port{Values: []uint16{5353}}
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic") _, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Reset(nil) err = manager.Close(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules) ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
require.NoError(t, err, "failed check chain exists") require.NoError(t, err, "failed check chain exists")
if ok { if ok {
require.NoErrorf(t, err, "chain '%v' still exists after Reset", chainNameInputRules) require.NoErrorf(t, err, "chain '%v' still exists after Close", chainNameInputRules)
} }
}) })
} }
@@ -117,8 +117,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
}, },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"), IP: net.ParseIP("10.20.0.0"),
@@ -136,7 +136,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
err := manager.Reset(nil) err := manager.Close(nil)
require.NoError(t, err, "clear the manager state") require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -148,7 +148,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
port := &fw.Port{ port := &fw.Port{
Values: []uint16{443}, Values: []uint16{443},
} }
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range") rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
for _, r := range rule2 { for _, r := range rule2 {
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set") require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
@@ -166,7 +166,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
}) })
t.Run("reset check", func(t *testing.T) { t.Run("reset check", func(t *testing.T) {
err = manager.Reset(nil) err = manager.Close(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
}) })
} }
@@ -184,8 +184,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
}, },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"), IP: net.ParseIP("10.20.0.0"),
@@ -204,7 +204,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
err := manager.Reset(nil) err := manager.Close(nil)
require.NoError(t, err, "clear the manager state") require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -216,7 +216,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}} port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
} }

View File

@@ -15,7 +15,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id" nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
@@ -121,6 +121,7 @@ func (r *router) init(stateManager *statemanager.Manager) error {
} }
func (r *router) AddRouteFiltering( func (r *router) AddRouteFiltering(
id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination netip.Prefix,
proto firewall.Protocol, proto firewall.Protocol,
@@ -128,7 +129,7 @@ func (r *router) AddRouteFiltering(
dPort *firewall.Port, dPort *firewall.Port,
action firewall.Action, action firewall.Action,
) (firewall.Rule, error) { ) (firewall.Rule, error) {
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action) ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok { if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil return ruleKey, nil
} }

View File

@@ -330,7 +330,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed") require.NoError(t, err, "AddRouteFiltering failed")
// Check if the rule is in the internal map // Check if the rule is in the internal map

View File

@@ -4,21 +4,20 @@ import (
"fmt" "fmt"
"sync" "sync"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/device"
) )
type InterfaceState struct { type InterfaceState struct {
NameStr string `json:"name"` NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"` WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"` UserspaceBind bool `json:"userspace_bind"`
} }
func (i *InterfaceState) Name() string { func (i *InterfaceState) Name() string {
return i.NameStr return i.NameStr
} }
func (i *InterfaceState) Address() device.WGAddress { func (i *InterfaceState) Address() wgaddr.Address {
return i.WGAddress return i.WGAddress
} }
@@ -62,7 +61,7 @@ func (s *ShutdownState) Cleanup() error {
ipt.aclMgr.ipsetStore = s.ACLIPsetStore ipt.aclMgr.ipsetStore = s.ACLIPsetStore
} }
if err := ipt.Reset(nil); err != nil { if err := ipt.Close(nil); err != nil {
return fmt.Errorf("reset iptables manager: %w", err) return fmt.Errorf("reset iptables manager: %w", err)
} }

View File

@@ -65,13 +65,13 @@ type Manager interface {
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
AddPeerFiltering( AddPeerFiltering(
id []byte,
ip net.IP, ip net.IP,
proto Protocol, proto Protocol,
sPort *Port, sPort *Port,
dPort *Port, dPort *Port,
action Action, action Action,
ipsetName string, ipsetName string,
comment string,
) ([]Rule, error) ) ([]Rule, error)
// DeletePeerRule from the firewall by rule definition // DeletePeerRule from the firewall by rule definition
@@ -80,7 +80,15 @@ type Manager interface {
// IsServerRouteSupported returns true if the firewall supports server side routing operations // IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool IsServerRouteSupported() bool
AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination netip.Prefix,
proto Protocol,
sPort *Port,
dPort *Port,
action Action,
) (Rule, error)
// DeleteRouteRule deletes a routing rule // DeleteRouteRule deletes a routing rule
DeleteRouteRule(rule Rule) error DeleteRouteRule(rule Rule) error
@@ -94,8 +102,8 @@ type Manager interface {
// SetLegacyManagement sets the legacy management mode // SetLegacyManagement sets the legacy management mode
SetLegacyManagement(legacy bool) error SetLegacyManagement(legacy bool) error
// Reset firewall to the default state // Close closes the firewall manager
Reset(stateManager *statemanager.Manager) error Close(stateManager *statemanager.Manager) error
// Flush the changes to firewall controller // Flush the changes to firewall controller
Flush() error Flush() error

View File

@@ -84,13 +84,13 @@ func (m *AclManager) init(workTable *nftables.Table) error {
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
func (m *AclManager) AddPeerFiltering( func (m *AclManager) AddPeerFiltering(
id []byte,
ip net.IP, ip net.IP,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
var ipset *nftables.Set var ipset *nftables.Set
if ipsetName != "" { if ipsetName != "" {
@@ -102,7 +102,7 @@ func (m *AclManager) AddPeerFiltering(
} }
newRules := make([]firewall.Rule, 0, 2) newRules := make([]firewall.Rule, 0, 2)
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset, comment) ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -256,7 +256,6 @@ func (m *AclManager) addIOFiltering(
dPort *firewall.Port, dPort *firewall.Port,
action firewall.Action, action firewall.Action,
ipset *nftables.Set, ipset *nftables.Set,
comment string,
) (*Rule, error) { ) (*Rule, error) {
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset) ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
if r, ok := m.rules[ruleId]; ok { if r, ok := m.rules[ruleId]; ok {
@@ -338,7 +337,7 @@ func (m *AclManager) addIOFiltering(
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop}) mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
} }
userData := []byte(strings.Join([]string{ruleId, comment}, " ")) userData := []byte(ruleId)
chain := m.chainInputRules chain := m.chainInputRules
nftRule := m.rConn.AddRule(&nftables.Rule{ nftRule := m.rConn.AddRule(&nftables.Rule{

View File

@@ -14,7 +14,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -29,7 +29,7 @@ const (
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface { type iFaceMapper interface {
Name() string Name() string
Address() iface.WGAddress Address() wgaddr.Address
IsUserspaceBind() bool IsUserspaceBind() bool
} }
@@ -87,7 +87,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
// We only need to record minimal interface state for potential recreation. // We only need to record minimal interface state for potential recreation.
// Unlike iptables, which requires tracking individual rules, nftables maintains // Unlike iptables, which requires tracking individual rules, nftables maintains
// a known state (our netbird table plus a few static rules). This allows for easy // a known state (our netbird table plus a few static rules). This allows for easy
// cleanup using Reset() without needing to store specific rules. // cleanup using Close() without needing to store specific rules.
if err := stateManager.UpdateState(&ShutdownState{ if err := stateManager.UpdateState(&ShutdownState{
InterfaceState: &InterfaceState{ InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(), NameStr: m.wgIface.Name(),
@@ -113,13 +113,13 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
func (m *Manager) AddPeerFiltering( func (m *Manager) AddPeerFiltering(
id []byte,
ip net.IP, ip net.IP,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -129,10 +129,11 @@ func (m *Manager) AddPeerFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", ip.String()) return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
} }
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, action, ipsetName, comment) return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
} }
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination netip.Prefix,
proto firewall.Protocol, proto firewall.Protocol,
@@ -147,7 +148,7 @@ func (m *Manager) AddRouteFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String()) return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
} }
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action) return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
} }
// DeletePeerRule from the firewall by rule definition // DeletePeerRule from the firewall by rule definition
@@ -242,7 +243,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
} }
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset(stateManager *statemanager.Manager) error { func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()

View File

@@ -16,15 +16,15 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
var ifaceMock = &iFaceMock{ var ifaceMock = &iFaceMock{
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
}, },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("100.96.0.1"), IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"), IP: net.ParseIP("100.96.0.0"),
@@ -37,7 +37,7 @@ var ifaceMock = &iFaceMock{
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct { type iFaceMock struct {
NameFunc func() string NameFunc func() string
AddressFunc func() iface.WGAddress AddressFunc func() wgaddr.Address
} }
func (i *iFaceMock) Name() string { func (i *iFaceMock) Name() string {
@@ -47,7 +47,7 @@ func (i *iFaceMock) Name() string {
panic("NameFunc is not set") panic("NameFunc is not set")
} }
func (i *iFaceMock) Address() iface.WGAddress { func (i *iFaceMock) Address() wgaddr.Address {
if i.AddressFunc != nil { if i.AddressFunc != nil {
return i.AddressFunc() return i.AddressFunc()
} }
@@ -65,7 +65,7 @@ func TestNftablesManager(t *testing.T) {
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
defer func() { defer func() {
err = manager.Reset(nil) err = manager.Close(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
@@ -74,7 +74,7 @@ func TestNftablesManager(t *testing.T) {
testClient := &nftables.Conn{} testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "", "") rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Flush() err = manager.Flush()
@@ -162,7 +162,7 @@ func TestNftablesManager(t *testing.T) {
// established rule remains // established rule remains
require.Len(t, rules, 1, "expected 1 rules after deletion") require.Len(t, rules, 1, "expected 1 rules after deletion")
err = manager.Reset(nil) err = manager.Close(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
} }
@@ -171,8 +171,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
}, },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("100.96.0.1"), IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"), IP: net.ParseIP("100.96.0.0"),
@@ -191,7 +191,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
defer func() { defer func() {
if err := manager.Reset(nil); err != nil { if err := manager.Close(nil); err != nil {
t.Errorf("clear the manager state: %v", err) t.Errorf("clear the manager state: %v", err)
} }
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -201,7 +201,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}} port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
if i%100 == 0 { if i%100 == 0 {
@@ -274,7 +274,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))
t.Cleanup(func() { t.Cleanup(func() {
err := manager.Reset(nil) err := manager.Close(nil)
require.NoError(t, err, "failed to reset manager state") require.NoError(t, err, "failed to reset manager state")
// Verify iptables output after reset // Verify iptables output after reset
@@ -283,10 +283,11 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
}) })
ip := net.ParseIP("100.96.0.1") ip := net.ParseIP("100.96.0.1")
_, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "", "test rule") _, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err, "failed to add peer filtering rule") require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering( _, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")}, []netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
netip.MustParsePrefix("10.1.0.0/24"), netip.MustParsePrefix("10.1.0.0/24"),
fw.ProtocolTCP, fw.ProtocolTCP,

View File

@@ -20,7 +20,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id" nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
@@ -228,6 +228,7 @@ func (r *router) createContainers() error {
// AddRouteFiltering appends a nftables rule to the routing chain // AddRouteFiltering appends a nftables rule to the routing chain
func (r *router) AddRouteFiltering( func (r *router) AddRouteFiltering(
id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination netip.Prefix,
proto firewall.Protocol, proto firewall.Protocol,
@@ -236,7 +237,7 @@ func (r *router) AddRouteFiltering(
action firewall.Action, action firewall.Action,
) (firewall.Rule, error) { ) (firewall.Rule, error) {
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action) ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok { if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil return ruleKey, nil
} }

View File

@@ -38,7 +38,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
// need fw manager to init both acl mgr and router for all chains to be present // need fw manager to init both acl mgr and router for all chains to be present
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock)
t.Cleanup(func() { t.Cleanup(func() {
require.NoError(t, manager.Reset(nil)) require.NoError(t, manager.Close(nil))
}) })
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))
@@ -127,7 +127,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock)
t.Cleanup(func() { t.Cleanup(func() {
require.NoError(t, manager.Reset(nil)) require.NoError(t, manager.Close(nil))
}) })
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))
@@ -311,7 +311,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed") require.NoError(t, err, "AddRouteFiltering failed")
t.Cleanup(func() { t.Cleanup(func() {

View File

@@ -3,21 +3,20 @@ package nftables
import ( import (
"fmt" "fmt"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/device"
) )
type InterfaceState struct { type InterfaceState struct {
NameStr string `json:"name"` NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"` WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"` UserspaceBind bool `json:"userspace_bind"`
} }
func (i *InterfaceState) Name() string { func (i *InterfaceState) Name() string {
return i.NameStr return i.NameStr
} }
func (i *InterfaceState) Address() device.WGAddress { func (i *InterfaceState) Address() wgaddr.Address {
return i.WGAddress return i.WGAddress
} }
@@ -39,7 +38,7 @@ func (s *ShutdownState) Cleanup() error {
return fmt.Errorf("create nftables manager: %w", err) return fmt.Errorf("create nftables manager: %w", err)
} }
if err := nft.Reset(nil); err != nil { if err := nft.Close(nil); err != nil {
return fmt.Errorf("reset nftables manager: %w", err) return fmt.Errorf("reset nftables manager: %w", err)
} }

View File

@@ -4,39 +4,36 @@ package uspfilter
import ( import (
"context" "context"
"net/netip"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset(stateManager *statemanager.Manager) error { func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet) m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[string]RuleSet) m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil { if m.udpTracker != nil {
m.udpTracker.Close() m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
} }
if m.icmpTracker != nil { if m.icmpTracker != nil {
m.icmpTracker.Close() m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
} }
if m.tcpTracker != nil { if m.tcpTracker != nil {
m.tcpTracker.Close() m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
} }
if m.forwarder != nil { if fwder := m.forwarder.Load(); fwder != nil {
m.forwarder.Stop() fwder.Stop()
} }
if m.logger != nil { if m.logger != nil {
@@ -48,7 +45,7 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
} }
if m.nativeFirewall != nil { if m.nativeFirewall != nil {
return m.nativeFirewall.Reset(stateManager) return m.nativeFirewall.Close(stateManager)
} }
return nil return nil
} }

View File

@@ -3,6 +3,7 @@ package uspfilter
import ( import (
"context" "context"
"fmt" "fmt"
"net/netip"
"os/exec" "os/exec"
"syscall" "syscall"
"time" "time"
@@ -22,30 +23,30 @@ const (
) )
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset(*statemanager.Manager) error { func (m *Manager) Close(*statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet) m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[string]RuleSet) m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil { if m.udpTracker != nil {
m.udpTracker.Close() m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger) m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
} }
if m.icmpTracker != nil { if m.icmpTracker != nil {
m.icmpTracker.Close() m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger) m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
} }
if m.tcpTracker != nil { if m.tcpTracker != nil {
m.tcpTracker.Close() m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
} }
if m.forwarder != nil { if fwder := m.forwarder.Load(); fwder != nil {
m.forwarder.Stop() fwder.Stop()
} }
if m.logger != nil { if m.logger != nil {

View File

@@ -3,14 +3,14 @@ package common
import ( import (
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// IFaceMapper defines subset methods of interface required for manager // IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface { type IFaceMapper interface {
SetFilter(device.PacketFilter) error SetFilter(device.PacketFilter) error
Address() iface.WGAddress Address() wgaddr.Address
GetWGDevice() *wgdevice.Device GetWGDevice() *wgdevice.Device
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
} }

View File

@@ -1,20 +1,27 @@
// common.go
package conntrack package conntrack
import ( import (
"net" "fmt"
"sync" "net/netip"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/google/uuid"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
// BaseConnTrack provides common fields and locking for all connection types // BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct { type BaseConnTrack struct {
SourceIP net.IP FlowId uuid.UUID
DestIP net.IP Direction nftypes.Direction
SourcePort uint16 SourceIP netip.Addr
DestPort uint16 DestIP netip.Addr
lastSeen atomic.Int64 // Unix nano for atomic access lastSeen atomic.Int64
PacketsTx atomic.Uint64
PacketsRx atomic.Uint64
BytesTx atomic.Uint64
BytesRx atomic.Uint64
} }
// these small methods will be inlined by the compiler // these small methods will be inlined by the compiler
@@ -24,6 +31,17 @@ func (b *BaseConnTrack) UpdateLastSeen() {
b.lastSeen.Store(time.Now().UnixNano()) b.lastSeen.Store(time.Now().UnixNano())
} }
// UpdateCounters safely updates the packet and byte counters
func (b *BaseConnTrack) UpdateCounters(direction nftypes.Direction, bytes int) {
if direction == nftypes.Egress {
b.PacketsTx.Add(1)
b.BytesTx.Add(uint64(bytes))
} else {
b.PacketsRx.Add(1)
b.BytesRx.Add(uint64(bytes))
}
}
// GetLastSeen safely gets the last seen timestamp // GetLastSeen safely gets the last seen timestamp
func (b *BaseConnTrack) GetLastSeen() time.Time { func (b *BaseConnTrack) GetLastSeen() time.Time {
return time.Unix(0, b.lastSeen.Load()) return time.Unix(0, b.lastSeen.Load())
@@ -35,92 +53,14 @@ func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
return time.Since(lastSeen) > timeout return time.Since(lastSeen) > timeout
} }
// IPAddr is a fixed-size IP address to avoid allocations
type IPAddr [16]byte
// MakeIPAddr creates an IPAddr from net.IP
func MakeIPAddr(ip net.IP) (addr IPAddr) {
// Optimization: check for v4 first as it's more common
if ip4 := ip.To4(); ip4 != nil {
copy(addr[12:], ip4)
} else {
copy(addr[:], ip.To16())
}
return addr
}
// ConnKey uniquely identifies a connection // ConnKey uniquely identifies a connection
type ConnKey struct { type ConnKey struct {
SrcIP IPAddr SrcIP netip.Addr
DstIP IPAddr DstIP netip.Addr
SrcPort uint16 SrcPort uint16
DstPort uint16 DstPort uint16
} }
// makeConnKey creates a connection key func (c ConnKey) String() string {
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey { return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
return ConnKey{
SrcIP: MakeIPAddr(srcIP),
DstIP: MakeIPAddr(dstIP),
SrcPort: srcPort,
DstPort: dstPort,
}
}
// ValidateIPs checks if IPs match without allocation
func ValidateIPs(connIP IPAddr, pktIP net.IP) bool {
if ip4 := pktIP.To4(); ip4 != nil {
// Compare IPv4 addresses (last 4 bytes)
for i := 0; i < 4; i++ {
if connIP[12+i] != ip4[i] {
return false
}
}
return true
}
// Compare full IPv6 addresses
ip6 := pktIP.To16()
for i := 0; i < 16; i++ {
if connIP[i] != ip6[i] {
return false
}
}
return true
}
// PreallocatedIPs is a pool of IP byte slices to reduce allocations
type PreallocatedIPs struct {
sync.Pool
}
// NewPreallocatedIPs creates a new IP pool
func NewPreallocatedIPs() *PreallocatedIPs {
return &PreallocatedIPs{
Pool: sync.Pool{
New: func() interface{} {
ip := make(net.IP, 16)
return &ip
},
},
}
}
// Get retrieves an IP from the pool
func (p *PreallocatedIPs) Get() net.IP {
return *p.Pool.Get().(*net.IP)
}
// Put returns an IP to the pool
func (p *PreallocatedIPs) Put(ip net.IP) {
p.Pool.Put(&ip)
}
// copyIP copies an IP address efficiently
func copyIP(dst, src net.IP) {
if len(src) == 16 {
copy(dst, src)
} else {
// Handle IPv4
copy(dst[12:], src.To4())
}
} }

View File

@@ -1,94 +1,67 @@
package conntrack package conntrack
import ( import (
"net" "context"
"net/netip"
"testing" "testing"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log" "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/internal/netflow"
) )
var logger = log.NewFromLogrus(logrus.StandardLogger()) var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
func BenchmarkIPOperations(b *testing.B) {
b.Run("MakeIPAddr", func(b *testing.B) {
ip := net.ParseIP("192.168.1.1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = MakeIPAddr(ip)
}
})
b.Run("ValidateIPs", func(b *testing.B) {
ip1 := net.ParseIP("192.168.1.1")
ip2 := net.ParseIP("192.168.1.1")
addr := MakeIPAddr(ip1)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = ValidateIPs(addr, ip2)
}
})
b.Run("IPPool", func(b *testing.B) {
pool := NewPreallocatedIPs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
ip := pool.Get()
pool.Put(ip)
}
})
}
// Memory pressure tests // Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) { func BenchmarkMemoryPressure(b *testing.B) {
b.Run("TCPHighLoad", func(b *testing.B) { b.Run("TCPHighLoad", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
// Generate different IPs // Generate different IPs
srcIPs := make([]net.IP, 100) srcIPs := make([]netip.Addr, 100)
dstIPs := make([]net.IP, 100) dstIPs := make([]netip.Addr, 100)
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256)) srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256)) dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs) srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs) dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn) tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn, 0)
// Simulate some valid inbound packets // Simulate some valid inbound packets
if i%3 == 0 { if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck) tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck, 0)
} }
} }
}) })
b.Run("UDPHighLoad", func(b *testing.B) { b.Run("UDPHighLoad", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger) tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
// Generate different IPs // Generate different IPs
srcIPs := make([]net.IP, 100) srcIPs := make([]netip.Addr, 100)
dstIPs := make([]net.IP, 100) dstIPs := make([]netip.Addr, 100)
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256)) srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256)) dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs) srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs) dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80) tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, 0)
// Simulate some valid inbound packets // Simulate some valid inbound packets
if i%3 == 0 { if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535)) tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), 0)
} }
} }
}) })

View File

@@ -1,13 +1,17 @@
package conntrack package conntrack
import ( import (
"net" "context"
"fmt"
"net/netip"
"sync" "sync"
"time" "time"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
const ( const (
@@ -19,18 +23,20 @@ const (
// ICMPConnKey uniquely identifies an ICMP connection // ICMPConnKey uniquely identifies an ICMP connection
type ICMPConnKey struct { type ICMPConnKey struct {
// Supports both IPv4 and IPv6 SrcIP netip.Addr
SrcIP [16]byte DstIP netip.Addr
DstIP [16]byte ID uint16
Sequence uint16 // ICMP sequence number }
ID uint16 // ICMP identifier
func (i ICMPConnKey) String() string {
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
} }
// ICMPConnTrack represents an ICMP connection state // ICMPConnTrack represents an ICMP connection state
type ICMPConnTrack struct { type ICMPConnTrack struct {
BaseConnTrack BaseConnTrack
Sequence uint16 ICMPType uint8
ID uint16 ICMPCode uint8
} }
// ICMPTracker manages ICMP connection states // ICMPTracker manages ICMP connection states
@@ -39,131 +45,201 @@ type ICMPTracker struct {
connections map[ICMPConnKey]*ICMPConnTrack connections map[ICMPConnKey]*ICMPConnTrack
timeout time.Duration timeout time.Duration
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex mutex sync.RWMutex
done chan struct{} flowLogger nftypes.FlowLogger
ipPool *PreallocatedIPs
} }
// NewICMPTracker creates a new ICMP connection tracker // NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker { func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
if timeout == 0 { if timeout == 0 {
timeout = DefaultICMPTimeout timeout = DefaultICMPTimeout
} }
ctx, cancel := context.WithCancel(context.Background())
tracker := &ICMPTracker{ tracker := &ICMPTracker{
logger: logger, logger: logger,
connections: make(map[ICMPConnKey]*ICMPConnTrack), connections: make(map[ICMPConnKey]*ICMPConnTrack),
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval), cleanupTicker: time.NewTicker(ICMPCleanupInterval),
done: make(chan struct{}), tickerCancel: cancel,
ipPool: NewPreallocatedIPs(), flowLogger: flowLogger,
} }
go tracker.cleanupRoutine() go tracker.cleanupRoutine(ctx)
return tracker return tracker
} }
// TrackOutbound records an outbound ICMP Echo Request func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, direction nftypes.Direction, size int) (ICMPConnKey, bool) {
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) { key := ICMPConnKey{
key := makeICMPKey(srcIP, dstIP, id, seq) SrcIP: srcIP,
DstIP: dstIP,
t.mutex.Lock() ID: id,
conn, exists := t.connections[key]
if !exists {
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &ICMPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
},
ID: id,
Sequence: seq,
}
conn.UpdateLastSeen()
t.connections[key] = conn
t.logger.Trace("New ICMP connection %v", key)
} }
t.mutex.Unlock()
conn.UpdateLastSeen()
}
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
return false
}
key := makeICMPKey(dstIP, srcIP, id, seq)
t.mutex.RLock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
t.mutex.RUnlock() t.mutex.RUnlock()
if !exists { if exists {
return false conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, true
} }
if conn.timeoutExceeded(t.timeout) { return key, false
return false
}
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
conn.ID == id &&
conn.Sequence == seq
} }
func (t *ICMPTracker) cleanupRoutine() { // TrackOutbound records an outbound ICMP connection
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
}
}
// TrackInbound records an inbound ICMP Echo Request
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
}
// track is the common implementation for tracking both inbound and outbound ICMP connections
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
if exists {
return
}
typ, code := typecode.Type(), typecode.Code()
// non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
return
}
conn := &ICMPConnTrack{
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
},
ICMPType: typ,
ICMPCode: code,
}
conn.UpdateLastSeen()
t.mutex.Lock()
t.connections[key] = conn
t.mutex.Unlock()
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendEvent(nftypes.TypeStart, conn, ruleId)
}
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
return false
}
key := ICMPConnKey{
SrcIP: dstIP,
DstIP: srcIP,
ID: id,
}
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists || conn.timeoutExceeded(t.timeout) {
return false
}
conn.UpdateLastSeen()
conn.UpdateCounters(nftypes.Ingress, size)
return true
}
func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
defer t.tickerCancel()
for { for {
select { select {
case <-t.cleanupTicker.C: case <-t.cleanupTicker.C:
t.cleanup() t.cleanup()
case <-t.done: case <-ctx.Done():
return return
} }
} }
} }
func (t *ICMPTracker) cleanup() { func (t *ICMPTracker) cleanup() {
t.mutex.Lock() t.mutex.Lock()
defer t.mutex.Unlock() defer t.mutex.Unlock()
for key, conn := range t.connections { for key, conn := range t.connections {
if conn.timeoutExceeded(t.timeout) { if conn.timeoutExceeded(t.timeout) {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key) delete(t.connections, key)
t.logger.Debug("Removed ICMP connection %v (timeout)", key) t.logger.Debug("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
} }
} }
// Close stops the cleanup routine and releases resources // Close stops the cleanup routine and releases resources
func (t *ICMPTracker) Close() { func (t *ICMPTracker) Close() {
t.cleanupTicker.Stop() t.tickerCancel()
close(t.done)
t.mutex.Lock() t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil t.connections = nil
t.mutex.Unlock() t.mutex.Unlock()
} }
// makeICMPKey creates an ICMP connection key func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []byte) {
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey { t.flowLogger.StoreEvent(nftypes.EventFields{
return ICMPConnKey{ FlowID: conn.FlowId,
SrcIP: MakeIPAddr(srcIP), Type: typ,
DstIP: MakeIPAddr(dstIP), RuleID: ruleID,
ID: id, Direction: conn.Direction,
Sequence: seq, Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
} SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
ICMPType: conn.ICMPType,
ICMPCode: conn.ICMPCode,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
})
}
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8, ruleID []byte, size int) {
fields := nftypes.EventFields{
FlowID: uuid.New(),
Type: nftypes.TypeStart,
RuleID: ruleID,
Direction: direction,
Protocol: nftypes.ICMP,
SourceIP: srcIP,
DestIP: dstIP,
ICMPType: typ,
ICMPCode: code,
}
if direction == nftypes.Ingress {
fields.RxPackets = 1
fields.RxBytes = uint64(size)
} else {
fields.TxPackets = 1
fields.TxBytes = uint64(size)
}
t.flowLogger.StoreEvent(fields)
} }

View File

@@ -1,39 +1,39 @@
package conntrack package conntrack
import ( import (
"net" "net/netip"
"testing" "testing"
) )
func BenchmarkICMPTracker(b *testing.B) { func BenchmarkICMPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger) tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535)) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
} }
}) })
b.Run("IsValidInbound", func(b *testing.B) { b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger) tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections // Pre-populate some connections
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i)) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0) tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), 0, 0)
} }
}) })
} }

View File

@@ -3,12 +3,16 @@ package conntrack
// TODO: Send RST packets for invalid/timed-out connections // TODO: Send RST packets for invalid/timed-out connections
import ( import (
"net" "context"
"net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
const ( const (
@@ -39,6 +43,35 @@ const (
// TCPState represents the state of a TCP connection // TCPState represents the state of a TCP connection
type TCPState int type TCPState int
func (s TCPState) String() string {
switch s {
case TCPStateNew:
return "New"
case TCPStateSynSent:
return "SYN Sent"
case TCPStateSynReceived:
return "SYN Received"
case TCPStateEstablished:
return "Established"
case TCPStateFinWait1:
return "FIN Wait 1"
case TCPStateFinWait2:
return "FIN Wait 2"
case TCPStateClosing:
return "Closing"
case TCPStateTimeWait:
return "Time Wait"
case TCPStateCloseWait:
return "Close Wait"
case TCPStateLastAck:
return "Last ACK"
case TCPStateClosed:
return "Closed"
default:
return "Unknown"
}
}
const ( const (
TCPStateNew TCPState = iota TCPStateNew TCPState = iota
TCPStateSynSent TCPStateSynSent
@@ -53,19 +86,14 @@ const (
TCPStateClosed TCPStateClosed
) )
// TCPConnKey uniquely identifies a TCP connection
type TCPConnKey struct {
SrcIP [16]byte
DstIP [16]byte
SrcPort uint16
DstPort uint16
}
// TCPConnTrack represents a TCP connection state // TCPConnTrack represents a TCP connection state
type TCPConnTrack struct { type TCPConnTrack struct {
BaseConnTrack BaseConnTrack
SourcePort uint16
DestPort uint16
State TCPState State TCPState
established atomic.Bool established atomic.Bool
tombstone atomic.Bool
sync.RWMutex sync.RWMutex
} }
@@ -79,78 +107,126 @@ func (t *TCPConnTrack) SetEstablished(state bool) {
t.established.Store(state) t.established.Store(state)
} }
// IsTombstone safely checks if the connection is marked for deletion
func (t *TCPConnTrack) IsTombstone() bool {
return t.tombstone.Load()
}
// SetTombstone safely marks the connection for deletion
func (t *TCPConnTrack) SetTombstone() {
t.tombstone.Store(true)
}
// TCPTracker manages TCP connection states // TCPTracker manages TCP connection states
type TCPTracker struct { type TCPTracker struct {
logger *nblog.Logger logger *nblog.Logger
connections map[ConnKey]*TCPConnTrack connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex mutex sync.RWMutex
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
done chan struct{} tickerCancel context.CancelFunc
timeout time.Duration timeout time.Duration
ipPool *PreallocatedIPs flowLogger nftypes.FlowLogger
} }
// NewTCPTracker creates a new TCP connection tracker // NewTCPTracker creates a new TCP connection tracker
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker { func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker {
if timeout == 0 {
timeout = DefaultTCPTimeout
}
ctx, cancel := context.WithCancel(context.Background())
tracker := &TCPTracker{ tracker := &TCPTracker{
logger: logger, logger: logger,
connections: make(map[ConnKey]*TCPConnTrack), connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval), cleanupTicker: time.NewTicker(TCPCleanupInterval),
done: make(chan struct{}), tickerCancel: cancel,
timeout: timeout, timeout: timeout,
ipPool: NewPreallocatedIPs(), flowLogger: flowLogger,
} }
go tracker.cleanupRoutine() go tracker.cleanupRoutine(ctx)
return tracker return tracker
} }
// TrackOutbound processes an outbound TCP packet and updates connection state func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) { key := ConnKey{
// Create key before lock SrcIP: srcIP,
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if exists {
conn.Lock()
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
conn.Unlock()
conn.UpdateCounters(direction, size)
return key, true
}
return key, false
}
// TrackOutbound records an outbound TCP connection
func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, 0, 0); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
}
}
// TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
if exists {
return
}
conn := &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
},
SourcePort: srcPort,
DestPort: dstPort,
}
conn.established.Store(false)
conn.tombstone.Store(false)
t.logger.Trace("New %s TCP connection: %s", direction, key)
t.updateState(key, conn, flags, direction == nftypes.Egress)
t.mutex.Lock() t.mutex.Lock()
conn, exists := t.connections[key] t.connections[key] = conn
if !exists {
// Use preallocated IPs
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
SourcePort: srcPort,
DestPort: dstPort,
},
State: TCPStateNew,
}
conn.UpdateLastSeen()
conn.established.Store(false)
t.connections[key] = conn
t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
}
t.mutex.Unlock() t.mutex.Unlock()
// Lock individual connection for state update t.sendEvent(nftypes.TypeStart, conn, ruleID)
conn.Lock()
t.updateState(conn, flags, true)
conn.Unlock()
conn.UpdateLastSeen()
} }
// IsValidInbound checks if an inbound TCP packet matches a tracked connection // IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool { func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) bool {
if !isValidFlagCombination(flags) { key := ConnKey{
return false SrcIP: dstIP,
DstIP: srcIP,
SrcPort: dstPort,
DstPort: srcPort,
} }
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.RLock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
t.mutex.RUnlock() t.mutex.RUnlock()
@@ -159,22 +235,26 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
return false return false
} }
// Handle RST packets // Handle RST flag specially - it always causes transition to closed
if flags&TCPRst != 0 { if flags&TCPRst != 0 {
conn.Lock() if conn.IsTombstone() {
if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
conn.State = TCPStateClosed
conn.SetEstablished(false)
conn.Unlock()
return true return true
} }
conn.Lock()
conn.SetTombstone()
conn.State = TCPStateClosed
conn.SetEstablished(false)
conn.Unlock() conn.Unlock()
return false conn.UpdateCounters(nftypes.Ingress, size)
t.logger.Trace("TCP connection reset: %s", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
return true
} }
conn.Lock() conn.Lock()
t.updateState(conn, flags, false) t.updateState(key, conn, flags, false)
conn.UpdateLastSeen()
isEstablished := conn.IsEstablished() isEstablished := conn.IsEstablished()
isValidState := t.isValidStateForFlags(conn.State, flags) isValidState := t.isValidStateForFlags(conn.State, flags)
conn.Unlock() conn.Unlock()
@@ -183,18 +263,17 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
} }
// updateState updates the TCP connection state based on flags // updateState updates the TCP connection state based on flags
func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) { func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) {
// Handle RST flag specially - it always causes transition to closed conn.UpdateLastSeen()
if flags&TCPRst != 0 {
conn.State = TCPStateClosed
conn.SetEstablished(false)
t.logger.Trace("TCP connection reset: %s:%d -> %s:%d", state := conn.State
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) defer func() {
return if state != conn.State {
} t.logger.Trace("TCP connection %s transitioned from %s to %s", key, state, conn.State)
}
}()
switch conn.State { switch state {
case TCPStateNew: case TCPStateNew:
if flags&TCPSyn != 0 && flags&TCPAck == 0 { if flags&TCPSyn != 0 && flags&TCPAck == 0 {
conn.State = TCPStateSynSent conn.State = TCPStateSynSent
@@ -203,11 +282,11 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
case TCPStateSynSent: case TCPStateSynSent:
if flags&TCPSyn != 0 && flags&TCPAck != 0 { if flags&TCPSyn != 0 && flags&TCPAck != 0 {
if isOutbound { if isOutbound {
conn.State = TCPStateSynReceived
} else {
// Simultaneous open
conn.State = TCPStateEstablished conn.State = TCPStateEstablished
conn.SetEstablished(true) conn.SetEstablished(true)
} else {
// Simultaneous open
conn.State = TCPStateSynReceived
} }
} }
@@ -225,22 +304,32 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
conn.State = TCPStateCloseWait conn.State = TCPStateCloseWait
} }
conn.SetEstablished(false) conn.SetEstablished(false)
} else if flags&TCPRst != 0 {
conn.State = TCPStateClosed
conn.SetTombstone()
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateFinWait1: case TCPStateFinWait1:
switch { switch {
case flags&TCPFin != 0 && flags&TCPAck != 0: case flags&TCPFin != 0 && flags&TCPAck != 0:
// Simultaneous close - both sides sent FIN
conn.State = TCPStateClosing conn.State = TCPStateClosing
case flags&TCPFin != 0: case flags&TCPFin != 0:
conn.State = TCPStateFinWait2 conn.State = TCPStateFinWait2
case flags&TCPAck != 0: case flags&TCPAck != 0:
conn.State = TCPStateFinWait2 conn.State = TCPStateFinWait2
case flags&TCPRst != 0:
conn.State = TCPStateClosed
conn.SetTombstone()
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateFinWait2: case TCPStateFinWait2:
if flags&TCPFin != 0 { if flags&TCPFin != 0 {
conn.State = TCPStateTimeWait conn.State = TCPStateTimeWait
t.logger.Trace("TCP connection %s completed", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateClosing: case TCPStateClosing:
@@ -248,8 +337,8 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
conn.State = TCPStateTimeWait conn.State = TCPStateTimeWait
// Keep established = false from previous state // Keep established = false from previous state
t.logger.Trace("TCP connection closed (simultaneous) - %s:%d -> %s:%d", t.logger.Trace("TCP connection %s closed (simultaneous)", key)
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateCloseWait: case TCPStateCloseWait:
@@ -260,17 +349,12 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
case TCPStateLastAck: case TCPStateLastAck:
if flags&TCPAck != 0 { if flags&TCPAck != 0 {
conn.State = TCPStateClosed conn.State = TCPStateClosed
conn.SetTombstone()
t.logger.Trace("TCP connection gracefully closed: %s:%d -> %s:%d", // Send close event for gracefully closed connections
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) t.sendEvent(nftypes.TypeEnd, conn, nil)
t.logger.Trace("TCP connection %s closed gracefully", key)
} }
case TCPStateTimeWait:
// Stay in TIME-WAIT for 2MSL before transitioning to closed
// This is handled by the cleanup routine
t.logger.Trace("TCP connection completed - %s:%d -> %s:%d",
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
} }
} }
@@ -315,12 +399,14 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
return false return false
} }
func (t *TCPTracker) cleanupRoutine() { func (t *TCPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop()
for { for {
select { select {
case <-t.cleanupTicker.C: case <-t.cleanupTicker.C:
t.cleanup() t.cleanup()
case <-t.done: case <-ctx.Done():
return return
} }
} }
@@ -331,6 +417,12 @@ func (t *TCPTracker) cleanup() {
defer t.mutex.Unlock() defer t.mutex.Unlock()
for key, conn := range t.connections { for key, conn := range t.connections {
if conn.IsTombstone() {
// Clean up tombstoned connections without sending an event
delete(t.connections, key)
continue
}
var timeout time.Duration var timeout time.Duration
switch { switch {
case conn.State == TCPStateTimeWait: case conn.State == TCPStateTimeWait:
@@ -341,29 +433,26 @@ func (t *TCPTracker) cleanup() {
timeout = TCPHandshakeTimeout timeout = TCPHandshakeTimeout
} }
lastSeen := conn.GetLastSeen() if conn.timeoutExceeded(timeout) {
if time.Since(lastSeen) > timeout {
// Return IPs to pool // Return IPs to pool
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace("Cleaned up TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) t.logger.Trace("Cleaned up timed-out TCP connection %s", key)
// event already handled by state change
if conn.State != TCPStateTimeWait {
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
} }
} }
} }
// Close stops the cleanup routine and releases resources // Close stops the cleanup routine and releases resources
func (t *TCPTracker) Close() { func (t *TCPTracker) Close() {
t.cleanupTicker.Stop() t.tickerCancel()
close(t.done)
// Clean up all remaining IPs // Clean up all remaining IPs
t.mutex.Lock() t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil t.connections = nil
t.mutex.Unlock() t.mutex.Unlock()
} }
@@ -381,3 +470,21 @@ func isValidFlagCombination(flags uint8) bool {
return true return true
} }
func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack, ruleID []byte) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
RuleID: ruleID,
Direction: conn.Direction,
Protocol: nftypes.TCP,
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
SourcePort: conn.SourcePort,
DestPort: conn.DestPort,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
})
}

View File

@@ -1,7 +1,7 @@
package conntrack package conntrack
import ( import (
"net" "net/netip"
"testing" "testing"
"time" "time"
@@ -9,11 +9,11 @@ import (
) )
func TestTCPStateMachine(t *testing.T) { func TestTCPStateMachine(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1") srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := net.ParseIP("100.64.0.2") dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345) srcPort := uint16(12345)
dstPort := uint16(80) dstPort := uint16(80)
@@ -58,7 +58,7 @@ func TestTCPStateMachine(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags) isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags, 0)
require.Equal(t, !tt.wantDrop, isValid, tt.desc) require.Equal(t, !tt.wantDrop, isValid, tt.desc)
}) })
} }
@@ -76,17 +76,17 @@ func TestTCPStateMachine(t *testing.T) {
t.Helper() t.Helper()
// Send initial SYN // Send initial SYN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
// Receive SYN-ACK // Receive SYN-ACK
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
require.True(t, valid, "SYN-ACK should be allowed") require.True(t, valid, "SYN-ACK should be allowed")
// Send ACK // Send ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// Test data transfer // Test data transfer
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck) valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 0)
require.True(t, valid, "Data should be allowed after handshake") require.True(t, valid, "Data should be allowed after handshake")
}, },
}, },
@@ -99,18 +99,18 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Send FIN // Send FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
// Receive ACK for FIN // Receive ACK for FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid, "ACK for FIN should be allowed") require.True(t, valid, "ACK for FIN should be allowed")
// Receive FIN from other side // Receive FIN from other side
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck) valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid, "FIN should be allowed") require.True(t, valid, "FIN should be allowed")
// Send final ACK // Send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
}, },
}, },
{ {
@@ -122,7 +122,7 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Receive RST // Receive RST
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
require.True(t, valid, "RST should be allowed for established connection") require.True(t, valid, "RST should be allowed for established connection")
// Connection is logically dead but we don't enforce blocking subsequent packets // Connection is logically dead but we don't enforce blocking subsequent packets
@@ -138,13 +138,13 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Both sides send FIN+ACK // Both sides send FIN+ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid, "Simultaneous FIN should be allowed") require.True(t, valid, "Simultaneous FIN should be allowed")
// Both sides send final ACK // Both sides send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck) valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid, "Final ACKs should be allowed") require.True(t, valid, "Final ACKs should be allowed")
}, },
}, },
@@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Helper() t.Helper()
tracker = NewTCPTracker(DefaultTCPTimeout, logger) tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
tt.test(t) tt.test(t)
}) })
} }
@@ -162,11 +162,11 @@ func TestTCPStateMachine(t *testing.T) {
} }
func TestRSTHandling(t *testing.T) { func TestRSTHandling(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1") srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := net.ParseIP("100.64.0.2") dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345) srcPort := uint16(12345)
dstPort := uint16(80) dstPort := uint16(80)
@@ -181,12 +181,12 @@ func TestRSTHandling(t *testing.T) {
name: "RST in established", name: "RST in established",
setupState: func() { setupState: func() {
// Establish connection first // Establish connection first
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
}, },
sendRST: func() { sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
}, },
wantValid: true, wantValid: true,
desc: "Should accept RST for established connection", desc: "Should accept RST for established connection",
@@ -195,7 +195,7 @@ func TestRSTHandling(t *testing.T) {
name: "RST without connection", name: "RST without connection",
setupState: func() {}, setupState: func() {},
sendRST: func() { sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
}, },
wantValid: false, wantValid: false,
desc: "Should reject RST without connection", desc: "Should reject RST without connection",
@@ -208,7 +208,12 @@ func TestRSTHandling(t *testing.T) {
tt.sendRST() tt.sendRST()
// Verify connection state is as expected // Verify connection state is as expected
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key] conn := tracker.connections[key]
if tt.wantValid { if tt.wantValid {
require.NotNil(t, conn) require.NotNil(t, conn)
@@ -220,63 +225,63 @@ func TestRSTHandling(t *testing.T) {
} }
// Helper to establish a TCP connection // Helper to establish a TCP connection
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) { func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
t.Helper() t.Helper()
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
require.True(t, valid, "SYN-ACK should be allowed") require.True(t, valid, "SYN-ACK should be allowed")
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
} }
func BenchmarkTCPTracker(b *testing.B) { func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
} }
}) })
b.Run("IsValidInbound", func(b *testing.B) { b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections // Pre-populate some connections
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck) tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, 0)
} }
}) })
b.Run("ConcurrentAccess", func(b *testing.B) { b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
i := 0 i := 0
for pb.Next() { for pb.Next() {
if i%2 == 0 { if i%2 == 0 {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
} else { } else {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck) tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, 0)
} }
i++ i++
} }
@@ -287,14 +292,14 @@ func BenchmarkTCPTracker(b *testing.B) {
// Benchmark connection cleanup // Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) { func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) { b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100*time.Millisecond, logger) // Short timeout for testing tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger) // Short timeout for testing
defer tracker.Close() defer tracker.Close()
// Pre-populate with expired connections // Pre-populate with expired connections
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
for i := 0; i < 10000; i++ { for i := 0; i < 10000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
} }
// Wait for connections to expire // Wait for connections to expire

View File

@@ -1,11 +1,15 @@
package conntrack package conntrack
import ( import (
"net" "context"
"net/netip"
"sync" "sync"
"time" "time"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
const ( const (
@@ -18,6 +22,8 @@ const (
// UDPConnTrack represents a UDP connection state // UDPConnTrack represents a UDP connection state
type UDPConnTrack struct { type UDPConnTrack struct {
BaseConnTrack BaseConnTrack
SourcePort uint16
DestPort uint16
} }
// UDPTracker manages UDP connection states // UDPTracker manages UDP connection states
@@ -26,89 +32,125 @@ type UDPTracker struct {
connections map[ConnKey]*UDPConnTrack connections map[ConnKey]*UDPConnTrack
timeout time.Duration timeout time.Duration
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex mutex sync.RWMutex
done chan struct{} flowLogger nftypes.FlowLogger
ipPool *PreallocatedIPs
} }
// NewUDPTracker creates a new UDP connection tracker // NewUDPTracker creates a new UDP connection tracker
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker { func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *UDPTracker {
if timeout == 0 { if timeout == 0 {
timeout = DefaultUDPTimeout timeout = DefaultUDPTimeout
} }
ctx, cancel := context.WithCancel(context.Background())
tracker := &UDPTracker{ tracker := &UDPTracker{
logger: logger, logger: logger,
connections: make(map[ConnKey]*UDPConnTrack), connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval), cleanupTicker: time.NewTicker(UDPCleanupInterval),
done: make(chan struct{}), tickerCancel: cancel,
ipPool: NewPreallocatedIPs(), flowLogger: flowLogger,
} }
go tracker.cleanupRoutine() go tracker.cleanupRoutine(ctx)
return tracker return tracker
} }
// TrackOutbound records an outbound UDP connection // TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) { func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.mutex.Lock() t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
conn, exists := t.connections[key]
if !exists {
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &UDPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
SourcePort: srcPort,
DestPort: dstPort,
},
}
conn.UpdateLastSeen()
t.connections[key] = conn
t.logger.Trace("New UDP connection: %v", conn)
} }
t.mutex.Unlock()
conn.UpdateLastSeen()
} }
// IsValidInbound checks if an inbound packet matches a tracked connection // TrackInbound records an inbound UDP connection
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool { func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
key := makeConnKey(dstIP, srcIP, dstPort, srcPort) t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
}
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
t.mutex.RLock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
t.mutex.RUnlock() t.mutex.RUnlock()
if !exists { if exists {
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, true
}
return key, false
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
if exists {
return
}
conn := &UDPConnTrack{
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
},
SourcePort: srcPort,
DestPort: dstPort,
}
conn.UpdateLastSeen()
t.mutex.Lock()
t.connections[key] = conn
t.mutex.Unlock()
t.logger.Trace("New %s UDP connection: %s", direction, key)
t.sendEvent(nftypes.TypeStart, conn, ruleID)
}
// IsValidInbound checks if an inbound packet matches a tracked connection
func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) bool {
key := ConnKey{
SrcIP: dstIP,
DstIP: srcIP,
SrcPort: dstPort,
DstPort: srcPort,
}
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists || conn.timeoutExceeded(t.timeout) {
return false return false
} }
if conn.timeoutExceeded(t.timeout) { conn.UpdateLastSeen()
return false conn.UpdateCounters(nftypes.Ingress, size)
}
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && return true
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
conn.DestPort == srcPort &&
conn.SourcePort == dstPort
} }
// cleanupRoutine periodically removes stale connections // cleanupRoutine periodically removes stale connections
func (t *UDPTracker) cleanupRoutine() { func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop()
for { for {
select { select {
case <-t.cleanupTicker.C: case <-t.cleanupTicker.C:
t.cleanup() t.cleanup()
case <-t.done: case <-ctx.Done():
return return
} }
} }
@@ -120,44 +162,58 @@ func (t *UDPTracker) cleanup() {
for key, conn := range t.connections { for key, conn := range t.connections {
if conn.timeoutExceeded(t.timeout) { if conn.timeoutExceeded(t.timeout) {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace("Removed UDP connection %v (timeout)", conn) t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
} }
} }
// Close stops the cleanup routine and releases resources // Close stops the cleanup routine and releases resources
func (t *UDPTracker) Close() { func (t *UDPTracker) Close() {
t.cleanupTicker.Stop() t.tickerCancel()
close(t.done)
t.mutex.Lock() t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil t.connections = nil
t.mutex.Unlock() t.mutex.Unlock()
} }
// GetConnection safely retrieves a connection state // GetConnection safely retrieves a connection state
func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) { func (t *UDPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*UDPConnTrack, bool) {
t.mutex.RLock() t.mutex.RLock()
defer t.mutex.RUnlock() defer t.mutex.RUnlock()
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) key := ConnKey{
conn, exists := t.connections[key] SrcIP: srcIP,
if !exists { DstIP: dstIP,
return nil, false SrcPort: srcPort,
DstPort: dstPort,
} }
conn, exists := t.connections[key]
return conn, true return conn, exists
} }
// Timeout returns the configured timeout duration for the tracker // Timeout returns the configured timeout duration for the tracker
func (t *UDPTracker) Timeout() time.Duration { func (t *UDPTracker) Timeout() time.Duration {
return t.timeout return t.timeout
} }
func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack, ruleID []byte) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
RuleID: ruleID,
Direction: conn.Direction,
Protocol: nftypes.UDP,
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
SourcePort: conn.SourcePort,
DestPort: conn.DestPort,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
})
}

View File

@@ -1,7 +1,8 @@
package conntrack package conntrack
import ( import (
"net" "context"
"net/netip"
"testing" "testing"
"time" "time"
@@ -29,54 +30,59 @@ func TestNewUDPTracker(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tracker := NewUDPTracker(tt.timeout, logger) tracker := NewUDPTracker(tt.timeout, logger, flowLogger)
assert.NotNil(t, tracker) assert.NotNil(t, tracker)
assert.Equal(t, tt.wantTimeout, tracker.timeout) assert.Equal(t, tt.wantTimeout, tracker.timeout)
assert.NotNil(t, tracker.connections) assert.NotNil(t, tracker.connections)
assert.NotNil(t, tracker.cleanupTicker) assert.NotNil(t, tracker.cleanupTicker)
assert.NotNil(t, tracker.done) assert.NotNil(t, tracker.tickerCancel)
}) })
} }
} }
func TestUDPTracker_TrackOutbound(t *testing.T) { func TestUDPTracker_TrackOutbound(t *testing.T) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger) tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2") srcIP := netip.MustParseAddr("192.168.1.2")
dstIP := net.ParseIP("192.168.1.3") dstIP := netip.MustParseAddr("192.168.1.3")
srcPort := uint16(12345) srcPort := uint16(12345)
dstPort := uint16(53) dstPort := uint16(53)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
// Verify connection was tracked // Verify connection was tracked
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn, exists := tracker.connections[key] conn, exists := tracker.connections[key]
require.True(t, exists) require.True(t, exists)
assert.True(t, conn.SourceIP.Equal(srcIP)) assert.True(t, conn.SourceIP.Compare(srcIP) == 0)
assert.True(t, conn.DestIP.Equal(dstIP)) assert.True(t, conn.DestIP.Compare(dstIP) == 0)
assert.Equal(t, srcPort, conn.SourcePort) assert.Equal(t, srcPort, conn.SourcePort)
assert.Equal(t, dstPort, conn.DestPort) assert.Equal(t, dstPort, conn.DestPort)
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second) assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
} }
func TestUDPTracker_IsValidInbound(t *testing.T) { func TestUDPTracker_IsValidInbound(t *testing.T) {
tracker := NewUDPTracker(1*time.Second, logger) tracker := NewUDPTracker(1*time.Second, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2") srcIP := netip.MustParseAddr("192.168.1.2")
dstIP := net.ParseIP("192.168.1.3") dstIP := netip.MustParseAddr("192.168.1.3")
srcPort := uint16(12345) srcPort := uint16(12345)
dstPort := uint16(53) dstPort := uint16(53)
// Track outbound connection // Track outbound connection
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
tests := []struct { tests := []struct {
name string name string
srcIP net.IP srcIP netip.Addr
dstIP net.IP dstIP netip.Addr
srcPort uint16 srcPort uint16
dstPort uint16 dstPort uint16
sleep time.Duration sleep time.Duration
@@ -93,7 +99,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
}, },
{ {
name: "invalid source IP", name: "invalid source IP",
srcIP: net.ParseIP("192.168.1.4"), srcIP: netip.MustParseAddr("192.168.1.4"),
dstIP: srcIP, dstIP: srcIP,
srcPort: dstPort, srcPort: dstPort,
dstPort: srcPort, dstPort: srcPort,
@@ -103,7 +109,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
{ {
name: "invalid destination IP", name: "invalid destination IP",
srcIP: dstIP, srcIP: dstIP,
dstIP: net.ParseIP("192.168.1.4"), dstIP: netip.MustParseAddr("192.168.1.4"),
srcPort: dstPort, srcPort: dstPort,
dstPort: srcPort, dstPort: srcPort,
sleep: 0, sleep: 0,
@@ -143,7 +149,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
if tt.sleep > 0 { if tt.sleep > 0 {
time.Sleep(tt.sleep) time.Sleep(tt.sleep)
} }
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort) got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort, 0)
assert.Equal(t, tt.want, got) assert.Equal(t, tt.want, got)
}) })
} }
@@ -154,42 +160,45 @@ func TestUDPTracker_Cleanup(t *testing.T) {
timeout := 50 * time.Millisecond timeout := 50 * time.Millisecond
cleanupInterval := 25 * time.Millisecond cleanupInterval := 25 * time.Millisecond
ctx, tickerCancel := context.WithCancel(context.Background())
defer tickerCancel()
// Create tracker with custom cleanup interval // Create tracker with custom cleanup interval
tracker := &UDPTracker{ tracker := &UDPTracker{
connections: make(map[ConnKey]*UDPConnTrack), connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(cleanupInterval), cleanupTicker: time.NewTicker(cleanupInterval),
done: make(chan struct{}), tickerCancel: tickerCancel,
ipPool: NewPreallocatedIPs(),
logger: logger, logger: logger,
flowLogger: flowLogger,
} }
// Start cleanup routine // Start cleanup routine
go tracker.cleanupRoutine() go tracker.cleanupRoutine(ctx)
// Add some connections // Add some connections
connections := []struct { connections := []struct {
srcIP net.IP srcIP netip.Addr
dstIP net.IP dstIP netip.Addr
srcPort uint16 srcPort uint16
dstPort uint16 dstPort uint16
}{ }{
{ {
srcIP: net.ParseIP("192.168.1.2"), srcIP: netip.MustParseAddr("192.168.1.2"),
dstIP: net.ParseIP("192.168.1.3"), dstIP: netip.MustParseAddr("192.168.1.3"),
srcPort: 12345, srcPort: 12345,
dstPort: 53, dstPort: 53,
}, },
{ {
srcIP: net.ParseIP("192.168.1.4"), srcIP: netip.MustParseAddr("192.168.1.4"),
dstIP: net.ParseIP("192.168.1.5"), dstIP: netip.MustParseAddr("192.168.1.5"),
srcPort: 12346, srcPort: 12346,
dstPort: 53, dstPort: 53,
}, },
} }
for _, conn := range connections { for _, conn := range connections {
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort) tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort, 0)
} }
// Verify initial connections // Verify initial connections
@@ -211,33 +220,33 @@ func TestUDPTracker_Cleanup(t *testing.T) {
func BenchmarkUDPTracker(b *testing.B) { func BenchmarkUDPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger) tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, 0)
} }
}) })
b.Run("IsValidInbound", func(b *testing.B) { b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger) tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections // Pre-populate some connections
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, 0)
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000)) tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), 0)
} }
}) })
} }

View File

@@ -1,6 +1,8 @@
package forwarder package forwarder
import ( import (
"fmt"
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/header"
@@ -79,3 +81,10 @@ func (e *endpoint) AddHeader(*stack.PacketBuffer) {
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool { func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
return true return true
} }
type epID stack.TransportEndpointID
func (i epID) String() string {
// src and remote is swapped
return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
}

View File

@@ -18,6 +18,7 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/common" "github.com/netbirdio/netbird/client/firewall/uspfilter/common"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
const ( const (
@@ -29,6 +30,7 @@ const (
type Forwarder struct { type Forwarder struct {
logger *nblog.Logger logger *nblog.Logger
flowLogger nftypes.FlowLogger
stack *stack.Stack stack *stack.Stack
endpoint *endpoint endpoint *endpoint
udpForwarder *udpForwarder udpForwarder *udpForwarder
@@ -38,7 +40,7 @@ type Forwarder struct {
netstack bool netstack bool
} }
func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) { func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) {
s := stack.New(stack.Options{ s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{ TransportProtocols: []stack.TransportProtocolFactory{
@@ -102,9 +104,10 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwar
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
f := &Forwarder{ f := &Forwarder{
logger: logger, logger: logger,
flowLogger: flowLogger,
stack: s, stack: s,
endpoint: endpoint, endpoint: endpoint,
udpForwarder: newUDPForwarder(mtu, logger), udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
netstack: netstack, netstack: netstack,

View File

@@ -3,14 +3,30 @@ package forwarder
import ( import (
"context" "context"
"net" "net"
"net/netip"
"time" "time"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
// handleICMP handles ICMP packets from the network stack // handleICMP handles ICMP packets from the network stack
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool { func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
icmpType := uint8(icmpHdr.Type())
icmpCode := uint8(icmpHdr.Code())
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
// dont process our own replies
return true
}
flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode)
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second) ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel() defer cancel()
@@ -18,7 +34,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
// TODO: support non-root // TODO: support non-root
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil { if err != nil {
f.logger.Error("Failed to create ICMP socket for %v: %v", id, err) f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err)
// This will make netstack reply on behalf of the original destination, that's ok for now // This will make netstack reply on behalf of the original destination, that's ok for now
return false return false
@@ -32,47 +48,31 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
dstIP := f.determineDialAddr(id.LocalAddress) dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP} dst := &net.IPAddr{IP: dstIP}
// Get the complete ICMP message (header + data)
fullPacket := stack.PayloadSince(pkt.TransportHeader()) fullPacket := stack.PayloadSince(pkt.TransportHeader())
payload := fullPacket.AsSlice() payload := fullPacket.AsSlice()
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice()) if _, err = conn.WriteTo(payload, dst); err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
return true
}
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
// For Echo Requests, send and handle response // For Echo Requests, send and handle response
switch icmpHdr.Type() { if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
case header.ICMPv4Echo: f.handleEchoResponse(icmpHdr, conn, id)
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id) f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
case header.ICMPv4EchoReply:
// dont process our own replies
return true
default:
} }
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) // For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
_, err = conn.WriteTo(payload, dst)
if err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
return true
}
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
id, icmpHdr.Type(), icmpHdr.Code())
return true return true
} }
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID) bool { func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) {
if _, err := conn.WriteTo(payload, dst); err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
return true
}
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
id, icmpHdr.Type(), icmpHdr.Code())
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error("Failed to set read deadline for ICMP response: %v", err) f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
return true return
} }
response := make([]byte, f.endpoint.mtu) response := make([]byte, f.endpoint.mtu)
@@ -81,7 +81,7 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds
if !isTimeout(err) { if !isTimeout(err) {
f.logger.Error("Failed to read ICMP response: %v", err) f.logger.Error("Failed to read ICMP response: %v", err)
} }
return true return
} }
ipHdr := make([]byte, header.IPv4MinimumSize) ipHdr := make([]byte, header.IPv4MinimumSize)
@@ -101,9 +101,27 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds
if err := f.InjectIncomingPacket(fullPacket); err != nil { if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error("Failed to inject ICMP response: %v", err) f.logger.Error("Failed to inject ICMP response: %v", err)
return true
return
} }
f.logger.Trace("Forwarded ICMP echo reply for %v", id) f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v",
return true epID(id), icmpHdr.Type(), icmpHdr.Code())
}
// sendICMPEvent stores flow events for ICMP packets
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8) {
f.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.ICMP,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
ICMPType: icmpType,
ICMPCode: icmpCode,
// TODO: get packets/bytes
})
} }

View File

@@ -5,24 +5,38 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"net/netip"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter" "gvisor.dev/gvisor/pkg/waiter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
// handleTCP is called by the TCP forwarder for new connections. // handleTCP is called by the TCP forwarder for new connections.
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
id := r.ID() id := r.ID()
flowID := uuid.New()
f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil)
var success bool
defer func() {
if !success {
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil)
}
}()
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if err != nil { if err != nil {
r.Complete(true) r.Complete(true)
f.logger.Trace("forwarder: dial error for %v: %v", id, err) f.logger.Trace("forwarder: dial error for %v: %v", epID(id), err)
return return
} }
@@ -44,12 +58,13 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
inConn := gonet.NewTCPConn(&wq, ep) inConn := gonet.NewTCPConn(&wq, ep)
f.logger.Trace("forwarder: established TCP connection %v", id) success = true
f.logger.Trace("forwarder: established TCP connection %v", epID(id))
go f.proxyTCP(id, inConn, outConn, ep) go f.proxyTCP(id, inConn, outConn, ep, flowID)
} }
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint) { func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
defer func() { defer func() {
if err := inConn.Close(); err != nil { if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: inConn close error: %v", err) f.logger.Debug("forwarder: inConn close error: %v", err)
@@ -58,6 +73,8 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
f.logger.Debug("forwarder: outConn close error: %v", err) f.logger.Debug("forwarder: outConn close error: %v", err)
} }
ep.Close() ep.Close()
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep)
}() }()
// Create context for managing the proxy goroutines // Create context for managing the proxy goroutines
@@ -78,13 +95,38 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
select { select {
case <-ctx.Done(): case <-ctx.Done():
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", id) f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", epID(id))
return return
case err := <-errChan: case err := <-errChan:
if err != nil && !isClosedError(err) { if err != nil && !isClosedError(err) {
f.logger.Error("proxyTCP: copy error: %v", err) f.logger.Error("proxyTCP: copy error: %v", err)
} }
f.logger.Trace("forwarder: tearing down TCP connection %v", id) f.logger.Trace("forwarder: tearing down TCP connection %v", epID(id))
return return
} }
} }
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
fields := nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.TCP,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
SourcePort: id.RemotePort,
DestPort: id.LocalPort,
}
if ep != nil {
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
// fields are flipped since this is the in conn
// TODO: get bytes
fields.RxPackets = tcpStats.SegmentsSent.Value()
fields.TxPackets = tcpStats.SegmentsReceived.Value()
}
}
f.flowLogger.StoreEvent(fields)
}

View File

@@ -5,10 +5,12 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -16,6 +18,7 @@ import (
"gvisor.dev/gvisor/pkg/waiter" "gvisor.dev/gvisor/pkg/waiter"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
const ( const (
@@ -28,15 +31,17 @@ type udpPacketConn struct {
lastSeen atomic.Int64 lastSeen atomic.Int64
cancel context.CancelFunc cancel context.CancelFunc
ep tcpip.Endpoint ep tcpip.Endpoint
flowID uuid.UUID
} }
type udpForwarder struct { type udpForwarder struct {
sync.RWMutex sync.RWMutex
logger *nblog.Logger logger *nblog.Logger
conns map[stack.TransportEndpointID]*udpPacketConn flowLogger nftypes.FlowLogger
bufPool sync.Pool conns map[stack.TransportEndpointID]*udpPacketConn
ctx context.Context bufPool sync.Pool
cancel context.CancelFunc ctx context.Context
cancel context.CancelFunc
} }
type idleConn struct { type idleConn struct {
@@ -44,13 +49,14 @@ type idleConn struct {
conn *udpPacketConn conn *udpPacketConn
} }
func newUDPForwarder(mtu int, logger *nblog.Logger) *udpForwarder { func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
f := &udpForwarder{ f := &udpForwarder{
logger: logger, logger: logger,
conns: make(map[stack.TransportEndpointID]*udpPacketConn), flowLogger: flowLogger,
ctx: ctx, conns: make(map[stack.TransportEndpointID]*udpPacketConn),
cancel: cancel, ctx: ctx,
cancel: cancel,
bufPool: sync.Pool{ bufPool: sync.Pool{
New: func() any { New: func() any {
b := make([]byte, mtu) b := make([]byte, mtu)
@@ -72,10 +78,10 @@ func (f *udpForwarder) Stop() {
for id, conn := range f.conns { for id, conn := range f.conns {
conn.cancel() conn.cancel()
if err := conn.conn.Close(); err != nil { if err := conn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP conn close error for %v: %v", id, err) f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(id), err)
} }
if err := conn.outConn.Close(); err != nil { if err := conn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err) f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
} }
conn.ep.Close() conn.ep.Close()
@@ -106,10 +112,10 @@ func (f *udpForwarder) cleanup() {
for _, idle := range idleConns { for _, idle := range idleConns {
idle.conn.cancel() idle.conn.cancel()
if err := idle.conn.conn.Close(); err != nil { if err := idle.conn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP conn close error for %v: %v", idle.id, err) f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(idle.id), err)
} }
if err := idle.conn.outConn.Close(); err != nil { if err := idle.conn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", idle.id, err) f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(idle.id), err)
} }
idle.conn.ep.Close() idle.conn.ep.Close()
@@ -118,7 +124,7 @@ func (f *udpForwarder) cleanup() {
delete(f.conns, idle.id) delete(f.conns, idle.id)
f.Unlock() f.Unlock()
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id) f.logger.Trace("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
} }
} }
} }
@@ -137,14 +143,24 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
_, exists := f.udpForwarder.conns[id] _, exists := f.udpForwarder.conns[id]
f.udpForwarder.RUnlock() f.udpForwarder.RUnlock()
if exists { if exists {
f.logger.Trace("forwarder: existing UDP connection for %v", id) f.logger.Trace("forwarder: existing UDP connection for %v", epID(id))
return return
} }
flowID := uuid.New()
f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil)
var success bool
defer func() {
if !success {
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil)
}
}()
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil { if err != nil {
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err) f.logger.Debug("forwarder: UDP dial error for %v: %v", epID(id), err)
// TODO: Send ICMP error message // TODO: Send ICMP error message
return return
} }
@@ -155,7 +171,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if epErr != nil { if epErr != nil {
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr) f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
if err := outConn.Close(); err != nil { if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err) f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
} }
return return
} }
@@ -168,6 +184,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
outConn: outConn, outConn: outConn,
cancel: connCancel, cancel: connCancel,
ep: ep, ep: ep,
flowID: flowID,
} }
pConn.updateLastSeen() pConn.updateLastSeen()
@@ -177,17 +194,20 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
f.udpForwarder.Unlock() f.udpForwarder.Unlock()
pConn.cancel() pConn.cancel()
if err := inConn.Close(); err != nil { if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err) f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
} }
if err := outConn.Close(); err != nil { if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err) f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
} }
return return
} }
f.udpForwarder.conns[id] = pConn f.udpForwarder.conns[id] = pConn
f.udpForwarder.Unlock() f.udpForwarder.Unlock()
f.logger.Trace("forwarder: established UDP connection to %v", id) success = true
f.logger.Trace("forwarder: established UDP connection %v", epID(id))
go f.proxyUDP(connCtx, pConn, id, ep) go f.proxyUDP(connCtx, pConn, id, ep)
} }
@@ -195,10 +215,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
defer func() { defer func() {
pConn.cancel() pConn.cancel()
if err := pConn.conn.Close(); err != nil { if err := pConn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err) f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
} }
if err := pConn.outConn.Close(); err != nil { if err := pConn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err) f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
} }
ep.Close() ep.Close()
@@ -206,6 +226,8 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
f.udpForwarder.Lock() f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id) delete(f.udpForwarder.conns, id)
f.udpForwarder.Unlock() f.udpForwarder.Unlock()
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep)
}() }()
errChan := make(chan error, 2) errChan := make(chan error, 2)
@@ -220,17 +242,43 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
select { select {
case <-ctx.Done(): case <-ctx.Done():
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", id) f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", epID(id))
return return
case err := <-errChan: case err := <-errChan:
if err != nil && !isClosedError(err) { if err != nil && !isClosedError(err) {
f.logger.Error("proxyUDP: copy error: %v", err) f.logger.Error("proxyUDP: copy error: %v", err)
} }
f.logger.Trace("forwarder: tearing down UDP connection %v", id) f.logger.Trace("forwarder: tearing down UDP connection %v", epID(id))
return return
} }
} }
// sendUDPEvent stores flow events for UDP connections
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
fields := nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.UDP,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
SourcePort: id.RemotePort,
DestPort: id.LocalPort,
}
if ep != nil {
if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
// fields are flipped since this is the in conn
// TODO: get bytes
fields.RxPackets = tcpStats.PacketsSent.Value()
fields.TxPackets = tcpStats.PacketsReceived.Value()
}
}
f.flowLogger.StoreEvent(fields)
}
func (c *udpPacketConn) updateLastSeen() { func (c *udpPacketConn) updateLastSeen() {
c.lastSeen.Store(time.Now().UnixNano()) c.lastSeen.Store(time.Now().UnixNano())
} }

View File

@@ -3,6 +3,7 @@ package uspfilter
import ( import (
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -31,13 +32,9 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
m.ipv4Bitmap[high] |= 1 << (low % 32) m.ipv4Bitmap[high] |= 1 << (low % 32)
} }
func (m *localIPManager) checkBitmapBit(ip net.IP) bool { func (m *localIPManager) checkBitmapBit(ip []byte) bool {
ipv4 := ip.To4() high := (uint16(ip[0]) << 8) | uint16(ip[1])
if ipv4 == nil { low := (uint16(ip[2]) << 8) | uint16(ip[3])
return false
}
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0 return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
} }
@@ -122,12 +119,12 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
return nil return nil
} }
func (m *localIPManager) IsLocalIP(ip net.IP) bool { func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
m.mu.RLock() m.mu.RLock()
defer m.mu.RUnlock() defer m.mu.RUnlock()
if ipv4 := ip.To4(); ipv4 != nil { if ip.Is4() {
return m.checkBitmapBit(ipv4) return m.checkBitmapBit(ip.AsSlice())
} }
return false return false

View File

@@ -2,90 +2,91 @@ package uspfilter
import ( import (
"net" "net"
"net/netip"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
func TestLocalIPManager(t *testing.T) { func TestLocalIPManager(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
setupAddr iface.WGAddress setupAddr wgaddr.Address
testIP net.IP testIP netip.Addr
expected bool expected bool
}{ }{
{ {
name: "Localhost range", name: "Localhost range",
setupAddr: iface.WGAddress{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"), IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32), Mask: net.CIDRMask(24, 32),
}, },
}, },
testIP: net.ParseIP("127.0.0.2"), testIP: netip.MustParseAddr("127.0.0.2"),
expected: true, expected: true,
}, },
{ {
name: "Localhost standard address", name: "Localhost standard address",
setupAddr: iface.WGAddress{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"), IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32), Mask: net.CIDRMask(24, 32),
}, },
}, },
testIP: net.ParseIP("127.0.0.1"), testIP: netip.MustParseAddr("127.0.0.1"),
expected: true, expected: true,
}, },
{ {
name: "Localhost range edge", name: "Localhost range edge",
setupAddr: iface.WGAddress{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"), IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32), Mask: net.CIDRMask(24, 32),
}, },
}, },
testIP: net.ParseIP("127.255.255.255"), testIP: netip.MustParseAddr("127.255.255.255"),
expected: true, expected: true,
}, },
{ {
name: "Local IP matches", name: "Local IP matches",
setupAddr: iface.WGAddress{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"), IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32), Mask: net.CIDRMask(24, 32),
}, },
}, },
testIP: net.ParseIP("192.168.1.1"), testIP: netip.MustParseAddr("192.168.1.1"),
expected: true, expected: true,
}, },
{ {
name: "Local IP doesn't match", name: "Local IP doesn't match",
setupAddr: iface.WGAddress{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"), IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32), Mask: net.CIDRMask(24, 32),
}, },
}, },
testIP: net.ParseIP("192.168.1.2"), testIP: netip.MustParseAddr("192.168.1.2"),
expected: false, expected: false,
}, },
{ {
name: "IPv6 address", name: "IPv6 address",
setupAddr: iface.WGAddress{ setupAddr: wgaddr.Address{
IP: net.ParseIP("fe80::1"), IP: net.ParseIP("fe80::1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("fe80::"), IP: net.ParseIP("fe80::"),
Mask: net.CIDRMask(64, 128), Mask: net.CIDRMask(64, 128),
}, },
}, },
testIP: net.ParseIP("fe80::1"), testIP: netip.MustParseAddr("fe80::1"),
expected: false, expected: false,
}, },
} }
@@ -95,7 +96,7 @@ func TestLocalIPManager(t *testing.T) {
manager := newLocalIPManager() manager := newLocalIPManager()
mock := &IFaceMock{ mock := &IFaceMock{
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return tt.setupAddr return tt.setupAddr
}, },
} }
@@ -174,7 +175,7 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
t.Logf("Testing %d IPs", len(tests)) t.Logf("Testing %d IPs", len(tests))
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.ip, func(t *testing.T) { t.Run(tt.ip, func(t *testing.T) {
result := manager.IsLocalIP(net.ParseIP(tt.ip)) result := manager.IsLocalIP(netip.MustParseAddr(tt.ip))
require.Equal(t, tt.expected, result, "IP: %s", tt.ip) require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
}) })
} }

View File

@@ -1,4 +1,4 @@
// Package logger provides a high-performance, non-blocking logger for userspace networking // Package log provides a high-performance, non-blocking logger for userspace networking
package log package log
import ( import (
@@ -13,13 +13,12 @@ import (
) )
const ( const (
maxBatchSize = 1024 * 16 // 16KB max batch size maxBatchSize = 1024 * 16
maxMessageSize = 1024 * 2 // 2KB per message maxMessageSize = 1024 * 2
bufferSize = 1024 * 256 // 256KB ring buffer
defaultFlushInterval = 2 * time.Second defaultFlushInterval = 2 * time.Second
logChannelSize = 1000
) )
// Level represents log severity
type Level uint32 type Level uint32
const ( const (
@@ -42,32 +41,37 @@ var levelStrings = map[Level]string{
LevelTrace: "TRAC", LevelTrace: "TRAC",
} }
// Logger is a high-performance, non-blocking logger type logMessage struct {
type Logger struct { level Level
output io.Writer format string
level atomic.Uint32 args []any
buffer *ringBuffer
shutdown chan struct{}
closeOnce sync.Once
wg sync.WaitGroup
// Reusable buffer pool for formatting messages
bufPool sync.Pool
} }
// Logger is a high-performance, non-blocking logger
type Logger struct {
output io.Writer
level atomic.Uint32
msgChannel chan logMessage
shutdown chan struct{}
closeOnce sync.Once
wg sync.WaitGroup
bufPool sync.Pool
}
// NewFromLogrus creates a new Logger that writes to the same output as the given logrus logger
func NewFromLogrus(logrusLogger *log.Logger) *Logger { func NewFromLogrus(logrusLogger *log.Logger) *Logger {
l := &Logger{ l := &Logger{
output: logrusLogger.Out, output: logrusLogger.Out,
buffer: newRingBuffer(bufferSize), msgChannel: make(chan logMessage, logChannelSize),
shutdown: make(chan struct{}), shutdown: make(chan struct{}),
bufPool: sync.Pool{ bufPool: sync.Pool{
New: func() interface{} { New: func() any {
// Pre-allocate buffer for message formatting
b := make([]byte, 0, maxMessageSize) b := make([]byte, 0, maxMessageSize)
return &b return &b
}, },
}, },
} }
logrusLevel := logrusLogger.GetLevel() logrusLevel := logrusLogger.GetLevel()
l.level.Store(uint32(logrusLevel)) l.level.Store(uint32(logrusLevel))
level := levelStrings[Level(logrusLevel)] level := levelStrings[Level(logrusLevel)]
@@ -79,97 +83,149 @@ func NewFromLogrus(logrusLogger *log.Logger) *Logger {
return l return l
} }
// SetLevel sets the logging level
func (l *Logger) SetLevel(level Level) { func (l *Logger) SetLevel(level Level) {
l.level.Store(uint32(level)) l.level.Store(uint32(level))
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level]) log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
} }
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) { func (l *Logger) log(level Level, format string, args ...any) {
*buf = (*buf)[:0] select {
case l.msgChannel <- logMessage{level: level, format: format, args: args}:
// Timestamp default:
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
*buf = append(*buf, ' ')
// Level
*buf = append(*buf, levelStrings[level]...)
*buf = append(*buf, ' ')
// Message
if len(args) > 0 {
*buf = append(*buf, fmt.Sprintf(format, args...)...)
} else {
*buf = append(*buf, format...)
} }
*buf = append(*buf, '\n')
} }
func (l *Logger) log(level Level, format string, args ...interface{}) { // Error logs a message at error level
bufp := l.bufPool.Get().(*[]byte) func (l *Logger) Error(format string, args ...any) {
l.formatMessage(bufp, level, format, args...)
if len(*bufp) > maxMessageSize {
*bufp = (*bufp)[:maxMessageSize]
}
_, _ = l.buffer.Write(*bufp)
l.bufPool.Put(bufp)
}
func (l *Logger) Error(format string, args ...interface{}) {
if l.level.Load() >= uint32(LevelError) { if l.level.Load() >= uint32(LevelError) {
l.log(LevelError, format, args...) l.log(LevelError, format, args...)
} }
} }
func (l *Logger) Warn(format string, args ...interface{}) { // Warn logs a message at warning level
func (l *Logger) Warn(format string, args ...any) {
if l.level.Load() >= uint32(LevelWarn) { if l.level.Load() >= uint32(LevelWarn) {
l.log(LevelWarn, format, args...) l.log(LevelWarn, format, args...)
} }
} }
func (l *Logger) Info(format string, args ...interface{}) { // Info logs a message at info level
func (l *Logger) Info(format string, args ...any) {
if l.level.Load() >= uint32(LevelInfo) { if l.level.Load() >= uint32(LevelInfo) {
l.log(LevelInfo, format, args...) l.log(LevelInfo, format, args...)
} }
} }
func (l *Logger) Debug(format string, args ...interface{}) { // Debug logs a message at debug level
func (l *Logger) Debug(format string, args ...any) {
if l.level.Load() >= uint32(LevelDebug) { if l.level.Load() >= uint32(LevelDebug) {
l.log(LevelDebug, format, args...) l.log(LevelDebug, format, args...)
} }
} }
func (l *Logger) Trace(format string, args ...interface{}) { // Trace logs a message at trace level
func (l *Logger) Trace(format string, args ...any) {
if l.level.Load() >= uint32(LevelTrace) { if l.level.Load() >= uint32(LevelTrace) {
l.log(LevelTrace, format, args...) l.log(LevelTrace, format, args...)
} }
} }
// worker periodically flushes the buffer func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...any) {
*buf = (*buf)[:0]
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
*buf = append(*buf, ' ')
*buf = append(*buf, levelStrings[level]...)
*buf = append(*buf, ' ')
var msg string
if len(args) > 0 {
msg = fmt.Sprintf(format, args...)
} else {
msg = format
}
*buf = append(*buf, msg...)
*buf = append(*buf, '\n')
if len(*buf) > maxMessageSize {
*buf = (*buf)[:maxMessageSize]
}
}
// processMessage handles a single log message and adds it to the buffer
func (l *Logger) processMessage(msg logMessage, buffer *[]byte) {
bufp := l.bufPool.Get().(*[]byte)
defer l.bufPool.Put(bufp)
l.formatMessage(bufp, msg.level, msg.format, msg.args...)
if len(*buffer)+len(*bufp) > maxBatchSize {
_, _ = l.output.Write(*buffer)
*buffer = (*buffer)[:0]
}
*buffer = append(*buffer, *bufp...)
}
// flushBuffer writes the accumulated buffer to output
func (l *Logger) flushBuffer(buffer *[]byte) {
if len(*buffer) > 0 {
_, _ = l.output.Write(*buffer)
*buffer = (*buffer)[:0]
}
}
// processBatch processes as many messages as possible without blocking
func (l *Logger) processBatch(buffer *[]byte) {
for len(*buffer) < maxBatchSize {
select {
case msg := <-l.msgChannel:
l.processMessage(msg, buffer)
default:
return
}
}
}
// handleShutdown manages the graceful shutdown sequence with timeout
func (l *Logger) handleShutdown(buffer *[]byte) {
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
for {
select {
case msg := <-l.msgChannel:
l.processMessage(msg, buffer)
case <-ctx.Done():
l.flushBuffer(buffer)
return
}
if len(l.msgChannel) == 0 {
l.flushBuffer(buffer)
return
}
}
}
// worker is the main goroutine that processes log messages
func (l *Logger) worker() { func (l *Logger) worker() {
defer l.wg.Done() defer l.wg.Done()
ticker := time.NewTicker(defaultFlushInterval) ticker := time.NewTicker(defaultFlushInterval)
defer ticker.Stop() defer ticker.Stop()
buf := make([]byte, 0, maxBatchSize) buffer := make([]byte, 0, maxBatchSize)
for { for {
select { select {
case <-l.shutdown: case <-l.shutdown:
l.handleShutdown(&buffer)
return return
case <-ticker.C: case <-ticker.C:
// Read accumulated messages l.flushBuffer(&buffer)
n, _ := l.buffer.Read(buf[:cap(buf)]) case msg := <-l.msgChannel:
if n == 0 { l.processMessage(msg, &buffer)
continue l.processBatch(&buffer)
}
// Write batch
_, _ = l.output.Write(buf[:n])
} }
} }
} }

View File

@@ -0,0 +1,121 @@
package log_test
import (
"context"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
type discard struct{}
func (d *discard) Write(p []byte) (n int, err error) {
return len(p), nil
}
func BenchmarkLogger(b *testing.B) {
simpleMessage := "Connection established"
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4 // TCPStateEstablished
complexMessage := "Packet inspection result: protocol=%s, direction=%s, flags=0x%x, sequence=%d, acknowledged=%d, payload_size=%d, fragmented=%v, connection_id=%s"
protocol := "TCP"
direction := "outbound"
flags := uint16(0x18) // ACK + PSH
sequence := uint32(123456789)
acknowledged := uint32(987654321)
payloadSize := 1460
fragmented := false
connID := "f7a12b3e-c456-7890-d123-456789abcdef"
b.Run("SimpleMessage", func(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(simpleMessage)
}
})
b.Run("ConntrackMessage", func(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
}
})
b.Run("ComplexMessage", func(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(complexMessage, protocol, direction, flags, sequence, acknowledged, payloadSize, fragmented, connID)
}
})
}
// BenchmarkLoggerParallel tests the logger under concurrent load
func BenchmarkLoggerParallel(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
}
})
}
// BenchmarkLoggerBurst tests how the logger handles bursts of messages
func BenchmarkLoggerBurst(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j := 0; j < 100; j++ {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
}
}
}
func createTestLogger() *log.Logger {
logrusLogger := logrus.New()
logrusLogger.SetOutput(&discard{})
logrusLogger.SetLevel(logrus.TraceLevel)
return log.NewFromLogrus(logrusLogger)
}
func cleanupLogger(logger *log.Logger) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_ = logger.Stop(ctx)
}

View File

@@ -1,85 +0,0 @@
package log
import "sync"
// ringBuffer is a simple ring buffer implementation
type ringBuffer struct {
buf []byte
size int
r, w int64 // Read and write positions
mu sync.Mutex
}
func newRingBuffer(size int) *ringBuffer {
return &ringBuffer{
buf: make([]byte, size),
size: size,
}
}
func (r *ringBuffer) Write(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
r.mu.Lock()
defer r.mu.Unlock()
if len(p) > r.size {
p = p[:r.size]
}
n = len(p)
// Write data, handling wrap-around
pos := int(r.w % int64(r.size))
writeLen := min(len(p), r.size-pos)
copy(r.buf[pos:], p[:writeLen])
// If we have more data and need to wrap around
if writeLen < len(p) {
copy(r.buf, p[writeLen:])
}
// Update write position
r.w += int64(n)
return n, nil
}
func (r *ringBuffer) Read(p []byte) (n int, err error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.w == r.r {
return 0, nil
}
// Calculate available data accounting for wraparound
available := int(r.w - r.r)
if available < 0 {
available += r.size
}
available = min(available, r.size)
// Limit read to buffer size
toRead := min(available, len(p))
if toRead == 0 {
return 0, nil
}
// Read data, handling wrap-around
pos := int(r.r % int64(r.size))
readLen := min(toRead, r.size-pos)
n = copy(p, r.buf[pos:pos+readLen])
// If we need more data and need to wrap around
if readLen < toRead {
n += copy(p[readLen:toRead], r.buf[:toRead-readLen])
}
// Update read position
r.r += int64(n)
return n, nil
}

View File

@@ -1,7 +1,6 @@
package uspfilter package uspfilter
import ( import (
"net"
"net/netip" "net/netip"
"github.com/google/gopacket" "github.com/google/gopacket"
@@ -12,14 +11,14 @@ import (
// PeerRule to handle management of rules // PeerRule to handle management of rules
type PeerRule struct { type PeerRule struct {
id string id string
ip net.IP mgmtId []byte
ip netip.Addr
ipLayer gopacket.LayerType ipLayer gopacket.LayerType
matchByIP bool matchByIP bool
protoLayer gopacket.LayerType protoLayer gopacket.LayerType
sPort *firewall.Port sPort *firewall.Port
dPort *firewall.Port dPort *firewall.Port
drop bool drop bool
comment string
udpHook func([]byte) bool udpHook func([]byte) bool
} }
@@ -31,6 +30,7 @@ func (r *PeerRule) ID() string {
type RouteRule struct { type RouteRule struct {
id string id string
mgmtId []byte
sources []netip.Prefix sources []netip.Prefix
destination netip.Prefix destination netip.Prefix
proto firewall.Protocol proto firewall.Protocol

View File

@@ -2,7 +2,7 @@ package uspfilter
import ( import (
"fmt" "fmt"
"net" "net/netip"
"time" "time"
"github.com/google/gopacket" "github.com/google/gopacket"
@@ -53,8 +53,8 @@ type TraceResult struct {
} }
type PacketTrace struct { type PacketTrace struct {
SourceIP net.IP SourceIP netip.Addr
DestinationIP net.IP DestinationIP netip.Addr
Protocol string Protocol string
SourcePort uint16 SourcePort uint16
DestinationPort uint16 DestinationPort uint16
@@ -72,8 +72,8 @@ type TCPState struct {
} }
type PacketBuilder struct { type PacketBuilder struct {
SrcIP net.IP SrcIP netip.Addr
DstIP net.IP DstIP netip.Addr
Protocol fw.Protocol Protocol fw.Protocol
SrcPort uint16 SrcPort uint16
DstPort uint16 DstPort uint16
@@ -126,8 +126,8 @@ func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
Version: 4, Version: 4,
TTL: 64, TTL: 64,
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)), Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
SrcIP: p.SrcIP, SrcIP: p.SrcIP.AsSlice(),
DstIP: p.DstIP, DstIP: p.DstIP.AsSlice(),
} }
} }
@@ -260,28 +260,30 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
return m.traceInbound(packetData, trace, d, srcIP, dstIP) return m.traceInbound(packetData, trace, d, srcIP, dstIP)
} }
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace { func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) { if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
return trace return trace
} }
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) { if m.localipmanager.IsLocalIP(dstIP) {
return trace if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
return trace
}
} }
if !m.handleRouting(trace) { if !m.handleRouting(trace) {
return trace return trace
} }
if m.nativeRouter { if m.nativeRouter.Load() {
return m.handleNativeRouter(trace) return m.handleNativeRouter(trace)
} }
return m.handleRouteACLs(trace, d, srcIP, dstIP) return m.handleRouteACLs(trace, d, srcIP, dstIP)
} }
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool { func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) bool {
allowed := m.isValidTrackedConnection(d, srcIP, dstIP) allowed := m.isValidTrackedConnection(d, srcIP, dstIP, 0)
msg := "No existing connection found" msg := "No existing connection found"
if allowed { if allowed {
msg = m.buildConntrackStateMessage(d) msg = m.buildConntrackStateMessage(d)
@@ -309,32 +311,46 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
return msg return msg
} }
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool { func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
if !m.localForwarding { trace.AddResult(StageRouting, "Packet destined for local delivery", true)
trace.AddResult(StageRouting, "Local forwarding disabled", false)
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false) ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
strRuleId := "<no id>"
if ruleId != nil {
strRuleId = string(ruleId)
}
msg := fmt.Sprintf("Allowed by peer ACL rules (%s)", strRuleId)
if blocked {
msg = fmt.Sprintf("Blocked by peer ACL rules (%s)", strRuleId)
trace.AddResult(StagePeerACL, msg, false)
trace.AddResult(StageCompleted, "Packet dropped - ACL denied", false)
return true return true
} }
trace.AddResult(StageRouting, "Packet destined for local delivery", true) trace.AddResult(StagePeerACL, msg, true)
blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
msg := "Allowed by peer ACL rules"
if blocked {
msg = "Blocked by peer ACL rules"
}
trace.AddResult(StagePeerACL, msg, !blocked)
// Handle netstack mode
if m.netstack { if m.netstack {
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", !blocked) switch {
case !m.localForwarding:
trace.AddResult(StageCompleted, "Packet sent to virtual stack", true)
case m.forwarder.Load() != nil:
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", true)
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
default:
trace.AddResult(StageCompleted, "Packet dropped - forwarder not initialized", false)
}
return true
} }
trace.AddResult(StageCompleted, msgProcessingCompleted, !blocked) // In normal mode, packets are allowed through for local delivery
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
return true return true
} }
func (m *Manager) handleRouting(trace *PacketTrace) bool { func (m *Manager) handleRouting(trace *PacketTrace) bool {
if !m.routingEnabled { if !m.routingEnabled.Load() {
trace.AddResult(StageRouting, "Routing disabled", false) trace.AddResult(StageRouting, "Routing disabled", false)
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false) trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
return false return false
@@ -350,18 +366,23 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
return trace return trace
} }
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace { func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
proto := getProtocolFromPacket(d) proto, _ := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
msg := "Allowed by route ACLs" strId := string(id)
if id == nil {
strId = "<no id>"
}
msg := fmt.Sprintf("Allowed by route ACLs (%s)", strId)
if !allowed { if !allowed {
msg = "Blocked by route ACLs" msg = fmt.Sprintf("Blocked by route ACLs (%s)", strId)
} }
trace.AddResult(StageRouteACL, msg, allowed) trace.AddResult(StageRouteACL, msg, allowed)
if allowed && m.forwarder != nil { if allowed && m.forwarder.Load() != nil {
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true) m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
} }
@@ -380,7 +401,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace { func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
// will create or update the connection state // will create or update the connection state
dropped := m.processOutgoingHooks(packetData) dropped := m.processOutgoingHooks(packetData, 0)
if dropped { if dropped {
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false) trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
} else { } else {

View File

@@ -0,0 +1,440 @@
package uspfilter
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func verifyTraceStages(t *testing.T, trace *PacketTrace, expectedStages []PacketStage) {
t.Logf("Trace results: %v", trace.Results)
actualStages := make([]PacketStage, 0, len(trace.Results))
for _, result := range trace.Results {
actualStages = append(actualStages, result.Stage)
t.Logf("Stage: %s, Message: %s, Allowed: %v", result.Stage, result.Message, result.Allowed)
}
require.ElementsMatch(t, expectedStages, actualStages, "Trace stages don't match expected stages")
}
func verifyFinalDisposition(t *testing.T, trace *PacketTrace, expectedAllowed bool) {
require.NotEmpty(t, trace.Results, "Trace should have results")
lastResult := trace.Results[len(trace.Results)-1]
require.Equal(t, StageCompleted, lastResult.Stage, "Last stage should be 'Completed'")
require.Equal(t, expectedAllowed, lastResult.Allowed, "Final disposition incorrect")
}
func TestTracePacket(t *testing.T) {
setupTracerTest := func(statefulMode bool) *Manager {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
}
},
}
m, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
if !statefulMode {
m.stateful = false
}
return m
}
createPacketBuilder := func(srcIP, dstIP string, protocol fw.Protocol, srcPort, dstPort uint16, direction fw.RuleDirection) *PacketBuilder {
builder := &PacketBuilder{
SrcIP: netip.MustParseAddr(srcIP),
DstIP: netip.MustParseAddr(dstIP),
Protocol: protocol,
SrcPort: srcPort,
DstPort: dstPort,
Direction: direction,
}
if protocol == "tcp" {
builder.TCPState = &TCPState{SYN: true}
}
return builder
}
createICMPPacketBuilder := func(srcIP, dstIP string, icmpType, icmpCode uint8, direction fw.RuleDirection) *PacketBuilder {
return &PacketBuilder{
SrcIP: netip.MustParseAddr(srcIP),
DstIP: netip.MustParseAddr(dstIP),
Protocol: "icmp",
ICMPType: icmpType,
ICMPCode: icmpCode,
Direction: direction,
}
}
testCases := []struct {
name string
setup func(*Manager)
packetBuilder func() *PacketBuilder
expectedStages []PacketStage
expectedAllow bool
}{
{
name: "LocalTraffic_ACLAllowed",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "LocalTraffic_ACLDenied",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,
},
{
name: "LocalTraffic_WithForwarder",
setup: func(m *Manager) {
m.netstack = true
m.localForwarding = true
m.forwarder.Store(&forwarder.Forwarder{})
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageForwarding,
StageCompleted,
},
expectedAllow: true,
},
{
name: "LocalTraffic_WithoutForwarder",
setup: func(m *Manager) {
m.netstack = true
m.localForwarding = false
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "RoutedTraffic_ACLAllowed",
setup: func(m *Manager) {
m.routingEnabled.Store(true)
m.nativeRouter.Store(false)
m.forwarder.Store(&forwarder.Forwarder{})
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageRouteACL,
StageForwarding,
StageCompleted,
},
expectedAllow: true,
},
{
name: "RoutedTraffic_ACLDenied",
setup: func(m *Manager) {
m.routingEnabled.Store(true)
m.nativeRouter.Store(false)
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageRouteACL,
StageCompleted,
},
expectedAllow: false,
},
{
name: "RoutedTraffic_NativeRouter",
setup: func(m *Manager) {
m.routingEnabled.Store(true)
m.nativeRouter.Store(true)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageRouteACL,
StageForwarding,
StageCompleted,
},
expectedAllow: true,
},
{
name: "RoutedTraffic_RoutingDisabled",
setup: func(m *Manager) {
m.routingEnabled.Store(false)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageCompleted,
},
expectedAllow: false,
},
{
name: "ConnectionTracking_Hit",
setup: func(m *Manager) {
srcIP := netip.MustParseAddr("100.10.0.100")
dstIP := netip.MustParseAddr("1.1.1.1")
srcPort := uint16(12345)
dstPort := uint16(80)
m.tcpTracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, conntrack.TCPSyn, 0)
},
packetBuilder: func() *PacketBuilder {
pb := createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 80, 12345, fw.RuleDirectionIN)
pb.TCPState = &TCPState{SYN: true, ACK: true}
return pb
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageCompleted,
},
expectedAllow: true,
},
{
name: "OutboundTraffic",
setup: func(m *Manager) {
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("100.10.0.100", "1.1.1.1", "tcp", 12345, 80, fw.RuleDirectionOUT)
},
expectedStages: []PacketStage{
StageReceived,
StageCompleted,
},
expectedAllow: true,
},
{
name: "ICMPEchoRequest",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolICMP
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 8, 0, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "ICMPDestinationUnreachable",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolICMP
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 3, 0, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "UDPTraffic_WithoutHook",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolUDP
port := &fw.Port{Values: []uint16{53}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "UDPTraffic_WithHook",
setup: func(m *Manager) {
hookFunc := func([]byte) bool {
return true
}
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,
},
{
name: "StatefulDisabled_NoTracking",
setup: func(m *Manager) {
m.stateful = false
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
m := setupTracerTest(true)
tc.setup(m)
require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")),
"100.10.0.100 should be recognized as a local IP")
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("172.17.0.2")),
"172.17.0.2 should not be recognized as a local IP")
pb := tc.packetBuilder()
trace, err := m.TracePacketFromBuilder(pb)
require.NoError(t, err)
verifyTraceStages(t, trace, tc.expectedStages)
verifyFinalDisposition(t, trace, tc.expectedAllow)
})
}
}

View File

@@ -10,6 +10,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
@@ -22,6 +23,7 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder" "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -65,9 +67,9 @@ func (r RouteRules) Sort() {
// Manager userspace firewall manager // Manager userspace firewall manager
type Manager struct { type Manager struct {
// outgoingRules is used for hooks only // outgoingRules is used for hooks only
outgoingRules map[string]RuleSet outgoingRules map[netip.Addr]RuleSet
// incomingRules is used for filtering and hooks // incomingRules is used for filtering and hooks
incomingRules map[string]RuleSet incomingRules map[netip.Addr]RuleSet
routeRules RouteRules routeRules RouteRules
wgNetwork *net.IPNet wgNetwork *net.IPNet
decoders sync.Pool decoders sync.Pool
@@ -79,9 +81,9 @@ type Manager struct {
// indicates whether server routes are disabled // indicates whether server routes are disabled
disableServerRoutes bool disableServerRoutes bool
// indicates whether we forward packets not destined for ourselves // indicates whether we forward packets not destined for ourselves
routingEnabled bool routingEnabled atomic.Bool
// indicates whether we leave forwarding and filtering to the native firewall // indicates whether we leave forwarding and filtering to the native firewall
nativeRouter bool nativeRouter atomic.Bool
// indicates whether we track outbound connections // indicates whether we track outbound connections
stateful bool stateful bool
// indicates whether wireguards runs in netstack mode // indicates whether wireguards runs in netstack mode
@@ -94,8 +96,9 @@ type Manager struct {
udpTracker *conntrack.UDPTracker udpTracker *conntrack.UDPTracker
icmpTracker *conntrack.ICMPTracker icmpTracker *conntrack.ICMPTracker
tcpTracker *conntrack.TCPTracker tcpTracker *conntrack.TCPTracker
forwarder *forwarder.Forwarder forwarder atomic.Pointer[forwarder.Forwarder]
logger *nblog.Logger logger *nblog.Logger
flowLogger nftypes.FlowLogger
} }
// decoder for packages // decoder for packages
@@ -112,16 +115,16 @@ type decoder struct {
} }
// Create userspace firewall manager constructor // Create userspace firewall manager constructor
func Create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) { func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
return create(iface, nil, disableServerRoutes) return create(iface, nil, disableServerRoutes, flowLogger)
} }
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) { func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
if nativeFirewall == nil { if nativeFirewall == nil {
return nil, errors.New("native firewall is nil") return nil, errors.New("native firewall is nil")
} }
mgr, err := create(iface, nativeFirewall, disableServerRoutes) mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -148,7 +151,7 @@ func parseCreateEnv() (bool, bool) {
return disableConntrack, enableLocalForwarding return disableConntrack, enableLocalForwarding
} }
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) { func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
disableConntrack, enableLocalForwarding := parseCreateEnv() disableConntrack, enableLocalForwarding := parseCreateEnv()
m := &Manager{ m := &Manager{
@@ -166,17 +169,18 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
}, },
}, },
nativeFirewall: nativeFirewall, nativeFirewall: nativeFirewall,
outgoingRules: make(map[string]RuleSet), outgoingRules: make(map[netip.Addr]RuleSet),
incomingRules: make(map[string]RuleSet), incomingRules: make(map[netip.Addr]RuleSet),
wgIface: iface, wgIface: iface,
localipmanager: newLocalIPManager(), localipmanager: newLocalIPManager(),
disableServerRoutes: disableServerRoutes, disableServerRoutes: disableServerRoutes,
routingEnabled: false,
stateful: !disableConntrack, stateful: !disableConntrack,
logger: nblog.NewFromLogrus(log.StandardLogger()), logger: nblog.NewFromLogrus(log.StandardLogger()),
flowLogger: flowLogger,
netstack: netstack.IsEnabled(), netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding, localForwarding: enableLocalForwarding,
} }
m.routingEnabled.Store(false)
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil { if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
return nil, fmt.Errorf("update local IPs: %w", err) return nil, fmt.Errorf("update local IPs: %w", err)
@@ -185,9 +189,9 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
if disableConntrack { if disableConntrack {
log.Info("conntrack is disabled") log.Info("conntrack is disabled")
} else { } else {
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger) m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, flowLogger)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger) m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger)
} }
// netstack needs the forwarder for local traffic // netstack needs the forwarder for local traffic
@@ -208,7 +212,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
} }
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error { func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
if m.forwarder == nil { if m.forwarder.Load() == nil {
return nil return nil
} }
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String()) wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
@@ -218,6 +222,7 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
log.Debugf("blocking invalid routed traffic for %s", wgPrefix) log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
if _, err := m.AddRouteFiltering( if _, err := m.AddRouteFiltering(
nil,
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}, []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
wgPrefix, wgPrefix,
firewall.ProtocolALL, firewall.ProtocolALL,
@@ -251,20 +256,20 @@ func (m *Manager) determineRouting() error {
switch { switch {
case disableUspRouting: case disableUspRouting:
m.routingEnabled = false m.routingEnabled.Store(false)
m.nativeRouter = false m.nativeRouter.Store(false)
log.Info("userspace routing is disabled") log.Info("userspace routing is disabled")
case m.disableServerRoutes: case m.disableServerRoutes:
// if server routes are disabled we will let packets pass to the native stack // if server routes are disabled we will let packets pass to the native stack
m.routingEnabled = true m.routingEnabled.Store(true)
m.nativeRouter = true m.nativeRouter.Store(true)
log.Info("server routes are disabled") log.Info("server routes are disabled")
case forceUserspaceRouter: case forceUserspaceRouter:
m.routingEnabled = true m.routingEnabled.Store(true)
m.nativeRouter = false m.nativeRouter.Store(false)
log.Info("userspace routing is forced") log.Info("userspace routing is forced")
@@ -272,19 +277,19 @@ func (m *Manager) determineRouting() error {
// if the OS supports routing natively, then we don't need to filter/route ourselves // if the OS supports routing natively, then we don't need to filter/route ourselves
// netstack mode won't support native routing as there is no interface // netstack mode won't support native routing as there is no interface
m.routingEnabled = true m.routingEnabled.Store(true)
m.nativeRouter = true m.nativeRouter.Store(true)
log.Info("native routing is enabled") log.Info("native routing is enabled")
default: default:
m.routingEnabled = true m.routingEnabled.Store(true)
m.nativeRouter = false m.nativeRouter.Store(false)
log.Info("userspace routing enabled by default") log.Info("userspace routing enabled by default")
} }
if m.routingEnabled && !m.nativeRouter { if m.routingEnabled.Load() && !m.nativeRouter.Load() {
return m.initForwarder() return m.initForwarder()
} }
@@ -293,24 +298,24 @@ func (m *Manager) determineRouting() error {
// initForwarder initializes the forwarder, it disables routing on errors // initForwarder initializes the forwarder, it disables routing on errors
func (m *Manager) initForwarder() error { func (m *Manager) initForwarder() error {
if m.forwarder != nil { if m.forwarder.Load() != nil {
return nil return nil
} }
// Only supported in userspace mode as we need to inject packets back into wireguard directly // Only supported in userspace mode as we need to inject packets back into wireguard directly
intf := m.wgIface.GetWGDevice() intf := m.wgIface.GetWGDevice()
if intf == nil { if intf == nil {
m.routingEnabled = false m.routingEnabled.Store(false)
return errors.New("forwarding not supported") return errors.New("forwarding not supported")
} }
forwarder, err := forwarder.New(m.wgIface, m.logger, m.netstack) forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack)
if err != nil { if err != nil {
m.routingEnabled = false m.routingEnabled.Store(false)
return fmt.Errorf("create forwarder: %w", err) return fmt.Errorf("create forwarder: %w", err)
} }
m.forwarder = forwarder m.forwarder.Store(forwarder)
log.Debug("forwarder initialized") log.Debug("forwarder initialized")
@@ -326,7 +331,7 @@ func (m *Manager) IsServerRouteSupported() bool {
} }
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if m.nativeRouter && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.AddNatRule(pair) return m.nativeFirewall.AddNatRule(pair)
} }
@@ -337,7 +342,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
// RemoveNatRule removes a routing firewall rule // RemoveNatRule removes a routing firewall rule
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
if m.nativeRouter && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.RemoveNatRule(pair) return m.nativeFirewall.RemoveNatRule(pair)
} }
return nil return nil
@@ -348,25 +353,31 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
func (m *Manager) AddPeerFiltering( func (m *Manager) AddPeerFiltering(
id []byte,
ip net.IP, ip net.IP,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
action firewall.Action, action firewall.Action,
_ string, _ string,
comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
// TODO: fix in upper layers
i, ok := netip.AddrFromSlice(ip)
if !ok {
return nil, fmt.Errorf("invalid IP: %s", ip)
}
i = i.Unmap()
r := PeerRule{ r := PeerRule{
id: uuid.New().String(), id: uuid.New().String(),
ip: ip, mgmtId: id,
ip: i,
ipLayer: layers.LayerTypeIPv6, ipLayer: layers.LayerTypeIPv6,
matchByIP: true, matchByIP: true,
drop: action == firewall.ActionDrop, drop: action == firewall.ActionDrop,
comment: comment,
} }
if ipNormalized := ip.To4(); ipNormalized != nil { if i.Is4() {
r.ipLayer = layers.LayerTypeIPv4 r.ipLayer = layers.LayerTypeIPv4
r.ip = ipNormalized
} }
if s := r.ip.String(); s == "0.0.0.0" || s == "::" { if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
@@ -391,15 +402,16 @@ func (m *Manager) AddPeerFiltering(
} }
m.mutex.Lock() m.mutex.Lock()
if _, ok := m.incomingRules[r.ip.String()]; !ok { if _, ok := m.incomingRules[r.ip]; !ok {
m.incomingRules[r.ip.String()] = make(RuleSet) m.incomingRules[r.ip] = make(RuleSet)
} }
m.incomingRules[r.ip.String()][r.id] = r m.incomingRules[r.ip][r.id] = r
m.mutex.Unlock() m.mutex.Unlock()
return []firewall.Rule{&r}, nil return []firewall.Rule{&r}, nil
} }
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination netip.Prefix,
proto firewall.Protocol, proto firewall.Protocol,
@@ -407,16 +419,15 @@ func (m *Manager) AddRouteFiltering(
dPort *firewall.Port, dPort *firewall.Port,
action firewall.Action, action firewall.Action,
) (firewall.Rule, error) { ) (firewall.Rule, error) {
if m.nativeRouter && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action) return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
} }
m.mutex.Lock()
defer m.mutex.Unlock()
ruleID := uuid.New().String() ruleID := uuid.New().String()
rule := RouteRule{ rule := RouteRule{
// TODO: consolidate these IDs
id: ruleID, id: ruleID,
mgmtId: id,
sources: sources, sources: sources,
destination: destination, destination: destination,
proto: proto, proto: proto,
@@ -425,14 +436,16 @@ func (m *Manager) AddRouteFiltering(
action: action, action: action,
} }
m.mutex.Lock()
m.routeRules = append(m.routeRules, rule) m.routeRules = append(m.routeRules, rule)
m.routeRules.Sort() m.routeRules.Sort()
m.mutex.Unlock()
return &rule, nil return &rule, nil
} }
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
if m.nativeRouter && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.DeleteRouteRule(rule) return m.nativeFirewall.DeleteRouteRule(rule)
} }
@@ -461,10 +474,10 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
return fmt.Errorf("delete rule: invalid rule type: %T", rule) return fmt.Errorf("delete rule: invalid rule type: %T", rule)
} }
if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok { if _, ok := m.incomingRules[r.ip][r.id]; !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id) return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
} }
delete(m.incomingRules[r.ip.String()], r.id) delete(m.incomingRules[r.ip], r.id)
return nil return nil
} }
@@ -497,13 +510,13 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
} }
// DropOutgoing filter outgoing packets // DropOutgoing filter outgoing packets
func (m *Manager) DropOutgoing(packetData []byte) bool { func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
return m.processOutgoingHooks(packetData) return m.processOutgoingHooks(packetData, size)
} }
// DropIncoming filter incoming packets // DropIncoming filter incoming packets
func (m *Manager) DropIncoming(packetData []byte) bool { func (m *Manager) DropIncoming(packetData []byte, size int) bool {
return m.dropFilter(packetData) return m.dropFilter(packetData, size)
} }
// UpdateLocalIPs updates the list of local IPs // UpdateLocalIPs updates the list of local IPs
@@ -511,10 +524,7 @@ func (m *Manager) UpdateLocalIPs() error {
return m.localipmanager.UpdateLocalIPs(m.wgIface) return m.localipmanager.UpdateLocalIPs(m.wgIface)
} }
func (m *Manager) processOutgoingHooks(packetData []byte) bool { func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
d := m.decoders.Get().(*decoder) d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d) defer m.decoders.Put(d)
@@ -527,52 +537,37 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
} }
srcIP, dstIP := m.extractIPs(d) srcIP, dstIP := m.extractIPs(d)
if srcIP == nil { if !srcIP.IsValid() {
m.logger.Error("Unknown network layer: %v", d.decoded[0])
return false return false
} }
// Track all protocols if stateful mode is enabled if d.decoded[1] == layers.LayerTypeUDP && m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) {
if m.stateful { return true
switch d.decoded[1] {
case layers.LayerTypeUDP:
m.trackUDPOutbound(d, srcIP, dstIP)
case layers.LayerTypeTCP:
m.trackTCPOutbound(d, srcIP, dstIP)
case layers.LayerTypeICMPv4:
m.trackICMPOutbound(d, srcIP, dstIP)
}
} }
// Process UDP hooks even if stateful mode is disabled if m.stateful {
if d.decoded[1] == layers.LayerTypeUDP { m.trackOutbound(d, srcIP, dstIP, size)
return m.checkUDPHooks(d, dstIP, packetData)
} }
return false return false
} }
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) { func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP netip.Addr) {
switch d.decoded[0] { switch d.decoded[0] {
case layers.LayerTypeIPv4: case layers.LayerTypeIPv4:
return d.ip4.SrcIP, d.ip4.DstIP src, _ := netip.AddrFromSlice(d.ip4.SrcIP)
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
return src, dst
case layers.LayerTypeIPv6: case layers.LayerTypeIPv6:
return d.ip6.SrcIP, d.ip6.DstIP src, _ := netip.AddrFromSlice(d.ip6.SrcIP)
dst, _ := netip.AddrFromSlice(d.ip6.DstIP)
return src, dst
default: default:
return nil, nil return netip.Addr{}, netip.Addr{}
} }
} }
func (m *Manager) trackTCPOutbound(d *decoder, srcIP, dstIP net.IP) {
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(
srcIP,
dstIP,
uint16(d.tcp.SrcPort),
uint16(d.tcp.DstPort),
flags,
)
}
func getTCPFlags(tcp *layers.TCP) uint8 { func getTCPFlags(tcp *layers.TCP) uint8 {
var flags uint8 var flags uint8
if tcp.SYN { if tcp.SYN {
@@ -596,45 +591,70 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
return flags return flags
} }
func (m *Manager) trackUDPOutbound(d *decoder, srcIP, dstIP net.IP) { func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
m.udpTracker.TrackOutbound( transport := d.decoded[1]
srcIP, switch transport {
dstIP, case layers.LayerTypeUDP:
uint16(d.udp.SrcPort), m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
uint16(d.udp.DstPort), case layers.LayerTypeTCP:
) flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size)
}
} }
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool { func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byte, size int) {
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} { transport := d.decoded[1]
if rules, exists := m.outgoingRules[ipKey]; exists { switch transport {
for _, rule := range rules { case layers.LayerTypeUDP:
if rule.udpHook != nil && portsMatch(rule.dPort, uint16(d.udp.DstPort)) { m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size)
return rule.udpHook(packetData) case layers.LayerTypeTCP:
} flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size)
}
}
// udpHooksDrop checks if any UDP hooks should drop the packet
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
// Check specific destination IP first
if rules, exists := m.outgoingRules[dstIP]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
} }
} }
} }
return false
}
func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) { // Check IPv4 unspecified address
if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest { if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
m.icmpTracker.TrackOutbound( for _, rule := range rules {
srcIP, if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
dstIP, return rule.udpHook(packetData)
d.icmp4.Id, }
d.icmp4.Seq, }
)
} }
// Check IPv6 unspecified address
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
}
return false
} }
// dropFilter implements filtering logic for incoming packets. // dropFilter implements filtering logic for incoming packets.
// If it returns true, the packet should be dropped. // If it returns true, the packet should be dropped.
func (m *Manager) dropFilter(packetData []byte) bool { func (m *Manager) dropFilter(packetData []byte, size int) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
d := m.decoders.Get().(*decoder) d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d) defer m.decoders.Put(d)
@@ -643,19 +663,19 @@ func (m *Manager) dropFilter(packetData []byte) bool {
} }
srcIP, dstIP := m.extractIPs(d) srcIP, dstIP := m.extractIPs(d)
if srcIP == nil { if !srcIP.IsValid() {
m.logger.Error("Unknown network layer: %v", d.decoded[0]) m.logger.Error("Unknown network layer: %v", d.decoded[0])
return true return true
} }
// For all inbound traffic, first check if it matches a tracked connection. // For all inbound traffic, first check if it matches a tracked connection.
// This must happen before any other filtering because the packets are statefully tracked. // This must happen before any other filtering because the packets are statefully tracked.
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) { if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
return false return false
} }
if m.localipmanager.IsLocalIP(dstIP) { if m.localipmanager.IsLocalIP(dstIP) {
return m.handleLocalTraffic(d, srcIP, dstIP, packetData) return m.handleLocalTraffic(d, srcIP, dstIP, packetData, size)
} }
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData) return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
@@ -663,10 +683,29 @@ func (m *Manager) dropFilter(packetData []byte) bool {
// handleLocalTraffic handles local traffic. // handleLocalTraffic handles local traffic.
// If it returns true, the packet should be dropped. // If it returns true, the packet should be dropped.
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool { func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
if m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) { ruleID, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
m.logger.Trace("Dropping local packet (ACL denied): src=%s dst=%s", if blocked {
srcIP, dstIP) _, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(),
Type: nftypes.TypeDrop,
RuleID: ruleID,
Direction: nftypes.Ingress,
Protocol: pnum,
SourceIP: srcIP,
DestIP: dstIP,
SourcePort: srcPort,
DestPort: dstPort,
// TODO: icmp type/code
RxPackets: 1,
RxBytes: uint64(size),
})
return true return true
} }
@@ -675,6 +714,9 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData
return m.handleNetstackLocalTraffic(packetData) return m.handleNetstackLocalTraffic(packetData)
} }
// track inbound packets to get the correct direction and session id for flows
m.trackInbound(d, srcIP, dstIP, ruleID, size)
return false return false
} }
@@ -684,12 +726,12 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
return false return false
} }
if m.forwarder == nil { if m.forwarder.Load() == nil {
m.logger.Trace("Dropping local packet (forwarder not initialized)") m.logger.Trace("Dropping local packet (forwarder not initialized)")
return true return true
} }
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil { if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject local packet: %v", err) m.logger.Error("Failed to inject local packet: %v", err)
} }
@@ -699,30 +741,43 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
// handleRoutedTraffic handles routed traffic. // handleRoutedTraffic handles routed traffic.
// If it returns true, the packet should be dropped. // If it returns true, the packet should be dropped.
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool { func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte) bool {
// Drop if routing is disabled // Drop if routing is disabled
if !m.routingEnabled { if !m.routingEnabled.Load() {
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s", m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
srcIP, dstIP) srcIP, dstIP)
return true return true
} }
// Pass to native stack if native router is enabled or forced // Pass to native stack if native router is enabled or forced
if m.nativeRouter { if m.nativeRouter.Load() {
return false return false
} }
proto := getProtocolFromPacket(d) proto, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
if !m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) { if ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass {
m.logger.Trace("Dropping routed packet (ACL denied): src=%s:%d dst=%s:%d proto=%v", m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
srcIP, srcPort, dstIP, dstPort, proto) ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(),
Type: nftypes.TypeDrop,
RuleID: ruleID,
Direction: nftypes.Ingress,
Protocol: pnum,
SourceIP: srcIP,
DestIP: dstIP,
SourcePort: srcPort,
DestPort: dstPort,
// TODO: icmp type/code
})
return true return true
} }
// Let forwarder handle the packet if it passed route ACLs // Let forwarder handle the packet if it passed route ACLs
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil { if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject incoming packet: %v", err) m.logger.Error("Failed to inject incoming packet: %v", err)
} }
@@ -730,16 +785,16 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat
return true return true
} }
func getProtocolFromPacket(d *decoder) firewall.Protocol { func getProtocolFromPacket(d *decoder) (firewall.Protocol, nftypes.Protocol) {
switch d.decoded[1] { switch d.decoded[1] {
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
return firewall.ProtocolTCP return firewall.ProtocolTCP, nftypes.TCP
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
return firewall.ProtocolUDP return firewall.ProtocolUDP, nftypes.UDP
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return firewall.ProtocolICMP return firewall.ProtocolICMP, nftypes.ICMP
default: default:
return firewall.ProtocolALL return firewall.ProtocolALL, nftypes.ProtocolUnknown
} }
} }
@@ -767,7 +822,7 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
return true return true
} }
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool { func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr, size int) bool {
switch d.decoded[1] { switch d.decoded[1] {
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
return m.tcpTracker.IsValidInbound( return m.tcpTracker.IsValidInbound(
@@ -776,6 +831,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
uint16(d.tcp.SrcPort), uint16(d.tcp.SrcPort),
uint16(d.tcp.DstPort), uint16(d.tcp.DstPort),
getTCPFlags(&d.tcp), getTCPFlags(&d.tcp),
size,
) )
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
@@ -784,6 +840,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
dstIP, dstIP,
uint16(d.udp.SrcPort), uint16(d.udp.SrcPort),
uint16(d.udp.DstPort), uint16(d.udp.DstPort),
size,
) )
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
@@ -791,8 +848,8 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
srcIP, srcIP,
dstIP, dstIP,
d.icmp4.Id, d.icmp4.Id,
d.icmp4.Seq,
d.icmp4.TypeCode.Type(), d.icmp4.TypeCode.Type(),
size,
) )
// TODO: ICMPv6 // TODO: ICMPv6
@@ -812,25 +869,27 @@ func (m *Manager) isSpecialICMP(d *decoder) bool {
icmpType == layers.ICMPv4TypeTimeExceeded icmpType == layers.ICMPv4TypeTimeExceeded
} }
func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool { func (m *Manager) peerACLsBlock(srcIP netip.Addr, packetData []byte, rules map[netip.Addr]RuleSet, d *decoder) ([]byte, bool) {
m.mutex.RLock()
defer m.mutex.RUnlock()
if m.isSpecialICMP(d) { if m.isSpecialICMP(d) {
return false return nil, false
} }
if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok { if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP], d); ok {
return filter return mgmtId, filter
} }
if filter, ok := validateRule(srcIP, packetData, rules["0.0.0.0"], d); ok { if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv4Unspecified()], d); ok {
return filter return mgmtId, filter
} }
if filter, ok := validateRule(srcIP, packetData, rules["::"], d); ok { if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv6Unspecified()], d); ok {
return filter return mgmtId, filter
} }
// Default policy: DROP ALL // Default policy: DROP ALL
return true return nil, true
} }
func portsMatch(rulePort *firewall.Port, packetPort uint16) bool { func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
@@ -850,15 +909,15 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
return false return false
} }
func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) (bool, bool) { func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
payloadLayer := d.decoded[1] payloadLayer := d.decoded[1]
for _, rule := range rules { for _, rule := range rules {
if rule.matchByIP && !ip.Equal(rule.ip) { if rule.matchByIP && ip.Compare(rule.ip) != 0 {
continue continue
} }
if rule.protoLayer == layerTypeAll { if rule.protoLayer == layerTypeAll {
return rule.drop, true return rule.mgmtId, rule.drop, true
} }
if payloadLayer != rule.protoLayer { if payloadLayer != rule.protoLayer {
@@ -868,39 +927,36 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *de
switch payloadLayer { switch payloadLayer {
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) { if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) {
return rule.drop, true return rule.mgmtId, rule.drop, true
} }
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
// if rule has UDP hook (and if we are here we match this rule) // if rule has UDP hook (and if we are here we match this rule)
// we ignore rule.drop and call this hook // we ignore rule.drop and call this hook
if rule.udpHook != nil { if rule.udpHook != nil {
return rule.udpHook(packetData), true return rule.mgmtId, rule.udpHook(packetData), true
} }
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) { if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
return rule.drop, true return rule.mgmtId, rule.drop, true
} }
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return rule.drop, true return rule.mgmtId, rule.drop, true
} }
} }
return false, false return nil, false, false
} }
// routeACLsPass returns treu if the packet is allowed by the route ACLs // routeACLsPass returns true if the packet is allowed by the route ACLs
func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) bool { func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
m.mutex.RLock() m.mutex.RLock()
defer m.mutex.RUnlock() defer m.mutex.RUnlock()
srcAddr := netip.AddrFrom4([4]byte(srcIP.To4()))
dstAddr := netip.AddrFrom4([4]byte(dstIP.To4()))
for _, rule := range m.routeRules { for _, rule := range m.routeRules {
if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) { if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches {
return rule.action == firewall.ActionAccept return rule.mgmtId, rule.action == firewall.ActionAccept
} }
} }
return false return nil, false
} }
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool { func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
@@ -940,36 +996,32 @@ func (m *Manager) SetNetwork(network *net.IPNet) {
// AddUDPPacketHook calls hook when UDP packet from given direction matched // AddUDPPacketHook calls hook when UDP packet from given direction matched
// //
// Hook function returns flag which indicates should be the matched package dropped or not // Hook function returns flag which indicates should be the matched package dropped or not
func (m *Manager) AddUDPPacketHook( func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
) string {
r := PeerRule{ r := PeerRule{
id: uuid.New().String(), id: uuid.New().String(),
ip: ip, ip: ip,
protoLayer: layers.LayerTypeUDP, protoLayer: layers.LayerTypeUDP,
dPort: &firewall.Port{Values: []uint16{dPort}}, dPort: &firewall.Port{Values: []uint16{dPort}},
ipLayer: layers.LayerTypeIPv6, ipLayer: layers.LayerTypeIPv6,
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
udpHook: hook, udpHook: hook,
} }
if ip.To4() != nil { if ip.Is4() {
r.ipLayer = layers.LayerTypeIPv4 r.ipLayer = layers.LayerTypeIPv4
} }
m.mutex.Lock() m.mutex.Lock()
if in { if in {
if _, ok := m.incomingRules[r.ip.String()]; !ok { if _, ok := m.incomingRules[r.ip]; !ok {
m.incomingRules[r.ip.String()] = make(map[string]PeerRule) m.incomingRules[r.ip] = make(map[string]PeerRule)
} }
m.incomingRules[r.ip.String()][r.id] = r m.incomingRules[r.ip][r.id] = r
} else { } else {
if _, ok := m.outgoingRules[r.ip.String()]; !ok { if _, ok := m.outgoingRules[r.ip]; !ok {
m.outgoingRules[r.ip.String()] = make(map[string]PeerRule) m.outgoingRules[r.ip] = make(map[string]PeerRule)
} }
m.outgoingRules[r.ip.String()][r.id] = r m.outgoingRules[r.ip][r.id] = r
} }
m.mutex.Unlock() m.mutex.Unlock()
return r.id return r.id
@@ -1017,20 +1069,21 @@ func (m *Manager) DisableRouting() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if m.forwarder == nil { fwder := m.forwarder.Load()
if fwder == nil {
return nil return nil
} }
m.routingEnabled = false m.routingEnabled.Store(false)
m.nativeRouter = false m.nativeRouter.Store(false)
// don't stop forwarder if in use by netstack // don't stop forwarder if in use by netstack
if m.netstack && m.localForwarding { if m.netstack && m.localForwarding {
return nil return nil
} }
m.forwarder.Stop() fwder.Stop()
m.forwarder = nil m.forwarder.Store(nil)
log.Debug("forwarder stopped") log.Debug("forwarder stopped")

View File

@@ -93,8 +93,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
stateful: false, stateful: false,
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
// Single rule allowing all traffic // Single rule allowing all traffic
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, _, err := m.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
fw.ActionAccept, "", "allow all")
require.NoError(b, err) require.NoError(b, err)
}, },
desc: "Baseline: Single 'allow all' rule without connection tracking", desc: "Baseline: Single 'allow all' rule without connection tracking",
@@ -114,10 +113,15 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Add explicit rules matching return traffic pattern // Add explicit rules matching return traffic pattern
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
ip := generateRandomIPs(1)[0] ip := generateRandomIPs(1)[0]
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP, _, err := m.AddPeerFiltering(
nil,
ip,
fw.ProtocolTCP,
&fw.Port{Values: []uint16{uint16(1024 + i)}}, &fw.Port{Values: []uint16{uint16(1024 + i)}},
&fw.Port{Values: []uint16{80}}, &fw.Port{Values: []uint16{80}},
fw.ActionAccept, "", "explicit return") fw.ActionAccept,
"",
)
require.NoError(b, err) require.NoError(b, err)
} }
}, },
@@ -128,8 +132,15 @@ func BenchmarkCoreFiltering(b *testing.B) {
stateful: true, stateful: true,
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
// Add some basic rules but rely on state for established connections // Add some basic rules but rely on state for established connections
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil, _, err := m.AddPeerFiltering(
fw.ActionDrop, "", "default drop") nil,
net.ParseIP("0.0.0.0"),
fw.ProtocolTCP,
nil,
nil,
fw.ActionDrop,
"",
)
require.NoError(b, err) require.NoError(b, err)
}, },
desc: "Connection tracking with established connections", desc: "Connection tracking with established connections",
@@ -158,9 +169,9 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Create manager and basic setup // Create manager and basic setup
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{ manager.wgNetwork = &net.IPNet{
@@ -182,13 +193,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
// For stateful scenarios, establish the connection // For stateful scenarios, establish the connection
if sc.stateful { if sc.stateful {
manager.processOutgoingHooks(outbound) manager.processOutgoingHooks(outbound, 0)
} }
// Measure inbound packet processing // Measure inbound packet processing
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(inbound) manager.dropFilter(inbound, 0)
} }
}) })
} }
@@ -203,9 +214,9 @@ func BenchmarkStateScaling(b *testing.B) {
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) { b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{ manager.wgNetwork = &net.IPNet{
@@ -219,7 +230,7 @@ func BenchmarkStateScaling(b *testing.B) {
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
outbound := generatePacket(b, srcIPs[i], dstIPs[i], outbound := generatePacket(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, layers.IPProtocolTCP) uint16(1024+i), 80, layers.IPProtocolTCP)
manager.processOutgoingHooks(outbound) manager.processOutgoingHooks(outbound, 0)
} }
// Test packet // Test packet
@@ -227,11 +238,11 @@ func BenchmarkStateScaling(b *testing.B) {
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP) testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
// First establish our test connection // First establish our test connection
manager.processOutgoingHooks(testOut) manager.processOutgoingHooks(testOut, 0)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(testIn) manager.dropFilter(testIn, 0)
} }
}) })
} }
@@ -251,9 +262,9 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
b.Run(sc.name, func(b *testing.B) { b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{ manager.wgNetwork = &net.IPNet{
@@ -267,12 +278,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
if sc.established { if sc.established {
manager.processOutgoingHooks(outbound) manager.processOutgoingHooks(outbound, 0)
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(inbound) manager.dropFilter(inbound, 0)
} }
}) })
} }
@@ -450,9 +461,9 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
b.Run(sc.name, func(b *testing.B) { b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Close(nil))
}) })
// Setup scenario // Setup scenario
@@ -466,25 +477,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
// For stateful cases and established connections // For stateful cases and established connections
if !strings.Contains(sc.name, "allow_non_wg") || if !strings.Contains(sc.name, "allow_non_wg") ||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") { (strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
manager.processOutgoingHooks(outbound) manager.processOutgoingHooks(outbound, 0)
// For TCP post-handshake, simulate full handshake // For TCP post-handshake, simulate full handshake
if sc.state == "post_handshake" { if sc.state == "post_handshake" {
// SYN // SYN
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn)) syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn) manager.processOutgoingHooks(syn, 0)
// SYN-ACK // SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck)) synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack) manager.dropFilter(synack, 0)
// ACK // ACK
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)) ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack) manager.processOutgoingHooks(ack, 0)
} }
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(inbound) manager.dropFilter(inbound, 0)
} }
}) })
} }
@@ -577,9 +588,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{ manager.SetNetwork(&net.IPNet{
@@ -590,10 +601,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
// Single rule to allow all return traffic from port 80 // Single rule to allow all return traffic from port 80
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
&fw.Port{Values: []uint16{80}},
nil,
fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }
@@ -616,17 +624,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Initial SYN // Initial SYN
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn)) uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn) manager.processOutgoingHooks(syn, 0)
// SYN-ACK // SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack) manager.dropFilter(synack, 0)
// ACK // ACK
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck)) uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack) manager.processOutgoingHooks(ack, 0)
} }
// Prepare test packets simulating bidirectional traffic // Prepare test packets simulating bidirectional traffic
@@ -647,9 +655,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Simulate bidirectional traffic // Simulate bidirectional traffic
// First outbound data // First outbound data
manager.processOutgoingHooks(outPackets[connIdx]) manager.processOutgoingHooks(outPackets[connIdx], 0)
// Then inbound response - this is what we're actually measuring // Then inbound response - this is what we're actually measuring
manager.dropFilter(inPackets[connIdx]) manager.dropFilter(inPackets[connIdx], 0)
} }
}) })
} }
@@ -668,9 +676,9 @@ func BenchmarkShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{ manager.SetNetwork(&net.IPNet{
@@ -681,10 +689,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
// Single rule to allow all return traffic from port 80 // Single rule to allow all return traffic from port 80
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
&fw.Port{Values: []uint16{80}},
nil,
fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }
@@ -756,19 +761,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
p := patterns[connIdx] p := patterns[connIdx]
// Connection establishment // Connection establishment
manager.processOutgoingHooks(p.syn) manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck) manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack) manager.processOutgoingHooks(p.ack, 0)
// Data transfer // Data transfer
manager.processOutgoingHooks(p.request) manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response) manager.dropFilter(p.response, 0)
// Connection teardown // Connection teardown
manager.processOutgoingHooks(p.finClient) manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer) manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer) manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient) manager.processOutgoingHooks(p.ackClient, 0)
} }
}) })
} }
@@ -787,9 +792,9 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{ manager.SetNetwork(&net.IPNet{
@@ -799,10 +804,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
&fw.Port{Values: []uint16{80}},
nil,
fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }
@@ -824,15 +826,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
for i := 0; i < sc.connCount; i++ { for i := 0; i < sc.connCount; i++ {
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn)) uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn) manager.processOutgoingHooks(syn, 0)
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack) manager.dropFilter(synack, 0)
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck)) uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack) manager.processOutgoingHooks(ack, 0)
} }
// Pre-generate test packets // Pre-generate test packets
@@ -854,8 +856,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
counter++ counter++
// Simulate bidirectional traffic // Simulate bidirectional traffic
manager.processOutgoingHooks(outPackets[connIdx]) manager.processOutgoingHooks(outPackets[connIdx], 0)
manager.dropFilter(inPackets[connIdx]) manager.dropFilter(inPackets[connIdx], 0)
} }
}) })
}) })
@@ -875,9 +877,9 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{ manager.SetNetwork(&net.IPNet{
@@ -886,10 +888,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
}) })
if sc.rules { if sc.rules {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
&fw.Port{Values: []uint16{80}},
nil,
fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }
@@ -951,17 +950,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
p := patterns[connIdx] p := patterns[connIdx]
// Full connection lifecycle // Full connection lifecycle
manager.processOutgoingHooks(p.syn) manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck) manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack) manager.processOutgoingHooks(p.ack, 0)
manager.processOutgoingHooks(p.request) manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response) manager.dropFilter(p.response, 0)
manager.processOutgoingHooks(p.finClient) manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer) manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer) manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient) manager.processOutgoingHooks(p.ackClient, 0)
} }
}) })
}) })
@@ -1033,14 +1032,7 @@ func BenchmarkRouteACLs(b *testing.B) {
} }
for _, r := range rules { for _, r := range rules {
_, err := manager.AddRouteFiltering( _, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept)
r.sources,
r.dest,
r.proto,
nil,
r.port,
fw.ActionAccept,
)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
@@ -1062,8 +1054,8 @@ func BenchmarkRouteACLs(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
for _, tc := range cases { for _, tc := range cases {
srcIP := net.ParseIP(tc.srcIP) srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := net.ParseIP(tc.dstIP) dstIP := netip.MustParseAddr(tc.dstIP)
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort) manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
} }
} }

View File

@@ -12,9 +12,9 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
func TestPeerACLFiltering(t *testing.T) { func TestPeerACLFiltering(t *testing.T) {
@@ -26,20 +26,20 @@ func TestPeerACLFiltering(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: localIP, IP: localIP,
Network: wgNet, Network: wgNet,
} }
}, },
} }
manager, err := Create(ifaceMock, false) manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, manager) require.NotNil(t, manager)
t.Cleanup(func() { t.Cleanup(func() {
require.NoError(t, manager.Reset(nil)) require.NoError(t, manager.Close(nil))
}) })
manager.wgNetwork = wgNet manager.wgNetwork = wgNet
@@ -192,20 +192,20 @@ func TestPeerACLFiltering(t *testing.T) {
t.Run("Implicit DROP (no rules)", func(t *testing.T) { t.Run("Implicit DROP (no rules)", func(t *testing.T) {
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443) packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
isDropped := manager.DropIncoming(packet) isDropped := manager.DropIncoming(packet, 0)
require.True(t, isDropped, "Packet should be dropped when no rules exist") require.True(t, isDropped, "Packet should be dropped when no rules exist")
}) })
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
rules, err := manager.AddPeerFiltering( rules, err := manager.AddPeerFiltering(
nil,
net.ParseIP(tc.ruleIP), net.ParseIP(tc.ruleIP),
tc.ruleProto, tc.ruleProto,
tc.ruleSrcPort, tc.ruleSrcPort,
tc.ruleDstPort, tc.ruleDstPort,
tc.ruleAction, tc.ruleAction,
"", "",
tc.name,
) )
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, rules) require.NotEmpty(t, rules)
@@ -217,7 +217,7 @@ func TestPeerACLFiltering(t *testing.T) {
}) })
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort) packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
isDropped := manager.DropIncoming(packet) isDropped := manager.DropIncoming(packet, 0)
require.Equal(t, tc.shouldBeBlocked, isDropped) require.Equal(t, tc.shouldBeBlocked, isDropped)
}) })
} }
@@ -288,8 +288,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: localIP, IP: localIP,
Network: wgNet, Network: wgNet,
} }
@@ -302,15 +302,15 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
}, },
} }
manager, err := Create(ifaceMock, false) manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(tb, manager.EnableRouting()) require.NoError(tb, manager.EnableRouting())
require.NoError(tb, err) require.NoError(tb, err)
require.NotNil(tb, manager) require.NotNil(tb, manager)
require.True(tb, manager.routingEnabled) require.True(tb, manager.routingEnabled.Load())
require.False(tb, manager.nativeRouter) require.False(tb, manager.nativeRouter.Load())
tb.Cleanup(func() { tb.Cleanup(func() {
require.NoError(tb, manager.Reset(nil)) require.NoError(tb, manager.Close(nil))
}) })
return manager return manager
@@ -803,6 +803,7 @@ func TestRouteACLFiltering(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
rule, err := manager.AddRouteFiltering( rule, err := manager.AddRouteFiltering(
nil,
tc.rule.sources, tc.rule.sources,
tc.rule.dest, tc.rule.dest,
tc.rule.proto, tc.rule.proto,
@@ -817,12 +818,12 @@ func TestRouteACLFiltering(t *testing.T) {
require.NoError(t, manager.DeleteRouteRule(rule)) require.NoError(t, manager.DeleteRouteRule(rule))
}) })
srcIP := net.ParseIP(tc.srcIP) srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := net.ParseIP(tc.dstIP) dstIP := netip.MustParseAddr(tc.dstIP)
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed // testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
// to the forwarder // to the forwarder
isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort) _, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
require.Equal(t, tc.shouldPass, isAllowed) require.Equal(t, tc.shouldPass, isAllowed)
}) })
} }
@@ -985,6 +986,7 @@ func TestRouteACLOrder(t *testing.T) {
var rules []fw.Rule var rules []fw.Rule
for _, r := range tc.rules { for _, r := range tc.rules {
rule, err := manager.AddRouteFiltering( rule, err := manager.AddRouteFiltering(
nil,
r.sources, r.sources,
r.dest, r.dest,
r.proto, r.proto,
@@ -1004,10 +1006,10 @@ func TestRouteACLOrder(t *testing.T) {
}) })
for i, p := range tc.packets { for i, p := range tc.packets {
srcIP := net.ParseIP(p.srcIP) srcIP := netip.MustParseAddr(p.srcIP)
dstIP := net.ParseIP(p.dstIP) dstIP := netip.MustParseAddr(p.dstIP)
isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort) _, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i) require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
} }
}) })

View File

@@ -1,8 +1,10 @@
package uspfilter package uspfilter
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
"testing" "testing"
"time" "time"
@@ -16,15 +18,17 @@ import (
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log" "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow"
) )
var logger = log.NewFromLogrus(logrus.StandardLogger()) var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
type IFaceMock struct { type IFaceMock struct {
SetFilterFunc func(device.PacketFilter) error SetFilterFunc func(device.PacketFilter) error
AddressFunc func() iface.WGAddress AddressFunc func() wgaddr.Address
GetWGDeviceFunc func() *wgdevice.Device GetWGDeviceFunc func() *wgdevice.Device
GetDeviceFunc func() *device.FilteredDevice GetDeviceFunc func() *device.FilteredDevice
} }
@@ -50,9 +54,9 @@ func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
return i.SetFilterFunc(iface) return i.SetFilterFunc(iface)
} }
func (i *IFaceMock) Address() iface.WGAddress { func (i *IFaceMock) Address() wgaddr.Address {
if i.AddressFunc == nil { if i.AddressFunc == nil {
return iface.WGAddress{} return wgaddr.Address{}
} }
return i.AddressFunc() return i.AddressFunc()
} }
@@ -62,7 +66,7 @@ func TestManagerCreate(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock, false) m, err := Create(ifaceMock, false, flowLogger)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -82,7 +86,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
}, },
} }
m, err := Create(ifaceMock, false) m, err := Create(ifaceMock, false, flowLogger)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -92,9 +96,8 @@ func TestManagerAddPeerFiltering(t *testing.T) {
proto := fw.ProtocolTCP proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}} port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule"
rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment) rule, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -116,26 +119,25 @@ func TestManagerDeleteRule(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock, false) m, err := Create(ifaceMock, false, flowLogger)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
} }
ip := net.ParseIP("192.168.1.1") ip := netip.MustParseAddr("192.168.1.1")
proto := fw.ProtocolTCP proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}} port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule 2"
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment) rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
} }
for _, r := range rule2 { for _, r := range rule2 {
if _, ok := m.incomingRules[ip.String()][r.ID()]; !ok { if _, ok := m.incomingRules[ip][r.ID()]; !ok {
t.Errorf("rule2 is not in the incomingRules") t.Errorf("rule2 is not in the incomingRules")
} }
} }
@@ -149,7 +151,7 @@ func TestManagerDeleteRule(t *testing.T) {
} }
for _, r := range rule2 { for _, r := range rule2 {
if _, ok := m.incomingRules[ip.String()][r.ID()]; ok { if _, ok := m.incomingRules[ip][r.ID()]; ok {
t.Errorf("rule2 is not in the incomingRules") t.Errorf("rule2 is not in the incomingRules")
} }
} }
@@ -160,7 +162,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name string name string
in bool in bool
expDir fw.RuleDirection expDir fw.RuleDirection
ip net.IP ip netip.Addr
dPort uint16 dPort uint16
hook func([]byte) bool hook func([]byte) bool
expectedID string expectedID string
@@ -169,7 +171,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name: "Test Outgoing UDP Packet Hook", name: "Test Outgoing UDP Packet Hook",
in: false, in: false,
expDir: fw.RuleDirectionOUT, expDir: fw.RuleDirectionOUT,
ip: net.IPv4(10, 168, 0, 1), ip: netip.MustParseAddr("10.168.0.1"),
dPort: 8000, dPort: 8000,
hook: func([]byte) bool { return true }, hook: func([]byte) bool { return true },
}, },
@@ -177,7 +179,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name: "Test Incoming UDP Packet Hook", name: "Test Incoming UDP Packet Hook",
in: true, in: true,
expDir: fw.RuleDirectionIN, expDir: fw.RuleDirectionIN,
ip: net.IPv6loopback, ip: netip.MustParseAddr("::1"),
dPort: 9000, dPort: 9000,
hook: func([]byte) bool { return false }, hook: func([]byte) bool { return false },
}, },
@@ -187,18 +189,18 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
var addedRule PeerRule var addedRule PeerRule
if tt.in { if tt.in {
if len(manager.incomingRules[tt.ip.String()]) != 1 { if len(manager.incomingRules[tt.ip]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules)) t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
return return
} }
for _, rule := range manager.incomingRules[tt.ip.String()] { for _, rule := range manager.incomingRules[tt.ip] {
addedRule = rule addedRule = rule
} }
} else { } else {
@@ -206,12 +208,12 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules)) t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
return return
} }
for _, rule := range manager.outgoingRules[tt.ip.String()] { for _, rule := range manager.outgoingRules[tt.ip] {
addedRule = rule addedRule = rule
} }
} }
if !tt.ip.Equal(addedRule.ip) { if tt.ip.Compare(addedRule.ip) != 0 {
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip) t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
return return
} }
@@ -236,7 +238,7 @@ func TestManagerReset(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock, false) m, err := Create(ifaceMock, false, flowLogger)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -246,15 +248,14 @@ func TestManagerReset(t *testing.T) {
proto := fw.ProtocolTCP proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}} port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment) _, err = m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
} }
err = m.Reset(nil) err = m.Close(nil)
if err != nil { if err != nil {
t.Errorf("failed to reset Manager: %v", err) t.Errorf("failed to reset Manager: %v", err)
return return
@@ -268,8 +269,8 @@ func TestManagerReset(t *testing.T) {
func TestNotMatchByIP(t *testing.T) { func TestNotMatchByIP(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"), IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"), IP: net.ParseIP("100.10.0.0"),
@@ -279,7 +280,7 @@ func TestNotMatchByIP(t *testing.T) {
}, },
} }
m, err := Create(ifaceMock, false) m, err := Create(ifaceMock, false, flowLogger)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -292,9 +293,8 @@ func TestNotMatchByIP(t *testing.T) {
ip := net.ParseIP("0.0.0.0") ip := net.ParseIP("0.0.0.0")
proto := fw.ProtocolUDP proto := fw.ProtocolUDP
action := fw.ActionAccept action := fw.ActionAccept
comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment) _, err = m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -328,12 +328,12 @@ func TestNotMatchByIP(t *testing.T) {
return return
} }
if m.dropFilter(buf.Bytes()) { if m.dropFilter(buf.Bytes(), 0) {
t.Errorf("expected packet to be accepted") t.Errorf("expected packet to be accepted")
return return
} }
if err = m.Reset(nil); err != nil { if err = m.Close(nil); err != nil {
t.Errorf("failed to reset Manager: %v", err) t.Errorf("failed to reset Manager: %v", err)
return return
} }
@@ -347,17 +347,17 @@ func TestRemovePacketHook(t *testing.T) {
} }
// creating manager instance // creating manager instance
manager, err := Create(iface, false) manager, err := Create(iface, false, flowLogger)
if err != nil { if err != nil {
t.Fatalf("Failed to create Manager: %s", err) t.Fatalf("Failed to create Manager: %s", err)
} }
defer func() { defer func() {
require.NoError(t, manager.Reset(nil)) require.NoError(t, manager.Close(nil))
}() }()
// Add a UDP packet hook // Add a UDP packet hook
hookFunc := func(data []byte) bool { return true } hookFunc := func(data []byte) bool { return true }
hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc) hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
// Assert the hook is added by finding it in the manager's outgoing rules // Assert the hook is added by finding it in the manager's outgoing rules
found := false found := false
@@ -393,7 +393,7 @@ func TestRemovePacketHook(t *testing.T) {
func TestProcessOutgoingHooks(t *testing.T) { func TestProcessOutgoingHooks(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{ manager.wgNetwork = &net.IPNet{
@@ -401,9 +401,9 @@ func TestProcessOutgoingHooks(t *testing.T) {
Mask: net.CIDRMask(16, 32), Mask: net.CIDRMask(16, 32),
} }
manager.udpTracker.Close() manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger) manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
defer func() { defer func() {
require.NoError(t, manager.Reset(nil)) require.NoError(t, manager.Close(nil))
}() }()
manager.decoders = sync.Pool{ manager.decoders = sync.Pool{
@@ -423,7 +423,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
hookCalled := false hookCalled := false
hookID := manager.AddUDPPacketHook( hookID := manager.AddUDPPacketHook(
false, false,
net.ParseIP("100.10.0.100"), netip.MustParseAddr("100.10.0.100"),
53, 53,
func([]byte) bool { func([]byte) bool {
hookCalled = true hookCalled = true
@@ -458,7 +458,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Test hook gets called // Test hook gets called
result := manager.processOutgoingHooks(buf.Bytes()) result := manager.processOutgoingHooks(buf.Bytes(), 0)
require.True(t, result) require.True(t, result)
require.True(t, hookCalled) require.True(t, hookCalled)
@@ -468,7 +468,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
err = gopacket.SerializeLayers(buf, opts, ipv4) err = gopacket.SerializeLayers(buf, opts, ipv4)
require.NoError(t, err) require.NoError(t, err)
result = manager.processOutgoingHooks(buf.Bytes()) result = manager.processOutgoingHooks(buf.Bytes(), 0)
require.False(t, result) require.False(t, result)
} }
@@ -479,12 +479,12 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
manager, err := Create(ifaceMock, false) manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
if err := manager.Reset(nil); err != nil { if err := manager.Close(nil); err != nil {
t.Errorf("clear the manager state: %v", err) t.Errorf("clear the manager state: %v", err)
} }
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -494,7 +494,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}} port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
} }
@@ -506,7 +506,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
func TestStatefulFirewall_UDPTracking(t *testing.T) { func TestStatefulFirewall_UDPTracking(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{ manager.wgNetwork = &net.IPNet{
@@ -515,7 +515,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
} }
manager.udpTracker.Close() // Close the existing tracker manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger) manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
manager.decoders = sync.Pool{ manager.decoders = sync.Pool{
New: func() any { New: func() any {
d := &decoder{ d := &decoder{
@@ -530,12 +530,12 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}, },
} }
defer func() { defer func() {
require.NoError(t, manager.Reset(nil)) require.NoError(t, manager.Close(nil))
}() }()
// Set up packet parameters // Set up packet parameters
srcIP := net.ParseIP("100.10.0.1") srcIP := netip.MustParseAddr("100.10.0.1")
dstIP := net.ParseIP("100.10.0.100") dstIP := netip.MustParseAddr("100.10.0.100")
srcPort := uint16(51334) srcPort := uint16(51334)
dstPort := uint16(53) dstPort := uint16(53)
@@ -543,8 +543,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
outboundIPv4 := &layers.IPv4{ outboundIPv4 := &layers.IPv4{
TTL: 64, TTL: 64,
Version: 4, Version: 4,
SrcIP: srcIP, SrcIP: srcIP.AsSlice(),
DstIP: dstIP, DstIP: dstIP.AsSlice(),
Protocol: layers.IPProtocolUDP, Protocol: layers.IPProtocolUDP,
} }
outboundUDP := &layers.UDP{ outboundUDP := &layers.UDP{
@@ -569,15 +569,15 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Process outbound packet and verify connection tracking // Process outbound packet and verify connection tracking
drop := manager.DropOutgoing(outboundBuf.Bytes()) drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
require.False(t, drop, "Initial outbound packet should not be dropped") require.False(t, drop, "Initial outbound packet should not be dropped")
// Verify connection was tracked // Verify connection was tracked
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort) conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
require.True(t, exists, "Connection should be tracked after outbound packet") require.True(t, exists, "Connection should be tracked after outbound packet")
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match") require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match")
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match") require.True(t, dstIP.Compare(conn.DestIP) == 0, "Destination IP should match")
require.Equal(t, srcPort, conn.SourcePort, "Source port should match") require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
require.Equal(t, dstPort, conn.DestPort, "Destination port should match") require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
@@ -585,8 +585,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
inboundIPv4 := &layers.IPv4{ inboundIPv4 := &layers.IPv4{
TTL: 64, TTL: 64,
Version: 4, Version: 4,
SrcIP: dstIP, // Original destination is now source SrcIP: dstIP.AsSlice(), // Original destination is now source
DstIP: srcIP, // Original source is now destination DstIP: srcIP.AsSlice(), // Original source is now destination
Protocol: layers.IPProtocolUDP, Protocol: layers.IPProtocolUDP,
} }
inboundUDP := &layers.UDP{ inboundUDP := &layers.UDP{
@@ -636,7 +636,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
for _, cp := range checkPoints { for _, cp := range checkPoints {
time.Sleep(cp.sleep) time.Sleep(cp.sleep)
drop = manager.dropFilter(inboundBuf.Bytes()) drop = manager.dropFilter(inboundBuf.Bytes(), 0)
require.Equal(t, cp.shouldAllow, !drop, cp.description) require.Equal(t, cp.shouldAllow, !drop, cp.description)
// If the connection should still be valid, verify it exists // If the connection should still be valid, verify it exists
@@ -685,7 +685,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
} }
// Create a new outbound connection for invalid tests // Create a new outbound connection for invalid tests
drop = manager.processOutgoingHooks(outboundBuf.Bytes()) drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
require.False(t, drop, "Second outbound packet should not be dropped") require.False(t, drop, "Second outbound packet should not be dropped")
for _, tc := range invalidCases { for _, tc := range invalidCases {
@@ -707,7 +707,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Verify the invalid packet is dropped // Verify the invalid packet is dropped
drop = manager.dropFilter(testBuf.Bytes()) drop = manager.dropFilter(testBuf.Bytes(), 0)
require.True(t, drop, tc.description) require.True(t, drop, tc.description)
}) })
} }

View File

@@ -5,7 +5,6 @@ import (
"net" "net"
"net/netip" "net/netip"
"runtime" "runtime"
"strings"
"sync" "sync"
"github.com/pion/stun/v2" "github.com/pion/stun/v2"
@@ -14,6 +13,8 @@ import (
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn" wgConn "golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type RecvMessage struct { type RecvMessage struct {
@@ -52,9 +53,10 @@ type ICEBind struct {
muUDPMux sync.Mutex muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault udpMux *UniversalUDPMuxDefault
address wgaddr.Address
} }
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind { func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{ ib := &ICEBind{
StdNetBind: b, StdNetBind: b,
@@ -64,6 +66,7 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
endpoints: make(map[netip.Addr]net.Conn), endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}), closedChan: make(chan struct{}),
closed: true, closed: true,
address: address,
} }
rc := receiverCreator{ rc := receiverCreator{
@@ -108,35 +111,17 @@ func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
return s.udpMux, nil return s.udpMux, nil
} }
func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) { func (b *ICEBind) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
fakeUDPAddr, err := fakeAddress(peerAddress)
if err != nil {
return nil, err
}
// force IPv4
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
return nil, fmt.Errorf("failed to convert IP to netip.Addr")
}
b.endpointsMu.Lock() b.endpointsMu.Lock()
b.endpoints[fakeAddr] = conn b.endpoints[fakeIP] = conn
b.endpointsMu.Unlock() b.endpointsMu.Unlock()
return fakeUDPAddr, nil
} }
func (b *ICEBind) RemoveEndpoint(fakeUDPAddr *net.UDPAddr) { func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) {
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
log.Warnf("failed to convert IP to netip.Addr")
return
}
b.endpointsMu.Lock() b.endpointsMu.Lock()
defer b.endpointsMu.Unlock() defer b.endpointsMu.Unlock()
delete(b.endpoints, fakeAddr)
delete(b.endpoints, fakeIP)
} }
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
@@ -161,9 +146,10 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
s.udpMux = NewUniversalUDPMuxDefault( s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{ UniversalUDPMuxParams{
UDPConn: conn, UDPConn: conn,
Net: s.transportNet, Net: s.transportNet,
FilterFn: s.filterFn, FilterFn: s.filterFn,
WGAddress: s.address,
}, },
) )
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
@@ -275,21 +261,6 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
} }
} }
// fakeAddress returns a fake address that is used to as an identifier for the peer.
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
octets := strings.Split(peerAddress.IP.String(), ".")
if len(octets) != 4 {
return nil, fmt.Errorf("invalid IP format")
}
newAddr := &net.UDPAddr{
IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])),
Port: peerAddress.Port,
}
return newAddr, nil
}
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message { func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
return msgsPool.Get().(*[]ipv6.Message) return msgsPool.Get().(*[]ipv6.Message)
} }

View File

@@ -17,6 +17,8 @@ import (
"github.com/pion/logging" "github.com/pion/logging"
"github.com/pion/stun/v2" "github.com/pion/stun/v2"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// FilterFn is a function that filters out candidates based on the address. // FilterFn is a function that filters out candidates based on the address.
@@ -41,6 +43,7 @@ type UniversalUDPMuxParams struct {
XORMappedAddrCacheTTL time.Duration XORMappedAddrCacheTTL time.Duration
Net transport.Net Net transport.Net
FilterFn FilterFn FilterFn FilterFn
WGAddress wgaddr.Address
} }
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux // NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
@@ -64,6 +67,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
mux: m, mux: m,
logger: params.Logger, logger: params.Logger,
filterFn: params.FilterFn, filterFn: params.FilterFn,
address: params.WGAddress,
} }
// embed UDPMux // embed UDPMux
@@ -118,6 +122,7 @@ type udpConn struct {
filterFn FilterFn filterFn FilterFn
// TODO: reset cache on route changes // TODO: reset cache on route changes
addrCache sync.Map addrCache sync.Map
address wgaddr.Address
} }
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) { func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
@@ -159,6 +164,11 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
return nil return nil
} }
if u.address.Network.Contains(a.AsSlice()) {
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
}
if isRouted, prefix, err := u.filterFn(a); err != nil { if isRouted, prefix, err := u.filterFn(a); err != nil {
log.Errorf("Failed to check if address %s is routed: %v", addr, err) log.Errorf("Failed to check if address %s is routed: %v", addr, err)
} else { } else {

View File

@@ -9,13 +9,14 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type WGTunDevice interface { type WGTunDevice interface {
Create() (device.WGConfigurer, error) Create() (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address WGAddress) error UpdateAddr(address wgaddr.Address) error
WgAddress() WGAddress WgAddress() wgaddr.Address
DeviceName() string DeviceName() string
Close() error Close() error
FilteredDevice() *device.FilteredDevice FilteredDevice() *device.FilteredDevice

View File

@@ -13,11 +13,12 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform // WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
type WGTunDevice struct { type WGTunDevice struct {
address WGAddress address wgaddr.Address
port int port int
key string key string
mtu int mtu int
@@ -31,7 +32,7 @@ type WGTunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice { func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
return &WGTunDevice{ return &WGTunDevice{
address: address, address: address,
port: port, port: port,
@@ -93,7 +94,7 @@ func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *WGTunDevice) UpdateAddr(addr WGAddress) error { func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error {
// todo implement // todo implement
return nil return nil
} }
@@ -123,7 +124,7 @@ func (t *WGTunDevice) DeviceName() string {
return t.name return t.name
} }
func (t *WGTunDevice) WgAddress() WGAddress { func (t *WGTunDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }

View File

@@ -13,11 +13,12 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type TunDevice struct { type TunDevice struct {
name string name string
address WGAddress address wgaddr.Address
port int port int
key string key string
mtu int mtu int
@@ -29,7 +30,7 @@ type TunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice { func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{ return &TunDevice{
name: name, name: name,
address: address, address: address,
@@ -85,7 +86,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *TunDevice) UpdateAddr(address WGAddress) error { func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
t.address = address t.address = address
return t.assignAddr() return t.assignAddr()
} }
@@ -106,7 +107,7 @@ func (t *TunDevice) Close() error {
return nil return nil
} }
func (t *TunDevice) WgAddress() WGAddress { func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }

View File

@@ -2,6 +2,7 @@ package device
import ( import (
"net" "net"
"net/netip"
"sync" "sync"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
@@ -10,16 +11,16 @@ import (
// PacketFilter interface for firewall abilities // PacketFilter interface for firewall abilities
type PacketFilter interface { type PacketFilter interface {
// DropOutgoing filter outgoing packets from host to external destinations // DropOutgoing filter outgoing packets from host to external destinations
DropOutgoing(packetData []byte) bool DropOutgoing(packetData []byte, size int) bool
// DropIncoming filter incoming packets from external sources to host // DropIncoming filter incoming packets from external sources to host
DropIncoming(packetData []byte) bool DropIncoming(packetData []byte, size int) bool
// AddUDPPacketHook calls hook when UDP packet from given direction matched // AddUDPPacketHook calls hook when UDP packet from given direction matched
// //
// Hook function returns flag which indicates should be the matched package dropped or not. // Hook function returns flag which indicates should be the matched package dropped or not.
// Hook function receives raw network packet data as argument. // Hook function receives raw network packet data as argument.
AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
// RemovePacketHook removes hook by ID // RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error RemovePacketHook(hookID string) error
@@ -57,7 +58,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
} }
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
if filter.DropOutgoing(bufs[i][offset : offset+sizes[i]]) { if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
bufs = append(bufs[:i], bufs[i+1:]...) bufs = append(bufs[:i], bufs[i+1:]...)
sizes = append(sizes[:i], sizes[i+1:]...) sizes = append(sizes[:i], sizes[i+1:]...)
n-- n--
@@ -81,7 +82,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
filteredBufs := make([][]byte, 0, len(bufs)) filteredBufs := make([][]byte, 0, len(bufs))
dropped := 0 dropped := 0
for _, buf := range bufs { for _, buf := range bufs {
if !filter.DropIncoming(buf[offset:]) { if !filter.DropIncoming(buf[offset:], len(buf)) {
filteredBufs = append(filteredBufs, buf) filteredBufs = append(filteredBufs, buf)
dropped++ dropped++
} }

View File

@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
tun.EXPECT().Write(mockBufs, 0).Return(0, nil) tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
filter := mocks.NewMockPacketFilter(ctrl) filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropIncoming(gomock.Any()).Return(true) filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun) wrapped := newDeviceFilter(tun)
wrapped.filter = filter wrapped.filter = filter
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
return 1, nil return 1, nil
}) })
filter := mocks.NewMockPacketFilter(ctrl) filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropOutgoing(gomock.Any()).Return(true) filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun) wrapped := newDeviceFilter(tun)
wrapped.filter = filter wrapped.filter = filter

View File

@@ -14,11 +14,12 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type TunDevice struct { type TunDevice struct {
name string name string
address WGAddress address wgaddr.Address
port int port int
key string key string
iceBind *bind.ICEBind iceBind *bind.ICEBind
@@ -30,7 +31,7 @@ type TunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(name string, address WGAddress, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice { func NewTunDevice(name string, address wgaddr.Address, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
return &TunDevice{ return &TunDevice{
name: name, name: name,
address: address, address: address,
@@ -120,11 +121,11 @@ func (t *TunDevice) Close() error {
return nil return nil
} }
func (t *TunDevice) WgAddress() WGAddress { func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }
func (t *TunDevice) UpdateAddr(addr WGAddress) error { func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error {
// todo implement // todo implement
return nil return nil
} }

View File

@@ -14,12 +14,13 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/sharedsock" "github.com/netbirdio/netbird/sharedsock"
) )
type TunKernelDevice struct { type TunKernelDevice struct {
name string name string
address WGAddress address wgaddr.Address
wgPort int wgPort int
key string key string
mtu int mtu int
@@ -34,7 +35,7 @@ type TunKernelDevice struct {
filterFn bind.FilterFn filterFn bind.FilterFn
} }
func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice { func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &TunKernelDevice{ return &TunKernelDevice{
ctx: ctx, ctx: ctx,
@@ -99,9 +100,10 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return nil, err return nil, err
} }
bindParams := bind.UniversalUDPMuxParams{ bindParams := bind.UniversalUDPMuxParams{
UDPConn: rawSock, UDPConn: rawSock,
Net: t.transportNet, Net: t.transportNet,
FilterFn: t.filterFn, FilterFn: t.filterFn,
WGAddress: t.address,
} }
mux := bind.NewUniversalUDPMuxDefault(bindParams) mux := bind.NewUniversalUDPMuxDefault(bindParams)
go mux.ReadFromConn(t.ctx) go mux.ReadFromConn(t.ctx)
@@ -112,7 +114,7 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return t.udpMux, nil return t.udpMux, nil
} }
func (t *TunKernelDevice) UpdateAddr(address WGAddress) error { func (t *TunKernelDevice) UpdateAddr(address wgaddr.Address) error {
t.address = address t.address = address
return t.assignAddr() return t.assignAddr()
} }
@@ -145,7 +147,7 @@ func (t *TunKernelDevice) Close() error {
return closErr return closErr
} }
func (t *TunKernelDevice) WgAddress() WGAddress { func (t *TunKernelDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }

View File

@@ -13,12 +13,13 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
type TunNetstackDevice struct { type TunNetstackDevice struct {
name string name string
address WGAddress address wgaddr.Address
port int port int
key string key string
mtu int mtu int
@@ -34,7 +35,7 @@ type TunNetstackDevice struct {
net *netstack.Net net *netstack.Net
} }
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice { func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
return &TunNetstackDevice{ return &TunNetstackDevice{
name: name, name: name,
address: address, address: address,
@@ -97,7 +98,7 @@ func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *TunNetstackDevice) UpdateAddr(WGAddress) error { func (t *TunNetstackDevice) UpdateAddr(wgaddr.Address) error {
return nil return nil
} }
@@ -116,7 +117,7 @@ func (t *TunNetstackDevice) Close() error {
return nil return nil
} }
func (t *TunNetstackDevice) WgAddress() WGAddress { func (t *TunNetstackDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }

View File

@@ -12,11 +12,12 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type USPDevice struct { type USPDevice struct {
name string name string
address WGAddress address wgaddr.Address
port int port int
key string key string
mtu int mtu int
@@ -28,7 +29,7 @@ type USPDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice { func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
log.Infof("using userspace bind mode") log.Infof("using userspace bind mode")
return &USPDevice{ return &USPDevice{
@@ -93,7 +94,7 @@ func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *USPDevice) UpdateAddr(address WGAddress) error { func (t *USPDevice) UpdateAddr(address wgaddr.Address) error {
t.address = address t.address = address
return t.assignAddr() return t.assignAddr()
} }
@@ -113,7 +114,7 @@ func (t *USPDevice) Close() error {
return nil return nil
} }
func (t *USPDevice) WgAddress() WGAddress { func (t *USPDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }

View File

@@ -13,13 +13,14 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}" const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}"
type TunDevice struct { type TunDevice struct {
name string name string
address WGAddress address wgaddr.Address
port int port int
key string key string
mtu int mtu int
@@ -32,7 +33,7 @@ type TunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice { func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{ return &TunDevice{
name: name, name: name,
address: address, address: address,
@@ -118,7 +119,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *TunDevice) UpdateAddr(address WGAddress) error { func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
t.address = address t.address = address
return t.assignAddr() return t.assignAddr()
} }
@@ -139,7 +140,7 @@ func (t *TunDevice) Close() error {
} }
return nil return nil
} }
func (t *TunDevice) WgAddress() WGAddress { func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }

View File

@@ -6,6 +6,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/freebsd" "github.com/netbirdio/netbird/client/iface/freebsd"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type wgLink struct { type wgLink struct {
@@ -56,7 +57,7 @@ func (l *wgLink) up() error {
return nil return nil
} }
func (l *wgLink) assignAddr(address WGAddress) error { func (l *wgLink) assignAddr(address wgaddr.Address) error {
link, err := freebsd.LinkByName(l.name) link, err := freebsd.LinkByName(l.name)
if err != nil { if err != nil {
return fmt.Errorf("link by name: %w", err) return fmt.Errorf("link by name: %w", err)

View File

@@ -8,6 +8,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type wgLink struct { type wgLink struct {
@@ -90,7 +92,7 @@ func (l *wgLink) up() error {
return nil return nil
} }
func (l *wgLink) assignAddr(address WGAddress) error { func (l *wgLink) assignAddr(address wgaddr.Address) error {
//delete existing addresses //delete existing addresses
list, err := netlink.AddrList(l, 0) list, err := netlink.AddrList(l, 0)
if err != nil { if err != nil {

View File

@@ -7,13 +7,14 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type WGTunDevice interface { type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error) Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address WGAddress) error UpdateAddr(address wgaddr.Address) error
WgAddress() WGAddress WgAddress() wgaddr.Address
DeviceName() string DeviceName() string
Close() error Close() error
FilteredDevice() *device.FilteredDevice FilteredDevice() *device.FilteredDevice

View File

@@ -19,6 +19,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
@@ -28,8 +29,6 @@ const (
WgInterfaceDefault = configurer.WgInterfaceDefault WgInterfaceDefault = configurer.WgInterfaceDefault
) )
type WGAddress = device.WGAddress
type wgProxyFactory interface { type wgProxyFactory interface {
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy
Free() error Free() error
@@ -72,7 +71,7 @@ func (w *WGIface) Name() string {
} }
// Address returns the interface address // Address returns the interface address
func (w *WGIface) Address() device.WGAddress { func (w *WGIface) Address() wgaddr.Address {
return w.tun.WgAddress() return w.tun.WgAddress()
} }
@@ -103,7 +102,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
addr, err := device.ParseWGAddress(newAddr) addr, err := wgaddr.ParseWGAddress(newAddr)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -3,17 +3,18 @@ package iface
import ( import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address) wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
wgIFace := &WGIface{ wgIFace := &WGIface{
userspaceBind: true, userspaceBind: true,

View File

@@ -6,17 +6,18 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address) wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
var tun WGTunDevice var tun WGTunDevice
if netstack.IsEnabled() { if netstack.IsEnabled() {

View File

@@ -5,17 +5,18 @@ package iface
import ( import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address) wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
wgIFace := &WGIface{ wgIFace := &WGIface{
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd), tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd),

View File

@@ -8,12 +8,13 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address) wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -21,7 +22,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{} wgIFace := &WGIface{}
if netstack.IsEnabled() { if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
@@ -34,7 +35,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
return wgIFace, nil return wgIFace, nil
} }
if device.ModuleTunIsLoaded() { if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)

View File

@@ -4,16 +4,17 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
wgaddr "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address) wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
var tun WGTunDevice var tun WGTunDevice
if netstack.IsEnabled() { if netstack.IsEnabled() {

View File

@@ -6,6 +6,7 @@ package mocks
import ( import (
net "net" net "net"
"net/netip"
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
@@ -35,7 +36,7 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
} }
// AddUDPPacketHook mocks base method. // AddUDPPacketHook mocks base method.
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func([]byte) bool) string { func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3) ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(string) ret0, _ := ret[0].(string)
@@ -49,31 +50,31 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
} }
// DropIncoming mocks base method. // DropIncoming mocks base method.
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool { func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropIncoming", arg0) ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1)
ret0, _ := ret[0].(bool) ret0, _ := ret[0].(bool)
return ret0 return ret0
} }
// DropIncoming indicates an expected call of DropIncoming. // DropIncoming indicates an expected call of DropIncoming.
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call { func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1)
} }
// DropOutgoing mocks base method. // DropOutgoing mocks base method.
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool { func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropOutgoing", arg0) ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1)
ret0, _ := ret[0].(bool) ret0, _ := ret[0].(bool)
return ret0 return ret0
} }
// DropOutgoing indicates an expected call of DropOutgoing. // DropOutgoing indicates an expected call of DropOutgoing.
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call { func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1)
} }
// RemovePacketHook mocks base method. // RemovePacketHook mocks base method.

View File

@@ -55,7 +55,7 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
skipProxy, err := strconv.ParseBool(os.Getenv(EnvSkipProxy)) skipProxy, err := strconv.ParseBool(os.Getenv(EnvSkipProxy))
if err != nil { if err != nil {
log.Errorf("failed to parse NB_ETSTACK_SKIP_PROXY: %s", err) log.Errorf("failed to parse %s: %s", EnvSkipProxy, err)
} }
if skipProxy { if skipProxy {
return nsTunDev, tunNet, nil return nsTunDev, tunNet, nil

View File

@@ -1,29 +1,29 @@
package device package wgaddr
import ( import (
"fmt" "fmt"
"net" "net"
) )
// WGAddress WireGuard parsed address // Address WireGuard parsed address
type WGAddress struct { type Address struct {
IP net.IP IP net.IP
Network *net.IPNet Network *net.IPNet
} }
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address // ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
func ParseWGAddress(address string) (WGAddress, error) { func ParseWGAddress(address string) (Address, error) {
ip, network, err := net.ParseCIDR(address) ip, network, err := net.ParseCIDR(address)
if err != nil { if err != nil {
return WGAddress{}, err return Address{}, err
} }
return WGAddress{ return Address{
IP: ip, IP: ip,
Network: network, Network: network,
}, nil }, nil
} }
func (addr WGAddress) String() string { func (addr Address) String() string {
maskSize, _ := addr.Network.Mask.Size() maskSize, _ := addr.Network.Mask.Size()
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize) return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
} }

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"strings"
"sync" "sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -16,13 +17,13 @@ import (
type ProxyBind struct { type ProxyBind struct {
Bind *bind.ICEBind Bind *bind.ICEBind
wgAddr *net.UDPAddr fakeNetIP *netip.AddrPort
wgEndpoint *bind.Endpoint wgBindEndpoint *bind.Endpoint
remoteConn net.Conn remoteConn net.Conn
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
closeMu sync.Mutex closeMu sync.Mutex
closed bool closed bool
pausedMu sync.Mutex pausedMu sync.Mutex
paused bool paused bool
@@ -33,20 +34,24 @@ type ProxyBind struct {
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the // endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
// WireGuard configuration. // WireGuard configuration.
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error { func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.Bind.SetEndpoint(nbAddr, remoteConn) fakeNetIP, err := fakeAddress(nbAddr)
if err != nil { if err != nil {
return err return err
} }
p.wgAddr = addr p.fakeNetIP = fakeNetIP
p.wgEndpoint = addrToEndpoint(addr) p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
p.remoteConn = remoteConn p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx) p.ctx, p.cancel = context.WithCancel(ctx)
return err return nil
} }
func (p *ProxyBind) EndpointAddr() *net.UDPAddr { func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
return p.wgAddr return &net.UDPAddr{
IP: p.fakeNetIP.Addr().AsSlice(),
Port: int(p.fakeNetIP.Port()),
Zone: p.fakeNetIP.Addr().Zone(),
}
} }
func (p *ProxyBind) Work() { func (p *ProxyBind) Work() {
@@ -54,6 +59,8 @@ func (p *ProxyBind) Work() {
return return
} }
p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn)
p.pausedMu.Lock() p.pausedMu.Lock()
p.paused = false p.paused = false
p.pausedMu.Unlock() p.pausedMu.Unlock()
@@ -93,7 +100,7 @@ func (p *ProxyBind) close() error {
p.cancel() p.cancel()
p.Bind.RemoveEndpoint(p.wgAddr) p.Bind.RemoveEndpoint(p.fakeNetIP.Addr())
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) { if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
return rErr return rErr
@@ -126,7 +133,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
} }
msg := bind.RecvMessage{ msg := bind.RecvMessage{
Endpoint: p.wgEndpoint, Endpoint: p.wgBindEndpoint,
Buffer: buf[:n], Buffer: buf[:n],
} }
p.Bind.RecvChan <- msg p.Bind.RecvChan <- msg
@@ -134,8 +141,19 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
} }
} }
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint { // fakeAddress returns a fake address that is used to as an identifier for the peer.
ip, _ := netip.AddrFromSlice(addr.IP.To4()) // The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port)) func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
return &bind.Endpoint{AddrPort: addrPort} octets := strings.Split(peerAddress.IP.String(), ".")
if len(octets) != 4 {
return nil, fmt.Errorf("invalid IP format")
}
fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3]))
if err != nil {
return nil, fmt.Errorf("failed to parse new IP: %w", err)
}
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
return &netipAddr, nil
} }

View File

@@ -6,8 +6,8 @@
!define DESCRIPTION "A WireGuard®-based mesh network that connects your devices into a single private network" !define DESCRIPTION "A WireGuard®-based mesh network that connects your devices into a single private network"
!define INSTALLER_NAME "netbird-installer.exe" !define INSTALLER_NAME "netbird-installer.exe"
!define MAIN_APP_EXE "Netbird" !define MAIN_APP_EXE "Netbird"
!define ICON "ui\\netbird.ico" !define ICON "ui\\assets\\netbird.ico"
!define BANNER "ui\\banner.bmp" !define BANNER "ui\\build\\banner.bmp"
!define LICENSE_DATA "..\\LICENSE" !define LICENSE_DATA "..\\LICENSE"
!define INSTALL_DIR "$PROGRAMFILES64\${APP_NAME}" !define INSTALL_DIR "$PROGRAMFILES64\${APP_NAME}"
@@ -22,6 +22,8 @@
!define UI_REG_APP_PATH "Software\Microsoft\Windows\CurrentVersion\App Paths\${UI_APP_EXE}" !define UI_REG_APP_PATH "Software\Microsoft\Windows\CurrentVersion\App Paths\${UI_APP_EXE}"
!define UI_UNINSTALL_PATH "Software\Microsoft\Windows\CurrentVersion\Uninstall\${UI_APP_NAME}" !define UI_UNINSTALL_PATH "Software\Microsoft\Windows\CurrentVersion\Uninstall\${UI_APP_NAME}"
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
Unicode True Unicode True
###################################################################### ######################################################################
@@ -68,6 +70,9 @@ ShowInstDetails Show
!insertmacro MUI_PAGE_DIRECTORY !insertmacro MUI_PAGE_DIRECTORY
; Custom page for autostart checkbox
Page custom AutostartPage AutostartPageLeave
!insertmacro MUI_PAGE_INSTFILES !insertmacro MUI_PAGE_INSTFILES
!insertmacro MUI_PAGE_FINISH !insertmacro MUI_PAGE_FINISH
@@ -80,8 +85,36 @@ ShowInstDetails Show
!insertmacro MUI_LANGUAGE "English" !insertmacro MUI_LANGUAGE "English"
; Variables for autostart option
Var AutostartCheckbox
Var AutostartEnabled
###################################################################### ######################################################################
; Function to create the autostart options page
Function AutostartPage
!insertmacro MUI_HEADER_TEXT "Startup Options" "Configure how ${APP_NAME} launches with Windows."
nsDialogs::Create 1018
Pop $0
${If} $0 == error
Abort
${EndIf}
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
Pop $AutostartCheckbox
${NSD_Check} $AutostartCheckbox ; Default to checked
StrCpy $AutostartEnabled "1" ; Default to enabled
nsDialogs::Show
FunctionEnd
; Function to handle leaving the autostart page
Function AutostartPageLeave
${NSD_GetState} $AutostartCheckbox $AutostartEnabled
FunctionEnd
Function GetAppFromCommand Function GetAppFromCommand
Exch $1 Exch $1
Push $2 Push $2
@@ -163,6 +196,16 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}" WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
; Create autostart registry entry based on checkbox
DetailPrint "Autostart enabled: $AutostartEnabled"
${If} $AutostartEnabled == "1"
WriteRegStr HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" "$INSTDIR\${UI_APP_EXE}.exe"
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
${Else}
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DetailPrint "Autostart not enabled by user"
${EndIf}
EnVar::SetHKLM EnVar::SetHKLM
EnVar::AddValueEx "path" "$INSTDIR" EnVar::AddValueEx "path" "$INSTDIR"
@@ -186,7 +229,10 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall' ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
# kill ui client # kill ui client
ExecWait `taskkill /im ${UI_APP_EXE}.exe` ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart registry entry
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
# wait the service uninstall take unblock the executable # wait the service uninstall take unblock the executable
Sleep 3000 Sleep 3000

View File

@@ -240,7 +240,7 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul
dPorts := convertPortInfo(rule.PortInfo) dPorts := convertPortInfo(rule.PortInfo)
addedRule, err := d.firewall.AddRouteFiltering(sources, destination, protocol, nil, dPorts, action) addedRule, err := d.firewall.AddRouteFiltering(rule.Id, sources, destination, protocol, nil, dPorts, action)
if err != nil { if err != nil {
return "", fmt.Errorf("add route rule: %w", err) return "", fmt.Errorf("add route rule: %w", err)
} }
@@ -281,7 +281,7 @@ func (d *DefaultManager) protoRuleToFirewallRule(
} }
} }
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action, "") ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action)
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok { if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
return ruleID, rulesPair, nil return ruleID, rulesPair, nil
} }
@@ -289,11 +289,11 @@ func (d *DefaultManager) protoRuleToFirewallRule(
var rules []firewall.Rule var rules []firewall.Rule
switch r.Direction { switch r.Direction {
case mgmProto.RuleDirection_IN: case mgmProto.RuleDirection_IN:
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "") rules, err = d.addInRules(r.Id, ip, protocol, port, action, ipsetName)
case mgmProto.RuleDirection_OUT: case mgmProto.RuleDirection_OUT:
// TODO: Remove this soon. Outbound rules are obsolete. // TODO: Remove this soon. Outbound rules are obsolete.
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already // We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "") rules, err = d.addOutRules(r.Id, ip, protocol, port, action, ipsetName)
default: default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule") return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
} }
@@ -322,14 +322,14 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
} }
func (d *DefaultManager) addInRules( func (d *DefaultManager) addInRules(
id []byte,
ip net.IP, ip net.IP,
protocol firewall.Protocol, protocol firewall.Protocol,
port *firewall.Port, port *firewall.Port,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
rule, err := d.firewall.AddPeerFiltering(ip, protocol, nil, port, action, ipsetName, comment) rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, nil, port, action, ipsetName)
if err != nil { if err != nil {
return nil, fmt.Errorf("add firewall rule: %w", err) return nil, fmt.Errorf("add firewall rule: %w", err)
} }
@@ -338,18 +338,18 @@ func (d *DefaultManager) addInRules(
} }
func (d *DefaultManager) addOutRules( func (d *DefaultManager) addOutRules(
id []byte,
ip net.IP, ip net.IP,
protocol firewall.Protocol, protocol firewall.Protocol,
port *firewall.Port, port *firewall.Port,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
if shouldSkipInvertedRule(protocol, port) { if shouldSkipInvertedRule(protocol, port) {
return nil, nil return nil, nil
} }
rule, err := d.firewall.AddPeerFiltering(ip, protocol, port, nil, action, ipsetName, comment) rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, port, nil, action, ipsetName)
if err != nil { if err != nil {
return nil, fmt.Errorf("add firewall rule: %w", err) return nil, fmt.Errorf("add firewall rule: %w", err)
} }
@@ -364,9 +364,8 @@ func (d *DefaultManager) getPeerRuleID(
direction int, direction int,
port *firewall.Port, port *firewall.Port,
action firewall.Action, action firewall.Action,
comment string,
) id.RuleID { ) id.RuleID {
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action))
if port != nil { if port != nil {
idStr += port.String() idStr += port.String()
} }

View File

@@ -1,6 +1,7 @@
package acl package acl
import ( import (
"context"
"net" "net"
"testing" "testing"
@@ -8,11 +9,14 @@ import (
"github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks" "github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/client/internal/netflow"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
) )
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
func TestDefaultManager(t *testing.T) { func TestDefaultManager(t *testing.T) {
networkMap := &mgmProto.NetworkMap{ networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{ FirewallRules: []*mgmProto.FirewallRule{
@@ -45,20 +49,20 @@ func TestDefaultManager(t *testing.T) {
} }
ifaceMock.EXPECT().Name().Return("lo").AnyTimes() ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(iface.WGAddress{ ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: ip, IP: ip,
Network: network, Network: network,
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it // we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil, false) fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
if err != nil { if err != nil {
t.Errorf("create firewall: %v", err) t.Errorf("create firewall: %v", err)
return return
} }
defer func(fw manager.Manager) { defer func(fw manager.Manager) {
_ = fw.Reset(nil) _ = fw.Close(nil)
}(fw) }(fw)
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)
@@ -339,20 +343,20 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
} }
ifaceMock.EXPECT().Name().Return("lo").AnyTimes() ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(iface.WGAddress{ ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: ip, IP: ip,
Network: network, Network: network,
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it // we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil, false) fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
if err != nil { if err != nil {
t.Errorf("create firewall: %v", err) t.Errorf("create firewall: %v", err)
return return
} }
defer func(fw manager.Manager) { defer func(fw manager.Manager) {
_ = fw.Reset(nil) _ = fw.Close(nil)
}(fw) }(fw)
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)

View File

@@ -10,8 +10,8 @@ import (
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
iface "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// MockIFaceMapper is a mock of IFaceMapper interface. // MockIFaceMapper is a mock of IFaceMapper interface.
@@ -38,10 +38,10 @@ func (m *MockIFaceMapper) EXPECT() *MockIFaceMapperMockRecorder {
} }
// Address mocks base method. // Address mocks base method.
func (m *MockIFaceMapper) Address() iface.WGAddress { func (m *MockIFaceMapper) Address() wgaddr.Address {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Address") ret := m.ctrl.Call(m, "Address")
ret0, _ := ret[0].(iface.WGAddress) ret0, _ := ret[0].(wgaddr.Address)
return ret0 return ret0
} }

View File

@@ -61,7 +61,7 @@ func NewConnectClient(
} }
// Run with main logic. // Run with main logic.
func (c *ConnectClient) Run(runningChan chan error) error { func (c *ConnectClient) Run(runningChan chan struct{}) error {
return c.run(MobileDependency{}, runningChan) return c.run(MobileDependency{}, runningChan)
} }
@@ -102,7 +102,7 @@ func (c *ConnectClient) RunOniOS(
return c.run(mobileDependency, nil) return c.run(mobileDependency, nil)
} }
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan error) error { func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}) error {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
rec := c.statusRecorder rec := c.statusRecorder
@@ -159,10 +159,9 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
} }
defer c.statusRecorder.ClientStop() defer c.statusRecorder.ClientStop()
runningChanOpen := true
operation := func() error { operation := func() error {
// if context cancelled we not start new backoff cycle // if context cancelled we not start new backoff cycle
if c.isContextCancelled() { if c.ctx.Err() != nil {
return nil return nil
} }
@@ -282,10 +281,11 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress()) log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected) state.Set(StatusConnected)
if runningChan != nil && runningChanOpen { if runningChan != nil {
runningChan <- nil select {
close(runningChan) case runningChan <- struct{}{}:
runningChanOpen = false default:
}
} }
<-engineCtx.Done() <-engineCtx.Done()
@@ -379,15 +379,6 @@ func (c *ConnectClient) Stop() error {
return nil return nil
} }
func (c *ConnectClient) isContextCancelled() bool {
select {
case <-c.ctx.Done():
return true
default:
return false
}
}
// SetNetworkMapPersistence enables or disables network map persistence. // SetNetworkMapPersistence enables or disables network map persistence.
// When enabled, the last received network map will be stored and can be retrieved // When enabled, the last received network map will be stored and can be retrieved
// through the Engine's getLatestNetworkMap method. When disabled, any stored // through the Engine's getLatestNetworkMap method. When disabled, any stored

View File

@@ -22,6 +22,8 @@ import (
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
pfmock "github.com/netbirdio/netbird/client/iface/mocks" pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
@@ -29,6 +31,8 @@ import (
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
) )
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
type mocWGIface struct { type mocWGIface struct {
filter device.PacketFilter filter device.PacketFilter
} }
@@ -37,9 +41,9 @@ func (w *mocWGIface) Name() string {
panic("implement me") panic("implement me")
} }
func (w *mocWGIface) Address() iface.WGAddress { func (w *mocWGIface) Address() wgaddr.Address {
ip, network, _ := net.ParseCIDR("100.66.100.0/24") ip, network, _ := net.ParseCIDR("100.66.100.0/24")
return iface.WGAddress{ return wgaddr.Address{
IP: ip, IP: ip,
Network: network, Network: network,
} }
@@ -455,7 +459,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
} }
packetfilter := pfmock.NewMockPacketFilter(ctrl) packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes() packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
packetfilter.EXPECT().RemovePacketHook(gomock.Any()) packetfilter.EXPECT().RemovePacketHook(gomock.Any())
packetfilter.EXPECT().SetNetwork(ipNet) packetfilter.EXPECT().SetNetwork(ipNet)
@@ -916,7 +920,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
return nil, err return nil, err
} }
pf, err := uspfilter.Create(wgIface, false) pf, err := uspfilter.Create(wgIface, false, flowLogger)
if err != nil { if err != nil {
t.Fatalf("failed to create uspfilter: %v", err) t.Fatalf("failed to create uspfilter: %v", err)
return nil, err return nil, err
@@ -1015,7 +1019,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
mh.AssertExpectations(t) mh.AssertExpectations(t)
} }
// Reset mocks // Close mocks
if mh, ok := tc.expectedHandler.(*MockHandler); ok { if mh, ok := tc.expectedHandler.(*MockHandler); ok {
mh.ExpectedCalls = nil mh.ExpectedCalls = nil
mh.Calls = nil mh.Calls = nil

View File

@@ -2,7 +2,7 @@ package dns
import ( import (
"fmt" "fmt"
"net" "net/netip"
"sync" "sync"
"github.com/google/gopacket" "github.com/google/gopacket"
@@ -117,5 +117,10 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
return true return true
} }
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil ip, err := netip.ParseAddr(s.runtimeIP)
if err != nil {
return "", fmt.Errorf("parse runtime ip: %w", err)
}
return filter.AddUDPPacketHook(false, ip, uint16(s.runtimePort), hook), nil
} }

View File

@@ -5,15 +5,15 @@ package dns
import ( import (
"net" "net"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// WGIface defines subset methods of interface required for manager // WGIface defines subset methods of interface required for manager
type WGIface interface { type WGIface interface {
Name() string Name() string
Address() iface.WGAddress Address() wgaddr.Address
ToInterface() *net.Interface ToInterface() *net.Interface
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() device.PacketFilter GetFilter() device.PacketFilter

View File

@@ -1,15 +1,15 @@
package dns package dns
import ( import (
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// WGIface defines subset methods of interface required for manager // WGIface defines subset methods of interface required for manager
type WGIface interface { type WGIface interface {
Name() string Name() string
Address() iface.WGAddress Address() wgaddr.Address
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice

View File

@@ -88,7 +88,7 @@ func (h *Manager) allowDNSFirewall() error {
return nil return nil
} }
dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "", "") dnsRules, err := h.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
if err != nil { if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err) log.Errorf("failed to add allow DNS router rules, err: %v", err)
return err return err

View File

@@ -35,7 +35,7 @@ import (
"github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/ingressgw" "github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/netflow"
"github.com/netbirdio/netbird/client/internal/netflow/types" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/guard" "github.com/netbirdio/netbird/client/internal/peer/guard"
@@ -191,7 +191,7 @@ type Engine struct {
persistNetworkMap bool persistNetworkMap bool
latestNetworkMap *mgmProto.NetworkMap latestNetworkMap *mgmProto.NetworkMap
connSemaphore *semaphoregroup.SemaphoreGroup connSemaphore *semaphoregroup.SemaphoreGroup
flowManager types.FlowManager flowManager nftypes.FlowManager
} }
// Peer is an instance of the Connection Peer // Peer is an instance of the Connection Peer
@@ -234,7 +234,6 @@ func NewEngine(
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
checks: checks, checks: checks,
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
flowManager: netflow.NewManager(clientCtx),
} }
if runtime.GOOS == "ios" { if runtime.GOOS == "ios" {
if !fileExists(mobileDep.StateFilePath) { if !fileExists(mobileDep.StateFilePath) {
@@ -303,8 +302,6 @@ func (e *Engine) Stop() error {
return fmt.Errorf("failed to remove all peers: %s", err) return fmt.Errorf("failed to remove all peers: %s", err)
} }
e.flowManager.Close()
if e.cancel != nil { if e.cancel != nil {
e.cancel() e.cancel()
} }
@@ -314,6 +311,12 @@ func (e *Engine) Stop() error {
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
e.close() e.close()
// stop flow manager after wg interface is gone
if e.flowManager != nil {
e.flowManager.Close()
}
log.Infof("stopped Netbird Engine") log.Infof("stopped Netbird Engine")
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
@@ -348,6 +351,10 @@ func (e *Engine) Start() error {
} }
e.wgInterface = wgIface e.wgInterface = wgIface
// start flow manager right after interface creation
publicKey := e.config.WgPrivateKey.PublicKey()
e.flowManager = netflow.NewManager(e.ctx, e.wgInterface, publicKey[:], e.statusRecorder)
if e.config.RosenpassEnabled { if e.config.RosenpassEnabled {
log.Infof("rosenpass is enabled") log.Infof("rosenpass is enabled")
if e.config.RosenpassPermissive { if e.config.RosenpassPermissive {
@@ -454,7 +461,7 @@ func (e *Engine) createFirewall() error {
} }
var err error var err error
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.config.DisableServerRoutes) e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes)
if err != nil || e.firewall == nil { if err != nil || e.firewall == nil {
log.Errorf("failed creating firewall manager: %s", err) log.Errorf("failed creating firewall manager: %s", err)
return nil return nil
@@ -488,13 +495,13 @@ func (e *Engine) initFirewall() error {
// this rule is static and will be torn down on engine down by the firewall manager // this rule is static and will be torn down on engine down by the firewall manager
if _, err := e.firewall.AddPeerFiltering( if _, err := e.firewall.AddPeerFiltering(
nil,
net.IP{0, 0, 0, 0}, net.IP{0, 0, 0, 0},
firewallManager.ProtocolUDP, firewallManager.ProtocolUDP,
nil, nil,
&port, &port,
firewallManager.ActionAccept, firewallManager.ActionAccept,
"", "",
"",
); err != nil { ); err != nil {
log.Errorf("failed to allow rosenpass interface traffic: %v", err) log.Errorf("failed to allow rosenpass interface traffic: %v", err)
return nil return nil
@@ -518,6 +525,7 @@ func (e *Engine) blockLanAccess() {
v4 := netip.PrefixFrom(netip.IPv4Unspecified(), 0) v4 := netip.PrefixFrom(netip.IPv4Unspecified(), 0)
for _, network := range toBlock { for _, network := range toBlock {
if _, err := e.firewall.AddRouteFiltering( if _, err := e.firewall.AddRouteFiltering(
nil,
[]netip.Prefix{v4}, []netip.Prefix{v4},
network, network,
firewallManager.ProtocolALL, firewallManager.ProtocolALL,
@@ -721,12 +729,13 @@ func (e *Engine) handleFlowUpdate(config *mgmProto.FlowConfig) error {
return e.flowManager.Update(flowConfig) return e.flowManager.Update(flowConfig)
} }
func toFlowLoggerConfig(config *mgmProto.FlowConfig) (*types.FlowConfig, error) { func toFlowLoggerConfig(config *mgmProto.FlowConfig) (*nftypes.FlowConfig, error) {
if config.GetInterval() == nil { if config.GetInterval() == nil {
return nil, errors.New("flow interval is nil") return nil, errors.New("flow interval is nil")
} }
return &types.FlowConfig{ return &nftypes.FlowConfig{
Enabled: config.GetEnabled(), Enabled: config.GetEnabled(),
Counters: config.GetCounters(),
URL: config.GetUrl(), URL: config.GetUrl(),
TokenPayload: config.GetTokenPayload(), TokenPayload: config.GetTokenPayload(),
TokenSignature: config.GetTokenSignature(), TokenSignature: config.GetTokenSignature(),
@@ -1419,7 +1428,7 @@ func (e *Engine) close() {
} }
if e.firewall != nil { if e.firewall != nil {
err := e.firewall.Reset(e.stateManager) err := e.firewall.Close(e.stateManager)
if err != nil { if err != nil {
log.Warnf("failed to reset firewall: %s", err) log.Warnf("failed to reset firewall: %s", err)
} }
@@ -1632,16 +1641,19 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
return relay.ProbeAll(e.ctx, relay.ProbeTURN, turns) return relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)
} }
// restartEngine restarts the engine by cancelling the client context
func (e *Engine) restartEngine() { func (e *Engine) restartEngine() {
log.Info("restarting engine") e.syncMsgMux.Lock()
CtxGetState(e.ctx).Set(StatusConnecting) defer e.syncMsgMux.Unlock()
if err := e.Stop(); err != nil { if e.ctx.Err() != nil {
log.Errorf("Failed to stop engine: %v", err) return
} }
log.Info("restarting engine")
CtxGetState(e.ctx).Set(StatusConnecting)
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection) _ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
log.Infof("cancelling client, engine will be recreated") log.Infof("cancelling client context, engine will be recreated")
e.clientCancel() e.clientCancel()
} }
@@ -1653,34 +1665,17 @@ func (e *Engine) startNetworkMonitor() {
e.networkMonitor = networkmonitor.New() e.networkMonitor = networkmonitor.New()
go func() { go func() {
var mu sync.Mutex if err := e.networkMonitor.Listen(e.ctx); err != nil {
var debounceTimer *time.Timer if errors.Is(err, context.Canceled) {
log.Infof("network monitor stopped")
// Start the network monitor with a callback, Start will block until the monitor is stopped, return
// a network change is detected, or an error occurs on start up
err := e.networkMonitor.Start(e.ctx, func() {
// This function is called when a network change is detected
mu.Lock()
defer mu.Unlock()
if debounceTimer != nil {
log.Infof("Network monitor: detected network change, reset debounceTimer")
debounceTimer.Stop()
} }
log.Errorf("network monitor error: %v", err)
// Set a new timer to debounce rapid network changes return
debounceTimer = time.AfterFunc(2*time.Second, func() {
// This function is called after the debounce period
mu.Lock()
defer mu.Unlock()
log.Infof("Network monitor: detected network change, restarting engine")
e.restartEngine()
})
})
if err != nil && !errors.Is(err, networkmonitor.ErrStopped) {
log.Errorf("Network monitor: %v", err)
} }
log.Infof("Network monitor: detected network change, restarting engine")
e.restartEngine()
}() }()
} }

View File

@@ -31,6 +31,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
@@ -75,7 +76,7 @@ type MockWGIface struct {
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
IsUserspaceBindFunc func() bool IsUserspaceBindFunc func() bool
NameFunc func() string NameFunc func() string
AddressFunc func() device.WGAddress AddressFunc func() wgaddr.Address
ToInterfaceFunc func() *net.Interface ToInterfaceFunc func() *net.Interface
UpFunc func() (*bind.UniversalUDPMuxDefault, error) UpFunc func() (*bind.UniversalUDPMuxDefault, error)
UpdateAddrFunc func(newAddr string) error UpdateAddrFunc func(newAddr string) error
@@ -114,7 +115,7 @@ func (m *MockWGIface) Name() string {
return m.NameFunc() return m.NameFunc()
} }
func (m *MockWGIface) Address() device.WGAddress { func (m *MockWGIface) Address() wgaddr.Address {
return m.AddressFunc() return m.AddressFunc()
} }
@@ -364,8 +365,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
RemovePeerFunc: func(peerKey string) error { RemovePeerFunc: func(peerKey string) error {
return nil return nil
}, },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"), IP: net.ParseIP("10.20.0.0"),

View File

@@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
@@ -20,7 +21,7 @@ type wgIfaceBase interface {
CreateOnAndroid(routeRange []string, ip string, domains []string) error CreateOnAndroid(routeRange []string, ip string, domains []string) error
IsUserspaceBind() bool IsUserspaceBind() bool
Name() string Name() string
Address() device.WGAddress Address() wgaddr.Address
ToInterface() *net.Interface ToInterface() *net.Interface
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error UpdateAddr(newAddr string) error

View File

@@ -0,0 +1,306 @@
//go:build linux && !android
package conntrack
import (
"encoding/binary"
"fmt"
"net/netip"
"sync"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
nfct "github.com/ti-mo/conntrack"
"github.com/ti-mo/netfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
const defaultChannelSize = 100
// ConnTrack manages kernel-based conntrack events
type ConnTrack struct {
flowLogger nftypes.FlowLogger
iface nftypes.IFaceMapper
conn *nfct.Conn
mux sync.Mutex
instanceID uuid.UUID
started bool
done chan struct{}
sysctlModified bool
}
// New creates a new connection tracker that interfaces with the kernel's conntrack system
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) *ConnTrack {
return &ConnTrack{
flowLogger: flowLogger,
iface: iface,
instanceID: uuid.New(),
started: false,
done: make(chan struct{}, 1),
}
}
// Start begins tracking connections by listening for conntrack events. This method is idempotent.
func (c *ConnTrack) Start(enableCounters bool) error {
c.mux.Lock()
defer c.mux.Unlock()
if c.started {
return nil
}
log.Info("Starting conntrack event listening")
if enableCounters {
c.EnableAccounting()
}
conn, err := nfct.Dial(nil)
if err != nil {
return fmt.Errorf("dial conntrack: %w", err)
}
c.conn = conn
events := make(chan nfct.Event, defaultChannelSize)
errChan, err := conn.Listen(events, 1, []netfilter.NetlinkGroup{
netfilter.GroupCTNew,
netfilter.GroupCTDestroy,
})
if err != nil {
if err := c.conn.Close(); err != nil {
log.Errorf("Error closing conntrack connection: %v", err)
}
c.conn = nil
return fmt.Errorf("start conntrack listener: %w", err)
}
c.started = true
go c.receiverRoutine(events, errChan)
return nil
}
func (c *ConnTrack) receiverRoutine(events chan nfct.Event, errChan chan error) {
for {
select {
case event := <-events:
c.handleEvent(event)
case err := <-errChan:
log.Errorf("Error from conntrack event listener: %v", err)
if err := c.conn.Close(); err != nil {
log.Errorf("Error closing conntrack connection: %v", err)
}
return
case <-c.done:
return
}
}
}
// Stop stops the connection tracking. This method is idempotent.
func (c *ConnTrack) Stop() {
c.mux.Lock()
defer c.mux.Unlock()
if !c.started {
return
}
log.Info("Stopping conntrack event listening")
select {
case c.done <- struct{}{}:
default:
}
if c.conn != nil {
if err := c.conn.Close(); err != nil {
log.Errorf("Error closing conntrack connection: %v", err)
}
c.conn = nil
}
c.started = false
c.RestoreAccounting()
}
// Close stops listening for events and cleans up resources
func (c *ConnTrack) Close() error {
c.mux.Lock()
defer c.mux.Unlock()
if c.started {
select {
case c.done <- struct{}{}:
default:
}
}
if c.conn != nil {
err := c.conn.Close()
c.conn = nil
c.started = false
c.RestoreAccounting()
if err != nil {
return fmt.Errorf("close conntrack: %w", err)
}
}
return nil
}
// handleEvent processes incoming conntrack events
func (c *ConnTrack) handleEvent(event nfct.Event) {
if event.Flow == nil {
return
}
if event.Type != nfct.EventNew && event.Type != nfct.EventDestroy {
return
}
flow := *event.Flow
proto := nftypes.Protocol(flow.TupleOrig.Proto.Protocol)
if proto == nftypes.ProtocolUnknown {
return
}
srcIP := flow.TupleOrig.IP.SourceAddress
dstIP := flow.TupleOrig.IP.DestinationAddress
if !c.relevantFlow(srcIP, dstIP) {
return
}
var srcPort, dstPort uint16
var icmpType, icmpCode uint8
switch proto {
case nftypes.TCP, nftypes.UDP, nftypes.SCTP:
srcPort = flow.TupleOrig.Proto.SourcePort
dstPort = flow.TupleOrig.Proto.DestinationPort
case nftypes.ICMP:
icmpType = flow.TupleOrig.Proto.ICMPType
icmpCode = flow.TupleOrig.Proto.ICMPCode
}
flowID := c.getFlowID(flow.ID)
direction := c.inferDirection(srcIP, dstIP)
eventType := nftypes.TypeStart
eventStr := "New"
if event.Type == nfct.EventDestroy {
eventType = nftypes.TypeEnd
eventStr = "Ended"
}
log.Tracef("%s %s %s connection: %s:%d -> %s:%d", eventStr, direction, proto, srcIP, srcPort, dstIP, dstPort)
c.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID,
Type: eventType,
Direction: direction,
Protocol: proto,
SourceIP: srcIP,
DestIP: dstIP,
SourcePort: srcPort,
DestPort: dstPort,
ICMPType: icmpType,
ICMPCode: icmpCode,
RxPackets: c.mapRxPackets(flow, direction),
TxPackets: c.mapTxPackets(flow, direction),
RxBytes: c.mapRxBytes(flow, direction),
TxBytes: c.mapTxBytes(flow, direction),
})
}
// relevantFlow checks if the flow is related to the specified interface
func (c *ConnTrack) relevantFlow(srcIP, dstIP netip.Addr) bool {
// TODO: filter traffic by interface
wgnet := c.iface.Address().Network
if !wgnet.Contains(srcIP.AsSlice()) && !wgnet.Contains(dstIP.AsSlice()) {
return false
}
return true
}
// mapRxPackets maps packet counts to RX based on flow direction
func (c *ConnTrack) mapRxPackets(flow nfct.Flow, direction nftypes.Direction) uint64 {
// For Ingress: CountersOrig is from external to us (RX)
// For Egress: CountersReply is from external to us (RX)
if direction == nftypes.Ingress {
return flow.CountersOrig.Packets
}
return flow.CountersReply.Packets
}
// mapTxPackets maps packet counts to TX based on flow direction
func (c *ConnTrack) mapTxPackets(flow nfct.Flow, direction nftypes.Direction) uint64 {
// For Ingress: CountersReply is from us to external (TX)
// For Egress: CountersOrig is from us to external (TX)
if direction == nftypes.Ingress {
return flow.CountersReply.Packets
}
return flow.CountersOrig.Packets
}
// mapRxBytes maps byte counts to RX based on flow direction
func (c *ConnTrack) mapRxBytes(flow nfct.Flow, direction nftypes.Direction) uint64 {
// For Ingress: CountersOrig is from external to us (RX)
// For Egress: CountersReply is from external to us (RX)
if direction == nftypes.Ingress {
return flow.CountersOrig.Bytes
}
return flow.CountersReply.Bytes
}
// mapTxBytes maps byte counts to TX based on flow direction
func (c *ConnTrack) mapTxBytes(flow nfct.Flow, direction nftypes.Direction) uint64 {
// For Ingress: CountersReply is from us to external (TX)
// For Egress: CountersOrig is from us to external (TX)
if direction == nftypes.Ingress {
return flow.CountersReply.Bytes
}
return flow.CountersOrig.Bytes
}
// getFlowID creates a unique UUID based on the conntrack ID and instance ID
func (c *ConnTrack) getFlowID(conntrackID uint32) uuid.UUID {
var buf [4]byte
binary.BigEndian.PutUint32(buf[:], conntrackID)
return uuid.NewSHA1(c.instanceID, buf[:])
}
func (c *ConnTrack) inferDirection(srcIP, dstIP netip.Addr) nftypes.Direction {
wgaddr := c.iface.Address().IP
wgnetwork := c.iface.Address().Network
src, dst := srcIP.AsSlice(), dstIP.AsSlice()
switch {
case wgaddr.Equal(src):
return nftypes.Egress
case wgaddr.Equal(dst):
return nftypes.Ingress
case wgnetwork.Contains(src):
// netbird network -> resource network
return nftypes.Ingress
case wgnetwork.Contains(dst):
// resource network -> netbird network
return nftypes.Egress
// TODO: handle site2site traffic
}
return nftypes.DirectionUnknown
}

View File

@@ -0,0 +1,9 @@
//go:build !linux || android
package conntrack
import nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) nftypes.ConnTracker {
return nil
}

View File

@@ -0,0 +1,73 @@
//go:build linux && !android
package conntrack
import (
"fmt"
"os"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
)
const (
// conntrackAcctPath is the sysctl path for conntrack accounting
conntrackAcctPath = "net.netfilter.nf_conntrack_acct"
)
// EnableAccounting ensures that connection tracking accounting is enabled in the kernel.
func (c *ConnTrack) EnableAccounting() {
// haven't restored yet
if c.sysctlModified {
return
}
modified, err := setSysctl(conntrackAcctPath, 1)
if err != nil {
log.Warnf("Failed to enable conntrack accounting: %v", err)
return
}
c.sysctlModified = modified
}
// RestoreAccounting restores the connection tracking accounting setting to its original value.
func (c *ConnTrack) RestoreAccounting() {
if !c.sysctlModified {
return
}
if _, err := setSysctl(conntrackAcctPath, 0); err != nil {
log.Warnf("Failed to restore conntrack accounting: %v", err)
return
}
c.sysctlModified = false
}
// setSysctl sets a sysctl configuration and returns whether it was modified.
func setSysctl(key string, desiredValue int) (bool, error) {
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
currentValue, err := os.ReadFile(path)
if err != nil {
return false, fmt.Errorf("read sysctl %s: %w", key, err)
}
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
if err != nil && len(currentValue) > 0 {
return false, fmt.Errorf("convert current value to int: %w", err)
}
if currentV == desiredValue {
return false, nil
}
// nolint:gosec
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
return false, fmt.Errorf("write sysctl %s: %w", key, err)
}
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
return true, nil
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/netbirdio/netbird/client/internal/netflow/store" "github.com/netbirdio/netbird/client/internal/netflow/store"
"github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/peer"
) )
type rcvChan chan *types.EventFields type rcvChan chan *types.EventFields
@@ -21,15 +22,17 @@ type Logger struct {
enabled atomic.Bool enabled atomic.Bool
rcvChan atomic.Pointer[rcvChan] rcvChan atomic.Pointer[rcvChan]
cancelReceiver context.CancelFunc cancelReceiver context.CancelFunc
statusRecorder *peer.Status
Store types.Store Store types.Store
} }
func New(ctx context.Context) *Logger { func New(ctx context.Context, statusRecorder *peer.Status) *Logger {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
return &Logger{ return &Logger{
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
Store: store.NewMemoryStore(), statusRecorder: statusRecorder,
Store: store.NewMemoryStore(),
} }
} }
@@ -58,13 +61,14 @@ func (l *Logger) startReceiver() {
if l.enabled.Load() { if l.enabled.Load() {
return return
} }
l.mux.Lock() l.mux.Lock()
ctx, cancel := context.WithCancel(l.ctx) ctx, cancel := context.WithCancel(l.ctx)
l.cancelReceiver = cancel l.cancelReceiver = cancel
l.mux.Unlock() l.mux.Unlock()
c := make(rcvChan, 100) c := make(rcvChan, 100)
l.rcvChan.Swap(&c) l.rcvChan.Store(&c)
l.enabled.Store(true) l.enabled.Store(true)
for { for {
@@ -73,12 +77,15 @@ func (l *Logger) startReceiver() {
log.Info("flow Memory store receiver stopped") log.Info("flow Memory store receiver stopped")
return return
case eventFields := <-c: case eventFields := <-c:
id := uuid.NewString() id := uuid.New()
event := types.Event{ event := types.Event{
ID: id, ID: id,
EventFields: *eventFields, EventFields: *eventFields,
Timestamp: time.Now(), Timestamp: time.Now(),
} }
srcResId, dstResId := l.statusRecorder.CheckRoutes(event.SourceIP, event.DestIP, event.Direction)
event.SourceResourceID = []byte(srcResId)
event.DestResourceID = []byte(dstResId)
l.Store.StoreEvent(&event) l.Store.StoreEvent(&event)
} }
} }
@@ -100,6 +107,7 @@ func (l *Logger) stop() {
l.cancelReceiver() l.cancelReceiver()
l.cancelReceiver = nil l.cancelReceiver = nil
} }
l.rcvChan.Store(nil)
l.mux.Unlock() l.mux.Unlock()
} }
@@ -107,6 +115,10 @@ func (l *Logger) GetEvents() []*types.Event {
return l.Store.GetEvents() return l.Store.GetEvents()
} }
func (l *Logger) DeleteEvents(ids []uuid.UUID) {
l.Store.DeleteEvents(ids)
}
func (l *Logger) Close() { func (l *Logger) Close() {
l.stop() l.stop()
l.cancel() l.cancel()

View File

@@ -12,7 +12,7 @@ import (
) )
func TestStore(t *testing.T) { func TestStore(t *testing.T) {
logger := logger.New(context.Background()) logger := logger.New(context.Background(), nil)
logger.Enable() logger.Enable()
event := types.EventFields{ event := types.EventFields{

View File

@@ -2,47 +2,234 @@ package netflow
import ( import (
"context" "context"
"errors"
"fmt"
"runtime"
"sync" "sync"
"time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/internal/netflow/conntrack"
"github.com/netbirdio/netbird/client/internal/netflow/logger" "github.com/netbirdio/netbird/client/internal/netflow/logger"
"github.com/netbirdio/netbird/client/internal/netflow/types" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/flow/client"
"github.com/netbirdio/netbird/flow/proto"
) )
// Manager handles netflow tracking and logging
type Manager struct { type Manager struct {
mux sync.Mutex mux sync.Mutex
logger types.FlowLogger logger nftypes.FlowLogger
flowConfig *types.FlowConfig flowConfig *nftypes.FlowConfig
conntrack nftypes.ConnTracker
ctx context.Context
receiverClient *client.GRPCClient
publicKey []byte
} }
func NewManager(ctx context.Context) *Manager { // NewManager creates a new netflow manager
func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
flowLogger := logger.New(ctx, statusRecorder)
var ct nftypes.ConnTracker
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {
ct = conntrack.New(flowLogger, iface)
}
return &Manager{ return &Manager{
logger: logger.New(ctx), logger: flowLogger,
conntrack: ct,
ctx: ctx,
publicKey: publicKey,
} }
} }
func (m *Manager) Update(update *types.FlowConfig) error { // Update applies new flow configuration settings
m.mux.Lock() // needsNewClient checks if a new client needs to be created
defer m.mux.Unlock() func (m *Manager) needsNewClient(previous *nftypes.FlowConfig) bool {
if update == nil { current := m.flowConfig
return nil return previous == nil ||
!previous.Enabled ||
previous.TokenPayload != current.TokenPayload ||
previous.TokenSignature != current.TokenSignature ||
previous.URL != current.URL
}
// enableFlow starts components for flow tracking
func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
// first make sender ready so events don't pile up
if m.needsNewClient(previous) {
if m.receiverClient != nil {
if err := m.receiverClient.Close(); err != nil {
log.Warnf("error closing previous flow client: %s", err)
}
}
flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval)
if err != nil {
return fmt.Errorf("create client: %w", err)
}
log.Infof("flow client configured to connect to %s", m.flowConfig.URL)
m.receiverClient = flowClient
go m.receiveACKs(flowClient)
go m.startSender()
} }
m.flowConfig = update m.logger.Enable()
if update.Enabled { if m.conntrack != nil {
m.logger.Enable() if err := m.conntrack.Start(m.flowConfig.Counters); err != nil {
return nil return fmt.Errorf("start conntrack: %w", err)
}
} }
m.logger.Disable()
return nil return nil
} }
// disableFlow stops components for flow tracking
func (m *Manager) disableFlow() error {
if m.conntrack != nil {
m.conntrack.Stop()
}
m.logger.Disable()
if m.receiverClient != nil {
return m.receiverClient.Close()
}
return nil
}
// Update applies new flow configuration settings
func (m *Manager) Update(update *nftypes.FlowConfig) error {
if update == nil {
return nil
}
m.mux.Lock()
defer m.mux.Unlock()
previous := m.flowConfig
m.flowConfig = update
if update.Enabled {
return m.enableFlow(previous)
}
return m.disableFlow()
}
// Close cleans up all resources
func (m *Manager) Close() { func (m *Manager) Close() {
m.mux.Lock()
defer m.mux.Unlock()
if m.conntrack != nil {
m.conntrack.Close()
}
if m.receiverClient != nil {
if err := m.receiverClient.Close(); err != nil {
log.Warnf("failed to close receiver client: %s", err)
}
}
m.logger.Close() m.logger.Close()
} }
func (m *Manager) GetLogger() types.FlowLogger { // GetLogger returns the flow logger
func (m *Manager) GetLogger() nftypes.FlowLogger {
return m.logger return m.logger
} }
func (m *Manager) startSender() {
ticker := time.NewTicker(m.flowConfig.Interval)
defer ticker.Stop()
for {
select {
case <-m.ctx.Done():
return
case <-ticker.C:
events := m.logger.GetEvents()
for _, event := range events {
if err := m.send(event); err != nil {
log.Errorf("failed to send flow event to server: %s", err)
continue
}
log.Tracef("sent flow event: %s", event.ID)
}
}
}
}
func (m *Manager) receiveACKs(client *client.GRPCClient) {
err := client.Receive(m.ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error {
log.Tracef("received flow event ack: %s", ack.EventId)
m.logger.DeleteEvents([]uuid.UUID{uuid.UUID(ack.EventId)})
return nil
})
if err != nil && !errors.Is(err, context.Canceled) {
log.Errorf("failed to receive flow event ack: %s", err)
}
}
func (m *Manager) send(event *nftypes.Event) error {
m.mux.Lock()
client := m.receiverClient
m.mux.Unlock()
if client == nil {
return nil
}
return client.Send(toProtoEvent(m.publicKey, event))
}
func toProtoEvent(publicKey []byte, event *nftypes.Event) *proto.FlowEvent {
protoEvent := &proto.FlowEvent{
EventId: event.ID[:],
Timestamp: timestamppb.New(event.Timestamp),
PublicKey: publicKey,
FlowFields: &proto.FlowFields{
FlowId: event.FlowID[:],
RuleId: event.RuleID,
Type: proto.Type(event.Type),
Direction: proto.Direction(event.Direction),
Protocol: uint32(event.Protocol),
SourceIp: event.SourceIP.AsSlice(),
DestIp: event.DestIP.AsSlice(),
RxPackets: event.RxPackets,
TxPackets: event.TxPackets,
RxBytes: event.RxBytes,
TxBytes: event.TxBytes,
SourceResourceId: event.SourceResourceID,
DestResourceId: event.DestResourceID,
},
}
if event.Protocol == nftypes.ICMP {
protoEvent.FlowFields.ConnectionInfo = &proto.FlowFields_IcmpInfo{
IcmpInfo: &proto.ICMPInfo{
IcmpType: uint32(event.ICMPType),
IcmpCode: uint32(event.ICMPCode),
},
}
return protoEvent
}
protoEvent.FlowFields.ConnectionInfo = &proto.FlowFields_PortInfo{
PortInfo: &proto.PortInfo{
SourcePort: uint32(event.SourcePort),
DestPort: uint32(event.DestPort),
},
}
return protoEvent
}

View File

@@ -3,18 +3,22 @@ package store
import ( import (
"sync" "sync"
"golang.org/x/exp/maps"
"github.com/google/uuid"
"github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
func NewMemoryStore() *Memory { func NewMemoryStore() *Memory {
return &Memory{ return &Memory{
events: make(map[string]*types.Event), events: make(map[uuid.UUID]*types.Event),
} }
} }
type Memory struct { type Memory struct {
mux sync.Mutex mux sync.Mutex
events map[string]*types.Event events map[uuid.UUID]*types.Event
} }
func (m *Memory) StoreEvent(event *types.Event) { func (m *Memory) StoreEvent(event *types.Event) {
@@ -26,7 +30,7 @@ func (m *Memory) StoreEvent(event *types.Event) {
func (m *Memory) Close() { func (m *Memory) Close() {
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
m.events = make(map[string]*types.Event) maps.Clear(m.events)
} }
func (m *Memory) GetEvents() []*types.Event { func (m *Memory) GetEvents() []*types.Event {
@@ -39,7 +43,7 @@ func (m *Memory) GetEvents() []*types.Event {
return events return events
} }
func (m *Memory) DeleteEvents(ids []string) { func (m *Memory) DeleteEvents(ids []uuid.UUID) {
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
for _, id := range ids { for _, id := range ids {

View File

@@ -2,48 +2,98 @@ package types
import ( import (
"net/netip" "net/netip"
"strconv"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type Protocol uint8
const (
ProtocolUnknown = Protocol(0)
ICMP = Protocol(1)
TCP = Protocol(6)
UDP = Protocol(17)
SCTP = Protocol(132)
)
func (p Protocol) String() string {
switch p {
case 1:
return "ICMP"
case 6:
return "TCP"
case 17:
return "UDP"
case 132:
return "SCTP"
default:
return strconv.FormatUint(uint64(p), 10)
}
}
type Type int type Type int
const ( const (
TypeStart = iota TypeUnknown = Type(iota)
TypeStart
TypeEnd TypeEnd
TypeDrop
) )
type Direction int type Direction int
func (d Direction) String() string {
switch d {
case Ingress:
return "ingress"
case Egress:
return "egress"
default:
return "unknown"
}
}
const ( const (
Ingress = iota DirectionUnknown = Direction(iota)
Ingress
Egress Egress
) )
type Event struct { type Event struct {
ID string ID uuid.UUID
Timestamp time.Time Timestamp time.Time
EventFields EventFields
} }
type EventFields struct { type EventFields struct {
FlowID uuid.UUID FlowID uuid.UUID
Type Type Type Type
Direction Direction RuleID []byte
Protocol uint8 Direction Direction
SourceIP netip.Addr Protocol Protocol
DestIP netip.Addr SourceIP netip.Addr
SourcePort uint16 DestIP netip.Addr
DestPort uint16 SourceResourceID []byte
ICMPType uint8 DestResourceID []byte
ICMPCode uint8 SourcePort uint16
DestPort uint16
ICMPType uint8
ICMPCode uint8
RxPackets uint64
TxPackets uint64
RxBytes uint64
TxBytes uint64
} }
type FlowConfig struct { type FlowConfig struct {
URL string URL string
Interval time.Duration Interval time.Duration
Enabled bool Enabled bool
Counters bool
TokenPayload string TokenPayload string
TokenSignature string TokenSignature string
} }
@@ -62,6 +112,8 @@ type FlowLogger interface {
StoreEvent(flowEvent EventFields) StoreEvent(flowEvent EventFields)
// GetEvents returns all stored events // GetEvents returns all stored events
GetEvents() []*Event GetEvents() []*Event
// DeleteEvents deletes events from the store
DeleteEvents([]uuid.UUID)
// Close closes the logger // Close closes the logger
Close() Close()
// Enable enables the flow logger receiver // Enable enables the flow logger receiver
@@ -76,7 +128,24 @@ type Store interface {
// GetEvents returns all stored events // GetEvents returns all stored events
GetEvents() []*Event GetEvents() []*Event
// DeleteEvents deletes events from the store // DeleteEvents deletes events from the store
DeleteEvents([]string) DeleteEvents([]uuid.UUID)
// Close closes the store // Close closes the store
Close() Close()
} }
// ConnTracker defines the interface for connection tracking functionality
type ConnTracker interface {
// Start begins tracking connections by listening for conntrack events.
Start(bool) error
// Stop stops the connection tracking.
Stop()
// Close stops listening for events and cleans up resources
Close() error
}
// IFaceMapper provides interface to check if we're using userspace WireGuard
type IFaceMapper interface {
IsUserspaceBind() bool
Name() string
Address() wgaddr.Address
}

View File

@@ -16,7 +16,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
) )
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error { func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil { if err != nil {
return fmt.Errorf("failed to open routing socket: %v", err) return fmt.Errorf("failed to open routing socket: %v", err)
@@ -28,18 +28,10 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
} }
}() }()
go func() {
<-ctx.Done()
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Debugf("Network monitor: closed routing socket: %v", err)
}
}()
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ErrStopped return ctx.Err()
default: default:
buf := make([]byte, 2048) buf := make([]byte, 2048)
n, err := unix.Read(fd, buf) n, err := unix.Read(fd, buf)
@@ -76,11 +68,11 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
switch msg.Type { switch msg.Type {
case unix.RTM_ADD: case unix.RTM_ADD:
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf) log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
go callback() return nil
case unix.RTM_DELETE: case unix.RTM_DELETE:
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 { if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf) log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
go callback() return nil
} }
} }
} }

View File

@@ -14,7 +14,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
) )
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error { func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
if nexthopv4.Intf == nil && nexthopv6.Intf == nil { if nexthopv4.Intf == nil && nexthopv6.Intf == nil {
return errors.New("no interfaces available") return errors.New("no interfaces available")
} }
@@ -31,8 +31,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ErrStopped return ctx.Err()
// handle route changes // handle route changes
case route := <-routeChan: case route := <-routeChan:
// default route and main table // default route and main table
@@ -43,12 +42,10 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
// triggered on added/replaced routes // triggered on added/replaced routes
case syscall.RTM_NEWROUTE: case syscall.RTM_NEWROUTE:
log.Infof("Network monitor: default route changed: via %s, interface %d", route.Gw, route.LinkIndex) log.Infof("Network monitor: default route changed: via %s, interface %d", route.Gw, route.LinkIndex)
go callback()
return nil return nil
case syscall.RTM_DELROUTE: case syscall.RTM_DELROUTE:
if nexthopv4.Intf != nil && route.Gw.Equal(nexthopv4.IP.AsSlice()) || nexthopv6.Intf != nil && route.Gw.Equal(nexthopv6.IP.AsSlice()) { if nexthopv4.Intf != nil && route.Gw.Equal(nexthopv4.IP.AsSlice()) || nexthopv6.Intf != nil && route.Gw.Equal(nexthopv6.IP.AsSlice()) {
log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex) log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex)
go callback()
return nil return nil
} }
} }

View File

@@ -10,7 +10,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
) )
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error { func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
routeMonitor, err := systemops.NewRouteMonitor(ctx) routeMonitor, err := systemops.NewRouteMonitor(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to create route monitor: %w", err) return fmt.Errorf("failed to create route monitor: %w", err)
@@ -24,20 +24,20 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ErrStopped return ctx.Err()
case route := <-routeMonitor.RouteUpdates(): case route := <-routeMonitor.RouteUpdates():
if route.Destination.Bits() != 0 { if route.Destination.Bits() != 0 {
continue continue
} }
if routeChanged(route, nexthopv4, nexthopv6, callback) { if routeChanged(route, nexthopv4, nexthopv6) {
break return nil
} }
} }
} }
} }
func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) bool { func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop) bool {
intf := "<nil>" intf := "<nil>"
if route.Interface != nil { if route.Interface != nil {
intf = route.Interface.Name intf = route.Interface.Name
@@ -51,18 +51,15 @@ func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Ne
case systemops.RouteModified: case systemops.RouteModified:
// TODO: get routing table to figure out if our route is affected for modified routes // TODO: get routing table to figure out if our route is affected for modified routes
log.Infof("Network monitor: default route changed: via %s, interface %s", route.NextHop, intf) log.Infof("Network monitor: default route changed: via %s, interface %s", route.NextHop, intf)
go callback()
return true return true
case systemops.RouteAdded: case systemops.RouteAdded:
if route.NextHop.Is4() && route.NextHop != nexthopv4.IP || route.NextHop.Is6() && route.NextHop != nexthopv6.IP { if route.NextHop.Is4() && route.NextHop != nexthopv4.IP || route.NextHop.Is6() && route.NextHop != nexthopv6.IP {
log.Infof("Network monitor: default route added: via %s, interface %s", route.NextHop, intf) log.Infof("Network monitor: default route added: via %s, interface %s", route.NextHop, intf)
go callback()
return true return true
} }
case systemops.RouteDeleted: case systemops.RouteDeleted:
if nexthopv4.Intf != nil && route.NextHop == nexthopv4.IP || nexthopv6.Intf != nil && route.NextHop == nexthopv6.IP { if nexthopv4.Intf != nil && route.NextHop == nexthopv4.IP || nexthopv6.Intf != nil && route.NextHop == nexthopv6.IP {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.NextHop, intf) log.Infof("Network monitor: default route removed: via %s, interface %s", route.NextHop, intf)
go callback()
return true return true
} }
} }

View File

@@ -1,12 +1,27 @@
//go:build !ios && !android
package networkmonitor package networkmonitor
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"net/netip"
"runtime/debug"
"sync" "sync"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
) )
var ErrStopped = errors.New("monitor has been stopped") const (
debounceTime = 2 * time.Second
)
var checkChangeFn = checkChange
// NetworkMonitor watches for changes in network configuration. // NetworkMonitor watches for changes in network configuration.
type NetworkMonitor struct { type NetworkMonitor struct {
@@ -19,3 +34,99 @@ type NetworkMonitor struct {
func New() *NetworkMonitor { func New() *NetworkMonitor {
return &NetworkMonitor{} return &NetworkMonitor{}
} }
// Listen begins monitoring network changes. When a change is detected, this function will return without error.
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
nw.mu.Lock()
if nw.cancel != nil {
nw.mu.Unlock()
return errors.New("network monitor already started")
}
ctx, nw.cancel = context.WithCancel(ctx)
defer nw.cancel()
nw.wg.Add(1)
nw.mu.Unlock()
defer nw.wg.Done()
var nexthop4, nexthop6 systemops.Nexthop
operation := func() error {
var errv4, errv6 error
nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified())
nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified())
if errv4 != nil && errv6 != nil {
return errors.New("failed to get default next hops")
}
if errv4 == nil {
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name)
}
if errv6 == nil {
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name)
}
// continue if either route was found
return nil
}
expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
if err := backoff.Retry(operation, expBackOff); err != nil {
return fmt.Errorf("failed to get default next hops: %w", err)
}
// recover in case sys ops panic
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, debug.Stack())
}
}()
event := make(chan struct{}, 1)
go nw.checkChanges(ctx, event, nexthop4, nexthop6)
// debounce changes
timer := time.NewTimer(0)
timer.Stop()
for {
select {
case <-event:
timer.Reset(debounceTime)
case <-timer.C:
return nil
case <-ctx.Done():
timer.Stop()
return ctx.Err()
}
}
}
// Stop stops the network monitor.
func (nw *NetworkMonitor) Stop() {
nw.mu.Lock()
defer nw.mu.Unlock()
if nw.cancel == nil {
return
}
nw.cancel()
nw.wg.Wait()
}
func (nw *NetworkMonitor) checkChanges(ctx context.Context, event chan struct{}, nexthop4 systemops.Nexthop, nexthop6 systemops.Nexthop) {
for {
if err := checkChangeFn(ctx, nexthop4, nexthop6); err != nil {
close(event)
return
}
// prevent blocking
select {
case event <- struct{}{}:
default:
}
}
}

View File

@@ -1,82 +0,0 @@
//go:build !ios && !android
package networkmonitor
import (
"context"
"errors"
"fmt"
"net/netip"
"runtime/debug"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
// Start begins monitoring network changes. When a change is detected, it calls the callback asynchronously and returns.
func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error) {
if ctx.Err() != nil {
return ctx.Err()
}
nw.mu.Lock()
ctx, nw.cancel = context.WithCancel(ctx)
nw.mu.Unlock()
nw.wg.Add(1)
defer nw.wg.Done()
var nexthop4, nexthop6 systemops.Nexthop
operation := func() error {
var errv4, errv6 error
nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified())
nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified())
if errv4 != nil && errv6 != nil {
return errors.New("failed to get default next hops")
}
if errv4 == nil {
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name)
}
if errv6 == nil {
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name)
}
// continue if either route was found
return nil
}
expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
if err := backoff.Retry(operation, expBackOff); err != nil {
return fmt.Errorf("failed to get default next hops: %w", err)
}
// recover in case sys ops panic
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, debug.Stack())
}
}()
if err := checkChange(ctx, nexthop4, nexthop6, callback); err != nil {
return fmt.Errorf("check change: %w", err)
}
return nil
}
// Stop stops the network monitor.
func (nw *NetworkMonitor) Stop() {
nw.mu.Lock()
defer nw.mu.Unlock()
if nw.cancel != nil {
nw.cancel()
nw.wg.Wait()
}
}

Some files were not shown because too many files have changed in this diff Show More