Compare commits

...

60 Commits

Author SHA1 Message Date
Pascal Fischer
4d2c774378 refactor networm map generation 2025-03-13 14:29:59 +01:00
Pascal Fischer
ab2e3fec72 expose resource type consts 2025-03-12 13:49:24 +01:00
Hakan Sariman
47f88f7057 Refactor routeIDLookup methods to use Addr() for resolved IP operations 2025-03-11 19:43:58 +08:00
Hakan Sariman
ee33a6ed7c Refactor RemoveLocalPeerStateRoute to eliminate resourceId parameter 2025-03-11 13:19:30 +08:00
Hakan Sariman
da662cfd08 Add source and destination resource IDs to FlowFields 2025-03-11 13:12:54 +08:00
Hakan Sariman
ed2ee1ee9d Merge branch 'feature/flow' into feat/flow-resid 2025-03-11 13:08:11 +08:00
Viktor Liu
76d73548d6 Fix more conflicts 2025-03-10 18:46:01 +01:00
Viktor Liu
11828a064a Fix conflict 2025-03-10 18:35:32 +01:00
Viktor Liu
0c2a3dd937 Merge branch 'main' into feature/flow 2025-03-10 18:30:45 +01:00
Zoltan Papp
cd9eff5331 Increase the timeout to 50 sec (#3481) 2025-03-10 18:23:47 +01:00
Viktor Liu
47dcf8d68c Fix forwarder IP source/destination (#3463) 2025-03-10 14:55:07 +01:00
Viktor Liu
80ceb80197 [client] Ignore candidates that are part of the the wireguard subnet (#3472) 2025-03-10 13:59:21 +01:00
Bethuel Mmbaga
cc8f6bcaf3 [management] Fix tests circular dependency (#3460)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-03-10 15:54:36 +03:00
Zoltan Papp
636a0e2475 [client] Fix engine restart (#3435)
- Refactor the network monitoring to handle one event and it after return
- In the engine restart cancel the upper layer context and the responsibility of the engine stop will be the upper layer
- Before triggering a restart, the engine checks whether the state is already down. This helps avoid unnecessary delayed network restart events.
2025-03-10 13:32:12 +01:00
Viktor Liu
e66e329bf6 [client] Add option to autostart netbird ui in the Windows installer (#3469) 2025-03-10 13:19:17 +01:00
Zoltan Papp
aaa23beeec [client] Prevent to block channel writing (#3474)
The "runningChan" provides feedback to the UI or any client about whether the service is up and running. If the client exits earlier than when the service successfully starts, then this channel causes a block.

- Added timeout for reading the channel to ensure we don't cause blocks for too long for the caller
- Modified channel writing operations to be non-blocking
2025-03-10 13:17:09 +01:00
Zoltan Papp
6bef474e9e [client] Prevent panic in case of double close call (#3475)
Prevent panic in case of double close call
2025-03-10 13:16:28 +01:00
Maycon Santos
81040ff80a [docs] Update typo (#3477) 2025-03-10 11:52:36 +01:00
Viktor Liu
c73481aee4 [client] Enable windows stderr logs by default (#3476) 2025-03-10 11:30:49 +01:00
Hakan Sariman
92286b2541 Implement routeIDLookup for managing local and remote route IDs 2025-03-10 15:58:45 +08:00
Maycon Santos
d8bcf745b0 update integrations 2025-03-09 19:32:38 +01:00
Maycon Santos
8430139d80 fix missing method 2025-03-09 19:03:57 +01:00
Maycon Santos
a2962b4ce0 sync go.sum 2025-03-09 18:50:20 +01:00
Maycon Santos
16fffdb75b sync changes from #3426 2025-03-09 18:48:48 +01:00
Maycon Santos
036cecbf46 update integrations and go mod 2025-03-09 18:47:05 +01:00
Maycon Santos
3482852bb6 sync proto and sum 2025-03-09 18:02:33 +01:00
Maycon Santos
fd62665b1f Merge branch 'main' into feature/flow
# Conflicts:
#	client/cmd/testutil_test.go
#	client/firewall/iptables/router_linux.go
#	client/firewall/nftables/router_linux.go
#	client/firewall/uspfilter/allow_netbird.go
#	client/firewall/uspfilter/allow_netbird_windows.go
#	client/firewall/uspfilter/uspfilter_test.go
#	client/internal/engine.go
#	client/internal/engine_test.go
#	client/server/server_test.go
#	go.mod
#	go.sum
#	management/client/client_test.go
#	management/cmd/management.go
#	management/proto/management.pb.go
#	management/proto/management.proto
#	management/server/account.go
#	management/server/account_test.go
#	management/server/dns_test.go
#	management/server/http/handler.go
#	management/server/http/testing/testing_tools/tools.go
#	management/server/integrations/port_forwarding/controller.go
#	management/server/management_proto_test.go
#	management/server/management_test.go
#	management/server/nameserver_test.go
#	management/server/peer.go
#	management/server/peer_test.go
#	management/server/route_test.go
2025-03-09 17:42:16 +01:00
Viktor Liu
fc1da94520 [client, management] Add port forwarding (#3275)
Add initial support to ingress ports on the client code.

- new types where added
- new protocol messages and controller
2025-03-09 16:06:43 +01:00
Hakan Sariman
1ffe48f0d4 Add nil check in CheckRoutes to prevent potential panic 2025-03-08 12:54:33 +03:00
Hakan Sariman
a3b8a21385 Refactor CheckRoutes to return resource IDs for matching source and destination addresses 2025-03-08 12:26:53 +03:00
Hakan Sariman
86492b88c4 Refactor route handling to simplify route information and improve state management 2025-03-08 12:25:35 +03:00
Hakan Sariman
d08a629f9e Merge branch 'feature/flow' into feat/flow-resid 2025-03-08 12:18:02 +03:00
Viktor Liu
36da464413 Fix tracer test 2025-03-07 17:19:10 +01:00
Hakan Sariman
268e3404d3 Merge branch 'feature/flow' into feat/flow-resid 2025-03-07 18:52:11 +03:00
Hakan Sariman
54d0591833 Refactor route handling to use RouteWithResourceId for improved state management 2025-03-07 18:43:49 +03:00
Muzammil
ae6b61301c Muz/netbird dashboards (#3458)
* added all 3 dashboards

* update readme
2025-03-07 16:13:11 +01:00
Viktor Liu
86370a0e7b Use bytes for flows event id (#3439) 2025-03-07 16:12:47 +01:00
Philippe Vaucher
a444e551b3 [misc] Traefik config improvements (#3346)
* Remove deprecated docker-compose version

* Prettify docker-compose files

* Backports missing logging entries

* Fix signal port

* Add missing relay configuration

* Serve management over 33073 to avoid confusion
2025-03-07 16:10:11 +01:00
Zoltan Papp
53b9a2002f Print out the goroutine id (#3433)
The TXT logger prints out the actual go routine ID

This feature depends on 'loggoroutine' build tag

```go build -tags loggoroutine```
2025-03-07 14:06:47 +01:00
Viktor Liu
cb16d0f45f Align packet tracer behavior with actual code paths (#3424) 2025-03-07 14:03:45 +01:00
Viktor Liu
e8d8bd8f18 Add peer traffic rule IDs to allowed connections in flows (#3442) 2025-03-07 13:56:26 +01:00
Viktor Liu
8b07f21c28 Don't track intercepted packets (#3448) 2025-03-07 13:56:16 +01:00
Viktor Liu
54be772ffd Handle flow updates (#3455) 2025-03-07 13:56:00 +01:00
Zoltan Papp
4b76d93cec [client] Fix TURN-Relay switch (#3456)
- When a peer is connected with TURN and a Relay connection is established, do not force switching to Relay. Keep using TURN until disconnection.

-In the proxy preparation phase, the Bind Proxy does not set the remote conn as a fake address for Bind. When running the Work() function, the proper proxy instance updates the conn inside the Bind.
2025-03-07 12:00:25 +01:00
Viktor Liu
3c3a454e61 Fix merge regression 2025-03-06 16:54:15 +01:00
Viktor Liu
5ff77b3595 Add flow userspace counters (#3438) 2025-03-06 16:52:56 +01:00
Viktor Liu
b180edbe5c Track icmp with id only (#3447) 2025-03-06 14:51:23 +01:00
Hakan Sariman
de3b5c78d7 Fix nil pointer dereference in CheckRoutes method 2025-03-06 14:10:31 +03:00
Hakan Sariman
0b42f40cf6 Refactor route management to include resource IDs in state handling 2025-03-06 13:51:46 +03:00
Viktor Liu
062d1ec76f [misc] Update bug-issue-report.md template (#3449) 2025-03-06 01:10:37 +01:00
Viktor Liu
0a042ac36d Fix merge conflict 2025-03-05 19:11:20 +01:00
Viktor Liu
c111675dd8 [client] Handle large DNS packets in dns route resolution (#3441) 2025-03-05 18:57:17 +01:00
Hakan Sariman
e7f921d787 [client] add resource id fields to netflow events 2025-03-05 20:35:52 +03:00
Viktor Liu
e9f11fb11b Replace net.IP with netip.Addr (#3425) 2025-03-05 18:28:05 +01:00
hakansa
419ed275fa Handle TCP RST flag to transition connection state to closed (#3432) 2025-03-05 18:25:42 +01:00
hakansa
60ffe0dc87 [client] UI Refactor Icon Paths (#3420)
[client] UI Refactor Icon Paths (#3420)
2025-03-04 18:29:29 +03:00
Viktor Liu
bcc5824980 [client] Close userspace firewall properly (#3426) 2025-03-04 11:19:42 +01:00
robertgro
af5796de1c [client] Add Netbird GitHub link to the client ui about sub menu (#3372) 2025-03-03 17:32:50 +01:00
Philippe Vaucher
9d604b7e66 [client Fix env var typo (#3415) 2025-03-03 17:22:51 +01:00
Bethuel Mmbaga
82c12cc8ae [management] Handle transaction error on peer deletion (#3387)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-25 19:57:04 +00:00
211 changed files with 9453 additions and 1768 deletions

View File

@@ -31,14 +31,22 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
`netbird version` `netbird version`
**NetBird status -dA output:** **Is any other VPN software installed?**
If applicable, add the `netbird status -dA' command output. If yes, which one?
**Do you face any (non-mobile) client issues?** **Debug output**
Please provide the file created by `netbird debug for 1m -AS`. To help us resolve the problem, please attach the following debug output
We advise reviewing the anonymized files for any remaining PII.
netbird status -dA
As well as the file created by
netbird debug for 1m -AS
We advise reviewing the anonymized output for any remaining personal information.
**Screenshots** **Screenshots**
@@ -47,3 +55,10 @@ If applicable, add screenshots to help explain your problem.
**Additional context** **Additional context**
Add any other context about the problem here. Add any other context about the problem here.
**Have you tried these troubleshooting steps?**
- [ ] Checked for newer NetBird versions
- [ ] Searched for similar issues on GitHub (including closed ones)
- [ ] Restarted the NetBird client
- [ ] Disabled other VPN software
- [ ] Checked firewall settings

View File

@@ -258,7 +258,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
arch: [ '386','amd64' ] arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ] store: [ 'sqlite', 'postgres', 'mysql' ]
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
@@ -325,8 +325,8 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
arch: [ '386','amd64' ] arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ] store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Install Go - name: Install Go
@@ -392,7 +392,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
arch: [ '386','amd64' ] arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ] store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
@@ -461,7 +461,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
arch: [ '386','amd64' ] arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres'] store: [ 'sqlite', 'postgres']
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:

View File

@@ -71,7 +71,7 @@ jobs:
- name: Install goversioninfo - name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64 - name: Generate windows syso amd64
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso run: goversioninfo -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
- name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4
with: with:
@@ -150,7 +150,7 @@ jobs:
- name: Install goversioninfo - name: Install goversioninfo
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
- name: Generate windows syso amd64 - name: Generate windows syso amd64
run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso run: goversioninfo -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
- name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@v4 uses: goreleaser/goreleaser-action@v4

View File

@@ -53,9 +53,9 @@ nfpms:
scripts: scripts:
postinstall: "release_files/ui-post-install.sh" postinstall: "release_files/ui-post-install.sh"
contents: contents:
- src: client/ui/netbird.desktop - src: client/ui/build/netbird.desktop
dst: /usr/share/applications/netbird.desktop dst: /usr/share/applications/netbird.desktop
- src: client/ui/netbird.png - src: client/ui/assets/netbird.png
dst: /usr/share/pixmaps/netbird.png dst: /usr/share/pixmaps/netbird.png
dependencies: dependencies:
- netbird - netbird
@@ -72,9 +72,9 @@ nfpms:
scripts: scripts:
postinstall: "release_files/ui-post-install.sh" postinstall: "release_files/ui-post-install.sh"
contents: contents:
- src: client/ui/netbird.desktop - src: client/ui/build/netbird.desktop
dst: /usr/share/applications/netbird.desktop dst: /usr/share/applications/netbird.desktop
- src: client/ui/netbird.png - src: client/ui/assets/netbird.png
dst: /usr/share/pixmaps/netbird.png dst: /usr/share/pixmaps/netbird.png
dependencies: dependencies:
- netbird - netbird

View File

@@ -134,10 +134,11 @@ func (c *Client) Start(startCtx context.Context) error {
// either startup error (permanent backoff err) or nil err (successful engine up) // either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available // TODO: make after-startup backoff err available
run := make(chan error, 1) run := make(chan struct{}, 1)
clientErr := make(chan error, 1)
go func() { go func() {
if err := client.Run(run); err != nil { if err := client.Run(run); err != nil {
run <- err clientErr <- err
} }
}() }()
@@ -147,13 +148,9 @@ func (c *Client) Start(startCtx context.Context) error {
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err()) return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
} }
return startCtx.Err() return startCtx.Err()
case err := <-run: case err := <-clientErr:
if err != nil { return fmt.Errorf("startup: %w", err)
if stopErr := client.Stop(); stopErr != nil { case <-run:
return fmt.Errorf("stop error after failed to startup. Stop error: %w. Start error: %w", stopErr, err)
}
return fmt.Errorf("startup: %w", err)
}
} }
c.connect = client c.connect = client

View File

@@ -4,12 +4,13 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// IFaceMapper defines subset methods of interface required for manager // IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface { type IFaceMapper interface {
Name() string Name() string
Address() device.WGAddress Address() wgaddr.Address
IsUserspaceBind() bool IsUserspaceBind() bool
SetFilter(device.PacketFilter) error SetFilter(device.PacketFilter) error
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice

View File

@@ -13,7 +13,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -31,7 +31,7 @@ type Manager struct {
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface { type iFaceMapper interface {
Name() string Name() string
Address() iface.WGAddress Address() wgaddr.Address
IsUserspaceBind() bool IsUserspaceBind() bool
} }
@@ -167,7 +167,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
} }
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset(stateManager *statemanager.Manager) error { func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()

View File

@@ -10,15 +10,15 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
var ifaceMock = &iFaceMock{ var ifaceMock = &iFaceMock{
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
}, },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"), IP: net.ParseIP("10.20.0.0"),
@@ -31,7 +31,7 @@ var ifaceMock = &iFaceMock{
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct { type iFaceMock struct {
NameFunc func() string NameFunc func() string
AddressFunc func() iface.WGAddress AddressFunc func() wgaddr.Address
} }
func (i *iFaceMock) Name() string { func (i *iFaceMock) Name() string {
@@ -41,7 +41,7 @@ func (i *iFaceMock) Name() string {
panic("NameFunc is not set") panic("NameFunc is not set")
} }
func (i *iFaceMock) Address() iface.WGAddress { func (i *iFaceMock) Address() wgaddr.Address {
if i.AddressFunc != nil { if i.AddressFunc != nil {
return i.AddressFunc() return i.AddressFunc()
} }
@@ -62,7 +62,7 @@ func TestIptablesManager(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
err := manager.Reset(nil) err := manager.Close(nil)
require.NoError(t, err, "clear the manager state") require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -100,14 +100,14 @@ func TestIptablesManager(t *testing.T) {
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "") _, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Reset(nil) err = manager.Close(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules) ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
require.NoError(t, err, "failed check chain exists") require.NoError(t, err, "failed check chain exists")
if ok { if ok {
require.NoErrorf(t, err, "chain '%v' still exists after Reset", chainNameInputRules) require.NoErrorf(t, err, "chain '%v' still exists after Close", chainNameInputRules)
} }
}) })
} }
@@ -117,8 +117,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
}, },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"), IP: net.ParseIP("10.20.0.0"),
@@ -136,7 +136,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
err := manager.Reset(nil) err := manager.Close(nil)
require.NoError(t, err, "clear the manager state") require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -166,7 +166,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
}) })
t.Run("reset check", func(t *testing.T) { t.Run("reset check", func(t *testing.T) {
err = manager.Reset(nil) err = manager.Close(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
}) })
} }
@@ -184,8 +184,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
}, },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"), IP: net.ParseIP("10.20.0.0"),
@@ -204,7 +204,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
err := manager.Reset(nil) err := manager.Close(nil)
require.NoError(t, err, "clear the manager state") require.NoError(t, err, "clear the manager state")
time.Sleep(time.Second) time.Sleep(time.Second)

View File

@@ -4,21 +4,20 @@ import (
"fmt" "fmt"
"sync" "sync"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/device"
) )
type InterfaceState struct { type InterfaceState struct {
NameStr string `json:"name"` NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"` WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"` UserspaceBind bool `json:"userspace_bind"`
} }
func (i *InterfaceState) Name() string { func (i *InterfaceState) Name() string {
return i.NameStr return i.NameStr
} }
func (i *InterfaceState) Address() device.WGAddress { func (i *InterfaceState) Address() wgaddr.Address {
return i.WGAddress return i.WGAddress
} }
@@ -62,7 +61,7 @@ func (s *ShutdownState) Cleanup() error {
ipt.aclMgr.ipsetStore = s.ACLIPsetStore ipt.aclMgr.ipsetStore = s.ACLIPsetStore
} }
if err := ipt.Reset(nil); err != nil { if err := ipt.Close(nil); err != nil {
return fmt.Errorf("reset iptables manager: %w", err) return fmt.Errorf("reset iptables manager: %w", err)
} }

View File

@@ -102,8 +102,8 @@ type Manager interface {
// SetLegacyManagement sets the legacy management mode // SetLegacyManagement sets the legacy management mode
SetLegacyManagement(legacy bool) error SetLegacyManagement(legacy bool) error
// Reset firewall to the default state // Close closes the firewall manager
Reset(stateManager *statemanager.Manager) error Close(stateManager *statemanager.Manager) error
// Flush the changes to firewall controller // Flush the changes to firewall controller
Flush() error Flush() error

View File

@@ -14,7 +14,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -29,7 +29,7 @@ const (
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface { type iFaceMapper interface {
Name() string Name() string
Address() iface.WGAddress Address() wgaddr.Address
IsUserspaceBind() bool IsUserspaceBind() bool
} }
@@ -87,7 +87,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
// We only need to record minimal interface state for potential recreation. // We only need to record minimal interface state for potential recreation.
// Unlike iptables, which requires tracking individual rules, nftables maintains // Unlike iptables, which requires tracking individual rules, nftables maintains
// a known state (our netbird table plus a few static rules). This allows for easy // a known state (our netbird table plus a few static rules). This allows for easy
// cleanup using Reset() without needing to store specific rules. // cleanup using Close() without needing to store specific rules.
if err := stateManager.UpdateState(&ShutdownState{ if err := stateManager.UpdateState(&ShutdownState{
InterfaceState: &InterfaceState{ InterfaceState: &InterfaceState{
NameStr: m.wgIface.Name(), NameStr: m.wgIface.Name(),
@@ -243,7 +243,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
} }
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset(stateManager *statemanager.Manager) error { func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()

View File

@@ -16,15 +16,15 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
var ifaceMock = &iFaceMock{ var ifaceMock = &iFaceMock{
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
}, },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("100.96.0.1"), IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"), IP: net.ParseIP("100.96.0.0"),
@@ -37,7 +37,7 @@ var ifaceMock = &iFaceMock{
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct { type iFaceMock struct {
NameFunc func() string NameFunc func() string
AddressFunc func() iface.WGAddress AddressFunc func() wgaddr.Address
} }
func (i *iFaceMock) Name() string { func (i *iFaceMock) Name() string {
@@ -47,7 +47,7 @@ func (i *iFaceMock) Name() string {
panic("NameFunc is not set") panic("NameFunc is not set")
} }
func (i *iFaceMock) Address() iface.WGAddress { func (i *iFaceMock) Address() wgaddr.Address {
if i.AddressFunc != nil { if i.AddressFunc != nil {
return i.AddressFunc() return i.AddressFunc()
} }
@@ -65,7 +65,7 @@ func TestNftablesManager(t *testing.T) {
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
defer func() { defer func() {
err = manager.Reset(nil) err = manager.Close(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
time.Sleep(time.Second) time.Sleep(time.Second)
}() }()
@@ -162,7 +162,7 @@ func TestNftablesManager(t *testing.T) {
// established rule remains // established rule remains
require.Len(t, rules, 1, "expected 1 rules after deletion") require.Len(t, rules, 1, "expected 1 rules after deletion")
err = manager.Reset(nil) err = manager.Close(nil)
require.NoError(t, err, "failed to reset") require.NoError(t, err, "failed to reset")
} }
@@ -171,8 +171,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
}, },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("100.96.0.1"), IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"), IP: net.ParseIP("100.96.0.0"),
@@ -191,7 +191,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
defer func() { defer func() {
if err := manager.Reset(nil); err != nil { if err := manager.Close(nil); err != nil {
t.Errorf("clear the manager state: %v", err) t.Errorf("clear the manager state: %v", err)
} }
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -274,7 +274,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))
t.Cleanup(func() { t.Cleanup(func() {
err := manager.Reset(nil) err := manager.Close(nil)
require.NoError(t, err, "failed to reset manager state") require.NoError(t, err, "failed to reset manager state")
// Verify iptables output after reset // Verify iptables output after reset

View File

@@ -38,7 +38,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
// need fw manager to init both acl mgr and router for all chains to be present // need fw manager to init both acl mgr and router for all chains to be present
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock)
t.Cleanup(func() { t.Cleanup(func() {
require.NoError(t, manager.Reset(nil)) require.NoError(t, manager.Close(nil))
}) })
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))
@@ -127,7 +127,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock)
t.Cleanup(func() { t.Cleanup(func() {
require.NoError(t, manager.Reset(nil)) require.NoError(t, manager.Close(nil))
}) })
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))

View File

@@ -3,21 +3,20 @@ package nftables
import ( import (
"fmt" "fmt"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/device"
) )
type InterfaceState struct { type InterfaceState struct {
NameStr string `json:"name"` NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"` WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"` UserspaceBind bool `json:"userspace_bind"`
} }
func (i *InterfaceState) Name() string { func (i *InterfaceState) Name() string {
return i.NameStr return i.NameStr
} }
func (i *InterfaceState) Address() device.WGAddress { func (i *InterfaceState) Address() wgaddr.Address {
return i.WGAddress return i.WGAddress
} }
@@ -39,7 +38,7 @@ func (s *ShutdownState) Cleanup() error {
return fmt.Errorf("create nftables manager: %w", err) return fmt.Errorf("create nftables manager: %w", err)
} }
if err := nft.Reset(nil); err != nil { if err := nft.Close(nil); err != nil {
return fmt.Errorf("reset nftables manager: %w", err) return fmt.Errorf("reset nftables manager: %w", err)
} }

View File

@@ -4,39 +4,36 @@ package uspfilter
import ( import (
"context" "context"
"net/netip"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset(stateManager *statemanager.Manager) error { func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet) m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[string]RuleSet) m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil { if m.udpTracker != nil {
m.udpTracker.Close() m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
} }
if m.icmpTracker != nil { if m.icmpTracker != nil {
m.icmpTracker.Close() m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
} }
if m.tcpTracker != nil { if m.tcpTracker != nil {
m.tcpTracker.Close() m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
} }
if m.forwarder != nil { if fwder := m.forwarder.Load(); fwder != nil {
m.forwarder.Stop() fwder.Stop()
} }
if m.logger != nil { if m.logger != nil {
@@ -48,7 +45,7 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
} }
if m.nativeFirewall != nil { if m.nativeFirewall != nil {
return m.nativeFirewall.Reset(stateManager) return m.nativeFirewall.Close(stateManager)
} }
return nil return nil
} }

View File

@@ -3,6 +3,7 @@ package uspfilter
import ( import (
"context" "context"
"fmt" "fmt"
"net/netip"
"os/exec" "os/exec"
"syscall" "syscall"
"time" "time"
@@ -22,12 +23,12 @@ const (
) )
// Reset firewall to the default state // Reset firewall to the default state
func (m *Manager) Reset(*statemanager.Manager) error { func (m *Manager) Close(*statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet) m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[string]RuleSet) m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil { if m.udpTracker != nil {
m.udpTracker.Close() m.udpTracker.Close()
@@ -44,8 +45,8 @@ func (m *Manager) Reset(*statemanager.Manager) error {
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
} }
if m.forwarder != nil { if fwder := m.forwarder.Load(); fwder != nil {
m.forwarder.Stop() fwder.Stop()
} }
if m.logger != nil { if m.logger != nil {

View File

@@ -3,14 +3,14 @@ package common
import ( import (
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// IFaceMapper defines subset methods of interface required for manager // IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface { type IFaceMapper interface {
SetFilter(device.PacketFilter) error SetFilter(device.PacketFilter) error
Address() iface.WGAddress Address() wgaddr.Address
GetWGDevice() *wgdevice.Device GetWGDevice() *wgdevice.Device
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
} }

View File

@@ -2,7 +2,6 @@ package conntrack
import ( import (
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -14,13 +13,15 @@ import (
// BaseConnTrack provides common fields and locking for all connection types // BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct { type BaseConnTrack struct {
FlowId uuid.UUID FlowId uuid.UUID
Direction nftypes.Direction Direction nftypes.Direction
SourceIP netip.Addr SourceIP netip.Addr
DestIP netip.Addr DestIP netip.Addr
SourcePort uint16 lastSeen atomic.Int64
DestPort uint16 PacketsTx atomic.Uint64
lastSeen atomic.Int64 PacketsRx atomic.Uint64
BytesTx atomic.Uint64
BytesRx atomic.Uint64
} }
// these small methods will be inlined by the compiler // these small methods will be inlined by the compiler
@@ -30,6 +31,17 @@ func (b *BaseConnTrack) UpdateLastSeen() {
b.lastSeen.Store(time.Now().UnixNano()) b.lastSeen.Store(time.Now().UnixNano())
} }
// UpdateCounters safely updates the packet and byte counters
func (b *BaseConnTrack) UpdateCounters(direction nftypes.Direction, bytes int) {
if direction == nftypes.Egress {
b.PacketsTx.Add(1)
b.BytesTx.Add(uint64(bytes))
} else {
b.PacketsRx.Add(1)
b.BytesRx.Add(uint64(bytes))
}
}
// GetLastSeen safely gets the last seen timestamp // GetLastSeen safely gets the last seen timestamp
func (b *BaseConnTrack) GetLastSeen() time.Time { func (b *BaseConnTrack) GetLastSeen() time.Time {
return time.Unix(0, b.lastSeen.Load()) return time.Unix(0, b.lastSeen.Load())
@@ -52,16 +64,3 @@ type ConnKey struct {
func (c ConnKey) String() string { func (c ConnKey) String() string {
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort) return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
} }
// makeConnKey creates a connection key
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
srcAddr, _ := netip.AddrFromSlice(srcIP)
dstAddr, _ := netip.AddrFromSlice(dstIP)
return ConnKey{
SrcIP: srcAddr,
DstIP: dstAddr,
SrcPort: srcPort,
DstPort: dstPort,
}
}

View File

@@ -2,7 +2,7 @@ package conntrack
import ( import (
"context" "context"
"net" "net/netip"
"testing" "testing"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@@ -12,7 +12,7 @@ import (
) )
var logger = log.NewFromLogrus(logrus.StandardLogger()) var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}).GetLogger() var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
// Memory pressure tests // Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) { func BenchmarkMemoryPressure(b *testing.B) {
@@ -21,22 +21,22 @@ func BenchmarkMemoryPressure(b *testing.B) {
defer tracker.Close() defer tracker.Close()
// Generate different IPs // Generate different IPs
srcIPs := make([]net.IP, 100) srcIPs := make([]netip.Addr, 100)
dstIPs := make([]net.IP, 100) dstIPs := make([]netip.Addr, 100)
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256)) srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256)) dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs) srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs) dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn) tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn, 0)
// Simulate some valid inbound packets // Simulate some valid inbound packets
if i%3 == 0 { if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck) tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck, 0)
} }
} }
}) })
@@ -46,22 +46,22 @@ func BenchmarkMemoryPressure(b *testing.B) {
defer tracker.Close() defer tracker.Close()
// Generate different IPs // Generate different IPs
srcIPs := make([]net.IP, 100) srcIPs := make([]netip.Addr, 100)
dstIPs := make([]net.IP, 100) dstIPs := make([]netip.Addr, 100)
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256)) srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256)) dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs) srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs) dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80) tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, 0)
// Simulate some valid inbound packets // Simulate some valid inbound packets
if i%3 == 0 { if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535)) tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), 0)
} }
} }
}) })

View File

@@ -1,8 +1,8 @@
package conntrack package conntrack
import ( import (
"context"
"fmt" "fmt"
"net"
"net/netip" "net/netip"
"sync" "sync"
"time" "time"
@@ -23,14 +23,13 @@ const (
// ICMPConnKey uniquely identifies an ICMP connection // ICMPConnKey uniquely identifies an ICMP connection
type ICMPConnKey struct { type ICMPConnKey struct {
SrcIP netip.Addr SrcIP netip.Addr
DstIP netip.Addr DstIP netip.Addr
Sequence uint16 ID uint16
ID uint16
} }
func (i ICMPConnKey) String() string { func (i ICMPConnKey) String() string {
return fmt.Sprintf("%s -> %s (%d/%d)", i.SrcIP, i.DstIP, i.ID, i.Sequence) return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
} }
// ICMPConnTrack represents an ICMP connection state // ICMPConnTrack represents an ICMP connection state
@@ -46,8 +45,8 @@ type ICMPTracker struct {
connections map[ICMPConnKey]*ICMPConnTrack connections map[ICMPConnKey]*ICMPConnTrack
timeout time.Duration timeout time.Duration
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex mutex sync.RWMutex
done chan struct{}
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
} }
@@ -57,21 +56,27 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty
timeout = DefaultICMPTimeout timeout = DefaultICMPTimeout
} }
ctx, cancel := context.WithCancel(context.Background())
tracker := &ICMPTracker{ tracker := &ICMPTracker{
logger: logger, logger: logger,
connections: make(map[ICMPConnKey]*ICMPConnTrack), connections: make(map[ICMPConnKey]*ICMPConnTrack),
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval), cleanupTicker: time.NewTicker(ICMPCleanupInterval),
done: make(chan struct{}), tickerCancel: cancel,
flowLogger: flowLogger, flowLogger: flowLogger,
} }
go tracker.cleanupRoutine() go tracker.cleanupRoutine(ctx)
return tracker return tracker
} }
func (t *ICMPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) (ICMPConnKey, bool) { func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, direction nftypes.Direction, size int) (ICMPConnKey, bool) {
key := makeICMPKey(srcIP, dstIP, id, seq) key := ICMPConnKey{
SrcIP: srcIP,
DstIP: dstIP,
ID: id,
}
t.mutex.RLock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
@@ -79,6 +84,7 @@ func (t *ICMPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, id uint16, seq
if exists { if exists {
conn.UpdateLastSeen() conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, true return key, true
} }
@@ -87,22 +93,21 @@ func (t *ICMPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, id uint16, seq
} }
// TrackOutbound records an outbound ICMP connection // TrackOutbound records an outbound ICMP connection
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) { func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, id, seq); !exists { if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction // if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, id, seq, typecode, nftypes.Egress) t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
} }
} }
// TrackInbound records an inbound ICMP Echo Request // TrackInbound records an inbound ICMP Echo Request
func (t *ICMPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode) { func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
t.track(srcIP, dstIP, id, seq, typecode, nftypes.Ingress) t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
} }
// track is the common implementation for tracking both inbound and outbound ICMP connections // track is the common implementation for tracking both inbound and outbound ICMP connections
func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction) { func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
// TODO: icmp doesn't need to extend the timeout key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
key, exists := t.updateIfExists(srcIP, dstIP, id, seq)
if exists { if exists {
return return
} }
@@ -112,7 +117,7 @@ func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, t
// non echo requests don't need tracking // non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) { if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code) t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendStartEvent(direction, key, typ, code) t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
return return
} }
@@ -120,8 +125,8 @@ func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, t
BaseConnTrack: BaseConnTrack{ BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(), FlowId: uuid.New(),
Direction: direction, Direction: direction,
SourceIP: key.SrcIP, SourceIP: srcIP,
DestIP: key.DstIP, DestIP: dstIP,
}, },
ICMPType: typ, ICMPType: typ,
ICMPCode: code, ICMPCode: code,
@@ -133,16 +138,20 @@ func (t *ICMPTracker) track(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, t
t.mutex.Unlock() t.mutex.Unlock()
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code) t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendEvent(nftypes.TypeStart, key, conn) t.sendEvent(nftypes.TypeStart, conn, ruleId)
} }
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request // IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool { func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
if icmpType != uint8(layers.ICMPv4TypeEchoReply) { if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
return false return false
} }
key := makeICMPKey(dstIP, srcIP, id, seq) key := ICMPConnKey{
SrcIP: dstIP,
DstIP: srcIP,
ID: id,
}
t.mutex.RLock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
@@ -153,16 +162,19 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq
} }
conn.UpdateLastSeen() conn.UpdateLastSeen()
conn.UpdateCounters(nftypes.Ingress, size)
return true return true
} }
func (t *ICMPTracker) cleanupRoutine() { func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
defer t.tickerCancel()
for { for {
select { select {
case <-t.cleanupTicker.C: case <-t.cleanupTicker.C:
t.cleanup() t.cleanup()
case <-t.done: case <-ctx.Done():
return return
} }
} }
@@ -176,56 +188,58 @@ func (t *ICMPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) { if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key) delete(t.connections, key)
t.logger.Debug("Removed ICMP connection %s (timeout)", &key) t.logger.Debug("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
t.sendEvent(nftypes.TypeEnd, key, conn) key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
} }
} }
// Close stops the cleanup routine and releases resources // Close stops the cleanup routine and releases resources
func (t *ICMPTracker) Close() { func (t *ICMPTracker) Close() {
t.cleanupTicker.Stop() t.tickerCancel()
close(t.done)
t.mutex.Lock() t.mutex.Lock()
t.connections = nil t.connections = nil
t.mutex.Unlock() t.mutex.Unlock()
} }
func (t *ICMPTracker) sendEvent(typ nftypes.Type, key ICMPConnKey, conn *ICMPConnTrack) { func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []byte) {
t.flowLogger.StoreEvent(nftypes.EventFields{ t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId, FlowID: conn.FlowId,
Type: typ, Type: typ,
RuleID: ruleID,
Direction: conn.Direction, Direction: conn.Direction,
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6 Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
SourceIP: key.SrcIP, SourceIP: conn.SourceIP,
DestIP: key.DstIP, DestIP: conn.DestIP,
ICMPType: conn.ICMPType, ICMPType: conn.ICMPType,
ICMPCode: conn.ICMPCode, ICMPCode: conn.ICMPCode,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
}) })
} }
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, key ICMPConnKey, typ, code uint8) { func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8, ruleID []byte, size int) {
t.flowLogger.StoreEvent(nftypes.EventFields{ fields := nftypes.EventFields{
FlowID: uuid.New(), FlowID: uuid.New(),
Type: nftypes.TypeStart, Type: nftypes.TypeStart,
RuleID: ruleID,
Direction: direction, Direction: direction,
Protocol: nftypes.ICMP, Protocol: nftypes.ICMP,
SourceIP: key.SrcIP, SourceIP: srcIP,
DestIP: key.DstIP, DestIP: dstIP,
ICMPType: typ, ICMPType: typ,
ICMPCode: code, ICMPCode: code,
})
}
// makeICMPKey creates an ICMP connection key
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
srcAddr, _ := netip.AddrFromSlice(srcIP)
dstAddr, _ := netip.AddrFromSlice(dstIP)
return ICMPConnKey{
SrcIP: srcAddr,
DstIP: dstAddr,
ID: id,
Sequence: seq,
} }
if direction == nftypes.Ingress {
fields.RxPackets = 1
fields.RxBytes = uint64(size)
} else {
fields.TxPackets = 1
fields.TxBytes = uint64(size)
}
t.flowLogger.StoreEvent(fields)
} }

View File

@@ -1,7 +1,7 @@
package conntrack package conntrack
import ( import (
"net" "net/netip"
"testing" "testing"
) )
@@ -10,12 +10,12 @@ func BenchmarkICMPTracker(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger) tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535), 0) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
} }
}) })
@@ -23,17 +23,17 @@ func BenchmarkICMPTracker(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger) tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections // Pre-populate some connections
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i), 0) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0) tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), 0, 0)
} }
}) })
} }

View File

@@ -3,7 +3,8 @@ package conntrack
// TODO: Send RST packets for invalid/timed-out connections // TODO: Send RST packets for invalid/timed-out connections
import ( import (
"net" "context"
"net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -88,6 +89,8 @@ const (
// TCPConnTrack represents a TCP connection state // TCPConnTrack represents a TCP connection state
type TCPConnTrack struct { type TCPConnTrack struct {
BaseConnTrack BaseConnTrack
SourcePort uint16
DestPort uint16
State TCPState State TCPState
established atomic.Bool established atomic.Bool
tombstone atomic.Bool tombstone atomic.Bool
@@ -120,7 +123,7 @@ type TCPTracker struct {
connections map[ConnKey]*TCPConnTrack connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex mutex sync.RWMutex
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
done chan struct{} tickerCancel context.CancelFunc
timeout time.Duration timeout time.Duration
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
} }
@@ -131,21 +134,28 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
timeout = DefaultTCPTimeout timeout = DefaultTCPTimeout
} }
ctx, cancel := context.WithCancel(context.Background())
tracker := &TCPTracker{ tracker := &TCPTracker{
logger: logger, logger: logger,
connections: make(map[ConnKey]*TCPConnTrack), connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval), cleanupTicker: time.NewTicker(TCPCleanupInterval),
done: make(chan struct{}), tickerCancel: cancel,
timeout: timeout, timeout: timeout,
flowLogger: flowLogger, flowLogger: flowLogger,
} }
go tracker.cleanupRoutine() go tracker.cleanupRoutine(ctx)
return tracker return tracker
} }
func (t *TCPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) (ConnKey, bool) { func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
t.mutex.RLock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
@@ -154,9 +164,10 @@ func (t *TCPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16,
if exists { if exists {
conn.Lock() conn.Lock()
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress) t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
conn.UpdateLastSeen()
conn.Unlock() conn.Unlock()
conn.UpdateCounters(direction, size)
return key, true return key, true
} }
@@ -164,37 +175,36 @@ func (t *TCPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16,
} }
// TrackOutbound records an outbound TCP connection // TrackOutbound records an outbound TCP connection
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) { func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags); !exists { if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, 0, 0); !exists {
// if (inverted direction) conn is not tracked, track this direction // if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress) t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
} }
} }
// TrackInbound processes an inbound TCP packet and updates connection state // TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) { func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress) t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
} }
// track is the common implementation for tracking both inbound and outbound connections // track is the common implementation for tracking both inbound and outbound connections
func (t *TCPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction) { func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags) key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
if exists { if exists {
return return
} }
conn := &TCPConnTrack{ conn := &TCPConnTrack{
BaseConnTrack: BaseConnTrack{ BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(), FlowId: uuid.New(),
Direction: direction, Direction: direction,
SourceIP: key.SrcIP, SourceIP: srcIP,
DestIP: key.DstIP, DestIP: dstIP,
SourcePort: srcPort,
DestPort: dstPort,
}, },
SourcePort: srcPort,
DestPort: dstPort,
} }
conn.UpdateLastSeen()
conn.established.Store(false) conn.established.Store(false)
conn.tombstone.Store(false) conn.tombstone.Store(false)
@@ -205,12 +215,17 @@ func (t *TCPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort u
t.connections[key] = conn t.connections[key] = conn
t.mutex.Unlock() t.mutex.Unlock()
t.sendEvent(nftypes.TypeStart, key, conn) t.sendEvent(nftypes.TypeStart, conn, ruleID)
} }
// IsValidInbound checks if an inbound TCP packet matches a tracked connection // IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool { func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) bool {
key := makeConnKey(dstIP, srcIP, dstPort, srcPort) key := ConnKey{
SrcIP: dstIP,
DstIP: srcIP,
SrcPort: dstPort,
DstPort: srcPort,
}
t.mutex.RLock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
@@ -231,15 +246,15 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
conn.State = TCPStateClosed conn.State = TCPStateClosed
conn.SetEstablished(false) conn.SetEstablished(false)
conn.Unlock() conn.Unlock()
conn.UpdateCounters(nftypes.Ingress, size)
t.logger.Trace("TCP connection reset: %s", key) t.logger.Trace("TCP connection reset: %s", key)
t.sendEvent(nftypes.TypeEnd, key, conn) t.sendEvent(nftypes.TypeEnd, conn, nil)
return true return true
} }
conn.Lock() conn.Lock()
t.updateState(key, conn, flags, false) t.updateState(key, conn, flags, false)
conn.UpdateLastSeen()
isEstablished := conn.IsEstablished() isEstablished := conn.IsEstablished()
isValidState := t.isValidStateForFlags(conn.State, flags) isValidState := t.isValidStateForFlags(conn.State, flags)
conn.Unlock() conn.Unlock()
@@ -249,6 +264,8 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
// updateState updates the TCP connection state based on flags // updateState updates the TCP connection state based on flags
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) { func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) {
conn.UpdateLastSeen()
state := conn.State state := conn.State
defer func() { defer func() {
if state != conn.State { if state != conn.State {
@@ -287,17 +304,24 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
conn.State = TCPStateCloseWait conn.State = TCPStateCloseWait
} }
conn.SetEstablished(false) conn.SetEstablished(false)
} else if flags&TCPRst != 0 {
conn.State = TCPStateClosed
conn.SetTombstone()
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateFinWait1: case TCPStateFinWait1:
switch { switch {
case flags&TCPFin != 0 && flags&TCPAck != 0: case flags&TCPFin != 0 && flags&TCPAck != 0:
// Simultaneous close - both sides sent FIN
conn.State = TCPStateClosing conn.State = TCPStateClosing
case flags&TCPFin != 0: case flags&TCPFin != 0:
conn.State = TCPStateFinWait2 conn.State = TCPStateFinWait2
case flags&TCPAck != 0: case flags&TCPAck != 0:
conn.State = TCPStateFinWait2 conn.State = TCPStateFinWait2
case flags&TCPRst != 0:
conn.State = TCPStateClosed
conn.SetTombstone()
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateFinWait2: case TCPStateFinWait2:
@@ -305,7 +329,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
conn.State = TCPStateTimeWait conn.State = TCPStateTimeWait
t.logger.Trace("TCP connection %s completed", key) t.logger.Trace("TCP connection %s completed", key)
t.sendEvent(nftypes.TypeEnd, key, conn) t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateClosing: case TCPStateClosing:
@@ -314,7 +338,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
// Keep established = false from previous state // Keep established = false from previous state
t.logger.Trace("TCP connection %s closed (simultaneous)", key) t.logger.Trace("TCP connection %s closed (simultaneous)", key)
t.sendEvent(nftypes.TypeEnd, key, conn) t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateCloseWait: case TCPStateCloseWait:
@@ -328,7 +352,7 @@ func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, i
conn.SetTombstone() conn.SetTombstone()
// Send close event for gracefully closed connections // Send close event for gracefully closed connections
t.sendEvent(nftypes.TypeEnd, key, conn) t.sendEvent(nftypes.TypeEnd, conn, nil)
t.logger.Trace("TCP connection %s closed gracefully", key) t.logger.Trace("TCP connection %s closed gracefully", key)
} }
} }
@@ -375,12 +399,14 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
return false return false
} }
func (t *TCPTracker) cleanupRoutine() { func (t *TCPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop()
for { for {
select { select {
case <-t.cleanupTicker.C: case <-t.cleanupTicker.C:
t.cleanup() t.cleanup()
case <-t.done: case <-ctx.Done():
return return
} }
} }
@@ -411,11 +437,11 @@ func (t *TCPTracker) cleanup() {
// Return IPs to pool // Return IPs to pool
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace("Cleaned up timed-out TCP connection %s", &key) t.logger.Trace("Cleaned up timed-out TCP connection %s", key)
// event already handled by state change // event already handled by state change
if conn.State != TCPStateTimeWait { if conn.State != TCPStateTimeWait {
t.sendEvent(nftypes.TypeEnd, key, conn) t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
} }
} }
@@ -423,8 +449,7 @@ func (t *TCPTracker) cleanup() {
// Close stops the cleanup routine and releases resources // Close stops the cleanup routine and releases resources
func (t *TCPTracker) Close() { func (t *TCPTracker) Close() {
t.cleanupTicker.Stop() t.tickerCancel()
close(t.done)
// Clean up all remaining IPs // Clean up all remaining IPs
t.mutex.Lock() t.mutex.Lock()
@@ -446,15 +471,20 @@ func isValidFlagCombination(flags uint8) bool {
return true return true
} }
func (t *TCPTracker) sendEvent(typ nftypes.Type, key ConnKey, conn *TCPConnTrack) { func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack, ruleID []byte) {
t.flowLogger.StoreEvent(nftypes.EventFields{ t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId, FlowID: conn.FlowId,
Type: typ, Type: typ,
RuleID: ruleID,
Direction: conn.Direction, Direction: conn.Direction,
Protocol: nftypes.TCP, Protocol: nftypes.TCP,
SourceIP: key.SrcIP, SourceIP: conn.SourceIP,
DestIP: key.DstIP, DestIP: conn.DestIP,
SourcePort: key.SrcPort, SourcePort: conn.SourcePort,
DestPort: key.DstPort, DestPort: conn.DestPort,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
}) })
} }

View File

@@ -1,7 +1,7 @@
package conntrack package conntrack
import ( import (
"net" "net/netip"
"testing" "testing"
"time" "time"
@@ -12,8 +12,8 @@ func TestTCPStateMachine(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1") srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := net.ParseIP("100.64.0.2") dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345) srcPort := uint16(12345)
dstPort := uint16(80) dstPort := uint16(80)
@@ -58,7 +58,7 @@ func TestTCPStateMachine(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags) isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags, 0)
require.Equal(t, !tt.wantDrop, isValid, tt.desc) require.Equal(t, !tt.wantDrop, isValid, tt.desc)
}) })
} }
@@ -76,17 +76,17 @@ func TestTCPStateMachine(t *testing.T) {
t.Helper() t.Helper()
// Send initial SYN // Send initial SYN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
// Receive SYN-ACK // Receive SYN-ACK
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
require.True(t, valid, "SYN-ACK should be allowed") require.True(t, valid, "SYN-ACK should be allowed")
// Send ACK // Send ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// Test data transfer // Test data transfer
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck) valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 0)
require.True(t, valid, "Data should be allowed after handshake") require.True(t, valid, "Data should be allowed after handshake")
}, },
}, },
@@ -99,18 +99,18 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Send FIN // Send FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
// Receive ACK for FIN // Receive ACK for FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid, "ACK for FIN should be allowed") require.True(t, valid, "ACK for FIN should be allowed")
// Receive FIN from other side // Receive FIN from other side
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck) valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid, "FIN should be allowed") require.True(t, valid, "FIN should be allowed")
// Send final ACK // Send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
}, },
}, },
{ {
@@ -122,7 +122,7 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Receive RST // Receive RST
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
require.True(t, valid, "RST should be allowed for established connection") require.True(t, valid, "RST should be allowed for established connection")
// Connection is logically dead but we don't enforce blocking subsequent packets // Connection is logically dead but we don't enforce blocking subsequent packets
@@ -138,13 +138,13 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Both sides send FIN+ACK // Both sides send FIN+ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid, "Simultaneous FIN should be allowed") require.True(t, valid, "Simultaneous FIN should be allowed")
// Both sides send final ACK // Both sides send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck) valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid, "Final ACKs should be allowed") require.True(t, valid, "Final ACKs should be allowed")
}, },
}, },
@@ -165,8 +165,8 @@ func TestRSTHandling(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1") srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := net.ParseIP("100.64.0.2") dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345) srcPort := uint16(12345)
dstPort := uint16(80) dstPort := uint16(80)
@@ -181,12 +181,12 @@ func TestRSTHandling(t *testing.T) {
name: "RST in established", name: "RST in established",
setupState: func() { setupState: func() {
// Establish connection first // Establish connection first
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
}, },
sendRST: func() { sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
}, },
wantValid: true, wantValid: true,
desc: "Should accept RST for established connection", desc: "Should accept RST for established connection",
@@ -195,7 +195,7 @@ func TestRSTHandling(t *testing.T) {
name: "RST without connection", name: "RST without connection",
setupState: func() {}, setupState: func() {},
sendRST: func() { sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
}, },
wantValid: false, wantValid: false,
desc: "Should reject RST without connection", desc: "Should reject RST without connection",
@@ -208,7 +208,12 @@ func TestRSTHandling(t *testing.T) {
tt.sendRST() tt.sendRST()
// Verify connection state is as expected // Verify connection state is as expected
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key] conn := tracker.connections[key]
if tt.wantValid { if tt.wantValid {
require.NotNil(t, conn) require.NotNil(t, conn)
@@ -220,15 +225,15 @@ func TestRSTHandling(t *testing.T) {
} }
// Helper to establish a TCP connection // Helper to establish a TCP connection
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) { func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
t.Helper() t.Helper()
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
require.True(t, valid, "SYN-ACK should be allowed") require.True(t, valid, "SYN-ACK should be allowed")
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
} }
func BenchmarkTCPTracker(b *testing.B) { func BenchmarkTCPTracker(b *testing.B) {
@@ -236,12 +241,12 @@ func BenchmarkTCPTracker(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
} }
}) })
@@ -249,17 +254,17 @@ func BenchmarkTCPTracker(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections // Pre-populate some connections
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck) tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, 0)
} }
}) })
@@ -267,16 +272,16 @@ func BenchmarkTCPTracker(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
i := 0 i := 0
for pb.Next() { for pb.Next() {
if i%2 == 0 { if i%2 == 0 {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
} else { } else {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck) tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, 0)
} }
i++ i++
} }
@@ -291,10 +296,10 @@ func BenchmarkCleanup(b *testing.B) {
defer tracker.Close() defer tracker.Close()
// Pre-populate with expired connections // Pre-populate with expired connections
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
for i := 0; i < 10000; i++ { for i := 0; i < 10000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
} }
// Wait for connections to expire // Wait for connections to expire

View File

@@ -1,7 +1,8 @@
package conntrack package conntrack
import ( import (
"net" "context"
"net/netip"
"sync" "sync"
"time" "time"
@@ -21,6 +22,8 @@ const (
// UDPConnTrack represents a UDP connection state // UDPConnTrack represents a UDP connection state
type UDPConnTrack struct { type UDPConnTrack struct {
BaseConnTrack BaseConnTrack
SourcePort uint16
DestPort uint16
} }
// UDPTracker manages UDP connection states // UDPTracker manages UDP connection states
@@ -29,8 +32,8 @@ type UDPTracker struct {
connections map[ConnKey]*UDPConnTrack connections map[ConnKey]*UDPConnTrack
timeout time.Duration timeout time.Duration
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex mutex sync.RWMutex
done chan struct{}
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
} }
@@ -40,34 +43,41 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
timeout = DefaultUDPTimeout timeout = DefaultUDPTimeout
} }
ctx, cancel := context.WithCancel(context.Background())
tracker := &UDPTracker{ tracker := &UDPTracker{
logger: logger, logger: logger,
connections: make(map[ConnKey]*UDPConnTrack), connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval), cleanupTicker: time.NewTicker(UDPCleanupInterval),
done: make(chan struct{}), tickerCancel: cancel,
flowLogger: flowLogger, flowLogger: flowLogger,
} }
go tracker.cleanupRoutine() go tracker.cleanupRoutine(ctx)
return tracker return tracker
} }
// TrackOutbound records an outbound UDP connection // TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) { func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort); !exists { if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction // if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress) t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
} }
} }
// TrackInbound records an inbound UDP connection // TrackInbound records an inbound UDP connection
func (t *UDPTracker) TrackInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) { func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress) t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
} }
func (t *UDPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) (ConnKey, bool) { func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
t.mutex.RLock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
@@ -75,6 +85,7 @@ func (t *UDPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16,
if exists { if exists {
conn.UpdateLastSeen() conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, true return key, true
} }
@@ -82,21 +93,21 @@ func (t *UDPTracker) updateIfExists(srcIP net.IP, dstIP net.IP, srcPort uint16,
} }
// track is the common implementation for tracking both inbound and outbound connections // track is the common implementation for tracking both inbound and outbound connections
func (t *UDPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, direction nftypes.Direction) { func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort) key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
if exists { if exists {
return return
} }
conn := &UDPConnTrack{ conn := &UDPConnTrack{
BaseConnTrack: BaseConnTrack{ BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(), FlowId: uuid.New(),
Direction: direction, Direction: direction,
SourceIP: key.SrcIP, SourceIP: srcIP,
DestIP: key.DstIP, DestIP: dstIP,
SourcePort: srcPort,
DestPort: dstPort,
}, },
SourcePort: srcPort,
DestPort: dstPort,
} }
conn.UpdateLastSeen() conn.UpdateLastSeen()
@@ -105,12 +116,17 @@ func (t *UDPTracker) track(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort u
t.mutex.Unlock() t.mutex.Unlock()
t.logger.Trace("New %s UDP connection: %s", direction, key) t.logger.Trace("New %s UDP connection: %s", direction, key)
t.sendEvent(nftypes.TypeStart, key, conn) t.sendEvent(nftypes.TypeStart, conn, ruleID)
} }
// IsValidInbound checks if an inbound packet matches a tracked connection // IsValidInbound checks if an inbound packet matches a tracked connection
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool { func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) bool {
key := makeConnKey(dstIP, srcIP, dstPort, srcPort) key := ConnKey{
SrcIP: dstIP,
DstIP: srcIP,
SrcPort: dstPort,
DstPort: srcPort,
}
t.mutex.RLock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
@@ -121,17 +137,20 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
} }
conn.UpdateLastSeen() conn.UpdateLastSeen()
conn.UpdateCounters(nftypes.Ingress, size)
return true return true
} }
// cleanupRoutine periodically removes stale connections // cleanupRoutine periodically removes stale connections
func (t *UDPTracker) cleanupRoutine() { func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop()
for { for {
select { select {
case <-t.cleanupTicker.C: case <-t.cleanupTicker.C:
t.cleanup() t.cleanup()
case <-t.done: case <-ctx.Done():
return return
} }
} }
@@ -145,16 +164,16 @@ func (t *UDPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) { if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace("Removed UDP connection %s (timeout)", key) t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
t.sendEvent(nftypes.TypeEnd, key, conn) key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
} }
} }
// Close stops the cleanup routine and releases resources // Close stops the cleanup routine and releases resources
func (t *UDPTracker) Close() { func (t *UDPTracker) Close() {
t.cleanupTicker.Stop() t.tickerCancel()
close(t.done)
t.mutex.Lock() t.mutex.Lock()
t.connections = nil t.connections = nil
@@ -162,11 +181,16 @@ func (t *UDPTracker) Close() {
} }
// GetConnection safely retrieves a connection state // GetConnection safely retrieves a connection state
func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) { func (t *UDPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*UDPConnTrack, bool) {
t.mutex.RLock() t.mutex.RLock()
defer t.mutex.RUnlock() defer t.mutex.RUnlock()
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn, exists := t.connections[key] conn, exists := t.connections[key]
return conn, exists return conn, exists
} }
@@ -176,15 +200,20 @@ func (t *UDPTracker) Timeout() time.Duration {
return t.timeout return t.timeout
} }
func (t *UDPTracker) sendEvent(typ nftypes.Type, key ConnKey, conn *UDPConnTrack) { func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack, ruleID []byte) {
t.flowLogger.StoreEvent(nftypes.EventFields{ t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId, FlowID: conn.FlowId,
Type: typ, Type: typ,
RuleID: ruleID,
Direction: conn.Direction, Direction: conn.Direction,
Protocol: nftypes.UDP, Protocol: nftypes.UDP,
SourceIP: key.SrcIP, SourceIP: conn.SourceIP,
DestIP: key.DstIP, DestIP: conn.DestIP,
SourcePort: key.SrcPort, SourcePort: conn.SourcePort,
DestPort: key.DstPort, DestPort: conn.DestPort,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
}) })
} }

View File

@@ -1,7 +1,7 @@
package conntrack package conntrack
import ( import (
"net" "context"
"net/netip" "net/netip"
"testing" "testing"
"time" "time"
@@ -35,7 +35,7 @@ func TestNewUDPTracker(t *testing.T) {
assert.Equal(t, tt.wantTimeout, tracker.timeout) assert.Equal(t, tt.wantTimeout, tracker.timeout)
assert.NotNil(t, tracker.connections) assert.NotNil(t, tracker.connections)
assert.NotNil(t, tracker.cleanupTicker) assert.NotNil(t, tracker.cleanupTicker)
assert.NotNil(t, tracker.done) assert.NotNil(t, tracker.tickerCancel)
}) })
} }
} }
@@ -49,10 +49,15 @@ func TestUDPTracker_TrackOutbound(t *testing.T) {
srcPort := uint16(12345) srcPort := uint16(12345)
dstPort := uint16(53) dstPort := uint16(53)
tracker.TrackOutbound(srcIP.AsSlice(), dstIP.AsSlice(), srcPort, dstPort) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
// Verify connection was tracked // Verify connection was tracked
key := makeConnKey(srcIP.AsSlice(), dstIP.AsSlice(), srcPort, dstPort) key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn, exists := tracker.connections[key] conn, exists := tracker.connections[key]
require.True(t, exists) require.True(t, exists)
assert.True(t, conn.SourceIP.Compare(srcIP) == 0) assert.True(t, conn.SourceIP.Compare(srcIP) == 0)
@@ -66,18 +71,18 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
tracker := NewUDPTracker(1*time.Second, logger, flowLogger) tracker := NewUDPTracker(1*time.Second, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2") srcIP := netip.MustParseAddr("192.168.1.2")
dstIP := net.ParseIP("192.168.1.3") dstIP := netip.MustParseAddr("192.168.1.3")
srcPort := uint16(12345) srcPort := uint16(12345)
dstPort := uint16(53) dstPort := uint16(53)
// Track outbound connection // Track outbound connection
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
tests := []struct { tests := []struct {
name string name string
srcIP net.IP srcIP netip.Addr
dstIP net.IP dstIP netip.Addr
srcPort uint16 srcPort uint16
dstPort uint16 dstPort uint16
sleep time.Duration sleep time.Duration
@@ -94,7 +99,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
}, },
{ {
name: "invalid source IP", name: "invalid source IP",
srcIP: net.ParseIP("192.168.1.4"), srcIP: netip.MustParseAddr("192.168.1.4"),
dstIP: srcIP, dstIP: srcIP,
srcPort: dstPort, srcPort: dstPort,
dstPort: srcPort, dstPort: srcPort,
@@ -104,7 +109,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
{ {
name: "invalid destination IP", name: "invalid destination IP",
srcIP: dstIP, srcIP: dstIP,
dstIP: net.ParseIP("192.168.1.4"), dstIP: netip.MustParseAddr("192.168.1.4"),
srcPort: dstPort, srcPort: dstPort,
dstPort: srcPort, dstPort: srcPort,
sleep: 0, sleep: 0,
@@ -144,7 +149,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
if tt.sleep > 0 { if tt.sleep > 0 {
time.Sleep(tt.sleep) time.Sleep(tt.sleep)
} }
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort) got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort, 0)
assert.Equal(t, tt.want, got) assert.Equal(t, tt.want, got)
}) })
} }
@@ -155,42 +160,45 @@ func TestUDPTracker_Cleanup(t *testing.T) {
timeout := 50 * time.Millisecond timeout := 50 * time.Millisecond
cleanupInterval := 25 * time.Millisecond cleanupInterval := 25 * time.Millisecond
ctx, tickerCancel := context.WithCancel(context.Background())
defer tickerCancel()
// Create tracker with custom cleanup interval // Create tracker with custom cleanup interval
tracker := &UDPTracker{ tracker := &UDPTracker{
connections: make(map[ConnKey]*UDPConnTrack), connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(cleanupInterval), cleanupTicker: time.NewTicker(cleanupInterval),
done: make(chan struct{}), tickerCancel: tickerCancel,
logger: logger, logger: logger,
flowLogger: flowLogger, flowLogger: flowLogger,
} }
// Start cleanup routine // Start cleanup routine
go tracker.cleanupRoutine() go tracker.cleanupRoutine(ctx)
// Add some connections // Add some connections
connections := []struct { connections := []struct {
srcIP net.IP srcIP netip.Addr
dstIP net.IP dstIP netip.Addr
srcPort uint16 srcPort uint16
dstPort uint16 dstPort uint16
}{ }{
{ {
srcIP: net.ParseIP("192.168.1.2"), srcIP: netip.MustParseAddr("192.168.1.2"),
dstIP: net.ParseIP("192.168.1.3"), dstIP: netip.MustParseAddr("192.168.1.3"),
srcPort: 12345, srcPort: 12345,
dstPort: 53, dstPort: 53,
}, },
{ {
srcIP: net.ParseIP("192.168.1.4"), srcIP: netip.MustParseAddr("192.168.1.4"),
dstIP: net.ParseIP("192.168.1.5"), dstIP: netip.MustParseAddr("192.168.1.5"),
srcPort: 12346, srcPort: 12346,
dstPort: 53, dstPort: 53,
}, },
} }
for _, conn := range connections { for _, conn := range connections {
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort) tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort, 0)
} }
// Verify initial connections // Verify initial connections
@@ -215,12 +223,12 @@ func BenchmarkUDPTracker(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger) tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, 0)
} }
}) })
@@ -228,17 +236,17 @@ func BenchmarkUDPTracker(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger) tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections // Pre-populate some connections
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, 0)
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000)) tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), 0)
} }
}) })
} }

View File

@@ -15,13 +15,16 @@ import (
// handleICMP handles ICMP packets from the network stack // handleICMP handles ICMP packets from the network stack
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool { func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
flowID := uuid.New()
// Extract ICMP header to get type and code
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice()) icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
icmpType := uint8(icmpHdr.Type()) icmpType := uint8(icmpHdr.Type())
icmpCode := uint8(icmpHdr.Code()) icmpCode := uint8(icmpHdr.Code())
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
// dont process our own replies
return true
}
flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode) f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode)
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second) ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
@@ -33,8 +36,6 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
if err != nil { if err != nil {
f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err) f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
// This will make netstack reply on behalf of the original destination, that's ok for now // This will make netstack reply on behalf of the original destination, that's ok for now
return false return false
} }
@@ -42,30 +43,15 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
if err := conn.Close(); err != nil { if err := conn.Close(); err != nil {
f.logger.Debug("Failed to close ICMP socket: %v", err) f.logger.Debug("Failed to close ICMP socket: %v", err)
} }
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
}() }()
dstIP := f.determineDialAddr(id.LocalAddress) dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP} dst := &net.IPAddr{IP: dstIP}
// Get the complete ICMP message (header + data)
fullPacket := stack.PayloadSince(pkt.TransportHeader()) fullPacket := stack.PayloadSince(pkt.TransportHeader())
payload := fullPacket.AsSlice() payload := fullPacket.AsSlice()
// For Echo Requests, send and handle response if _, err = conn.WriteTo(payload, dst); err != nil {
switch icmpHdr.Type() {
case header.ICMPv4Echo:
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id, flowID)
case header.ICMPv4EchoReply:
// dont process our own replies
return true
default:
}
// For other ICMP types (Time Exceeded, Destination Unreachable, etc)
_, err = conn.WriteTo(payload, dst)
if err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err) f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
return true return true
} }
@@ -73,21 +59,20 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
f.logger.Trace("Forwarded ICMP packet %v type %v code %v", f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code()) epID(id), icmpHdr.Type(), icmpHdr.Code())
// For Echo Requests, send and handle response
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
f.handleEchoResponse(icmpHdr, conn, id)
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
}
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
return true return true
} }
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID, flowID uuid.UUID) bool { func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) {
if _, err := conn.WriteTo(payload, dst); err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
return true
}
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error("Failed to set read deadline for ICMP response: %v", err) f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
return true return
} }
response := make([]byte, f.endpoint.mtu) response := make([]byte, f.endpoint.mtu)
@@ -96,7 +81,7 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds
if !isTimeout(err) { if !isTimeout(err) {
f.logger.Error("Failed to read ICMP response: %v", err) f.logger.Error("Failed to read ICMP response: %v", err)
} }
return true return
} }
ipHdr := make([]byte, header.IPv4MinimumSize) ipHdr := make([]byte, header.IPv4MinimumSize)
@@ -117,13 +102,11 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds
if err := f.InjectIncomingPacket(fullPacket); err != nil { if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error("Failed to inject ICMP response: %v", err) f.logger.Error("Failed to inject ICMP response: %v", err)
return true return
} }
f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v", f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code()) epID(id), icmpHdr.Type(), icmpHdr.Code())
return true
} }
// sendICMPEvent stores flow events for ICMP packets // sendICMPEvent stores flow events for ICMP packets
@@ -134,9 +117,11 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: nftypes.ICMP, Protocol: nftypes.ICMP,
// TODO: handle ipv6 // TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()), SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()), DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
ICMPType: icmpType, ICMPType: icmpType,
ICMPCode: icmpCode, ICMPCode: icmpCode,
// TODO: get packets/bytes
}) })
} }

View File

@@ -22,7 +22,14 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
id := r.ID() id := r.ID()
flowID := uuid.New() flowID := uuid.New()
f.sendTCPEvent(nftypes.TypeStart, flowID, id)
f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil)
var success bool
defer func() {
if !success {
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil)
}
}()
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
@@ -51,6 +58,7 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
inConn := gonet.NewTCPConn(&wq, ep) inConn := gonet.NewTCPConn(&wq, ep)
success = true
f.logger.Trace("forwarder: established TCP connection %v", epID(id)) f.logger.Trace("forwarder: established TCP connection %v", epID(id))
go f.proxyTCP(id, inConn, outConn, ep, flowID) go f.proxyTCP(id, inConn, outConn, ep, flowID)
@@ -66,7 +74,7 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
} }
ep.Close() ep.Close()
f.sendTCPEvent(nftypes.TypeEnd, flowID, id) f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep)
}() }()
// Create context for managing the proxy goroutines // Create context for managing the proxy goroutines
@@ -98,17 +106,27 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
} }
} }
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) { func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
fields := nftypes.EventFields{
f.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID, FlowID: flowID,
Type: typ, Type: typ,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: 6, Protocol: nftypes.TCP,
// TODO: handle ipv6 // TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()), SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()), DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
SourcePort: id.LocalPort, SourcePort: id.RemotePort,
DestPort: id.RemotePort, DestPort: id.LocalPort,
}) }
if ep != nil {
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
// fields are flipped since this is the in conn
// TODO: get bytes
fields.RxPackets = tcpStats.SegmentsSent.Value()
fields.TxPackets = tcpStats.SegmentsReceived.Value()
}
}
f.flowLogger.StoreEvent(fields)
} }

View File

@@ -89,21 +89,6 @@ func (f *udpForwarder) Stop() {
} }
} }
// sendUDPEvent stores flow events for UDP connections
func (f *udpForwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) {
f.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: 17,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()),
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()),
SourcePort: id.LocalPort,
DestPort: id.RemotePort,
})
}
// cleanup periodically removes idle UDP connections // cleanup periodically removes idle UDP connections
func (f *udpForwarder) cleanup() { func (f *udpForwarder) cleanup() {
ticker := time.NewTicker(time.Minute) ticker := time.NewTicker(time.Minute)
@@ -140,8 +125,6 @@ func (f *udpForwarder) cleanup() {
f.Unlock() f.Unlock()
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", epID(idle.id)) f.logger.Trace("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
f.sendUDPEvent(nftypes.TypeEnd, idle.conn.flowID, idle.id)
} }
} }
} }
@@ -165,13 +148,19 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
} }
flowID := uuid.New() flowID := uuid.New()
f.sendUDPEvent(nftypes.TypeStart, flowID, id)
f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil)
var success bool
defer func() {
if !success {
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil)
}
}()
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil { if err != nil {
f.logger.Debug("forwarder: UDP dial error for %v: %v", epID(id), err) f.logger.Debug("forwarder: UDP dial error for %v: %v", epID(id), err)
f.sendUDPEvent(nftypes.TypeEnd, flowID, id)
// TODO: Send ICMP error message // TODO: Send ICMP error message
return return
} }
@@ -184,7 +173,6 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if err := outConn.Close(); err != nil { if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err) f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
} }
f.sendUDPEvent(nftypes.TypeEnd, flowID, id)
return return
} }
@@ -212,13 +200,14 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err) f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
} }
f.sendUDPEvent(nftypes.TypeEnd, flowID, id)
return return
} }
f.udpForwarder.conns[id] = pConn f.udpForwarder.conns[id] = pConn
f.udpForwarder.Unlock() f.udpForwarder.Unlock()
success = true
f.logger.Trace("forwarder: established UDP connection %v", epID(id)) f.logger.Trace("forwarder: established UDP connection %v", epID(id))
go f.proxyUDP(connCtx, pConn, id, ep) go f.proxyUDP(connCtx, pConn, id, ep)
} }
@@ -238,7 +227,7 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
delete(f.udpForwarder.conns, id) delete(f.udpForwarder.conns, id)
f.udpForwarder.Unlock() f.udpForwarder.Unlock()
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id) f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep)
}() }()
errChan := make(chan error, 2) errChan := make(chan error, 2)
@@ -264,19 +253,30 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
} }
} }
// sendUDPEvent stores flow events for UDP connections, mirrors the TCP version // sendUDPEvent stores flow events for UDP connections
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID) { func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
f.flowLogger.StoreEvent(nftypes.EventFields{ fields := nftypes.EventFields{
FlowID: flowID, FlowID: flowID,
Type: typ, Type: typ,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: 17, // UDP protocol number Protocol: nftypes.UDP,
// TODO: handle ipv6 // TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.LocalAddress.As4()), SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: netip.AddrFrom4(id.RemoteAddress.As4()), DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
SourcePort: id.LocalPort, SourcePort: id.RemotePort,
DestPort: id.RemotePort, DestPort: id.LocalPort,
}) }
if ep != nil {
if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
// fields are flipped since this is the in conn
// TODO: get bytes
fields.RxPackets = tcpStats.PacketsSent.Value()
fields.TxPackets = tcpStats.PacketsReceived.Value()
}
}
f.flowLogger.StoreEvent(fields)
} }
func (c *udpPacketConn) updateLastSeen() { func (c *udpPacketConn) updateLastSeen() {

View File

@@ -3,6 +3,7 @@ package uspfilter
import ( import (
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -31,13 +32,9 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
m.ipv4Bitmap[high] |= 1 << (low % 32) m.ipv4Bitmap[high] |= 1 << (low % 32)
} }
func (m *localIPManager) checkBitmapBit(ip net.IP) bool { func (m *localIPManager) checkBitmapBit(ip []byte) bool {
ipv4 := ip.To4() high := (uint16(ip[0]) << 8) | uint16(ip[1])
if ipv4 == nil { low := (uint16(ip[2]) << 8) | uint16(ip[3])
return false
}
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0 return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
} }
@@ -122,12 +119,12 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
return nil return nil
} }
func (m *localIPManager) IsLocalIP(ip net.IP) bool { func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
m.mu.RLock() m.mu.RLock()
defer m.mu.RUnlock() defer m.mu.RUnlock()
if ipv4 := ip.To4(); ipv4 != nil { if ip.Is4() {
return m.checkBitmapBit(ipv4) return m.checkBitmapBit(ip.AsSlice())
} }
return false return false

View File

@@ -2,90 +2,91 @@ package uspfilter
import ( import (
"net" "net"
"net/netip"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
func TestLocalIPManager(t *testing.T) { func TestLocalIPManager(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
setupAddr iface.WGAddress setupAddr wgaddr.Address
testIP net.IP testIP netip.Addr
expected bool expected bool
}{ }{
{ {
name: "Localhost range", name: "Localhost range",
setupAddr: iface.WGAddress{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"), IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32), Mask: net.CIDRMask(24, 32),
}, },
}, },
testIP: net.ParseIP("127.0.0.2"), testIP: netip.MustParseAddr("127.0.0.2"),
expected: true, expected: true,
}, },
{ {
name: "Localhost standard address", name: "Localhost standard address",
setupAddr: iface.WGAddress{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"), IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32), Mask: net.CIDRMask(24, 32),
}, },
}, },
testIP: net.ParseIP("127.0.0.1"), testIP: netip.MustParseAddr("127.0.0.1"),
expected: true, expected: true,
}, },
{ {
name: "Localhost range edge", name: "Localhost range edge",
setupAddr: iface.WGAddress{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"), IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32), Mask: net.CIDRMask(24, 32),
}, },
}, },
testIP: net.ParseIP("127.255.255.255"), testIP: netip.MustParseAddr("127.255.255.255"),
expected: true, expected: true,
}, },
{ {
name: "Local IP matches", name: "Local IP matches",
setupAddr: iface.WGAddress{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"), IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32), Mask: net.CIDRMask(24, 32),
}, },
}, },
testIP: net.ParseIP("192.168.1.1"), testIP: netip.MustParseAddr("192.168.1.1"),
expected: true, expected: true,
}, },
{ {
name: "Local IP doesn't match", name: "Local IP doesn't match",
setupAddr: iface.WGAddress{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"), IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32), Mask: net.CIDRMask(24, 32),
}, },
}, },
testIP: net.ParseIP("192.168.1.2"), testIP: netip.MustParseAddr("192.168.1.2"),
expected: false, expected: false,
}, },
{ {
name: "IPv6 address", name: "IPv6 address",
setupAddr: iface.WGAddress{ setupAddr: wgaddr.Address{
IP: net.ParseIP("fe80::1"), IP: net.ParseIP("fe80::1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("fe80::"), IP: net.ParseIP("fe80::"),
Mask: net.CIDRMask(64, 128), Mask: net.CIDRMask(64, 128),
}, },
}, },
testIP: net.ParseIP("fe80::1"), testIP: netip.MustParseAddr("fe80::1"),
expected: false, expected: false,
}, },
} }
@@ -95,7 +96,7 @@ func TestLocalIPManager(t *testing.T) {
manager := newLocalIPManager() manager := newLocalIPManager()
mock := &IFaceMock{ mock := &IFaceMock{
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return tt.setupAddr return tt.setupAddr
}, },
} }
@@ -174,7 +175,7 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
t.Logf("Testing %d IPs", len(tests)) t.Logf("Testing %d IPs", len(tests))
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.ip, func(t *testing.T) { t.Run(tt.ip, func(t *testing.T) {
result := manager.IsLocalIP(net.ParseIP(tt.ip)) result := manager.IsLocalIP(netip.MustParseAddr(tt.ip))
require.Equal(t, tt.expected, result, "IP: %s", tt.ip) require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
}) })
} }

View File

@@ -1,7 +1,6 @@
package uspfilter package uspfilter
import ( import (
"net"
"net/netip" "net/netip"
"github.com/google/gopacket" "github.com/google/gopacket"
@@ -13,7 +12,7 @@ import (
type PeerRule struct { type PeerRule struct {
id string id string
mgmtId []byte mgmtId []byte
ip net.IP ip netip.Addr
ipLayer gopacket.LayerType ipLayer gopacket.LayerType
matchByIP bool matchByIP bool
protoLayer gopacket.LayerType protoLayer gopacket.LayerType

View File

@@ -2,7 +2,7 @@ package uspfilter
import ( import (
"fmt" "fmt"
"net" "net/netip"
"time" "time"
"github.com/google/gopacket" "github.com/google/gopacket"
@@ -53,8 +53,8 @@ type TraceResult struct {
} }
type PacketTrace struct { type PacketTrace struct {
SourceIP net.IP SourceIP netip.Addr
DestinationIP net.IP DestinationIP netip.Addr
Protocol string Protocol string
SourcePort uint16 SourcePort uint16
DestinationPort uint16 DestinationPort uint16
@@ -72,8 +72,8 @@ type TCPState struct {
} }
type PacketBuilder struct { type PacketBuilder struct {
SrcIP net.IP SrcIP netip.Addr
DstIP net.IP DstIP netip.Addr
Protocol fw.Protocol Protocol fw.Protocol
SrcPort uint16 SrcPort uint16
DstPort uint16 DstPort uint16
@@ -126,8 +126,8 @@ func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
Version: 4, Version: 4,
TTL: 64, TTL: 64,
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)), Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
SrcIP: p.SrcIP, SrcIP: p.SrcIP.AsSlice(),
DstIP: p.DstIP, DstIP: p.DstIP.AsSlice(),
} }
} }
@@ -260,28 +260,30 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
return m.traceInbound(packetData, trace, d, srcIP, dstIP) return m.traceInbound(packetData, trace, d, srcIP, dstIP)
} }
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace { func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) { if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
return trace return trace
} }
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) { if m.localipmanager.IsLocalIP(dstIP) {
return trace if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
return trace
}
} }
if !m.handleRouting(trace) { if !m.handleRouting(trace) {
return trace return trace
} }
if m.nativeRouter { if m.nativeRouter.Load() {
return m.handleNativeRouter(trace) return m.handleNativeRouter(trace)
} }
return m.handleRouteACLs(trace, d, srcIP, dstIP) return m.handleRouteACLs(trace, d, srcIP, dstIP)
} }
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool { func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) bool {
allowed := m.isValidTrackedConnection(d, srcIP, dstIP) allowed := m.isValidTrackedConnection(d, srcIP, dstIP, 0)
msg := "No existing connection found" msg := "No existing connection found"
if allowed { if allowed {
msg = m.buildConntrackStateMessage(d) msg = m.buildConntrackStateMessage(d)
@@ -309,39 +311,46 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
return msg return msg
} }
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool { func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
if !m.localForwarding {
trace.AddResult(StageRouting, "Local forwarding disabled", false)
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false)
return true
}
trace.AddResult(StageRouting, "Packet destined for local delivery", true) trace.AddResult(StageRouting, "Packet destined for local delivery", true)
ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
strRuleId := "implicit" strRuleId := "<no id>"
if ruleId != nil { if ruleId != nil {
strRuleId = string(ruleId) strRuleId = string(ruleId)
} }
msg := fmt.Sprintf("Allowed by peer ACL rules (%s)", strRuleId) msg := fmt.Sprintf("Allowed by peer ACL rules (%s)", strRuleId)
if blocked { if blocked {
msg = fmt.Sprintf("Blocked by peer ACL rules (%s)", strRuleId) msg = fmt.Sprintf("Blocked by peer ACL rules (%s)", strRuleId)
trace.AddResult(StagePeerACL, msg, false)
trace.AddResult(StageCompleted, "Packet dropped - ACL denied", false)
return true
} }
trace.AddResult(StagePeerACL, msg, !blocked) trace.AddResult(StagePeerACL, msg, true)
// Handle netstack mode
if m.netstack { if m.netstack {
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", !blocked) switch {
case !m.localForwarding:
trace.AddResult(StageCompleted, "Packet sent to virtual stack", true)
case m.forwarder.Load() != nil:
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", true)
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
default:
trace.AddResult(StageCompleted, "Packet dropped - forwarder not initialized", false)
}
return true
} }
trace.AddResult(StageCompleted, msgProcessingCompleted, !blocked) // In normal mode, packets are allowed through for local delivery
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
return true return true
} }
func (m *Manager) handleRouting(trace *PacketTrace) bool { func (m *Manager) handleRouting(trace *PacketTrace) bool {
if !m.routingEnabled { if !m.routingEnabled.Load() {
trace.AddResult(StageRouting, "Routing disabled", false) trace.AddResult(StageRouting, "Routing disabled", false)
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false) trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
return false return false
@@ -357,14 +366,14 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
return trace return trace
} }
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace { func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
proto, _ := getProtocolFromPacket(d) proto, _ := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
strId := string(id) strId := string(id)
if id == nil { if id == nil {
strId = "implicit" strId = "<no id>"
} }
msg := fmt.Sprintf("Allowed by route ACLs (%s)", strId) msg := fmt.Sprintf("Allowed by route ACLs (%s)", strId)
@@ -373,7 +382,7 @@ func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP n
} }
trace.AddResult(StageRouteACL, msg, allowed) trace.AddResult(StageRouteACL, msg, allowed)
if allowed && m.forwarder != nil { if allowed && m.forwarder.Load() != nil {
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true) m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
} }
@@ -392,7 +401,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace { func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
// will create or update the connection state // will create or update the connection state
dropped := m.processOutgoingHooks(packetData) dropped := m.processOutgoingHooks(packetData, 0)
if dropped { if dropped {
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false) trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
} else { } else {

View File

@@ -0,0 +1,440 @@
package uspfilter
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func verifyTraceStages(t *testing.T, trace *PacketTrace, expectedStages []PacketStage) {
t.Logf("Trace results: %v", trace.Results)
actualStages := make([]PacketStage, 0, len(trace.Results))
for _, result := range trace.Results {
actualStages = append(actualStages, result.Stage)
t.Logf("Stage: %s, Message: %s, Allowed: %v", result.Stage, result.Message, result.Allowed)
}
require.ElementsMatch(t, expectedStages, actualStages, "Trace stages don't match expected stages")
}
func verifyFinalDisposition(t *testing.T, trace *PacketTrace, expectedAllowed bool) {
require.NotEmpty(t, trace.Results, "Trace should have results")
lastResult := trace.Results[len(trace.Results)-1]
require.Equal(t, StageCompleted, lastResult.Stage, "Last stage should be 'Completed'")
require.Equal(t, expectedAllowed, lastResult.Allowed, "Final disposition incorrect")
}
func TestTracePacket(t *testing.T) {
setupTracerTest := func(statefulMode bool) *Manager {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
}
},
}
m, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
if !statefulMode {
m.stateful = false
}
return m
}
createPacketBuilder := func(srcIP, dstIP string, protocol fw.Protocol, srcPort, dstPort uint16, direction fw.RuleDirection) *PacketBuilder {
builder := &PacketBuilder{
SrcIP: netip.MustParseAddr(srcIP),
DstIP: netip.MustParseAddr(dstIP),
Protocol: protocol,
SrcPort: srcPort,
DstPort: dstPort,
Direction: direction,
}
if protocol == "tcp" {
builder.TCPState = &TCPState{SYN: true}
}
return builder
}
createICMPPacketBuilder := func(srcIP, dstIP string, icmpType, icmpCode uint8, direction fw.RuleDirection) *PacketBuilder {
return &PacketBuilder{
SrcIP: netip.MustParseAddr(srcIP),
DstIP: netip.MustParseAddr(dstIP),
Protocol: "icmp",
ICMPType: icmpType,
ICMPCode: icmpCode,
Direction: direction,
}
}
testCases := []struct {
name string
setup func(*Manager)
packetBuilder func() *PacketBuilder
expectedStages []PacketStage
expectedAllow bool
}{
{
name: "LocalTraffic_ACLAllowed",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "LocalTraffic_ACLDenied",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,
},
{
name: "LocalTraffic_WithForwarder",
setup: func(m *Manager) {
m.netstack = true
m.localForwarding = true
m.forwarder.Store(&forwarder.Forwarder{})
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageForwarding,
StageCompleted,
},
expectedAllow: true,
},
{
name: "LocalTraffic_WithoutForwarder",
setup: func(m *Manager) {
m.netstack = true
m.localForwarding = false
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "RoutedTraffic_ACLAllowed",
setup: func(m *Manager) {
m.routingEnabled.Store(true)
m.nativeRouter.Store(false)
m.forwarder.Store(&forwarder.Forwarder{})
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageRouteACL,
StageForwarding,
StageCompleted,
},
expectedAllow: true,
},
{
name: "RoutedTraffic_ACLDenied",
setup: func(m *Manager) {
m.routingEnabled.Store(true)
m.nativeRouter.Store(false)
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageRouteACL,
StageCompleted,
},
expectedAllow: false,
},
{
name: "RoutedTraffic_NativeRouter",
setup: func(m *Manager) {
m.routingEnabled.Store(true)
m.nativeRouter.Store(true)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageRouteACL,
StageForwarding,
StageCompleted,
},
expectedAllow: true,
},
{
name: "RoutedTraffic_RoutingDisabled",
setup: func(m *Manager) {
m.routingEnabled.Store(false)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageCompleted,
},
expectedAllow: false,
},
{
name: "ConnectionTracking_Hit",
setup: func(m *Manager) {
srcIP := netip.MustParseAddr("100.10.0.100")
dstIP := netip.MustParseAddr("1.1.1.1")
srcPort := uint16(12345)
dstPort := uint16(80)
m.tcpTracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, conntrack.TCPSyn, 0)
},
packetBuilder: func() *PacketBuilder {
pb := createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 80, 12345, fw.RuleDirectionIN)
pb.TCPState = &TCPState{SYN: true, ACK: true}
return pb
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageCompleted,
},
expectedAllow: true,
},
{
name: "OutboundTraffic",
setup: func(m *Manager) {
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("100.10.0.100", "1.1.1.1", "tcp", 12345, 80, fw.RuleDirectionOUT)
},
expectedStages: []PacketStage{
StageReceived,
StageCompleted,
},
expectedAllow: true,
},
{
name: "ICMPEchoRequest",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolICMP
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 8, 0, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "ICMPDestinationUnreachable",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolICMP
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 3, 0, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "UDPTraffic_WithoutHook",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolUDP
port := &fw.Port{Values: []uint16{53}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "UDPTraffic_WithHook",
setup: func(m *Manager) {
hookFunc := func([]byte) bool {
return true
}
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,
},
{
name: "StatefulDisabled_NoTracking",
setup: func(m *Manager) {
m.stateful = false
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
m := setupTracerTest(true)
tc.setup(m)
require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")),
"100.10.0.100 should be recognized as a local IP")
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("172.17.0.2")),
"172.17.0.2 should not be recognized as a local IP")
pb := tc.packetBuilder()
trace, err := m.TracePacketFromBuilder(pb)
require.NoError(t, err)
verifyTraceStages(t, trace, tc.expectedStages)
verifyFinalDisposition(t, trace, tc.expectedAllow)
})
}
}

View File

@@ -10,6 +10,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
@@ -66,9 +67,9 @@ func (r RouteRules) Sort() {
// Manager userspace firewall manager // Manager userspace firewall manager
type Manager struct { type Manager struct {
// outgoingRules is used for hooks only // outgoingRules is used for hooks only
outgoingRules map[string]RuleSet outgoingRules map[netip.Addr]RuleSet
// incomingRules is used for filtering and hooks // incomingRules is used for filtering and hooks
incomingRules map[string]RuleSet incomingRules map[netip.Addr]RuleSet
routeRules RouteRules routeRules RouteRules
wgNetwork *net.IPNet wgNetwork *net.IPNet
decoders sync.Pool decoders sync.Pool
@@ -80,9 +81,9 @@ type Manager struct {
// indicates whether server routes are disabled // indicates whether server routes are disabled
disableServerRoutes bool disableServerRoutes bool
// indicates whether we forward packets not destined for ourselves // indicates whether we forward packets not destined for ourselves
routingEnabled bool routingEnabled atomic.Bool
// indicates whether we leave forwarding and filtering to the native firewall // indicates whether we leave forwarding and filtering to the native firewall
nativeRouter bool nativeRouter atomic.Bool
// indicates whether we track outbound connections // indicates whether we track outbound connections
stateful bool stateful bool
// indicates whether wireguards runs in netstack mode // indicates whether wireguards runs in netstack mode
@@ -95,7 +96,7 @@ type Manager struct {
udpTracker *conntrack.UDPTracker udpTracker *conntrack.UDPTracker
icmpTracker *conntrack.ICMPTracker icmpTracker *conntrack.ICMPTracker
tcpTracker *conntrack.TCPTracker tcpTracker *conntrack.TCPTracker
forwarder *forwarder.Forwarder forwarder atomic.Pointer[forwarder.Forwarder]
logger *nblog.Logger logger *nblog.Logger
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
} }
@@ -168,18 +169,18 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
}, },
}, },
nativeFirewall: nativeFirewall, nativeFirewall: nativeFirewall,
outgoingRules: make(map[string]RuleSet), outgoingRules: make(map[netip.Addr]RuleSet),
incomingRules: make(map[string]RuleSet), incomingRules: make(map[netip.Addr]RuleSet),
wgIface: iface, wgIface: iface,
localipmanager: newLocalIPManager(), localipmanager: newLocalIPManager(),
disableServerRoutes: disableServerRoutes, disableServerRoutes: disableServerRoutes,
routingEnabled: false,
stateful: !disableConntrack, stateful: !disableConntrack,
logger: nblog.NewFromLogrus(log.StandardLogger()), logger: nblog.NewFromLogrus(log.StandardLogger()),
flowLogger: flowLogger, flowLogger: flowLogger,
netstack: netstack.IsEnabled(), netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding, localForwarding: enableLocalForwarding,
} }
m.routingEnabled.Store(false)
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil { if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
return nil, fmt.Errorf("update local IPs: %w", err) return nil, fmt.Errorf("update local IPs: %w", err)
@@ -211,7 +212,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
} }
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error { func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
if m.forwarder == nil { if m.forwarder.Load() == nil {
return nil return nil
} }
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String()) wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
@@ -255,20 +256,20 @@ func (m *Manager) determineRouting() error {
switch { switch {
case disableUspRouting: case disableUspRouting:
m.routingEnabled = false m.routingEnabled.Store(false)
m.nativeRouter = false m.nativeRouter.Store(false)
log.Info("userspace routing is disabled") log.Info("userspace routing is disabled")
case m.disableServerRoutes: case m.disableServerRoutes:
// if server routes are disabled we will let packets pass to the native stack // if server routes are disabled we will let packets pass to the native stack
m.routingEnabled = true m.routingEnabled.Store(true)
m.nativeRouter = true m.nativeRouter.Store(true)
log.Info("server routes are disabled") log.Info("server routes are disabled")
case forceUserspaceRouter: case forceUserspaceRouter:
m.routingEnabled = true m.routingEnabled.Store(true)
m.nativeRouter = false m.nativeRouter.Store(false)
log.Info("userspace routing is forced") log.Info("userspace routing is forced")
@@ -276,19 +277,19 @@ func (m *Manager) determineRouting() error {
// if the OS supports routing natively, then we don't need to filter/route ourselves // if the OS supports routing natively, then we don't need to filter/route ourselves
// netstack mode won't support native routing as there is no interface // netstack mode won't support native routing as there is no interface
m.routingEnabled = true m.routingEnabled.Store(true)
m.nativeRouter = true m.nativeRouter.Store(true)
log.Info("native routing is enabled") log.Info("native routing is enabled")
default: default:
m.routingEnabled = true m.routingEnabled.Store(true)
m.nativeRouter = false m.nativeRouter.Store(false)
log.Info("userspace routing enabled by default") log.Info("userspace routing enabled by default")
} }
if m.routingEnabled && !m.nativeRouter { if m.routingEnabled.Load() && !m.nativeRouter.Load() {
return m.initForwarder() return m.initForwarder()
} }
@@ -297,24 +298,24 @@ func (m *Manager) determineRouting() error {
// initForwarder initializes the forwarder, it disables routing on errors // initForwarder initializes the forwarder, it disables routing on errors
func (m *Manager) initForwarder() error { func (m *Manager) initForwarder() error {
if m.forwarder != nil { if m.forwarder.Load() != nil {
return nil return nil
} }
// Only supported in userspace mode as we need to inject packets back into wireguard directly // Only supported in userspace mode as we need to inject packets back into wireguard directly
intf := m.wgIface.GetWGDevice() intf := m.wgIface.GetWGDevice()
if intf == nil { if intf == nil {
m.routingEnabled = false m.routingEnabled.Store(false)
return errors.New("forwarding not supported") return errors.New("forwarding not supported")
} }
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack) forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack)
if err != nil { if err != nil {
m.routingEnabled = false m.routingEnabled.Store(false)
return fmt.Errorf("create forwarder: %w", err) return fmt.Errorf("create forwarder: %w", err)
} }
m.forwarder = forwarder m.forwarder.Store(forwarder)
log.Debug("forwarder initialized") log.Debug("forwarder initialized")
@@ -330,7 +331,7 @@ func (m *Manager) IsServerRouteSupported() bool {
} }
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if m.nativeRouter && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.AddNatRule(pair) return m.nativeFirewall.AddNatRule(pair)
} }
@@ -341,7 +342,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
// RemoveNatRule removes a routing firewall rule // RemoveNatRule removes a routing firewall rule
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
if m.nativeRouter && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.RemoveNatRule(pair) return m.nativeFirewall.RemoveNatRule(pair)
} }
return nil return nil
@@ -360,17 +361,23 @@ func (m *Manager) AddPeerFiltering(
action firewall.Action, action firewall.Action,
_ string, _ string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
// TODO: fix in upper layers
i, ok := netip.AddrFromSlice(ip)
if !ok {
return nil, fmt.Errorf("invalid IP: %s", ip)
}
i = i.Unmap()
r := PeerRule{ r := PeerRule{
id: uuid.New().String(), id: uuid.New().String(),
mgmtId: id, mgmtId: id,
ip: ip, ip: i,
ipLayer: layers.LayerTypeIPv6, ipLayer: layers.LayerTypeIPv6,
matchByIP: true, matchByIP: true,
drop: action == firewall.ActionDrop, drop: action == firewall.ActionDrop,
} }
if ipNormalized := ip.To4(); ipNormalized != nil { if i.Is4() {
r.ipLayer = layers.LayerTypeIPv4 r.ipLayer = layers.LayerTypeIPv4
r.ip = ipNormalized
} }
if s := r.ip.String(); s == "0.0.0.0" || s == "::" { if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
@@ -395,10 +402,10 @@ func (m *Manager) AddPeerFiltering(
} }
m.mutex.Lock() m.mutex.Lock()
if _, ok := m.incomingRules[r.ip.String()]; !ok { if _, ok := m.incomingRules[r.ip]; !ok {
m.incomingRules[r.ip.String()] = make(RuleSet) m.incomingRules[r.ip] = make(RuleSet)
} }
m.incomingRules[r.ip.String()][r.id] = r m.incomingRules[r.ip][r.id] = r
m.mutex.Unlock() m.mutex.Unlock()
return []firewall.Rule{&r}, nil return []firewall.Rule{&r}, nil
} }
@@ -412,13 +419,10 @@ func (m *Manager) AddRouteFiltering(
dPort *firewall.Port, dPort *firewall.Port,
action firewall.Action, action firewall.Action,
) (firewall.Rule, error) { ) (firewall.Rule, error) {
if m.nativeRouter && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action) return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
} }
m.mutex.Lock()
defer m.mutex.Unlock()
ruleID := uuid.New().String() ruleID := uuid.New().String()
rule := RouteRule{ rule := RouteRule{
// TODO: consolidate these IDs // TODO: consolidate these IDs
@@ -432,14 +436,16 @@ func (m *Manager) AddRouteFiltering(
action: action, action: action,
} }
m.mutex.Lock()
m.routeRules = append(m.routeRules, rule) m.routeRules = append(m.routeRules, rule)
m.routeRules.Sort() m.routeRules.Sort()
m.mutex.Unlock()
return &rule, nil return &rule, nil
} }
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
if m.nativeRouter && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.DeleteRouteRule(rule) return m.nativeFirewall.DeleteRouteRule(rule)
} }
@@ -468,10 +474,10 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
return fmt.Errorf("delete rule: invalid rule type: %T", rule) return fmt.Errorf("delete rule: invalid rule type: %T", rule)
} }
if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok { if _, ok := m.incomingRules[r.ip][r.id]; !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id) return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
} }
delete(m.incomingRules[r.ip.String()], r.id) delete(m.incomingRules[r.ip], r.id)
return nil return nil
} }
@@ -504,13 +510,13 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
} }
// DropOutgoing filter outgoing packets // DropOutgoing filter outgoing packets
func (m *Manager) DropOutgoing(packetData []byte) bool { func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
return m.processOutgoingHooks(packetData) return m.processOutgoingHooks(packetData, size)
} }
// DropIncoming filter incoming packets // DropIncoming filter incoming packets
func (m *Manager) DropIncoming(packetData []byte) bool { func (m *Manager) DropIncoming(packetData []byte, size int) bool {
return m.dropFilter(packetData) return m.dropFilter(packetData, size)
} }
// UpdateLocalIPs updates the list of local IPs // UpdateLocalIPs updates the list of local IPs
@@ -518,10 +524,7 @@ func (m *Manager) UpdateLocalIPs() error {
return m.localipmanager.UpdateLocalIPs(m.wgIface) return m.localipmanager.UpdateLocalIPs(m.wgIface)
} }
func (m *Manager) processOutgoingHooks(packetData []byte) bool { func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
d := m.decoders.Get().(*decoder) d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d) defer m.decoders.Put(d)
@@ -534,31 +537,34 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
} }
srcIP, dstIP := m.extractIPs(d) srcIP, dstIP := m.extractIPs(d)
if srcIP == nil { if !srcIP.IsValid() {
m.logger.Error("Unknown network layer: %v", d.decoded[0])
return false return false
} }
// Track all protocols if stateful mode is enabled if d.decoded[1] == layers.LayerTypeUDP && m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) {
if m.stateful { return true
m.trackOutbound(d, srcIP, dstIP)
} }
// Process UDP hooks even if stateful mode is disabled if m.stateful {
if d.decoded[1] == layers.LayerTypeUDP { m.trackOutbound(d, srcIP, dstIP, size)
return m.checkUDPHooks(d, dstIP, packetData)
} }
return false return false
} }
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) { func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP netip.Addr) {
switch d.decoded[0] { switch d.decoded[0] {
case layers.LayerTypeIPv4: case layers.LayerTypeIPv4:
return d.ip4.SrcIP, d.ip4.DstIP src, _ := netip.AddrFromSlice(d.ip4.SrcIP)
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
return src, dst
case layers.LayerTypeIPv6: case layers.LayerTypeIPv6:
return d.ip6.SrcIP, d.ip6.DstIP src, _ := netip.AddrFromSlice(d.ip6.SrcIP)
dst, _ := netip.AddrFromSlice(d.ip6.DstIP)
return src, dst
default: default:
return nil, nil return netip.Addr{}, netip.Addr{}
} }
} }
@@ -585,51 +591,70 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
return flags return flags
} }
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP net.IP) { func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
transport := d.decoded[1] transport := d.decoded[1]
switch transport { switch transport {
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort)) m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp) flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags) m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.Seq, d.icmp4.TypeCode) m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size)
} }
} }
func (m *Manager) trackInbound(d *decoder, srcIP, dstIP net.IP) { func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byte, size int) {
transport := d.decoded[1] transport := d.decoded[1]
switch transport { switch transport {
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort)) m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size)
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
flags := getTCPFlags(&d.tcp) flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags) m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.Seq, d.icmp4.TypeCode) m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size)
} }
} }
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool { // udpHooksDrop checks if any UDP hooks should drop the packet
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} { func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
if rules, exists := m.outgoingRules[ipKey]; exists { m.mutex.RLock()
for _, rule := range rules { defer m.mutex.RUnlock()
if rule.udpHook != nil && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
return rule.udpHook(packetData) // Check specific destination IP first
} if rules, exists := m.outgoingRules[dstIP]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
} }
} }
} }
// Check IPv4 unspecified address
if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
}
// Check IPv6 unspecified address
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
}
return false return false
} }
// dropFilter implements filtering logic for incoming packets. // dropFilter implements filtering logic for incoming packets.
// If it returns true, the packet should be dropped. // If it returns true, the packet should be dropped.
func (m *Manager) dropFilter(packetData []byte) bool { func (m *Manager) dropFilter(packetData []byte, size int) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
d := m.decoders.Get().(*decoder) d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d) defer m.decoders.Put(d)
@@ -638,19 +663,19 @@ func (m *Manager) dropFilter(packetData []byte) bool {
} }
srcIP, dstIP := m.extractIPs(d) srcIP, dstIP := m.extractIPs(d)
if srcIP == nil { if !srcIP.IsValid() {
m.logger.Error("Unknown network layer: %v", d.decoded[0]) m.logger.Error("Unknown network layer: %v", d.decoded[0])
return true return true
} }
// For all inbound traffic, first check if it matches a tracked connection. // For all inbound traffic, first check if it matches a tracked connection.
// This must happen before any other filtering because the packets are statefully tracked. // This must happen before any other filtering because the packets are statefully tracked.
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) { if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
return false return false
} }
if m.localipmanager.IsLocalIP(dstIP) { if m.localipmanager.IsLocalIP(dstIP) {
return m.handleLocalTraffic(d, srcIP, dstIP, packetData) return m.handleLocalTraffic(d, srcIP, dstIP, packetData, size)
} }
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData) return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
@@ -658,27 +683,28 @@ func (m *Manager) dropFilter(packetData []byte) bool {
// handleLocalTraffic handles local traffic. // handleLocalTraffic handles local traffic.
// If it returns true, the packet should be dropped. // If it returns true, the packet should be dropped.
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool { func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
if ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d); blocked { ruleID, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
srcAddr, _ := netip.AddrFromSlice(srcIP) if blocked {
dstAddr, _ := netip.AddrFromSlice(dstIP)
_, pnum := getProtocolFromPacket(d) _, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleId, pnum, srcAddr, srcPort, dstAddr, dstPort) ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{ m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(), FlowID: uuid.New(),
Type: nftypes.TypeDrop, Type: nftypes.TypeDrop,
RuleID: ruleId, RuleID: ruleID,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: pnum, Protocol: pnum,
SourceIP: srcAddr, SourceIP: srcIP,
DestIP: dstAddr, DestIP: dstIP,
SourcePort: srcPort, SourcePort: srcPort,
DestPort: dstPort, DestPort: dstPort,
// TODO: icmp type/code // TODO: icmp type/code
RxPackets: 1,
RxBytes: uint64(size),
}) })
return true return true
} }
@@ -689,7 +715,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData
} }
// track inbound packets to get the correct direction and session id for flows // track inbound packets to get the correct direction and session id for flows
m.trackInbound(d, srcIP, dstIP) m.trackInbound(d, srcIP, dstIP, ruleID, size)
return false return false
} }
@@ -700,12 +726,12 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
return false return false
} }
if m.forwarder == nil { if m.forwarder.Load() == nil {
m.logger.Trace("Dropping local packet (forwarder not initialized)") m.logger.Trace("Dropping local packet (forwarder not initialized)")
return true return true
} }
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil { if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject local packet: %v", err) m.logger.Error("Failed to inject local packet: %v", err)
} }
@@ -715,37 +741,34 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
// handleRoutedTraffic handles routed traffic. // handleRoutedTraffic handles routed traffic.
// If it returns true, the packet should be dropped. // If it returns true, the packet should be dropped.
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool { func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte) bool {
// Drop if routing is disabled // Drop if routing is disabled
if !m.routingEnabled { if !m.routingEnabled.Load() {
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s", m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
srcIP, dstIP) srcIP, dstIP)
return true return true
} }
// Pass to native stack if native router is enabled or forced // Pass to native stack if native router is enabled or forced
if m.nativeRouter { if m.nativeRouter.Load() {
return false return false
} }
proto, pnum := getProtocolFromPacket(d) proto, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
if id, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass { if ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass {
srcAddr, _ := netip.AddrFromSlice(srcIP)
dstAddr, _ := netip.AddrFromSlice(dstIP)
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
id, pnum, srcIP, srcPort, dstIP, dstPort) ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{ m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(), FlowID: uuid.New(),
Type: nftypes.TypeDrop, Type: nftypes.TypeDrop,
RuleID: id, RuleID: ruleID,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: pnum, Protocol: pnum,
SourceIP: srcAddr, SourceIP: srcIP,
DestIP: dstAddr, DestIP: dstIP,
SourcePort: srcPort, SourcePort: srcPort,
DestPort: dstPort, DestPort: dstPort,
// TODO: icmp type/code // TODO: icmp type/code
@@ -754,7 +777,7 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat
} }
// Let forwarder handle the packet if it passed route ACLs // Let forwarder handle the packet if it passed route ACLs
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil { if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject incoming packet: %v", err) m.logger.Error("Failed to inject incoming packet: %v", err)
} }
@@ -799,7 +822,7 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
return true return true
} }
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool { func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr, size int) bool {
switch d.decoded[1] { switch d.decoded[1] {
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
return m.tcpTracker.IsValidInbound( return m.tcpTracker.IsValidInbound(
@@ -808,6 +831,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
uint16(d.tcp.SrcPort), uint16(d.tcp.SrcPort),
uint16(d.tcp.DstPort), uint16(d.tcp.DstPort),
getTCPFlags(&d.tcp), getTCPFlags(&d.tcp),
size,
) )
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
@@ -816,6 +840,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
dstIP, dstIP,
uint16(d.udp.SrcPort), uint16(d.udp.SrcPort),
uint16(d.udp.DstPort), uint16(d.udp.DstPort),
size,
) )
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
@@ -823,8 +848,8 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
srcIP, srcIP,
dstIP, dstIP,
d.icmp4.Id, d.icmp4.Id,
d.icmp4.Seq,
d.icmp4.TypeCode.Type(), d.icmp4.TypeCode.Type(),
size,
) )
// TODO: ICMPv6 // TODO: ICMPv6
@@ -844,20 +869,22 @@ func (m *Manager) isSpecialICMP(d *decoder) bool {
icmpType == layers.ICMPv4TypeTimeExceeded icmpType == layers.ICMPv4TypeTimeExceeded
} }
func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) ([]byte, bool) { func (m *Manager) peerACLsBlock(srcIP netip.Addr, packetData []byte, rules map[netip.Addr]RuleSet, d *decoder) ([]byte, bool) {
m.mutex.RLock()
defer m.mutex.RUnlock()
if m.isSpecialICMP(d) { if m.isSpecialICMP(d) {
return nil, false return nil, false
} }
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok { if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP], d); ok {
return mgmtId, filter return mgmtId, filter
} }
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules["0.0.0.0"], d); ok { if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv4Unspecified()], d); ok {
return mgmtId, filter return mgmtId, filter
} }
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules["::"], d); ok { if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv6Unspecified()], d); ok {
return mgmtId, filter return mgmtId, filter
} }
@@ -882,10 +909,10 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
return false return false
} }
func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) { func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
payloadLayer := d.decoded[1] payloadLayer := d.decoded[1]
for _, rule := range rules { for _, rule := range rules {
if rule.matchByIP && !ip.Equal(rule.ip) { if rule.matchByIP && ip.Compare(rule.ip) != 0 {
continue continue
} }
@@ -919,16 +946,13 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *de
return nil, false, false return nil, false, false
} }
// routeACLsPass returns treu if the packet is allowed by the route ACLs // routeACLsPass returns true if the packet is allowed by the route ACLs
func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) { func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
m.mutex.RLock() m.mutex.RLock()
defer m.mutex.RUnlock() defer m.mutex.RUnlock()
srcAddr := netip.AddrFrom4([4]byte(srcIP.To4()))
dstAddr := netip.AddrFrom4([4]byte(dstIP.To4()))
for _, rule := range m.routeRules { for _, rule := range m.routeRules {
if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) { if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches {
return rule.mgmtId, rule.action == firewall.ActionAccept return rule.mgmtId, rule.action == firewall.ActionAccept
} }
} }
@@ -972,9 +996,7 @@ func (m *Manager) SetNetwork(network *net.IPNet) {
// AddUDPPacketHook calls hook when UDP packet from given direction matched // AddUDPPacketHook calls hook when UDP packet from given direction matched
// //
// Hook function returns flag which indicates should be the matched package dropped or not // Hook function returns flag which indicates should be the matched package dropped or not
func (m *Manager) AddUDPPacketHook( func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
) string {
r := PeerRule{ r := PeerRule{
id: uuid.New().String(), id: uuid.New().String(),
ip: ip, ip: ip,
@@ -984,23 +1006,22 @@ func (m *Manager) AddUDPPacketHook(
udpHook: hook, udpHook: hook,
} }
if ip.To4() != nil { if ip.Is4() {
r.ipLayer = layers.LayerTypeIPv4 r.ipLayer = layers.LayerTypeIPv4
} }
m.mutex.Lock() m.mutex.Lock()
if in { if in {
if _, ok := m.incomingRules[r.ip.String()]; !ok { if _, ok := m.incomingRules[r.ip]; !ok {
m.incomingRules[r.ip.String()] = make(map[string]PeerRule) m.incomingRules[r.ip] = make(map[string]PeerRule)
} }
m.incomingRules[r.ip.String()][r.id] = r m.incomingRules[r.ip][r.id] = r
} else { } else {
if _, ok := m.outgoingRules[r.ip.String()]; !ok { if _, ok := m.outgoingRules[r.ip]; !ok {
m.outgoingRules[r.ip.String()] = make(map[string]PeerRule) m.outgoingRules[r.ip] = make(map[string]PeerRule)
} }
m.outgoingRules[r.ip.String()][r.id] = r m.outgoingRules[r.ip][r.id] = r
} }
m.mutex.Unlock() m.mutex.Unlock()
return r.id return r.id
@@ -1048,20 +1069,21 @@ func (m *Manager) DisableRouting() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if m.forwarder == nil { fwder := m.forwarder.Load()
if fwder == nil {
return nil return nil
} }
m.routingEnabled = false m.routingEnabled.Store(false)
m.nativeRouter = false m.nativeRouter.Store(false)
// don't stop forwarder if in use by netstack // don't stop forwarder if in use by netstack
if m.netstack && m.localForwarding { if m.netstack && m.localForwarding {
return nil return nil
} }
m.forwarder.Stop() fwder.Stop()
m.forwarder = nil m.forwarder.Store(nil)
log.Debug("forwarder stopped") log.Debug("forwarder stopped")

View File

@@ -171,7 +171,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{ manager.wgNetwork = &net.IPNet{
@@ -193,13 +193,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
// For stateful scenarios, establish the connection // For stateful scenarios, establish the connection
if sc.stateful { if sc.stateful {
manager.processOutgoingHooks(outbound) manager.processOutgoingHooks(outbound, 0)
} }
// Measure inbound packet processing // Measure inbound packet processing
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(inbound) manager.dropFilter(inbound, 0)
} }
}) })
} }
@@ -216,7 +216,7 @@ func BenchmarkStateScaling(b *testing.B) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{ manager.wgNetwork = &net.IPNet{
@@ -230,7 +230,7 @@ func BenchmarkStateScaling(b *testing.B) {
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
outbound := generatePacket(b, srcIPs[i], dstIPs[i], outbound := generatePacket(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, layers.IPProtocolTCP) uint16(1024+i), 80, layers.IPProtocolTCP)
manager.processOutgoingHooks(outbound) manager.processOutgoingHooks(outbound, 0)
} }
// Test packet // Test packet
@@ -238,11 +238,11 @@ func BenchmarkStateScaling(b *testing.B) {
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP) testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
// First establish our test connection // First establish our test connection
manager.processOutgoingHooks(testOut) manager.processOutgoingHooks(testOut, 0)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(testIn) manager.dropFilter(testIn, 0)
} }
}) })
} }
@@ -264,7 +264,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.wgNetwork = &net.IPNet{ manager.wgNetwork = &net.IPNet{
@@ -278,12 +278,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
if sc.established { if sc.established {
manager.processOutgoingHooks(outbound) manager.processOutgoingHooks(outbound, 0)
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(inbound) manager.dropFilter(inbound, 0)
} }
}) })
} }
@@ -463,7 +463,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Close(nil))
}) })
// Setup scenario // Setup scenario
@@ -477,25 +477,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
// For stateful cases and established connections // For stateful cases and established connections
if !strings.Contains(sc.name, "allow_non_wg") || if !strings.Contains(sc.name, "allow_non_wg") ||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") { (strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
manager.processOutgoingHooks(outbound) manager.processOutgoingHooks(outbound, 0)
// For TCP post-handshake, simulate full handshake // For TCP post-handshake, simulate full handshake
if sc.state == "post_handshake" { if sc.state == "post_handshake" {
// SYN // SYN
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn)) syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn) manager.processOutgoingHooks(syn, 0)
// SYN-ACK // SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck)) synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack) manager.dropFilter(synack, 0)
// ACK // ACK
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)) ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack) manager.processOutgoingHooks(ack, 0)
} }
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(inbound) manager.dropFilter(inbound, 0)
} }
}) })
} }
@@ -590,7 +590,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{ manager.SetNetwork(&net.IPNet{
@@ -624,17 +624,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Initial SYN // Initial SYN
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn)) uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn) manager.processOutgoingHooks(syn, 0)
// SYN-ACK // SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack) manager.dropFilter(synack, 0)
// ACK // ACK
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck)) uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack) manager.processOutgoingHooks(ack, 0)
} }
// Prepare test packets simulating bidirectional traffic // Prepare test packets simulating bidirectional traffic
@@ -655,9 +655,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Simulate bidirectional traffic // Simulate bidirectional traffic
// First outbound data // First outbound data
manager.processOutgoingHooks(outPackets[connIdx]) manager.processOutgoingHooks(outPackets[connIdx], 0)
// Then inbound response - this is what we're actually measuring // Then inbound response - this is what we're actually measuring
manager.dropFilter(inPackets[connIdx]) manager.dropFilter(inPackets[connIdx], 0)
} }
}) })
} }
@@ -678,7 +678,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{ manager.SetNetwork(&net.IPNet{
@@ -761,19 +761,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
p := patterns[connIdx] p := patterns[connIdx]
// Connection establishment // Connection establishment
manager.processOutgoingHooks(p.syn) manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck) manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack) manager.processOutgoingHooks(p.ack, 0)
// Data transfer // Data transfer
manager.processOutgoingHooks(p.request) manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response) manager.dropFilter(p.response, 0)
// Connection teardown // Connection teardown
manager.processOutgoingHooks(p.finClient) manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer) manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer) manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient) manager.processOutgoingHooks(p.ackClient, 0)
} }
}) })
} }
@@ -794,7 +794,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{ manager.SetNetwork(&net.IPNet{
@@ -826,15 +826,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
for i := 0; i < sc.connCount; i++ { for i := 0; i < sc.connCount; i++ {
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn)) uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn) manager.processOutgoingHooks(syn, 0)
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack) manager.dropFilter(synack, 0)
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck)) uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack) manager.processOutgoingHooks(ack, 0)
} }
// Pre-generate test packets // Pre-generate test packets
@@ -856,8 +856,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
counter++ counter++
// Simulate bidirectional traffic // Simulate bidirectional traffic
manager.processOutgoingHooks(outPackets[connIdx]) manager.processOutgoingHooks(outPackets[connIdx], 0)
manager.dropFilter(inPackets[connIdx]) manager.dropFilter(inPackets[connIdx], 0)
} }
}) })
}) })
@@ -879,7 +879,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Reset(nil)) require.NoError(b, manager.Close(nil))
}) })
manager.SetNetwork(&net.IPNet{ manager.SetNetwork(&net.IPNet{
@@ -950,17 +950,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
p := patterns[connIdx] p := patterns[connIdx]
// Full connection lifecycle // Full connection lifecycle
manager.processOutgoingHooks(p.syn) manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck) manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack) manager.processOutgoingHooks(p.ack, 0)
manager.processOutgoingHooks(p.request) manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response) manager.dropFilter(p.response, 0)
manager.processOutgoingHooks(p.finClient) manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer) manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer) manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient) manager.processOutgoingHooks(p.ackClient, 0)
} }
}) })
}) })
@@ -1054,8 +1054,8 @@ func BenchmarkRouteACLs(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
for _, tc := range cases { for _, tc := range cases {
srcIP := net.ParseIP(tc.srcIP) srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := net.ParseIP(tc.dstIP) dstIP := netip.MustParseAddr(tc.dstIP)
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort) manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
} }
} }

View File

@@ -12,9 +12,9 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
func TestPeerACLFiltering(t *testing.T) { func TestPeerACLFiltering(t *testing.T) {
@@ -26,8 +26,8 @@ func TestPeerACLFiltering(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: localIP, IP: localIP,
Network: wgNet, Network: wgNet,
} }
@@ -39,7 +39,7 @@ func TestPeerACLFiltering(t *testing.T) {
require.NotNil(t, manager) require.NotNil(t, manager)
t.Cleanup(func() { t.Cleanup(func() {
require.NoError(t, manager.Reset(nil)) require.NoError(t, manager.Close(nil))
}) })
manager.wgNetwork = wgNet manager.wgNetwork = wgNet
@@ -192,7 +192,7 @@ func TestPeerACLFiltering(t *testing.T) {
t.Run("Implicit DROP (no rules)", func(t *testing.T) { t.Run("Implicit DROP (no rules)", func(t *testing.T) {
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443) packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
isDropped := manager.DropIncoming(packet) isDropped := manager.DropIncoming(packet, 0)
require.True(t, isDropped, "Packet should be dropped when no rules exist") require.True(t, isDropped, "Packet should be dropped when no rules exist")
}) })
@@ -217,7 +217,7 @@ func TestPeerACLFiltering(t *testing.T) {
}) })
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort) packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
isDropped := manager.DropIncoming(packet) isDropped := manager.DropIncoming(packet, 0)
require.Equal(t, tc.shouldBeBlocked, isDropped) require.Equal(t, tc.shouldBeBlocked, isDropped)
}) })
} }
@@ -288,8 +288,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: localIP, IP: localIP,
Network: wgNet, Network: wgNet,
} }
@@ -306,11 +306,11 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
require.NoError(tb, manager.EnableRouting()) require.NoError(tb, manager.EnableRouting())
require.NoError(tb, err) require.NoError(tb, err)
require.NotNil(tb, manager) require.NotNil(tb, manager)
require.True(tb, manager.routingEnabled) require.True(tb, manager.routingEnabled.Load())
require.False(tb, manager.nativeRouter) require.False(tb, manager.nativeRouter.Load())
tb.Cleanup(func() { tb.Cleanup(func() {
require.NoError(tb, manager.Reset(nil)) require.NoError(tb, manager.Close(nil))
}) })
return manager return manager
@@ -818,8 +818,8 @@ func TestRouteACLFiltering(t *testing.T) {
require.NoError(t, manager.DeleteRouteRule(rule)) require.NoError(t, manager.DeleteRouteRule(rule))
}) })
srcIP := net.ParseIP(tc.srcIP) srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := net.ParseIP(tc.dstIP) dstIP := netip.MustParseAddr(tc.dstIP)
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed // testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
// to the forwarder // to the forwarder
@@ -1006,8 +1006,8 @@ func TestRouteACLOrder(t *testing.T) {
}) })
for i, p := range tc.packets { for i, p := range tc.packets {
srcIP := net.ParseIP(p.srcIP) srcIP := netip.MustParseAddr(p.srcIP)
dstIP := net.ParseIP(p.dstIP) dstIP := netip.MustParseAddr(p.dstIP)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort) _, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i) require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)

View File

@@ -18,17 +18,17 @@ import (
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log" "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/netflow"
) )
var logger = log.NewFromLogrus(logrus.StandardLogger()) var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}).GetLogger() var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
type IFaceMock struct { type IFaceMock struct {
SetFilterFunc func(device.PacketFilter) error SetFilterFunc func(device.PacketFilter) error
AddressFunc func() iface.WGAddress AddressFunc func() wgaddr.Address
GetWGDeviceFunc func() *wgdevice.Device GetWGDeviceFunc func() *wgdevice.Device
GetDeviceFunc func() *device.FilteredDevice GetDeviceFunc func() *device.FilteredDevice
} }
@@ -54,9 +54,9 @@ func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
return i.SetFilterFunc(iface) return i.SetFilterFunc(iface)
} }
func (i *IFaceMock) Address() iface.WGAddress { func (i *IFaceMock) Address() wgaddr.Address {
if i.AddressFunc == nil { if i.AddressFunc == nil {
return iface.WGAddress{} return wgaddr.Address{}
} }
return i.AddressFunc() return i.AddressFunc()
} }
@@ -125,19 +125,19 @@ func TestManagerDeleteRule(t *testing.T) {
return return
} }
ip := net.ParseIP("192.168.1.1") ip := netip.MustParseAddr("192.168.1.1")
proto := fw.ProtocolTCP proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}} port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop action := fw.ActionDrop
rule2, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "") rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
} }
for _, r := range rule2 { for _, r := range rule2 {
if _, ok := m.incomingRules[ip.String()][r.ID()]; !ok { if _, ok := m.incomingRules[ip][r.ID()]; !ok {
t.Errorf("rule2 is not in the incomingRules") t.Errorf("rule2 is not in the incomingRules")
} }
} }
@@ -151,7 +151,7 @@ func TestManagerDeleteRule(t *testing.T) {
} }
for _, r := range rule2 { for _, r := range rule2 {
if _, ok := m.incomingRules[ip.String()][r.ID()]; ok { if _, ok := m.incomingRules[ip][r.ID()]; ok {
t.Errorf("rule2 is not in the incomingRules") t.Errorf("rule2 is not in the incomingRules")
} }
} }
@@ -162,7 +162,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name string name string
in bool in bool
expDir fw.RuleDirection expDir fw.RuleDirection
ip net.IP ip netip.Addr
dPort uint16 dPort uint16
hook func([]byte) bool hook func([]byte) bool
expectedID string expectedID string
@@ -171,7 +171,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name: "Test Outgoing UDP Packet Hook", name: "Test Outgoing UDP Packet Hook",
in: false, in: false,
expDir: fw.RuleDirectionOUT, expDir: fw.RuleDirectionOUT,
ip: net.IPv4(10, 168, 0, 1), ip: netip.MustParseAddr("10.168.0.1"),
dPort: 8000, dPort: 8000,
hook: func([]byte) bool { return true }, hook: func([]byte) bool { return true },
}, },
@@ -179,7 +179,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name: "Test Incoming UDP Packet Hook", name: "Test Incoming UDP Packet Hook",
in: true, in: true,
expDir: fw.RuleDirectionIN, expDir: fw.RuleDirectionIN,
ip: net.IPv6loopback, ip: netip.MustParseAddr("::1"),
dPort: 9000, dPort: 9000,
hook: func([]byte) bool { return false }, hook: func([]byte) bool { return false },
}, },
@@ -196,11 +196,11 @@ func TestAddUDPPacketHook(t *testing.T) {
var addedRule PeerRule var addedRule PeerRule
if tt.in { if tt.in {
if len(manager.incomingRules[tt.ip.String()]) != 1 { if len(manager.incomingRules[tt.ip]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules)) t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
return return
} }
for _, rule := range manager.incomingRules[tt.ip.String()] { for _, rule := range manager.incomingRules[tt.ip] {
addedRule = rule addedRule = rule
} }
} else { } else {
@@ -208,12 +208,12 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules)) t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
return return
} }
for _, rule := range manager.outgoingRules[tt.ip.String()] { for _, rule := range manager.outgoingRules[tt.ip] {
addedRule = rule addedRule = rule
} }
} }
if !tt.ip.Equal(addedRule.ip) { if tt.ip.Compare(addedRule.ip) != 0 {
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip) t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
return return
} }
@@ -255,7 +255,7 @@ func TestManagerReset(t *testing.T) {
return return
} }
err = m.Reset(nil) err = m.Close(nil)
if err != nil { if err != nil {
t.Errorf("failed to reset Manager: %v", err) t.Errorf("failed to reset Manager: %v", err)
return return
@@ -269,8 +269,8 @@ func TestManagerReset(t *testing.T) {
func TestNotMatchByIP(t *testing.T) { func TestNotMatchByIP(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"), IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"), IP: net.ParseIP("100.10.0.0"),
@@ -328,12 +328,12 @@ func TestNotMatchByIP(t *testing.T) {
return return
} }
if m.dropFilter(buf.Bytes()) { if m.dropFilter(buf.Bytes(), 0) {
t.Errorf("expected packet to be accepted") t.Errorf("expected packet to be accepted")
return return
} }
if err = m.Reset(nil); err != nil { if err = m.Close(nil); err != nil {
t.Errorf("failed to reset Manager: %v", err) t.Errorf("failed to reset Manager: %v", err)
return return
} }
@@ -352,12 +352,12 @@ func TestRemovePacketHook(t *testing.T) {
t.Fatalf("Failed to create Manager: %s", err) t.Fatalf("Failed to create Manager: %s", err)
} }
defer func() { defer func() {
require.NoError(t, manager.Reset(nil)) require.NoError(t, manager.Close(nil))
}() }()
// Add a UDP packet hook // Add a UDP packet hook
hookFunc := func(data []byte) bool { return true } hookFunc := func(data []byte) bool { return true }
hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc) hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
// Assert the hook is added by finding it in the manager's outgoing rules // Assert the hook is added by finding it in the manager's outgoing rules
found := false found := false
@@ -403,7 +403,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
manager.udpTracker.Close() manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger) manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
defer func() { defer func() {
require.NoError(t, manager.Reset(nil)) require.NoError(t, manager.Close(nil))
}() }()
manager.decoders = sync.Pool{ manager.decoders = sync.Pool{
@@ -423,7 +423,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
hookCalled := false hookCalled := false
hookID := manager.AddUDPPacketHook( hookID := manager.AddUDPPacketHook(
false, false,
net.ParseIP("100.10.0.100"), netip.MustParseAddr("100.10.0.100"),
53, 53,
func([]byte) bool { func([]byte) bool {
hookCalled = true hookCalled = true
@@ -458,7 +458,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Test hook gets called // Test hook gets called
result := manager.processOutgoingHooks(buf.Bytes()) result := manager.processOutgoingHooks(buf.Bytes(), 0)
require.True(t, result) require.True(t, result)
require.True(t, hookCalled) require.True(t, hookCalled)
@@ -468,7 +468,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
err = gopacket.SerializeLayers(buf, opts, ipv4) err = gopacket.SerializeLayers(buf, opts, ipv4)
require.NoError(t, err) require.NoError(t, err)
result = manager.processOutgoingHooks(buf.Bytes()) result = manager.processOutgoingHooks(buf.Bytes(), 0)
require.False(t, result) require.False(t, result)
} }
@@ -484,7 +484,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
time.Sleep(time.Second) time.Sleep(time.Second)
defer func() { defer func() {
if err := manager.Reset(nil); err != nil { if err := manager.Close(nil); err != nil {
t.Errorf("clear the manager state: %v", err) t.Errorf("clear the manager state: %v", err)
} }
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -530,7 +530,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}, },
} }
defer func() { defer func() {
require.NoError(t, manager.Reset(nil)) require.NoError(t, manager.Close(nil))
}() }()
// Set up packet parameters // Set up packet parameters
@@ -569,11 +569,11 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Process outbound packet and verify connection tracking // Process outbound packet and verify connection tracking
drop := manager.DropOutgoing(outboundBuf.Bytes()) drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
require.False(t, drop, "Initial outbound packet should not be dropped") require.False(t, drop, "Initial outbound packet should not be dropped")
// Verify connection was tracked // Verify connection was tracked
conn, exists := manager.udpTracker.GetConnection(srcIP.AsSlice(), srcPort, dstIP.AsSlice(), dstPort) conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
require.True(t, exists, "Connection should be tracked after outbound packet") require.True(t, exists, "Connection should be tracked after outbound packet")
require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match") require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match")
@@ -636,12 +636,12 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
for _, cp := range checkPoints { for _, cp := range checkPoints {
time.Sleep(cp.sleep) time.Sleep(cp.sleep)
drop = manager.dropFilter(inboundBuf.Bytes()) drop = manager.dropFilter(inboundBuf.Bytes(), 0)
require.Equal(t, cp.shouldAllow, !drop, cp.description) require.Equal(t, cp.shouldAllow, !drop, cp.description)
// If the connection should still be valid, verify it exists // If the connection should still be valid, verify it exists
if cp.shouldAllow { if cp.shouldAllow {
conn, exists := manager.udpTracker.GetConnection(srcIP.AsSlice(), srcPort, dstIP.AsSlice(), dstPort) conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
require.True(t, exists, "Connection should still exist during valid window") require.True(t, exists, "Connection should still exist during valid window")
require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(), require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(),
"LastSeen should be updated for valid responses") "LastSeen should be updated for valid responses")
@@ -685,7 +685,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
} }
// Create a new outbound connection for invalid tests // Create a new outbound connection for invalid tests
drop = manager.processOutgoingHooks(outboundBuf.Bytes()) drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
require.False(t, drop, "Second outbound packet should not be dropped") require.False(t, drop, "Second outbound packet should not be dropped")
for _, tc := range invalidCases { for _, tc := range invalidCases {
@@ -707,7 +707,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Verify the invalid packet is dropped // Verify the invalid packet is dropped
drop = manager.dropFilter(testBuf.Bytes()) drop = manager.dropFilter(testBuf.Bytes(), 0)
require.True(t, drop, tc.description) require.True(t, drop, tc.description)
}) })
} }

View File

@@ -5,7 +5,6 @@ import (
"net" "net"
"net/netip" "net/netip"
"runtime" "runtime"
"strings"
"sync" "sync"
"github.com/pion/stun/v2" "github.com/pion/stun/v2"
@@ -14,6 +13,8 @@ import (
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn" wgConn "golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type RecvMessage struct { type RecvMessage struct {
@@ -52,9 +53,10 @@ type ICEBind struct {
muUDPMux sync.Mutex muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault udpMux *UniversalUDPMuxDefault
address wgaddr.Address
} }
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind { func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{ ib := &ICEBind{
StdNetBind: b, StdNetBind: b,
@@ -64,6 +66,7 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
endpoints: make(map[netip.Addr]net.Conn), endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}), closedChan: make(chan struct{}),
closed: true, closed: true,
address: address,
} }
rc := receiverCreator{ rc := receiverCreator{
@@ -108,35 +111,17 @@ func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
return s.udpMux, nil return s.udpMux, nil
} }
func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) { func (b *ICEBind) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
fakeUDPAddr, err := fakeAddress(peerAddress)
if err != nil {
return nil, err
}
// force IPv4
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
return nil, fmt.Errorf("failed to convert IP to netip.Addr")
}
b.endpointsMu.Lock() b.endpointsMu.Lock()
b.endpoints[fakeAddr] = conn b.endpoints[fakeIP] = conn
b.endpointsMu.Unlock() b.endpointsMu.Unlock()
return fakeUDPAddr, nil
} }
func (b *ICEBind) RemoveEndpoint(fakeUDPAddr *net.UDPAddr) { func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) {
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
if !ok {
log.Warnf("failed to convert IP to netip.Addr")
return
}
b.endpointsMu.Lock() b.endpointsMu.Lock()
defer b.endpointsMu.Unlock() defer b.endpointsMu.Unlock()
delete(b.endpoints, fakeAddr)
delete(b.endpoints, fakeIP)
} }
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
@@ -161,9 +146,10 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
s.udpMux = NewUniversalUDPMuxDefault( s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{ UniversalUDPMuxParams{
UDPConn: conn, UDPConn: conn,
Net: s.transportNet, Net: s.transportNet,
FilterFn: s.filterFn, FilterFn: s.filterFn,
WGAddress: s.address,
}, },
) )
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
@@ -275,21 +261,6 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
} }
} }
// fakeAddress returns a fake address that is used to as an identifier for the peer.
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
octets := strings.Split(peerAddress.IP.String(), ".")
if len(octets) != 4 {
return nil, fmt.Errorf("invalid IP format")
}
newAddr := &net.UDPAddr{
IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])),
Port: peerAddress.Port,
}
return newAddr, nil
}
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message { func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
return msgsPool.Get().(*[]ipv6.Message) return msgsPool.Get().(*[]ipv6.Message)
} }

View File

@@ -17,6 +17,8 @@ import (
"github.com/pion/logging" "github.com/pion/logging"
"github.com/pion/stun/v2" "github.com/pion/stun/v2"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// FilterFn is a function that filters out candidates based on the address. // FilterFn is a function that filters out candidates based on the address.
@@ -41,6 +43,7 @@ type UniversalUDPMuxParams struct {
XORMappedAddrCacheTTL time.Duration XORMappedAddrCacheTTL time.Duration
Net transport.Net Net transport.Net
FilterFn FilterFn FilterFn FilterFn
WGAddress wgaddr.Address
} }
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux // NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
@@ -64,6 +67,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
mux: m, mux: m,
logger: params.Logger, logger: params.Logger,
filterFn: params.FilterFn, filterFn: params.FilterFn,
address: params.WGAddress,
} }
// embed UDPMux // embed UDPMux
@@ -118,6 +122,7 @@ type udpConn struct {
filterFn FilterFn filterFn FilterFn
// TODO: reset cache on route changes // TODO: reset cache on route changes
addrCache sync.Map addrCache sync.Map
address wgaddr.Address
} }
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) { func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
@@ -159,6 +164,11 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
return nil return nil
} }
if u.address.Network.Contains(a.AsSlice()) {
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
}
if isRouted, prefix, err := u.filterFn(a); err != nil { if isRouted, prefix, err := u.filterFn(a); err != nil {
log.Errorf("Failed to check if address %s is routed: %v", addr, err) log.Errorf("Failed to check if address %s is routed: %v", addr, err)
} else { } else {

View File

@@ -9,13 +9,14 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type WGTunDevice interface { type WGTunDevice interface {
Create() (device.WGConfigurer, error) Create() (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address WGAddress) error UpdateAddr(address wgaddr.Address) error
WgAddress() WGAddress WgAddress() wgaddr.Address
DeviceName() string DeviceName() string
Close() error Close() error
FilteredDevice() *device.FilteredDevice FilteredDevice() *device.FilteredDevice

View File

@@ -13,11 +13,12 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform // WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
type WGTunDevice struct { type WGTunDevice struct {
address WGAddress address wgaddr.Address
port int port int
key string key string
mtu int mtu int
@@ -31,7 +32,7 @@ type WGTunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice { func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
return &WGTunDevice{ return &WGTunDevice{
address: address, address: address,
port: port, port: port,
@@ -93,7 +94,7 @@ func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *WGTunDevice) UpdateAddr(addr WGAddress) error { func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error {
// todo implement // todo implement
return nil return nil
} }
@@ -123,7 +124,7 @@ func (t *WGTunDevice) DeviceName() string {
return t.name return t.name
} }
func (t *WGTunDevice) WgAddress() WGAddress { func (t *WGTunDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }

View File

@@ -13,11 +13,12 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type TunDevice struct { type TunDevice struct {
name string name string
address WGAddress address wgaddr.Address
port int port int
key string key string
mtu int mtu int
@@ -29,7 +30,7 @@ type TunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice { func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{ return &TunDevice{
name: name, name: name,
address: address, address: address,
@@ -85,7 +86,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *TunDevice) UpdateAddr(address WGAddress) error { func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
t.address = address t.address = address
return t.assignAddr() return t.assignAddr()
} }
@@ -106,7 +107,7 @@ func (t *TunDevice) Close() error {
return nil return nil
} }
func (t *TunDevice) WgAddress() WGAddress { func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }

View File

@@ -2,6 +2,7 @@ package device
import ( import (
"net" "net"
"net/netip"
"sync" "sync"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
@@ -10,16 +11,16 @@ import (
// PacketFilter interface for firewall abilities // PacketFilter interface for firewall abilities
type PacketFilter interface { type PacketFilter interface {
// DropOutgoing filter outgoing packets from host to external destinations // DropOutgoing filter outgoing packets from host to external destinations
DropOutgoing(packetData []byte) bool DropOutgoing(packetData []byte, size int) bool
// DropIncoming filter incoming packets from external sources to host // DropIncoming filter incoming packets from external sources to host
DropIncoming(packetData []byte) bool DropIncoming(packetData []byte, size int) bool
// AddUDPPacketHook calls hook when UDP packet from given direction matched // AddUDPPacketHook calls hook when UDP packet from given direction matched
// //
// Hook function returns flag which indicates should be the matched package dropped or not. // Hook function returns flag which indicates should be the matched package dropped or not.
// Hook function receives raw network packet data as argument. // Hook function receives raw network packet data as argument.
AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
// RemovePacketHook removes hook by ID // RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error RemovePacketHook(hookID string) error
@@ -57,7 +58,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
} }
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
if filter.DropOutgoing(bufs[i][offset : offset+sizes[i]]) { if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
bufs = append(bufs[:i], bufs[i+1:]...) bufs = append(bufs[:i], bufs[i+1:]...)
sizes = append(sizes[:i], sizes[i+1:]...) sizes = append(sizes[:i], sizes[i+1:]...)
n-- n--
@@ -81,7 +82,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
filteredBufs := make([][]byte, 0, len(bufs)) filteredBufs := make([][]byte, 0, len(bufs))
dropped := 0 dropped := 0
for _, buf := range bufs { for _, buf := range bufs {
if !filter.DropIncoming(buf[offset:]) { if !filter.DropIncoming(buf[offset:], len(buf)) {
filteredBufs = append(filteredBufs, buf) filteredBufs = append(filteredBufs, buf)
dropped++ dropped++
} }

View File

@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
tun.EXPECT().Write(mockBufs, 0).Return(0, nil) tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
filter := mocks.NewMockPacketFilter(ctrl) filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropIncoming(gomock.Any()).Return(true) filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun) wrapped := newDeviceFilter(tun)
wrapped.filter = filter wrapped.filter = filter
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
return 1, nil return 1, nil
}) })
filter := mocks.NewMockPacketFilter(ctrl) filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropOutgoing(gomock.Any()).Return(true) filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun) wrapped := newDeviceFilter(tun)
wrapped.filter = filter wrapped.filter = filter

View File

@@ -14,11 +14,12 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type TunDevice struct { type TunDevice struct {
name string name string
address WGAddress address wgaddr.Address
port int port int
key string key string
iceBind *bind.ICEBind iceBind *bind.ICEBind
@@ -30,7 +31,7 @@ type TunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(name string, address WGAddress, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice { func NewTunDevice(name string, address wgaddr.Address, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
return &TunDevice{ return &TunDevice{
name: name, name: name,
address: address, address: address,
@@ -120,11 +121,11 @@ func (t *TunDevice) Close() error {
return nil return nil
} }
func (t *TunDevice) WgAddress() WGAddress { func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }
func (t *TunDevice) UpdateAddr(addr WGAddress) error { func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error {
// todo implement // todo implement
return nil return nil
} }

View File

@@ -14,12 +14,13 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/sharedsock" "github.com/netbirdio/netbird/sharedsock"
) )
type TunKernelDevice struct { type TunKernelDevice struct {
name string name string
address WGAddress address wgaddr.Address
wgPort int wgPort int
key string key string
mtu int mtu int
@@ -34,7 +35,7 @@ type TunKernelDevice struct {
filterFn bind.FilterFn filterFn bind.FilterFn
} }
func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice { func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &TunKernelDevice{ return &TunKernelDevice{
ctx: ctx, ctx: ctx,
@@ -99,9 +100,10 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return nil, err return nil, err
} }
bindParams := bind.UniversalUDPMuxParams{ bindParams := bind.UniversalUDPMuxParams{
UDPConn: rawSock, UDPConn: rawSock,
Net: t.transportNet, Net: t.transportNet,
FilterFn: t.filterFn, FilterFn: t.filterFn,
WGAddress: t.address,
} }
mux := bind.NewUniversalUDPMuxDefault(bindParams) mux := bind.NewUniversalUDPMuxDefault(bindParams)
go mux.ReadFromConn(t.ctx) go mux.ReadFromConn(t.ctx)
@@ -112,7 +114,7 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return t.udpMux, nil return t.udpMux, nil
} }
func (t *TunKernelDevice) UpdateAddr(address WGAddress) error { func (t *TunKernelDevice) UpdateAddr(address wgaddr.Address) error {
t.address = address t.address = address
return t.assignAddr() return t.assignAddr()
} }
@@ -145,7 +147,7 @@ func (t *TunKernelDevice) Close() error {
return closErr return closErr
} }
func (t *TunKernelDevice) WgAddress() WGAddress { func (t *TunKernelDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }

View File

@@ -13,12 +13,13 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
type TunNetstackDevice struct { type TunNetstackDevice struct {
name string name string
address WGAddress address wgaddr.Address
port int port int
key string key string
mtu int mtu int
@@ -34,7 +35,7 @@ type TunNetstackDevice struct {
net *netstack.Net net *netstack.Net
} }
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice { func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
return &TunNetstackDevice{ return &TunNetstackDevice{
name: name, name: name,
address: address, address: address,
@@ -97,7 +98,7 @@ func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *TunNetstackDevice) UpdateAddr(WGAddress) error { func (t *TunNetstackDevice) UpdateAddr(wgaddr.Address) error {
return nil return nil
} }
@@ -116,7 +117,7 @@ func (t *TunNetstackDevice) Close() error {
return nil return nil
} }
func (t *TunNetstackDevice) WgAddress() WGAddress { func (t *TunNetstackDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }

View File

@@ -12,11 +12,12 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type USPDevice struct { type USPDevice struct {
name string name string
address WGAddress address wgaddr.Address
port int port int
key string key string
mtu int mtu int
@@ -28,7 +29,7 @@ type USPDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice { func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
log.Infof("using userspace bind mode") log.Infof("using userspace bind mode")
return &USPDevice{ return &USPDevice{
@@ -93,7 +94,7 @@ func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *USPDevice) UpdateAddr(address WGAddress) error { func (t *USPDevice) UpdateAddr(address wgaddr.Address) error {
t.address = address t.address = address
return t.assignAddr() return t.assignAddr()
} }
@@ -113,7 +114,7 @@ func (t *USPDevice) Close() error {
return nil return nil
} }
func (t *USPDevice) WgAddress() WGAddress { func (t *USPDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }

View File

@@ -13,13 +13,14 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}" const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}"
type TunDevice struct { type TunDevice struct {
name string name string
address WGAddress address wgaddr.Address
port int port int
key string key string
mtu int mtu int
@@ -32,7 +33,7 @@ type TunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice { func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{ return &TunDevice{
name: name, name: name,
address: address, address: address,
@@ -118,7 +119,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *TunDevice) UpdateAddr(address WGAddress) error { func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
t.address = address t.address = address
return t.assignAddr() return t.assignAddr()
} }
@@ -139,7 +140,7 @@ func (t *TunDevice) Close() error {
} }
return nil return nil
} }
func (t *TunDevice) WgAddress() WGAddress { func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }

View File

@@ -6,6 +6,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/freebsd" "github.com/netbirdio/netbird/client/iface/freebsd"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type wgLink struct { type wgLink struct {
@@ -56,7 +57,7 @@ func (l *wgLink) up() error {
return nil return nil
} }
func (l *wgLink) assignAddr(address WGAddress) error { func (l *wgLink) assignAddr(address wgaddr.Address) error {
link, err := freebsd.LinkByName(l.name) link, err := freebsd.LinkByName(l.name)
if err != nil { if err != nil {
return fmt.Errorf("link by name: %w", err) return fmt.Errorf("link by name: %w", err)

View File

@@ -8,6 +8,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type wgLink struct { type wgLink struct {
@@ -90,7 +92,7 @@ func (l *wgLink) up() error {
return nil return nil
} }
func (l *wgLink) assignAddr(address WGAddress) error { func (l *wgLink) assignAddr(address wgaddr.Address) error {
//delete existing addresses //delete existing addresses
list, err := netlink.AddrList(l, 0) list, err := netlink.AddrList(l, 0)
if err != nil { if err != nil {

View File

@@ -7,13 +7,14 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type WGTunDevice interface { type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error) Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address WGAddress) error UpdateAddr(address wgaddr.Address) error
WgAddress() WGAddress WgAddress() wgaddr.Address
DeviceName() string DeviceName() string
Close() error Close() error
FilteredDevice() *device.FilteredDevice FilteredDevice() *device.FilteredDevice

View File

@@ -19,6 +19,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
@@ -28,8 +29,6 @@ const (
WgInterfaceDefault = configurer.WgInterfaceDefault WgInterfaceDefault = configurer.WgInterfaceDefault
) )
type WGAddress = device.WGAddress
type wgProxyFactory interface { type wgProxyFactory interface {
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy
Free() error Free() error
@@ -72,7 +71,7 @@ func (w *WGIface) Name() string {
} }
// Address returns the interface address // Address returns the interface address
func (w *WGIface) Address() device.WGAddress { func (w *WGIface) Address() wgaddr.Address {
return w.tun.WgAddress() return w.tun.WgAddress()
} }
@@ -103,7 +102,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
addr, err := device.ParseWGAddress(newAddr) addr, err := wgaddr.ParseWGAddress(newAddr)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -3,17 +3,18 @@ package iface
import ( import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address) wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
wgIFace := &WGIface{ wgIFace := &WGIface{
userspaceBind: true, userspaceBind: true,

View File

@@ -6,17 +6,18 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address) wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
var tun WGTunDevice var tun WGTunDevice
if netstack.IsEnabled() { if netstack.IsEnabled() {

View File

@@ -5,17 +5,18 @@ package iface
import ( import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address) wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
wgIFace := &WGIface{ wgIFace := &WGIface{
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd), tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd),

View File

@@ -8,12 +8,13 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address) wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -21,7 +22,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{} wgIFace := &WGIface{}
if netstack.IsEnabled() { if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
@@ -34,7 +35,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
return wgIFace, nil return wgIFace, nil
} }
if device.ModuleTunIsLoaded() { if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)

View File

@@ -4,16 +4,17 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
wgaddr "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address) wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
var tun WGTunDevice var tun WGTunDevice
if netstack.IsEnabled() { if netstack.IsEnabled() {

View File

@@ -6,6 +6,7 @@ package mocks
import ( import (
net "net" net "net"
"net/netip"
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
@@ -35,7 +36,7 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
} }
// AddUDPPacketHook mocks base method. // AddUDPPacketHook mocks base method.
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func([]byte) bool) string { func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3) ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(string) ret0, _ := ret[0].(string)
@@ -49,31 +50,31 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
} }
// DropIncoming mocks base method. // DropIncoming mocks base method.
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool { func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropIncoming", arg0) ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1)
ret0, _ := ret[0].(bool) ret0, _ := ret[0].(bool)
return ret0 return ret0
} }
// DropIncoming indicates an expected call of DropIncoming. // DropIncoming indicates an expected call of DropIncoming.
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call { func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1)
} }
// DropOutgoing mocks base method. // DropOutgoing mocks base method.
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool { func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropOutgoing", arg0) ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1)
ret0, _ := ret[0].(bool) ret0, _ := ret[0].(bool)
return ret0 return ret0
} }
// DropOutgoing indicates an expected call of DropOutgoing. // DropOutgoing indicates an expected call of DropOutgoing.
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call { func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1)
} }
// RemovePacketHook mocks base method. // RemovePacketHook mocks base method.

View File

@@ -55,7 +55,7 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
skipProxy, err := strconv.ParseBool(os.Getenv(EnvSkipProxy)) skipProxy, err := strconv.ParseBool(os.Getenv(EnvSkipProxy))
if err != nil { if err != nil {
log.Errorf("failed to parse NB_ETSTACK_SKIP_PROXY: %s", err) log.Errorf("failed to parse %s: %s", EnvSkipProxy, err)
} }
if skipProxy { if skipProxy {
return nsTunDev, tunNet, nil return nsTunDev, tunNet, nil

View File

@@ -1,29 +1,29 @@
package device package wgaddr
import ( import (
"fmt" "fmt"
"net" "net"
) )
// WGAddress WireGuard parsed address // Address WireGuard parsed address
type WGAddress struct { type Address struct {
IP net.IP IP net.IP
Network *net.IPNet Network *net.IPNet
} }
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address // ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
func ParseWGAddress(address string) (WGAddress, error) { func ParseWGAddress(address string) (Address, error) {
ip, network, err := net.ParseCIDR(address) ip, network, err := net.ParseCIDR(address)
if err != nil { if err != nil {
return WGAddress{}, err return Address{}, err
} }
return WGAddress{ return Address{
IP: ip, IP: ip,
Network: network, Network: network,
}, nil }, nil
} }
func (addr WGAddress) String() string { func (addr Address) String() string {
maskSize, _ := addr.Network.Mask.Size() maskSize, _ := addr.Network.Mask.Size()
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize) return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
} }

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"strings"
"sync" "sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -16,13 +17,13 @@ import (
type ProxyBind struct { type ProxyBind struct {
Bind *bind.ICEBind Bind *bind.ICEBind
wgAddr *net.UDPAddr fakeNetIP *netip.AddrPort
wgEndpoint *bind.Endpoint wgBindEndpoint *bind.Endpoint
remoteConn net.Conn remoteConn net.Conn
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
closeMu sync.Mutex closeMu sync.Mutex
closed bool closed bool
pausedMu sync.Mutex pausedMu sync.Mutex
paused bool paused bool
@@ -33,20 +34,24 @@ type ProxyBind struct {
// endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the // endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the
// WireGuard configuration. // WireGuard configuration.
func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error { func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error {
addr, err := p.Bind.SetEndpoint(nbAddr, remoteConn) fakeNetIP, err := fakeAddress(nbAddr)
if err != nil { if err != nil {
return err return err
} }
p.wgAddr = addr p.fakeNetIP = fakeNetIP
p.wgEndpoint = addrToEndpoint(addr) p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP}
p.remoteConn = remoteConn p.remoteConn = remoteConn
p.ctx, p.cancel = context.WithCancel(ctx) p.ctx, p.cancel = context.WithCancel(ctx)
return err return nil
} }
func (p *ProxyBind) EndpointAddr() *net.UDPAddr { func (p *ProxyBind) EndpointAddr() *net.UDPAddr {
return p.wgAddr return &net.UDPAddr{
IP: p.fakeNetIP.Addr().AsSlice(),
Port: int(p.fakeNetIP.Port()),
Zone: p.fakeNetIP.Addr().Zone(),
}
} }
func (p *ProxyBind) Work() { func (p *ProxyBind) Work() {
@@ -54,6 +59,8 @@ func (p *ProxyBind) Work() {
return return
} }
p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn)
p.pausedMu.Lock() p.pausedMu.Lock()
p.paused = false p.paused = false
p.pausedMu.Unlock() p.pausedMu.Unlock()
@@ -93,7 +100,7 @@ func (p *ProxyBind) close() error {
p.cancel() p.cancel()
p.Bind.RemoveEndpoint(p.wgAddr) p.Bind.RemoveEndpoint(p.fakeNetIP.Addr())
if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) { if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) {
return rErr return rErr
@@ -126,7 +133,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
} }
msg := bind.RecvMessage{ msg := bind.RecvMessage{
Endpoint: p.wgEndpoint, Endpoint: p.wgBindEndpoint,
Buffer: buf[:n], Buffer: buf[:n],
} }
p.Bind.RecvChan <- msg p.Bind.RecvChan <- msg
@@ -134,8 +141,19 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
} }
} }
func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint { // fakeAddress returns a fake address that is used to as an identifier for the peer.
ip, _ := netip.AddrFromSlice(addr.IP.To4()) // The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
addrPort := netip.AddrPortFrom(ip, uint16(addr.Port)) func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) {
return &bind.Endpoint{AddrPort: addrPort} octets := strings.Split(peerAddress.IP.String(), ".")
if len(octets) != 4 {
return nil, fmt.Errorf("invalid IP format")
}
fakeIP, err := netip.ParseAddr(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3]))
if err != nil {
return nil, fmt.Errorf("failed to parse new IP: %w", err)
}
netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port))
return &netipAddr, nil
} }

View File

@@ -6,8 +6,8 @@
!define DESCRIPTION "A WireGuard®-based mesh network that connects your devices into a single private network" !define DESCRIPTION "A WireGuard®-based mesh network that connects your devices into a single private network"
!define INSTALLER_NAME "netbird-installer.exe" !define INSTALLER_NAME "netbird-installer.exe"
!define MAIN_APP_EXE "Netbird" !define MAIN_APP_EXE "Netbird"
!define ICON "ui\\netbird.ico" !define ICON "ui\\assets\\netbird.ico"
!define BANNER "ui\\banner.bmp" !define BANNER "ui\\build\\banner.bmp"
!define LICENSE_DATA "..\\LICENSE" !define LICENSE_DATA "..\\LICENSE"
!define INSTALL_DIR "$PROGRAMFILES64\${APP_NAME}" !define INSTALL_DIR "$PROGRAMFILES64\${APP_NAME}"
@@ -22,6 +22,8 @@
!define UI_REG_APP_PATH "Software\Microsoft\Windows\CurrentVersion\App Paths\${UI_APP_EXE}" !define UI_REG_APP_PATH "Software\Microsoft\Windows\CurrentVersion\App Paths\${UI_APP_EXE}"
!define UI_UNINSTALL_PATH "Software\Microsoft\Windows\CurrentVersion\Uninstall\${UI_APP_NAME}" !define UI_UNINSTALL_PATH "Software\Microsoft\Windows\CurrentVersion\Uninstall\${UI_APP_NAME}"
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
Unicode True Unicode True
###################################################################### ######################################################################
@@ -68,6 +70,9 @@ ShowInstDetails Show
!insertmacro MUI_PAGE_DIRECTORY !insertmacro MUI_PAGE_DIRECTORY
; Custom page for autostart checkbox
Page custom AutostartPage AutostartPageLeave
!insertmacro MUI_PAGE_INSTFILES !insertmacro MUI_PAGE_INSTFILES
!insertmacro MUI_PAGE_FINISH !insertmacro MUI_PAGE_FINISH
@@ -80,8 +85,36 @@ ShowInstDetails Show
!insertmacro MUI_LANGUAGE "English" !insertmacro MUI_LANGUAGE "English"
; Variables for autostart option
Var AutostartCheckbox
Var AutostartEnabled
###################################################################### ######################################################################
; Function to create the autostart options page
Function AutostartPage
!insertmacro MUI_HEADER_TEXT "Startup Options" "Configure how ${APP_NAME} launches with Windows."
nsDialogs::Create 1018
Pop $0
${If} $0 == error
Abort
${EndIf}
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
Pop $AutostartCheckbox
${NSD_Check} $AutostartCheckbox ; Default to checked
StrCpy $AutostartEnabled "1" ; Default to enabled
nsDialogs::Show
FunctionEnd
; Function to handle leaving the autostart page
Function AutostartPageLeave
${NSD_GetState} $AutostartCheckbox $AutostartEnabled
FunctionEnd
Function GetAppFromCommand Function GetAppFromCommand
Exch $1 Exch $1
Push $2 Push $2
@@ -163,6 +196,16 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}" WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
; Create autostart registry entry based on checkbox
DetailPrint "Autostart enabled: $AutostartEnabled"
${If} $AutostartEnabled == "1"
WriteRegStr HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" "$INSTDIR\${UI_APP_EXE}.exe"
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
${Else}
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DetailPrint "Autostart not enabled by user"
${EndIf}
EnVar::SetHKLM EnVar::SetHKLM
EnVar::AddValueEx "path" "$INSTDIR" EnVar::AddValueEx "path" "$INSTDIR"
@@ -186,7 +229,10 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall' ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
# kill ui client # kill ui client
ExecWait `taskkill /im ${UI_APP_EXE}.exe` ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart registry entry
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
# wait the service uninstall take unblock the executable # wait the service uninstall take unblock the executable
Sleep 3000 Sleep 3000

View File

@@ -9,13 +9,13 @@ import (
"github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks" "github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/netflow"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
) )
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}).GetLogger() var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
func TestDefaultManager(t *testing.T) { func TestDefaultManager(t *testing.T) {
networkMap := &mgmProto.NetworkMap{ networkMap := &mgmProto.NetworkMap{
@@ -49,7 +49,7 @@ func TestDefaultManager(t *testing.T) {
} }
ifaceMock.EXPECT().Name().Return("lo").AnyTimes() ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(iface.WGAddress{ ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: ip, IP: ip,
Network: network, Network: network,
}).AnyTimes() }).AnyTimes()
@@ -62,7 +62,7 @@ func TestDefaultManager(t *testing.T) {
return return
} }
defer func(fw manager.Manager) { defer func(fw manager.Manager) {
_ = fw.Reset(nil) _ = fw.Close(nil)
}(fw) }(fw)
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)
@@ -343,7 +343,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
} }
ifaceMock.EXPECT().Name().Return("lo").AnyTimes() ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(iface.WGAddress{ ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: ip, IP: ip,
Network: network, Network: network,
}).AnyTimes() }).AnyTimes()
@@ -356,7 +356,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
return return
} }
defer func(fw manager.Manager) { defer func(fw manager.Manager) {
_ = fw.Reset(nil) _ = fw.Close(nil)
}(fw) }(fw)
acl := NewDefaultManager(fw) acl := NewDefaultManager(fw)

View File

@@ -10,8 +10,8 @@ import (
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
iface "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// MockIFaceMapper is a mock of IFaceMapper interface. // MockIFaceMapper is a mock of IFaceMapper interface.
@@ -38,10 +38,10 @@ func (m *MockIFaceMapper) EXPECT() *MockIFaceMapperMockRecorder {
} }
// Address mocks base method. // Address mocks base method.
func (m *MockIFaceMapper) Address() iface.WGAddress { func (m *MockIFaceMapper) Address() wgaddr.Address {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Address") ret := m.ctrl.Call(m, "Address")
ret0, _ := ret[0].(iface.WGAddress) ret0, _ := ret[0].(wgaddr.Address)
return ret0 return ret0
} }

View File

@@ -61,7 +61,7 @@ func NewConnectClient(
} }
// Run with main logic. // Run with main logic.
func (c *ConnectClient) Run(runningChan chan error) error { func (c *ConnectClient) Run(runningChan chan struct{}) error {
return c.run(MobileDependency{}, runningChan) return c.run(MobileDependency{}, runningChan)
} }
@@ -102,7 +102,7 @@ func (c *ConnectClient) RunOniOS(
return c.run(mobileDependency, nil) return c.run(mobileDependency, nil)
} }
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan error) error { func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}) error {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
rec := c.statusRecorder rec := c.statusRecorder
@@ -159,10 +159,9 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
} }
defer c.statusRecorder.ClientStop() defer c.statusRecorder.ClientStop()
runningChanOpen := true
operation := func() error { operation := func() error {
// if context cancelled we not start new backoff cycle // if context cancelled we not start new backoff cycle
if c.isContextCancelled() { if c.ctx.Err() != nil {
return nil return nil
} }
@@ -282,10 +281,11 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress()) log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected) state.Set(StatusConnected)
if runningChan != nil && runningChanOpen { if runningChan != nil {
runningChan <- nil select {
close(runningChan) case runningChan <- struct{}{}:
runningChanOpen = false default:
}
} }
<-engineCtx.Done() <-engineCtx.Done()
@@ -379,15 +379,6 @@ func (c *ConnectClient) Stop() error {
return nil return nil
} }
func (c *ConnectClient) isContextCancelled() bool {
select {
case <-c.ctx.Done():
return true
default:
return false
}
}
// SetNetworkMapPersistence enables or disables network map persistence. // SetNetworkMapPersistence enables or disables network map persistence.
// When enabled, the last received network map will be stored and can be retrieved // When enabled, the last received network map will be stored and can be retrieved
// through the Engine's getLatestNetworkMap method. When disabled, any stored // through the Engine's getLatestNetworkMap method. When disabled, any stored

View File

@@ -22,6 +22,7 @@ import (
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
pfmock "github.com/netbirdio/netbird/client/iface/mocks" pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/netflow"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
@@ -30,7 +31,7 @@ import (
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
) )
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}).GetLogger() var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
type mocWGIface struct { type mocWGIface struct {
filter device.PacketFilter filter device.PacketFilter
@@ -40,9 +41,9 @@ func (w *mocWGIface) Name() string {
panic("implement me") panic("implement me")
} }
func (w *mocWGIface) Address() iface.WGAddress { func (w *mocWGIface) Address() wgaddr.Address {
ip, network, _ := net.ParseCIDR("100.66.100.0/24") ip, network, _ := net.ParseCIDR("100.66.100.0/24")
return iface.WGAddress{ return wgaddr.Address{
IP: ip, IP: ip,
Network: network, Network: network,
} }
@@ -458,7 +459,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
} }
packetfilter := pfmock.NewMockPacketFilter(ctrl) packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes() packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
packetfilter.EXPECT().RemovePacketHook(gomock.Any()) packetfilter.EXPECT().RemovePacketHook(gomock.Any())
packetfilter.EXPECT().SetNetwork(ipNet) packetfilter.EXPECT().SetNetwork(ipNet)
@@ -1018,7 +1019,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
mh.AssertExpectations(t) mh.AssertExpectations(t)
} }
// Reset mocks // Close mocks
if mh, ok := tc.expectedHandler.(*MockHandler); ok { if mh, ok := tc.expectedHandler.(*MockHandler); ok {
mh.ExpectedCalls = nil mh.ExpectedCalls = nil
mh.Calls = nil mh.Calls = nil

View File

@@ -2,7 +2,7 @@ package dns
import ( import (
"fmt" "fmt"
"net" "net/netip"
"sync" "sync"
"github.com/google/gopacket" "github.com/google/gopacket"
@@ -117,5 +117,10 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
return true return true
} }
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil ip, err := netip.ParseAddr(s.runtimeIP)
if err != nil {
return "", fmt.Errorf("parse runtime ip: %w", err)
}
return filter.AddUDPPacketHook(false, ip, uint16(s.runtimePort), hook), nil
} }

View File

@@ -5,15 +5,15 @@ package dns
import ( import (
"net" "net"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// WGIface defines subset methods of interface required for manager // WGIface defines subset methods of interface required for manager
type WGIface interface { type WGIface interface {
Name() string Name() string
Address() iface.WGAddress Address() wgaddr.Address
ToInterface() *net.Interface ToInterface() *net.Interface
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() device.PacketFilter GetFilter() device.PacketFilter

View File

@@ -1,15 +1,15 @@
package dns package dns
import ( import (
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// WGIface defines subset methods of interface required for manager // WGIface defines subset methods of interface required for manager
type WGIface interface { type WGIface interface {
Name() string Name() string
Address() iface.WGAddress Address() wgaddr.Address
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice

View File

@@ -353,7 +353,7 @@ func (e *Engine) Start() error {
// start flow manager right after interface creation // start flow manager right after interface creation
publicKey := e.config.WgPrivateKey.PublicKey() publicKey := e.config.WgPrivateKey.PublicKey()
e.flowManager = netflow.NewManager(e.ctx, e.wgInterface, publicKey[:]) e.flowManager = netflow.NewManager(e.ctx, e.wgInterface, publicKey[:], e.statusRecorder)
if e.config.RosenpassEnabled { if e.config.RosenpassEnabled {
log.Infof("rosenpass is enabled") log.Infof("rosenpass is enabled")
@@ -1428,7 +1428,7 @@ func (e *Engine) close() {
} }
if e.firewall != nil { if e.firewall != nil {
err := e.firewall.Reset(e.stateManager) err := e.firewall.Close(e.stateManager)
if err != nil { if err != nil {
log.Warnf("failed to reset firewall: %s", err) log.Warnf("failed to reset firewall: %s", err)
} }
@@ -1641,16 +1641,19 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
return relay.ProbeAll(e.ctx, relay.ProbeTURN, turns) return relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)
} }
// restartEngine restarts the engine by cancelling the client context
func (e *Engine) restartEngine() { func (e *Engine) restartEngine() {
log.Info("restarting engine") e.syncMsgMux.Lock()
CtxGetState(e.ctx).Set(StatusConnecting) defer e.syncMsgMux.Unlock()
if err := e.Stop(); err != nil { if e.ctx.Err() != nil {
log.Errorf("Failed to stop engine: %v", err) return
} }
log.Info("restarting engine")
CtxGetState(e.ctx).Set(StatusConnecting)
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection) _ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
log.Infof("cancelling client, engine will be recreated") log.Infof("cancelling client context, engine will be recreated")
e.clientCancel() e.clientCancel()
} }
@@ -1662,34 +1665,17 @@ func (e *Engine) startNetworkMonitor() {
e.networkMonitor = networkmonitor.New() e.networkMonitor = networkmonitor.New()
go func() { go func() {
var mu sync.Mutex if err := e.networkMonitor.Listen(e.ctx); err != nil {
var debounceTimer *time.Timer if errors.Is(err, context.Canceled) {
log.Infof("network monitor stopped")
// Start the network monitor with a callback, Start will block until the monitor is stopped, return
// a network change is detected, or an error occurs on start up
err := e.networkMonitor.Start(e.ctx, func() {
// This function is called when a network change is detected
mu.Lock()
defer mu.Unlock()
if debounceTimer != nil {
log.Infof("Network monitor: detected network change, reset debounceTimer")
debounceTimer.Stop()
} }
log.Errorf("network monitor error: %v", err)
// Set a new timer to debounce rapid network changes return
debounceTimer = time.AfterFunc(2*time.Second, func() {
// This function is called after the debounce period
mu.Lock()
defer mu.Unlock()
log.Infof("Network monitor: detected network change, restarting engine")
e.restartEngine()
})
})
if err != nil && !errors.Is(err, networkmonitor.ErrStopped) {
log.Errorf("Network monitor: %v", err)
} }
log.Infof("Network monitor: detected network change, restarting engine")
e.restartEngine()
}() }()
} }

View File

@@ -31,6 +31,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
@@ -75,7 +76,7 @@ type MockWGIface struct {
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
IsUserspaceBindFunc func() bool IsUserspaceBindFunc func() bool
NameFunc func() string NameFunc func() string
AddressFunc func() device.WGAddress AddressFunc func() wgaddr.Address
ToInterfaceFunc func() *net.Interface ToInterfaceFunc func() *net.Interface
UpFunc func() (*bind.UniversalUDPMuxDefault, error) UpFunc func() (*bind.UniversalUDPMuxDefault, error)
UpdateAddrFunc func(newAddr string) error UpdateAddrFunc func(newAddr string) error
@@ -114,7 +115,7 @@ func (m *MockWGIface) Name() string {
return m.NameFunc() return m.NameFunc()
} }
func (m *MockWGIface) Address() device.WGAddress { func (m *MockWGIface) Address() wgaddr.Address {
return m.AddressFunc() return m.AddressFunc()
} }
@@ -364,8 +365,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
RemovePeerFunc: func(peerKey string) error { RemovePeerFunc: func(peerKey string) error {
return nil return nil
}, },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"), IP: net.ParseIP("10.20.0.0"),

View File

@@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
@@ -20,7 +21,7 @@ type wgIfaceBase interface {
CreateOnAndroid(routeRange []string, ip string, domains []string) error CreateOnAndroid(routeRange []string, ip string, domains []string) error
IsUserspaceBind() bool IsUserspaceBind() bool
Name() string Name() string
Address() device.WGAddress Address() wgaddr.Address
ToInterface() *net.Interface ToInterface() *net.Interface
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error UpdateAddr(newAddr string) error

View File

@@ -11,6 +11,7 @@ import (
"github.com/netbirdio/netbird/client/internal/netflow/store" "github.com/netbirdio/netbird/client/internal/netflow/store"
"github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/peer"
) )
type rcvChan chan *types.EventFields type rcvChan chan *types.EventFields
@@ -21,15 +22,17 @@ type Logger struct {
enabled atomic.Bool enabled atomic.Bool
rcvChan atomic.Pointer[rcvChan] rcvChan atomic.Pointer[rcvChan]
cancelReceiver context.CancelFunc cancelReceiver context.CancelFunc
statusRecorder *peer.Status
Store types.Store Store types.Store
} }
func New(ctx context.Context) *Logger { func New(ctx context.Context, statusRecorder *peer.Status) *Logger {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
return &Logger{ return &Logger{
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
Store: store.NewMemoryStore(), statusRecorder: statusRecorder,
Store: store.NewMemoryStore(),
} }
} }
@@ -58,13 +61,14 @@ func (l *Logger) startReceiver() {
if l.enabled.Load() { if l.enabled.Load() {
return return
} }
l.mux.Lock() l.mux.Lock()
ctx, cancel := context.WithCancel(l.ctx) ctx, cancel := context.WithCancel(l.ctx)
l.cancelReceiver = cancel l.cancelReceiver = cancel
l.mux.Unlock() l.mux.Unlock()
c := make(rcvChan, 100) c := make(rcvChan, 100)
l.rcvChan.Swap(&c) l.rcvChan.Store(&c)
l.enabled.Store(true) l.enabled.Store(true)
for { for {
@@ -73,12 +77,15 @@ func (l *Logger) startReceiver() {
log.Info("flow Memory store receiver stopped") log.Info("flow Memory store receiver stopped")
return return
case eventFields := <-c: case eventFields := <-c:
id := uuid.NewString() id := uuid.New()
event := types.Event{ event := types.Event{
ID: id, ID: id,
EventFields: *eventFields, EventFields: *eventFields,
Timestamp: time.Now(), Timestamp: time.Now(),
} }
srcResId, dstResId := l.statusRecorder.CheckRoutes(event.SourceIP, event.DestIP, event.Direction)
event.SourceResourceID = []byte(srcResId)
event.DestResourceID = []byte(dstResId)
l.Store.StoreEvent(&event) l.Store.StoreEvent(&event)
} }
} }
@@ -100,6 +107,7 @@ func (l *Logger) stop() {
l.cancelReceiver() l.cancelReceiver()
l.cancelReceiver = nil l.cancelReceiver = nil
} }
l.rcvChan.Store(nil)
l.mux.Unlock() l.mux.Unlock()
} }
@@ -107,7 +115,7 @@ func (l *Logger) GetEvents() []*types.Event {
return l.Store.GetEvents() return l.Store.GetEvents()
} }
func (l *Logger) DeleteEvents(ids []string) { func (l *Logger) DeleteEvents(ids []uuid.UUID) {
l.Store.DeleteEvents(ids) l.Store.DeleteEvents(ids)
} }

View File

@@ -12,7 +12,7 @@ import (
) )
func TestStore(t *testing.T) { func TestStore(t *testing.T) {
logger := logger.New(context.Background()) logger := logger.New(context.Background(), nil)
logger.Enable() logger.Enable()
event := types.EventFields{ event := types.EventFields{

View File

@@ -2,17 +2,20 @@ package netflow
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"runtime" "runtime"
"sync" "sync"
"time" "time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/internal/netflow/conntrack" "github.com/netbirdio/netbird/client/internal/netflow/conntrack"
"github.com/netbirdio/netbird/client/internal/netflow/logger" "github.com/netbirdio/netbird/client/internal/netflow/logger"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/flow/client" "github.com/netbirdio/netbird/flow/client"
"github.com/netbirdio/netbird/flow/proto" "github.com/netbirdio/netbird/flow/proto"
) )
@@ -29,8 +32,8 @@ type Manager struct {
} }
// NewManager creates a new netflow manager // NewManager creates a new netflow manager
func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte) *Manager { func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
flowLogger := logger.New(ctx) flowLogger := logger.New(ctx, statusRecorder)
var ct nftypes.ConnTracker var ct nftypes.ConnTracker
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() { if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {
@@ -45,46 +48,80 @@ func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte
} }
} }
// Update applies new flow configuration settings
// needsNewClient checks if a new client needs to be created
func (m *Manager) needsNewClient(previous *nftypes.FlowConfig) bool {
current := m.flowConfig
return previous == nil ||
!previous.Enabled ||
previous.TokenPayload != current.TokenPayload ||
previous.TokenSignature != current.TokenSignature ||
previous.URL != current.URL
}
// enableFlow starts components for flow tracking
func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
// first make sender ready so events don't pile up
if m.needsNewClient(previous) {
if m.receiverClient != nil {
if err := m.receiverClient.Close(); err != nil {
log.Warnf("error closing previous flow client: %s", err)
}
}
flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval)
if err != nil {
return fmt.Errorf("create client: %w", err)
}
log.Infof("flow client configured to connect to %s", m.flowConfig.URL)
m.receiverClient = flowClient
go m.receiveACKs(flowClient)
go m.startSender()
}
m.logger.Enable()
if m.conntrack != nil {
if err := m.conntrack.Start(m.flowConfig.Counters); err != nil {
return fmt.Errorf("start conntrack: %w", err)
}
}
return nil
}
// disableFlow stops components for flow tracking
func (m *Manager) disableFlow() error {
if m.conntrack != nil {
m.conntrack.Stop()
}
m.logger.Disable()
if m.receiverClient != nil {
return m.receiverClient.Close()
}
return nil
}
// Update applies new flow configuration settings // Update applies new flow configuration settings
func (m *Manager) Update(update *nftypes.FlowConfig) error { func (m *Manager) Update(update *nftypes.FlowConfig) error {
if update == nil { if update == nil {
return nil return nil
} }
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
previous := m.flowConfig previous := m.flowConfig
m.flowConfig = update m.flowConfig = update
if update.Enabled { if update.Enabled {
if m.conntrack != nil { return m.enableFlow(previous)
if err := m.conntrack.Start(update.Counters); err != nil {
return fmt.Errorf("start conntrack: %w", err)
}
}
m.logger.Enable()
if previous == nil || !previous.Enabled {
flowClient, err := client.NewClient(m.ctx, m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature)
if err != nil {
return err
}
log.Infof("flow client connected to %s", m.flowConfig.URL)
m.receiverClient = flowClient
go m.receiveACKs()
go m.startSender()
}
return nil
} }
if m.conntrack != nil { return m.disableFlow()
m.conntrack.Stop()
}
m.logger.Disable()
if previous != nil && previous.Enabled {
return m.receiverClient.Close()
}
return nil
} }
// Close cleans up all resources // Close cleans up all resources
@@ -95,6 +132,13 @@ func (m *Manager) Close() {
if m.conntrack != nil { if m.conntrack != nil {
m.conntrack.Close() m.conntrack.Close()
} }
if m.receiverClient != nil {
if err := m.receiverClient.Close(); err != nil {
log.Warnf("failed to close receiver client: %s", err)
}
}
m.logger.Close() m.logger.Close()
} }
@@ -106,6 +150,7 @@ func (m *Manager) GetLogger() nftypes.FlowLogger {
func (m *Manager) startSender() { func (m *Manager) startSender() {
ticker := time.NewTicker(m.flowConfig.Interval) ticker := time.NewTicker(m.flowConfig.Interval)
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
case <-m.ctx.Done(): case <-m.ctx.Done():
@@ -113,56 +158,62 @@ func (m *Manager) startSender() {
case <-ticker.C: case <-ticker.C:
events := m.logger.GetEvents() events := m.logger.GetEvents()
for _, event := range events { for _, event := range events {
log.Infof("send flow event to server: %s", event.ID) if err := m.send(event); err != nil {
err := m.send(event) log.Errorf("failed to send flow event to server: %s", err)
if err != nil { continue
log.Errorf("send flow event to server: %s", err)
} }
log.Tracef("sent flow event: %s", event.ID)
} }
} }
} }
} }
func (m *Manager) receiveACKs() { func (m *Manager) receiveACKs(client *client.GRPCClient) {
if m.receiverClient == nil { err := client.Receive(m.ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error {
return log.Tracef("received flow event ack: %s", ack.EventId)
} m.logger.DeleteEvents([]uuid.UUID{uuid.UUID(ack.EventId)})
err := m.receiverClient.Receive(m.ctx, func(ack *proto.FlowEventAck) error {
log.Infof("receive flow event ack: %s", ack.EventId)
m.logger.DeleteEvents([]string{ack.EventId})
return nil return nil
}) })
if err != nil {
log.Errorf("receive flow event ack: %s", err) if err != nil && !errors.Is(err, context.Canceled) {
log.Errorf("failed to receive flow event ack: %s", err)
} }
} }
func (m *Manager) send(event *nftypes.Event) error { func (m *Manager) send(event *nftypes.Event) error {
if m.receiverClient == nil { m.mux.Lock()
client := m.receiverClient
m.mux.Unlock()
if client == nil {
return nil return nil
} }
return m.receiverClient.Send(m.ctx, toProtoEvent(m.publicKey, event))
return client.Send(toProtoEvent(m.publicKey, event))
} }
func toProtoEvent(publicKey []byte, event *nftypes.Event) *proto.FlowEvent { func toProtoEvent(publicKey []byte, event *nftypes.Event) *proto.FlowEvent {
protoEvent := &proto.FlowEvent{ protoEvent := &proto.FlowEvent{
EventId: event.ID, EventId: event.ID[:],
Timestamp: timestamppb.New(event.Timestamp), Timestamp: timestamppb.New(event.Timestamp),
PublicKey: publicKey, PublicKey: publicKey,
FlowFields: &proto.FlowFields{ FlowFields: &proto.FlowFields{
FlowId: event.FlowID[:], FlowId: event.FlowID[:],
RuleId: event.RuleID, RuleId: event.RuleID,
Type: proto.Type(event.Type), Type: proto.Type(event.Type),
Direction: proto.Direction(event.Direction), Direction: proto.Direction(event.Direction),
Protocol: uint32(event.Protocol), Protocol: uint32(event.Protocol),
SourceIp: event.SourceIP.AsSlice(), SourceIp: event.SourceIP.AsSlice(),
DestIp: event.DestIP.AsSlice(), DestIp: event.DestIP.AsSlice(),
RxPackets: event.RxPackets, RxPackets: event.RxPackets,
TxPackets: event.TxPackets, TxPackets: event.TxPackets,
RxBytes: event.RxBytes, RxBytes: event.RxBytes,
TxBytes: event.TxBytes, TxBytes: event.TxBytes,
SourceResourceId: event.SourceResourceID,
DestResourceId: event.DestResourceID,
}, },
} }
if event.Protocol == nftypes.ICMP { if event.Protocol == nftypes.ICMP {
protoEvent.FlowFields.ConnectionInfo = &proto.FlowFields_IcmpInfo{ protoEvent.FlowFields.ConnectionInfo = &proto.FlowFields_IcmpInfo{
IcmpInfo: &proto.ICMPInfo{ IcmpInfo: &proto.ICMPInfo{

View File

@@ -3,18 +3,22 @@ package store
import ( import (
"sync" "sync"
"golang.org/x/exp/maps"
"github.com/google/uuid"
"github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
func NewMemoryStore() *Memory { func NewMemoryStore() *Memory {
return &Memory{ return &Memory{
events: make(map[string]*types.Event), events: make(map[uuid.UUID]*types.Event),
} }
} }
type Memory struct { type Memory struct {
mux sync.Mutex mux sync.Mutex
events map[string]*types.Event events map[uuid.UUID]*types.Event
} }
func (m *Memory) StoreEvent(event *types.Event) { func (m *Memory) StoreEvent(event *types.Event) {
@@ -26,7 +30,7 @@ func (m *Memory) StoreEvent(event *types.Event) {
func (m *Memory) Close() { func (m *Memory) Close() {
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
m.events = make(map[string]*types.Event) maps.Clear(m.events)
} }
func (m *Memory) GetEvents() []*types.Event { func (m *Memory) GetEvents() []*types.Event {
@@ -39,7 +43,7 @@ func (m *Memory) GetEvents() []*types.Event {
return events return events
} }
func (m *Memory) DeleteEvents(ids []string) { func (m *Memory) DeleteEvents(ids []uuid.UUID) {
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
for _, id := range ids { for _, id := range ids {

View File

@@ -2,11 +2,12 @@ package types
import ( import (
"net/netip" "net/netip"
"strconv"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type Protocol uint8 type Protocol uint8
@@ -27,8 +28,10 @@ func (p Protocol) String() string {
return "TCP" return "TCP"
case 17: case 17:
return "UDP" return "UDP"
case 132:
return "SCTP"
default: default:
return "unknown" return strconv.FormatUint(uint64(p), 10)
} }
} }
@@ -61,27 +64,29 @@ const (
) )
type Event struct { type Event struct {
ID string ID uuid.UUID
Timestamp time.Time Timestamp time.Time
EventFields EventFields
} }
type EventFields struct { type EventFields struct {
FlowID uuid.UUID FlowID uuid.UUID
Type Type Type Type
RuleID []byte RuleID []byte
Direction Direction Direction Direction
Protocol Protocol Protocol Protocol
SourceIP netip.Addr SourceIP netip.Addr
DestIP netip.Addr DestIP netip.Addr
SourcePort uint16 SourceResourceID []byte
DestPort uint16 DestResourceID []byte
ICMPType uint8 SourcePort uint16
ICMPCode uint8 DestPort uint16
RxPackets uint64 ICMPType uint8
TxPackets uint64 ICMPCode uint8
RxBytes uint64 RxPackets uint64
TxBytes uint64 TxPackets uint64
RxBytes uint64
TxBytes uint64
} }
type FlowConfig struct { type FlowConfig struct {
@@ -108,7 +113,7 @@ type FlowLogger interface {
// GetEvents returns all stored events // GetEvents returns all stored events
GetEvents() []*Event GetEvents() []*Event
// DeleteEvents deletes events from the store // DeleteEvents deletes events from the store
DeleteEvents([]string) DeleteEvents([]uuid.UUID)
// Close closes the logger // Close closes the logger
Close() Close()
// Enable enables the flow logger receiver // Enable enables the flow logger receiver
@@ -123,7 +128,7 @@ type Store interface {
// GetEvents returns all stored events // GetEvents returns all stored events
GetEvents() []*Event GetEvents() []*Event
// DeleteEvents deletes events from the store // DeleteEvents deletes events from the store
DeleteEvents([]string) DeleteEvents([]uuid.UUID)
// Close closes the store // Close closes the store
Close() Close()
} }
@@ -142,5 +147,5 @@ type ConnTracker interface {
type IFaceMapper interface { type IFaceMapper interface {
IsUserspaceBind() bool IsUserspaceBind() bool
Name() string Name() string
Address() device.WGAddress Address() wgaddr.Address
} }

View File

@@ -16,7 +16,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
) )
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error { func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil { if err != nil {
return fmt.Errorf("failed to open routing socket: %v", err) return fmt.Errorf("failed to open routing socket: %v", err)
@@ -28,18 +28,10 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
} }
}() }()
go func() {
<-ctx.Done()
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
log.Debugf("Network monitor: closed routing socket: %v", err)
}
}()
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ErrStopped return ctx.Err()
default: default:
buf := make([]byte, 2048) buf := make([]byte, 2048)
n, err := unix.Read(fd, buf) n, err := unix.Read(fd, buf)
@@ -76,11 +68,11 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
switch msg.Type { switch msg.Type {
case unix.RTM_ADD: case unix.RTM_ADD:
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf) log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
go callback() return nil
case unix.RTM_DELETE: case unix.RTM_DELETE:
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 { if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf) log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
go callback() return nil
} }
} }
} }

View File

@@ -14,7 +14,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
) )
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error { func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
if nexthopv4.Intf == nil && nexthopv6.Intf == nil { if nexthopv4.Intf == nil && nexthopv6.Intf == nil {
return errors.New("no interfaces available") return errors.New("no interfaces available")
} }
@@ -31,8 +31,7 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ErrStopped return ctx.Err()
// handle route changes // handle route changes
case route := <-routeChan: case route := <-routeChan:
// default route and main table // default route and main table
@@ -43,12 +42,10 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
// triggered on added/replaced routes // triggered on added/replaced routes
case syscall.RTM_NEWROUTE: case syscall.RTM_NEWROUTE:
log.Infof("Network monitor: default route changed: via %s, interface %d", route.Gw, route.LinkIndex) log.Infof("Network monitor: default route changed: via %s, interface %d", route.Gw, route.LinkIndex)
go callback()
return nil return nil
case syscall.RTM_DELROUTE: case syscall.RTM_DELROUTE:
if nexthopv4.Intf != nil && route.Gw.Equal(nexthopv4.IP.AsSlice()) || nexthopv6.Intf != nil && route.Gw.Equal(nexthopv6.IP.AsSlice()) { if nexthopv4.Intf != nil && route.Gw.Equal(nexthopv4.IP.AsSlice()) || nexthopv6.Intf != nil && route.Gw.Equal(nexthopv6.IP.AsSlice()) {
log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex) log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex)
go callback()
return nil return nil
} }
} }

View File

@@ -10,7 +10,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
) )
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) error { func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
routeMonitor, err := systemops.NewRouteMonitor(ctx) routeMonitor, err := systemops.NewRouteMonitor(ctx)
if err != nil { if err != nil {
return fmt.Errorf("failed to create route monitor: %w", err) return fmt.Errorf("failed to create route monitor: %w", err)
@@ -24,20 +24,20 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop, ca
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ErrStopped return ctx.Err()
case route := <-routeMonitor.RouteUpdates(): case route := <-routeMonitor.RouteUpdates():
if route.Destination.Bits() != 0 { if route.Destination.Bits() != 0 {
continue continue
} }
if routeChanged(route, nexthopv4, nexthopv6, callback) { if routeChanged(route, nexthopv4, nexthopv6) {
break return nil
} }
} }
} }
} }
func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop, callback func()) bool { func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Nexthop) bool {
intf := "<nil>" intf := "<nil>"
if route.Interface != nil { if route.Interface != nil {
intf = route.Interface.Name intf = route.Interface.Name
@@ -51,18 +51,15 @@ func routeChanged(route systemops.RouteUpdate, nexthopv4, nexthopv6 systemops.Ne
case systemops.RouteModified: case systemops.RouteModified:
// TODO: get routing table to figure out if our route is affected for modified routes // TODO: get routing table to figure out if our route is affected for modified routes
log.Infof("Network monitor: default route changed: via %s, interface %s", route.NextHop, intf) log.Infof("Network monitor: default route changed: via %s, interface %s", route.NextHop, intf)
go callback()
return true return true
case systemops.RouteAdded: case systemops.RouteAdded:
if route.NextHop.Is4() && route.NextHop != nexthopv4.IP || route.NextHop.Is6() && route.NextHop != nexthopv6.IP { if route.NextHop.Is4() && route.NextHop != nexthopv4.IP || route.NextHop.Is6() && route.NextHop != nexthopv6.IP {
log.Infof("Network monitor: default route added: via %s, interface %s", route.NextHop, intf) log.Infof("Network monitor: default route added: via %s, interface %s", route.NextHop, intf)
go callback()
return true return true
} }
case systemops.RouteDeleted: case systemops.RouteDeleted:
if nexthopv4.Intf != nil && route.NextHop == nexthopv4.IP || nexthopv6.Intf != nil && route.NextHop == nexthopv6.IP { if nexthopv4.Intf != nil && route.NextHop == nexthopv4.IP || nexthopv6.Intf != nil && route.NextHop == nexthopv6.IP {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.NextHop, intf) log.Infof("Network monitor: default route removed: via %s, interface %s", route.NextHop, intf)
go callback()
return true return true
} }
} }

View File

@@ -1,12 +1,27 @@
//go:build !ios && !android
package networkmonitor package networkmonitor
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"net/netip"
"runtime/debug"
"sync" "sync"
"time"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
) )
var ErrStopped = errors.New("monitor has been stopped") const (
debounceTime = 2 * time.Second
)
var checkChangeFn = checkChange
// NetworkMonitor watches for changes in network configuration. // NetworkMonitor watches for changes in network configuration.
type NetworkMonitor struct { type NetworkMonitor struct {
@@ -19,3 +34,99 @@ type NetworkMonitor struct {
func New() *NetworkMonitor { func New() *NetworkMonitor {
return &NetworkMonitor{} return &NetworkMonitor{}
} }
// Listen begins monitoring network changes. When a change is detected, this function will return without error.
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
nw.mu.Lock()
if nw.cancel != nil {
nw.mu.Unlock()
return errors.New("network monitor already started")
}
ctx, nw.cancel = context.WithCancel(ctx)
defer nw.cancel()
nw.wg.Add(1)
nw.mu.Unlock()
defer nw.wg.Done()
var nexthop4, nexthop6 systemops.Nexthop
operation := func() error {
var errv4, errv6 error
nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified())
nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified())
if errv4 != nil && errv6 != nil {
return errors.New("failed to get default next hops")
}
if errv4 == nil {
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name)
}
if errv6 == nil {
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name)
}
// continue if either route was found
return nil
}
expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
if err := backoff.Retry(operation, expBackOff); err != nil {
return fmt.Errorf("failed to get default next hops: %w", err)
}
// recover in case sys ops panic
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, debug.Stack())
}
}()
event := make(chan struct{}, 1)
go nw.checkChanges(ctx, event, nexthop4, nexthop6)
// debounce changes
timer := time.NewTimer(0)
timer.Stop()
for {
select {
case <-event:
timer.Reset(debounceTime)
case <-timer.C:
return nil
case <-ctx.Done():
timer.Stop()
return ctx.Err()
}
}
}
// Stop stops the network monitor.
func (nw *NetworkMonitor) Stop() {
nw.mu.Lock()
defer nw.mu.Unlock()
if nw.cancel == nil {
return
}
nw.cancel()
nw.wg.Wait()
}
func (nw *NetworkMonitor) checkChanges(ctx context.Context, event chan struct{}, nexthop4 systemops.Nexthop, nexthop6 systemops.Nexthop) {
for {
if err := checkChangeFn(ctx, nexthop4, nexthop6); err != nil {
close(event)
return
}
// prevent blocking
select {
case event <- struct{}{}:
default:
}
}
}

View File

@@ -1,82 +0,0 @@
//go:build !ios && !android
package networkmonitor
import (
"context"
"errors"
"fmt"
"net/netip"
"runtime/debug"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
// Start begins monitoring network changes. When a change is detected, it calls the callback asynchronously and returns.
func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error) {
if ctx.Err() != nil {
return ctx.Err()
}
nw.mu.Lock()
ctx, nw.cancel = context.WithCancel(ctx)
nw.mu.Unlock()
nw.wg.Add(1)
defer nw.wg.Done()
var nexthop4, nexthop6 systemops.Nexthop
operation := func() error {
var errv4, errv6 error
nexthop4, errv4 = systemops.GetNextHop(netip.IPv4Unspecified())
nexthop6, errv6 = systemops.GetNextHop(netip.IPv6Unspecified())
if errv4 != nil && errv6 != nil {
return errors.New("failed to get default next hops")
}
if errv4 == nil {
log.Debugf("Network monitor: IPv4 default route: %s, interface: %s", nexthop4.IP, nexthop4.Intf.Name)
}
if errv6 == nil {
log.Debugf("Network monitor: IPv6 default route: %s, interface: %s", nexthop6.IP, nexthop6.Intf.Name)
}
// continue if either route was found
return nil
}
expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx)
if err := backoff.Retry(operation, expBackOff); err != nil {
return fmt.Errorf("failed to get default next hops: %w", err)
}
// recover in case sys ops panic
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, debug.Stack())
}
}()
if err := checkChange(ctx, nexthop4, nexthop6, callback); err != nil {
return fmt.Errorf("check change: %w", err)
}
return nil
}
// Stop stops the network monitor.
func (nw *NetworkMonitor) Stop() {
nw.mu.Lock()
defer nw.mu.Unlock()
if nw.cancel != nil {
nw.cancel()
nw.wg.Wait()
}
}

View File

@@ -2,10 +2,21 @@
package networkmonitor package networkmonitor
import "context" import (
"context"
"fmt"
)
func (nw *NetworkMonitor) Start(context.Context, func()) error { type NetworkMonitor struct {
return nil }
// New creates a new network monitor.
func New() *NetworkMonitor {
return &NetworkMonitor{}
}
func (nw *NetworkMonitor) Listen(_ context.Context) error {
return fmt.Errorf("network monitor not supported on mobile platforms")
} }
func (nw *NetworkMonitor) Stop() { func (nw *NetworkMonitor) Stop() {

View File

@@ -0,0 +1,99 @@
package networkmonitor
import (
"context"
"errors"
"testing"
"time"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
type MocMultiEvent struct {
counter int
}
func (m *MocMultiEvent) checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
if m.counter == 0 {
<-ctx.Done()
return ctx.Err()
}
time.Sleep(1 * time.Second)
m.counter--
return nil
}
func TestNetworkMonitor_Close(t *testing.T) {
checkChangeFn = func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
<-ctx.Done()
return ctx.Err()
}
nw := New()
var resErr error
done := make(chan struct{})
go func() {
resErr = nw.Listen(context.Background())
close(done)
}()
time.Sleep(1 * time.Second) // wait for the goroutine to start
nw.Stop()
<-done
if !errors.Is(resErr, context.Canceled) {
t.Errorf("unexpected error: %v", resErr)
}
}
func TestNetworkMonitor_Event(t *testing.T) {
checkChangeFn = func(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
timeout, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
select {
case <-ctx.Done():
return ctx.Err()
case <-timeout.Done():
return nil
}
}
nw := New()
defer nw.Stop()
var resErr error
done := make(chan struct{})
go func() {
resErr = nw.Listen(context.Background())
close(done)
}()
<-done
if !errors.Is(resErr, nil) {
t.Errorf("unexpected error: %v", nil)
}
}
func TestNetworkMonitor_MultiEvent(t *testing.T) {
eventsRepeated := 3
me := &MocMultiEvent{counter: eventsRepeated}
checkChangeFn = me.checkChange
nw := New()
defer nw.Stop()
done := make(chan struct{})
started := time.Now()
go func() {
if resErr := nw.Listen(context.Background()); resErr != nil {
t.Errorf("unexpected error: %v", resErr)
}
close(done)
}()
<-done
expectedResponseTime := time.Duration(eventsRepeated)*time.Second + debounceTime
if time.Since(started) < expectedResponseTime {
t.Errorf("unexpected duration: %v", time.Since(started))
}
}

View File

@@ -442,8 +442,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String()) conn.log.Infof("created new wgProxy for relay connection: %s", wgProxy.EndpointAddr().String())
if conn.iceP2PIsActive() { if conn.isICEActive() {
conn.log.Debugf("do not switch to relay because current priority is: %s", conn.currentConnPriority.String()) conn.log.Infof("do not switch to relay because current priority is: %s", conn.currentConnPriority.String())
conn.setRelayedProxy(wgProxy) conn.setRelayedProxy(wgProxy)
conn.statusRelay.Set(StatusConnected) conn.statusRelay.Set(StatusConnected)
conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey)
@@ -711,8 +711,8 @@ func (conn *Conn) isReadyToUpgrade() bool {
return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay return conn.wgProxyRelay != nil && conn.currentConnPriority != connPriorityRelay
} }
func (conn *Conn) iceP2PIsActive() bool { func (conn *Conn) isICEActive() bool {
return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected return (conn.currentConnPriority == connPriorityICEP2P || conn.currentConnPriority == connPriorityICETurn) && conn.statusICE.Get() == StatusConnected
} }
func (conn *Conn) removeWgPeer() error { func (conn *Conn) removeWgPeer() error {

View File

@@ -8,6 +8,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
@@ -16,4 +17,5 @@ type WGIface interface {
RemovePeer(peerKey string) error RemovePeer(peerKey string) error
GetStats(peerKey string) (configurer.WGStats, error) GetStats(peerKey string) (configurer.WGStats, error)
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy
Address() wgaddr.Address
} }

View File

@@ -0,0 +1,100 @@
package peer
import (
"net/netip"
"sync"
log "github.com/sirupsen/logrus"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
type routeIDLookup struct {
localMap sync.Map
remoteMap sync.Map
resolvedIPs sync.Map
}
func (r *routeIDLookup) AddLocalRouteID(resourceID string, route netip.Prefix) {
_, exists := r.localMap.LoadOrStore(route, resourceID)
if exists {
log.Tracef("resourceID %s already exists in local map", resourceID)
}
}
func (r *routeIDLookup) RemoveLocalRouteID(route netip.Prefix) {
r.localMap.Delete(route)
}
func (r *routeIDLookup) AddRemoteRouteID(resourceID string, route netip.Prefix) {
_, exists := r.remoteMap.LoadOrStore(route, resourceID)
if exists {
log.Tracef("resourceID %s already exists in remote map", resourceID)
}
}
func (r *routeIDLookup) RemoveRemoteRouteID(route netip.Prefix) {
r.remoteMap.Delete(route)
}
func (r *routeIDLookup) AddResolvedIP(resourceID string, route netip.Prefix) {
r.resolvedIPs.Store(route.Addr(), resourceID)
}
func (r *routeIDLookup) RemoveResolvedIP(route netip.Prefix) {
r.resolvedIPs.Delete(route.Addr())
}
func (r *routeIDLookup) Lookup(src, dst netip.Addr, direction nftypes.Direction) (srcResourceID, dstResourceID string) {
// check resolved ip's first
resId, ok := r.resolvedIPs.Load(src)
if ok {
srcResourceID = resId.(string)
} else {
resId, ok := r.resolvedIPs.Load(dst)
if ok {
dstResourceID = resId.(string)
}
}
switch direction {
case nftypes.Ingress:
if srcResourceID == "" || dstResourceID == "" {
r.localMap.Range(func(key, value interface{}) bool {
if srcResourceID == "" && key.(netip.Prefix).Contains(src) {
srcResourceID = value.(string)
} else if dstResourceID == "" && key.(netip.Prefix).Contains(dst) {
dstResourceID = value.(string)
}
if srcResourceID != "" && dstResourceID != "" {
return false
}
return true
})
}
case nftypes.Egress:
if srcResourceID == "" || dstResourceID == "" {
r.remoteMap.Range(func(key, value interface{}) bool {
if srcResourceID == "" && key.(netip.Prefix).Contains(src) {
srcResourceID = value.(string)
} else if dstResourceID == "" && key.(netip.Prefix).Contains(dst) {
dstResourceID = value.(string)
}
if srcResourceID != "" && dstResourceID != "" {
return false
}
return true
})
}
}
return srcResourceID, dstResourceID
}

View File

@@ -17,6 +17,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/internal/ingressgw" "github.com/netbirdio/netbird/client/internal/ingressgw"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
@@ -176,6 +177,8 @@ type Status struct {
eventQueue *EventQueue eventQueue *EventQueue
ingressGwMgr *ingressgw.Manager ingressGwMgr *ingressgw.Manager
routeIDLookup routeIDLookup
} }
// NewRecorder returns a new Status instance // NewRecorder returns a new Status instance
@@ -311,7 +314,7 @@ func (d *Status) UpdatePeerState(receivedState State) error {
return nil return nil
} }
func (d *Status) AddPeerStateRoute(peer string, route string) error { func (d *Status) AddPeerStateRoute(peer string, route string, resourceId string) error {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock() defer d.mux.Unlock()
@@ -323,6 +326,14 @@ func (d *Status) AddPeerStateRoute(peer string, route string) error {
peerState.AddRoute(route) peerState.AddRoute(route)
d.peers[peer] = peerState d.peers[peer] = peerState
pref, err := netip.ParsePrefix(route)
if err != nil {
log.Errorf("failed to parse prefix %s: %v", route, err)
} else {
d.routeIDLookup.AddRemoteRouteID(resourceId, pref)
}
// todo: consider to make sense of this notification or not // todo: consider to make sense of this notification or not
d.notifyPeerListChanged() d.notifyPeerListChanged()
return nil return nil
@@ -340,11 +351,28 @@ func (d *Status) RemovePeerStateRoute(peer string, route string) error {
peerState.DeleteRoute(route) peerState.DeleteRoute(route)
d.peers[peer] = peerState d.peers[peer] = peerState
pref, err := netip.ParsePrefix(route)
if err != nil {
log.Errorf("failed to parse prefix %s: %v", route, err)
} else {
d.routeIDLookup.RemoveRemoteRouteID(pref)
}
// todo: consider to make sense of this notification or not // todo: consider to make sense of this notification or not
d.notifyPeerListChanged() d.notifyPeerListChanged()
return nil return nil
} }
// CheckRoutes checks if the source and destination addresses are within the same route
// and returns the resource ID of the route that contains the addresses
func (d *Status) CheckRoutes(src, dst netip.Addr, direction nftypes.Direction) (srcResId string, dstResId string) {
if d == nil {
return
}
return d.routeIDLookup.Lookup(src, dst, direction)
}
func (d *Status) UpdatePeerICEState(receivedState State) error { func (d *Status) UpdatePeerICEState(receivedState State) error {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock() defer d.mux.Unlock()
@@ -558,6 +586,50 @@ func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
d.notifyAddressChanged() d.notifyAddressChanged()
} }
// AddLocalPeerStateRoute adds a route to the local peer state
func (d *Status) AddLocalPeerStateRoute(route, resourceId string) {
d.mux.Lock()
defer d.mux.Unlock()
pref, err := netip.ParsePrefix(route)
if err != nil {
log.Errorf("failed to parse prefix %s: %v", route, err)
return
}
if d.localPeer.Routes == nil {
d.localPeer.Routes = map[string]struct{}{}
}
d.localPeer.Routes[route] = struct{}{}
d.routeIDLookup.AddLocalRouteID(resourceId, pref)
}
// RemoveLocalPeerStateRoute removes a route from the local peer state
func (d *Status) RemoveLocalPeerStateRoute(route string) {
d.mux.Lock()
defer d.mux.Unlock()
pref, err := netip.ParsePrefix(route)
if err != nil {
log.Errorf("failed to parse prefix %s: %v", route, err)
return
}
delete(d.localPeer.Routes, route)
d.routeIDLookup.RemoveLocalRouteID(pref)
}
// CleanLocalPeerStateRoutes cleans all routes from the local peer state
func (d *Status) CleanLocalPeerStateRoutes() {
d.mux.Lock()
defer d.mux.Unlock()
d.localPeer.Routes = map[string]struct{}{}
}
// CleanLocalPeerState cleans local peer status // CleanLocalPeerState cleans local peer status
func (d *Status) CleanLocalPeerState() { func (d *Status) CleanLocalPeerState() {
d.mux.Lock() d.mux.Lock()
@@ -641,7 +713,7 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) {
d.nsGroupStates = dnsStates d.nsGroupStates = dnsStates
} }
func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix) { func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix, resourceId string) {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock() defer d.mux.Unlock()
@@ -650,6 +722,10 @@ func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resol
Prefixes: prefixes, Prefixes: prefixes,
ParentDomain: originalDomain, ParentDomain: originalDomain,
} }
for _, prefix := range prefixes {
d.routeIDLookup.AddResolvedIP(resourceId, prefix)
}
} }
func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) { func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
@@ -660,6 +736,10 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) {
for k, v := range d.resolvedDomainsStates { for k, v := range d.resolvedDomainsStates {
if v.ParentDomain == domain { if v.ParentDomain == domain {
delete(d.resolvedDomainsStates, k) delete(d.resolvedDomainsStates, k)
for _, prefix := range v.Prefixes {
d.routeIDLookup.RemoveResolvedIP(prefix)
}
} }
} }
} }

View File

@@ -358,6 +358,12 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
} }
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool { func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
addr, err := netip.ParseAddr(candidate.Address())
if err != nil {
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
return false
}
var routePrefixes []netip.Prefix var routePrefixes []netip.Prefix
for _, routes := range clientRoutes { for _, routes := range clientRoutes {
if len(routes) > 0 && routes[0] != nil { if len(routes) > 0 && routes[0] != nil {
@@ -365,14 +371,8 @@ func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool
} }
} }
addr, err := netip.ParseAddr(candidate.Address())
if err != nil {
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
return false
}
for _, prefix := range routePrefixes { for _, prefix := range routePrefixes {
// default route is // default route is handled by route exclusion / ip rules
if prefix.Bits() == 0 { if prefix.Bits() == 0 {
continue continue
} }

View File

@@ -330,7 +330,7 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem(rsn reason) error
c.connectEvent() c.connectEvent()
} }
err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String()) err := c.statusRecorder.AddPeerStateRoute(c.currentChosen.Peer, c.handler.String(), c.currentChosen.GetResourceID())
if err != nil { if err != nil {
return fmt.Errorf("add peer state route: %w", err) return fmt.Errorf("add peer state route: %w", err)
} }

View File

@@ -160,6 +160,12 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return return
} }
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
if r.Extra == nil {
r.SetEdns0(4096, false)
r.MsgHdr.AuthenticatedData = true
}
client := &dns.Client{ client := &dns.Client{
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
Net: "udp", Net: "udp",
@@ -315,7 +321,7 @@ func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain dom
if len(toAdd) > 0 || len(toRemove) > 0 { if len(toAdd) > 0 || len(toRemove) > 0 {
d.interceptedDomains[resolvedDomain] = newPrefixes d.interceptedDomains[resolvedDomain] = newPrefixes
originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), ".")) originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), "."))
d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes) d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes, d.route.GetResourceID())
if len(toAdd) > 0 { if len(toAdd) > 0 {
log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s", log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s",

View File

@@ -288,7 +288,7 @@ func (r *Route) updateDynamicRoutes(ctx context.Context, newDomains domainMap) e
updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes) updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes)
r.dynamicDomains[domain] = updatedPrefixes r.dynamicDomains[domain] = updatedPrefixes
r.statusRecorder.UpdateResolvedDomainsStates(domain, domain, updatedPrefixes) r.statusRecorder.UpdateResolvedDomainsStates(domain, domain, updatedPrefixes, r.route.GetResourceID())
} }
return nberrors.FormatErrorOrNil(merr) return nberrors.FormatErrorOrNil(merr)

View File

@@ -3,9 +3,9 @@ package iface
import ( import (
"net" "net"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type wgIfaceBase interface { type wgIfaceBase interface {
@@ -13,7 +13,7 @@ type wgIfaceBase interface {
RemoveAllowedIP(peerKey string, allowedIP string) error RemoveAllowedIP(peerKey string, allowedIP string) error
Name() string Name() string
Address() iface.WGAddress Address() wgaddr.Address
ToInterface() *net.Interface ToInterface() *net.Interface
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() device.PacketFilter GetFilter() device.PacketFilter

View File

@@ -103,9 +103,7 @@ func (m *serverRouter) removeFromServerNetwork(route *route.Route) error {
delete(m.routes, route.ID) delete(m.routes, route.ID)
state := m.statusRecorder.GetLocalPeerState() m.statusRecorder.RemoveLocalPeerStateRoute(route.Network.String())
delete(state.Routes, route.Network.String())
m.statusRecorder.UpdateLocalPeerState(state)
return nil return nil
} }
@@ -131,18 +129,12 @@ func (m *serverRouter) addToServerNetwork(route *route.Route) error {
m.routes[route.ID] = route m.routes[route.ID] = route
state := m.statusRecorder.GetLocalPeerState()
if state.Routes == nil {
state.Routes = map[string]struct{}{}
}
routeStr := route.Network.String() routeStr := route.Network.String()
if route.IsDynamic() { if route.IsDynamic() {
routeStr = route.Domains.SafeString() routeStr = route.Domains.SafeString()
} }
state.Routes[routeStr] = struct{}{}
m.statusRecorder.UpdateLocalPeerState(state) m.statusRecorder.AddLocalPeerStateRoute(routeStr, route.GetResourceID())
return nil return nil
} }
@@ -164,9 +156,7 @@ func (m *serverRouter) cleanUp() {
} }
state := m.statusRecorder.GetLocalPeerState() m.statusRecorder.CleanLocalPeerStateRoutes()
state.Routes = nil
m.statusRecorder.UpdateLocalPeerState(state)
} }
func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) { func routeToRouterPair(route *route.Route) (firewall.RouterPair, error) {

View File

@@ -71,7 +71,7 @@
</InstallExecuteSequence> </InstallExecuteSequence>
<!-- Icons --> <!-- Icons -->
<Icon Id="NetbirdIcon" SourceFile=".\client\ui\netbird.ico" /> <Icon Id="NetbirdIcon" SourceFile=".\client\ui\assets\netbird.ico" />
<Property Id="ARPPRODUCTICON" Value="NetbirdIcon" /> <Property Id="ARPPRODUCTICON" Value="NetbirdIcon" />
</Package> </Package>

View File

@@ -5,5 +5,5 @@
#define STRINGIZE(x) #x #define STRINGIZE(x) #x
#define EXPAND(x) STRINGIZE(x) #define EXPAND(x) STRINGIZE(x)
CREATEPROCESS_MANIFEST_RESOURCE_ID RT_MANIFEST manifest.xml CREATEPROCESS_MANIFEST_RESOURCE_ID RT_MANIFEST manifest.xml
7 ICON ui/netbird.ico 7 ICON ui/assets/netbird.ico
wintun.dll RCDATA wintun.dll wintun.dll RCDATA wintun.dll

View File

@@ -3,7 +3,7 @@ package server
import ( import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path"
"syscall" "syscall"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -12,7 +12,6 @@ import (
) )
const ( const (
windowsPanicLogEnvVar = "NB_WINDOWS_PANIC_LOG"
// STD_ERROR_HANDLE ((DWORD)-12) = 4294967284 // STD_ERROR_HANDLE ((DWORD)-12) = 4294967284
stdErrorHandle = ^uintptr(11) stdErrorHandle = ^uintptr(11)
) )
@@ -25,13 +24,10 @@ var (
) )
func handlePanicLog() error { func handlePanicLog() error {
logPath := os.Getenv(windowsPanicLogEnvVar) // TODO: move this to a central location
if logPath == "" { logDir := path.Join(os.Getenv("PROGRAMDATA"), "Netbird")
return nil logPath := path.Join(logDir, "netbird.err")
}
// Ensure the directory exists
logDir := filepath.Dir(logPath)
if err := os.MkdirAll(logDir, 0750); err != nil { if err := os.MkdirAll(logDir, 0750); err != nil {
return fmt.Errorf("create panic log directory: %w", err) return fmt.Errorf("create panic log directory: %w", err)
} }
@@ -39,13 +35,11 @@ func handlePanicLog() error {
return fmt.Errorf("enforce permission on panic log file: %w", err) return fmt.Errorf("enforce permission on panic log file: %w", err)
} }
// Open log file with append mode
f, err := os.OpenFile(logPath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) f, err := os.OpenFile(logPath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
if err != nil { if err != nil {
return fmt.Errorf("open panic log file: %w", err) return fmt.Errorf("open panic log file: %w", err)
} }
// Redirect stderr to the file
if err = redirectStderr(f); err != nil { if err = redirectStderr(f); err != nil {
if closeErr := f.Close(); closeErr != nil { if closeErr := f.Close(); closeErr != nil {
log.Warnf("failed to close file after redirect error: %v", closeErr) log.Warnf("failed to close file after redirect error: %v", closeErr)
@@ -59,7 +53,6 @@ func handlePanicLog() error {
// redirectStderr redirects stderr to the provided file // redirectStderr redirects stderr to the provided file
func redirectStderr(f *os.File) error { func redirectStderr(f *os.File) error {
// Get the current process's stderr handle
if err := setStdHandle(f); err != nil { if err := setStdHandle(f); err != nil {
return fmt.Errorf("failed to set stderr handle: %w", err) return fmt.Errorf("failed to set stderr handle: %w", err)
} }

View File

@@ -160,7 +160,7 @@ func (s *Server) Start() error {
// mechanism to keep the client connected even when the connection is lost. // mechanism to keep the client connected even when the connection is lost.
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status, func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Config, statusRecorder *peer.Status,
runningChan chan error, runningChan chan struct{},
) { ) {
backOff := getConnectWithBackoff(ctx) backOff := getConnectWithBackoff(ctx)
retryStarted := false retryStarted := false
@@ -628,20 +628,21 @@ func (s *Server) Up(callerCtx context.Context, _ *proto.UpRequest) (*proto.UpRes
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String()) s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
runningChan := make(chan error) timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second)
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, runningChan) defer cancel()
runningChan := make(chan struct{}, 1) // buffered channel to do not lose the signal
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, runningChan)
for { for {
select { select {
case err := <-runningChan: case <-runningChan:
if err != nil { return &proto.UpResponse{}, nil
log.Debugf("waiting for engine to become ready failed: %s", err)
} else {
return &proto.UpResponse{}, nil
}
case <-callerCtx.Done(): case <-callerCtx.Done():
log.Debug("context done, stopping the wait for engine to become ready") log.Debug("context done, stopping the wait for engine to become ready")
return nil, callerCtx.Err() return nil, callerCtx.Err()
case <-timeoutCtx.Done():
log.Debug("up is timed out, stopping the wait for engine to become ready")
return nil, timeoutCtx.Err()
} }
} }
} }

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
@@ -41,11 +42,21 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
srcIP = engine.GetWgAddr() srcIP = engine.GetWgAddr()
} }
srcAddr, ok := netip.AddrFromSlice(srcIP)
if !ok {
return nil, fmt.Errorf("invalid source IP address")
}
dstIP := net.ParseIP(req.GetDestinationIp()) dstIP := net.ParseIP(req.GetDestinationIp())
if req.GetDestinationIp() == "self" { if req.GetDestinationIp() == "self" {
dstIP = engine.GetWgAddr() dstIP = engine.GetWgAddr()
} }
dstAddr, ok := netip.AddrFromSlice(dstIP)
if !ok {
return nil, fmt.Errorf("invalid source IP address")
}
if srcIP == nil || dstIP == nil { if srcIP == nil || dstIP == nil {
return nil, fmt.Errorf("invalid IP address") return nil, fmt.Errorf("invalid IP address")
} }
@@ -85,8 +96,8 @@ func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (
} }
builder := &uspfilter.PacketBuilder{ builder := &uspfilter.PacketBuilder{
SrcIP: srcIP, SrcIP: srcAddr,
DstIP: dstIP, DstIP: dstAddr,
Protocol: protocol, Protocol: protocol,
SrcPort: uint16(req.GetSourcePort()), SrcPort: uint16(req.GetSourcePort()),
DstPort: uint16(req.GetDestinationPort()), DstPort: uint16(req.GetDestinationPort()),

Some files were not shown because too many files have changed in this diff Show More